├── .flake8 ├── .gitattributes ├── .github ├── dependabot.yml └── workflows │ ├── ci.yml │ └── pypi.yml ├── .gitignore ├── .isort.cfg ├── LICENSE ├── README.md ├── apgorm ├── __init__.py ├── __main__.py ├── connection.py ├── constraints │ ├── __init__.py │ ├── check.py │ ├── constraint.py │ ├── exclude.py │ ├── foreign_key.py │ ├── primary_key.py │ └── unique.py ├── converter.py ├── database.py ├── exceptions.py ├── field.py ├── indexes.py ├── manytomany.py ├── migrations │ ├── __init__.py │ ├── applied_migration.py │ ├── apply_migration.py │ ├── create_migration.py │ ├── describe.py │ └── migration.py ├── model.py ├── py.typed ├── sql │ ├── __init__.py │ ├── generators │ │ ├── __init__.py │ │ ├── alter.py │ │ └── query.py │ ├── query_builder.py │ └── sql.py ├── types │ ├── __init__.py │ ├── array.py │ ├── base_type.py │ ├── binary.py │ ├── bitstr.py │ ├── boolean.py │ ├── character.py │ ├── date.py │ ├── geometric.py │ ├── json_type.py │ ├── monetary.py │ ├── network.py │ ├── numeric.py │ ├── uuid_type.py │ └── xml_type.py ├── undefined.py └── utils │ ├── __init__.py │ └── lazy_list.py ├── examples ├── __init__.py ├── basic │ ├── __init__.py │ ├── __main__.py │ ├── main.py │ └── migrations │ │ └── 0000 │ │ ├── describe.json │ │ └── migrations.sql ├── converters │ ├── __init__.py │ ├── __main__.py │ ├── main.py │ └── migrations │ │ └── 0000 │ │ ├── describe.json │ │ └── migrations.sql ├── manytomany │ ├── __init__.py │ ├── __main__.py │ ├── main.py │ └── migrations │ │ └── 0000 │ │ ├── describe.json │ │ └── migrations.sql └── validators │ ├── __init__.py │ ├── __main__.py │ └── main.py ├── mypy.ini ├── noxfile.py ├── pyproject.toml └── tests ├── conftest.py ├── grouptests └── test_migrations.py └── unittests ├── sql ├── generators │ └── test_query.py └── test_query_builder.py ├── test_connection.py ├── test_converter.py ├── test_database.py ├── test_field.py ├── test_indexes.py ├── test_model.py ├── types ├── test_array.py ├── test_bitstr.py ├── test_character.py ├── test_date.py └── test_numeric.py └── utils └── test_lazy_list.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | extend_ignore = E203 3 | exclude=.nox 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | ci: 8 | runs-on: ${{ matrix.os }}-latest 9 | strategy: 10 | matrix: 11 | python-version: [ '3.8', '3.9', '3.10' ] 12 | os: [ ubuntu, windows ] 13 | 14 | name: ${{ matrix.python-version }} ${{ matrix.os }} 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install Nox 22 | run: pip install nox 23 | - name: Run Nox 24 | run: nox 25 | - name: Upload to CodeCov 26 | uses: codecov/codecov-action@v3 27 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 13 | uses: actions/setup-python@v3 14 | with: 15 | python-version: '3.8' 16 | - name: Install Dependencies 17 | run: pip install poetry 18 | - name: Build & Upload 19 | env: 20 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 21 | run: | 22 | poetry config pypi-token.pypi $PYPI_TOKEN 23 | poetry version $(git describe --tags --abbrev=0) 24 | poetry build 25 | poetry publish 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .DS_Store 3 | poetry.lock 4 | tests/migrations/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | line_length = 79 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021-present CircuitSacul 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # apgorm 2 | [![pypi](https://github.com/TrigonDev/apgorm/actions/workflows/pypi.yml/badge.svg)](https://pypi.org/project/apgorm) 3 | [![codecov](https://codecov.io/gh/TrigonDev/apgorm/branch/main/graph/badge.svg?token=LEY276K4NS)](https://codecov.io/gh/TrigonDev/apgorm) 4 | 5 | [Documentation](https://github.com/circuitsacul/apgorm/wiki) | [Support](https://discord.gg/dGAzZDaTS9) 6 | 7 | An asynchronous ORM wrapped around asyncpg. Examples can be found under `examples/`. Run examples with `python -m examples.` (`python -m examples.basic`). 8 | 9 | Please note that this library is not for those learning SQL or Postgres. Although the basic usage of apgorm is straightforward, you will run into problems, especially with migrations, if you don't understand regular SQL well. 10 | 11 | ## Features 12 | - Fairly straightforward and easy-to-use. 13 | - Support for basic migrations. 14 | - Protects against SQL-injection. 15 | - Python-side converters and validators. 16 | - Decent many-to-many support. 17 | - Fully type-checked. 18 | - Tested. 19 | 20 | ## Limitations 21 | - Limited column namespace. For example, you cannot have a column named `tablename` since that is used to store the name of the model. 22 | - Only supports PostgreSQL with asyncpg. 23 | - Migrations don't natively support field/table renaming or type changes, but you can still write your own migration with raw SQL. 24 | - Converters only work on instances of models and when initializing the model. 25 | - Models can only detect assignments. If `User.nicknames` is a list of nicknames, it won't detect `user.nicknames.append("new nick")`. You need to do `user.nicknames = user.nicknames + ["new nick"]`. 26 | 27 | ## Basic Usage 28 | Defining a model and database: 29 | ```py 30 | class User(apgorm.Model): 31 | username = apgorm.types.VarChar(32).field() 32 | email = apgorm.types.VarChar().nullablefield() 33 | 34 | primary_key = (username,) 35 | 36 | class Database(apgorm.Database): 37 | users = User 38 | ``` 39 | 40 | Intializing the database: 41 | ```py 42 | db = Database(migrations_folder="path/to/migrations") 43 | await db.connect(database="database name") 44 | ``` 45 | 46 | Creating & Applying migrations: 47 | ```py 48 | if db.must_create_migrations(): 49 | db.create_migrations() 50 | if await db.must_apply_migrations(): 51 | await db.apply_migrations() 52 | ``` 53 | 54 | Basic create, fetch, update, and delete: 55 | ```py 56 | user = await User(username="Circuit").create() 57 | print("Created user", user) 58 | 59 | assert user == await User.fetch(username="Circuit") 60 | 61 | user.email = "email@example.com" 62 | await user.save() 63 | 64 | await user.delete() 65 | ``` 66 | -------------------------------------------------------------------------------- /apgorm/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | apgorm is a fully type-checked asynchronous ORM wrapped around asyncpg. 3 | 4 | apgorm is licensed under the MIT license. 5 | 6 | https://github.com/TrigonDev/apgorm 7 | """ 8 | 9 | from importlib import metadata 10 | 11 | from . import exceptions 12 | from .connection import Connection, Pool, PoolAcquireContext 13 | from .constraints.check import Check 14 | from .constraints.constraint import Constraint 15 | from .constraints.exclude import Exclude 16 | from .constraints.foreign_key import ForeignKey, ForeignKeyAction 17 | from .constraints.primary_key import PrimaryKey 18 | from .constraints.unique import Unique 19 | from .converter import Converter, IntEFConverter 20 | from .database import Database 21 | from .field import BaseField, ConverterField, Field 22 | from .indexes import Index, IndexType 23 | from .manytomany import ManyToMany 24 | from .migrations.describe import ( 25 | Describe, 26 | DescribeConstraint, 27 | DescribeField, 28 | DescribeIndex, 29 | DescribeTable, 30 | ) 31 | from .migrations.migration import Migration 32 | from .model import Model 33 | from .sql.query_builder import ( 34 | BaseQueryBuilder, 35 | DeleteQueryBuilder, 36 | FetchQueryBuilder, 37 | FilterQueryBuilder, 38 | InsertQueryBuilder, 39 | UpdateQueryBuilder, 40 | ) 41 | from .sql.sql import ( 42 | CASTED, 43 | SQL, 44 | Block, 45 | Comparable, 46 | Parameter, 47 | Raw, 48 | Renderer, 49 | and_, 50 | join, 51 | or_, 52 | raw, 53 | sql, 54 | wrap, 55 | ) 56 | from .undefined import UNDEF 57 | from .utils.lazy_list import LazyList 58 | 59 | __version__ = metadata.version(__name__) 60 | 61 | __all__ = ( 62 | "Converter", 63 | "Model", 64 | "Database", 65 | "ManyToMany", 66 | "Constraint", 67 | "Check", 68 | "ForeignKey", 69 | "ForeignKeyAction", 70 | "PrimaryKey", 71 | "Exclude", 72 | "Unique", 73 | "Index", 74 | "IndexType", 75 | "Migration", 76 | "Describe", 77 | "DescribeConstraint", 78 | "DescribeTable", 79 | "DescribeField", 80 | "DescribeIndex", 81 | "CASTED", 82 | "SQL", 83 | "Block", 84 | "Comparable", 85 | "Parameter", 86 | "Raw", 87 | "Renderer", 88 | "BaseQueryBuilder", 89 | "FetchQueryBuilder", 90 | "DeleteQueryBuilder", 91 | "FilterQueryBuilder", 92 | "InsertQueryBuilder", 93 | "UpdateQueryBuilder", 94 | "LazyList", 95 | "Connection", 96 | "Pool", 97 | "Field", 98 | "BaseField", 99 | "ConverterField", 100 | "PoolAcquireContext", 101 | "IntEFConverter", 102 | "UNDEF", 103 | "and_", 104 | "join", 105 | "or_", 106 | "raw", 107 | "sql", 108 | "wrap", 109 | "types", 110 | "migrations", 111 | "exceptions", 112 | "undefined", 113 | ) 114 | -------------------------------------------------------------------------------- /apgorm/__main__.py: -------------------------------------------------------------------------------- 1 | import apgorm 2 | 3 | print( 4 | f"apgorm {apgorm.__version__} - Fully type-checked asynchronous ORM " 5 | "wrapped around asyncpg." 6 | ) 7 | print("https://github.com/TrigonDev/apgorm") 8 | -------------------------------------------------------------------------------- /apgorm/connection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Coroutine, Iterable 4 | 5 | import asyncpg 6 | from asyncpg.cursor import CursorFactory 7 | from asyncpg.transaction import Transaction 8 | 9 | from .utils.lazy_list import LazyList 10 | 11 | 12 | class PoolAcquireContext: 13 | __slots__: Iterable[str] = ("pac",) 14 | 15 | def __init__(self, pac: asyncpg.pool.PoolAcquireContext) -> None: 16 | self.pac = pac 17 | 18 | async def __aenter__(self) -> Connection: 19 | return Connection(await self.pac.__aenter__()) 20 | 21 | async def __aexit__(self, *exc: Exception) -> None: 22 | await self.pac.__aexit__(*exc) 23 | 24 | 25 | class Pool: 26 | __slots__: Iterable[str] = ("pool",) 27 | 28 | def __init__(self, pool: asyncpg.Pool) -> None: 29 | self.pool = pool 30 | 31 | def acquire(self) -> PoolAcquireContext: 32 | return PoolAcquireContext(self.pool.acquire()) 33 | 34 | def close(self) -> Coroutine[Any, Any, None]: 35 | return self.pool.close() # type: ignore 36 | 37 | 38 | class Connection: 39 | """Wrapper around asyncpg.Connection.""" 40 | 41 | __slots__: Iterable[str] = ("con",) 42 | 43 | def __init__(self, con: asyncpg.Connection) -> None: 44 | self.con = con 45 | 46 | def transaction(self) -> Transaction: 47 | """Enter a transaction. 48 | 49 | Usage: 50 | ``` 51 | async with con.transaction(): 52 | await con.execute("DROP TABLE users", []) 53 | raise Exception("Wait no I don't want to do that!") 54 | ``` 55 | 56 | Alternatively: 57 | ``` 58 | t = con.transaction() 59 | await t.start() 60 | await con.execute("DROP TABLE users", []) 61 | await t.rollback() 62 | 63 | # or, if you did want to do that, replace t.rollback() with t.commit(): 64 | await t.commit() 65 | ``` 66 | 67 | Returns: 68 | Transaction: The asyncpg.Transaction object. 69 | """ 70 | 71 | return self.con.transaction() 72 | 73 | async def execute( 74 | self, query: str, params: list[Any] | None = None 75 | ) -> None: 76 | """Execute SQL. 77 | 78 | Consider using Database.execute() unless you want to manage the 79 | transaction flow. 80 | 81 | Args: 82 | query (str): The raw SQL. 83 | params (list[Any], optional): List of parameters. Defaults to None. 84 | """ 85 | 86 | params = params or [] 87 | await self.con.execute(query, *params) 88 | 89 | async def fetchrow( 90 | self, query: str, params: list[Any] | None = None 91 | ) -> dict[str, Any] | None: 92 | """Execute SQL and return a single row, if any. 93 | 94 | Consider using Database.fetchrow() unless you want to manage the 95 | transaction flow. 96 | 97 | Args: 98 | query (str): The raw SQL. 99 | params (list[Any], optional): List of parameters. Defaults to None. 100 | 101 | Returns: 102 | dict | None: The row or None. 103 | """ 104 | 105 | params = params or [] 106 | res = await self.con.fetchrow(query, *params) 107 | if res is not None: 108 | res = dict(res) 109 | assert res is None or isinstance(res, dict) 110 | return res 111 | 112 | async def fetchmany( 113 | self, query: str, params: list[Any] | None = None 114 | ) -> LazyList[asyncpg.Record, dict[str, Any]]: 115 | """Execute SQL, returning all found rows. 116 | 117 | See LazyList docs. 118 | 119 | Consider using Database.fetchmany() unless you want to manage the 120 | transaction flow. 121 | 122 | Args: 123 | query (str): The raw SQL. 124 | params (list[Any], optional): List of parameters. Defaults to None. 125 | 126 | Returns: 127 | LazyList[asyncpg.Record, dict[str, Any]]: [description] 128 | """ 129 | 130 | params = params or [] 131 | return LazyList(await self.con.fetch(query, *params), dict) 132 | 133 | async def fetchval( 134 | self, query: str, params: list[Any] | None = None 135 | ) -> Any: 136 | """Execute SQL, returning the values. 137 | 138 | For example: 139 | ``` 140 | db.fetchrow("SELECT name FROM users") 141 | # -> {"name": "Circuit"} 142 | 143 | db.fetchval("SELECT name FROM users") 144 | # -> "Circuit" 145 | ``` 146 | 147 | Args: 148 | query (str): [description] 149 | params (list[Any], optional): [description]. Defaults to None. 150 | 151 | Returns: 152 | Any: [description] 153 | """ 154 | 155 | return await self.con.fetchval(query, *params) 156 | 157 | def cursor( 158 | self, query: str, params: list[Any] | None = None 159 | ) -> CursorFactory: 160 | """Returns a CursorFactory, but outside a transaction. 161 | 162 | Since cursors must be inside a transaction, you must handle the 163 | transaction flow like this: 164 | 165 | ``` 166 | async with con.transaction(): 167 | async for row in con.cursor(): 168 | ... 169 | ``` 170 | 171 | Args: 172 | query (str): The raw SQL 173 | params (list[Any], optional): List of parameters. Defaults to None. 174 | 175 | Returns: 176 | CursorFactory: The cursor factory. 177 | """ 178 | 179 | params = params or [] 180 | return self.con.cursor(query, *params) 181 | -------------------------------------------------------------------------------- /apgorm/constraints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/apgorm/constraints/__init__.py -------------------------------------------------------------------------------- /apgorm/constraints/check.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Iterable 4 | 5 | from apgorm.sql.sql import Block, raw 6 | 7 | from .constraint import Constraint 8 | 9 | if TYPE_CHECKING: # pragma: no cover 10 | from apgorm.types.boolean import Bool 11 | 12 | 13 | class Check(Constraint): 14 | __slots__: Iterable[str] = ("check",) 15 | 16 | def __init__(self, check: Block[Bool] | str) -> None: 17 | """Specify a check constraint for a table. 18 | 19 | Args: 20 | check (Block[Bool]): The raw SQL for the check constraint. 21 | """ 22 | 23 | self.check = raw(check) if isinstance(check, str) else check 24 | 25 | def _creation_sql(self) -> Block[Any]: 26 | return Block( 27 | raw("CONSTRAINT"), 28 | raw(self.name), 29 | raw("CHECK"), 30 | Block(self.check, wrap=True), 31 | wrap=True, 32 | ) 33 | -------------------------------------------------------------------------------- /apgorm/constraints/constraint.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable 2 | 3 | from apgorm.migrations.describe import DescribeConstraint 4 | from apgorm.sql.sql import Block 5 | 6 | 7 | class Constraint: 8 | """The base class for all constraints.""" 9 | 10 | __slots__: Iterable[str] = ("name",) 11 | 12 | name: str 13 | """The name of the constraint. 14 | 15 | Populated by Database and will not exist until an instance 16 | of Database has been created.""" 17 | 18 | def _creation_sql(self) -> Block[Any]: 19 | raise NotImplementedError # pragma: no cover 20 | 21 | def _describe(self) -> DescribeConstraint: 22 | return DescribeConstraint( 23 | name=self.name, raw_sql=self._creation_sql().render_no_params() 24 | ) 25 | -------------------------------------------------------------------------------- /apgorm/constraints/exclude.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Iterable 4 | 5 | from apgorm.indexes import IndexType, _IndexType 6 | from apgorm.sql.sql import Block, join, raw, wrap 7 | 8 | from .constraint import Constraint 9 | 10 | if TYPE_CHECKING: # pragma: no cover 11 | from apgorm.field import BaseField 12 | from apgorm.types.boolean import Bool 13 | 14 | 15 | class Exclude(Constraint): 16 | __slots__: Iterable[str] = ("using", "elements", "where") 17 | 18 | def __init__( 19 | self, 20 | *elements: tuple[BaseField[Any, Any, Any] | Block[Any] | str, str], 21 | using: IndexType = IndexType.BTREE, 22 | where: Block[Bool] | str | None = None, 23 | ) -> None: 24 | """Specify an Exclusion constraint for a table. 25 | 26 | Args: 27 | *elements: key-value pairs of fields and the operator to use on 28 | that field. 29 | 30 | Kwargs: 31 | using (IndexType, optional): The index to use. Defaults 32 | to IndexType.BTREE. 33 | where (Block[Bool] | str, optional): Specify condition for 34 | applying this constraint. Defaults to None. 35 | """ 36 | 37 | self.using: _IndexType = using.value 38 | self.elements = [ 39 | (raw(f) if isinstance(f, str) else f, op) for f, op in elements 40 | ] 41 | self.where = raw(where) if isinstance(where, str) else where 42 | 43 | def _creation_sql(self) -> Block[Any]: 44 | sql = Block[Any]( 45 | raw("CONSTRAINT"), 46 | raw(self.name), 47 | raw("EXCLUDE USING"), 48 | raw(self.using.name), 49 | join( 50 | raw(","), 51 | *(wrap(f, raw("WITH"), raw(op)) for f, op in self.elements), 52 | wrap=True, 53 | ), 54 | ) 55 | if self.where is not None: 56 | sql += Block(raw("WHERE"), wrap(self.where)) 57 | 58 | return sql 59 | -------------------------------------------------------------------------------- /apgorm/constraints/foreign_key.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import Enum 4 | from typing import Any, Iterable, Sequence 5 | from typing import cast as typingcast 6 | 7 | from apgorm.exceptions import BadArgument 8 | from apgorm.field import BaseField 9 | from apgorm.sql.sql import Block, join, raw 10 | 11 | from .constraint import Constraint 12 | 13 | 14 | class ForeignKeyAction(Enum): 15 | """Action for ON UPDATE or ON DELETE of ForeignKey.""" 16 | 17 | CASCADE = "CASCADE" 18 | """Carry the changes""" 19 | RESTRICT = "RESTRICT" 20 | """Prevent the changes""" 21 | NO_ACTION = "NO ACTION" 22 | """Do nothing""" 23 | 24 | 25 | class ForeignKey(Constraint): 26 | __slots__: Iterable[str] = ( 27 | "fields", 28 | "ref_fields", 29 | "ref_table", 30 | "match_full", 31 | "on_delete", 32 | "on_update", 33 | ) 34 | 35 | def __init__( 36 | self, 37 | fields: Sequence[Block[Any] | BaseField[Any, Any, Any] | str] 38 | | Block[Any] 39 | | BaseField[Any, Any, Any] 40 | | str, 41 | ref_fields: Sequence[Block[Any] | BaseField[Any, Any, Any] | str] 42 | | Block[Any] 43 | | BaseField[Any, Any, Any] 44 | | str, 45 | ref_table: Block[Any] | str | None = None, 46 | match_full: bool = False, 47 | on_delete: ForeignKeyAction = ForeignKeyAction.CASCADE, 48 | on_update: ForeignKeyAction = ForeignKeyAction.CASCADE, 49 | ) -> None: 50 | """Specify a ForeignKey constraint for a table. 51 | 52 | Args: 53 | fields (Sequence[Block | BaseField]): A list of fields or raw 54 | field names on the current table. 55 | ref_fields (Sequence[Block | BaseField]): A list of fields or raw 56 | field names on the referenced table. 57 | ref_table (Block, optional): If all of `ref_fields` are `Block`, 58 | specify the raw tablename of the referenced table. Defaults to 59 | None. 60 | match_full (bool, optional): Whether or not a full match is 61 | required. If False, MATCH SIMPLE is used instead. Defaults to 62 | False. 63 | on_delete (Action, optional): The action to perform if the 64 | referenced row is deleted. Defaults to Action.CASCADE. 65 | on_update (Action, optional): The action to perform if the 66 | referenced row is updated. Defaults to Action.CASCADE. 67 | 68 | Raises: 69 | BadArgument: Bad arguments were sent to ForeignKey. 70 | """ 71 | 72 | self.fields: Sequence[BaseField[Any, Any, Any] | Block[Any]] = [ 73 | raw(f) if isinstance(f, str) else f 74 | for f in (fields if isinstance(fields, Sequence) else [fields]) 75 | ] 76 | self.ref_fields: Sequence[BaseField[Any, Any, Any] | Block[Any]] = [ 77 | raw(f) if isinstance(f, str) else f 78 | for f in ( 79 | ref_fields 80 | if isinstance(ref_fields, Sequence) 81 | else [ref_fields] 82 | ) 83 | ] 84 | self.ref_table = ( 85 | raw(ref_table) if isinstance(ref_table, str) else ref_table 86 | ) 87 | self.match_full = match_full 88 | self.on_delete = on_delete 89 | self.on_update = on_update 90 | 91 | if len(self.ref_fields) != len(self.fields): 92 | raise BadArgument( 93 | "Must have same number of fields and ref_fields." 94 | ) 95 | 96 | if not self.fields: 97 | raise BadArgument("Must specify at least on field and ref_field.") 98 | 99 | def _creation_sql(self) -> Block[Any]: 100 | if ( 101 | len( 102 | { 103 | f.model.tablename 104 | for f in self.ref_fields 105 | if isinstance(f, BaseField) 106 | } 107 | ) 108 | > 1 109 | ): 110 | raise BadArgument( 111 | "All fields in ref_fields must be of the same table." 112 | ) 113 | 114 | if self.ref_table is None: 115 | _ref_fields = self.ref_fields 116 | if not all(isinstance(f, BaseField) for f in _ref_fields): 117 | raise BadArgument( 118 | "ref_fields must either all be BaseFields or " 119 | "ref_table must be specified." 120 | ) 121 | _ref_fields = typingcast( 122 | Sequence[BaseField[Any, Any, Any]], _ref_fields 123 | ) 124 | 125 | ref_table = raw(_ref_fields[0].model.tablename) 126 | 127 | else: 128 | ref_table = self.ref_table 129 | 130 | ref_fields = ( 131 | raw(f.name) if isinstance(f, BaseField) else f 132 | for f in self.ref_fields 133 | ) 134 | fields = ( 135 | raw(f.name) if isinstance(f, BaseField) else f for f in self.fields 136 | ) 137 | 138 | return Block( 139 | raw("CONSTRAINT"), 140 | raw(self.name), 141 | raw("FOREIGN KEY ("), 142 | join(raw(","), *fields), 143 | raw(") REFERENCES"), 144 | ref_table, 145 | raw("("), 146 | join(raw(","), *ref_fields), 147 | raw(") MATCH"), 148 | raw("FULL" if self.match_full else "SIMPLE"), 149 | raw(f"ON DELETE {self.on_delete.value}"), 150 | raw(f"ON UPDATE {self.on_update.value}"), 151 | wrap=True, 152 | ) 153 | -------------------------------------------------------------------------------- /apgorm/constraints/primary_key.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Iterable 4 | 5 | from apgorm.sql.sql import Block, join, raw 6 | 7 | from .constraint import Constraint 8 | 9 | if TYPE_CHECKING: # pragma: no cover 10 | from apgorm.field import BaseField 11 | 12 | 13 | class PrimaryKey(Constraint): 14 | __slots__: Iterable[str] = ("fields",) 15 | 16 | def __init__( 17 | self, *fields: BaseField[Any, Any, Any] | Block[Any] | str 18 | ) -> None: 19 | """Do not use. Specify primary keys like this instead: 20 | 21 | ``` 22 | class Users(Model): 23 | username = VarChar(32).field() 24 | 25 | primary_key = (username,) 26 | ``` 27 | """ 28 | 29 | self.fields = fields 30 | 31 | def _creation_sql(self) -> Block[Any]: 32 | fields = ( 33 | raw(f) 34 | if isinstance(f, str) 35 | else f 36 | if isinstance(f, Block) 37 | else raw(f.name) 38 | for f in self.fields 39 | ) 40 | return Block( 41 | raw("CONSTRAINT"), 42 | raw(self.name), 43 | raw("PRIMARY KEY ("), 44 | join(raw(","), *fields), 45 | raw(")"), 46 | wrap=True, 47 | ) 48 | -------------------------------------------------------------------------------- /apgorm/constraints/unique.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Iterable 4 | 5 | from apgorm.sql.sql import Block, join, raw 6 | 7 | from .constraint import Constraint 8 | 9 | if TYPE_CHECKING: # pragma: no cover 10 | from apgorm.field import BaseField 11 | 12 | 13 | class Unique(Constraint): 14 | __slots__: Iterable[str] = ("fields",) 15 | 16 | def __init__( 17 | self, *fields: BaseField[Any, Any, Any] | Block[Any] | str 18 | ) -> None: 19 | """Specify a unique constraint for a table. 20 | 21 | ``` 22 | class User(Model): 23 | ... 24 | nickname = VarChar(32).field() 25 | nickname_unique = Unique(nickname) 26 | ... 27 | ``` 28 | """ 29 | 30 | self.fields = fields 31 | 32 | def _creation_sql(self) -> Block[Any]: 33 | fields = ( 34 | raw(f) 35 | if isinstance(f, str) 36 | else f 37 | if isinstance(f, Block) 38 | else raw(f.name) 39 | for f in self.fields 40 | ) 41 | return Block( 42 | raw("CONSTRAINT"), 43 | raw(self.name), 44 | raw("UNIQUE ("), 45 | join(raw(","), *fields), 46 | raw(")"), 47 | wrap=True, 48 | ) 49 | -------------------------------------------------------------------------------- /apgorm/converter.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from enum import IntEnum, IntFlag 5 | from typing import Generic, Iterable, Type, TypeVar, Union, cast 6 | 7 | _ORIG = TypeVar("_ORIG") 8 | _CONV = TypeVar("_CONV") 9 | 10 | 11 | class Converter(ABC, Generic[_ORIG, _CONV]): 12 | """Base class that must be used for all field converters.""" 13 | 14 | __slots__: Iterable[str] = () 15 | 16 | @abstractmethod 17 | def from_stored(self, value: _ORIG) -> _CONV: 18 | """Take the value given by the database and convert it to the type 19 | used in your code.""" 20 | 21 | @abstractmethod 22 | def to_stored(self, value: _CONV) -> _ORIG: 23 | """Take the type used by your code and convert it to the type 24 | used to store the value in the database.""" 25 | 26 | 27 | _INTEF = TypeVar("_INTEF", bound="Union[IntFlag, IntEnum]") 28 | 29 | 30 | class IntEFConverter(Converter[int, _INTEF], Generic[_INTEF]): 31 | """Converter that converts integers to IntEnums or IntFlags. 32 | ``` 33 | class User(Model): 34 | flags = Int().field().with_converter(IntEFConverter(MyIntFlag)) 35 | ``` 36 | 37 | Args: 38 | type_ (Type[IntEnum] | Type[IntFlag]): The IntEnum or IntFlag to use. 39 | """ 40 | 41 | __slots__: Iterable[str] = ("_type",) 42 | 43 | def __init__(self, type_: Type[_INTEF]): 44 | self._type: Type[_INTEF] = type_ 45 | 46 | def from_stored(self, value: int) -> _INTEF: 47 | return cast(_INTEF, self._type(value)) 48 | 49 | def to_stored(self, value: _INTEF) -> int: 50 | return cast(Union[IntEnum, IntFlag], value).value 51 | -------------------------------------------------------------------------------- /apgorm/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from contextlib import asynccontextmanager 5 | from pathlib import Path 6 | from typing import Any, AsyncGenerator, Awaitable, Iterable, Sequence 7 | 8 | import asyncpg 9 | from asyncpg.cursor import CursorFactory 10 | 11 | from .connection import Connection, Pool 12 | from .exceptions import NoMigrationsToCreate 13 | from .indexes import Index 14 | from .migrations import describe 15 | from .migrations.applied_migration import AppliedMigration 16 | from .migrations.apply_migration import apply_migration 17 | from .migrations.create_migration import create_next_migration 18 | from .migrations.migration import Migration 19 | from .model import Model 20 | from .utils.lazy_list import LazyList 21 | 22 | 23 | class Database: 24 | """Base database class. You must subclass this, as 25 | all tables and indexes are stored as classvars. 26 | 27 | Example: 28 | ``` 29 | class MyDatabase(Database): 30 | users = User 31 | games = Game 32 | players = Players 33 | 34 | indexes = [SomeIndex, SomeOtherIndex] 35 | ``` 36 | """ 37 | 38 | __slots__: Iterable[str] = ( 39 | "_migrations_folder", 40 | "pool", 41 | "default_padding", 42 | ) 43 | 44 | _migrations: type[AppliedMigration] 45 | """Internal table to track applied migrations.""" 46 | _all_models: list[type[Model]] 47 | """A list of all models on this database.""" 48 | 49 | indexes: Sequence[Index] | None = None 50 | """A list of indexes for the database.""" 51 | 52 | def __init_subclass__(cls) -> None: 53 | cls._migrations = AppliedMigration 54 | cls._all_models = [] 55 | 56 | for attr_name, model in cls.__dict__.items(): 57 | if not (isinstance(model, type) and issubclass(model, Model)): 58 | continue 59 | 60 | model.tablename = attr_name 61 | cls._all_models.append(model) 62 | 63 | def __init__( 64 | self, migrations_folder: Path | str, padding: int = 4 65 | ) -> None: 66 | """Initialize the database. 67 | 68 | Args: 69 | migrations_folder (pathlib.Path): The folder in which migrations 70 | are/will be stored. 71 | """ 72 | 73 | self._migrations_folder = ( 74 | Path(migrations_folder) 75 | if isinstance(migrations_folder, str) 76 | else migrations_folder 77 | ) 78 | 79 | for model in self._all_models: 80 | model.database = self 81 | 82 | self.default_padding = padding 83 | self.pool: Pool | None = None 84 | 85 | # migration functions 86 | def describe(self) -> describe.Describe: 87 | """Return the description of the database. 88 | 89 | Returns: 90 | describe.Describe: The description. 91 | """ 92 | 93 | return describe.Describe( 94 | tables=[m._describe() for m in self._all_models], 95 | indexes=[i._describe() for i in self.indexes or []], 96 | ) 97 | 98 | def load_all_migrations(self) -> list[Migration]: 99 | """Load all migrations and return them. 100 | 101 | Returns: 102 | list[Migration]: The migrations. 103 | """ 104 | 105 | return Migration._load_all_migrations(self._migrations_folder) 106 | 107 | def load_last_migration(self) -> Migration | None: 108 | """Loads and returns the most recent migration, if any. 109 | 110 | Returns: 111 | Migration | None: The migration. 112 | """ 113 | 114 | return Migration._load_last_migration(self._migrations_folder) 115 | 116 | def load_migration_from_id(self, migration_id: int) -> Migration: 117 | """Loads a migration from its id. 118 | 119 | Args: 120 | migration_id (int): The id of the migration. 121 | 122 | Returns: 123 | Migration: The migration. 124 | """ 125 | 126 | return Migration._from_path( 127 | Migration._path_from_id( 128 | migration_id, self._migrations_folder, self.default_padding 129 | ) 130 | ) 131 | 132 | def must_create_migrations(self) -> bool: 133 | """Whether or not you need to call Database.create_migrations() 134 | 135 | Returns: 136 | bool 137 | """ 138 | 139 | sql = self._create_next_migration() 140 | return sql is not None 141 | 142 | def create_migrations(self, allow_empty: bool = False) -> Migration: 143 | """Create migrations. 144 | 145 | Args: 146 | allow_empty (bool): Whether to allow the migration to be empty 147 | (useful for creating custom migrations). 148 | 149 | Raises: 150 | NoMigrationsToCreate: Migrations do not creating. 151 | 152 | Returns: 153 | Migration: The created migration. 154 | """ 155 | 156 | if (not self.must_create_migrations()) and not allow_empty: 157 | raise NoMigrationsToCreate 158 | sql = self._create_next_migration() or "" 159 | return Migration._create_migration( 160 | self.describe(), sql, self._migrations_folder, self.default_padding 161 | ) 162 | 163 | async def load_unapplied_migrations(self) -> list[Migration]: 164 | """Returns a list of migrations that have not been applied on the 165 | current database. 166 | 167 | Returns: 168 | list[Migration]: The unapplied migrations. 169 | """ 170 | 171 | try: 172 | applied = [ 173 | m.id_ for m in await self._migrations.fetch_query().fetchmany() 174 | ] 175 | except asyncpg.UndefinedTableError: 176 | applied = [] 177 | return [ 178 | m 179 | for m in self.load_all_migrations() 180 | if m.migration_id not in applied 181 | ] 182 | 183 | async def must_apply_migrations(self) -> bool: 184 | """Whether or not there are migrations that need to be applied. 185 | 186 | Returns: 187 | bool 188 | """ 189 | 190 | return len(await self.load_unapplied_migrations()) > 0 191 | 192 | async def apply_migrations(self) -> None: 193 | """Applies all migrations that need to be applied.""" 194 | 195 | unapplied = await self.load_unapplied_migrations() 196 | unapplied.sort(key=lambda m: m.migration_id) 197 | for m in unapplied: 198 | await self._apply_migration(m) 199 | 200 | # database functions 201 | async def connect(self, **connect_kwargs: Any) -> None: 202 | """Connect to a database. Any kwargs that can be passed to 203 | asyncpg.create_pool() can be used here. 204 | """ 205 | 206 | self.pool = Pool(await asyncpg.create_pool(**connect_kwargs)) 207 | 208 | async def cleanup(self, timeout: float = 30) -> None: 209 | """Close the connection. 210 | 211 | Args: 212 | timeout (float): The maximum time pool.close() has. Defaults to 30. 213 | 214 | Raises: 215 | TimeoutError: The operation timed out. 216 | """ 217 | 218 | if self.pool is not None: 219 | await asyncio.wait_for(self.pool.close(), timeout=timeout) 220 | 221 | async def execute(self, query: str, params: list[Any]) -> None: 222 | """Execute SQL within a transaction.""" 223 | 224 | assert self.pool is not None 225 | async with self.pool.acquire() as con: 226 | async with con.transaction(): 227 | await con.execute(query, params) 228 | 229 | async def fetchrow( 230 | self, query: str, params: list[Any] 231 | ) -> dict[str, Any] | None: 232 | """Fetch the first matching row. 233 | 234 | Returns: 235 | dict | None: The row, if any. 236 | """ 237 | 238 | assert self.pool is not None 239 | async with self.pool.acquire() as con: 240 | async with con.transaction(): 241 | return await con.fetchrow(query, params) 242 | 243 | async def fetchmany( 244 | self, query: str, params: list[Any] 245 | ) -> LazyList[asyncpg.Record, dict[str, Any]]: 246 | """Fetch all matching rows. 247 | 248 | Returns: 249 | LazyList[asyncpg.Record, dict]: All matching rows. 250 | """ 251 | 252 | assert self.pool is not None 253 | async with self.pool.acquire() as con: 254 | async with con.transaction(): 255 | return await con.fetchmany(query, params) 256 | 257 | async def fetchval(self, query: str, params: list[Any]) -> Any: 258 | """Fetch a single value.""" 259 | 260 | assert self.pool is not None 261 | async with self.pool.acquire() as con: 262 | async with con.transaction(): 263 | return await con.fetchval(query, params) 264 | 265 | @asynccontextmanager 266 | async def cursor( 267 | self, query: str, params: list[Any], con: Connection | None = None 268 | ) -> AsyncGenerator[CursorFactory, None]: 269 | """Yields a CursorFactory. 270 | 271 | Usage: 272 | ``` 273 | async with db.cursor(query, args) as cursor: 274 | async for res in cursor: 275 | print(res) 276 | ``` 277 | """ 278 | 279 | if con: 280 | yield con.cursor(query, params) 281 | 282 | else: 283 | assert self.pool is not None 284 | async with self.pool.acquire() as con: 285 | async with con.transaction(): 286 | yield con.cursor(query, params) 287 | 288 | def _create_next_migration(self) -> str | None: 289 | return create_next_migration(self.describe(), self._migrations_folder) 290 | 291 | def _apply_migration(self, migration: Migration) -> Awaitable[None]: 292 | return apply_migration(migration, self) 293 | -------------------------------------------------------------------------------- /apgorm/exceptions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Iterable, Type 4 | 5 | if TYPE_CHECKING: # pragma: no cover 6 | from .field import BaseField 7 | from .model import Model 8 | 9 | 10 | class ApgormBaseException(Exception): 11 | """The base clase for all exceptions in apgorm.""" 12 | 13 | __slots__: Iterable[str] = () 14 | 15 | 16 | # migration-side exceptions 17 | class MigrationException(ApgormBaseException): 18 | """Base class for all exceptions related to migrations.""" 19 | 20 | __slots__: Iterable[str] = () 21 | 22 | 23 | class NoMigrationsToCreate(MigrationException): 24 | """The migration for the given id was not found.""" 25 | 26 | __slots__: Iterable[str] = () 27 | 28 | def __init__(self) -> None: 29 | super().__init__("There are no migrations to create.") 30 | 31 | 32 | class MigrationAlreadyApplied(MigrationException): 33 | """The migration has already been applied.""" 34 | 35 | __slots__: Iterable[str] = () 36 | 37 | def __init__(self, path: str) -> None: 38 | super().__init__(f"The migration at {path} has already been applied.") 39 | 40 | 41 | # apgorm-side exceptions 42 | class ApgormException(ApgormBaseException): 43 | """Base class for all exceptions related to the code of apgorm.""" 44 | 45 | __slots__: Iterable[str] = () 46 | 47 | 48 | class UndefinedFieldValue(ApgormException): 49 | """Raised if you try to get the value for a field that is undefined. 50 | 51 | Usually means that the model has not been created.""" 52 | 53 | __slots__: Iterable[str] = ("field",) 54 | 55 | def __init__(self, field: BaseField[Any, Any, Any]) -> None: 56 | self.field = field 57 | 58 | super().__init__( 59 | f"The field {field.full_name} is undefined. " 60 | "This usually means that the model has not been " 61 | "created." 62 | ) 63 | 64 | 65 | class InvalidFieldValue(ApgormException): 66 | """The field value failed the validator check.""" 67 | 68 | __slots__: Iterable[str] = () 69 | 70 | def __init__(self, message: str) -> None: 71 | self.message = message 72 | 73 | super().__init__(message) 74 | 75 | 76 | class SpecifiedPrimaryKey(ApgormException): 77 | """You tried to create a primary key constraint by using PrimaryKey 78 | instead of Model.primary_key.""" 79 | 80 | __slots__: Iterable[str] = () 81 | 82 | def __init__(self, cls: str, fields: Iterable[str]) -> None: 83 | super().__init__( 84 | f"You tried to specify a primary key on {cls} by using " 85 | f"the PrimaryKey constraint. Please use {cls}.primary_key " 86 | "instead:\nprimary_key = ({},)".format(", ".join(fields)) 87 | ) 88 | 89 | 90 | class BadArgument(ApgormException): 91 | """Bad arguments were passed.""" 92 | 93 | __slots__: Iterable[str] = ("message",) 94 | 95 | def __init__(self, message: str) -> None: 96 | self.message = message 97 | 98 | super().__init__(message) 99 | 100 | 101 | # postgres-side exceptions 102 | class SqlException(ApgormBaseException): 103 | """Base class for all exceptions related to SQL.""" 104 | 105 | __slots__: Iterable[str] = () 106 | 107 | 108 | class ModelNotFound(SqlException): 109 | """A model for the given parameters was not found.""" 110 | 111 | __slots__: Iterable[str] = ("model", "value") 112 | 113 | def __init__(self, model: Type[Model], values: dict[str, Any]) -> None: 114 | self.model = model 115 | self.values = values 116 | 117 | super().__init__( 118 | "No Model was found for the following parameters:\n - " 119 | + ("\n - ".join(f"{k!r} = {v!r}" for k, v in values.items())) 120 | ) 121 | -------------------------------------------------------------------------------- /apgorm/field.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | TYPE_CHECKING, 5 | Any, 6 | Callable, 7 | Generic, 8 | Iterable, 9 | Type, 10 | TypeVar, 11 | overload, 12 | ) 13 | 14 | from .converter import Converter 15 | from .exceptions import BadArgument, InvalidFieldValue, UndefinedFieldValue 16 | from .migrations.describe import DescribeField 17 | from .sql.sql import Comparable, raw 18 | from .undefined import UNDEF 19 | 20 | if TYPE_CHECKING: # pragma: no cover 21 | from .model import Model 22 | from .sql.sql import Block 23 | from .types.base_type import SqlType 24 | 25 | 26 | _SELF = TypeVar("_SELF", bound="BaseField[Any, Any, Any]", covariant=True) 27 | _T = TypeVar("_T") 28 | _C = TypeVar("_C") 29 | _F = TypeVar("_F", bound="SqlType[Any]") 30 | VALIDATOR = Callable[[_C], bool] 31 | 32 | 33 | class BaseField(Comparable, Generic[_F, _T, _C]): 34 | """Represents a field for a table.""" 35 | 36 | __slots__: Iterable[str] = ( 37 | "name", 38 | "model", 39 | "sql_type", 40 | "_default", 41 | "_default_factory", 42 | "not_null", 43 | "use_repr", 44 | "_validators", 45 | ) 46 | 47 | name: str # populated by Database 48 | """The name of the field, populated when Database is initialized.""" 49 | model: Type[Model] # populated by Database 50 | """The model this field belongs to, populated when Database is 51 | initialized.""" 52 | 53 | def __init__( 54 | self, 55 | sql_type: _F, 56 | *, 57 | default: _T | UNDEF = UNDEF.UNDEF, 58 | default_factory: Callable[[], _T] | None = None, 59 | not_null: bool = False, 60 | use_repr: bool = True, 61 | ) -> None: 62 | self.sql_type = sql_type 63 | 64 | if (default is not UNDEF.UNDEF) and (default_factory is not None): 65 | raise BadArgument("Cannot specify default and default_factory.") 66 | 67 | self._default = default 68 | self._default_factory = default_factory 69 | 70 | self.not_null = not_null 71 | 72 | self.use_repr = use_repr 73 | 74 | self._validators: list[VALIDATOR[_C]] = [] 75 | 76 | @overload 77 | def __get__(self: _SELF, inst: None, cls: Type[Model]) -> _SELF: 78 | ... 79 | 80 | @overload 81 | def __get__(self, inst: Model, cls: Type[Model]) -> _C: 82 | ... 83 | 84 | def __get__( 85 | self: _SELF, inst: Model | None, cls: Type[Model] 86 | ) -> _SELF | _C: 87 | raise NotImplementedError 88 | 89 | def __set__(self, inst: Model, value: _C) -> None: 90 | raise NotImplementedError 91 | 92 | @property 93 | def full_name(self) -> str: 94 | """The full name of the field (`tablename.fieldname`).""" 95 | 96 | return f"{self.model.tablename}.{self.name}" 97 | 98 | def add_validator(self: _SELF, validator: VALIDATOR[_C]) -> _SELF: 99 | """Add a validator to the value of this field.""" 100 | 101 | self._validators.append(validator) 102 | return self 103 | 104 | def _validate(self, value: _C) -> None: 105 | for v in self._validators: 106 | if not v(value): 107 | raise InvalidFieldValue( 108 | f"Validator on {self.name} failed for {value!r}" 109 | ) 110 | 111 | def _describe(self) -> DescribeField: 112 | return DescribeField( 113 | name=self.name, type_=self.sql_type._sql, not_null=self.not_null 114 | ) 115 | 116 | def _get_default(self) -> _T | UNDEF: 117 | if self._default is not UNDEF.UNDEF: 118 | return self._default 119 | if self._default_factory is not None: 120 | return self._default_factory() 121 | return UNDEF.UNDEF 122 | 123 | def _get_block(self) -> Block[SqlType[Any]]: 124 | return raw(self.full_name) 125 | 126 | def _copy_kwargs(self) -> dict[str, Any]: 127 | return dict( 128 | sql_type=self.sql_type, 129 | default=self._default, 130 | default_factory=self._default_factory, 131 | not_null=self.not_null, 132 | use_repr=self.use_repr, 133 | ) 134 | 135 | 136 | class Field(BaseField[_F, _T, _T]): 137 | __slots__: Iterable[str] = () 138 | 139 | if not TYPE_CHECKING: 140 | 141 | def __get__(self, inst, cls): 142 | if not inst: 143 | return self 144 | if self.name not in inst._raw_values: 145 | raise UndefinedFieldValue(self) 146 | return inst._raw_values[self.name] 147 | 148 | def __set__(self, inst, value): 149 | self._validate(value) 150 | inst._raw_values[self.name] = value 151 | inst._changed_fields.add(self.name) 152 | 153 | def with_converter( 154 | self, converter: Converter[_T, _C] | Type[Converter[_T, _C]] 155 | ) -> ConverterField[_F, _T, _C]: 156 | """Add a converter to the field. 157 | 158 | Args: 159 | converter (Converter | Type[Converter]): The converter. 160 | 161 | Returns: 162 | ConverterField: The new field with the converter. 163 | """ 164 | 165 | if isinstance(converter, type) and issubclass(converter, Converter): 166 | converter = converter() 167 | f: ConverterField[_F, _T, _C] = ConverterField( 168 | **self._copy_kwargs(), converter=converter 169 | ) 170 | if hasattr(self, "name"): 171 | f.name = self.name 172 | if hasattr(self, "model"): 173 | f.model = self.model 174 | return f 175 | 176 | 177 | class ConverterField(BaseField[_F, _T, _C]): 178 | __slots__: Iterable[str] = ("converter",) 179 | 180 | def __init__(self, *args: Any, **kwargs: Any) -> None: 181 | self.converter: Converter[_T, _C] = kwargs.pop("converter") 182 | super().__init__(*args, **kwargs) 183 | 184 | if not TYPE_CHECKING: 185 | 186 | def __get__( 187 | self, inst: Model | None, cls: Type[Model] 188 | ) -> ConverterField | Any: 189 | if not inst: 190 | return self 191 | if self.name not in inst._raw_values: 192 | raise UndefinedFieldValue(self) 193 | return self.converter.from_stored(inst._raw_values[self.name]) 194 | 195 | def __set__(self, inst, value): 196 | self._validate(value) 197 | inst._raw_values[self.name] = self.converter.to_stored(value) 198 | inst._changed_fields.add(self.name) 199 | -------------------------------------------------------------------------------- /apgorm/indexes.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from enum import Enum 5 | from typing import TYPE_CHECKING, Any, Iterable, Sequence, Type 6 | 7 | from .exceptions import BadArgument 8 | from .field import BaseField 9 | from .migrations.describe import DescribeIndex 10 | from .sql.sql import Block, join, raw, wrap 11 | 12 | if TYPE_CHECKING: # pragma: no cover 13 | from .model import Model 14 | 15 | 16 | @dataclass 17 | class _IndexType: 18 | name: str 19 | multi: bool = False 20 | unique: bool = False 21 | 22 | 23 | class IndexType(Enum): 24 | """An enum containing every supported index type.""" 25 | 26 | BTREE = _IndexType("BTREE", multi=True, unique=True) 27 | HASH = _IndexType("HASH") 28 | GIST = _IndexType("GIST", multi=True) 29 | SPGIST = _IndexType("SPGIST") 30 | GIN = _IndexType("GIN", multi=True) 31 | BRIN = _IndexType("BRIN", multi=True) 32 | 33 | 34 | class Index: 35 | """Represents an index for a table. 36 | 37 | Must be added to Database.indexes before the database is initialized. For 38 | example: 39 | ``` 40 | class MyDatabase(Database): 41 | ... 42 | 43 | indexes = [Index(...), Index(...), ...] 44 | ``` 45 | 46 | Args: 47 | table (Type[Model]): The table this index is for. 48 | fields (Sequence[BaseField): A list of fields that this index is for. 49 | type_ (IndexType, optional): The index type. Defaults to 50 | IndexType.BTREE. 51 | unique (bool, optional): Whether or not the index should be unique 52 | (BTREE only). Defaults to False. 53 | 54 | Raises: 55 | BadArgument: Bad arguments were passed to Index. 56 | """ 57 | 58 | __slots__: Iterable[str] = ("type_", "fields", "table", "unique") 59 | 60 | def __init__( 61 | self, 62 | table: Type[Model], 63 | fields: Sequence[BaseField[Any, Any, Any] | Block[Any] | str] 64 | | BaseField[Any, Any, Any] 65 | | Block[Any] 66 | | str, 67 | type_: IndexType = IndexType.BTREE, 68 | unique: bool = False, 69 | ) -> None: 70 | self.type_: _IndexType = type_.value 71 | if not isinstance(fields, Sequence): 72 | fields = [fields] 73 | self.fields = [raw(f) if isinstance(f, str) else f for f in fields] 74 | self.table = table 75 | self.unique = unique 76 | 77 | if not fields: 78 | raise BadArgument("Must specify at least one field for index.") 79 | 80 | if (not self.type_.multi) and len(fields) > 1: 81 | raise BadArgument(f"{type_.name} indexes only support one column.") 82 | 83 | if (not self.type_.unique) and unique: 84 | raise BadArgument( 85 | f"{type_.name} indexes do not support uniqueness." 86 | ) 87 | 88 | def get_name(self) -> str: 89 | """Create a name based on the type, table, and fields of the index. 90 | 91 | Returns: 92 | str: The name. 93 | """ 94 | 95 | fields = (f.name for f in self.fields if isinstance(f, BaseField)) 96 | return "_{type_}_index_{table}__{cols}".format( 97 | type_=self.type_.name, 98 | table=self.table.tablename, 99 | cols="_".join(fields), 100 | ).lower() 101 | 102 | def _creation_sql(self) -> Block[Any]: 103 | names = ( 104 | wrap(f) if isinstance(f, Block) else wrap(raw(f.name)) 105 | for f in self.fields 106 | ) 107 | 108 | return Block( 109 | (raw("UNIQUE INDEX") if self.unique else raw("INDEX")), 110 | raw(self.get_name()), 111 | raw("ON"), 112 | raw(self.table.tablename), 113 | raw("USING"), 114 | raw(self.type_.name), 115 | raw("("), 116 | join(raw(","), *names), 117 | raw(")"), 118 | ) 119 | 120 | def _describe(self) -> DescribeIndex: 121 | return DescribeIndex( 122 | name=self.get_name(), 123 | raw_sql=self._creation_sql().render_no_params(), 124 | ) 125 | -------------------------------------------------------------------------------- /apgorm/manytomany.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | TYPE_CHECKING, 5 | Any, 6 | Generic, 7 | Iterable, 8 | Type, 9 | TypeVar, 10 | cast, 11 | overload, 12 | ) 13 | 14 | if TYPE_CHECKING: # pragma: no cover 15 | from .connection import Connection 16 | from .field import BaseField 17 | from .model import Model 18 | from .utils.lazy_list import LazyList 19 | 20 | 21 | _REF = TypeVar("_REF", bound="Model") 22 | _THROUGH = TypeVar("_THROUGH", bound="Model") 23 | 24 | 25 | class ManyToMany(Generic[_REF, _THROUGH]): 26 | """A useful tool to simplify many-to-many references. 27 | 28 | Args: 29 | here (str): The field name on the current model. 30 | here_ref (str): The model and field name on the "middle" table (in the 31 | example below, the middle table is Player) in the format of 32 | "model.field". 33 | other_ref (str): The model and field name on the middle table that 34 | references the final table (or other table). 35 | other (str): The model and field name on the final table referenced by 36 | the middle table. 37 | 38 | Note: Although unecessary, it is highly recommended to use ForeignKeys on 39 | the middle table where it references the initial and final table. You may 40 | get unexpected behaviour if you don't. 41 | 42 | Example Usage: 43 | ``` 44 | class User(Model): 45 | username = VarChar(32).field() 46 | primary_key = (username,) 47 | 48 | games = ManyToMany["Game", "Player"]( # the typehints are optional 49 | # the column on this table referenced in Player 50 | "username", 51 | 52 | # the column on Player that references "username" 53 | "players.username", 54 | 55 | # the column on Player that references Game.gameid 56 | "players.gameid", 57 | 58 | # the column on Game referenced by Player 59 | "games.gameid", 60 | ) 61 | 62 | class Game(Model): 63 | gameid = Serial().field() 64 | primary_key = (gameid,) 65 | 66 | users = ManyToMany["User", "Player"]( 67 | "gameid", 68 | "players.gameid", 69 | "players.username", 70 | "users.username", 71 | ) 72 | 73 | class Player(Model): 74 | username = VarChar(32).field() 75 | gameid = Int().field() 76 | primary_key = (username, gameid) 77 | 78 | username_fk = ForeignKey(username, User.username) 79 | gameid_fk = ForeignKey(gameid, Game.gameid) 80 | 81 | class MyDatabase(Database): 82 | users = User 83 | games = Game 84 | players = Player 85 | 86 | ... 87 | 88 | circuit = await User.fetch(username="Circuit") 89 | circuits_games = await circuit.games.fetchmany() 90 | ``` 91 | 92 | If you want typehints to work properly, use 93 | `games = ManyToMany["Game"](...)`. 94 | """ 95 | 96 | __slots__: Iterable[str] = ( 97 | "_here", 98 | "_here_ref", 99 | "_other_ref", 100 | "_other", 101 | "_attribute_name", 102 | ) 103 | 104 | _attribute_name: str 105 | # populated by Model on __init_subclass__ 106 | 107 | def __init__( 108 | self, here: str, here_ref: str, other_ref: str, other: str 109 | ) -> None: 110 | self._here = here 111 | self._here_ref = here_ref 112 | self._other_ref = other_ref 113 | self._other = other 114 | 115 | @overload 116 | def __get__( 117 | self, inst: Model, cls: Type[Model] 118 | ) -> _RealManyToMany[_REF, _THROUGH]: 119 | ... 120 | 121 | @overload 122 | def __get__( 123 | self, inst: None, cls: Type[Model] 124 | ) -> ManyToMany[_REF, _THROUGH]: 125 | ... 126 | 127 | def __get__( 128 | self, inst: Model | None, cls: type[Model] 129 | ) -> ManyToMany[_REF, _THROUGH] | _RealManyToMany[_REF, _THROUGH]: 130 | if inst is None: 131 | return self 132 | real_m2m = self._generate_mtm(inst) 133 | setattr(inst, self._attribute_name, real_m2m) 134 | return real_m2m 135 | 136 | def _generate_mtm(self, inst: Model) -> _RealManyToMany[_REF, _THROUGH]: 137 | return _RealManyToMany(self, inst) 138 | 139 | 140 | class _RealManyToMany(Generic[_REF, _THROUGH]): 141 | __slots__: Iterable[str] = ( 142 | "orig", 143 | "model", 144 | "field", 145 | "mm_model", 146 | "mm_h_field", 147 | "mm_o_field", 148 | "ot_model", 149 | "ot_field", 150 | ) 151 | 152 | def __init__( 153 | self, orig: ManyToMany[_REF, _THROUGH], model_inst: Model 154 | ) -> None: 155 | # NOTE: all these casts are ugly, but truthfully 156 | # there isn't a better way to do this. You can't 157 | # actually check that these are Models and Fields 158 | # without creating circular imports (since model.py 159 | # imports this file) 160 | 161 | self.orig = orig 162 | 163 | mm_h_model, _mm_h_field = self.orig._here_ref.split(".") 164 | mm_o_model, _mm_o_field = self.orig._other_ref.split(".") 165 | assert mm_h_model == mm_o_model 166 | 167 | mm_model = cast( 168 | "Type[_THROUGH]", getattr(model_inst.database, mm_h_model) 169 | ) 170 | 171 | mm_h_field = cast( 172 | "BaseField[Any, Any, Any]", getattr(mm_model, _mm_h_field) 173 | ) 174 | mm_o_field = cast( 175 | "BaseField[Any, Any, Any]", getattr(mm_model, _mm_o_field) 176 | ) 177 | 178 | _ot_model, _ot_field = self.orig._other.split(".") 179 | ot_model = cast("Type[_REF]", getattr(model_inst.database, _ot_model)) 180 | 181 | ot_field = cast( 182 | "BaseField[Any, Any, Any]", getattr(ot_model, _ot_field) 183 | ) 184 | 185 | self.model = model_inst 186 | self.field = cast( 187 | "BaseField[Any, Any, Any]", 188 | getattr(model_inst.__class__, self.orig._here), 189 | ) 190 | self.mm_model = mm_model 191 | self.mm_h_field = mm_h_field 192 | self.mm_o_field = mm_o_field 193 | self.ot_model = ot_model 194 | self.ot_field = ot_field 195 | 196 | def __getattr__(self, name: str) -> Any: 197 | return getattr(self.orig, name) 198 | 199 | async def fetchmany( 200 | self, con: Connection | None = None 201 | ) -> LazyList[dict[str, Any], _REF]: 202 | """Fetch all rows from the final table that belong to this instance. 203 | 204 | Returns: 205 | LazyList[dict, Model]: A lazy-list of returned Models. 206 | """ 207 | 208 | return ( 209 | await self.ot_model.fetch_query(con=con) 210 | .where( 211 | self.mm_model.fetch_query() 212 | .where( 213 | self.mm_h_field.eq( 214 | self.model._raw_values[self.field.name] 215 | ), 216 | self.mm_o_field.eq(self.ot_field), 217 | ) 218 | .exists() 219 | ) 220 | .fetchmany() 221 | ) 222 | 223 | async def count(self, con: Connection | None = None) -> int: 224 | """Returns the count. 225 | 226 | Warning: To be efficient, this returns the count of *middle* models, 227 | which may differ from the number of final models if you did not use 228 | ForeignKeys properly. 229 | 230 | Returns: 231 | int: The count. 232 | """ 233 | 234 | return ( 235 | await self.mm_model.fetch_query(con=con) 236 | .where(self.mm_h_field.eq(self.model._raw_values[self.field.name])) 237 | .count() 238 | ) 239 | 240 | async def clear( 241 | self, con: Connection | None = None 242 | ) -> LazyList[dict[str, Any], _THROUGH]: 243 | """Remove all instances of the other model from this instance. 244 | 245 | Both of these lines do the same thing: 246 | ``` 247 | deleted_players = await user.games.clear() 248 | deleted_players = await Player.delete_query().where( 249 | username=user.name 250 | ).execute() 251 | ``` 252 | 253 | Returns: 254 | LazyList[dict, _THROUGH]: A lazy-list of deleted through models (in 255 | the example, it would be a list of Player). 256 | """ 257 | 258 | return ( 259 | await self.mm_model.delete_query(con=con) 260 | .where(self.mm_h_field.eq(self.model._raw_values[self.field.name])) 261 | .execute() 262 | ) 263 | 264 | async def add(self, other: Model, con: Connection | None = None) -> Model: 265 | """Add one or more models to this ManyToMany. 266 | 267 | Each of these lines does the exact same thing: 268 | ``` 269 | player = await user.games.add(game) 270 | # OR 271 | player = await games.users.add(user) 272 | # OR 273 | player = await Player(username=user.name, gameid=game.id).create() 274 | ``` 275 | 276 | Returns: 277 | Model: The reference model that lines this model and the other 278 | model. In the example, the return would be a Player. 279 | """ 280 | 281 | values = { 282 | self.mm_h_field.name: self.model._raw_values[self.field.name], 283 | self.mm_o_field.name: other._raw_values[self.ot_field.name], 284 | } 285 | return await self.mm_model(**values).create(con=con) 286 | 287 | async def remove( 288 | self, other: Model, con: Connection | None = None 289 | ) -> LazyList[dict[str, Any], _THROUGH]: 290 | """Remove one or models from this ManyToMany. 291 | 292 | Each of these lines does the exact same thing: 293 | ``` 294 | deleted_players = await user.games.remove(game) 295 | # OR 296 | deleted_players = await games.user.remove(user) 297 | # OR 298 | deleted_players = await Player.delete_query().where( 299 | username=user.name, gameid=game.id 300 | ).execute() 301 | ``` 302 | 303 | Note: The fact that .remove() returns a list instead of a single model 304 | was intentional. The reason is ManyToMany does not enforce uniqueness 305 | in any way, so there could be multiple Players that link a single user 306 | to a single game. Thus, user.remove(game) could actually end up 307 | deleting multiple players. 308 | """ 309 | 310 | values = { 311 | self.mm_h_field.name: self.model._raw_values[self.field.name], 312 | self.mm_o_field.name: other._raw_values[self.ot_field.name], 313 | } 314 | return ( 315 | await self.mm_model.delete_query(con=con).where(**values).execute() 316 | ) 317 | -------------------------------------------------------------------------------- /apgorm/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/apgorm/migrations/__init__.py -------------------------------------------------------------------------------- /apgorm/migrations/applied_migration.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from apgorm.model import Model 6 | from apgorm.types.numeric import Int 7 | 8 | 9 | class AppliedMigration(Model): 10 | __slots__: Iterable[str] = () 11 | 12 | id_ = Int().field() 13 | 14 | primary_key = (id_,) 15 | -------------------------------------------------------------------------------- /apgorm/migrations/apply_migration.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import asyncpg 6 | 7 | from apgorm.exceptions import MigrationAlreadyApplied, ModelNotFound 8 | 9 | if TYPE_CHECKING: # pragma: no cover 10 | from apgorm.database import Database 11 | 12 | from .migration import Migration 13 | 14 | 15 | async def apply_migration(migration: Migration, db: Database) -> None: 16 | try: 17 | await db._migrations.fetch(id_=migration.migration_id) 18 | except (asyncpg.exceptions.UndefinedTableError, ModelNotFound): 19 | pass 20 | else: 21 | raise MigrationAlreadyApplied(str(migration.path)) 22 | 23 | assert db.pool is not None 24 | async with db.pool.acquire() as con: 25 | async with con.transaction(): 26 | await con.execute(migration.migrations, []) 27 | await db._migrations(id_=migration.migration_id).create(con=con) 28 | -------------------------------------------------------------------------------- /apgorm/migrations/create_migration.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from typing import TYPE_CHECKING 5 | 6 | from apgorm.sql.generators import alter 7 | from apgorm.sql.sql import raw 8 | from apgorm.undefined import UNDEF 9 | 10 | from .describe import DescribeConstraint 11 | from .migration import Migration 12 | 13 | if TYPE_CHECKING: # pragma: no cover 14 | from .describe import Describe 15 | 16 | 17 | def _handle_constraint_list( 18 | tablename: str, 19 | orig: list[DescribeConstraint], 20 | curr: list[DescribeConstraint], 21 | ) -> tuple[list[str], list[str]]: 22 | origd = {c.name: c for c in orig} 23 | currd = {c.name: c for c in curr} 24 | 25 | new_constraints = [ 26 | alter.add_constraint(raw(tablename), c.raw_sql).render_no_params() 27 | for c in currd.values() 28 | if c.name not in origd 29 | ] 30 | drop_constraints = [ 31 | alter.drop_constraint(raw(tablename), raw(c.name)).render_no_params() 32 | for c in origd.values() 33 | if c.name not in currd 34 | ] 35 | 36 | common = ((currd[k], origd[k]) for k in currd.keys() & origd.keys()) 37 | for currc, origc in common: 38 | if currc.raw_sql == origc.raw_sql: 39 | continue 40 | 41 | new_constraints.append( 42 | alter.add_constraint( 43 | raw(tablename), currc.raw_sql 44 | ).render_no_params() 45 | ) 46 | drop_constraints.append( 47 | alter.drop_constraint( 48 | raw(tablename), raw(currc.name) 49 | ).render_no_params() 50 | ) 51 | 52 | return new_constraints, drop_constraints 53 | 54 | 55 | def create_next_migration(cd: Describe, folder: Path) -> str | None: 56 | lm = Migration._load_last_migration(folder) 57 | 58 | curr_tables = {t.name: t for t in cd.tables} 59 | last_tables = {t.name: t for t in lm.describe.tables} if lm else {} 60 | 61 | migrations: list[str] = [] 62 | 63 | # tables # TODO: renamed tables 64 | add_tables = [ 65 | alter.add_table(raw(key)).render_no_params() 66 | for key in curr_tables 67 | if key not in last_tables 68 | ] 69 | drop_tables = [ 70 | alter.drop_table(raw(key)).render_no_params() 71 | for key in last_tables 72 | if key not in curr_tables 73 | ] 74 | 75 | # everything else 76 | curr_indexes = {c.name: c for c in cd.indexes} 77 | last_indexes = {c.name: c for c in lm.describe.indexes} if lm else {} 78 | comm_indexes = ( 79 | (curr_indexes[k], last_indexes[k]) 80 | for k in curr_indexes.keys() & last_indexes.keys() 81 | ) 82 | add_indexes: list[str] = [ 83 | alter.add_index(c.raw_sql).render_no_params() 84 | for name, c in curr_indexes.items() 85 | if name not in last_indexes 86 | ] 87 | drop_indexes: list[str] = [ 88 | alter.drop_index(raw(c.name)).render_no_params() 89 | for name, c in last_indexes.items() 90 | if name not in curr_indexes 91 | ] 92 | for curr, last in comm_indexes: 93 | if curr.raw_sql != last.raw_sql: 94 | drop_indexes.append( 95 | alter.drop_index(raw(curr.name)).render_no_params() 96 | ) 97 | add_indexes.append( 98 | alter.add_index(curr.raw_sql).render_no_params() 99 | ) 100 | 101 | add_unique_constraints: list[str] = [] 102 | add_pk_constraints: list[str] = [] 103 | add_check_constraints: list[str] = [] 104 | add_fk_constraints: list[str] = [] 105 | add_exclude_constraints: list[str] = [] 106 | 107 | drop_unique_constraints: list[str] = [] 108 | drop_pk_constraints: list[str] = [] 109 | drop_check_constraints: list[str] = [] 110 | drop_fk_constraints: list[str] = [] 111 | drop_exclude_constraints: list[str] = [] 112 | 113 | add_fields: list[str] = [] 114 | drop_fields: list[str] = [] 115 | 116 | field_not_nulls: list[str] = [] 117 | 118 | for tablename, currtable in curr_tables.items(): 119 | lasttable = ( 120 | last_tables[tablename] if tablename in last_tables else None 121 | ) 122 | 123 | # fields: 124 | curr_fields = {f.name: f for f in currtable.fields} 125 | last_fields = ( 126 | {f.name: f for f in lasttable.fields} if lasttable else {} 127 | ) 128 | 129 | add_fields.extend( 130 | alter.add_field( 131 | raw(tablename), raw(key), raw(f.type_) 132 | ).render_no_params() 133 | for key, f in curr_fields.items() 134 | if key not in last_fields 135 | ) 136 | drop_fields.extend( 137 | alter.drop_field(raw(tablename), raw(key)).render_no_params() 138 | for key, f in last_fields.items() 139 | if key not in curr_fields 140 | ) 141 | 142 | # field not nulls: 143 | for field in curr_fields.values(): 144 | last_not_null = ( 145 | UNDEF.UNDEF 146 | if field.name not in last_fields 147 | else last_fields[field.name].not_null 148 | ) 149 | set_nn_to: bool | None = None 150 | if last_not_null is UNDEF.UNDEF: 151 | if field.not_null: 152 | set_nn_to = True 153 | elif last_not_null != field.not_null: 154 | set_nn_to = field.not_null 155 | 156 | if set_nn_to is not None: 157 | field_not_nulls.append( 158 | alter.set_field_not_null( 159 | raw(tablename), raw(field.name), set_nn_to 160 | ).render_no_params() 161 | ) 162 | 163 | # constraints (in the order they should be applied) 164 | _new, _drop = _handle_constraint_list( 165 | tablename, 166 | lasttable.unique_constraints if lasttable else [], 167 | currtable.unique_constraints, 168 | ) 169 | add_unique_constraints.extend(_new) 170 | drop_unique_constraints.extend(_drop) 171 | 172 | _new, _drop = _handle_constraint_list( 173 | tablename, 174 | [lasttable.pk_constraint] if lasttable else [], 175 | [currtable.pk_constraint], 176 | ) 177 | add_pk_constraints.extend(_new) 178 | drop_pk_constraints.extend(_drop) 179 | 180 | _new, _drop = _handle_constraint_list( 181 | tablename, 182 | lasttable.check_constraints if lasttable else [], 183 | currtable.check_constraints, 184 | ) 185 | add_check_constraints.extend(_new) 186 | drop_check_constraints.extend(_drop) 187 | 188 | _new, _drop = _handle_constraint_list( 189 | tablename, 190 | lasttable.fk_constraints if lasttable else [], 191 | currtable.fk_constraints, 192 | ) 193 | add_fk_constraints.extend(_new) 194 | drop_fk_constraints.extend(_drop) 195 | 196 | _new, _drop = _handle_constraint_list( 197 | tablename, 198 | lasttable.exclude_constraints if lasttable else [], 199 | currtable.exclude_constraints, 200 | ) 201 | add_exclude_constraints.extend(_new) 202 | drop_exclude_constraints.extend(_drop) 203 | 204 | for tablename, lasttable in last_tables.items(): 205 | if tablename in curr_tables: 206 | continue 207 | _, _drop = _handle_constraint_list( 208 | tablename, lasttable.fk_constraints, [] 209 | ) 210 | drop_fk_constraints.extend(_drop) 211 | 212 | # finalization 213 | migrations.extend(add_tables) 214 | 215 | migrations.extend(drop_fk_constraints) 216 | 217 | migrations.extend(drop_indexes) 218 | 219 | migrations.extend(drop_tables) 220 | 221 | migrations.extend(add_fields) 222 | 223 | migrations.extend(drop_pk_constraints) 224 | migrations.extend(drop_unique_constraints) 225 | migrations.extend(drop_check_constraints) 226 | migrations.extend(drop_exclude_constraints) 227 | 228 | migrations.extend(drop_fields) 229 | 230 | migrations.extend(field_not_nulls) 231 | 232 | migrations.extend(add_indexes) 233 | 234 | migrations.extend(add_unique_constraints) 235 | migrations.extend(add_check_constraints) 236 | migrations.extend(add_pk_constraints) 237 | migrations.extend(add_fk_constraints) 238 | migrations.extend(add_exclude_constraints) 239 | 240 | if not migrations: 241 | return None 242 | return ";\n".join(migrations) + ";" 243 | -------------------------------------------------------------------------------- /apgorm/migrations/describe.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class DescribeField(BaseModel): 7 | """Field description.""" 8 | 9 | name: str 10 | """The name of the field.""" 11 | type_: str 12 | """The raw SQL name of the type (INTEGER, VARCHAR, etc.).""" 13 | not_null: bool 14 | """Whether or not the field should be marked NOT NULL.""" 15 | 16 | 17 | class DescribeConstraint(BaseModel): 18 | """Constraint description.""" 19 | 20 | name: str 21 | """The name of the constraint.""" 22 | raw_sql: str 23 | """The raw SQL of the constraint.""" 24 | 25 | 26 | class DescribeIndex(BaseModel): 27 | """Index description.""" 28 | 29 | name: str 30 | """The name of the index.""" 31 | raw_sql: str 32 | """The raw SQL of the index.""" 33 | 34 | 35 | class DescribeTable(BaseModel): 36 | """Table description.""" 37 | 38 | name: str 39 | """The name of the table.""" 40 | fields: List[DescribeField] 41 | """List of field descriptions.""" 42 | fk_constraints: List[DescribeConstraint] 43 | """List of ForeignKey descriptions.""" 44 | pk_constraint: DescribeConstraint 45 | """List of PrimaryKey descriptions.""" 46 | unique_constraints: List[DescribeConstraint] 47 | """List of Unique descriptions.""" 48 | check_constraints: List[DescribeConstraint] 49 | """List of Check descriptions.""" 50 | exclude_constraints: List[DescribeConstraint] 51 | """List of Exclude descriptions.""" 52 | 53 | @property 54 | def constraints(self) -> List[DescribeConstraint]: 55 | """List of all constraints (FK, PK, unique, check, and exclude). 56 | 57 | Returns: 58 | List[DescribeConstraint]: List of constraints. 59 | """ 60 | 61 | return ( 62 | self.fk_constraints 63 | + [self.pk_constraint] 64 | + self.unique_constraints 65 | + self.check_constraints 66 | + self.exclude_constraints 67 | ) 68 | 69 | 70 | class Describe(BaseModel): 71 | """Database description.""" 72 | 73 | tables: List[DescribeTable] 74 | """List of table descriptions.""" 75 | indexes: List[DescribeIndex] 76 | """List of index descriptions.""" 77 | -------------------------------------------------------------------------------- /apgorm/migrations/migration.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from contextlib import suppress 5 | from pathlib import Path 6 | from typing import Iterable 7 | 8 | from .describe import Describe 9 | 10 | 11 | class Migration: 12 | """Represents a single migrations.""" 13 | 14 | __slots__: Iterable[str] = ("describe", "migrations", "path") 15 | 16 | def __init__( 17 | self, describe: Describe, migrations: str, path: Path 18 | ) -> None: 19 | self.describe = describe 20 | """The Database.describe() at the time this migration was created.""" 21 | self.migrations = migrations 22 | """The raw SQL used for applying the migration.""" 23 | self.path = path 24 | """The path at which the migration can be found.""" 25 | 26 | @property 27 | def migration_id(self) -> int: 28 | """The id of the migration. 29 | 30 | Returns: 31 | int: The migration id. 32 | """ 33 | 34 | return self._id_from_path(self.path) 35 | 36 | @classmethod 37 | def _from_path(cls, path: Path) -> Migration: 38 | with (path / "describe.json").open("r") as f: 39 | describe = Describe(**json.loads(f.read())) 40 | 41 | with (path / "migrations.sql").open("r") as f: 42 | sql = f.read() 43 | 44 | return cls(describe, sql, path) 45 | 46 | @staticmethod 47 | def _path_from_id(migration_id: int, folder: Path, padding: int) -> Path: 48 | return folder / str(migration_id).zfill(padding) 49 | 50 | @staticmethod 51 | def _id_from_path(path: Path) -> int: 52 | return int(path.name) 53 | 54 | @classmethod 55 | def _load_all_migrations(cls, folder: Path) -> list[Migration]: 56 | migrations: list[Migration] = [] 57 | for p in folder.glob("*"): 58 | if p.is_file(): 59 | continue 60 | 61 | with suppress(FileNotFoundError): 62 | migrations.append(cls._from_path(p)) 63 | 64 | return migrations 65 | 66 | @classmethod 67 | def _load_last_migration(cls, folder: Path) -> Migration | None: 68 | migrations = cls._load_all_migrations(folder) 69 | if not migrations: 70 | return None 71 | migrations.sort(key=lambda m: m.migration_id) 72 | return migrations[-1] 73 | 74 | @classmethod 75 | def _create_migration( 76 | cls, 77 | describe: Describe, 78 | migrations: str, 79 | folder: Path, 80 | padding: int = 0, 81 | ) -> Migration: 82 | last_migration = cls._load_last_migration(folder) 83 | if last_migration: 84 | next_id = last_migration.migration_id + 1 85 | else: 86 | next_id = 0 87 | next_path = folder / str(next_id).zfill(padding) 88 | next_path.mkdir(parents=True, exist_ok=False) 89 | with (next_path / "describe.json").open("w+") as f: 90 | f.write(json.dumps(describe.dict(), indent=4)) 91 | with (next_path / "migrations.sql").open("w+") as f: 92 | f.write(migrations) 93 | 94 | return cls(describe, migrations, next_path) 95 | 96 | def __eq__(self, other: object) -> bool: 97 | if not isinstance(other, Migration): 98 | raise TypeError(f"Unsupported type {type(other)}") 99 | return self.migration_id == other.migration_id 100 | -------------------------------------------------------------------------------- /apgorm/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations, print_function 2 | 3 | from contextlib import suppress 4 | from typing import TYPE_CHECKING, Any, Iterable, Type, TypeVar 5 | 6 | from .constraints.check import Check 7 | from .constraints.constraint import Constraint 8 | from .constraints.exclude import Exclude 9 | from .constraints.foreign_key import ForeignKey 10 | from .constraints.primary_key import PrimaryKey 11 | from .constraints.unique import Unique 12 | from .exceptions import ModelNotFound, SpecifiedPrimaryKey 13 | from .field import BaseField 14 | from .manytomany import ManyToMany 15 | from .migrations.describe import DescribeConstraint, DescribeTable 16 | from .sql.query_builder import ( 17 | DeleteQueryBuilder, 18 | FetchQueryBuilder, 19 | InsertQueryBuilder, 20 | UpdateQueryBuilder, 21 | ) 22 | from .undefined import UNDEF 23 | from .utils.lazy_list import LazyList 24 | 25 | if TYPE_CHECKING: # pragma: no cover 26 | from .connection import Connection 27 | from .database import Database 28 | 29 | 30 | _SELF = TypeVar("_SELF", bound="Model") 31 | 32 | 33 | class Model: 34 | """Base class for all models. To create a new model, subclass this class 35 | and add it to your database like this: 36 | 37 | ``` 38 | class User(Model): 39 | username = VarChar(32).field() 40 | ... 41 | 42 | primary_key = (username,) 43 | 44 | ... 45 | 46 | class MyDatabase(Database): 47 | users = User 48 | ... 49 | ``` 50 | 51 | Use Model() to create an instance of the model. There are two reasons you 52 | might do this: 53 | 1. You used `Database.fetch...` directly instead of `Model.fetch`, and 54 | you want to convert the result to a model. 55 | 2. You want to create a new model (`Model(**values).create()`). 56 | 57 | If you wish to fetch an existing model, please use `Model.fetch` or 58 | `Model.fetch_query`. 59 | """ 60 | 61 | __slots__: Iterable[str] = ("_raw_values", "_changed_fields") 62 | 63 | _all_fields: dict[str, BaseField[Any, Any, Any]] 64 | _all_constraints: dict[str, Constraint] 65 | _all_mtm: dict[str, ManyToMany[Any, Any]] 66 | _changed_fields: set[str] 67 | 68 | _raw_values: dict[str, Any] 69 | _columns: set[str] 70 | 71 | tablename: str # populated by Database 72 | """The name of the table, populated by Database.""" 73 | database: Database # populated by Database 74 | """The database instance, populated by Database.""" 75 | 76 | primary_key: tuple[BaseField[Any, Any, Any], ...] 77 | """The primary key for the model. All models MUST have a primary key.""" 78 | 79 | def __init_subclass__(cls) -> None: 80 | cls._all_fields = {} 81 | cls._all_constraints = {} 82 | cls._all_mtm = {} 83 | cls._columns = set() 84 | 85 | for key, value in cls.__dict__.items(): 86 | if isinstance(value, BaseField): 87 | value.name = key 88 | value.model = cls 89 | cls._all_fields[key] = value 90 | cls._columns.add(value.name) 91 | 92 | elif isinstance(value, Constraint): 93 | if isinstance(value, PrimaryKey): 94 | raise SpecifiedPrimaryKey( 95 | cls.__name__, map(str, value.fields) 96 | ) 97 | value.name = key 98 | cls._all_constraints[key] = value 99 | 100 | elif isinstance(value, ManyToMany): 101 | value._attribute_name = key 102 | cls._all_mtm[key] = value 103 | 104 | def __init__(self, **values: Any) -> None: 105 | self._raw_values: dict[str, Any] = {} 106 | self._changed_fields = set() 107 | 108 | for f in self._all_fields.values(): 109 | if f.name in values: 110 | f.__set__(self, values[f.name]) 111 | elif (d := f._get_default()) is not UNDEF.UNDEF: 112 | self._raw_values[f.name] = d 113 | self._changed_fields.clear() 114 | 115 | async def delete(self: _SELF, con: Connection | None = None) -> _SELF: 116 | """Delete the model. Does not update the values of this model, 117 | but the returned model will have updated values. 118 | 119 | Returns: 120 | Model: The deleted model (with updated values). 121 | """ 122 | 123 | deleted = ( 124 | await self.delete_query(con=con) 125 | .where(**(f := self._pk_fields())) 126 | .execute() 127 | ) 128 | if not deleted: 129 | raise ModelNotFound(self.__class__, f) 130 | return deleted[0] 131 | 132 | async def save(self, con: Connection | None = None) -> None: 133 | """Save any changed fields of the model. Updates the values 134 | of this model.""" 135 | 136 | changed_fields = self._get_changed_fields() 137 | if not changed_fields: 138 | return 139 | q = self.update_query(con=con).where(**self._pk_fields()) 140 | q.set(**changed_fields) 141 | result = await q.execute() 142 | self._raw_values.update(result[0]._raw_values) 143 | self._changed_fields.clear() 144 | 145 | async def create(self: _SELF, con: Connection | None = None) -> _SELF: 146 | """Insert the model into the database. Updates the values on this 147 | model. 148 | 149 | Returns: 150 | Model: The model that was inserted. Useful if you want to write 151 | `model = await Model(...).create()` 152 | """ 153 | 154 | q = self.insert_query(con=con) 155 | q.set( 156 | **{ 157 | f.name: self._raw_values[f.name] 158 | for f in self._all_fields.values() 159 | if f.name in self._raw_values 160 | } 161 | ) 162 | self._raw_values.update((await q.execute())._raw_values) 163 | 164 | return self 165 | 166 | async def refetch(self: _SELF, con: Connection | None = None) -> None: 167 | """Updates the model instance by fetching changed values from the 168 | database.""" 169 | 170 | res = await self.fetch(con, **self._pk_fields()) 171 | self._raw_values.update(res._raw_values) 172 | 173 | @classmethod 174 | async def exists( 175 | cls: Type[_SELF], con: Connection | None = None, /, **values: Any 176 | ) -> _SELF | None: 177 | """Check if a model with the given values exists. 178 | 179 | ``` 180 | assert (await User.exists(nickname=None)) is not None 181 | ``` 182 | 183 | This method can also be used to get the `model or None`, whereas 184 | `.fetch()` will raise an exception if the model does not exist. 185 | 186 | Returns: 187 | Model | None: The model if it existed, otherwise None. 188 | """ 189 | 190 | with suppress(ModelNotFound): 191 | return await cls.fetch(con, **values) 192 | return None 193 | 194 | @classmethod 195 | async def fetch( 196 | cls: Type[_SELF], con: Connection | None = None, /, **values: Any 197 | ) -> _SELF: 198 | """Fetch an exiting model from the database. 199 | 200 | Example: 201 | ``` 202 | user = await User.fetch(username="Circuit") 203 | ``` 204 | 205 | Raises: 206 | ModelNotFound: No model for the given parameters were found. 207 | 208 | Returns: 209 | Model: The model. 210 | """ 211 | 212 | res = await cls.fetch_query(con=con).where(**values).fetchone() 213 | if res is None: 214 | raise ModelNotFound(cls, values) 215 | return res 216 | 217 | @classmethod 218 | async def fetchmany( 219 | cls: Type[_SELF], con: Connection | None = None, /, **values: Any 220 | ) -> LazyList[dict[str, Any], _SELF]: 221 | """Fetch multiple models from the database. 222 | 223 | Returns: 224 | LazyList[dict[str, Any], Model]: The list of models. 225 | """ 226 | 227 | return await cls.fetch_query(con=con).where(**values).fetchmany() 228 | 229 | @classmethod 230 | async def count( 231 | cls: Type[_SELF], con: Connection | None = None, /, **values: Any 232 | ) -> int: 233 | """Get the number of rows matching these params. 234 | 235 | Returns: 236 | int: The count. 237 | """ 238 | 239 | return await cls.fetch_query(con=con).where(**values).count() 240 | 241 | @classmethod 242 | def fetch_query( 243 | cls: Type[_SELF], con: Connection | None = None 244 | ) -> FetchQueryBuilder[_SELF]: 245 | """Returns a FetchQueryBuilder.""" 246 | 247 | return FetchQueryBuilder(model=cls, con=con) 248 | 249 | @classmethod 250 | def delete_query( 251 | cls: Type[_SELF], con: Connection | None = None 252 | ) -> DeleteQueryBuilder[_SELF]: 253 | """Returns a DeleteQueryBuilder.""" 254 | 255 | return DeleteQueryBuilder(model=cls, con=con) 256 | 257 | @classmethod 258 | def update_query( 259 | cls: Type[_SELF], con: Connection | None = None 260 | ) -> UpdateQueryBuilder[_SELF]: 261 | """Returns an UpdateQueryBuilder.""" 262 | 263 | return UpdateQueryBuilder(model=cls, con=con) 264 | 265 | @classmethod 266 | def insert_query( 267 | cls: Type[_SELF], con: Connection | None = None 268 | ) -> InsertQueryBuilder[_SELF]: 269 | """Returns an InsertQueryBuilder.""" 270 | 271 | return InsertQueryBuilder(model=cls, con=con) 272 | 273 | @classmethod 274 | def _from_raw(cls: type[_SELF], **values: Any) -> _SELF: 275 | n = super().__new__(cls) 276 | n._raw_values = values 277 | n._changed_fields = set() 278 | return n 279 | 280 | @classmethod 281 | def _primary_key(cls) -> PrimaryKey: 282 | pk = PrimaryKey(*cls.primary_key) 283 | pk.name = f"_{cls.tablename}_{{}}_primary_key".format( 284 | "_".join(f.name for f in cls.primary_key) 285 | ) 286 | return pk 287 | 288 | @classmethod 289 | def _describe(cls) -> DescribeTable: 290 | unique: list[DescribeConstraint] = [] 291 | check: list[DescribeConstraint] = [] 292 | fk: list[DescribeConstraint] = [] 293 | exclude: list[DescribeConstraint] = [] 294 | for c in cls._all_constraints.values(): 295 | if isinstance(c, Check): 296 | check.append(c._describe()) 297 | elif isinstance(c, ForeignKey): 298 | fk.append(c._describe()) 299 | elif isinstance(c, Unique): 300 | unique.append(c._describe()) 301 | elif isinstance(c, Exclude): 302 | exclude.append(c._describe()) 303 | 304 | return DescribeTable( 305 | name=cls.tablename, 306 | fields=[f._describe() for f in cls._all_fields.values()], 307 | fk_constraints=fk, 308 | pk_constraint=cls._primary_key()._describe(), 309 | unique_constraints=unique, 310 | check_constraints=check, 311 | exclude_constraints=exclude, 312 | ) 313 | 314 | def _pk_fields(self) -> dict[str, Any]: 315 | return {f.name: self._raw_values[f.name] for f in self.primary_key} 316 | 317 | def _get_changed_fields(self) -> dict[str, Any]: 318 | return {n: self._raw_values[n] for n in self._changed_fields} 319 | 320 | # magic methods 321 | def __repr__(self) -> str: 322 | return ( 323 | f"<{self.__class__.__name__} " 324 | + " ".join( 325 | f"{f.name}:{getattr(self, f.name)!r}" 326 | for f in self._all_fields.values() 327 | if f.use_repr and f.name in self._raw_values 328 | ) 329 | + ">" 330 | ) 331 | 332 | def __eq__(self, other: object) -> bool: 333 | if not isinstance(other, self.__class__): 334 | raise TypeError(f"Unsupported type {type(other)}.") 335 | 336 | return self._pk_fields() == other._pk_fields() 337 | -------------------------------------------------------------------------------- /apgorm/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/apgorm/py.typed -------------------------------------------------------------------------------- /apgorm/sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/apgorm/sql/__init__.py -------------------------------------------------------------------------------- /apgorm/sql/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/apgorm/sql/generators/__init__.py -------------------------------------------------------------------------------- /apgorm/sql/generators/alter.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from apgorm.sql.sql import SQL, Block, raw 6 | 7 | 8 | def add_table(tablename: Block[Any]) -> Block[Any]: 9 | return Block(raw("CREATE TABLE"), tablename, raw("()")) 10 | 11 | 12 | def drop_table(tablename: Block[Any]) -> Block[Any]: 13 | return Block(raw("DROP TABLE"), tablename) 14 | 15 | 16 | def add_index(raw_sql: str) -> Block[Any]: 17 | return Block(raw("CREATE"), raw(raw_sql)) 18 | 19 | 20 | def drop_index(name: Block[Any]) -> Block[Any]: 21 | return Block(raw("DROP INDEX"), name) 22 | 23 | 24 | def _alter_table(tablename: Block[Any], sql: SQL[Any]) -> Block[Any]: 25 | return Block(raw("ALTER TABLE"), tablename, sql) 26 | 27 | 28 | def add_constraint( 29 | tablename: Block[Any], constraint_raw_sql: str 30 | ) -> Block[Any]: 31 | return _alter_table(tablename, Block(raw("ADD"), raw(constraint_raw_sql))) 32 | 33 | 34 | def drop_constraint( 35 | tablename: Block[Any], constraint_name: Block[Any] 36 | ) -> Block[Any]: 37 | return _alter_table( 38 | tablename, Block(raw("DROP CONSTRAINT"), constraint_name) 39 | ) 40 | 41 | 42 | def add_field( 43 | tablename: Block[Any], fieldname: Block[Any], type_: Block[Any] 44 | ) -> Block[Any]: 45 | return _alter_table(tablename, Block(raw("ADD COLUMN"), fieldname, type_)) 46 | 47 | 48 | def drop_field(tablename: Block[Any], fieldname: Block[Any]) -> Block[Any]: 49 | return _alter_table(tablename, Block(raw("DROP COLUMN"), fieldname)) 50 | 51 | 52 | def _alter_field( 53 | tablename: Block[Any], fieldname: Block[Any], sql: SQL[Any] 54 | ) -> Block[Any]: 55 | return _alter_table(tablename, Block(raw("ALTER COLUMN"), fieldname, sql)) 56 | 57 | 58 | def set_field_not_null( 59 | tablename: Block[Any], fieldname: Block[Any], not_null: bool 60 | ) -> Block[Any]: 61 | return _alter_field( 62 | tablename, 63 | fieldname, 64 | Block(raw("SET NOT NULL") if not_null else raw("DROP NOT NULL")), 65 | ) 66 | -------------------------------------------------------------------------------- /apgorm/sql/generators/query.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, Literal, Sequence, Type, overload 4 | 5 | from apgorm.field import BaseField 6 | from apgorm.sql.sql import SQL, Block, join, raw, wrap 7 | from apgorm.types.boolean import Bool 8 | from apgorm.undefined import UNDEF 9 | 10 | if TYPE_CHECKING: # pragma: no cover 11 | from apgorm.model import Model 12 | 13 | 14 | @overload 15 | def select( 16 | *, 17 | from_: Model | Type[Model] | Block[Any], 18 | fields: Sequence[BaseField[Any, Any, Any] | Block[Any]] | None = ..., 19 | where: Block[Bool] | None = ..., 20 | order_by: SQL[Any] | UNDEF = ..., 21 | reverse: bool = ..., 22 | limit: int | None = ..., 23 | ) -> Block[Any]: 24 | ... 25 | 26 | 27 | @overload 28 | def select( 29 | *, 30 | from_: Model | Type[Model] | Block[Any], 31 | count: Literal[True] = True, 32 | where: Block[Bool] | None = ..., 33 | limit: int | None = ..., 34 | ) -> Block[Any]: 35 | ... 36 | 37 | 38 | def select( 39 | *, 40 | from_: Model | Type[Model] | Block[Any], 41 | fields: Sequence[BaseField[Any, Any, Any] | Block[Any]] | None = None, 42 | count: bool = False, 43 | where: Block[Bool] | None = None, 44 | order_by: SQL[Any] | UNDEF = UNDEF.UNDEF, 45 | reverse: bool = False, 46 | limit: int | None = None, 47 | ) -> Block[Any]: 48 | sql = Block[Any](raw("SELECT")) 49 | 50 | if count: 51 | sql += Block(raw("COUNT(*)")) 52 | elif fields is not None: 53 | sql += join(raw(","), *fields, wrap=True) 54 | else: 55 | sql += Block(raw("*")) 56 | 57 | tablename = from_ if isinstance(from_, Block) else raw(from_.tablename) 58 | sql += Block(raw("FROM"), tablename) 59 | 60 | if where is not None: 61 | sql += Block(raw("WHERE"), wrap(where)) 62 | 63 | if order_by is not UNDEF.UNDEF: 64 | sql += Block(raw("ORDER BY"), order_by) 65 | sql += raw("DESC" if reverse else "ASC") 66 | 67 | if limit is not None: 68 | sql += Block(raw("LIMIT"), raw(str(limit))) 69 | 70 | return wrap(sql) 71 | 72 | 73 | def delete( 74 | from_: Model | Type[Model] | Block[Any], 75 | where: Block[Bool] | None = None, 76 | return_fields: Sequence[BaseField[Any, Any, Any] | Block[Any]] 77 | | None = None, 78 | ) -> Block[Any]: 79 | tablename = from_ if isinstance(from_, Block) else raw(from_.tablename) 80 | sql = Block[Any](raw("DELETE FROM"), tablename) 81 | if where is not None: 82 | sql += Block(raw("WHERE"), where) 83 | if return_fields is not None: 84 | sql += Block(raw("RETURNING"), join(raw(","), *return_fields)) 85 | return wrap(sql) 86 | 87 | 88 | def update( 89 | table: Model | Type[Model] | Block[Any], 90 | values: dict[SQL[Any], SQL[Any]], 91 | where: Block[Bool] | None = None, 92 | return_fields: Sequence[BaseField[Any, Any, Any] | Block[Any]] 93 | | None = None, 94 | ) -> Block[Any]: 95 | tablename = table if isinstance(table, Block) else raw(table.tablename) 96 | sql = Block[Any](raw("UPDATE"), tablename, raw("SET")) 97 | 98 | set_logic: list[Block[Any]] = [] 99 | for key, value in values.items(): 100 | set_logic.append(Block(key, raw("="), value)) 101 | 102 | sql += join(raw(","), *set_logic) 103 | 104 | if where is not None: 105 | sql += Block(raw("WHERE"), where) 106 | 107 | if return_fields is not None: 108 | sql += Block(raw("RETURNING"), join(raw(","), *return_fields)) 109 | 110 | return wrap(sql) 111 | 112 | 113 | def insert( 114 | into: Model | Type[Model] | Block[Any], 115 | fields: Sequence[BaseField[Any, Any, Any] | Block[Any]], 116 | values: Sequence[SQL[Any]], 117 | return_fields: Sequence[BaseField[Any, Any, Any] | Block[Any]] 118 | | None = None, 119 | ) -> Block[Any]: 120 | tablename = into if isinstance(into, Block) else raw(into.tablename) 121 | 122 | sql = Block[Any](raw("INSERT INTO"), tablename) 123 | if fields: 124 | sql += join(raw(","), *fields, wrap=True) 125 | 126 | if values: 127 | sql += Block[Any](raw("VALUES"), join(raw(","), *values, wrap=True)) 128 | else: 129 | sql += Block[Any](raw("DEFAULT VALUES")) 130 | 131 | if return_fields is not None: 132 | sql += raw("RETURNING") 133 | if isinstance(return_fields, (BaseField, Block)): 134 | sql += return_fields 135 | else: 136 | sql += join(raw(","), *return_fields) 137 | 138 | return wrap(sql) 139 | -------------------------------------------------------------------------------- /apgorm/sql/query_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | TYPE_CHECKING, 5 | Any, 6 | AsyncGenerator, 7 | Callable, 8 | Generic, 9 | Iterable, 10 | Type, 11 | TypeVar, 12 | cast, 13 | ) 14 | 15 | from apgorm.connection import Connection 16 | from apgorm.undefined import UNDEF 17 | from apgorm.utils.lazy_list import LazyList 18 | 19 | from .generators.query import delete, insert, select, update 20 | from .sql import SQL, Block, and_, raw, sql, wrap 21 | 22 | if TYPE_CHECKING: # pragma: no cover 23 | from apgorm.model import Model 24 | from apgorm.types.boolean import Bool 25 | 26 | _T = TypeVar("_T", bound="Model") 27 | 28 | 29 | def _dict_model_converter(model: Type[_T]) -> Callable[[dict[str, Any]], _T]: 30 | def converter(values: dict[str, Any]) -> _T: 31 | return model._from_raw(**values) 32 | 33 | return converter 34 | 35 | 36 | class BaseQueryBuilder(Generic[_T]): 37 | """Base class for query builders.""" 38 | 39 | __slots__: Iterable[str] = ("model", "con") 40 | 41 | def __init__(self, model: Type[_T], con: Connection | None = None) -> None: 42 | self.model = model 43 | self.con = con or model.database 44 | 45 | def _get_block(self) -> Block[Any]: 46 | """Convert the data in the query builder to a Block.""" 47 | 48 | raise NotImplementedError # pragma: no cover 49 | 50 | 51 | _S = TypeVar("_S", bound="FilterQueryBuilder[Any]") 52 | 53 | 54 | class FilterQueryBuilder(BaseQueryBuilder[_T]): 55 | """Base class for query builders that have "where logic".""" 56 | 57 | __slots__: Iterable[str] = ("_filters",) 58 | 59 | def __init__(self, model: Type[_T], con: Connection | None = None) -> None: 60 | super().__init__(model, con) 61 | 62 | self._filters: list[Block[Bool]] = [] 63 | 64 | def where(self: _S, *filters: Block[Bool], **values: SQL[Any]) -> _S: 65 | """Extend the current where logic. 66 | 67 | Example: 68 | ``` 69 | builder.where(User.nickname.neq("Nickname"), status=4) 70 | # nickname != "Nickname", status == 4 71 | ``` 72 | 73 | Returns: 74 | FilterQueryBuilder: Returns the query builder to allow for 75 | chaining. 76 | """ 77 | 78 | # NOTE: Although **values might look like an SQL-injection 79 | # vulnerability, it's really not. Since the keys for **values 80 | # can only contain A-Za-z_ characters, there's no possibly way 81 | # to perform sql injection, even if the keys are user input. 82 | self._filters.extend(filters) 83 | for k, v in values.items(): 84 | self._filters.append(raw(k).eq(v)) 85 | 86 | return self 87 | 88 | def _where_logic(self) -> Block[Bool] | None: 89 | if not self._filters: 90 | return None 91 | return and_(*self._filters) 92 | 93 | 94 | class FetchQueryBuilder(FilterQueryBuilder[_T]): 95 | """Query builder for fetching models.""" 96 | 97 | __slots__: Iterable[str] = ("_order_by_logic", "_reverse") 98 | 99 | def __init__(self, model: Type[_T], con: Connection | None = None) -> None: 100 | super().__init__(model, con) 101 | 102 | self._order_by_logic: SQL[Any] | UNDEF = UNDEF.UNDEF 103 | self._reverse: bool = False 104 | 105 | def order_by( 106 | self, logic: SQL[Any], reverse: bool = False 107 | ) -> FetchQueryBuilder[_T]: 108 | """Specify the order logic of the query. 109 | 110 | Args: 111 | field (Block | BaseField): The field or raw field name to order by. 112 | reverse (bool, optional): If set, will return results decending 113 | instead of ascending. 114 | 115 | Returns: 116 | FetchQueryBuilder: Returns the query builder to allow for chaining. 117 | """ 118 | 119 | self._order_by_logic = logic 120 | self._reverse = reverse 121 | return self 122 | 123 | def exists(self) -> Block[Bool]: 124 | """Returns this query wrapped in EXISTS (). Useful for subqueries: 125 | 126 | ``` 127 | user = await User.fetch(name="Circuit") 128 | games = await Game.fetch_query().where( 129 | Player.fetch_query().where( 130 | gameid=Game.id_, username=user.name 131 | ).exists() 132 | ).fetchmany() 133 | ``` 134 | 135 | Returns: 136 | Block[Bool]: The subquery. 137 | """ 138 | 139 | return sql(raw("EXISTS"), wrap(self._get_block())) 140 | 141 | async def fetchmany( 142 | self, limit: int | None = None 143 | ) -> LazyList[dict[str, Any], _T]: 144 | """Execute the query and return a list of models. 145 | 146 | Args: 147 | limit (int, optional): The maximum number of models to return. 148 | Defaults to None. 149 | 150 | Raises: 151 | TypeError: You specified a limit that wasn't an integer (must be a 152 | python int). 153 | 154 | Returns: 155 | LazyList[dict, Model]: The list of models matching the query. 156 | """ 157 | 158 | if not (limit is None or isinstance(limit, int)): 159 | # NOTE: although limit as a string would work, there is a good 160 | # chance that it's a string because it was user input, meaning 161 | # that allowing limit to be a string would create an SQL-injection 162 | # vulnerability. 163 | raise TypeError("Limit can only be an int.") 164 | res = await self.con.fetchmany(*self._get_block(limit).render()) 165 | return LazyList(res, _dict_model_converter(self.model)) 166 | 167 | async def fetchone(self) -> _T | None: 168 | """Fetch the first model found. 169 | 170 | Returns: 171 | Model | None: Returns the model, or None if none were found. 172 | """ 173 | 174 | res = await self.con.fetchrow(*self._get_block().render()) 175 | if res is None: 176 | return None 177 | return self.model._from_raw(**res) 178 | 179 | async def count(self) -> int: 180 | """SELECT COUNT(*) ... 181 | 182 | Returns: 183 | int: The count. 184 | """ 185 | 186 | return cast( 187 | int, await self.con.fetchval(*self._get_block(count=True).render()) 188 | ) 189 | 190 | async def cursor(self) -> AsyncGenerator[_T, None]: 191 | """Return an iterator of the resulting models. 192 | 193 | Yields: 194 | Model: The (next) model from the iterator. 195 | """ 196 | 197 | con = self.con if isinstance(self.con, Connection) else None 198 | async with self.model.database.cursor( 199 | *self._get_block().render(), con=con 200 | ) as cursor: 201 | async for res in cursor: 202 | yield self.model._from_raw(**res) 203 | 204 | def _get_block( 205 | self, limit: int | None = None, count: bool = False 206 | ) -> Block[Any]: 207 | if count: 208 | return select( 209 | from_=self.model, 210 | where=self._where_logic(), 211 | count=True, 212 | limit=limit, 213 | ) 214 | return select( 215 | from_=self.model, 216 | where=self._where_logic(), 217 | order_by=self._order_by_logic, 218 | reverse=self._reverse, 219 | limit=limit, 220 | ) 221 | 222 | 223 | class DeleteQueryBuilder(FilterQueryBuilder[_T]): 224 | """Query builder for deleting models.""" 225 | 226 | __slots__: Iterable[str] = () 227 | 228 | async def execute(self) -> LazyList[dict[str, Any], _T]: 229 | """Execute the deletion query. 230 | 231 | Returns: 232 | LazyList[dict, Model]: List of models deleted. 233 | """ 234 | 235 | res = await self.con.fetchmany(*self._get_block().render()) 236 | return LazyList(res, _dict_model_converter(self.model)) 237 | 238 | def _get_block(self) -> Block[Any]: 239 | return delete( 240 | self.model, 241 | self._where_logic(), 242 | list(self.model._all_fields.values()), 243 | ) 244 | 245 | 246 | class UpdateQueryBuilder(FilterQueryBuilder[_T]): 247 | """Query builder for updating models.""" 248 | 249 | __slots__: Iterable[str] = ("_set_values",) 250 | 251 | def __init__(self, model: Type[_T], con: Connection | None = None) -> None: 252 | super().__init__(model, con) 253 | 254 | self._set_values: dict[Block[Any], SQL[Any]] = {} 255 | 256 | def set(self, **values: SQL[Any]) -> UpdateQueryBuilder[_T]: 257 | """Specify changes in the model. 258 | 259 | Example: 260 | ``` 261 | builder.set(username="New Name") 262 | ``` 263 | 264 | Returns: 265 | UpdateQueryBuilder: Returns the query builder to allow for 266 | chaining. 267 | """ 268 | 269 | self._set_values.update({raw(k): v for k, v in values.items()}) 270 | return self 271 | 272 | async def execute(self) -> LazyList[dict[str, Any], _T]: 273 | """Execute the query. 274 | 275 | Returns: 276 | LazyList[dict, Model]: List of updated models. 277 | """ 278 | 279 | res = await self.con.fetchmany(*self._get_block().render()) 280 | return LazyList(res, _dict_model_converter(self.model)) 281 | 282 | def _get_block(self) -> Block[Any]: 283 | return update( 284 | self.model, 285 | {k: v for k, v in self._set_values.items()}, 286 | where=self._where_logic(), 287 | return_fields=list(self.model._all_fields.values()), 288 | ) 289 | 290 | 291 | class InsertQueryBuilder(BaseQueryBuilder[_T]): 292 | """Query builder for creating a model.""" 293 | 294 | __slots__: Iterable[str] = ("_set_values",) 295 | 296 | def __init__(self, model: Type[_T], con: Connection | None = None) -> None: 297 | super().__init__(model, con) 298 | 299 | self._set_values: dict[Block[Any], SQL[Any]] = {} 300 | 301 | def set(self, **values: SQL[Any]) -> InsertQueryBuilder[_T]: 302 | """Specify values to be set in the database. 303 | 304 | ``` 305 | await User.insert_query().set(username="Circuit").execute() 306 | ``` 307 | 308 | Returns: 309 | InsertQueryBuilder: Returns the query builder to allow for 310 | chaining. 311 | """ 312 | 313 | self._set_values.update({raw(k): v for k, v in values.items()}) 314 | return self 315 | 316 | async def execute(self) -> _T: 317 | """Execute the query. 318 | 319 | Returns: 320 | Model: The model that was inserted. 321 | """ 322 | 323 | res = await self.con.fetchrow(*self._get_block().render()) 324 | assert res is not None 325 | return self.model._from_raw(**res) 326 | 327 | def _get_block(self) -> Block[Any]: 328 | value_names = list(self._set_values.keys()) 329 | value_values = list(self._set_values.values()) 330 | 331 | return insert( 332 | self.model, 333 | value_names, 334 | value_values, 335 | return_fields=list(self.model._all_fields.values()), 336 | ) 337 | -------------------------------------------------------------------------------- /apgorm/sql/sql.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import UserString 4 | from typing import ( 5 | TYPE_CHECKING, 6 | Any, 7 | Callable, 8 | Generic, 9 | Iterable, 10 | Type, 11 | TypeVar, 12 | Union, 13 | overload, 14 | ) 15 | 16 | if TYPE_CHECKING: # pragma: no cover 17 | from apgorm.field import BaseField 18 | from apgorm.types.base_type import SqlType 19 | from apgorm.types.boolean import Bool 20 | from apgorm.types.numeric import Int # noqa 21 | 22 | _T = TypeVar("_T", covariant=True) 23 | _T2 = TypeVar("_T2") 24 | _SQLT_CO = TypeVar("_SQLT_CO", bound="SqlType[Any]", covariant=True) 25 | _SQLT = TypeVar("_SQLT", bound="SqlType[Any]") 26 | SQL = Union[ 27 | "BaseField[SqlType[_T], _T, Any]", 28 | "Block[SqlType[_T]]", 29 | "Parameter[_T]", 30 | _T, 31 | ] 32 | """ 33 | A type alias that allows for any field, raw sql, parameter. 34 | 35 | If used like `SQL[int]`, then it must have a python type of int. 36 | For example: 37 | 38 | ``` 39 | integer: SQL[int] = sql(1) # works 40 | integer: SQL[int] = sql(1).cast(BigInt()) # works 41 | integer: SQL[int] = sql("hi") # fails 42 | """ 43 | 44 | CASTED = Union["BaseField[_SQLT, Any, Any]", "Block[_SQLT]"] 45 | """ 46 | A type alias that allows for any field or raw sql with a specified 47 | SQL type. For example: 48 | 49 | ``` 50 | bigint: CASTED[BigInt] = sql(1).cast(BigInt()) # works 51 | bigint: CASTED[BigInt] = sql(1) # fails 52 | bigint: CASTED[BigInt] = sql(1).cast(SmallInt()) # fails 53 | ``` 54 | """ 55 | 56 | 57 | @overload 58 | def sql(piece: CASTED[_SQLT_CO], /, *, wrap: bool = ...) -> Block[_SQLT_CO]: 59 | ... 60 | 61 | 62 | @overload 63 | def sql(piece: SQL[_T2], /, *, wrap: bool = ...) -> Block[SqlType[_T2]]: 64 | ... 65 | 66 | 67 | @overload 68 | def sql(*pieces: SQL[Any], wrap: bool = ...) -> Block[Any]: 69 | ... 70 | 71 | 72 | def sql(*pieces: SQL[Any], wrap: bool = False) -> Block[Any]: 73 | """Convenience function to wrap content in a Block. 74 | 75 | It is safe to pass user input to this function. 76 | 77 | Example: 78 | ``` 79 | sql( 80 | r("SELECT * FROM mytable WHERE id="), 81 | 5, 82 | r("AND othercol="), 83 | "hello", 84 | ).render() 85 | # "SELECT * FROM mytable WHERE id=$1 AND othercol=$2", [5, "hello"] 86 | ``` 87 | 88 | Returns: 89 | Block: The pieces wrapped in a Block. 90 | """ 91 | 92 | return Block(*pieces, wrap=wrap) 93 | 94 | 95 | def wrap(*pieces: SQL[Any]) -> Block[Any]: 96 | """Return the SQL pieces as a Block, wrapping them in parantheses.""" 97 | 98 | return sql(*pieces, wrap=True) 99 | 100 | 101 | def join( 102 | joiner: SQL[Any], *values: SQL[Any], wrap: bool = False 103 | ) -> Block[Any]: 104 | """SQL version of "delim".join(values). For example: 105 | 106 | ``` 107 | join(r(","), value1, value2, value3) 108 | ``` 109 | 110 | Args: 111 | joiner (SQL): The joiner or deliminator. 112 | 113 | Kwargs: 114 | wrap (bool): Whether or not to return the result wrapped. Defaults to 115 | False. 116 | """ 117 | 118 | new_values: list[SQL[Any]] = [] 119 | for x, v in enumerate(values): 120 | new_values.append(v) 121 | if x != len(values) - 1: 122 | new_values.append(joiner) 123 | return sql(*new_values, wrap=wrap) 124 | 125 | 126 | def or_(*pieces: SQL[bool]) -> Block[Bool]: 127 | """Shortcut for `join(r("OR"), value1, value2, ...)`.""" 128 | 129 | return join(raw("OR"), *pieces, wrap=True) 130 | 131 | 132 | def and_(*pieces: SQL[bool]) -> Block[Bool]: 133 | """Shortcut for `join(r("AND"), value1, value2, ...)`.""" 134 | 135 | return join(raw("AND"), *pieces, wrap=True) 136 | 137 | 138 | def raw(string: str) -> Block[Any]: 139 | """Treat the string as raw SQL and return a Block. 140 | 141 | Never, ever, ever wrap user input in `r()`. A good rule of thumb is that 142 | if you see a call to `r()`, the input should be text in quotes, not a 143 | variable. 144 | 145 | ``` 146 | r(somevarthatmightbeuserinput) # bad 147 | r("some text") # good 148 | ``` 149 | """ 150 | 151 | return sql(Raw(string)) 152 | 153 | 154 | class Raw(UserString): 155 | __slots__: Iterable[str] = () 156 | 157 | 158 | class Parameter(Generic[_T]): 159 | __slots__: Iterable[str] = ("value",) 160 | 161 | def __init__(self, value: _T) -> None: 162 | self.value = value 163 | 164 | 165 | class _Op(Generic[_SQLT]): 166 | __slots__: Iterable[str] = ("op",) 167 | 168 | def __init__(self, op: str) -> None: 169 | self.op = op 170 | 171 | def __get__( 172 | self, inst: object, cls: Type[object] 173 | ) -> Callable[[SQL[Any]], Block[_SQLT]]: 174 | def operator(other: SQL[Any]) -> Block[_SQLT]: 175 | assert isinstance(inst, Comparable) 176 | return wrap(inst._get_block(), raw(self.op), other) 177 | 178 | return operator 179 | 180 | 181 | class _Func(Generic[_SQLT]): 182 | __slots__: Iterable[str] = ("func", "rside") 183 | 184 | def __init__(self, func: str, rside: bool = False) -> None: 185 | self.func = func 186 | self.rside = rside 187 | 188 | def __get__(self, inst: object, cls: Type[object]) -> Block[_SQLT]: 189 | assert isinstance(inst, Comparable) 190 | if self.rside: 191 | return wrap(wrap(inst._get_block()), raw(self.func)) 192 | return sql(raw(self.func), wrap(inst._get_block())) 193 | 194 | 195 | class Comparable: 196 | __slots__: Iterable[str] = () 197 | 198 | def _get_block(self) -> Block[Any]: 199 | raise NotImplementedError # pragma: no cover 200 | 201 | def cast(self, type_: _SQLT) -> Block[_SQLT]: 202 | return wrap(self._get_block(), raw("::"), raw(type_._sql)) 203 | 204 | # comparison 205 | is_null = _Func["Bool"]("IS NULL", True) 206 | is_true = _Func["Bool"]("IS TRUE", True) 207 | is_false = _Func["Bool"]("IS FALSE", True) 208 | num_nulls = _Func["Int"]("NUM_NULLS") 209 | num_nonnulls = _Func["Int"]("NUM_NONNULLS") 210 | not_ = _Func["Bool"]("NOT") 211 | 212 | eq = _Op["Bool"]("=") 213 | neq = _Op["Bool"]("!=") 214 | lt = _Op["Bool"]("<") 215 | gt = _Op["Bool"](">") 216 | lteq = _Op["Bool"]("<=") 217 | gteq = _Op["Bool"](">=") 218 | 219 | any = _Func["Any"]("ANY") 220 | all = _Func["Any"]("ALL") 221 | 222 | 223 | class Block(Comparable, Generic[_SQLT_CO]): 224 | """Represents a list of raw sql and parameters.""" 225 | 226 | __slots__: Iterable[str] = ("_pieces", "_wrap") 227 | 228 | def __init__(self, *pieces: SQL[Any] | Raw, wrap: bool = False) -> None: 229 | """Create a Block. You may find it more convienient to use the `sql()` 230 | helper function (or `wrap(sql())` if you want the result to be wrapped 231 | in "( )"). 232 | 233 | Kwargs: 234 | wrap (bool): Whether or not to add `( )` around the result. 235 | 236 | Example: 237 | ``` 238 | sql = Block(r("SELECT"), "Hello, World", r(","), 17) 239 | sql.render() # "SELECT $1, $2", ["Hello, World", 17] 240 | ``` 241 | """ 242 | 243 | self._pieces: list[Raw | Parameter[Any]] = [] 244 | 245 | if len(pieces) == 1 and isinstance(pieces[0], Block): 246 | block = pieces[0] 247 | assert isinstance(block, Block) 248 | self._wrap: bool = block._wrap or wrap 249 | self._pieces = block._pieces 250 | 251 | else: 252 | self._wrap = wrap 253 | for p in pieces: 254 | if isinstance(p, Comparable): 255 | p = p._get_block() 256 | if isinstance(p, Block): 257 | self._pieces.extend(p.get_pieces()) 258 | elif isinstance(p, (Raw, Parameter)): 259 | self._pieces.append(p) 260 | else: 261 | self._pieces.append(Parameter(p)) 262 | 263 | def render(self) -> tuple[str, list[Any]]: 264 | """Return the rendered result of the Block. 265 | 266 | Returns: 267 | tuple[str, list[Any]]: The raw SQL (as a string) and the list of 268 | parameters. 269 | """ 270 | 271 | return Renderer().render(self) 272 | 273 | def render_no_params(self) -> str: 274 | """Convenience function to get the SQL but not the parameters. 275 | 276 | Although it is technically easier to write `.render()[0]`, this is 277 | cleaner and more obvious to those reading your code. 278 | 279 | Returns: 280 | str: The sql. 281 | """ 282 | 283 | return self.render()[0] 284 | 285 | def get_pieces( 286 | self, force_wrap: bool | None = None 287 | ) -> list[Raw | Parameter[Any]]: 288 | wrap = self._wrap if force_wrap is None else force_wrap 289 | if wrap: 290 | return [Raw("("), *self._pieces, Raw(")")] 291 | return self._pieces 292 | 293 | def _get_block(self) -> Block[Any]: 294 | return self 295 | 296 | def __iadd__(self, other: object) -> Block[Any]: 297 | if isinstance(other, Block): 298 | self._pieces.extend(other.get_pieces()) 299 | elif isinstance(other, Parameter): 300 | self._pieces.append(other) 301 | else: 302 | raise TypeError(f"Unsupported type {type(other)}") 303 | 304 | return self 305 | 306 | 307 | class Renderer: 308 | __slots__: Iterable[str] = ("_curr_value_id",) 309 | 310 | def __init__(self) -> None: 311 | self._curr_value_id: int = 0 312 | 313 | @property 314 | def _next_value_id(self) -> int: 315 | self._curr_value_id += 1 316 | return self._curr_value_id 317 | 318 | def render(self, sql: Block[Any]) -> tuple[str, list[Any]]: 319 | sql_pieces: list[str] = [] 320 | params: list[Any] = [] 321 | 322 | for piece in sql.get_pieces(force_wrap=False): 323 | if isinstance(piece, Raw): 324 | sql_pieces.append(str(piece)) 325 | else: 326 | sql_pieces.append(f"${self._next_value_id}") 327 | params.append(piece.value) 328 | 329 | return " ".join(sql_pieces), params 330 | -------------------------------------------------------------------------------- /apgorm/types/__init__.py: -------------------------------------------------------------------------------- 1 | from .array import Array 2 | from .binary import ByteA 3 | from .bitstr import Bit, VarBit 4 | from .boolean import Bool, Boolean 5 | from .character import Char, Text, VarChar 6 | from .date import ( 7 | Date, 8 | Interval, 9 | IntervalField, 10 | Time, 11 | Timestamp, 12 | TimestampTZ, 13 | TimeTZ, 14 | ) 15 | from .geometric import Box, Circle, Line, LineSegment, Path, Point, Polygon 16 | from .json_type import Json, JsonB 17 | from .monetary import Money 18 | from .network import CIDR, INET, MacAddr, MacAddr8 19 | from .numeric import ( 20 | FLOAT, 21 | INT, 22 | NUMBER, 23 | SERIAL, 24 | BigInt, 25 | BigSerial, 26 | DoublePrecision, 27 | Int, 28 | Integer, 29 | Numeric, 30 | Real, 31 | Serial, 32 | SmallInt, 33 | SmallSerial, 34 | ) 35 | from .uuid_type import UUID 36 | from .xml_type import XML 37 | 38 | __all__ = ( 39 | "Array", 40 | "ByteA", 41 | "Bit", 42 | "VarBit", 43 | "Bool", 44 | "Boolean", 45 | "Char", 46 | "Text", 47 | "VarChar", 48 | "Date", 49 | "Interval", 50 | "IntervalField", 51 | "Time", 52 | "Timestamp", 53 | "TimestampTZ", 54 | "TimeTZ", 55 | "Box", 56 | "Circle", 57 | "Line", 58 | "LineSegment", 59 | "Path", 60 | "Point", 61 | "Polygon", 62 | "Json", 63 | "JsonB", 64 | "Money", 65 | "CIDR", 66 | "INET", 67 | "MacAddr", 68 | "MacAddr8", 69 | "FLOAT", 70 | "INT", 71 | "NUMBER", 72 | "SERIAL", 73 | "BigInt", 74 | "BigSerial", 75 | "DoublePrecision", 76 | "Int", 77 | "Integer", 78 | "Numeric", 79 | "Real", 80 | "Serial", 81 | "SmallInt", 82 | "SmallSerial", 83 | "UUID", 84 | "XML", 85 | ) 86 | -------------------------------------------------------------------------------- /apgorm/types/array.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Iterable, Optional, Sequence, TypeVar 4 | 5 | from .base_type import SqlType 6 | 7 | _T = TypeVar("_T") 8 | 9 | 10 | class Array(SqlType[Sequence[Optional[_T]]]): 11 | """SQL array type. 12 | 13 | Args: 14 | subtype (SqlType): The subtype (can be another Array for multiple 15 | dimensions). 16 | 17 | Example: 18 | ``` 19 | list_of_names = Array(VarChar(32)).field() 20 | game_board = Array(Array(Int(), size=16), size=16).nullablefield() 21 | ``` 22 | 23 | https://www.postgresql.org/docs/14/arrays.html 24 | """ 25 | 26 | __slots__: Iterable[str] = ("_subtype",) 27 | 28 | def __init__(self, subtype: SqlType[_T]) -> None: 29 | self._subtype = subtype 30 | 31 | def _get_arrays( 32 | t: SqlType[Any], arrays: list[Array[Any]] | None = None 33 | ) -> tuple[list[Array[Any]], SqlType[Any]]: 34 | arrays = arrays or [] 35 | if isinstance(t, Array): 36 | arrays.append(t) 37 | return _get_arrays(t.subtype, arrays=arrays) 38 | else: 39 | return arrays, t 40 | 41 | arrays, final = _get_arrays(self) 42 | 43 | self._sql = final._sql + "".join(["[]" for _ in reversed(arrays)]) 44 | 45 | @property 46 | def subtype(self) -> SqlType[_T]: 47 | """The arrays sub type. 48 | 49 | Returns: 50 | SqlType: The sub type. 51 | """ 52 | 53 | return self._subtype 54 | -------------------------------------------------------------------------------- /apgorm/types/base_type.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, Generic, Iterable, TypeVar 4 | 5 | from apgorm.field import Field 6 | from apgorm.undefined import UNDEF 7 | 8 | _T = TypeVar("_T", covariant=True) 9 | _S = TypeVar("_S", bound="SqlType[Any]", covariant=True) 10 | 11 | 12 | class SqlType(Generic[_T]): 13 | """Base type for all SQL types.""" 14 | 15 | __slots__: Iterable[str] = ("_sql",) 16 | 17 | _sql: str 18 | """The raw sql for the type.""" 19 | 20 | def field( 21 | self: _S, 22 | default: _T | UNDEF = UNDEF.UNDEF, 23 | default_factory: Callable[[], _T] | None = None, 24 | use_repr: bool = True, 25 | ) -> Field[_S, _T]: 26 | """Generate a field using this type.""" 27 | 28 | return Field( 29 | sql_type=self, 30 | default=default, 31 | default_factory=default_factory, 32 | not_null=True, 33 | use_repr=use_repr, 34 | ) 35 | 36 | def nullablefield( 37 | self: _S, 38 | default: _T | None | UNDEF = UNDEF.UNDEF, 39 | default_factory: Callable[[], _T] | None = None, 40 | use_repr: bool = True, 41 | ) -> Field[_S, _T | None]: 42 | """Generate a nullable field using this type.""" 43 | 44 | return Field( 45 | sql_type=self, 46 | default=default, 47 | default_factory=default_factory, 48 | not_null=False, 49 | use_repr=use_repr, 50 | ) 51 | -------------------------------------------------------------------------------- /apgorm/types/binary.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from .base_type import SqlType 6 | 7 | 8 | class ByteA(SqlType[bytes]): 9 | """Variable-length binary string. 10 | 11 | https://www.postgresql.org/docs/14/datatype-binary.html 12 | """ 13 | 14 | __slots__: Iterable[str] = () 15 | 16 | _sql = "BYTEA" 17 | -------------------------------------------------------------------------------- /apgorm/types/bitstr.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | import asyncpg 6 | 7 | from .base_type import SqlType 8 | 9 | 10 | class Bit(SqlType[asyncpg.BitString]): 11 | """Fixed length string of 0's and 1's. 12 | 13 | Args: 14 | length (int, optional): The length of the bitstr. If None, 15 | postgres uses 1. Defaults to None. 16 | 17 | https://www.postgresql.org/docs/14/datatype-bit.html 18 | """ 19 | 20 | __slots__: Iterable[str] = ("_length",) 21 | 22 | def __init__(self, length: int | None = None) -> None: 23 | self._length = length 24 | self._sql = "BIT" 25 | if length is not None: 26 | self._sql += f"({length})" 27 | 28 | @property 29 | def length(self) -> int | None: 30 | """The length of the bitstr. 31 | 32 | Returns: 33 | int | None 34 | """ 35 | 36 | return self._length 37 | 38 | 39 | class VarBit(SqlType[asyncpg.BitString]): 40 | """Variable-length bitstr with max-length. 41 | 42 | https://www.postgresql.org/docs/14/datatype-bit.html 43 | """ 44 | 45 | def __init__(self, max_length: int | None = None) -> None: 46 | """Create a VarBit type. 47 | 48 | Args: 49 | max_length (int, optional): The maximum bitstr length. If None the 50 | length is unlimited. Defaults to None. 51 | """ 52 | 53 | self._max_length = max_length 54 | self._sql = "VARBIT" 55 | if max_length is not None: 56 | self._sql += f"({max_length})" 57 | 58 | @property 59 | def max_length(self) -> int | None: 60 | """The maximum length of the bitstr. 61 | 62 | Returns: 63 | int | None 64 | """ 65 | 66 | return self._max_length 67 | -------------------------------------------------------------------------------- /apgorm/types/boolean.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from .base_type import SqlType 6 | 7 | 8 | class Boolean(SqlType[bool]): 9 | """Boolean type. 10 | 11 | Note: Although the docs say that "null" is valid as a bool value, `col 12 | bool not null` still refuses null values. 13 | 14 | https://www.postgresql.org/docs/14/datatype-boolean.html 15 | """ 16 | 17 | __slots__: Iterable[str] = () 18 | 19 | _sql = "BOOLEAN" 20 | 21 | 22 | Bool = Boolean 23 | -------------------------------------------------------------------------------- /apgorm/types/character.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from .base_type import SqlType 6 | 7 | 8 | class VarChar(SqlType[str]): 9 | """Variable length string with limit. 10 | 11 | Args: 12 | max_length (int, optional): The maximum length of the string. If 13 | None, any length will be accepted. Defaults to None. 14 | 15 | https://www.postgresql.org/docs/14/datatype-character.html 16 | """ 17 | 18 | __slots__: Iterable[str] = ("_max_length",) 19 | 20 | def __init__(self, max_length: int | None = None) -> None: 21 | self._max_length = max_length 22 | self._sql = "VARCHAR" 23 | if max_length is not None: 24 | self._sql += f"({max_length})" 25 | 26 | @property 27 | def max_length(self) -> int | None: 28 | """The maximum length of the VarChar type. 29 | 30 | Returns: 31 | int | None 32 | """ 33 | 34 | return self._max_length 35 | 36 | 37 | class Char(SqlType[str]): 38 | """Fixed length string. If a passed string is too short, it will be 39 | padded with spaces. 40 | 41 | Args: 42 | max_length (int, optional): The length of the string. If None, 43 | postgres will use 1. Defaults to None. 44 | 45 | https://www.postgresql.org/docs/14/datatype-character.html 46 | """ 47 | 48 | def __init__(self, length: int | None = None) -> None: 49 | self._length = length 50 | self._sql = "CHAR" 51 | if length is not None: 52 | self._sql += f"({length})" 53 | 54 | @property 55 | def length(self) -> int | None: 56 | """The length of the string. 57 | 58 | Returns: 59 | int | None 60 | """ 61 | 62 | return self._length 63 | 64 | 65 | class Text(SqlType[str]): 66 | """Variable unlimited length string. 67 | 68 | https://www.postgresql.org/docs/14/datatype-character.html 69 | """ 70 | 71 | _sql = "TEXT" 72 | -------------------------------------------------------------------------------- /apgorm/types/date.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import datetime 4 | from enum import Enum 5 | from typing import Iterable 6 | 7 | from .base_type import SqlType 8 | 9 | 10 | class Timestamp(SqlType[datetime.datetime]): 11 | """Date and time without time zone. 12 | 13 | Args: 14 | precision (int, optional): Precision (0-6). Defaults to None. 15 | 16 | https://www.postgresql.org/docs/14/datatype-datetime.html 17 | """ 18 | 19 | __slots__: Iterable[str] = ("_precision",) 20 | 21 | def __init__(self, precision: int | None = None) -> None: 22 | self._precision = precision 23 | self._sql = "TIMESTAMP" 24 | if precision is not None: 25 | self._sql += f"({precision})" 26 | 27 | @property 28 | def precision(self) -> int | None: 29 | """The precision of the timestamp. 30 | 31 | Returns: 32 | int | None 33 | """ 34 | 35 | return self._precision 36 | 37 | 38 | class TimestampTZ(SqlType[datetime.datetime]): 39 | """Date and time with time zone. 40 | 41 | Args: 42 | precision (int, optional): Precision (0-6). Defaults to None. 43 | 44 | https://www.postgresql.org/docs/14/datatype-datetime.html 45 | """ 46 | 47 | __slots__: Iterable[str] = ("_precision",) 48 | 49 | def __init__(self, precision: int | None = None) -> None: 50 | self._precision = precision 51 | self._sql = "TIMESTAMPTZ" 52 | if precision is not None: 53 | self._sql += f"({precision})" 54 | 55 | @property 56 | def precision(self) -> int | None: 57 | """The precision of the timestamp. 58 | 59 | Returns: 60 | int | None 61 | """ 62 | 63 | return self._precision 64 | 65 | 66 | class Time(SqlType[datetime.time]): 67 | """Time of day with no date or time zone. 68 | 69 | Args: 70 | precision (int, optional): Precision (0-6). Defaults to None. 71 | 72 | https://www.postgresql.org/docs/14/datatype-datetime.html 73 | """ 74 | 75 | __slots__: Iterable[str] = ("_precision",) 76 | 77 | def __init__(self, precision: int | None = None) -> None: 78 | self._precision = precision 79 | self._sql = "TIME" 80 | if precision is not None: 81 | self._sql += f"({precision})" 82 | 83 | @property 84 | def precision(self) -> int | None: 85 | """The precision of the time. 86 | 87 | Returns: 88 | int | None 89 | """ 90 | 91 | return self._precision 92 | 93 | 94 | class TimeTZ(SqlType[datetime.time]): 95 | """Time of day with time zone (no date). 96 | 97 | Args: 98 | precision (int, optional): Precision (0-6). Defaults to None. 99 | 100 | https://www.postgresql.org/docs/14/datatype-datetime.html 101 | """ 102 | 103 | __slots__: Iterable[str] = ("_precision",) 104 | 105 | def __init__(self, precision: int | None = None) -> None: 106 | self._precision = precision 107 | self._sql = "TIMETZ" 108 | if precision is not None: 109 | self._sql += f"({precision})" 110 | 111 | @property 112 | def precision(self) -> int | None: 113 | """The precsion of the time. 114 | 115 | Returns: 116 | int | None 117 | """ 118 | 119 | return self._precision 120 | 121 | 122 | class Date(SqlType[datetime.date]): 123 | """Date (no time of day). 124 | 125 | https://www.postgresql.org/docs/14/datatype-datetime.html 126 | """ 127 | 128 | __slots__: Iterable[str] = () 129 | 130 | _sql = "DATE" 131 | 132 | 133 | class IntervalField(Enum): 134 | YEAR = "YEAR" 135 | MONTH = "MONTH" 136 | DAY = "DAY" 137 | HOUR = "HOUR" 138 | MINUTE = "MINUTE" 139 | SECOND = "SECOND" 140 | YEAR_TO_MONTH = "YEAR TO MONTH" 141 | DAY_TO_HOUR = "DAY TO HOUR" 142 | DAY_TO_SECOND = "DAY TO SECOND" 143 | HOUR_TO_MINUTE = "HOUR TO MINUTE" 144 | HOUR_TO_SECOND = "HOUR TO SECOND" 145 | MINUTE_TO_SECOND = "MINUTE TO SECOND" 146 | 147 | 148 | class Interval(SqlType[datetime.timedelta]): 149 | """Interval type. Essentially a time range (5 years or 2 minutes) but not 150 | an actual date or time (05-9-2005). 151 | 152 | Args: 153 | interval_field (IntervalField, optional): Restrict the fields of 154 | the interval. Defaults to None. 155 | precision (int, optional): Precision (0-6). Defaults to None. 156 | 157 | https://www.postgresql.org/docs/14/datatype-datetime.html 158 | """ 159 | 160 | __slots__: Iterable[str] = ("_interval_field", "_precision") 161 | 162 | def __init__( 163 | self, 164 | interval_field: IntervalField | None = None, 165 | precision: int | None = None, 166 | ) -> None: 167 | self._interval_field = interval_field 168 | self._precision = precision 169 | 170 | self._sql = "INTERVAL" 171 | if interval_field is not None: 172 | self._sql += " " + interval_field.value 173 | if precision is not None: 174 | self._sql += f"({precision})" 175 | 176 | @property 177 | def interval_field(self) -> IntervalField | None: 178 | """The allowed fields for the interval. 179 | 180 | Returns: 181 | IntervalField | None 182 | """ 183 | 184 | return self._interval_field 185 | 186 | @property 187 | def precision(self) -> int | None: 188 | """The precision of the interval. 189 | 190 | Returns: 191 | int | None 192 | """ 193 | 194 | return self._precision 195 | -------------------------------------------------------------------------------- /apgorm/types/geometric.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | import asyncpg 6 | 7 | from .base_type import SqlType 8 | 9 | 10 | class Point(SqlType[asyncpg.Point]): 11 | """Point (x, y). 12 | 13 | https://www.postgresql.org/docs/14/datatype-geometric.html#id-1.5.7.16.5 14 | """ 15 | 16 | __slots__: Iterable[str] = () 17 | 18 | _sql = "POINT" 19 | 20 | 21 | class Line(SqlType[asyncpg.Line]): 22 | """Infinite line. 23 | 24 | https://www.postgresql.org/docs/14/datatype-geometric.html#DATATYPE-LINE 25 | """ 26 | 27 | __slots__: Iterable[str] = () 28 | 29 | _sql = "LINE" 30 | 31 | 32 | class LineSegment(SqlType[asyncpg.LineSegment]): 33 | """Finite line. 34 | 35 | https://www.postgresql.org/docs/14/datatype-geometric.html#DATATYPE-LSEG 36 | """ 37 | 38 | __slots__: Iterable[str] = () 39 | 40 | _sql = "LSEG" 41 | 42 | 43 | class Box(SqlType[asyncpg.Box]): 44 | """Rectangular box (two points specifying corners). 45 | 46 | https://www.postgresql.org/docs/14/datatype-geometric.html#id-1.5.7.16.8 47 | """ 48 | 49 | __slots__: Iterable[str] = () 50 | 51 | _sql = "BOX" 52 | 53 | 54 | class Path(SqlType[asyncpg.Path]): 55 | """Path (multiple LineSegments). Can be open or closed. A closed path is 56 | similar to a polygon. 57 | 58 | https://www.postgresql.org/docs/14/datatype-geometric.html#id-1.5.7.16.9 59 | """ 60 | 61 | __slots__: Iterable[str] = () 62 | 63 | _sql = "PATH" 64 | 65 | 66 | class Polygon(SqlType[asyncpg.Polygon]): 67 | """Polygon type (similar to a closed path). 68 | 69 | https://www.postgresql.org/docs/14/datatype-geometric.html#DATATYPE-POLYGON 70 | """ 71 | 72 | __slots__: Iterable[str] = () 73 | 74 | _sql = "POLYGON" 75 | 76 | 77 | class Circle(SqlType[asyncpg.Circle]): 78 | """Circle type (center Point and radius). 79 | 80 | https://www.postgresql.org/docs/14/datatype-geometric.html#DATATYPE-CIRCLE 81 | """ 82 | 83 | __slots__: Iterable[str] = () 84 | 85 | _sql = "CIRCLE" 86 | -------------------------------------------------------------------------------- /apgorm/types/json_type.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from .base_type import SqlType 6 | 7 | 8 | class Json(SqlType[str]): 9 | """Stores JSON. 10 | 11 | https://www.postgresql.org/docs/14/datatype-json.html 12 | """ 13 | 14 | __slots__: Iterable[str] = () 15 | 16 | _sql = "JSON" 17 | 18 | 19 | class JsonB(SqlType[str]): 20 | """Stores JSON. 21 | 22 | https://www.postgresql.org/docs/14/datatype-json.html 23 | """ 24 | 25 | __slots__: Iterable[str] = () 26 | 27 | _sql = "JSONB" 28 | -------------------------------------------------------------------------------- /apgorm/types/monetary.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from .base_type import SqlType 6 | 7 | 8 | class Money(SqlType[str]): 9 | """Currency type, 8 bytes, range -92233720368547758.08 to 10 | +92233720368547758.07. 11 | 12 | https://www.postgresql.org/docs/14/datatype-money.html 13 | """ 14 | 15 | __slots__: Iterable[str] = () 16 | 17 | _sql = "MONEY" 18 | -------------------------------------------------------------------------------- /apgorm/types/network.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import ipaddress as ipaddr 4 | from typing import Iterable, Union 5 | 6 | from .base_type import SqlType 7 | 8 | 9 | class CIDR(SqlType[Union[ipaddr.IPv4Network, ipaddr.IPv6Network]]): 10 | """IPv4 and IPv6 networks. 7 or 19 bytes. 11 | 12 | https://www.postgresql.org/docs/14/datatype-net-types.html#DATATYPE-CIDR 13 | https://www.postgresql.org/docs/14/datatype-net-types.html#DATATYPE-INET-VS-CIDR 14 | """ 15 | 16 | __slots__: Iterable[str] = () 17 | 18 | _sql = "CIDR" 19 | 20 | 21 | class INET( 22 | SqlType[ 23 | Union[ 24 | ipaddr.IPv6Address, 25 | ipaddr.IPv6Interface, 26 | ipaddr.IPv4Address, 27 | ipaddr.IPv4Interface, 28 | ] 29 | ] 30 | ): 31 | """IPv4 and IPv6 hosts and networks. 7 or 19 bytes. 32 | 33 | https://www.postgresql.org/docs/14/datatype-net-types.html#DATATYPE-INET 34 | https://www.postgresql.org/docs/14/datatype-net-types.html#DATATYPE-INET-VS-CIDR 35 | """ 36 | 37 | __slots__: Iterable[str] = () 38 | 39 | _sql = "INET" 40 | 41 | 42 | class MacAddr(SqlType[str]): 43 | """MAC addresses. 6 bytes. 44 | 45 | https://www.postgresql.org/docs/14/datatype-net-types.html#DATATYPE-MACADDR 46 | """ 47 | 48 | __slots__: Iterable[str] = () 49 | 50 | _sql = "MACADDR" 51 | 52 | 53 | class MacAddr8(SqlType[str]): 54 | """MAC address (EUI-64 format). 8 bytes. 55 | 56 | https://www.postgresql.org/docs/14/datatype-net-types.html#DATATYPE-MACADDR8 57 | """ 58 | 59 | __slots__: Iterable[str] = () 60 | 61 | _sql = "MACADDR8" 62 | -------------------------------------------------------------------------------- /apgorm/types/numeric.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from decimal import Decimal 4 | from typing import Iterable, TypeVar, Union 5 | 6 | from apgorm.exceptions import BadArgument 7 | from apgorm.field import Field 8 | 9 | from .base_type import SqlType 10 | 11 | INT = Union["SmallInt", "Int", "BigInt"] 12 | """All integral types (excluding serials).""" 13 | FLOAT = Union["Real", "DoublePrecision", "Numeric"] 14 | """All floating-point types.""" 15 | SERIAL = Union["SmallSerial", "Serial", "BigSerial"] 16 | """All serial types.""" 17 | 18 | NUMBER = Union[INT, FLOAT, SERIAL] 19 | """All integral, floating-point, and serial types.""" 20 | 21 | 22 | class SmallInt(SqlType[int]): 23 | """Small integer, 2 bytes, -32768 to +32767. 24 | 25 | https://www.postgresql.org/docs/14/datatype-numeric.html 26 | """ 27 | 28 | __slots__: Iterable[str] = () 29 | 30 | _sql = "SMALLINT" 31 | 32 | 33 | class Int(SqlType[int]): 34 | """Integer, 4 bytes, -2147483648 to +2147483647. 35 | 36 | https://www.postgresql.org/docs/14/datatype-numeric.html 37 | """ 38 | 39 | __slots__: Iterable[str] = () 40 | 41 | _sql = "INTEGER" 42 | 43 | 44 | Integer = Int 45 | 46 | 47 | class BigInt(SqlType[int]): 48 | """Big integer, 8 bytes, -9223372036854775808 to +9223372036854775807. 49 | 50 | https://www.postgresql.org/docs/14/datatype-numeric.html 51 | """ 52 | 53 | __slots__: Iterable[str] = () 54 | 55 | _sql = "BIGINT" 56 | 57 | 58 | class Numeric(SqlType[Decimal]): 59 | """Numeric (AKA decimal) type with a user specified precision and scale. 60 | 61 | Up to 131072 digits before the decimal point and 16383 digits after. 62 | 63 | Args: 64 | precision (int, optional): The total number of significant digits 65 | (`55.55` has 4). Defaults to None. 66 | scale (int, optional): The total number of digits after the 67 | decimal. Specifying a precision without a scale sets the scale to 68 | 0. Defaults to None. 69 | 70 | Raises: 71 | BadArgument: You tried to specify a scale without a precision. 72 | 73 | https://www.postgresql.org/docs/14/datatype-numeric.html 74 | """ 75 | 76 | __slots__: Iterable[str] = ("_precision", "_scale") 77 | 78 | def __init__( 79 | self, precision: int | None = None, scale: int | None = None 80 | ) -> None: 81 | self._precision = precision 82 | self._scale = scale 83 | 84 | self._sql = "NUMERIC" 85 | 86 | if precision is not None and scale is not None: 87 | self._sql += f"({precision}, {scale})" 88 | elif precision is not None: 89 | self._sql += f"({precision})" 90 | elif scale is not None: 91 | raise BadArgument("Cannot specify scale without precision.") 92 | 93 | @property 94 | def precision(self) -> int | None: 95 | """The precision (total number of significant digits) of the type. 96 | 97 | Returns: 98 | int | None 99 | """ 100 | 101 | return self._precision 102 | 103 | @property 104 | def scale(self) -> int | None: 105 | """The scale (digits after the decimal point) of the type. 106 | 107 | Returns: 108 | int | None 109 | """ 110 | 111 | return self._scale 112 | 113 | 114 | class Real(SqlType[float]): 115 | """4 byte floating-point number (inexact). 116 | 117 | https://www.postgresql.org/docs/14/datatype-numeric.html 118 | """ 119 | 120 | __slots__: Iterable[str] = () 121 | 122 | _sql = "REAL" 123 | 124 | 125 | class DoublePrecision(SqlType[float]): 126 | """8 byte floating-point number (inexact). 127 | 128 | https://www.postgresql.org/docs/14/datatype-numeric.html 129 | """ 130 | 131 | __slots__: Iterable[str] = () 132 | 133 | _sql = "DOUBLE PRECISION" 134 | 135 | 136 | _S = TypeVar("_S", bound="_BaseSerial", covariant=True) 137 | 138 | 139 | class _BaseSerial(SqlType[int]): 140 | __slots__: Iterable[str] = () 141 | 142 | def field( # type: ignore 143 | self: _S, use_repr: bool = True 144 | ) -> Field[_S, int]: 145 | return Field(sql_type=self, not_null=True, use_repr=use_repr) 146 | 147 | def nullablefield( # type: ignore 148 | self: _S, use_repr: bool = True 149 | ) -> Field[_S, int | None]: 150 | return Field(sql_type=self, not_null=False, use_repr=use_repr) 151 | 152 | 153 | class SmallSerial(_BaseSerial): 154 | """Small autoincrementing integer. You cannot cast to this type. 155 | 156 | Accepts any value accepted by SmallInt, but the default values are 157 | positive only (starting at 1). 158 | 159 | https://www.postgresql.org/docs/14/datatype-numeric.html 160 | """ 161 | 162 | __slots__: Iterable[str] = () 163 | 164 | _sql = "SMALLSERIAL" 165 | 166 | 167 | class Serial(_BaseSerial): 168 | """Autoincrementing integer. You cannot cast to this type. 169 | 170 | Accepts any value accepted by Int, but the default values are positive 171 | only (starting at 1). 172 | 173 | https://www.postgresql.org/docs/14/datatype-numeric.html 174 | """ 175 | 176 | __slots__: Iterable[str] = () 177 | 178 | _sql = "SERIAL" 179 | 180 | 181 | class BigSerial(_BaseSerial): 182 | """Large autoincrementing integer. You cannot cast to this type. 183 | 184 | Accepts any value accepted by BigInt, but the default values are positive 185 | only (starting at 1). 186 | 187 | https://www.postgresql.org/docs/14/datatype-numeric.html 188 | """ 189 | 190 | __slots__: Iterable[str] = () 191 | 192 | _sql = "BIGSERIAL" 193 | -------------------------------------------------------------------------------- /apgorm/types/uuid_type.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import uuid 4 | from typing import Iterable 5 | 6 | from .base_type import SqlType 7 | 8 | 9 | class UUID(SqlType[uuid.UUID]): 10 | """Universally Unique Identifier (AKA GUID). 11 | 12 | https://www.postgresql.org/docs/14/datatype-uuid.html 13 | """ 14 | 15 | __slots__: Iterable[str] = () 16 | 17 | _sql = "UUID" 18 | -------------------------------------------------------------------------------- /apgorm/types/xml_type.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from .base_type import SqlType 6 | 7 | 8 | class XML(SqlType[str]): 9 | """Type for storing XML data. 10 | 11 | https://www.postgresql.org/docs/14/datatype-xml.html 12 | """ 13 | 14 | __slots__: Iterable[str] = () 15 | 16 | _sql = "XML" 17 | -------------------------------------------------------------------------------- /apgorm/undefined.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | class UNDEF(enum.Enum): 5 | UNDEF = "UNDEFINED" 6 | -------------------------------------------------------------------------------- /apgorm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/apgorm/utils/__init__.py -------------------------------------------------------------------------------- /apgorm/utils/lazy_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Any, 5 | Callable, 6 | Generator, 7 | Generic, 8 | Iterable, 9 | Sequence, 10 | TypeVar, 11 | overload, 12 | ) 13 | 14 | _IN = TypeVar("_IN") 15 | _OUT = TypeVar("_OUT") 16 | 17 | 18 | class LazyList(Generic[_IN, _OUT]): 19 | """Lazily converts each value of an iterable. 20 | 21 | Incredible useful for casting `list[asyncpg.Record]` to `list[dict]`, and 22 | then `list[dict]` to `list[Model]`, especially when there are many rows. 23 | """ 24 | 25 | __slots__: Iterable[str] = ("_data", "_converter") 26 | 27 | def __init__( 28 | self, 29 | data: Sequence[_IN] | LazyList[Any, _IN], 30 | converter: Callable[[_IN], _OUT], 31 | ) -> None: 32 | self._data = data 33 | self._converter = converter 34 | 35 | @overload 36 | def __getitem__(self, index: int) -> _OUT: 37 | ... # pragma: no cover 38 | 39 | @overload 40 | def __getitem__(self, index: slice) -> LazyList[_IN, _OUT]: 41 | ... # pragma: no cover 42 | 43 | def __getitem__(self, index: int | slice) -> LazyList[_IN, _OUT] | _OUT: 44 | if isinstance(index, int): 45 | return self._converter(self._data[index]) 46 | return LazyList(self._data[index], self._converter) 47 | 48 | def __iter__(self) -> Generator[_OUT, None, None]: 49 | for r in self._data: 50 | yield self._converter(r) 51 | 52 | def __len__(self) -> int: 53 | return len(self._data) 54 | 55 | def __repr__(self) -> str: 56 | if len(self) == 6: 57 | ddd = ", {}".format(repr(self[-1])) 58 | elif len(self) > 6: 59 | ddd = ", ..., {}".format(repr(self[-1])) 60 | else: 61 | ddd = "" 62 | return "LazyList([{}{}])".format( 63 | ", ".join(repr(c) for c in list(self[0:5])), ddd 64 | ) 65 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/examples/__init__.py -------------------------------------------------------------------------------- /examples/basic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/examples/basic/__init__.py -------------------------------------------------------------------------------- /examples/basic/__main__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | 5 | from .main import main 6 | 7 | if __name__ == "__main__": 8 | asyncio.run(main()) 9 | -------------------------------------------------------------------------------- /examples/basic/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import apgorm 6 | from apgorm.types import VarChar 7 | 8 | 9 | class User(apgorm.Model): 10 | username = VarChar(32).field() 11 | nickname = VarChar(32).nullablefield() 12 | 13 | primary_key = (username,) 14 | 15 | 16 | class Database(apgorm.Database): 17 | users = User 18 | 19 | 20 | async def _main() -> None: 21 | await User.delete_query().execute() 22 | 23 | # create a user 24 | user = await User(username="Circuit").create() 25 | print("Created", user) 26 | 27 | # get the user 28 | _u = await User.fetch(username="Circuit") 29 | print(f"{user} == {_u}:", user == _u) 30 | 31 | # delete the user 32 | await user.delete() 33 | print("Deleted", user) 34 | 35 | # create lots of users 36 | for un in ["Circuit", "Bad Guy", "Super Man", "Batman"]: 37 | await User(username=un).create() 38 | print("Created user with name", un) 39 | 40 | # get all users except for "Bad Guy" 41 | query = User.fetch_query() 42 | query.where(User.username.neq("Bad Guy")) 43 | query.order_by(User.username.eq("Circuit"), True) # put Circuit first 44 | good_users = await query.fetchmany() 45 | print("Users except for 'Bad Guy':", good_users) 46 | 47 | # set the nickname for all good users to "Good Guy" 48 | for user in good_users: 49 | user.nickname = "Good Guy" 50 | await user.save() 51 | 52 | # OR, you can use User.update_query... 53 | ( 54 | await User.update_query() 55 | .where(User.username.neq("Bad Guy")) 56 | .set(nickname="Good Guy") 57 | .execute() 58 | ) 59 | 60 | print( 61 | "Users after setting nicknames:", await User.fetch_query().fetchmany() 62 | ) 63 | 64 | 65 | async def main() -> None: 66 | db = Database(Path("examples/basic/migrations")) 67 | await db.connect(database="apgorm_testing_database") 68 | 69 | if db.must_create_migrations(): 70 | db.create_migrations() 71 | if await db.must_apply_migrations(): 72 | await db.apply_migrations() 73 | 74 | try: 75 | await _main() 76 | finally: 77 | await db.cleanup() 78 | -------------------------------------------------------------------------------- /examples/basic/migrations/0000/describe.json: -------------------------------------------------------------------------------- 1 | { 2 | "tables": [ 3 | { 4 | "name": "users", 5 | "fields": [ 6 | { 7 | "name": "username", 8 | "type_": "VARCHAR(32)", 9 | "not_null": true 10 | }, 11 | { 12 | "name": "nickname", 13 | "type_": "VARCHAR(32)", 14 | "not_null": false 15 | } 16 | ], 17 | "fk_constraints": [], 18 | "pk_constraint": { 19 | "name": "_users_username_primary_key", 20 | "raw_sql": "CONSTRAINT _users_username_primary_key PRIMARY KEY ( username )" 21 | }, 22 | "unique_constraints": [], 23 | "check_constraints": [], 24 | "exclude_constraints": [] 25 | }, 26 | { 27 | "name": "_migrations", 28 | "fields": [ 29 | { 30 | "name": "id_", 31 | "type_": "INTEGER", 32 | "not_null": true 33 | } 34 | ], 35 | "fk_constraints": [], 36 | "pk_constraint": { 37 | "name": "__migrations_id__primary_key", 38 | "raw_sql": "CONSTRAINT __migrations_id__primary_key PRIMARY KEY ( id_ )" 39 | }, 40 | "unique_constraints": [], 41 | "check_constraints": [], 42 | "exclude_constraints": [] 43 | } 44 | ], 45 | "indexes": [] 46 | } -------------------------------------------------------------------------------- /examples/basic/migrations/0000/migrations.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE users (); 2 | CREATE TABLE _migrations (); 3 | ALTER TABLE users ADD COLUMN username VARCHAR(32); 4 | ALTER TABLE users ADD COLUMN nickname VARCHAR(32); 5 | ALTER TABLE _migrations ADD COLUMN id_ INTEGER; 6 | ALTER TABLE users ALTER COLUMN username SET NOT NULL; 7 | ALTER TABLE _migrations ALTER COLUMN id_ SET NOT NULL; 8 | ALTER TABLE users ADD CONSTRAINT _users_username_primary_key PRIMARY KEY ( username ); 9 | ALTER TABLE _migrations ADD CONSTRAINT __migrations_id__primary_key PRIMARY KEY ( id_ ); -------------------------------------------------------------------------------- /examples/converters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/examples/converters/__init__.py -------------------------------------------------------------------------------- /examples/converters/__main__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | 5 | from .main import main 6 | 7 | if __name__ == "__main__": 8 | asyncio.run(main()) 9 | -------------------------------------------------------------------------------- /examples/converters/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import IntEnum 4 | from pathlib import Path 5 | 6 | import apgorm 7 | from apgorm.types import Int, VarChar 8 | 9 | 10 | class PlayerStatus(IntEnum): 11 | NOT_FINISHED = 0 12 | WINNER = 1 13 | LOSER = 2 14 | DROPPED = 3 15 | 16 | 17 | class PlayerStatusConverter(apgorm.Converter[int, PlayerStatus]): 18 | def to_stored(self, value: PlayerStatus) -> int: 19 | return int(value) 20 | 21 | def from_stored(self, value: int) -> PlayerStatus: 22 | return PlayerStatus(value) 23 | 24 | 25 | class Player(apgorm.Model): 26 | username = VarChar(32).field() 27 | status = Int().field(default=0).with_converter(PlayerStatusConverter) 28 | 29 | primary_key = (username,) 30 | 31 | 32 | class Database(apgorm.Database): 33 | players = Player 34 | 35 | 36 | async def _main() -> None: 37 | player = await Player( 38 | username="Circuit", status=PlayerStatus.NOT_FINISHED 39 | ).create() 40 | print("Created player", player) 41 | 42 | print("player.status:", player.status) # PlayerStatus.NOT_FINISHED 43 | print("players.status raw value:", player._raw_values["status"]) # 0 44 | 45 | player.status = PlayerStatus.WINNER 46 | await player.save() 47 | print(f"Set status to {player.status!r}") 48 | 49 | await player.delete() 50 | 51 | 52 | async def main() -> None: 53 | db = Database(Path("examples/converters/migrations")) 54 | await db.connect(database="apgorm_testing_database") 55 | 56 | if db.must_create_migrations(): 57 | db.create_migrations() 58 | if await db.must_apply_migrations(): 59 | await db.apply_migrations() 60 | 61 | try: 62 | await _main() 63 | finally: 64 | await db.cleanup() 65 | -------------------------------------------------------------------------------- /examples/converters/migrations/0000/describe.json: -------------------------------------------------------------------------------- 1 | { 2 | "tables": [ 3 | { 4 | "name": "players", 5 | "fields": [ 6 | { 7 | "name": "username", 8 | "type_": "VARCHAR(32)", 9 | "not_null": true 10 | }, 11 | { 12 | "name": "status", 13 | "type_": "INTEGER", 14 | "not_null": true 15 | } 16 | ], 17 | "fk_constraints": [], 18 | "pk_constraint": { 19 | "name": "_players_username_primary_key", 20 | "raw_sql": "CONSTRAINT _players_username_primary_key PRIMARY KEY ( username )" 21 | }, 22 | "unique_constraints": [], 23 | "check_constraints": [], 24 | "exclude_constraints": [] 25 | }, 26 | { 27 | "name": "_migrations", 28 | "fields": [ 29 | { 30 | "name": "id_", 31 | "type_": "INTEGER", 32 | "not_null": true 33 | } 34 | ], 35 | "fk_constraints": [], 36 | "pk_constraint": { 37 | "name": "__migrations_id__primary_key", 38 | "raw_sql": "CONSTRAINT __migrations_id__primary_key PRIMARY KEY ( id_ )" 39 | }, 40 | "unique_constraints": [], 41 | "check_constraints": [], 42 | "exclude_constraints": [] 43 | } 44 | ], 45 | "indexes": [] 46 | } -------------------------------------------------------------------------------- /examples/converters/migrations/0000/migrations.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE players (); 2 | CREATE TABLE _migrations (); 3 | ALTER TABLE players ADD COLUMN username VARCHAR(32); 4 | ALTER TABLE players ADD COLUMN status INTEGER; 5 | ALTER TABLE _migrations ADD COLUMN id_ INTEGER; 6 | ALTER TABLE players ALTER COLUMN username SET NOT NULL; 7 | ALTER TABLE players ALTER COLUMN status SET NOT NULL; 8 | ALTER TABLE _migrations ALTER COLUMN id_ SET NOT NULL; 9 | ALTER TABLE players ADD CONSTRAINT _players_username_primary_key PRIMARY KEY ( username ); 10 | ALTER TABLE _migrations ADD CONSTRAINT __migrations_id__primary_key PRIMARY KEY ( id_ ); -------------------------------------------------------------------------------- /examples/manytomany/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/examples/manytomany/__init__.py -------------------------------------------------------------------------------- /examples/manytomany/__main__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from .main import main 4 | 5 | if __name__ == "__main__": 6 | asyncio.run(main()) 7 | -------------------------------------------------------------------------------- /examples/manytomany/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | from pathlib import Path 5 | 6 | import apgorm 7 | from apgorm import ForeignKey, ManyToMany 8 | from apgorm.types import Int, Serial, VarChar 9 | 10 | 11 | class User(apgorm.Model): 12 | name = VarChar(32).field() 13 | 14 | games = ManyToMany["Game", "Player"]( 15 | "name", "players.username", "players.gameid", "games.id_" 16 | ) 17 | 18 | primary_key = (name,) 19 | 20 | 21 | class Game(apgorm.Model): 22 | id_ = Serial().field() 23 | 24 | users = ManyToMany[User, "Player"]( 25 | "id_", "players.gameid", "players.username", "users.name" 26 | ) 27 | 28 | primary_key = (id_,) 29 | 30 | 31 | class Player(apgorm.Model): 32 | username = VarChar(32).field() 33 | gameid = Int().field() 34 | 35 | username_fk = ForeignKey(username, User.name) 36 | gameid_fk = ForeignKey(gameid, Game.id_) 37 | 38 | primary_key = (username, gameid) 39 | 40 | 41 | class Database(apgorm.Database): 42 | users = User 43 | games = Game 44 | players = Player 45 | 46 | 47 | async def _main(db: Database) -> None: 48 | # migrations 49 | if db.must_create_migrations(): 50 | db.create_migrations() 51 | if await db.must_apply_migrations(): 52 | await db.apply_migrations() 53 | 54 | # delete stuff (don't want to run into unique violations) 55 | await User.delete_query().execute() 56 | await Game.delete_query().execute() 57 | 58 | # create some stuff 59 | usernames = ["Circuit", "User 2", "James", "Superman", "Batman"] 60 | for un in usernames: 61 | await User(name=un).create() 62 | 63 | for _ in range(0, 5): 64 | await Game().create() 65 | 66 | all_users = await User.fetch_query().fetchmany() 67 | all_games = await Game.fetch_query().fetchmany() 68 | 69 | # add users to games randomly 70 | for g in all_games: 71 | for u in all_users: 72 | if random.randint(0, 1) == 1: 73 | await g.users.add(u) # or, await u.games.add(g) 74 | 75 | # for each user, get all their games 76 | for user in all_users: 77 | print(f"Games {user} is in:") 78 | _games = await user.games.fetchmany() 79 | num = await user.games.count() 80 | assert num == len(_games) 81 | print(" - " + "\n - ".join([repr(g) for g in _games])) 82 | 83 | print("") 84 | 85 | # for each game, get all the users 86 | for g in all_games: 87 | print(f"Users for game {g}:") 88 | _users = await g.users.fetchmany() 89 | print(" - " + "\n - ".join([repr(u) for u in _users])) 90 | 91 | # remove all instances of player (remove all games from all users) 92 | await Player.delete_query().execute() 93 | print("\nRemoved all games from all users") 94 | 95 | # add all games to Circuit 96 | circuit = await User.fetch(name="Circuit") 97 | [await circuit.games.add(g) for g in all_games] 98 | 99 | print(f"\nAdded {all_games} to {circuit}") 100 | print("Circuits Games: {}\n".format(await circuit.games.fetchmany())) 101 | 102 | # remove two random games: 103 | to_remove = random.choices(list(all_games), k=2) 104 | [await circuit.games.remove(r) for r in to_remove] 105 | 106 | print(f"Removed {to_remove} from {circuit}") 107 | print("Circuits Games: {}\n".format(await circuit.games.fetchmany())) 108 | 109 | # clear all games from circuit 110 | await circuit.games.clear() 111 | 112 | print(f"Cleared all games from {circuit}") 113 | print("Circuits Games: {}\n".format(await circuit.games.fetchmany())) 114 | 115 | 116 | async def main() -> None: 117 | db = Database(Path("examples/manytomany/migrations")) 118 | await db.connect(database="apgorm_testing_database") 119 | try: 120 | await _main(db) 121 | finally: 122 | await db.cleanup() 123 | -------------------------------------------------------------------------------- /examples/manytomany/migrations/0000/describe.json: -------------------------------------------------------------------------------- 1 | { 2 | "tables": [ 3 | { 4 | "name": "users", 5 | "fields": [ 6 | { 7 | "name": "name", 8 | "type_": "VARCHAR(32)", 9 | "not_null": true 10 | } 11 | ], 12 | "fk_constraints": [], 13 | "pk_constraint": { 14 | "name": "_users_name_primary_key", 15 | "raw_sql": "CONSTRAINT _users_name_primary_key PRIMARY KEY ( name )" 16 | }, 17 | "unique_constraints": [], 18 | "check_constraints": [], 19 | "exclude_constraints": [] 20 | }, 21 | { 22 | "name": "games", 23 | "fields": [ 24 | { 25 | "name": "id_", 26 | "type_": "SERIAL", 27 | "not_null": true 28 | } 29 | ], 30 | "fk_constraints": [], 31 | "pk_constraint": { 32 | "name": "_games_id__primary_key", 33 | "raw_sql": "CONSTRAINT _games_id__primary_key PRIMARY KEY ( id_ )" 34 | }, 35 | "unique_constraints": [], 36 | "check_constraints": [], 37 | "exclude_constraints": [] 38 | }, 39 | { 40 | "name": "players", 41 | "fields": [ 42 | { 43 | "name": "username", 44 | "type_": "VARCHAR(32)", 45 | "not_null": true 46 | }, 47 | { 48 | "name": "gameid", 49 | "type_": "INTEGER", 50 | "not_null": true 51 | } 52 | ], 53 | "fk_constraints": [ 54 | { 55 | "name": "username_fk", 56 | "raw_sql": "CONSTRAINT username_fk FOREIGN KEY ( username ) REFERENCES users ( name ) MATCH SIMPLE ON DELETE CASCADE ON UPDATE CASCADE" 57 | }, 58 | { 59 | "name": "gameid_fk", 60 | "raw_sql": "CONSTRAINT gameid_fk FOREIGN KEY ( gameid ) REFERENCES games ( id_ ) MATCH SIMPLE ON DELETE CASCADE ON UPDATE CASCADE" 61 | } 62 | ], 63 | "pk_constraint": { 64 | "name": "_players_username_gameid_primary_key", 65 | "raw_sql": "CONSTRAINT _players_username_gameid_primary_key PRIMARY KEY ( username , gameid )" 66 | }, 67 | "unique_constraints": [], 68 | "check_constraints": [], 69 | "exclude_constraints": [] 70 | }, 71 | { 72 | "name": "_migrations", 73 | "fields": [ 74 | { 75 | "name": "id_", 76 | "type_": "INTEGER", 77 | "not_null": true 78 | } 79 | ], 80 | "fk_constraints": [], 81 | "pk_constraint": { 82 | "name": "__migrations_id__primary_key", 83 | "raw_sql": "CONSTRAINT __migrations_id__primary_key PRIMARY KEY ( id_ )" 84 | }, 85 | "unique_constraints": [], 86 | "check_constraints": [], 87 | "exclude_constraints": [] 88 | } 89 | ], 90 | "indexes": [] 91 | } -------------------------------------------------------------------------------- /examples/manytomany/migrations/0000/migrations.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE users (); 2 | CREATE TABLE games (); 3 | CREATE TABLE players (); 4 | CREATE TABLE _migrations (); 5 | ALTER TABLE users ADD COLUMN name VARCHAR(32); 6 | ALTER TABLE games ADD COLUMN id_ SERIAL; 7 | ALTER TABLE players ADD COLUMN username VARCHAR(32); 8 | ALTER TABLE players ADD COLUMN gameid INTEGER; 9 | ALTER TABLE _migrations ADD COLUMN id_ INTEGER; 10 | ALTER TABLE users ALTER COLUMN name SET NOT NULL; 11 | ALTER TABLE games ALTER COLUMN id_ SET NOT NULL; 12 | ALTER TABLE players ALTER COLUMN username SET NOT NULL; 13 | ALTER TABLE players ALTER COLUMN gameid SET NOT NULL; 14 | ALTER TABLE _migrations ALTER COLUMN id_ SET NOT NULL; 15 | ALTER TABLE users ADD CONSTRAINT _users_name_primary_key PRIMARY KEY ( name ); 16 | ALTER TABLE games ADD CONSTRAINT _games_id__primary_key PRIMARY KEY ( id_ ); 17 | ALTER TABLE players ADD CONSTRAINT _players_username_gameid_primary_key PRIMARY KEY ( username , gameid ); 18 | ALTER TABLE _migrations ADD CONSTRAINT __migrations_id__primary_key PRIMARY KEY ( id_ ); 19 | ALTER TABLE players ADD CONSTRAINT username_fk FOREIGN KEY ( username ) REFERENCES users ( name ) MATCH SIMPLE ON DELETE CASCADE ON UPDATE CASCADE; 20 | ALTER TABLE players ADD CONSTRAINT gameid_fk FOREIGN KEY ( gameid ) REFERENCES games ( id_ ) MATCH SIMPLE ON DELETE CASCADE ON UPDATE CASCADE; -------------------------------------------------------------------------------- /examples/validators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circuitsacul/apgorm/db1c0045141797b0128b5321d5f37aaf00065e71/examples/validators/__init__.py -------------------------------------------------------------------------------- /examples/validators/__main__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .main import main 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /examples/validators/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import apgorm 6 | from apgorm.exceptions import InvalidFieldValue 7 | from apgorm.types import VarChar 8 | 9 | 10 | def email_validator(email: str | None) -> bool: 11 | if email is None or email.endswith("@gmail.com"): 12 | return True 13 | raise InvalidFieldValue("Email must be None or end with @gmail.com") 14 | 15 | 16 | class User(apgorm.Model): 17 | username = VarChar(32).field() 18 | email = VarChar(32).nullablefield() 19 | email.add_validator(email_validator) 20 | 21 | primary_key = (username,) 22 | 23 | 24 | class Database(apgorm.Database): 25 | users = User 26 | 27 | 28 | def main() -> None: 29 | Database(Path("examples/validators/migrations")) 30 | 31 | user = User(username="Circuit") 32 | try: 33 | user.email = "not an email" 34 | except InvalidFieldValue: 35 | print("'not an email' is not a valid email!") 36 | 37 | user.email = "something@gmail.com" 38 | print(f"{user.email} is a valid email!") 39 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | warn_return_any=True 3 | warn_unused_configs=True 4 | exclude=tests 5 | strict=True 6 | 7 | [mypy-asyncpg.*] 8 | ignore_missing_imports=True 9 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import nox 2 | 3 | 4 | @nox.session 5 | def pytest_and_mypy(session: nox.Session) -> None: 6 | session.install("poetry") 7 | session.run("poetry", "install") 8 | 9 | session.run("mypy", ".") 10 | session.run( 11 | "poetry", 12 | "run", 13 | "python", 14 | "-m", 15 | "pytest", 16 | "--cov=apgorm/", 17 | "--cov-report=xml", 18 | "--asyncio-mode=auto", 19 | ) 20 | 21 | 22 | @nox.session 23 | def flake8(session: nox.Session) -> None: 24 | session.install("flake8") 25 | session.run("flake8") 26 | 27 | 28 | @nox.session 29 | def black(session: nox.Session) -> None: 30 | session.install("black") 31 | session.run("black", ".", "--check") 32 | 33 | 34 | @nox.session 35 | def isort(session: nox.Session) -> None: 36 | session.install("isort") 37 | session.run("isort", ".", "--check") 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length=79 3 | skip-magic-trailing-comma=true 4 | 5 | [tool.poetry] 6 | name = "apgorm" 7 | version = "0.0.0" 8 | description = "A fully type-checked asynchronous ORM wrapped around asyncpg." 9 | authors = ["Circuit"] 10 | license = "MIT" 11 | readme = "README.md" 12 | homepage = "https://github.com/TrigonDev/apgorm" 13 | repository = "https://github.com/TrigonDev/apgorm" 14 | documentation = "https://github.com/TrigonDev/apgorm/wiki" 15 | keywords = ["postgres", "postgresql", "asyncpg", "asyncio", "orm"] 16 | 17 | [tool.poetry.dependencies] 18 | python = "^3.8" 19 | asyncpg = ">=0.25,<0.29" 20 | pydantic = ">=1.9,<3.0" 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | black = "^23.3.0" 24 | isort = "^5.12.0" 25 | mypy = "^1.3.0" 26 | flake8 = "^5.0.4" 27 | pytest = "^7.3.2" 28 | pytest-asyncio = "^0.21.0" 29 | pytest-mock = "^3.11.1" 30 | pytest-cov = "^4.1.0" 31 | nox = "^2023.4.22" 32 | 33 | [build-system] 34 | requires = ["poetry-core>=1.0.0"] 35 | build-backend = "poetry.core.masonry.api" 36 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from typing import Generator 5 | 6 | import pytest 7 | 8 | 9 | @pytest.fixture(scope="package") 10 | def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: 11 | loop = asyncio.get_event_loop_policy().new_event_loop() 12 | yield loop 13 | loop.close() 14 | -------------------------------------------------------------------------------- /tests/grouptests/test_migrations.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import shutil 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Any 8 | from unittest.mock import AsyncMock, Mock 9 | 10 | import asyncpg 11 | import pytest 12 | from pytest_mock import MockerFixture 13 | 14 | import apgorm 15 | from apgorm.types import VarChar 16 | 17 | 18 | class EmptyDB(apgorm.Database): 19 | pass 20 | 21 | 22 | class User(apgorm.Model): 23 | name = VarChar(32).field() 24 | primary_key = (name,) 25 | 26 | 27 | class UserDB(apgorm.Database): 28 | user = User 29 | 30 | indexes = [apgorm.Index(User, User.name)] 31 | 32 | 33 | MIG_PATH = Path("tests/migrations") 34 | EDB = EmptyDB(MIG_PATH) 35 | UDB = UserDB(MIG_PATH) 36 | 37 | 38 | def erase_migrations(): 39 | if MIG_PATH.exists(): 40 | shutil.rmtree(MIG_PATH) 41 | 42 | 43 | @pytest.fixture 44 | def patch_fetch(mocker: MockerFixture): 45 | ret = asyncio.Future() 46 | mock = mocker.Mock() 47 | mock.fetch_query.return_value.fetchmany.return_value = ret 48 | 49 | mocker.patch.object(UDB, "_migrations", mock) 50 | mocker.patch.object(EDB, "_migrations", mock) 51 | 52 | return ret, mock.fetch_query.return_value.fetchmany 53 | 54 | 55 | @dataclass 56 | class PatchedDBMethods: 57 | pool: Mock 58 | con: Mock 59 | curr: Mock 60 | 61 | 62 | @pytest.fixture 63 | def patch_pool(mocker: MockerFixture) -> PatchedDBMethods: 64 | def areturn(ret: Any = None) -> AsyncMock: 65 | func = mocker.AsyncMock() 66 | func.return_value = ret 67 | return func 68 | 69 | con = mocker.Mock() 70 | tc = mocker.AsyncMock() 71 | tc.__aenter__.return_value = None 72 | tc.__aexit__.return_value = None 73 | con.transaction.return_value = tc 74 | con.execute = areturn() 75 | con.fetchrow = areturn({"hello": "world"}) 76 | con.fetchmany = areturn([{"hello": "world"}]) 77 | con.fetchval = areturn("hello, world") 78 | 79 | curr = mocker.Mock() 80 | con.cursor.return_value = curr 81 | 82 | pool = mocker.Mock() 83 | pac = mocker.AsyncMock() 84 | pac.__aenter__.return_value = con 85 | pac.__aexit__.return_value = None 86 | pool.acquire.return_value = pac 87 | pool.close = areturn() 88 | 89 | mocker.patch.object(UDB, "pool", pool) 90 | mocker.patch.object(EDB, "pool", pool) 91 | mocker.patch.object(UDB._migrations.database, "pool", pool) 92 | mocker.patch.object(EDB._migrations.database, "pool", pool) 93 | 94 | return PatchedDBMethods(pool, con, curr) 95 | 96 | 97 | @pytest.mark.asyncio 98 | async def test_create_and_apply_migrations(patch_pool, mocker): 99 | erase_migrations() 100 | 101 | # Before migration creation 102 | assert not MIG_PATH.exists() 103 | assert len(EDB.load_all_migrations()) == 0 104 | assert EDB.load_last_migration() is None 105 | assert EDB.must_create_migrations() 106 | 107 | # Create migrations 108 | EDB.create_migrations() 109 | assert not EDB.must_create_migrations() 110 | assert UDB.must_create_migrations() 111 | 112 | UDB.create_migrations() 113 | assert EDB.must_create_migrations() 114 | assert not UDB.must_create_migrations() 115 | 116 | EDB.create_migrations() 117 | UDB.create_migrations() 118 | 119 | # Test trying to create migration when none are needed 120 | with pytest.raises(apgorm.exceptions.NoMigrationsToCreate): 121 | UDB.create_migrations() 122 | 123 | # Test with override 124 | UDB.create_migrations(allow_empty=True) 125 | 126 | # After creation (there should be three, the last should be empty) 127 | assert MIG_PATH.exists() 128 | assert len(EDB.load_all_migrations()) == 5 129 | assert EDB.load_last_migration().migration_id == 4 130 | 131 | assert EDB.load_migration_from_id(1).migration_id == 1 132 | 133 | assert ( 134 | EDB.load_migration_from_id(3).describe 135 | == EDB.load_migration_from_id(4).describe 136 | ) 137 | assert EDB.load_migration_from_id(4).migrations == "" 138 | 139 | # test applying the three migrations 140 | def _make_fut( 141 | ret: Any | None = None, exc: Any | None = None 142 | ) -> asyncio.Future: 143 | fut = asyncio.Future() 144 | if exc: 145 | fut.set_exception(exc) 146 | else: 147 | fut.set_result(ret) 148 | return fut 149 | 150 | # patch _migrations.fetch 151 | f = mocker.patch.object(UDB, "_migrations") 152 | f.return_value.create.return_value = _make_fut() 153 | f.fetch.return_value = _make_fut( 154 | exc=apgorm.exceptions.ModelNotFound(User, {}) 155 | ) 156 | 157 | mig = mocker.Mock() 158 | mig.id_ = 0 159 | f.fetch_query.return_value.fetchmany.return_value = _make_fut([mig]) 160 | 161 | await UDB.apply_migrations() 162 | 163 | 164 | @pytest.mark.asyncio 165 | async def test_load_unapplied_none(patch_fetch: tuple[asyncio.Future, Mock]): 166 | erase_migrations() 167 | ret, func = patch_fetch 168 | ret.set_result([]) 169 | 170 | UDB.create_migrations() 171 | 172 | assert ( 173 | am := UDB.load_all_migrations() 174 | ) == await UDB.load_unapplied_migrations() 175 | assert len(am) == 1 176 | assert am[0].migration_id == 0 177 | 178 | 179 | @pytest.mark.asyncio 180 | async def test_load_unapplied_all( 181 | patch_fetch: tuple[asyncio.Future, Mock], mocker: MockerFixture 182 | ): 183 | erase_migrations() 184 | ret, func = patch_fetch 185 | mig = mocker.Mock() 186 | mig.id_ = 0 187 | ret.set_result([mig]) 188 | 189 | assert len(await UDB.load_unapplied_migrations()) == 0 190 | 191 | 192 | @pytest.mark.asyncio 193 | async def test_load_unapplied_exc(patch_fetch: tuple[asyncio.Future, Mock]): 194 | erase_migrations() 195 | UDB.create_migrations() 196 | 197 | ret, func = patch_fetch 198 | ret.set_exception(asyncpg.UndefinedTableError) 199 | 200 | assert len(await UDB.load_unapplied_migrations()) == 1 201 | func.assert_called_once_with() 202 | 203 | 204 | @pytest.mark.asyncio 205 | async def test_must_apply_true(patch_fetch: tuple[asyncio.Future, Mock]): 206 | erase_migrations() 207 | UDB.create_migrations() 208 | 209 | ret, func = patch_fetch 210 | ret.set_result([]) 211 | 212 | assert await UDB.must_apply_migrations() 213 | 214 | 215 | @pytest.mark.asyncio 216 | async def test_must_apply_false( 217 | patch_fetch: tuple[asyncio.Future, Mock], mocker: MockerFixture 218 | ): 219 | erase_migrations() 220 | UDB.create_migrations() 221 | 222 | ret, func = patch_fetch 223 | mig = mocker.Mock() 224 | mig.id_ = 0 225 | ret.set_result([mig]) 226 | 227 | assert not await UDB.must_apply_migrations() 228 | 229 | 230 | @pytest.mark.asyncio 231 | async def test_apply_migrations( 232 | mocker: MockerFixture, patch_fetch: tuple[asyncio.Future, Mock] 233 | ): 234 | erase_migrations() 235 | UDB.create_migrations() 236 | 237 | ret, _ = patch_fetch 238 | ret.set_result([]) 239 | func = mocker.patch.object(UDB, "_apply_migration") 240 | func.return_value = asyncio.Future() 241 | func.return_value.set_result(None) 242 | 243 | await UDB.apply_migrations() 244 | 245 | func.assert_called_once() 246 | -------------------------------------------------------------------------------- /tests/unittests/sql/generators/test_query.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | import apgorm 6 | from apgorm.sql.generators import query 7 | from apgorm.sql.sql import Block 8 | from apgorm.types import VarChar 9 | 10 | 11 | class MyModel(apgorm.Model): 12 | somefield = VarChar(32).field() 13 | primary_key = (somefield,) 14 | 15 | 16 | class Database(apgorm.Database): 17 | mymodel = MyModel 18 | 19 | 20 | DB = Database(None) 21 | 22 | 23 | @pytest.mark.parametrize("model", [MyModel, MyModel(), apgorm.raw("mymodel")]) 24 | def test_select_table_model(model): 25 | q = query.select(from_=model) 26 | assert q.render() == ("SELECT * FROM mymodel", []) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "fields", 31 | [ 32 | [apgorm.raw("somefield"), MyModel.somefield], 33 | [apgorm.raw("somefield")], 34 | [MyModel.somefield], 35 | ], 36 | ) 37 | def test_select_table_fields(fields): 38 | q = query.select(from_=MyModel, fields=fields) 39 | fs = [ 40 | f.render_no_params() if isinstance(f, Block) else f.full_name 41 | for f in fields 42 | ] 43 | fs = " , ".join(fs) 44 | assert q.render() == (f"SELECT ( {fs} ) FROM mymodel", []) 45 | 46 | 47 | def test_where_logic(): 48 | q = query.select(from_=MyModel, where=apgorm.raw("name=$1")) 49 | assert q.render() == ("SELECT * FROM mymodel WHERE ( name=$1 )", []) 50 | 51 | 52 | def test_count(): 53 | q = query.select(from_=MyModel, count=True) 54 | assert q.render() == ("SELECT COUNT(*) FROM mymodel", []) 55 | -------------------------------------------------------------------------------- /tests/unittests/sql/test_query_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from apgorm import ( 6 | DeleteQueryBuilder, 7 | FetchQueryBuilder, 8 | InsertQueryBuilder, 9 | UpdateQueryBuilder, 10 | ) 11 | from apgorm.sql.query_builder import _dict_model_converter 12 | 13 | ALL_BUILDERS = [ 14 | DeleteQueryBuilder, 15 | FetchQueryBuilder, 16 | InsertQueryBuilder, 17 | UpdateQueryBuilder, 18 | ] 19 | FILTER_BUILDERS = [DeleteQueryBuilder, FetchQueryBuilder, UpdateQueryBuilder] 20 | 21 | 22 | def test_model_converter(mocker): 23 | c = _dict_model_converter(m := mocker.Mock()) 24 | ret = c({"hello": "world"}) 25 | 26 | assert m._from_raw.return_value is ret 27 | m._from_raw.assert_called_once_with(hello="world") 28 | 29 | 30 | @pytest.mark.parametrize("type_", ALL_BUILDERS) 31 | def test_init_no_con(type_, mocker): 32 | q = type_(m := mocker.Mock()) 33 | assert q.model is m 34 | assert q.con is m.database 35 | 36 | 37 | @pytest.mark.parametrize("type_", ALL_BUILDERS) 38 | def test_init_with_con(type_, mocker): 39 | q = type_(m := mocker.Mock(), c := mocker.Mock()) 40 | assert q.model is m 41 | assert q.con is c 42 | 43 | 44 | @pytest.mark.parametrize("type_", FILTER_BUILDERS) 45 | def test_where_logic(type_, mocker): 46 | q = type_(mocker.Mock()) 47 | q.where("PARAM", key="value") 48 | 49 | assert q._where_logic().render() == ( 50 | "$1 AND ( key = $2 )", 51 | ["PARAM", "value"], 52 | ) 53 | 54 | 55 | @pytest.mark.parametrize("type_", FILTER_BUILDERS) 56 | def test_where_logic_empty(type_, mocker): 57 | q = type_(mocker.Mock()) 58 | 59 | assert q._where_logic() is None 60 | 61 | 62 | @pytest.mark.parametrize("type_", FILTER_BUILDERS) 63 | def test_where_returns_self(type_, mocker): 64 | q = type_(mocker.Mock()) 65 | nq = q.where() 66 | 67 | assert q is nq 68 | 69 | 70 | @pytest.mark.parametrize("reverse", [True, False]) 71 | def test_fqb_order_by(mocker, reverse): 72 | q = FetchQueryBuilder(mocker.Mock()) 73 | q.order_by("hello, world", reverse=reverse) 74 | 75 | assert q._reverse is reverse 76 | assert q._order_by_logic == "hello, world" 77 | 78 | 79 | def test_fqb_order_by_returns_self(mocker): 80 | q = FetchQueryBuilder(mocker.Mock()) 81 | nq = q.order_by("hello, world") 82 | 83 | assert q is nq 84 | 85 | 86 | def test_fqb_exists(mocker): 87 | q = FetchQueryBuilder(m := mocker.Mock()) 88 | m.tablename = "tablename" 89 | 90 | # f = mocker.patch.object(q, "_get_block") 91 | # f.return_value = apgorm.raw("hello") 92 | 93 | assert q.exists().render() == ("EXISTS ( SELECT * FROM tablename )", []) 94 | 95 | 96 | @pytest.mark.asyncio 97 | async def test_fqb_fetchmany_exc(mocker): 98 | q = FetchQueryBuilder(mocker.Mock()) 99 | 100 | with pytest.raises(TypeError): 101 | await q.fetchmany("hello") 102 | 103 | 104 | @pytest.mark.asyncio 105 | async def test_fqb_fetchmany(mocker): 106 | q = FetchQueryBuilder(m := mocker.Mock(), c := mocker.AsyncMock()) 107 | c.fetchmany.return_value = [{"hello": "world"}] 108 | 109 | ll = await q.fetchmany() 110 | cnv = ll[0] 111 | assert ll._data == [{"hello": "world"}] 112 | assert cnv is m._from_raw.return_value 113 | m._from_raw.assert_called_once_with(hello="world") 114 | -------------------------------------------------------------------------------- /tests/unittests/test_connection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from apgorm import Connection, LazyList, Pool 6 | 7 | 8 | @pytest.fixture 9 | def async_mocked_con(mocker): 10 | subcon = mocker.AsyncMock() 11 | subcon.fetchrow.return_value = {"hello": "world"} 12 | subcon.fetch.return_value = [{"hello": "world"}] 13 | subcon.fetchval.return_value = "hello, world" 14 | return subcon 15 | 16 | 17 | @pytest.mark.asyncio 18 | async def test_connection_execute(async_mocked_con): 19 | con = Connection(async_mocked_con) 20 | 21 | await con.execute("SELECT $1", [1]) 22 | 23 | async_mocked_con.execute.assert_called_once_with("SELECT $1", 1) 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_connection_fetchrow(async_mocked_con): 28 | con = Connection(async_mocked_con) 29 | 30 | res = await con.fetchrow("SELECT $1", [1]) 31 | 32 | assert res == {"hello": "world"} 33 | async_mocked_con.fetchrow.assert_called_once_with("SELECT $1", 1) 34 | 35 | 36 | @pytest.mark.asyncio 37 | async def test_connection_fetchmany(async_mocked_con): 38 | con = Connection(async_mocked_con) 39 | 40 | res = await con.fetchmany("SELECT $1", [1]) 41 | 42 | assert isinstance(res, LazyList) 43 | assert list(res) == [{"hello": "world"}] 44 | 45 | async_mocked_con.fetch.assert_called_once_with("SELECT $1", 1) 46 | 47 | 48 | @pytest.mark.asyncio 49 | async def test_connection_fetchval(async_mocked_con): 50 | con = Connection(async_mocked_con) 51 | 52 | res = await con.fetchval("SELECT $1", [1]) 53 | 54 | assert res == "hello, world" 55 | async_mocked_con.fetchval.assert_called_once_with("SELECT $1", 1) 56 | 57 | 58 | def test_connection_cursor(mocker): 59 | mocked_con = mocker.Mock() 60 | con = Connection(mocked_con) 61 | 62 | con.cursor("SELECT $1", [1]) 63 | 64 | mocked_con.cursor.assert_called_once_with("SELECT $1", 1) 65 | 66 | 67 | def test_connection_transaction(mocker): 68 | mocked_con = mocker.Mock() 69 | con = Connection(mocked_con) 70 | 71 | con.transaction() 72 | 73 | mocked_con.transaction.assert_called_once() 74 | 75 | 76 | @pytest.mark.asyncio 77 | async def test_pool(mocker, async_mocked_con): 78 | mocked_pac = mocker.AsyncMock() 79 | mocked_pac.__aenter__.return_value = async_mocked_con 80 | 81 | mocked_pool = mocker.Mock() 82 | mocked_pool.acquire.return_value = mocked_pac 83 | 84 | pool = Pool(mocked_pool) 85 | 86 | async with pool.acquire() as yielded_con: 87 | pass 88 | 89 | pool.close() 90 | 91 | assert yielded_con.con is async_mocked_con 92 | mocked_pool.acquire.assert_called_once() 93 | mocked_pool.close.assert_called_once() 94 | mocked_pac.__aenter__.assert_called_once() 95 | mocked_pac.__aexit__.assert_called_once() 96 | -------------------------------------------------------------------------------- /tests/unittests/test_converter.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import IntEnum, IntFlag 4 | 5 | from apgorm import IntEFConverter 6 | 7 | 8 | class IE(IntEnum): 9 | ONE = 1 10 | TWO = 2 11 | THREE = 3 12 | 13 | 14 | class IF(IntFlag): 15 | ONE = 1 << 0 16 | TWO = 1 << 1 17 | THREE = 1 << 3 18 | 19 | 20 | def test_intefconverter_as_intenum() -> None: 21 | iec = IntEFConverter(IE) 22 | assert iec.from_stored(3) is IE.THREE 23 | assert iec.to_stored(IE.THREE) == 3 24 | 25 | 26 | def test_intefconverter_as_intflag() -> None: 27 | ifc = IntEFConverter(IF) 28 | assert ifc.from_stored((IF.ONE | IF.TWO).value) == (IF.ONE | IF.TWO) 29 | assert ifc.to_stored(IF.ONE | IF.THREE) == (IF.ONE | IF.THREE).value 30 | -------------------------------------------------------------------------------- /tests/unittests/test_database.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Any 7 | from unittest.mock import AsyncMock, Mock 8 | 9 | import asyncpg 10 | import pytest 11 | from pytest_mock import MockerFixture 12 | 13 | import apgorm 14 | from apgorm.types import VarChar 15 | 16 | 17 | class User(apgorm.Model): 18 | name = VarChar(32).field() 19 | nick = VarChar(32).nullablefield() 20 | 21 | nick_unique = apgorm.Exclude((nick, "="), where="nick IS NOT NULL") 22 | 23 | primary_key = (name,) 24 | 25 | 26 | class Database(apgorm.Database): 27 | users = User 28 | 29 | indexes = [apgorm.Index(User, User.nick)] 30 | 31 | 32 | DB = Database(Path("tests/migrations")) 33 | 34 | 35 | def test_updates_models_and_fields(): 36 | assert DB.users.tablename == "users" 37 | assert DB.users.name.name == "name" 38 | assert DB.users.nick.name == "nick" 39 | assert DB.users.nick_unique.name == "nick_unique" 40 | assert DB.users.primary_key[0] is DB.users.name 41 | 42 | 43 | def test_str_folder(): 44 | STR = Database("tests/migrations") 45 | assert isinstance(STR._migrations_folder, Path) 46 | assert STR._migrations_folder == DB._migrations_folder 47 | 48 | 49 | def test_collects_models(): 50 | assert User in DB._all_models 51 | assert User is DB.users 52 | 53 | 54 | def test_indexes_match(): 55 | assert DB.indexes[0].table is DB.users 56 | assert DB.indexes[0].fields[0].name == "nick" 57 | 58 | 59 | def test_describe(): 60 | describe = DB.describe() 61 | 62 | assert len(describe.tables) == 2 63 | assert len(describe.indexes) == 1 64 | 65 | 66 | @dataclass 67 | class PatchedDBMethods: 68 | pool: Mock 69 | con: Mock 70 | curr: Mock 71 | 72 | 73 | @pytest.fixture 74 | def db(mocker: MockerFixture) -> PatchedDBMethods: 75 | def areturn(ret: Any = None) -> AsyncMock: 76 | func = mocker.AsyncMock() 77 | func.return_value = ret 78 | return func 79 | 80 | con = mocker.Mock() 81 | tc = mocker.AsyncMock() 82 | tc.__aenter__.return_value = None 83 | tc.__aexit__.return_value = None 84 | con.transaction.return_value = tc 85 | con.execute = areturn() 86 | con.fetchrow = areturn({"hello": "world"}) 87 | con.fetchmany = areturn([{"hello": "world"}]) 88 | con.fetchval = areturn("hello, world") 89 | 90 | curr = mocker.Mock() 91 | con.cursor.return_value = curr 92 | 93 | pool = mocker.Mock() 94 | pac = mocker.AsyncMock() 95 | pac.__aenter__.return_value = con 96 | pac.__aexit__.return_value = None 97 | pool.acquire.return_value = pac 98 | pool.close = areturn() 99 | 100 | mocker.patch.object(DB, "pool", pool) 101 | 102 | return PatchedDBMethods(pool, con, curr) 103 | 104 | 105 | @pytest.mark.asyncio 106 | async def test_connect(mocker: MockerFixture): 107 | cn = mocker.patch.object(asyncpg, "create_pool") 108 | fut = asyncio.Future() 109 | fut.set_result(None) 110 | cn.return_value = fut 111 | 112 | await DB.connect(hello="world") 113 | 114 | cn.assert_called_once_with(hello="world") 115 | assert DB.pool.pool is fut.result() 116 | 117 | 118 | @pytest.mark.asyncio 119 | async def test_cleanup(db: PatchedDBMethods): 120 | await DB.cleanup() 121 | 122 | db.pool.close.assert_called_once_with() 123 | 124 | 125 | def assert_transaction(db: PatchedDBMethods): 126 | db.pool.acquire.assert_called_once_with() 127 | db.pool.acquire.return_value.__aenter__.assert_called_once_with() 128 | db.con.transaction.assert_called_once_with() 129 | db.con.transaction.return_value.__aenter__.assert_called_once_with() 130 | 131 | 132 | @pytest.mark.asyncio 133 | async def test_execute(db: PatchedDBMethods): 134 | ret = await DB.execute("HELLO $1", ["world"]) 135 | 136 | assert_transaction(db) 137 | db.con.execute.assert_called_once_with("HELLO $1", ["world"]) 138 | assert ret is None 139 | 140 | 141 | @pytest.mark.asyncio 142 | async def test_fetchrow(db: PatchedDBMethods): 143 | ret = await DB.fetchrow("HELLO $1", ["world"]) 144 | 145 | assert_transaction(db) 146 | db.con.fetchrow.assert_called_once_with("HELLO $1", ["world"]) 147 | assert ret == db.con.fetchrow.return_value 148 | 149 | 150 | @pytest.mark.asyncio 151 | async def test_fetchmany(db: PatchedDBMethods): 152 | ret = await DB.fetchmany("HELLO $1", ["world"]) 153 | 154 | assert_transaction(db) 155 | db.con.fetchmany.assert_called_once_with("HELLO $1", ["world"]) 156 | assert ret == db.con.fetchmany.return_value 157 | 158 | 159 | @pytest.mark.asyncio 160 | async def test_fetchval(db: PatchedDBMethods): 161 | ret = await DB.fetchval("HELLO $1", ["world"]) 162 | 163 | assert_transaction(db) 164 | db.con.fetchval.assert_called_once_with("HELLO $1", ["world"]) 165 | assert ret == db.con.fetchval.return_value 166 | 167 | 168 | @pytest.mark.asyncio 169 | async def test_cursor(db: PatchedDBMethods): 170 | async with DB.cursor("HELLO $1", ["world"]) as curr: 171 | assert curr is db.curr 172 | 173 | assert_transaction(db) 174 | db.con.cursor.assert_called_once_with("HELLO $1", ["world"]) 175 | 176 | 177 | @pytest.mark.asyncio 178 | async def test_cursor_with_con(db: PatchedDBMethods): 179 | async with DB.pool.acquire() as con: 180 | async with con.transaction(): 181 | async with DB.cursor("HELLO $1", ["world"], con=db.con) as cursor: 182 | assert cursor is db.curr 183 | 184 | assert_transaction(db) 185 | db.con.cursor.assert_called_once_with("HELLO $1", ["world"]) 186 | -------------------------------------------------------------------------------- /tests/unittests/test_field.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockerFixture 3 | 4 | from apgorm import Converter, Field 5 | from apgorm.exceptions import BadArgument, InvalidFieldValue 6 | 7 | 8 | @pytest.fixture(scope="function") 9 | def mock_field(mocker: MockerFixture): 10 | f = Field(mocker.Mock()) 11 | f.sql_type.sql = "SQLTYPE" 12 | f.model = mocker.Mock() 13 | f.model.tablename = "table" 14 | f.name = "field" 15 | return f 16 | 17 | 18 | def test_init(): 19 | Field(None) 20 | with pytest.raises(BadArgument): 21 | Field(None, default=None, default_factory=lambda: 1) 22 | 23 | 24 | def test_full_name(mock_field: Field): 25 | assert mock_field.full_name == "table.field" 26 | 27 | 28 | def test_add_validator(mock_field: Field, mocker: MockerFixture): 29 | validator = mocker.Mock() 30 | ret_f = mock_field.add_validator(validator) 31 | 32 | assert ret_f is mock_field 33 | assert validator in mock_field._validators 34 | 35 | 36 | def test_validate(mock_field: Field, mocker: MockerFixture): 37 | works = mocker.Mock() 38 | works.return_value = True 39 | fails = mocker.Mock() 40 | fails.return_value = False 41 | 42 | def fails_another_way(v): 43 | raise InvalidFieldValue("Test") 44 | 45 | assert mock_field.add_validator(works)._validate(None) is None 46 | 47 | with pytest.raises(InvalidFieldValue): 48 | mock_field.add_validator(fails)._validate(None) 49 | 50 | mock_field._validators = [] 51 | try: 52 | mock_field.add_validator(fails_another_way)._validate(None) 53 | except InvalidFieldValue: 54 | pass 55 | else: 56 | assert False, "Raising Exception didn't raise InvalidFieldValue." 57 | 58 | # test that Field.v = value calls validators 59 | validator = mocker.Mock() 60 | mock_field._validators = [validator] 61 | mock_field._validate(1) 62 | validator.assert_called_once_with(1) 63 | 64 | 65 | class MyConv(Converter): 66 | def to_stored(self, value: str) -> int: 67 | return int(value) 68 | 69 | def from_stored(self, value: int) -> str: 70 | return str(value) 71 | 72 | 73 | def test_with_converter(mock_field: Field): 74 | convf_init = mock_field.with_converter(i := MyConv()) 75 | convf_type = mock_field.with_converter(MyConv) 76 | 77 | assert convf_init.converter is i 78 | assert isinstance(convf_type.converter, MyConv) 79 | -------------------------------------------------------------------------------- /tests/unittests/test_indexes.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from apgorm import Database, Index, IndexType, Model 6 | from apgorm.exceptions import BadArgument 7 | from apgorm.types import VarChar 8 | 9 | 10 | class SomeModel(Model): 11 | name = VarChar(32).field() 12 | other = VarChar(32).field() 13 | primary_key = (name,) 14 | 15 | 16 | class MyDb(Database): 17 | some_model = SomeModel 18 | 19 | 20 | DB = MyDb(None) 21 | 22 | 23 | def test_creation(): 24 | # single field 25 | single_index = Index(SomeModel, SomeModel.name) 26 | assert single_index.fields == [SomeModel.name] 27 | 28 | # multiple fields 29 | multi_index = Index(SomeModel, [SomeModel.name, SomeModel.other]) 30 | assert multi_index.fields == [SomeModel.name, SomeModel.other] 31 | 32 | # test 0 field error 33 | with pytest.raises(BadArgument): 34 | Index(SomeModel, []) 35 | 36 | # test other exceptions 37 | for t in IndexType: 38 | try: 39 | Index(SomeModel, [SomeModel.name, SomeModel.other], type_=t) 40 | except BadArgument: 41 | multi = False 42 | else: 43 | multi = True 44 | 45 | try: 46 | Index(SomeModel, SomeModel.name, unique=True, type_=t) 47 | except BadArgument: 48 | unique = False 49 | else: 50 | unique = True 51 | 52 | assert multi == t.value.multi 53 | assert unique == t.value.unique 54 | 55 | 56 | def test_get_name(): 57 | single_index = Index(SomeModel, SomeModel.name) 58 | multi_index = Index(SomeModel, [SomeModel.name, SomeModel.other]) 59 | 60 | assert single_index.get_name() == "_btree_index_some_model__name" 61 | assert multi_index.get_name() == "_btree_index_some_model__name_other" 62 | -------------------------------------------------------------------------------- /tests/unittests/test_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from enum import IntEnum 5 | from typing import Any 6 | from unittest.mock import AsyncMock, Mock 7 | 8 | import pytest 9 | from pytest_mock import MockerFixture 10 | 11 | import apgorm 12 | from apgorm.types import Int, Serial, VarChar 13 | 14 | 15 | class UserStatus(IntEnum): 16 | OFFLINE = 0 17 | INVISIBLE = 1 18 | ONLINE = 2 19 | DND = 3 20 | 21 | 22 | def validate_status(value) -> bool: 23 | assert isinstance(value, UserStatus) 24 | return True 25 | 26 | 27 | class User(apgorm.Model): 28 | userid = Serial().field() 29 | name = VarChar(32).field() 30 | nick = VarChar(32).nullablefield() 31 | status = ( 32 | Int() 33 | .field(default=0) 34 | .with_converter(apgorm.IntEFConverter(UserStatus)) 35 | ) 36 | status.add_validator(lambda v: v is not UserStatus.INVISIBLE) 37 | default_fact = VarChar(32).field(default_factory=lambda: "hello, world") 38 | 39 | name_unique = apgorm.Unique(name) 40 | primary_key = (userid,) 41 | 42 | games = apgorm.ManyToMany( 43 | "userid", "players.userid", "players.gameid", "games.gameid" 44 | ) 45 | 46 | 47 | class Game(apgorm.Model): 48 | gameid = Serial().field() 49 | 50 | users = apgorm.ManyToMany( 51 | "gameid", "players.gameid", "players.userid", "users.userid" 52 | ) 53 | 54 | primary_key = (gameid,) 55 | 56 | 57 | class Player(apgorm.Model): 58 | userid = Int().field() 59 | gameid = Int().field() 60 | 61 | uid_fk = apgorm.ForeignKey(userid, User.userid) 62 | gid_fk = apgorm.ForeignKey(gameid, Game.gameid) 63 | 64 | primary_key = (userid, gameid) 65 | 66 | 67 | class Database(apgorm.Database): 68 | users = User 69 | games = Game 70 | players = Player 71 | 72 | 73 | DB = Database(None) 74 | 75 | 76 | @dataclass 77 | class PatchedDBMethods: 78 | pool: Mock 79 | con: Mock 80 | curr: Mock 81 | 82 | 83 | @pytest.fixture 84 | def db(mocker: MockerFixture) -> PatchedDBMethods: 85 | def areturn(ret: Any = None) -> AsyncMock: 86 | func = mocker.AsyncMock() 87 | func.return_value = ret 88 | return func 89 | 90 | con = mocker.Mock() 91 | tc = mocker.AsyncMock() 92 | tc.__aenter__.return_value = None 93 | tc.__aexit__.return_value = None 94 | con.transaction.return_value = tc 95 | con.execute = areturn() 96 | con.fetchrow = areturn({"hello": "world"}) 97 | con.fetchmany = areturn([{"hello": "world"}]) 98 | con.fetchval = areturn("hello, world") 99 | 100 | curr = mocker.Mock() 101 | con.cursor.return_value = curr 102 | 103 | pool = mocker.Mock() 104 | pac = mocker.AsyncMock() 105 | pac.__aenter__.return_value = con 106 | pac.__aexit__.return_value = None 107 | pool.acquire.return_value = pac 108 | pool.close = areturn() 109 | 110 | mocker.patch.object(DB, "pool", pool) 111 | 112 | return PatchedDBMethods(pool, con, curr) 113 | 114 | 115 | def assert_transaction(db: PatchedDBMethods): 116 | db.pool.acquire.assert_called_once_with() 117 | db.pool.acquire.return_value.__aenter__.assert_called_once_with() 118 | db.con.transaction.assert_called_once_with() 119 | db.con.transaction.return_value.__aenter__.assert_called_once_with() 120 | 121 | 122 | def test_converter_field_in_init(): 123 | user = User(status=UserStatus.OFFLINE) 124 | assert user._raw_values["status"] == UserStatus.OFFLINE.value 125 | 126 | 127 | def test_validate_in_init(): 128 | with pytest.raises(apgorm.exceptions.InvalidFieldValue): 129 | User(status=UserStatus.INVISIBLE) 130 | 131 | User(status=UserStatus.OFFLINE) 132 | User._from_raw(status=UserStatus.INVISIBLE.value) 133 | 134 | 135 | def test_gets_default(): 136 | user = User() 137 | 138 | assert user.status is UserStatus.OFFLINE 139 | assert user.default_fact == "hello, world" 140 | assert "name" not in user._raw_values 141 | 142 | 143 | @pytest.mark.asyncio 144 | async def test_delete(db: PatchedDBMethods, mocker: MockerFixture): 145 | async def new_execute(self: apgorm.DeleteQueryBuilder): 146 | assert self._where_logic().render_no_params() == "userid = $1" 147 | return [User(userid=1)] 148 | 149 | mocker.patch.object(apgorm.DeleteQueryBuilder, "execute", new_execute) 150 | 151 | user = User(userid=5) 152 | ret = await user.delete() 153 | assert user is not ret 154 | assert ret.userid == 1 155 | assert user.userid == 5 156 | 157 | 158 | @pytest.mark.asyncio 159 | async def test_delete_none(db: PatchedDBMethods, mocker: MockerFixture): 160 | async def new_execute(self: apgorm.DeleteQueryBuilder): 161 | return [] 162 | 163 | mocker.patch.object(apgorm.DeleteQueryBuilder, "execute", new_execute) 164 | 165 | user = User(userid=5) 166 | with pytest.raises(apgorm.exceptions.ModelNotFound): 167 | await user.delete() 168 | 169 | 170 | @pytest.mark.asyncio 171 | async def test_save_normal(db: PatchedDBMethods, mocker: MockerFixture): 172 | async def new_execute(self: apgorm.UpdateQueryBuilder): 173 | assert self._where_logic().render() == ("userid = $1", [1]) 174 | set_values = self._set_values 175 | assert [(k.render_no_params(), v) for k, v in set_values.items()] == [ 176 | ("name", "New Name") 177 | ] 178 | return [User(userid=5)] 179 | 180 | mocker.patch.object(apgorm.UpdateQueryBuilder, "execute", new_execute) 181 | 182 | user = User(userid=1) 183 | user.name = "New Name" 184 | await user.save() 185 | 186 | assert user.userid == 5 # updated 187 | 188 | 189 | @pytest.mark.asyncio 190 | async def test_save_none_to_save(db: PatchedDBMethods, mocker: MockerFixture): 191 | spy = mocker.spy(apgorm.UpdateQueryBuilder, "execute") 192 | user = User(userid=5) 193 | await user.save() 194 | 195 | spy.assert_not_called() 196 | -------------------------------------------------------------------------------- /tests/unittests/types/test_array.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from apgorm.types import Array, Int # for subtypes 4 | 5 | 6 | @pytest.mark.parametrize("subtype", [Int(), Array(Int()), Array(Array(Int()))]) 7 | def test_array_init(subtype): 8 | a = Array(subtype) 9 | 10 | assert a.subtype is subtype 11 | 12 | 13 | def test_array_sql(): 14 | assert Array(Int())._sql == "INTEGER[]" 15 | assert Array(Array(Int()))._sql == "INTEGER[][]" 16 | -------------------------------------------------------------------------------- /tests/unittests/types/test_bitstr.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from apgorm.types import Bit, VarBit 4 | 5 | 6 | def test_bit_no_length(): 7 | bit = Bit() 8 | assert bit._sql == "BIT" 9 | assert bit._length is None 10 | 11 | 12 | def test_bit_with_length(): 13 | bit = Bit(5) 14 | assert bit._sql == "BIT(5)" 15 | assert bit._length == 5 16 | 17 | 18 | def test_varbit_no_length(): 19 | bit = VarBit() 20 | assert bit._sql == "VARBIT" 21 | assert bit.max_length is None 22 | 23 | 24 | def test_varbit_with_length(): 25 | bit = VarBit(5) 26 | assert bit._sql == "VARBIT(5)" 27 | assert bit.max_length == 5 28 | -------------------------------------------------------------------------------- /tests/unittests/types/test_character.py: -------------------------------------------------------------------------------- 1 | from apgorm.types import Char, VarChar 2 | 3 | 4 | def test_char_no_length(): 5 | c = Char() 6 | assert c.length is None 7 | assert c._sql == "CHAR" 8 | 9 | 10 | def test_char_with_length(): 11 | c = Char(5) 12 | assert c.length == 5 13 | assert c._sql == "CHAR(5)" 14 | 15 | 16 | def test_varchar_no_length(): 17 | c = VarChar() 18 | assert c.max_length is None 19 | assert c._sql == "VARCHAR" 20 | 21 | 22 | def test_varchar_with_length(): 23 | c = VarChar(5) 24 | assert c.max_length == 5 25 | assert c._sql == "VARCHAR(5)" 26 | -------------------------------------------------------------------------------- /tests/unittests/types/test_date.py: -------------------------------------------------------------------------------- 1 | from apgorm.types import ( 2 | Interval, 3 | IntervalField, 4 | Time, 5 | Timestamp, 6 | TimestampTZ, 7 | TimeTZ, 8 | ) 9 | 10 | 11 | def test_timestamp_nop(): 12 | ts = Timestamp() 13 | assert ts.precision is None 14 | assert ts._sql == "TIMESTAMP" 15 | 16 | 17 | def test_timestamp_wp(): 18 | ts = Timestamp(5) 19 | assert ts.precision == 5 20 | assert ts._sql == "TIMESTAMP(5)" 21 | 22 | 23 | def test_timestamptz_nop(): 24 | ts = TimestampTZ() 25 | assert ts.precision is None 26 | assert ts._sql == "TIMESTAMPTZ" 27 | 28 | 29 | def test_timestamptz_wp(): 30 | ts = TimestampTZ(5) 31 | assert ts.precision == 5 32 | assert ts._sql == "TIMESTAMPTZ(5)" 33 | 34 | 35 | def test_time_nop(): 36 | t = Time() 37 | assert t.precision is None 38 | assert t._sql == "TIME" 39 | 40 | 41 | def test_time_wp(): 42 | t = Time(5) 43 | assert t.precision == 5 44 | assert t._sql == "TIME(5)" 45 | 46 | 47 | def test_timetz_nop(): 48 | t = TimeTZ() 49 | assert t.precision is None 50 | assert t._sql == "TIMETZ" 51 | 52 | 53 | def test_timetz_wp(): 54 | t = TimeTZ(5) 55 | assert t.precision == 5 56 | assert t._sql == "TIMETZ(5)" 57 | 58 | 59 | def test_interval(): 60 | i = Interval() 61 | assert i.precision is None 62 | assert i.interval_field is None 63 | assert i._sql == "INTERVAL" 64 | 65 | 66 | def test_interval_wf(): 67 | i = Interval(interval_field=IntervalField.DAY) 68 | assert i.precision is None 69 | assert i.interval_field is IntervalField.DAY 70 | assert i._sql == "INTERVAL DAY" 71 | 72 | 73 | def test_interval_wp(): 74 | i = Interval(precision=5) 75 | assert i.precision == 5 76 | assert i.interval_field is None 77 | assert i._sql == "INTERVAL(5)" 78 | 79 | 80 | def test_interval_wpf(): 81 | i = Interval(precision=5, interval_field=IntervalField.DAY) 82 | assert i.precision == 5 83 | assert i.interval_field is IntervalField.DAY 84 | assert i._sql == "INTERVAL DAY(5)" 85 | -------------------------------------------------------------------------------- /tests/unittests/types/test_numeric.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import apgorm 4 | from apgorm.types import BigSerial, Numeric, Serial, SmallSerial 5 | from apgorm.types.numeric import _BaseSerial 6 | 7 | 8 | def test_numeric(): 9 | n = Numeric() 10 | assert n.precision is None 11 | assert n.scale is None 12 | assert n._sql == "NUMERIC" 13 | 14 | 15 | def test_numeric_wp(): 16 | n = Numeric(precision=5) 17 | assert n.precision == 5 18 | assert n.scale is None 19 | assert n._sql == "NUMERIC(5)" 20 | 21 | 22 | def test_numeric_ws(): 23 | with pytest.raises(apgorm.exceptions.BadArgument): 24 | Numeric(scale=5) 25 | 26 | 27 | def test_numeric_wps(): 28 | n = Numeric(scale=4, precision=5) 29 | assert n.precision == 5 30 | assert n.scale == 4 31 | assert n._sql == "NUMERIC(5, 4)" 32 | 33 | 34 | @pytest.mark.parametrize("s", [Serial, SmallSerial, BigSerial]) 35 | def test_serial_subclass_of_base_serial(s): 36 | assert issubclass(s, _BaseSerial) 37 | 38 | 39 | @pytest.mark.parametrize("s", [Serial, SmallSerial, BigSerial]) 40 | def test_base_serial(s: _BaseSerial): 41 | nnf = s().field() 42 | nf = s().nullablefield() 43 | 44 | assert nnf.not_null is True 45 | assert nf.not_null is False 46 | -------------------------------------------------------------------------------- /tests/unittests/utils/test_lazy_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from apgorm import LazyList 4 | 5 | 6 | def test_lazy_list(): 7 | initial = range(0, 500) 8 | ll = LazyList(initial, str) 9 | 10 | assert len(initial) == len(ll) 11 | assert list(ll[0:5]) == [str(v) for v in initial[0:5]] 12 | assert ll[6] == str(initial[6]) 13 | assert list(ll) == [str(v) for v in initial] 14 | 15 | assert repr(ll) # essentially make sure it doesn't crash 16 | 17 | sll = LazyList([0, 1], str) 18 | assert repr(sll) 19 | 20 | mll = LazyList([0, 1, 2, 3, 4], str) 21 | assert repr(mll) 22 | --------------------------------------------------------------------------------