├── iceaxe ├── py.typed ├── __tests__ │ ├── __init__.py │ ├── schemas │ │ ├── __init__.py │ │ ├── test_cli.py │ │ └── test_db_stubs.py │ ├── benchmarks │ │ ├── __init__.py │ │ ├── test_bulk_insert.py │ │ └── test_select.py │ ├── migrations │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_generics.py │ │ ├── test_generator.py │ │ └── test_action_sorter.py │ ├── mountaineer │ │ ├── __init__.py │ │ └── dependencies │ │ │ ├── __init__.py │ │ │ └── test_core.py │ ├── test_helpers.py │ ├── test_field.py │ ├── test_base.py │ ├── test_alias.py │ ├── conf_models.py │ ├── test_modifications.py │ ├── test_queries_str.py │ ├── conftest.py │ ├── docker_helpers.py │ └── helpers.py ├── schemas │ ├── __init__.py │ └── cli.py ├── migrations │ ├── __init__.py │ ├── client_io.py │ ├── migration.py │ ├── action_sorter.py │ ├── migrator.py │ └── cli.py ├── mountaineer │ ├── dependencies │ │ ├── __init__.py │ │ └── core.py │ ├── __init__.py │ ├── config.py │ └── cli.py ├── __init__.py ├── typing.py ├── alias_values.py ├── logging.py ├── io.py ├── sql_types.py ├── postgres.py ├── generics.py ├── modifications.py ├── field.py └── session_optimized.pyx ├── MANIFEST.in ├── media └── header.png ├── docs ├── api │ ├── mountaineer │ │ ├── config │ │ │ └── page.mdx │ │ ├── dependencies │ │ │ └── page.mdx │ │ └── page.mdx │ ├── connection │ │ └── page.mdx │ ├── tables │ │ └── page.mdx │ ├── migrations │ │ └── page.mdx │ ├── fields │ │ └── page.mdx │ └── queries │ │ └── page.mdx ├── sql │ ├── select │ │ └── page.mdx │ └── introduction │ │ └── page.mdx └── guides │ ├── logging │ └── page.mdx │ ├── tables │ └── page.mdx │ ├── quickstart │ └── page.mdx │ └── inserts │ └── page.mdx ├── setup.py ├── LICENSE ├── .github ├── scripts │ ├── update_version.py.lock │ └── update_version.py └── workflows │ └── test.yml ├── pyproject.toml ├── Makefile ├── .gitignore └── README.md /iceaxe/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iceaxe/__tests__/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iceaxe/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include iceaxe/*.pyx -------------------------------------------------------------------------------- /iceaxe/__tests__/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iceaxe/__tests__/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iceaxe/__tests__/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iceaxe/__tests__/mountaineer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iceaxe/__tests__/mountaineer/dependencies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /media/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piercefreeman/iceaxe/HEAD/media/header.png -------------------------------------------------------------------------------- /docs/api/mountaineer/config/page.mdx: -------------------------------------------------------------------------------- 1 | # Database Config 2 | 3 | ::: mountaineer.database.config.DatabaseConfig 4 | -------------------------------------------------------------------------------- /docs/api/mountaineer/dependencies/page.mdx: -------------------------------------------------------------------------------- 1 | # Database Dependencies 2 | 3 | ::: mountaineer.database.DatabaseDependencies 4 | -------------------------------------------------------------------------------- /iceaxe/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | from .cli import ( 2 | handle_apply as handle_apply, 3 | handle_generate as handle_generate, 4 | handle_rollback as handle_rollback, 5 | ) 6 | -------------------------------------------------------------------------------- /iceaxe/mountaineer/dependencies/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database dependencies for use in API endpoint routes. 3 | 4 | """ 5 | 6 | from .core import get_db_connection as get_db_connection 7 | -------------------------------------------------------------------------------- /docs/api/connection/page.mdx: -------------------------------------------------------------------------------- 1 | # Connection 2 | 3 | The `DBConnection` class manages all Postgres-Python interactions that actually 4 | affect the state of the database. 5 | 6 | ::: iceaxe.session.DBConnection 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from Cython.Build import cythonize 3 | 4 | setup( 5 | name="iceaxe", 6 | ext_modules=cythonize("iceaxe/**/*.pyx", annotate=True, language_level="3"), 7 | ) 8 | -------------------------------------------------------------------------------- /docs/sql/select/page.mdx: -------------------------------------------------------------------------------- 1 | 2 | ```sql 3 | SELECT * FROM employees 4 | ``` 5 | 6 | This query is requesting a few things. It's asking for all columns (`*`) from the `employees` table, and all 7 | rows within that table (lack of a `where` filter). 8 | -------------------------------------------------------------------------------- /docs/sql/introduction/page.mdx: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | If you're new to databases in general, it's worth spending a bit of time 4 | to get acquainted with the new language that underpins everything that 5 | you're going to be doing with Postgres: SQL. 6 | -------------------------------------------------------------------------------- /iceaxe/__tests__/test_helpers.py: -------------------------------------------------------------------------------- 1 | from iceaxe.__tests__.helpers import pyright_raises 2 | 3 | 4 | def test_basic_type_error(): 5 | def type_error_func(x: int) -> int: 6 | return 10 7 | 8 | with pyright_raises("reportArgumentType"): 9 | type_error_func("20") # type: ignore 10 | -------------------------------------------------------------------------------- /iceaxe/mountaineer/__init__.py: -------------------------------------------------------------------------------- 1 | from iceaxe.mountaineer import ( 2 | dependencies as DatabaseDependencies, # noqa: F401 3 | ) 4 | 5 | from .cli import ( 6 | apply_migration as apply_migration, 7 | generate_migration as generate_migration, 8 | rollback_migration as rollback_migration, 9 | ) 10 | from .config import DatabaseConfig as DatabaseConfig 11 | -------------------------------------------------------------------------------- /iceaxe/__tests__/test_field.py: -------------------------------------------------------------------------------- 1 | from iceaxe.base import TableBase 2 | from iceaxe.field import DBFieldClassDefinition, DBFieldInfo 3 | 4 | 5 | def test_db_field_class_definition_instantiation(): 6 | field_def = DBFieldClassDefinition( 7 | root_model=TableBase, key="test_key", field_definition=DBFieldInfo() 8 | ) 9 | assert field_def.root_model == TableBase 10 | assert field_def.key == "test_key" 11 | assert isinstance(field_def.field_definition, DBFieldInfo) 12 | -------------------------------------------------------------------------------- /docs/api/tables/page.mdx: -------------------------------------------------------------------------------- 1 | # Tables 2 | 3 | Define typehinted tables in Python as the source of truth for what 4 | your tables should look like within your database. These are used by 5 | both the typehint signatures, runtime, and the migration logic to enforce 6 | your schemas are as you define here. 7 | 8 | ## Tables 9 | 10 | ::: iceaxe.base.TableBase 11 | 12 | --- 13 | 14 | ## Constraints 15 | 16 | ::: iceaxe.base.IndexConstraint 17 | 18 | --- 19 | 20 | ::: iceaxe.base.UniqueConstraint 21 | -------------------------------------------------------------------------------- /docs/api/migrations/page.mdx: -------------------------------------------------------------------------------- 1 | # Database Migrations 2 | 3 | ## CLI 4 | 5 | ::: iceaxe.migrations.cli.handle_generate 6 | 7 | --- 8 | 9 | ::: iceaxe.migrations.cli.handle_apply 10 | 11 | --- 12 | 13 | ::: iceaxe.migrations.cli.handle_rollback 14 | 15 | --- 16 | 17 | ## Migrations 18 | 19 | ::: iceaxe.migrations.migration.MigrationRevisionBase 20 | 21 | --- 22 | 23 | ::: iceaxe.migrations.migrator.Migrator 24 | 25 | --- 26 | 27 | ## Table Manipulator 28 | 29 | ::: iceaxe.schemas.actions.DatabaseActions 30 | 31 | --- 32 | ::: iceaxe.sql_types.ColumnType 33 | 34 | --- 35 | 36 | ::: iceaxe.sql_types.ConstraintType 37 | -------------------------------------------------------------------------------- /docs/api/mountaineer/page.mdx: -------------------------------------------------------------------------------- 1 | # Mountaineer & FastAPI 2 | 3 | Iceaxe includes some helper classes that make it easier to connect to databases 4 | from within dependency-injection frameworks (like Mountaineer & FastAPI). 5 | 6 | ## Config 7 | 8 | ::: iceaxe.mountaineer.config.DatabaseConfig 9 | 10 | --- 11 | 12 | ## Dependencies 13 | 14 | ::: iceaxe.mountaineer.dependencies.core.get_db_connection 15 | 16 | --- 17 | 18 | ## CLI 19 | 20 | ::: iceaxe.mountaineer.cli.generate_migration 21 | 22 | --- 23 | 24 | ::: iceaxe.mountaineer.cli.apply_migration 25 | 26 | --- 27 | 28 | ::: iceaxe.mountaineer.cli.rollback_migration 29 | -------------------------------------------------------------------------------- /iceaxe/__init__.py: -------------------------------------------------------------------------------- 1 | from .alias_values import alias as alias 2 | from .base import ( 3 | IndexConstraint as IndexConstraint, 4 | TableBase as TableBase, 5 | UniqueConstraint as UniqueConstraint, 6 | ) 7 | from .field import Field as Field 8 | from .functions import func as func 9 | from .postgres import PostgresDateTime as PostgresDateTime, PostgresTime as PostgresTime 10 | from .queries import ( 11 | QueryBuilder as QueryBuilder, 12 | and_ as and_, 13 | delete as delete, 14 | or_ as or_, 15 | select as select, 16 | update as update, 17 | ) 18 | from .queries_str import sql as sql 19 | from .session import DBConnection as DBConnection 20 | from .typing import column as column 21 | -------------------------------------------------------------------------------- /docs/api/fields/page.mdx: -------------------------------------------------------------------------------- 1 | # Fields 2 | 3 | Fields let you customize the behavior of the Columns. They accept any parameter 4 | that regular [Pydantic Fields](https://docs.pydantic.dev/latest/concepts/fields/#numeric-constraints) do, plus 5 | a few extra parameters that are specific to your database. 6 | 7 | ::: iceaxe.field.Field 8 | 9 | --- 10 | 11 | ::: iceaxe.field.DBFieldInfo 12 | 13 | --- 14 | 15 | ## Postgres Types 16 | 17 | Some primitives (like datetimes) allow for more configuration in Postgres 18 | than their types do in Python. In these cases we provide additional configuration 19 | classes that can be passed to `postgres_config` within the Field. 20 | 21 | ::: iceaxe.postgres.PostgresFieldBase 22 | 23 | --- 24 | 25 | ::: iceaxe.postgres.PostgresDateTime 26 | 27 | --- 28 | 29 | ::: iceaxe.postgres.PostgresTime 30 | -------------------------------------------------------------------------------- /iceaxe/__tests__/schemas/test_cli.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from iceaxe.base import TableBase 4 | from iceaxe.queries import select 5 | from iceaxe.schemas.cli import create_all 6 | from iceaxe.session import DBConnection 7 | 8 | 9 | class DemoCustomModel(TableBase): 10 | value: str 11 | 12 | 13 | @pytest.mark.asyncio 14 | async def test_create_all(db_connection: DBConnection): 15 | # Drop the existing table if it already exists 16 | await db_connection.conn.fetch("DROP TABLE IF EXISTS democustommodel CASCADE") 17 | 18 | await create_all(db_connection, models=[DemoCustomModel]) 19 | 20 | # Now we try to insert a row into the new table 21 | await db_connection.insert([DemoCustomModel(value="test")]) 22 | 23 | result = await db_connection.exec(select(DemoCustomModel)) 24 | assert len(result) == 1 25 | assert result[0].value == "test" 26 | -------------------------------------------------------------------------------- /docs/guides/logging/page.mdx: -------------------------------------------------------------------------------- 1 | # Logging 2 | 3 | At a standard logging level, Iceaxe is conservative with what it logs. 4 | 5 | During development it's often useful to increase this verbosity to help diagnose the 6 | SQL queries that are actually sent to your database. You can set the `ICEAXE_LOG_LEVEL` environment 7 | variable to `DEBUG` to enable verbose logging: 8 | 9 | ```bash 10 | ICEAXE_LOG_LEVEL=DEBUG uv run runserver 11 | ``` 12 | 13 | This results in a lot of output, so you may want to pipe it to a file: 14 | 15 | ```bash 16 | ICEAXE_LOG_LEVEL=DEBUG uv run runserver > debug.log 17 | ``` 18 | 19 | Each entry in the log will be a JSON formatted, single-lined 20 | payload with the timestamp and the logged sql message (or whatever other 21 | message is logged by the iceaxe library). These JSON logs allow for 22 | easier machine parsing and analysis of query performance. 23 | 24 | ```bash 25 | {"level": "DEBUG", "name": "iceaxe", "timestamp": "2021-10-14 15:00:00,000", "message": "SELECT * FROM users WHERE id = 1"} 26 | ``` 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Pierce Freeman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /iceaxe/schemas/cli.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from typing import Type 3 | 4 | from iceaxe.base import DBModelMetaclass, TableBase 5 | from iceaxe.schemas.actions import DatabaseActions 6 | from iceaxe.schemas.db_memory_serializer import DatabaseMemorySerializer 7 | from iceaxe.session import DBConnection 8 | 9 | 10 | async def create_all( 11 | db_connection: DBConnection, models: list[Type[TableBase]] | None = None 12 | ): 13 | # Get all of the instances that have been registered 14 | # in memory scope by the user. 15 | if models is None: 16 | models = [ 17 | cls 18 | for cls in DBModelMetaclass.get_registry() 19 | if isclass(cls) and issubclass(cls, TableBase) 20 | ] 21 | 22 | migrator = DatabaseMemorySerializer() 23 | db_objects = list(migrator.delegate(models)) 24 | next_ordering = migrator.order_db_objects(db_objects) 25 | 26 | # Create the tables in the database 27 | actor = DatabaseActions(dry_run=False, db_connection=db_connection) 28 | await migrator.build_actions( 29 | actor, [], {}, [obj for obj, _ in db_objects], next_ordering 30 | ) 31 | -------------------------------------------------------------------------------- /iceaxe/__tests__/migrations/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest_asyncio 2 | 3 | from iceaxe.session import DBConnection 4 | 5 | 6 | @pytest_asyncio.fixture 7 | async def clear_all_database_objects(db_connection: DBConnection): 8 | """ 9 | Clear all database objects. 10 | 11 | """ 12 | # Step 1: Drop all tables in the public schema 13 | await db_connection.conn.execute( 14 | """ 15 | DO $$ DECLARE 16 | r RECORD; 17 | BEGIN 18 | FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP 19 | EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE'; 20 | END LOOP; 21 | END $$; 22 | """ 23 | ) 24 | 25 | # Step 2: Drop all custom types in the public schema 26 | await db_connection.conn.execute( 27 | """ 28 | DO $$ DECLARE 29 | r RECORD; 30 | BEGIN 31 | FOR r IN (SELECT typname FROM pg_type WHERE typtype = 'e' AND typnamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'public')) LOOP 32 | EXECUTE 'DROP TYPE IF EXISTS ' || quote_ident(r.typname) || ' CASCADE'; 33 | END LOOP; 34 | END $$; 35 | """ 36 | ) 37 | -------------------------------------------------------------------------------- /.github/scripts/update_version.py.lock: -------------------------------------------------------------------------------- 1 | version = 1 2 | revision = 1 3 | requires-python = ">=3.13" 4 | 5 | [manifest] 6 | requirements = [ 7 | { name = "packaging" }, 8 | { name = "toml" }, 9 | ] 10 | 11 | [[package]] 12 | name = "packaging" 13 | version = "24.2" 14 | source = { registry = "https://pypi.org/simple" } 15 | sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } 16 | wheels = [ 17 | { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, 18 | ] 19 | 20 | [[package]] 21 | name = "toml" 22 | version = "0.10.2" 23 | source = { registry = "https://pypi.org/simple" } 24 | sdist = { url = "https://files.pythonhosted.org/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", size = 22253 } 25 | wheels = [ 26 | { url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588 }, 27 | ] 28 | -------------------------------------------------------------------------------- /iceaxe/__tests__/benchmarks/test_bulk_insert.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Sequence 3 | 4 | import pytest 5 | 6 | from iceaxe.__tests__.conf_models import UserDemo 7 | from iceaxe.logging import CONSOLE, LOGGER 8 | from iceaxe.session import DBConnection 9 | 10 | 11 | def generate_test_users(count: int) -> Sequence[UserDemo]: 12 | """ 13 | Generate a sequence of test users for bulk insertion. 14 | 15 | :param count: Number of users to generate 16 | :return: Sequence of UserDemo instances 17 | """ 18 | return [ 19 | UserDemo(name=f"User {i}", email=f"user{i}@example.com") for i in range(count) 20 | ] 21 | 22 | 23 | @pytest.mark.asyncio 24 | @pytest.mark.integration_tests 25 | async def test_bulk_insert_performance(db_connection: DBConnection): 26 | """ 27 | Test the performance of bulk inserting 500k records. 28 | """ 29 | NUM_USERS = 500_000 30 | users = generate_test_users(NUM_USERS) 31 | LOGGER.info(f"Generated {NUM_USERS} test users") 32 | 33 | start_time = time.time() 34 | 35 | await db_connection.insert(users) 36 | 37 | total_time = time.time() - start_time 38 | records_per_second = NUM_USERS / total_time 39 | 40 | CONSOLE.print("\nBulk Insert Performance:") 41 | CONSOLE.print(f"Total time: {total_time:.2f} seconds") 42 | CONSOLE.print(f"Records per second: {records_per_second:.2f}") 43 | 44 | result = await db_connection.conn.fetchval("SELECT COUNT(*) FROM userdemo") 45 | assert result == NUM_USERS 46 | -------------------------------------------------------------------------------- /iceaxe/mountaineer/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | 3 | from iceaxe.modifications import MODIFICATION_TRACKER_VERBOSITY 4 | 5 | 6 | class DatabaseConfig(BaseSettings): 7 | """ 8 | Configuration settings for PostgreSQL database connection. 9 | This class uses Pydantic's BaseSettings to manage environment-based configuration. 10 | """ 11 | 12 | POSTGRES_HOST: str 13 | """ 14 | The hostname where the PostgreSQL server is running. 15 | This can be a domain name or IP address (e.g., 'localhost' or '127.0.0.1'). 16 | """ 17 | 18 | POSTGRES_USER: str 19 | """ 20 | The username to authenticate with the PostgreSQL server. 21 | This user should have appropriate permissions for the database operations. 22 | """ 23 | 24 | POSTGRES_PASSWORD: str 25 | """ 26 | The password for authenticating the PostgreSQL user. 27 | This should be kept secure and not exposed in code or version control. 28 | """ 29 | 30 | POSTGRES_DB: str 31 | """ 32 | The name of the PostgreSQL database to connect to. 33 | This database should exist on the server before attempting connection. 34 | """ 35 | 36 | POSTGRES_PORT: int = 5432 37 | """ 38 | The port number where PostgreSQL server is listening. 39 | Defaults to the standard PostgreSQL port 5432 if not specified. 40 | """ 41 | 42 | ICEAXE_UNCOMMITTED_VERBOSITY: MODIFICATION_TRACKER_VERBOSITY | None = None 43 | """ 44 | The verbosity level for uncommitted modifications. 45 | If set to None, uncommitted modifications will not be tracked. 46 | """ 47 | -------------------------------------------------------------------------------- /docs/api/queries/page.mdx: -------------------------------------------------------------------------------- 1 | # Queries 2 | 3 | Queries allow you to retrieve data from your database. Every query you build with ORM objects 4 | is a subclass of `QueryBuilder` at its base. This class provides a number of methods to help you 5 | construct your query and pass the correct types for downstream processing. 6 | 7 | ## Shortcuts 8 | 9 | We offer convenience functions for the most common SQL operations. This avoids 10 | the need to instantiate a `QueryBuilder` directly every time you want a new SQL 11 | query and typically results in cleaner code. 12 | 13 | ::: iceaxe.queries.select 14 | 15 | --- 16 | 17 | ::: iceaxe.queries.update 18 | 19 | --- 20 | 21 | ::: iceaxe.queries.delete 22 | 23 | --- 24 | 25 | ## Raw Queries 26 | 27 | ::: iceaxe.queries.QueryBuilder 28 | 29 | --- 30 | 31 | ::: iceaxe.functions.FunctionMetadata 32 | 33 | --- 34 | 35 | ::: iceaxe.comparison.ComparisonBase 36 | 37 | --- 38 | 39 | ## Functions 40 | 41 | ::: iceaxe.functions.func 42 | 43 | --- 44 | 45 | ::: iceaxe.functions.FunctionBuilder 46 | 47 | --- 48 | 49 | ## Conditions 50 | 51 | If you want to add conditions to your query, you can use the `where` method. If you pass multiple 52 | conditions to this, they'll be chained together as an `AND` filter. To 53 | combine OR and AND predicates together, you can use `and_` and `or_`. 54 | 55 | ::: iceaxe.queries.and_ 56 | 57 | --- 58 | 59 | ::: iceaxe.queries.or_ 60 | 61 | --- 62 | 63 | ## Enums 64 | 65 | ::: iceaxe.queries.JoinType 66 | 67 | --- 68 | 69 | ::: iceaxe.queries.OrderDirection 70 | 71 | --- 72 | 73 | ## Raw Queries 74 | 75 | ::: iceaxe.queries_str.sql 76 | 77 | --- 78 | 79 | ::: iceaxe.queries_str.SQLGenerator 80 | -------------------------------------------------------------------------------- /iceaxe/__tests__/test_base.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar 2 | 3 | from iceaxe.base import ( 4 | DBModelMetaclass, 5 | TableBase, 6 | ) 7 | from iceaxe.field import DBFieldInfo 8 | 9 | 10 | def test_autodetect(): 11 | class WillAutodetect(TableBase): 12 | pass 13 | 14 | assert WillAutodetect in DBModelMetaclass.get_registry() 15 | 16 | 17 | def test_not_autodetect(): 18 | class WillNotAutodetect(TableBase, autodetect=False): 19 | pass 20 | 21 | assert WillNotAutodetect not in DBModelMetaclass.get_registry() 22 | 23 | 24 | def test_not_autodetect_generic(clear_registry): 25 | T = TypeVar("T") 26 | 27 | class GenericSuperclass(TableBase, Generic[T], autodetect=False): 28 | value: T 29 | 30 | class WillAutodetect(GenericSuperclass[int]): 31 | pass 32 | 33 | assert DBModelMetaclass.get_registry() == [WillAutodetect] 34 | 35 | 36 | def test_model_fields(): 37 | class User(TableBase): 38 | id: int 39 | name: str 40 | 41 | # Check the main fields 42 | assert isinstance(User.model_fields["id"], DBFieldInfo) 43 | assert User.model_fields["id"].annotation == int # noqa: E721 44 | assert User.model_fields["id"].is_required() is True 45 | 46 | assert isinstance(User.model_fields["name"], DBFieldInfo) 47 | assert User.model_fields["name"].annotation == str # noqa: E721 48 | assert User.model_fields["name"].is_required() is True 49 | 50 | # Check that the special fields exist with the right types 51 | assert isinstance(User.model_fields["modified_attrs"], DBFieldInfo) 52 | assert isinstance(User.model_fields["modified_attrs_callbacks"], DBFieldInfo) 53 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "iceaxe" 3 | version = "0.1.0" 4 | description = "A modern, fast ORM for Python." 5 | readme = "README.md" 6 | authors = [{ name = "Pierce Freeman", email = "pierce@freeman.vc" }] 7 | requires-python = ">=3.11" 8 | dependencies = [ 9 | "asyncpg>=0.30,<1", 10 | "pydantic>=2,<3", 11 | "rich>=13", 12 | ] 13 | 14 | [dependency-groups] 15 | dev = [ 16 | "docker>=7.1.0", 17 | "mountaineer>=0.8.0", 18 | "mypy>=1.15.0", 19 | "pydantic>=2,<2.11", 20 | "pyinstrument>=5.0.1", 21 | "pyright>=1.1.396", 22 | "pytest>=8.3.5", 23 | "pytest-asyncio>=0.25.3", 24 | "ruff>=0.11.0", 25 | ] 26 | 27 | [build-system] 28 | requires = ["setuptools>=61", "cython>=3.0.0"] 29 | build-backend = "setuptools.build_meta" 30 | 31 | [tool.setuptools.packages.find] 32 | include = ["iceaxe", "iceaxe.*"] 33 | 34 | [tool.cython-build] 35 | packages = ["iceaxe"] 36 | include = ["iceaxe/*.pyx"] 37 | 38 | [tool.pyright] 39 | venvPath = "." 40 | venv = ".venv" 41 | 42 | [tool.mypy] 43 | warn_return_any = true 44 | warn_unused_configs = true 45 | check_untyped_defs = true 46 | plugins = ["pydantic.mypy"] 47 | 48 | [[tool.mypy.overrides]] 49 | module = "asyncpg.*" 50 | ignore_missing_imports = true 51 | 52 | [tool.pydantic-mypy] 53 | init_forbid_extra = true 54 | init_typed = true 55 | warn_required_dynamic_aliases = true 56 | 57 | [tool.ruff.lint] 58 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 59 | # Disable print statements 60 | select = ["E4", "E7", "E9", "F", "I001", "T201"] 61 | 62 | [tool.ruff.lint.isort] 63 | section-order = [ 64 | "future", 65 | "standard-library", 66 | "third-party", 67 | "first-party", 68 | "local-folder", 69 | ] 70 | combine-as-imports = true 71 | 72 | [tool.pytest.ini_options] 73 | markers = ["integration_tests: run longer-running integration tests"] 74 | # Default pytest runs shouldn't execute the integration tests 75 | addopts = "-m 'not integration_tests'" 76 | -------------------------------------------------------------------------------- /iceaxe/typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import date, datetime, time, timedelta 4 | from enum import Enum, IntEnum, StrEnum 5 | from inspect import isclass 6 | from typing import ( 7 | TYPE_CHECKING, 8 | Any, 9 | Type, 10 | TypeGuard, 11 | TypeVar, 12 | ) 13 | from uuid import UUID 14 | 15 | if TYPE_CHECKING: 16 | from iceaxe.alias_values import Alias 17 | from iceaxe.base import ( 18 | DBFieldClassDefinition, 19 | TableBase, 20 | ) 21 | from iceaxe.comparison import FieldComparison, FieldComparisonGroup 22 | from iceaxe.functions import FunctionMetadata 23 | 24 | 25 | ALL_ENUM_TYPES = Type[Enum | StrEnum | IntEnum] 26 | PRIMITIVE_TYPES = int | float | str | bool | bytes | UUID 27 | PRIMITIVE_WRAPPER_TYPES = list[PRIMITIVE_TYPES] | PRIMITIVE_TYPES 28 | DATE_TYPES = datetime | date | time | timedelta 29 | JSON_WRAPPER_FALLBACK = list[Any] | dict[Any, Any] 30 | 31 | T = TypeVar("T") 32 | 33 | 34 | def is_base_table(obj: Any) -> TypeGuard[type[TableBase]]: 35 | from iceaxe.base import TableBase 36 | 37 | return isclass(obj) and issubclass(obj, TableBase) 38 | 39 | 40 | def is_column(obj: T) -> TypeGuard[DBFieldClassDefinition[T]]: 41 | from iceaxe.base import DBFieldClassDefinition 42 | 43 | return isinstance(obj, DBFieldClassDefinition) 44 | 45 | 46 | def is_comparison(obj: Any) -> TypeGuard[FieldComparison]: 47 | from iceaxe.comparison import FieldComparison 48 | 49 | return isinstance(obj, FieldComparison) 50 | 51 | 52 | def is_comparison_group(obj: Any) -> TypeGuard[FieldComparisonGroup]: 53 | from iceaxe.comparison import FieldComparisonGroup 54 | 55 | return isinstance(obj, FieldComparisonGroup) 56 | 57 | 58 | def is_function_metadata(obj: Any) -> TypeGuard[FunctionMetadata]: 59 | from iceaxe.functions import FunctionMetadata 60 | 61 | return isinstance(obj, FunctionMetadata) 62 | 63 | 64 | def is_alias(obj: Any) -> TypeGuard[Alias]: 65 | from iceaxe.alias_values import Alias 66 | 67 | return isinstance(obj, Alias) 68 | 69 | 70 | def column(obj: T) -> DBFieldClassDefinition[T]: 71 | if not is_column(obj): 72 | raise ValueError(f"Invalid column: {obj}") 73 | return obj 74 | -------------------------------------------------------------------------------- /iceaxe/alias_values.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Generic, TypeVar, cast 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | @dataclass(frozen=True, slots=True) 8 | class Alias(Generic[T]): 9 | name: str 10 | value: T 11 | 12 | def __str__(self): 13 | return self.name 14 | 15 | 16 | def alias(name: str, type: T) -> T: 17 | """ 18 | Creates an alias for a field in raw SQL queries, allowing for type-safe mapping of raw SQL results. 19 | This is particularly useful in two main scenarios: 20 | 21 | 1. When using raw SQL queries with aliased columns: 22 | ```python 23 | # Map a COUNT(*) result to an integer 24 | query = select(alias("user_count", int)).text( 25 | "SELECT COUNT(*) AS user_count FROM users" 26 | ) 27 | 28 | # Map multiple aliased columns with different types 29 | query = select(( 30 | alias("full_name", str), 31 | alias("order_count", int), 32 | alias("total_spent", float) 33 | )).text( 34 | ''' 35 | SELECT 36 | concat(first_name, ' ', last_name) AS full_name, 37 | COUNT(orders.id) AS order_count, 38 | SUM(orders.amount) AS total_spent 39 | FROM users 40 | LEFT JOIN orders ON users.id = orders.user_id 41 | GROUP BY users.id 42 | ''' 43 | ) 44 | ``` 45 | 46 | 2. When combining ORM models with function results: 47 | ```python 48 | # Select a model alongside a function result 49 | query = select(( 50 | User, 51 | alias("name_length", func.length(User.name)), 52 | alias("upper_name", func.upper(User.name)) 53 | )) 54 | 55 | # Use with aggregation functions 56 | query = select(( 57 | User, 58 | alias("total_orders", func.count(Order.id)) 59 | )).join(Order, User.id == Order.user_id).group_by(User.id) 60 | ``` 61 | 62 | :param name: The name of the alias as it appears in the SQL query's AS clause 63 | :param type: Either a Python type to cast the result to (e.g., int, str, float) or 64 | a function metadata object (e.g., from func.length()) 65 | :return: A type-safe alias that can be used in select() statements 66 | """ 67 | return cast(T, Alias(name, type)) 68 | -------------------------------------------------------------------------------- /iceaxe/__tests__/mountaineer/dependencies/test_core.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, MagicMock, patch 2 | 3 | import asyncpg 4 | import pytest 5 | 6 | from iceaxe.mountaineer.config import DatabaseConfig 7 | from iceaxe.mountaineer.dependencies.core import get_db_connection 8 | from iceaxe.session import DBConnection 9 | 10 | 11 | @pytest.fixture(autouse=True) 12 | def mock_db_connect(mock_connection: AsyncMock): 13 | with patch("asyncpg.connect", new_callable=AsyncMock) as mock: 14 | mock.return_value = mock_connection 15 | 16 | yield mock 17 | 18 | 19 | @pytest.fixture 20 | def mock_config(): 21 | return DatabaseConfig( 22 | POSTGRES_HOST="test-host", 23 | POSTGRES_PORT=5432, 24 | POSTGRES_USER="test-user", 25 | POSTGRES_PASSWORD="test-pass", 26 | POSTGRES_DB="test-db", 27 | ) 28 | 29 | 30 | @pytest.fixture 31 | def mock_connection(): 32 | conn = AsyncMock(spec=asyncpg.Connection) 33 | conn.close = AsyncMock() 34 | 35 | # We need to populate the internal dsn parameters like the real query 36 | conn._addr = ("test-host", 5432) 37 | conn._params.user = "test-user" 38 | conn._params.password = "test-pass" 39 | conn._params.database = "test-db" 40 | 41 | conn._introspect_types.return_value = (MagicMock(), MagicMock()) 42 | 43 | return conn 44 | 45 | 46 | @pytest.mark.asyncio 47 | async def test_get_db_connection_closes_after_yield( 48 | mock_config: DatabaseConfig, 49 | mock_connection: AsyncMock, 50 | mock_db_connect: AsyncMock, 51 | ): 52 | mock_db_connect.return_value = mock_connection 53 | 54 | # Get the generator 55 | db_gen = get_db_connection(mock_config) 56 | 57 | # Get the connection 58 | connection = await anext(db_gen) # noqa: F821 59 | 60 | assert isinstance(connection, DBConnection) 61 | assert connection.conn == mock_connection 62 | mock_db_connect.assert_called_once_with( 63 | host=mock_config.POSTGRES_HOST, 64 | port=mock_config.POSTGRES_PORT, 65 | user=mock_config.POSTGRES_USER, 66 | password=mock_config.POSTGRES_PASSWORD, 67 | database=mock_config.POSTGRES_DB, 68 | ) 69 | 70 | # Simulate the end of the generator's scope 71 | try: 72 | await db_gen.aclose() 73 | except StopAsyncIteration: 74 | pass 75 | 76 | mock_connection.close.assert_called_once() 77 | -------------------------------------------------------------------------------- /iceaxe/mountaineer/dependencies/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optional compatibility layer for `mountaineer` dependency access. 3 | 4 | """ 5 | 6 | from typing import AsyncGenerator 7 | 8 | import asyncpg 9 | from mountaineer import CoreDependencies, Depends 10 | 11 | from iceaxe.mountaineer.config import DatabaseConfig 12 | from iceaxe.session import DBConnection 13 | 14 | 15 | async def get_db_connection( 16 | config: DatabaseConfig = Depends( 17 | CoreDependencies.get_config_with_type(DatabaseConfig) 18 | ), 19 | ) -> AsyncGenerator[DBConnection, None]: 20 | """ 21 | A dependency that provides a database connection for use in FastAPI endpoints or other 22 | dependency-injected contexts. The connection is automatically closed when the endpoint 23 | finishes processing. 24 | 25 | This dependency: 26 | - Creates a new PostgreSQL connection using the provided configuration 27 | - Wraps it in a DBConnection for ORM functionality 28 | - Initializes the connection's type cache to support enums without per-connection 29 | type introspection 30 | - Automatically closes the connection when done 31 | - Integrates with Mountaineer's dependency injection system 32 | 33 | :param config: DatabaseConfig instance containing connection parameters. 34 | Automatically injected by Mountaineer if not provided. 35 | :return: An async generator yielding a DBConnection instance 36 | 37 | ```python 38 | from fastapi import FastAPI, Depends 39 | from iceaxe.mountaineer.dependencies import get_db_connection 40 | from iceaxe.session import DBConnection 41 | 42 | app = FastAPI() 43 | 44 | # Basic usage in a FastAPI endpoint 45 | @app.get("/users") 46 | async def get_users(db: DBConnection = Depends(get_db_connection)): 47 | users = await db.exec(select(User)) 48 | return users 49 | ``` 50 | """ 51 | conn = await asyncpg.connect( 52 | host=config.POSTGRES_HOST, 53 | port=config.POSTGRES_PORT, 54 | user=config.POSTGRES_USER, 55 | password=config.POSTGRES_PASSWORD, 56 | database=config.POSTGRES_DB, 57 | ) 58 | 59 | connection = DBConnection( 60 | conn, uncommitted_verbosity=config.ICEAXE_UNCOMMITTED_VERBOSITY 61 | ) 62 | await connection.initialize_types() 63 | 64 | try: 65 | yield connection 66 | finally: 67 | await connection.close() 68 | -------------------------------------------------------------------------------- /iceaxe/migrations/client_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Client interface functions to introspect the client migrations and import them appropriately 3 | into the current runtime. 4 | 5 | """ 6 | 7 | import sys 8 | from importlib.util import module_from_spec, spec_from_file_location 9 | from pathlib import Path 10 | 11 | from iceaxe.migrations.migration import MigrationRevisionBase 12 | 13 | 14 | def fetch_migrations(migration_base: Path): 15 | """ 16 | Fetch all migrations from the migration base directory. Instead of using importlib.import_module, 17 | we manually create the module dependencies - this provides some flexibility in the future to avoid 18 | importing the whole client application just to fetch the migrations. 19 | 20 | We enforce that all migration files have the prefix 'rev_'. 21 | 22 | """ 23 | migrations: list[MigrationRevisionBase] = [] 24 | for file in migration_base.glob("rev_*.py"): 25 | module_name = file.stem 26 | if module_name.isidentifier() and not module_name.startswith("_"): 27 | spec = spec_from_file_location(module_name, str(file)) 28 | if spec and spec.loader: 29 | module = module_from_spec(spec) 30 | sys.modules[module_name] = module 31 | spec.loader.exec_module(module) 32 | migrations.extend( 33 | attribute() 34 | for attribute_name in dir(module) 35 | if isinstance(attribute := getattr(module, attribute_name), type) 36 | and issubclass(attribute, MigrationRevisionBase) 37 | and attribute is not MigrationRevisionBase 38 | ) 39 | return migrations 40 | 41 | 42 | def sort_migrations(migrations: list[MigrationRevisionBase]): 43 | """ 44 | Sort migrations by their (down_revision, up_revision) dependencies. We start with down=current_revision 45 | which should be the original migration. 46 | 47 | """ 48 | migration_dict = {mig.down_revision: mig for mig in migrations} 49 | sorted_revisions: list[MigrationRevisionBase] = [] 50 | 51 | next_revision = None 52 | while next_revision in migration_dict: 53 | next_migration = migration_dict[next_revision] 54 | sorted_revisions.append(next_migration) 55 | next_revision = next_migration.up_revision 56 | 57 | if len(sorted_revisions) != len(migrations): 58 | raise ValueError( 59 | "There are gaps in the migration sequence or unresolved dependencies." 60 | ) 61 | 62 | return sorted_revisions 63 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # 2 | # Main Makefile for project development and CI 3 | # 4 | 5 | # Default shell 6 | SHELL := /bin/bash 7 | # Fail on first error 8 | .SHELLFLAGS := -ec 9 | 10 | # Global variables 11 | ICEAXE := ./ 12 | ICEAXE_NAME := iceaxe 13 | 14 | # Ignore these directories in the local filesystem if they exist 15 | .PHONY: lint test 16 | 17 | # Main lint target 18 | lint: lint-iceaxe 19 | 20 | # Lint validation target 21 | lint-validation: lint-validation-iceaxe 22 | 23 | # Testing target 24 | test: test-iceaxe 25 | 26 | # Install all sub-project dependencies with uv 27 | install-deps: install-deps-iceaxe 28 | 29 | install-deps-iceaxe: 30 | @echo "Installing dependencies for $(ICEAXE)..." 31 | @(cd $(ICEAXE) && uv sync) 32 | 33 | # Clean the current uv.lock files, useful for remote CI machines 34 | # where we're running on a different base architecture than when 35 | # developing locally 36 | clean-uv-lock: 37 | @echo "Cleaning uv.lock files..." 38 | @rm -f $(ICEAXE)/uv.lock 39 | 40 | # Standard linting - local development, with fixing enabled 41 | lint-iceaxe: 42 | $(call lint-common,$(ICEAXE),$(ICEAXE_NAME)) 43 | 44 | # Lint intended for CI usage - will fail on any errors 45 | lint-validation-iceaxe: 46 | $(call lint-validation-common,$(ICEAXE),$(ICEAXE_NAME)) 47 | 48 | # Tests 49 | test-iceaxe: 50 | @set -e; \ 51 | $(call test-common,$(ICEAXE),$(MOUNTAINEER_PLUGINS_NAME)) 52 | 53 | # 54 | # Common helper functions 55 | # 56 | 57 | define test-common 58 | echo "Running tests for $(2)..." 59 | @(cd $(1) && uv run pytest -W error -vv $(test-args) $(2)) 60 | endef 61 | 62 | define lint-common 63 | echo "Running linting for $(2)..." 64 | @(cd $(1) && uv run ruff format $(2)) 65 | @(cd $(1) && uv run ruff check --fix $(2)) 66 | echo "Running pyright for $(2)..." 67 | @(cd $(1) && uv run pyright $(2)) 68 | endef 69 | 70 | define lint-validation-common 71 | echo "Running lint validation for $(2)..." 72 | @(cd $(1) && uv run ruff format --check $(2)) 73 | @(cd $(1) && uv run ruff check $(2)) 74 | echo "Running mypy for $(2)..." 75 | @(cd $(1) && uv run mypy $(2)) 76 | echo "Running pyright for $(2)..." 77 | @(cd $(1) && uv run pyright $(2)) 78 | endef 79 | 80 | 81 | # Function to wait for PostgreSQL to be ready 82 | define wait-for-postgres 83 | @echo "Waiting for PostgreSQL to be ready..." 84 | @timeout=$(1); \ 85 | while ! nc -z localhost $(2) >/dev/null 2>&1; do \ 86 | timeout=$$((timeout-1)); \ 87 | if [ $$timeout -le 0 ]; then \ 88 | echo "Timed out waiting for PostgreSQL to start on port $(2)"; \ 89 | exit 1; \ 90 | fi; \ 91 | echo "Waiting for PostgreSQL to start..."; \ 92 | sleep 1; \ 93 | done; \ 94 | echo "PostgreSQL is ready on port $(2)." 95 | endef 96 | -------------------------------------------------------------------------------- /iceaxe/mountaineer/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Alternative entrypoints to migrations/cli that use Mountaineer configurations 3 | to simplify the setup of the database connection. 4 | 5 | """ 6 | 7 | from mountaineer import ConfigBase, CoreDependencies, Depends 8 | from mountaineer.dependencies import get_function_dependencies 9 | 10 | from iceaxe.migrations.cli import handle_apply, handle_generate, handle_rollback 11 | from iceaxe.mountaineer.config import DatabaseConfig 12 | from iceaxe.mountaineer.dependencies.core import get_db_connection 13 | from iceaxe.session import DBConnection 14 | 15 | 16 | async def generate_migration(message: str | None = None): 17 | async def _inner( 18 | db_config: DatabaseConfig = Depends( 19 | CoreDependencies.get_config_with_type(DatabaseConfig) 20 | ), 21 | core_config: ConfigBase = Depends( 22 | CoreDependencies.get_config_with_type(ConfigBase) 23 | ), 24 | db_connection: DBConnection = Depends(get_db_connection), 25 | ): 26 | if not core_config.PACKAGE: 27 | raise ValueError("No package provided in the configuration") 28 | 29 | await handle_generate( 30 | package=core_config.PACKAGE, 31 | db_connection=db_connection, 32 | message=message, 33 | ) 34 | 35 | async with get_function_dependencies(callable=_inner) as values: 36 | await _inner(**values) 37 | 38 | 39 | async def apply_migration(): 40 | async def _inner( 41 | core_config: ConfigBase = Depends( 42 | CoreDependencies.get_config_with_type(ConfigBase) 43 | ), 44 | db_connection: DBConnection = Depends(get_db_connection), 45 | ): 46 | if not core_config.PACKAGE: 47 | raise ValueError("No package provided in the configuration") 48 | 49 | await handle_apply( 50 | package=core_config.PACKAGE, 51 | db_connection=db_connection, 52 | ) 53 | 54 | async with get_function_dependencies(callable=_inner) as values: 55 | await _inner(**values) 56 | 57 | 58 | async def rollback_migration(): 59 | async def _inner( 60 | core_config: ConfigBase = Depends( 61 | CoreDependencies.get_config_with_type(ConfigBase) 62 | ), 63 | db_connection: DBConnection = Depends(get_db_connection), 64 | ): 65 | if not core_config.PACKAGE: 66 | raise ValueError("No package provided in the configuration") 67 | 68 | await handle_rollback( 69 | package=core_config.PACKAGE, 70 | db_connection=db_connection, 71 | ) 72 | 73 | async with get_function_dependencies(callable=_inner) as values: 74 | await _inner(**values) 75 | -------------------------------------------------------------------------------- /iceaxe/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | from json import dumps as json_dumps 4 | from logging import Formatter, StreamHandler, getLogger 5 | from os import environ 6 | from time import monotonic_ns 7 | 8 | from click import secho 9 | from rich.console import Console 10 | 11 | VERBOSITY_MAPPING = { 12 | "INFO": logging.INFO, 13 | "DEBUG": logging.DEBUG, 14 | "WARNING": logging.WARNING, 15 | "ERROR": logging.ERROR, 16 | } 17 | 18 | 19 | class JsonFormatter(Formatter): 20 | def format(self, record): 21 | log_record = { 22 | "level": record.levelname, 23 | "name": record.name, 24 | "timestamp": self.formatTime(record, self.datefmt), 25 | "message": record.getMessage(), 26 | } 27 | if record.exc_info: 28 | log_record["exception"] = self.formatException(record.exc_info) 29 | return json_dumps(log_record) 30 | 31 | 32 | class ColorHandler(StreamHandler): 33 | def emit(self, record): 34 | try: 35 | msg = self.format(record) 36 | if record.levelno == logging.WARNING: 37 | secho(msg, fg="yellow") 38 | elif record.levelno >= logging.ERROR: 39 | secho(msg, fg="red") 40 | else: 41 | secho(msg) 42 | except Exception: 43 | self.handleError(record) 44 | 45 | 46 | def setup_logger(name, log_level=logging.DEBUG): 47 | """ 48 | Constructor for the main logger used by Mountaineer. Provided 49 | convenient defaults for log level and formatting, alongside coloring 50 | of stdout/stderr messages and JSON fields for structured parsing. 51 | 52 | """ 53 | logger = getLogger(name) 54 | logger.setLevel(log_level) 55 | 56 | # Create a handler that writes log records to the standard error 57 | handler = ColorHandler() 58 | handler.setLevel(log_level) 59 | 60 | formatter = JsonFormatter() 61 | handler.setFormatter(formatter) 62 | 63 | logger.addHandler(handler) 64 | 65 | return logger 66 | 67 | 68 | @contextmanager 69 | def log_time_duration(message: str): 70 | """ 71 | Context manager to time a code block at runtime. 72 | 73 | ```python 74 | with log_time_duration("Long computation"): 75 | # Simulate work 76 | sleep(10) 77 | ``` 78 | 79 | """ 80 | start = monotonic_ns() 81 | yield 82 | LOGGER.debug(f"{message} : Took {(monotonic_ns() - start) / 1e9:.2f}s") 83 | 84 | 85 | # Our global logger should only surface warnings and above by default 86 | LOGGER = setup_logger( 87 | __name__, 88 | log_level=VERBOSITY_MAPPING[environ.get("ICEAXE_LOG_LEVEL", "WARNING")], 89 | ) 90 | 91 | CONSOLE = Console() 92 | -------------------------------------------------------------------------------- /iceaxe/__tests__/migrations/test_generics.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Type, Union 3 | 4 | import pytest 5 | 6 | from iceaxe.generics import ( 7 | _is_type_compatible, 8 | remove_null_type, 9 | ) 10 | 11 | 12 | class SomeEnum(Enum): 13 | A = "A" 14 | 15 | 16 | class SomeSuperClass: 17 | pass 18 | 19 | 20 | class SomeSubClass(SomeSuperClass): 21 | pass 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "obj_type, target_type, expected", 26 | [ 27 | # Basic types 28 | (int, int, True), 29 | (str, str, True), 30 | (int, str, False), 31 | (int, float, False), 32 | # Subclasses 33 | (bool, int, True), 34 | (int, object, True), 35 | (SomeSubClass, SomeSuperClass, True), 36 | # Instance can match classes 37 | (SomeSubClass(), SomeSuperClass, True), 38 | ([SomeSubClass], list[SomeSuperClass], True), 39 | # Enums 40 | (SomeEnum, Type[Enum], True), 41 | # Unions with new syntax 42 | (int, Union[int, str], True), 43 | (str, Union[int, str], True), 44 | (float, Union[int, str], False), 45 | # Unions with old syntax using type hints 46 | (int, Union[int, str], True), 47 | (str, Union[int, str], True), 48 | (float, Union[int, str], False), 49 | # Complex types involving collections 50 | (list[int], list[int], True), 51 | (list[int], list[str], False), 52 | (dict[str, int], dict[str, int], True), 53 | (dict[str, int], dict[str, str], False), 54 | # More complex union cases 55 | (list[int], Union[list[int], dict[str, str]], True), 56 | (dict[str, str], Union[list[int], dict[str, str]], True), 57 | (dict[str, float], Union[list[int], dict[str, str]], False), 58 | (dict[str, str], Union[list[int], dict[str, Any]], True), 59 | # Pipe operator if Python >= 3.10 60 | (int, int | str, True), 61 | (int | str, int | str | float, True), 62 | (str, int | str, True), 63 | (float, int | str, False), 64 | # Nested unions 65 | (int, list[int | float | str] | int | float | str, True), 66 | # Optional types 67 | (None, int | str | None, True), 68 | (int, int | str | None, True), 69 | # Value evaluation for sequences 70 | ([1, 2, 3], list[int], True), 71 | ([1, 2, "3"], list[int], False), 72 | ], 73 | ) 74 | def test_is_type_compatible(obj_type: Any, target_type: Any, expected: bool): 75 | raw_result = _is_type_compatible(obj_type, target_type) 76 | bool_result = raw_result != float("inf") 77 | assert bool_result == expected 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "typehint, expected", 82 | [ 83 | (int, int), 84 | (str, str), 85 | (int | None, int), 86 | (str | None, str), 87 | (Union[int, None], int), 88 | ], 89 | ) 90 | def test_remove_null_type(typehint: Any, expected: bool): 91 | assert remove_null_type(typehint) == expected 92 | -------------------------------------------------------------------------------- /iceaxe/__tests__/test_alias.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from iceaxe import alias, func, select, sql 4 | from iceaxe.__tests__.conf_models import UserDemo 5 | from iceaxe.queries import QueryBuilder 6 | from iceaxe.session import ( 7 | DBConnection, 8 | ) 9 | 10 | 11 | @pytest.mark.asyncio 12 | async def test_alias_with_function(db_connection: DBConnection): 13 | """Test using alias with a function value.""" 14 | demo = UserDemo(id=1, name="Test Title", email="john@example.com") 15 | await db_connection.insert([demo]) 16 | 17 | # Test using string length function with alias 18 | results = await db_connection.exec( 19 | select((UserDemo, alias("name_length", func.length(UserDemo.name)))) 20 | ) 21 | 22 | assert len(results) == 1 23 | assert results[0][0].id == 1 24 | assert isinstance(results[0][1], int) # length result 25 | 26 | 27 | @pytest.mark.asyncio 28 | async def test_alias_with_raw_sql(db_connection: DBConnection): 29 | """ 30 | Test that we can use a text query alongside an alias to map raw SQL results 31 | to typed values. 32 | """ 33 | user = UserDemo(name="John Doe", email="john@example.com") 34 | await db_connection.insert([user]) 35 | 36 | # Create a query that uses text() alongside an alias 37 | query = ( 38 | QueryBuilder() 39 | .select((UserDemo, alias("rollup_value", int))) 40 | .text( 41 | f""" 42 | SELECT {sql.select(UserDemo)}, COUNT(*) AS rollup_value 43 | FROM userdemo 44 | GROUP BY id 45 | """ 46 | ) 47 | ) 48 | result = await db_connection.exec(query) 49 | assert len(result) == 1 50 | assert isinstance(result[0], tuple) 51 | assert isinstance(result[0][0], UserDemo) 52 | assert result[0][0].name == "John Doe" 53 | assert result[0][0].email == "john@example.com" 54 | assert result[0][1] == 1 # The count should be 1 55 | 56 | 57 | @pytest.mark.asyncio 58 | async def test_multiple_aliases(db_connection: DBConnection): 59 | """Test using multiple aliases in a single query.""" 60 | demo1 = UserDemo(id=1, name="First Item", email="john@example.com") 61 | demo2 = UserDemo(id=2, name="Second Item", email="jane@example.com") 62 | await db_connection.insert([demo1, demo2]) 63 | 64 | # Test multiple aliases with different SQL functions 65 | results = await db_connection.exec( 66 | select( 67 | ( 68 | UserDemo, 69 | alias("upper_name", func.upper(UserDemo.name)), 70 | alias("item_count", func.count(UserDemo.id)), 71 | ) 72 | ) 73 | .group_by(UserDemo.id, UserDemo.name) 74 | .order_by(UserDemo.id) 75 | ) 76 | 77 | assert len(results) == 2 78 | assert results[0][0].id == 1 79 | assert results[0][1] == "FIRST ITEM" # uppercase name 80 | assert results[0][2] == 1 # count result 81 | assert results[1][0].id == 2 82 | assert results[1][1] == "SECOND ITEM" # uppercase name 83 | assert results[1][2] == 1 # count result 84 | -------------------------------------------------------------------------------- /iceaxe/migrations/migration.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from contextlib import asynccontextmanager 3 | 4 | from iceaxe.migrations.migrator import Migrator 5 | from iceaxe.session import DBConnection 6 | 7 | 8 | class MigrationRevisionBase: 9 | """ 10 | Base class for all revisions. This class is most often automatically 11 | generated by the `migrate` CLI command, which automatically determines 12 | the proper up and down revisions for your migration class. 13 | 14 | Once added to your project, you can modify the up/down migration methods 15 | however you see fit. 16 | 17 | """ 18 | 19 | # up and down revision are both set, except for the initial revision 20 | # where down_revision is None 21 | up_revision: str 22 | down_revision: str | None 23 | 24 | use_transaction: bool = True 25 | """ 26 | Disables the transaction for the current migration. Only do this if you're 27 | confident that the migration will succeed on the first try, or is otherwise 28 | independent so it can be run multiple times. 29 | 30 | This can speed up migrations, and in some cases might be even fully required for your 31 | production database to avoid deadlocks when interacting with hot tables. 32 | 33 | """ 34 | 35 | async def _handle_up(self, db_connection: DBConnection): 36 | """ 37 | Internal method to handle the up migration. 38 | """ 39 | # Isolated migrator context just for this migration 40 | async with self._optional_transaction(db_connection): 41 | migrator = Migrator(db_connection) 42 | await self.up(migrator) 43 | await migrator.set_active_revision(self.up_revision) 44 | 45 | async def _handle_down(self, db_connection: DBConnection): 46 | """ 47 | Internal method to handle the down migration. 48 | """ 49 | async with self._optional_transaction(db_connection): 50 | migrator = Migrator(db_connection) 51 | await self.down(migrator) 52 | await migrator.set_active_revision(self.down_revision) 53 | 54 | @asynccontextmanager 55 | async def _optional_transaction(self, db_connection: DBConnection): 56 | if self.use_transaction: 57 | async with db_connection.transaction(): 58 | yield 59 | else: 60 | yield 61 | 62 | @abstractmethod 63 | async def up(self, migrator: Migrator): 64 | """ 65 | Perform the migration "up" action. This converts your old database schema to the new 66 | schema that's found in your code definition. Add any other migration rules to your 67 | data in this function as well. 68 | 69 | It's good practice to make sure any up migrations don't immediately drop data. Instead, 70 | consider moving to a temporary table. 71 | 72 | Support both raw SQL execution and helper functions to manipulate the tables and 73 | columns that you have defined. See the Migrator class for more information. 74 | 75 | """ 76 | pass 77 | 78 | @abstractmethod 79 | async def down(self, migrator: Migrator): 80 | """ 81 | Perform the migration "down" action. This converts the current database state to the 82 | previous snapshot. It should be used in any automatic rollback pipelines if you 83 | had a feature that was rolled out and then rolled back. 84 | 85 | """ 86 | pass 87 | -------------------------------------------------------------------------------- /docs/guides/tables/page.mdx: -------------------------------------------------------------------------------- 1 | # Table Definition 2 | 3 | We define tables in Iceaxe using the `TableBase` class. You define the column 4 | type by attaching a Python typehint to the class attributes. You can further configure 5 | these columns by customizing a `Field`. 6 | 7 | ```python 8 | from iceaxe import TableBase, Field, PostgresDateTime, DBConnection 9 | 10 | class Employee(TableBase): 11 | id: int | None = Field(primary_key=True, default=None) 12 | name: str 13 | age: int 14 | payload: datetime = Field(postgres_config=PostgresDateTime(timezone=True)) 15 | ``` 16 | 17 | Right now this table just exists in Python. To create it in the database, we need to 18 | issue the instructions to create the table. You can manually write this SQL or rely 19 | on our harness to generate it for you. 20 | 21 | ```python 22 | from iceaxe.schemas.cli import create_all 23 | 24 | conn = DBConnection( 25 | await asyncpg.connect( 26 | host="localhost", 27 | port=5432, 28 | user="db_user", 29 | password="yoursecretpassword", 30 | database="your_db", 31 | ) 32 | ) 33 | await create_all(conn) 34 | ``` 35 | 36 | By default, `create_all` will create all tables that are subclasses of `TableBase` that 37 | are currently imported into your Python environment. Make sure all the models that you 38 | use are imported into thet namespace before you call `create_all`. One convenient convention 39 | to follow is to define your schemas in a `models` folder and then re-export them from 40 | the `models/__init__.py` file. 41 | 42 | ``` 43 | myproject/ 44 | ├── __init__.py 45 | └── models/ 46 | ├── __init__.py 47 | ├── employee.py 48 | └── office.py 49 | ``` 50 | 51 | ```python 52 | # models/__init__.py 53 | from .employee import Employee as Employee 54 | from .office import Office as Office 55 | ``` 56 | 57 | This allows you to import the one models module and have all of your models available 58 | with that one import: 59 | 60 | ```python 61 | from myproject import models 62 | 63 | print(models.Employee) 64 | print(models.Office) 65 | 66 | await create_all(conn) 67 | ``` 68 | 69 | You can also limit the models that are created by passing a list of explicit models to 70 | the `models` argument: 71 | 72 | ```python 73 | await create_all(conn, models=[DemoCustomModel]) 74 | ``` 75 | 76 | ## Fields 77 | 78 | This TableBase is a subclass of Pydantic's `BaseModel`, which means that 79 | you can use Pydantic's validation and serialization features on your 80 | tables in addition to our database specific flags. You can read more 81 | about the Pydantic features in the [Pydantic documentation](https://docs.pydantic.dev/latest/concepts/fields/). We 82 | define our full database options in the Iceaxe [API docs](/iceaxe/api/fields). 83 | 84 | To use one example, let's say that you want to validate that the age is 85 | at least 16. Add a new field validator for that particular field: 86 | 87 | ```python 88 | from iceaxe import TableBase, Field 89 | from pydantic import field_validator 90 | 91 | class Employee(TableBase): 92 | id: int | None = Field(primary_key=True, default=None) 93 | name: str 94 | age: int 95 | 96 | @field_validator("age") 97 | def validate_age(cls, v): 98 | if v < 16: 99 | raise ValueError("You must be at least 16 to receive a tax return") 100 | return v 101 | 102 | # Successful validation 103 | person_1 = Person(name="John Doe", age=16) 104 | print(person_1) 105 | 106 | # Raises a pydantic_core._pydantic_core.ValidationError 107 | person_2 = Person(name="John Doe", age=15) 108 | print(person_2) 109 | ``` 110 | -------------------------------------------------------------------------------- /iceaxe/__tests__/conf_models.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from enum import StrEnum 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | from pyinstrument import Profiler 7 | 8 | from iceaxe.base import Field, TableBase, UniqueConstraint 9 | 10 | 11 | class UserDemo(TableBase): 12 | id: int = Field(primary_key=True, default=None) 13 | name: str 14 | email: str 15 | 16 | 17 | class ArtifactDemo(TableBase): 18 | id: int = Field(primary_key=True, default=None) 19 | title: str 20 | user_id: int = Field(foreign_key="userdemo.id") 21 | 22 | 23 | class ComplexDemo(TableBase): 24 | id: int = Field(primary_key=True, default=None) 25 | string_list: list[str] 26 | json_data: dict[str, str] = Field(is_json=True) 27 | 28 | 29 | class Employee(TableBase): 30 | id: int = Field(primary_key=True, default=None) 31 | email: str = Field(unique=True) 32 | first_name: str 33 | last_name: str 34 | department: str 35 | salary: float 36 | 37 | 38 | class Department(TableBase): 39 | id: int = Field(primary_key=True, default=None) 40 | name: str = Field(unique=True) 41 | budget: float 42 | location: str 43 | 44 | 45 | class ProjectAssignment(TableBase): 46 | id: int = Field(primary_key=True, default=None) 47 | employee_id: int = Field(foreign_key="employee.id") 48 | project_name: str 49 | role: str 50 | start_date: str 51 | 52 | 53 | class EmployeeStatus(StrEnum): 54 | ACTIVE = "active" 55 | INACTIVE = "inactive" 56 | ON_LEAVE = "on_leave" 57 | 58 | 59 | class EmployeeMetadata(TableBase): 60 | id: int = Field(primary_key=True, default=None) 61 | employee_id: int = Field(foreign_key="employee.id") 62 | status: EmployeeStatus 63 | tags: list[str] = Field(is_json=True) 64 | additional_info: dict[str, Any] = Field(is_json=True) 65 | 66 | 67 | class FunctionDemoModel(TableBase): 68 | id: int = Field(primary_key=True, default=None) 69 | balance: float 70 | created_at: str 71 | birth_date: str 72 | start_date: str 73 | end_date: str 74 | year: int 75 | month: int 76 | day: int 77 | hour: int 78 | minute: int 79 | second: int 80 | years: int 81 | months: int 82 | days: int 83 | weeks: int 84 | hours: int 85 | minutes: int 86 | seconds: int 87 | name: str 88 | balance_str: str 89 | timestamp_str: str 90 | 91 | 92 | class DemoModelA(TableBase): 93 | id: int = Field(primary_key=True, default=None) 94 | name: str 95 | description: str 96 | code: str = Field(unique=True) 97 | 98 | 99 | class DemoModelB(TableBase): 100 | id: int = Field(primary_key=True, default=None) 101 | name: str 102 | category: str 103 | code: str = Field(unique=True) 104 | 105 | 106 | class JsonDemo(TableBase): 107 | """ 108 | Model for testing JSON field updates. 109 | """ 110 | 111 | id: int | None = Field(primary_key=True, default=None) 112 | settings: dict[Any, Any] = Field(is_json=True) 113 | metadata: dict[Any, Any] | None = Field(is_json=True) 114 | unique_val: str 115 | 116 | table_args = [UniqueConstraint(columns=["unique_val"])] 117 | 118 | 119 | @contextmanager 120 | def run_profile(request): 121 | TESTS_ROOT = Path.cwd() 122 | PROFILE_ROOT = TESTS_ROOT / ".profiles" 123 | 124 | # Turn profiling on 125 | profiler = Profiler() 126 | profiler.start() 127 | 128 | yield # Run test 129 | 130 | profiler.stop() 131 | PROFILE_ROOT.mkdir(exist_ok=True) 132 | results_file = PROFILE_ROOT / f"{request.node.name}.html" 133 | profiler.write_html(results_file) 134 | -------------------------------------------------------------------------------- /iceaxe/migrations/action_sorter.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from iceaxe.schemas.db_stubs import DBObject 4 | 5 | 6 | class ActionTopologicalSorter: 7 | """ 8 | Extends Python's native TopologicalSorter to group nodes by table_name. This provides 9 | better semantic grouping within the migrations since most actions are oriented by their 10 | parent table. 11 | 12 | - Places cross-table dependencies and non-table actions (like types) according 13 | to their default DAG order. 14 | - Tables are processed in alphabetical order of their names. 15 | 16 | """ 17 | 18 | def __init__(self, graph: dict[DBObject, list[DBObject]]): 19 | self.graph = graph 20 | self.in_degree = defaultdict(int) 21 | self.nodes = set(graph.keys()) 22 | 23 | for node, dependencies in list(graph.items()): 24 | for dep in dependencies: 25 | self.in_degree[node] += 1 26 | if dep not in self.nodes: 27 | self.nodes.add(dep) 28 | self.graph[dep] = [] 29 | 30 | # Order based on the original yield / creation order of the nodes 31 | self.node_to_ordering = {node: i for i, node in enumerate(self.graph.keys())} 32 | 33 | def sort(self): 34 | result = [] 35 | root_nodes_queued = sorted( 36 | [node for node in self.nodes if self.in_degree[node] == 0], 37 | key=self.node_key, 38 | ) 39 | if not root_nodes_queued and self.nodes: 40 | raise ValueError("Graph contains a cycle") 41 | elif not root_nodes_queued: 42 | return [] 43 | 44 | # Sort by the table name and then by the node representation 45 | root_nodes_queued.sort(key=self.node_key) 46 | 47 | # Always put the non-table actions first (things like global types) 48 | non_table_nodes = [ 49 | node for node in root_nodes_queued if not hasattr(node, "table_name") 50 | ] 51 | root_nodes_queued = [ 52 | node for node in root_nodes_queued if node not in non_table_nodes 53 | ] 54 | root_nodes_queued = non_table_nodes + root_nodes_queued 55 | 56 | queue = [root_nodes_queued.pop(0)] 57 | processed = set() 58 | 59 | while True: 60 | if not queue: 61 | # Pop another root node, if available 62 | if root_nodes_queued: 63 | queue.append(root_nodes_queued.pop(0)) 64 | continue 65 | else: 66 | # If no more root nodes are available and no more work to be done 67 | break 68 | 69 | current_node = queue.pop(0) 70 | 71 | result.append(current_node) 72 | processed.add(current_node) 73 | 74 | # Newly unblocked nodes, since we've resolved their dependencies 75 | # with the current processing 76 | new_ready = [] 77 | for dependent, deps in self.graph.items(): 78 | if current_node in deps and dependent not in processed: 79 | self.in_degree[dependent] -= 1 80 | if self.in_degree[dependent] == 0: 81 | new_ready.append(dependent) 82 | 83 | # Add newly ready nodes to queue in sorted order 84 | queue.extend( 85 | sorted(new_ready, key=lambda node: self.node_to_ordering[node]) 86 | ) 87 | 88 | if len(result) != len(self.nodes): 89 | raise ValueError("Graph contains a cycle") 90 | 91 | return result 92 | 93 | @staticmethod 94 | def node_key(node: DBObject): 95 | # Not all objects specify a table_name, but if they do we want to explicitly 96 | # sort before the representation 97 | table_name = getattr(node, "table_name", "") 98 | return (table_name, node.representation()) 99 | -------------------------------------------------------------------------------- /iceaxe/migrations/migrator.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from iceaxe.logging import LOGGER 4 | from iceaxe.schemas.actions import DatabaseActions 5 | from iceaxe.session import DBConnection 6 | 7 | 8 | class Migrator: 9 | """ 10 | Main interface for client migrations. Mountaineer provides a simple shim on top of 11 | common database migration options within `migrator.actor`. This lets you add columns, 12 | drop columns, migrate types, and the like. For more complex migrations, you can use 13 | the `migrator.db_session` to run raw SQL queries within the current migration transaction. 14 | 15 | """ 16 | 17 | actor: DatabaseActions 18 | """ 19 | The main interface for client migrations. Add tables, columns, and more using this wrapper. 20 | """ 21 | 22 | db_connection: DBConnection 23 | """ 24 | The main database connection for the migration. Use this to run raw SQL queries. We auto-wrap 25 | this connection in a transaction block for you, so successful migrations will be 26 | automatically committed when completed and unsuccessful migrations will be rolled back. 27 | 28 | """ 29 | 30 | def __init__(self, db_connection: DBConnection): 31 | self.actor = DatabaseActions(dry_run=False, db_connection=db_connection) 32 | self.db_connection = db_connection 33 | 34 | async def init_db(self): 35 | """ 36 | Initialize our migration management table if it doesn't already exist 37 | within the attached postgres database. This will be a no-op if the table 38 | already exists. 39 | 40 | Client callers should call this method once before running any migrations. 41 | 42 | """ 43 | # Create the table if it doesn't exist 44 | await self.db_connection.conn.execute( 45 | """ 46 | CREATE TABLE IF NOT EXISTS migration_info ( 47 | active_revision VARCHAR(255) 48 | ) 49 | """ 50 | ) 51 | 52 | # Check if the table is empty and insert a default value if necessary 53 | rows = await self.db_connection.conn.fetch( 54 | "SELECT COUNT(*) AS migration_count FROM migration_info" 55 | ) 56 | count = rows[0]["migration_count"] if rows else 0 57 | if count == 0: 58 | await self.db_connection.conn.execute( 59 | "INSERT INTO migration_info (active_revision) VALUES (NULL)" 60 | ) 61 | 62 | async def set_active_revision(self, value: str | None): 63 | """ 64 | Sets the active revision in the migration_info table. 65 | 66 | """ 67 | LOGGER.info(f"Setting active revision to {value}") 68 | 69 | query = """ 70 | UPDATE migration_info SET active_revision = $1 71 | """ 72 | 73 | await self.db_connection.conn.execute(query, value) 74 | 75 | LOGGER.info("Active revision set") 76 | 77 | async def get_active_revision(self) -> str | None: 78 | """ 79 | Gets the active revision from the migration_info table. 80 | Requires that the migration_info table has been initialized. 81 | 82 | """ 83 | query = """ 84 | SELECT active_revision FROM migration_info 85 | """ 86 | 87 | result = await self.db_connection.conn.fetch(query) 88 | return cast(str | None, result[0]["active_revision"] if result else None) 89 | 90 | async def raw_sql(self, query: str, *args): 91 | """ 92 | Shortcut to execute a raw SQL query against the database. Raw SQL can be more useful 93 | than using ORM objects within migrations, because you can interact with the old & new data 94 | schemas via text (whereas the runtime ORM is only aware of the current schema). 95 | 96 | ```python {{sticky: True}} 97 | await migrator.execute("CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name VARCHAR(255))") 98 | ``` 99 | 100 | """ 101 | await self.db_connection.conn.execute(query, *args) 102 | -------------------------------------------------------------------------------- /.github/scripts/update_version.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.13" 3 | # dependencies = [ 4 | # "packaging", 5 | # "toml", 6 | # ] 7 | # /// 8 | 9 | import sys 10 | from pathlib import Path 11 | from sys import stdout 12 | 13 | import toml 14 | from packaging.version import parse 15 | 16 | 17 | def update_version_rust(new_version: str): 18 | cargo_path = Path("Cargo.toml") 19 | if not cargo_path.exists(): 20 | stdout.write("Cargo.toml not found, skipping version update") 21 | return 22 | 23 | filedata = toml.loads(cargo_path.read_text()) 24 | 25 | # If the new version is a pre-release version, we need to reformat it 26 | # to align with Cargo standards 27 | # pip format uses "0.1.0.dev1" while Cargo uses "0.1.0-dev1" 28 | cargo_version = format_cargo_version(new_version) 29 | 30 | if "package" not in filedata: 31 | raise ValueError("Cargo.toml is missing the [package] section") 32 | 33 | filedata["package"]["version"] = cargo_version 34 | 35 | cargo_path.write_text(toml.dumps(filedata)) 36 | 37 | 38 | def format_cargo_version(new_version: str) -> str: 39 | parsed_version = parse(new_version) 40 | 41 | cargo_version = ( 42 | f"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro}" 43 | ) 44 | if parsed_version.is_prerelease and parsed_version.pre is not None: 45 | pre_release = ".".join(str(x) for x in parsed_version.pre) 46 | cargo_version += f"-{pre_release.replace('.', '')}" 47 | if parsed_version.is_postrelease and parsed_version.post is not None: 48 | cargo_version += f"-post{parsed_version.post}" 49 | if parsed_version.is_devrelease and parsed_version.dev is not None: 50 | cargo_version += f"-dev{parsed_version.dev}" 51 | 52 | return cargo_version 53 | 54 | 55 | def update_version_python(new_version: str): 56 | pyproject_path = Path("pyproject.toml") 57 | if not pyproject_path.exists(): 58 | print("pyproject.toml not found, skipping version update") # noqa: T201 59 | return 60 | 61 | filedata = toml.loads(pyproject_path.read_text()) 62 | 63 | # Parse the new version to ensure it's valid and potentially reformat 64 | python_version = format_python_version(new_version) 65 | 66 | updated = False 67 | 68 | # Update poetry version if it exists 69 | if "tool" in filedata and "poetry" in filedata["tool"]: 70 | filedata["tool"]["poetry"]["version"] = python_version 71 | updated = True 72 | 73 | # Update project.version if it exists 74 | if "project" in filedata and "version" in filedata["project"]: 75 | filedata["project"]["version"] = python_version 76 | updated = True 77 | 78 | if not updated: 79 | print( # noqa: T201 80 | "Warning: Neither [tool.poetry] nor [project] sections found in pyproject.toml" 81 | ) 82 | return 83 | 84 | pyproject_path.write_text(toml.dumps(filedata)) 85 | 86 | 87 | def format_python_version(new_version: str) -> str: 88 | parsed_version = parse(new_version) 89 | # Assuming semantic versioning, format it as needed 90 | python_version = ( 91 | f"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro}" 92 | ) 93 | if parsed_version.is_prerelease and parsed_version.pre is not None: 94 | pre_release = ".".join(str(x) for x in parsed_version.pre) 95 | python_version += ( 96 | f".dev{parsed_version.pre[-1]}" 97 | if pre_release.startswith("dev") 98 | else f"b{parsed_version.pre[-1]}" 99 | ) 100 | if parsed_version.is_postrelease and parsed_version.post is not None: 101 | python_version += f".post{parsed_version.post}" 102 | if parsed_version.is_devrelease and parsed_version.dev is not None: 103 | python_version += f".dev{parsed_version.dev}" 104 | return python_version 105 | 106 | 107 | if __name__ == "__main__": 108 | if len(sys.argv) != 2: 109 | stdout.write("Usage: python update_version.py ") 110 | sys.exit(1) 111 | new_version = sys.argv[1].lstrip("v") 112 | update_version_rust(new_version) 113 | update_version_python(new_version) 114 | stdout.write(f"Updated version to: {new_version}") 115 | 116 | -------------------------------------------------------------------------------- /iceaxe/io.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import importlib.metadata 3 | from functools import lru_cache, wraps 4 | from json import loads as json_loads 5 | from pathlib import Path 6 | from re import search as re_search 7 | from typing import Any, Callable, Coroutine, TypeVar 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | def lru_cache_async( 13 | maxsize: int | None = 100, 14 | ): 15 | def decorator( 16 | async_function: Callable[..., Coroutine[Any, Any, T]], 17 | ): 18 | @lru_cache(maxsize=maxsize) 19 | @wraps(async_function) 20 | def internal(*args, **kwargs): 21 | coroutine = async_function(*args, **kwargs) 22 | # Unlike regular coroutine functions, futures can be awaited multiple times 23 | # so our caller functions can await the same future on multiple cache hits 24 | return asyncio.ensure_future(coroutine) 25 | 26 | return internal 27 | 28 | return decorator 29 | 30 | 31 | def resolve_package_path(package_name: str): 32 | """ 33 | Given a package distribution, returns the local file directory where the code 34 | is located. This is resolved to the original reference if installed with `-e` 35 | otherwise is a copy of the package. 36 | 37 | NOTE: Copied from Mountaineer, refactor to a shared library. 38 | 39 | """ 40 | dist = importlib.metadata.distribution(package_name) 41 | 42 | def normalize_package(package: str): 43 | return package.replace("-", "_").lower() 44 | 45 | # Recent versions of poetry install development packages (-e .) as direct URLs 46 | # https://the-hitchhikers-guide-to-packaging.readthedocs.io/en/latest/introduction.html 47 | # "Path configuration files have an extension of .pth, and each line must 48 | # contain a single path that will be appended to sys.path." 49 | package_name = normalize_package(dist.name) 50 | symbolic_links = [ 51 | path 52 | for path in (dist.files or []) 53 | if path.name.lower() == f"{package_name}.pth" 54 | ] 55 | dist_links = [ 56 | path 57 | for path in (dist.files or []) 58 | if path.name == "direct_url.json" 59 | and re_search(package_name + r"-[0-9-.]+\.dist-info", path.parent.name.lower()) 60 | ] 61 | explicit_links = [ 62 | path 63 | for path in (dist.files or []) 64 | if path.parent.name.lower() == package_name 65 | and ( 66 | # Sanity check that the parent is the high level project directory 67 | # by looking for common base files 68 | path.name == "__init__.py" 69 | ) 70 | ] 71 | 72 | # The user installed code as an absolute package (ie. with pip install .) instead of 73 | # as a reference. There's no need to sniff for the additional package path since 74 | # we've already found it 75 | if explicit_links: 76 | # Right now we have a file pointer to __init__.py. Go up one level 77 | # to the main package directory to return a directory 78 | explicit_link = explicit_links[0] 79 | return Path(str(dist.locate_file(explicit_link.parent))) 80 | 81 | # Raw path will capture the path to the pyproject.toml file directory, 82 | # not the actual package code directory 83 | # Find the root, then resolve the package directory 84 | raw_path: Path | None = None 85 | 86 | if symbolic_links: 87 | direct_url_path = symbolic_links[0] 88 | raw_path = Path(str(dist.locate_file(direct_url_path.read_text().strip()))) 89 | elif dist_links: 90 | dist_link = dist_links[0] 91 | direct_metadata = json_loads(dist_link.read_text()) 92 | package_path = "/" + direct_metadata["url"].lstrip("file://").lstrip("/") 93 | raw_path = Path(str(dist.locate_file(package_path))) 94 | 95 | if not raw_path: 96 | raise ValueError( 97 | f"Could not find a valid path for package {dist.name}, found files: {dist.files}" 98 | ) 99 | 100 | # Sniff for the presence of the code directory 101 | for path in raw_path.iterdir(): 102 | if path.is_dir() and normalize_package(path.name) == package_name: 103 | return path 104 | 105 | raise ValueError( 106 | f"No matching package found in root path: {raw_path} {list(raw_path.iterdir())}" 107 | ) 108 | -------------------------------------------------------------------------------- /iceaxe/__tests__/benchmarks/test_select.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from time import monotonic_ns 3 | from typing import Any 4 | 5 | import asyncpg 6 | import pytest 7 | 8 | from iceaxe.__tests__.conf_models import UserDemo, run_profile 9 | from iceaxe.logging import CONSOLE, LOGGER 10 | from iceaxe.queries import QueryBuilder 11 | from iceaxe.session import DBConnection 12 | 13 | 14 | class FetchType(Enum): 15 | ID = "id" 16 | OBJ = "obj" 17 | 18 | 19 | async def insert_users(conn: asyncpg.Connection, num_users: int): 20 | users = [(f"User {i}", f"user{i}@example.com") for i in range(num_users)] 21 | await conn.executemany("INSERT INTO userdemo (name, email) VALUES ($1, $2)", users) 22 | 23 | 24 | async def fetch_users_raw(conn: asyncpg.Connection, fetch_type: FetchType) -> list[Any]: 25 | if fetch_type == FetchType.OBJ: 26 | return await conn.fetch("SELECT * FROM userdemo") # type: ignore 27 | elif fetch_type == FetchType.ID: 28 | return await conn.fetch("SELECT id FROM userdemo") # type: ignore 29 | else: 30 | raise ValueError(f"Invalid run profile: {fetch_type}") 31 | 32 | 33 | def build_iceaxe_query(fetch_type: FetchType): 34 | if fetch_type == FetchType.OBJ: 35 | return QueryBuilder().select(UserDemo) 36 | elif fetch_type == FetchType.ID: 37 | return QueryBuilder().select(UserDemo.id) 38 | else: 39 | raise ValueError(f"Invalid run profile: {fetch_type}") 40 | 41 | 42 | @pytest.mark.asyncio 43 | @pytest.mark.integration_tests 44 | @pytest.mark.parametrize( 45 | "fetch_type, allowed_overhead", 46 | [ 47 | (FetchType.ID, 10), 48 | (FetchType.OBJ, 800), 49 | ], 50 | ) 51 | async def test_benchmark( 52 | db_connection: DBConnection, request, fetch_type: FetchType, allowed_overhead: float 53 | ): 54 | num_users = 500_000 55 | num_loops = 100 56 | 57 | # Insert users using raw asyncpg 58 | await insert_users(db_connection.conn, num_users) 59 | 60 | # Benchmark raw asyncpg query 61 | start_time = monotonic_ns() 62 | raw_results: list[Any] = [] 63 | for _ in range(num_loops): 64 | raw_results = await fetch_users_raw(db_connection.conn, fetch_type) 65 | raw_time = monotonic_ns() - start_time 66 | raw_time_seconds = raw_time / 1e9 67 | raw_time_per_query = (raw_time / num_loops) / 1e9 68 | 69 | LOGGER.info( 70 | f"Raw asyncpg query time: {raw_time_per_query:.4f} (total: {raw_time_seconds:.4f}) seconds" 71 | ) 72 | CONSOLE.print( 73 | f"Raw asyncpg query time: {raw_time_per_query:.4f} (total: {raw_time_seconds:.4f}) seconds" 74 | ) 75 | 76 | # Benchmark DBConnection.exec query 77 | start_time = monotonic_ns() 78 | query = build_iceaxe_query(fetch_type) 79 | db_results: list[UserDemo] | list[int] = [] 80 | for _ in range(num_loops): 81 | db_results = await db_connection.exec(query) 82 | db_time = monotonic_ns() - start_time 83 | db_time_seconds = db_time / 1e9 84 | db_time_per_query = (db_time / num_loops) / 1e9 85 | 86 | LOGGER.info( 87 | f"DBConnection.exec query time: {db_time_per_query:.4f} (total: {db_time_seconds:.4f}) seconds" 88 | ) 89 | CONSOLE.print( 90 | f"DBConnection.exec query time: {db_time_per_query:.4f} (total: {db_time_seconds:.4f}) seconds" 91 | ) 92 | 93 | # Slower than the raw run since we need to run the performance instrumentation 94 | if False: 95 | with run_profile(request): 96 | # Right now we don't cache results so we can run multiple times to get a better measure of samples 97 | for _ in range(num_loops): 98 | query = build_iceaxe_query(fetch_type) 99 | db_results = await db_connection.exec(query) 100 | 101 | # Compare results 102 | assert len(raw_results) == len(db_results) == num_users, "Result count mismatch" 103 | 104 | # Calculate performance difference 105 | performance_diff = (db_time - raw_time) / raw_time * 100 106 | LOGGER.info(f"Performance difference: {performance_diff:.2f}%") 107 | CONSOLE.print(f"Performance difference: {performance_diff:.2f}%") 108 | 109 | # Assert that DBConnection.exec is at most X% slower than raw query 110 | assert performance_diff <= allowed_overhead, ( 111 | f"DBConnection.exec is {performance_diff:.2f}% slower than raw query, which exceeds the {allowed_overhead}% threshold" 112 | ) 113 | 114 | LOGGER.info("Benchmark completed successfully.") 115 | -------------------------------------------------------------------------------- /iceaxe/sql_types.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime, time, timedelta 2 | from enum import Enum, StrEnum 3 | from uuid import UUID 4 | 5 | 6 | class ColumnType(StrEnum): 7 | # The values of the enum are the actual SQL types. When constructing 8 | # the column they can be case-insensitive, but when we're casting from 9 | # the database to memory they must align with the on-disk representation 10 | # which is lowercase. 11 | # 12 | # Note: The SQL standard requires that writing just "timestamp" be equivalent 13 | # to "timestamp without time zone", and PostgreSQL honors that behavior. 14 | # Similarly, "time" is equivalent to "time without time zone". 15 | 16 | # Numeric Types 17 | SMALLINT = "smallint" 18 | INTEGER = "integer" 19 | BIGINT = "bigint" 20 | DECIMAL = "decimal" 21 | NUMERIC = "numeric" 22 | REAL = "real" 23 | DOUBLE_PRECISION = "double precision" 24 | SERIAL = "serial" 25 | BIGSERIAL = "bigserial" 26 | SMALLSERIAL = "smallserial" 27 | 28 | # Monetary Type 29 | MONEY = "money" 30 | 31 | # Character Types 32 | CHAR = "char" 33 | VARCHAR = "character varying" 34 | TEXT = "text" 35 | 36 | # Binary Data Types 37 | BYTEA = "bytea" 38 | 39 | # Date/Time Types 40 | DATE = "date" 41 | TIME_WITHOUT_TIME_ZONE = "time without time zone" 42 | TIME_WITH_TIME_ZONE = "time with time zone" 43 | TIMESTAMP_WITHOUT_TIME_ZONE = "timestamp without time zone" 44 | TIMESTAMP_WITH_TIME_ZONE = "timestamp with time zone" 45 | INTERVAL = "interval" 46 | 47 | # Boolean Type 48 | BOOLEAN = "boolean" 49 | 50 | # Geometric Types 51 | POINT = "point" 52 | LINE = "line" 53 | LSEG = "lseg" 54 | BOX = "box" 55 | PATH = "path" 56 | POLYGON = "polygon" 57 | CIRCLE = "circle" 58 | 59 | # Network Address Types 60 | CIDR = "cidr" 61 | INET = "inet" 62 | MACADDR = "macaddr" 63 | MACADDR8 = "macaddr8" 64 | 65 | # Bit String Types 66 | BIT = "bit" 67 | BIT_VARYING = "bit varying" 68 | 69 | # Text Search Types 70 | TSVECTOR = "tsvector" 71 | TSQUERY = "tsquery" 72 | 73 | # UUID Type 74 | UUID = "uuid" 75 | 76 | # XML Type 77 | XML = "xml" 78 | 79 | # JSON Types 80 | JSON = "json" 81 | JSONB = "jsonb" 82 | 83 | # Range Types 84 | INT4RANGE = "int4range" 85 | NUMRANGE = "numrange" 86 | TSRANGE = "tsrange" 87 | TSTZRANGE = "tstzrange" 88 | DATERANGE = "daterange" 89 | 90 | # Object Identifier Type 91 | OID = "oid" 92 | 93 | @classmethod 94 | def _missing_(cls, value: object): 95 | """ 96 | Handle SQL standard aliases when the exact enum value is not found. 97 | 98 | The SQL standard requires that "timestamp" be equivalent to "timestamp without time zone" 99 | and "time" be equivalent to "time without time zone". 100 | """ 101 | # Only handle string values for SQL type aliases 102 | if not isinstance(value, str): 103 | return None 104 | 105 | aliases = { 106 | "timestamp": "timestamp without time zone", 107 | "time": "time without time zone", 108 | } 109 | 110 | # Check if this is an alias we can resolve 111 | if value in aliases: 112 | # Return the actual enum member for the aliased value 113 | return cls(aliases[value]) 114 | 115 | # If not an alias, let the default enum behavior handle it 116 | return None 117 | 118 | 119 | class ConstraintType(StrEnum): 120 | PRIMARY_KEY = "PRIMARY KEY" 121 | FOREIGN_KEY = "FOREIGN KEY" 122 | UNIQUE = "UNIQUE" 123 | CHECK = "CHECK" 124 | INDEX = "INDEX" 125 | 126 | 127 | def get_python_to_sql_mapping(): 128 | """ 129 | Returns a mapping of Python types to their corresponding SQL types. 130 | """ 131 | return { 132 | int: ColumnType.INTEGER, 133 | float: ColumnType.DOUBLE_PRECISION, 134 | str: ColumnType.VARCHAR, 135 | bool: ColumnType.BOOLEAN, 136 | bytes: ColumnType.BYTEA, 137 | UUID: ColumnType.UUID, 138 | datetime: ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE, 139 | date: ColumnType.DATE, 140 | time: ColumnType.TIME_WITHOUT_TIME_ZONE, 141 | timedelta: ColumnType.INTERVAL, 142 | } 143 | 144 | 145 | def enum_to_name(enum: Enum) -> str: 146 | """ 147 | Returns the name of the enum as a string. 148 | """ 149 | return enum.__name__.lower() 150 | -------------------------------------------------------------------------------- /iceaxe/postgres.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | from typing import Literal 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class LexemePriority(StrEnum): 8 | """Enum representing text search lexeme priority weights in Postgres.""" 9 | 10 | HIGHEST = "A" 11 | HIGH = "B" 12 | LOW = "C" 13 | LOWEST = "D" 14 | 15 | 16 | class PostgresFieldBase(BaseModel): 17 | """ 18 | Extensions to python core types that specify addition arguments 19 | used by Postgres. 20 | 21 | """ 22 | 23 | pass 24 | 25 | 26 | class PostgresDateTime(PostgresFieldBase): 27 | """ 28 | Extension to Python's datetime type that specifies additional Postgres-specific configuration. 29 | Used to customize the timezone behavior of datetime fields in Postgres. 30 | 31 | ```python {{sticky: True}} 32 | from iceaxe import Field, TableBase 33 | class Event(TableBase): 34 | id: int = Field(primary_key=True) 35 | created_at: datetime = Field(postgres_config=PostgresDateTime(timezone=True)) 36 | ``` 37 | """ 38 | 39 | timezone: bool = False 40 | """ 41 | Whether the datetime field should include timezone information in Postgres. 42 | If True, maps to TIMESTAMP WITH TIME ZONE. 43 | If False, maps to TIMESTAMP WITHOUT TIME ZONE. 44 | Defaults to False. 45 | 46 | """ 47 | 48 | 49 | class PostgresTime(PostgresFieldBase): 50 | """ 51 | Extension to Python's time type that specifies additional Postgres-specific configuration. 52 | Used to customize the timezone behavior of time fields in Postgres. 53 | 54 | ```python {{sticky: True}} 55 | from iceaxe import Field, TableBase 56 | class Schedule(TableBase): 57 | id: int = Field(primary_key=True) 58 | start_time: time = Field(postgres_config=PostgresTime(timezone=True)) 59 | ``` 60 | """ 61 | 62 | timezone: bool = False 63 | """ 64 | Whether the time field should include timezone information in Postgres. 65 | If True, maps to TIME WITH TIME ZONE. 66 | If False, maps to TIME WITHOUT TIME ZONE. 67 | Defaults to False. 68 | 69 | """ 70 | 71 | 72 | class PostgresFullText(PostgresFieldBase): 73 | """ 74 | Extension to Python's string type that specifies additional Postgres-specific configuration 75 | for full-text search. Used to customize the behavior of text search fields in Postgres. 76 | 77 | ```python {{sticky: True}} 78 | from iceaxe import TableBase, Field 79 | from iceaxe.postgres import PostgresFullText, LexemePriority 80 | 81 | class Article(TableBase): 82 | id: int = Field(primary_key=True) 83 | title: str = Field(postgres_config=PostgresFullText( 84 | language="english", 85 | weight=LexemePriority.HIGHEST # or "A" 86 | )) 87 | content: str = Field(postgres_config=PostgresFullText( 88 | language="english", 89 | weight=LexemePriority.HIGH # or "B" 90 | )) 91 | ``` 92 | """ 93 | 94 | language: str = "english" 95 | """ 96 | The language to use for text search operations. 97 | Defaults to 'english'. 98 | """ 99 | weight: Literal["A", "B", "C", "D"] | LexemePriority = LexemePriority.HIGHEST 100 | """ 101 | The weight to assign to matches in this column. 102 | Can be specified either as a string literal ("A", "B", "C", "D") or using LexemePriority enum. 103 | A/HIGHEST is highest priority, D/LOWEST is lowest priority. 104 | Defaults to LexemePriority.HIGHEST (A). 105 | """ 106 | 107 | 108 | ForeignKeyModifications = Literal[ 109 | "RESTRICT", "NO ACTION", "CASCADE", "SET DEFAULT", "SET NULL" 110 | ] 111 | 112 | 113 | class PostgresForeignKey(PostgresFieldBase): 114 | """ 115 | Extension to Python's ForeignKey type that specifies additional Postgres-specific configuration. 116 | Used to customize the behavior of foreign key constraints in Postgres. 117 | 118 | ```python {{sticky: True}} 119 | from iceaxe import TableBase, Field 120 | 121 | class Office(TableBase): 122 | id: int = Field(primary_key=True) 123 | name: str 124 | 125 | class Employee(TableBase): 126 | id: int = Field(primary_key=True) 127 | name: str 128 | office_id: int = Field(foreign_key="office.id", postgres_config=PostgresForeignKey(on_delete="CASCADE", on_update="CASCADE")) 129 | ``` 130 | """ 131 | 132 | on_delete: ForeignKeyModifications = "NO ACTION" 133 | on_update: ForeignKeyModifications = "NO ACTION" 134 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## python ## 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .env.local 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | target 165 | *.ipynb 166 | *.ipynb_checkpoints 167 | 168 | pyrightconfig.json 169 | 170 | .profiles 171 | 172 | *.c 173 | setup.py 174 | 175 | ### macOS ### 176 | # General 177 | .DS_Store 178 | .AppleDouble 179 | .LSOverride 180 | 181 | # Icon must end with two \r 182 | Icon 183 | 184 | 185 | # Thumbnails 186 | ._* 187 | 188 | # Files that might appear in the root of a volume 189 | .DocumentRevisions-V100 190 | .fseventsd 191 | .Spotlight-V100 192 | .TemporaryItems 193 | .Trashes 194 | .VolumeIcon.icns 195 | .com.apple.timemachine.donotpresent 196 | 197 | # Directories potentially created on remote AFP share 198 | .AppleDB 199 | .AppleDesktop 200 | Network Trash Folder 201 | Temporary Items 202 | .apdisk 203 | 204 | ### macOS Patch ### 205 | # iCloud generated files 206 | *.icloud 207 | 208 | # Build pipeline 209 | session_optimized.html 210 | -------------------------------------------------------------------------------- /iceaxe/__tests__/migrations/test_generator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import Any 4 | 5 | import pytest 6 | from pydantic import BaseModel 7 | 8 | from iceaxe.migrations.generator import MigrationGenerator 9 | from iceaxe.migrations.migration import MigrationRevisionBase 10 | from iceaxe.schemas.actions import ColumnType, DatabaseActions, DryRunAction 11 | from iceaxe.schemas.db_memory_serializer import DatabaseMemorySerializer 12 | from iceaxe.schemas.db_stubs import DBTable 13 | 14 | 15 | @pytest.mark.asyncio 16 | async def test_new_migration(): 17 | migration_generator = MigrationGenerator() 18 | 19 | code, up_revision = await migration_generator.new_migration( 20 | down_objects_with_dependencies=[ 21 | ( 22 | DBTable( 23 | table_name="test_table_a", 24 | ), 25 | [], 26 | ) 27 | ], 28 | up_objects_with_dependencies=[ 29 | ( 30 | DBTable( 31 | table_name="test_table_b", 32 | ), 33 | [], 34 | ) 35 | ], 36 | down_revision="test_down_revision", 37 | user_message="test_user_message", 38 | ) 39 | 40 | # Expected up 41 | assert 'await migrator.actor.add_table(table_name="test_table_b")' in code 42 | assert 'await migrator.actor.drop_table(table_name="test_table_a")' in code 43 | 44 | # Expected down 45 | assert 'await migrator.actor.drop_table(table_name="test_table_b")' in code 46 | assert 'await migrator.actor.add_table(table_name="test_table_a")' in code 47 | 48 | assert "Context: test_user_message" in code 49 | assert 'down_revision: str | None = "test_down_revision"' in code 50 | 51 | 52 | def test_actions_to_code(): 53 | actor = DatabaseActions() 54 | migration_generator = MigrationGenerator() 55 | 56 | code = migration_generator.actions_to_code( 57 | [ 58 | DryRunAction( 59 | fn=actor.add_column, 60 | kwargs={ 61 | "table_name": "test_table", 62 | "column_name": "test_column", 63 | "explicit_data_type": ColumnType.VARCHAR, 64 | }, 65 | ) 66 | ] 67 | ) 68 | assert code == [ 69 | 'await migrator.actor.add_column(table_name="test_table", column_name="test_column", explicit_data_type=ColumnType.VARCHAR)' 70 | ] 71 | 72 | 73 | def test_actions_to_code_pass(): 74 | """ 75 | We support generating migrations where there are no schema-level changes, so users can 76 | write their own data migration logic. In these cases we should pass the code-block 77 | so the resulting file is still legitimate. 78 | 79 | """ 80 | migration_generator = MigrationGenerator() 81 | code = migration_generator.actions_to_code([]) 82 | assert code == ["pass"] 83 | 84 | 85 | class ExampleEnum(Enum): 86 | A = "a" 87 | B = "b" 88 | 89 | 90 | class ExampleModel(BaseModel): 91 | value: str 92 | 93 | 94 | @dataclass 95 | class ExampleDataclass: 96 | value: str 97 | 98 | 99 | @pytest.mark.parametrize( 100 | "value, expected_value", 101 | [ 102 | (ExampleEnum.A, "ExampleEnum.A"), 103 | ("example_arg", '"example_arg"'), 104 | (1, "1"), 105 | ( 106 | {"key": "value", "nested": {"key": "value2"}}, 107 | '{"key": "value", "nested": {"key": "value2"}}', 108 | ), 109 | ( 110 | { 111 | "key": ExampleModel(value="test"), 112 | "enum": ExampleEnum.B, 113 | }, 114 | '{"key": ExampleModel(value="test"), "enum": ExampleEnum.B}', 115 | ), 116 | (ExampleModel(value="test"), 'ExampleModel(value="test")'), 117 | (ExampleDataclass(value="test"), 'ExampleDataclass(value="test")'), 118 | (True, "True"), 119 | (False, "False"), 120 | (frozenset({"A", "B"}), 'frozenset({"A", "B"})'), 121 | ({"A", "B"}, '{"A", "B"}'), 122 | (("A",), '("A",)'), 123 | (("A", "B"), '("A", "B")'), 124 | ], 125 | ) 126 | def test_format_arg(value: Any, expected_value: str): 127 | migration_generator = MigrationGenerator() 128 | assert migration_generator.format_arg(value) == expected_value 129 | 130 | 131 | def test_track_import(): 132 | migration_generator = MigrationGenerator() 133 | 134 | migration_generator.track_import(DatabaseMemorySerializer) 135 | migration_generator.track_import(MigrationRevisionBase) 136 | 137 | assert dict(migration_generator.import_tracker) == { 138 | "iceaxe.migrations.migration": {"MigrationRevisionBase"}, 139 | "iceaxe.schemas.db_memory_serializer": {"DatabaseMemorySerializer"}, 140 | } 141 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Library Tests and Release 2 | on: 3 | push: 4 | branches: [main] 5 | tags: 6 | - "v*" 7 | pull_request: 8 | branches: [main] 9 | types: [opened, synchronize, reopened, labeled] 10 | jobs: 11 | test: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [ubuntu-latest] 17 | python-version: ["3.11", "3.12", "3.13"] 18 | services: 19 | postgres: 20 | image: postgres:latest 21 | env: 22 | POSTGRES_USER: iceaxe 23 | POSTGRES_PASSWORD: mysecretpassword 24 | POSTGRES_DB: iceaxe_test_db 25 | ports: 26 | - 5438:5432 27 | options: >- 28 | --health-cmd pg_isready 29 | --health-interval 10s 30 | --health-timeout 5s 31 | --health-retries 5 32 | steps: 33 | - uses: actions/checkout@v4 34 | - name: Install uv and with python version ${{ matrix.python-version }} 35 | uses: astral-sh/setup-uv@v6.3.0 36 | with: 37 | enable-cache: false 38 | python-version: ${{ matrix.python-version }} 39 | - name: Run tests 40 | run: uv run pytest -v --continue-on-collection-errors 41 | env: 42 | ICEAXE_LOG_LEVEL: DEBUG 43 | 44 | lint: 45 | runs-on: ubuntu-latest 46 | steps: 47 | - uses: actions/checkout@v4 48 | - name: Install uv with Python 3.13 49 | uses: astral-sh/setup-uv@v6.3.0 50 | with: 51 | enable-cache: false 52 | python-version: 3.13 53 | - name: Run lint 54 | run: make lint 55 | 56 | full-build: 57 | needs: [test, lint] 58 | if: github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'Full Build') || startsWith(github.ref, 'refs/tags/v') 59 | strategy: 60 | matrix: 61 | os: [ubuntu-latest, macos-latest] 62 | runs-on: ${{ matrix.os }} 63 | steps: 64 | - uses: actions/checkout@v4 65 | - name: Install uv with Python 3.13 66 | uses: astral-sh/setup-uv@v6.3.0 67 | with: 68 | enable-cache: false 69 | python-version: 3.13 70 | 71 | - name: Update version 72 | if: startsWith(github.ref, 'refs/tags/v') 73 | shell: bash 74 | run: uv run .github/scripts/update_version.py ${{ github.ref_name }} 75 | 76 | - name: Clean build artifacts 77 | run: | 78 | rm -rf build/ 79 | rm -rf iceaxe/*.so 80 | rm -rf *.egg-info/ 81 | find . -type f -name "*.so" -delete 82 | find . -type f -name "*.o" -delete 83 | find . -type f -name "*.c" -delete 84 | 85 | - name: Set up QEMU 86 | if: runner.os == 'Linux' 87 | uses: docker/setup-qemu-action@v3 88 | with: 89 | platforms: arm64 90 | 91 | - name: Build SDist 92 | if: matrix.os == 'ubuntu-latest' 93 | run: uv build --sdist 94 | 95 | - name: Build wheels 96 | uses: pypa/cibuildwheel@v3.0.0 97 | env: 98 | CIBW_ARCHS_LINUX: x86_64 aarch64 99 | CIBW_BUILD: cp311-* cp312-* cp313-* 100 | CIBW_SKIP: "*-musllinux*" 101 | CIBW_MANYLINUX_AARCH64_IMAGE: manylinux2014 102 | CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014 103 | CIBW_ENVIRONMENT: "CFLAGS='-std=c99'" 104 | CIBW_BEFORE_BUILD: | 105 | rm -rf {project}/build/ 106 | rm -rf {project}/*.egg-info/ 107 | find {project} -type f -name "*.so" -delete 108 | find {project} -type f -name "*.o" -delete 109 | find {project} -type f -name "*.c" -delete 110 | 111 | - name: Upload artifacts 112 | uses: actions/upload-artifact@v4 113 | with: 114 | name: dist-${{ matrix.os }} 115 | path: | 116 | dist/ 117 | wheelhouse/ 118 | 119 | release: 120 | needs: [full-build] 121 | if: startsWith(github.ref, 'refs/tags/v') 122 | runs-on: ubuntu-latest 123 | environment: 124 | name: pypi 125 | url: https://pypi.org/p/iceaxe 126 | permissions: 127 | id-token: write 128 | steps: 129 | - name: Download Ubuntu build 130 | uses: actions/download-artifact@v4 131 | with: 132 | name: dist-ubuntu-latest 133 | path: dist-raw/ubuntu 134 | 135 | - name: Download macOS build 136 | uses: actions/download-artifact@v4 137 | with: 138 | name: dist-macos-latest 139 | path: dist-raw/macos 140 | 141 | - name: Prepare distribution files 142 | run: | 143 | mkdir -p dist 144 | cp -r dist-raw/ubuntu/dist/* dist/ 2>/dev/null || true 145 | cp -r dist-raw/ubuntu/wheelhouse/* dist/ 2>/dev/null || true 146 | cp -r dist-raw/macos/dist/* dist/ 2>/dev/null || true 147 | cp -r dist-raw/macos/wheelhouse/* dist/ 2>/dev/null || true 148 | echo "Content of dist directory:" 149 | ls -la dist/ 150 | 151 | - name: Publish package distributions to PyPI 152 | uses: pypa/gh-action-pypi-publish@release/v1 153 | -------------------------------------------------------------------------------- /iceaxe/generics.py: -------------------------------------------------------------------------------- 1 | import types 2 | from inspect import isclass 3 | from typing import Any, Type, TypeGuard, TypeVar, Union, get_args, get_origin 4 | 5 | 6 | def mro_distance(obj_type: Type, target_type: Type) -> float: 7 | """ 8 | Calculate the MRO distance between obj_type and target_type. 9 | Returns a large number if no match is found. 10 | 11 | """ 12 | if not isclass(obj_type): 13 | obj_type = type(obj_type) 14 | if not isclass(target_type): 15 | target_type = type(target_type) 16 | 17 | # Compare class types for exact match 18 | if obj_type == target_type: 19 | return 0 20 | 21 | # Check if obj_type is a subclass of target_type using the MRO 22 | try: 23 | return obj_type.mro().index(target_type) # type: ignore 24 | except ValueError: 25 | return float("inf") 26 | 27 | 28 | T = TypeVar("T") 29 | 30 | 31 | def is_type_compatible(obj_type: Type, target_type: T) -> TypeGuard[T]: 32 | return _is_type_compatible(obj_type, target_type) < float("inf") 33 | 34 | 35 | def _is_type_compatible(obj_type: Type, target_type: Any) -> float: 36 | """ 37 | Relatively comprehensive type compatibility checker. This function is 38 | used to check if a type has has a registered object that can 39 | handle it. 40 | 41 | Specifically returns the MRO distance where 0 indicates 42 | an exact match, 1 indicates a direct ancestor, and so on. Returns a large number 43 | if no compatibility is found. 44 | 45 | """ 46 | # Any type is compatible with any other type 47 | if target_type is Any: 48 | return 0 49 | 50 | # If obj_type is a nested type, each of these types must be compatible 51 | # with the corresponding type in target_type 52 | if get_origin(obj_type) is Union or isinstance(obj_type, types.UnionType): 53 | return max(_is_type_compatible(t, target_type) for t in get_args(obj_type)) 54 | 55 | # Handle OR types 56 | if get_origin(target_type) is Union or isinstance(target_type, types.UnionType): 57 | return min(_is_type_compatible(obj_type, t) for t in get_args(target_type)) 58 | 59 | # Handle Type[Values] like typehints where we want to typehint a class 60 | if get_origin(target_type) == type: # noqa: E721 61 | return _is_type_compatible(obj_type, get_args(target_type)[0]) 62 | 63 | # Handle dict[str, str] like typehints 64 | # We assume that each arg in order must be matched with the target type 65 | obj_origin = get_origin(obj_type) 66 | target_origin = get_origin(target_type) 67 | if obj_origin and target_origin: 68 | if obj_origin == target_origin: 69 | return max( 70 | _is_type_compatible(t1, t2) 71 | for t1, t2 in zip(get_args(obj_type), get_args(target_type)) 72 | ) 73 | else: 74 | return float("inf") 75 | 76 | # For lists, sets, and tuple objects make sure that each object matches 77 | # the target type 78 | if isinstance(obj_type, (list, set, tuple)): 79 | if type(obj_type) != get_origin(target_type): # noqa: E721 80 | return float("inf") 81 | return max( 82 | _is_type_compatible(obj, get_args(target_type)[0]) for obj in obj_type 83 | ) 84 | 85 | if isinstance(target_type, type): 86 | return mro_distance(obj_type, target_type) 87 | 88 | # Default case 89 | return float("inf") 90 | 91 | 92 | def remove_null_type(typehint: Type) -> Type: 93 | if get_origin(typehint) is Union or isinstance(typehint, types.UnionType): 94 | return Union[ # type: ignore 95 | tuple( # type: ignore 96 | [t for t in get_args(typehint) if t != type(None)] # noqa: E721 97 | ) 98 | ] 99 | return typehint 100 | 101 | 102 | def has_null_type(typehint: Type) -> bool: 103 | if get_origin(typehint) is Union or isinstance(typehint, types.UnionType): 104 | return any(arg == type(None) for arg in get_args(typehint)) # noqa: E721 105 | return typehint == type(None) # noqa: E721 106 | 107 | 108 | def get_typevar_mapping(cls): 109 | """ 110 | Get the raw typevar mappings {typvar: generic values} for each 111 | typevar in the class hierarchy of `cls`. 112 | 113 | Shared logic with Mountaineer. TODO: Move to a shared package. 114 | 115 | """ 116 | mapping: dict[Any, Any] = {} 117 | 118 | # Traverse MRO in reverse order, except `object` 119 | for base in reversed(cls.__mro__[:-1]): 120 | # Skip non-generic classes 121 | if not hasattr(base, "__orig_bases__"): 122 | continue 123 | 124 | for origin_base in base.__orig_bases__: 125 | origin = get_origin(origin_base) 126 | if origin: 127 | base_params = getattr(origin, "__parameters__", []) 128 | instantiated_params = get_args(origin_base) 129 | 130 | # Update mapping with current base's mappings 131 | base_mapping = dict(zip(base_params, instantiated_params)) 132 | for key, value in base_mapping.items(): 133 | # If value is another TypeVar, resolve it if possible 134 | if isinstance(value, TypeVar) and value in mapping: 135 | mapping[key] = mapping[value] 136 | else: 137 | mapping[key] = value 138 | 139 | # Exclude TypeVars from the final mapping 140 | return mapping 141 | -------------------------------------------------------------------------------- /docs/guides/quickstart/page.mdx: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | ## Installation 4 | 5 | If you're using `uv` to manage your dependencies: 6 | 7 | ```bash 8 | uv add iceaxe 9 | ``` 10 | 11 | Otherwise install with pip: 12 | 13 | ```bash 14 | pip install iceaxe 15 | ``` 16 | 17 | ## Usage 18 | 19 | Define your models as a `TableBase` subclass: 20 | 21 | ```python 22 | from iceaxe import TableBase 23 | 24 | class Person(TableBase): 25 | id: int 26 | name: str 27 | age: int 28 | ``` 29 | 30 | TableBase is a subclass of Pydantic's `BaseModel`, so you get all of the validation and Field customization 31 | out of the box. We provide our own `Field` constructor that adds database-specific configuration. For instance, to make the 32 | `id` field a primary key / auto-incrementing you can do: 33 | 34 | ```python 35 | from iceaxe import Field 36 | 37 | class Person(TableBase): 38 | id: int = Field(primary_key=True) 39 | name: str 40 | age: int 41 | ``` 42 | 43 | Okay now you have a model. How do you interact with it? 44 | 45 | Databases are based on a few core primitives to insert data, update it, and fetch it out again. 46 | To do so you'll need a _database connection_, which is a connection over the network from your code 47 | to your Postgres database. The `DBConnection` is the core class for all ORM actions against the database. 48 | 49 | ```python 50 | from iceaxe import DBConnection 51 | import asyncpg 52 | 53 | conn = DBConnection( 54 | await asyncpg.connect( 55 | host="localhost", 56 | port=5432, 57 | user="db_user", 58 | password="yoursecretpassword", 59 | database="your_db", 60 | ) 61 | ) 62 | ``` 63 | 64 | The Person class currently just lives in memory. To back it with a full 65 | database table, we can run raw SQL or run a [migration](/iceaxe/guides/migrations) to 66 | add it automatically: 67 | 68 | ```python 69 | await conn.conn.execute( 70 | """ 71 | CREATE TABLE IF NOT EXISTS person ( 72 | id SERIAL PRIMARY KEY, 73 | name TEXT NOT NULL, 74 | age INT NOT NULL 75 | ) 76 | """ 77 | ) 78 | ``` 79 | 80 | ### Inserting Data 81 | 82 | Instantiate object classes as you normally do: 83 | 84 | ```python 85 | people = [ 86 | Person(name="Alice", age=30), 87 | Person(name="Bob", age=40), 88 | Person(name="Charlie", age=50), 89 | ] 90 | await conn.insert(people) 91 | 92 | print(people[0].id) # 1 93 | print(people[1].id) # 2 94 | ``` 95 | 96 | Because we're using an auto-incrementing primary key, the `id` field will be populated after the insert. 97 | Iceaxe will automatically update the object in place with the newly assigned value. 98 | 99 | ### Updating data 100 | 101 | Now that we have these lovely people, let's modify them. 102 | 103 | ```python 104 | person = people[0] 105 | person.name = "Blice" 106 | ``` 107 | 108 | Right now, we have a Python object that's out of state with the database. But that's often okay. We can inspect it 109 | and further write logic - it's fully decoupled from the database. 110 | 111 | ```python 112 | def ensure_b_letter(person: Person): 113 | if person.name[0].lower() != "b": 114 | raise ValueError("Name must start with 'B'") 115 | 116 | ensure_b_letter(person) 117 | ``` 118 | 119 | To sync the values back to the database, we can call `update`: 120 | 121 | ```python 122 | await conn.update([person]) 123 | ``` 124 | 125 | If we were to query the database directly, we see that the name has been updated: 126 | 127 | ``` 128 | id | name | age 129 | ----+-------+----- 130 | 1 | Blice | 31 131 | 2 | Bob | 40 132 | 3 | Charlie | 50 133 | ``` 134 | 135 | But no other fields have been touched. This lets a potentially concurrent process 136 | modify `Alice`'s record - say, updating the age to 31. By the time we update the data, we'll 137 | change the name but nothing else. Under the hood we do this by tracking the fields that 138 | have been modified in-memory and creating a targeted UPDATE to modify only those values. 139 | 140 | ### Selecting data 141 | 142 | To select data, we can use a `QueryBuilder`. For a shortcut to `select` query functions, 143 | you can also just import select directly. This method takes the desired value parameters 144 | and returns a list of the desired objects. 145 | 146 | ```python 147 | from iceaxe import select 148 | 149 | query = select(Person).where(Person.name == "Blice", Person.age > 25) 150 | results = await conn.exec(query) 151 | ``` 152 | 153 | If we inspect the typing of `results`, we see that it's a `list[Person]` objects. This matches 154 | the typehint of the `select` function. You can also target columns directly: 155 | 156 | ```python 157 | query = select((Person.id, Person.name)).where(Person.age > 25) 158 | results = await conn.exec(query) 159 | ``` 160 | 161 | This will return a list of tuples, where each tuple is the id and name of the person: `list[tuple[int, str]]`. 162 | 163 | We support most of the common SQL operations. Just like the results, these are typehinted 164 | to their proper types as well. Static typecheckers and your IDE will throw an error if you try to compare 165 | a string column to an integer, for instance. A more complex example of a query: 166 | 167 | ```python 168 | query = select(( 169 | Person.id, 170 | FavoriteColor, 171 | )).join( 172 | FavoriteColor, 173 | Person.id == FavoriteColor.person_id, 174 | ).where( 175 | Person.age > 25, 176 | Person.name == "Blice", 177 | ).order_by( 178 | Person.age.desc(), 179 | ).limit(10) 180 | results = await conn.exec(query) 181 | ``` 182 | 183 | As expected this will deliver results - and typehint - as a `list[tuple[int, FavoriteColor]]` 184 | -------------------------------------------------------------------------------- /iceaxe/__tests__/test_modifications.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | from iceaxe.__tests__.conf_models import ArtifactDemo, UserDemo 6 | from iceaxe.modifications import ( 7 | MODIFICATION_TRACKER_VERBOSITY, 8 | Modification, 9 | ModificationTracker, 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def tracker(): 15 | """Create a fresh ModificationTracker for each test.""" 16 | return ModificationTracker(known_first_party=["test_modifications"]) 17 | 18 | 19 | @pytest.fixture 20 | def demo_instance(): 21 | """Create a demo model instance for testing.""" 22 | return UserDemo(id=1, name="test", email="test@example.com") 23 | 24 | 25 | def test_get_current_stack_trace(): 26 | """Test that get_current_stack_trace returns both traces.""" 27 | full_trace, user_trace = Modification.get_current_stack_trace() 28 | 29 | assert isinstance(full_trace, str) 30 | assert isinstance(user_trace, str) 31 | assert "test_modifications.py" in user_trace 32 | assert len(full_trace) >= len(user_trace) 33 | 34 | 35 | def test_track_modification_new_instance( 36 | tracker: ModificationTracker, demo_instance: UserDemo 37 | ): 38 | """Test tracking a new modification.""" 39 | tracker.track_modification(demo_instance) 40 | 41 | instance_id = id(demo_instance) 42 | assert instance_id in tracker.modified_models 43 | 44 | modification = tracker.modified_models[instance_id] 45 | assert modification.instance == demo_instance 46 | assert "test_modifications.py" in modification.user_stack_trace 47 | 48 | 49 | def test_track_modification_duplicate( 50 | tracker: ModificationTracker, demo_instance: UserDemo 51 | ): 52 | """Test that tracking the same instance twice only records it once.""" 53 | tracker.track_modification(demo_instance) 54 | tracker.track_modification(demo_instance) 55 | 56 | assert len(tracker.modified_models) == 1 57 | 58 | 59 | def test_clear_status_single(tracker: ModificationTracker, demo_instance: UserDemo): 60 | """Test committing a single model.""" 61 | tracker.track_modification(demo_instance) 62 | tracker.clear_status([demo_instance]) 63 | 64 | assert id(demo_instance) not in tracker.modified_models 65 | 66 | 67 | def test_clear_status_partial(tracker: ModificationTracker): 68 | """Test committing some but not all models.""" 69 | instance1 = UserDemo(id=1, name="test1", email="test1@example.com") 70 | instance2 = UserDemo(id=2, name="test2", email="test2@example.com") 71 | 72 | tracker.track_modification(instance1) 73 | tracker.track_modification(instance2) 74 | tracker.clear_status([instance1]) 75 | 76 | assert id(instance1) not in tracker.modified_models 77 | assert id(instance2) in tracker.modified_models 78 | assert tracker.modified_models[id(instance2)].instance == instance2 79 | 80 | 81 | @pytest.mark.parametrize("verbosity", ["ERROR", "WARNING", "INFO", None]) 82 | def test_log_with_different_verbosity( 83 | tracker: ModificationTracker, 84 | demo_instance: UserDemo, 85 | verbosity: MODIFICATION_TRACKER_VERBOSITY, 86 | caplog, 87 | ): 88 | """Test logging with different verbosity levels.""" 89 | tracker.verbosity = verbosity 90 | tracker.track_modification(demo_instance) 91 | 92 | with caplog.at_level(logging.INFO): 93 | tracker.log() 94 | 95 | if verbosity: 96 | assert len(caplog.records) > 0 97 | assert "Object modified locally but not committed" in caplog.records[0].message 98 | if verbosity == "INFO": 99 | assert any( 100 | "Full stack trace" in record.message for record in caplog.records 101 | ) 102 | else: 103 | assert len(caplog.records) == 0 104 | 105 | 106 | def test_multiple_model_types(tracker: ModificationTracker): 107 | """Test tracking modifications for different model types.""" 108 | instance1 = UserDemo(id=1, name="test", email="test@example.com") 109 | instance2 = ArtifactDemo(id=2, title="test", user_id=1) 110 | 111 | tracker.track_modification(instance1) 112 | tracker.track_modification(instance2) 113 | 114 | assert len(tracker.modified_models) == 2 115 | assert id(instance1) in tracker.modified_models 116 | assert id(instance2) in tracker.modified_models 117 | 118 | 119 | def test_clear_status_cleanup(tracker: ModificationTracker): 120 | """Test that clear_status properly cleans up empty model lists.""" 121 | instance = UserDemo(id=1, name="test", email="test@example.com") 122 | tracker.track_modification(instance) 123 | 124 | assert id(instance) in tracker.modified_models 125 | tracker.clear_status([instance]) 126 | assert id(instance) not in tracker.modified_models 127 | 128 | 129 | def test_callback_registration(tracker: ModificationTracker): 130 | """ 131 | Test that registering the tracker as a callback on a model instance 132 | properly tracks modifications when the model is changed. 133 | """ 134 | instance = UserDemo(id=1, name="test", email="test@example.com") 135 | instance.register_modified_callback(tracker.track_modification) 136 | 137 | # Initially no modifications 138 | assert id(instance) not in tracker.modified_models 139 | 140 | # Modify the instance 141 | instance.name = "new name" 142 | 143 | # Should have tracked the modification 144 | assert id(instance) in tracker.modified_models 145 | modification = tracker.modified_models[id(instance)] 146 | assert modification.instance == instance 147 | assert "test_modifications.py" in modification.user_stack_trace 148 | 149 | # Another modification shouldn't create a new entry 150 | instance.email = "new@example.com" 151 | assert len(tracker.modified_models) == 1 152 | -------------------------------------------------------------------------------- /iceaxe/__tests__/test_queries_str.py: -------------------------------------------------------------------------------- 1 | from iceaxe.base import TableBase 2 | from iceaxe.queries_str import QueryIdentifier, QueryLiteral, sql 3 | 4 | 5 | class DemoModel(TableBase): 6 | field_one: str 7 | field_two: int 8 | 9 | 10 | def test_query_identifier(): 11 | """Test the QueryIdentifier class for proper SQL identifier quoting.""" 12 | identifier = QueryIdentifier("test_field") 13 | assert str(identifier) == '"test_field"' 14 | 15 | # Test with special characters and SQL keywords 16 | identifier = QueryIdentifier("group") 17 | assert str(identifier) == '"group"' 18 | 19 | identifier = QueryIdentifier("test.field") 20 | assert str(identifier) == '"test.field"' 21 | 22 | 23 | def test_query_literal(): 24 | """Test the QueryLiteral class for raw SQL inclusion.""" 25 | literal = QueryLiteral("COUNT(*)") 26 | assert str(literal) == "COUNT(*)" 27 | 28 | literal = QueryLiteral("CASE WHEN x > 0 THEN 1 ELSE 0 END") 29 | assert str(literal) == "CASE WHEN x > 0 THEN 1 ELSE 0 END" 30 | 31 | 32 | def test_query_element_hashable(): 33 | """Test that QueryElementBase subclasses are hashable and work in sets/dicts.""" 34 | # Test set behavior 35 | elements = { 36 | QueryIdentifier("users"), 37 | QueryIdentifier("users"), # Duplicate should be removed 38 | QueryIdentifier("posts"), 39 | QueryLiteral("COUNT(*)"), 40 | QueryLiteral("COUNT(*)"), # Duplicate should be removed 41 | } 42 | assert len(elements) == 3 43 | assert QueryIdentifier("users") in elements 44 | assert QueryLiteral("COUNT(*)") in elements 45 | 46 | # Test dictionary behavior 47 | element_map = { 48 | QueryIdentifier("users"): "users table", 49 | QueryLiteral("COUNT(*)"): "count function", 50 | } 51 | assert element_map[QueryIdentifier("users")] == "users table" 52 | assert element_map[QueryLiteral("COUNT(*)")] == "count function" 53 | 54 | # Test hash consistency with equality 55 | id1 = QueryIdentifier("test") 56 | id2 = QueryIdentifier("test") 57 | assert id1 == id2 58 | assert hash(id1) == hash(id2) 59 | 60 | lit1 = QueryLiteral("COUNT(*)") 61 | lit2 = QueryLiteral("COUNT(*)") 62 | assert lit1 == lit2 63 | assert hash(lit1) == hash(lit2) 64 | 65 | 66 | def test_query_element_sortable(): 67 | """Test that QueryElementBase subclasses can be sorted.""" 68 | # Test sorting of identifiers 69 | identifiers = [ 70 | QueryIdentifier("users"), 71 | QueryIdentifier("posts"), 72 | QueryIdentifier("comments"), 73 | ] 74 | sorted_identifiers = sorted(identifiers) 75 | assert [str(literal) for literal in sorted_identifiers] == [ 76 | '"comments"', 77 | '"posts"', 78 | '"users"', 79 | ] 80 | 81 | # Test sorting of literals 82 | literals = [ 83 | QueryLiteral("SUM(*)"), 84 | QueryLiteral("COUNT(*)"), 85 | QueryLiteral("AVG(*)"), 86 | ] 87 | sorted_literals = sorted(literals) 88 | assert [str(literal) for literal in sorted_literals] == [ 89 | "AVG(*)", 90 | "COUNT(*)", 91 | "SUM(*)", 92 | ] 93 | 94 | # Test sorting mixed elements 95 | mixed = [ 96 | QueryIdentifier("users"), 97 | QueryLiteral("COUNT(*)"), 98 | QueryIdentifier("posts"), 99 | ] 100 | sorted_mixed = sorted(mixed) 101 | assert [str(literal) for literal in sorted_mixed] == [ 102 | '"posts"', 103 | '"users"', 104 | "COUNT(*)", 105 | ] 106 | 107 | # Test sorting with duplicates 108 | with_duplicates = [ 109 | QueryIdentifier("users"), 110 | QueryIdentifier("posts"), 111 | QueryIdentifier("users"), 112 | QueryIdentifier("comments"), 113 | ] 114 | sorted_with_duplicates = sorted(with_duplicates) 115 | assert [str(literal) for literal in sorted_with_duplicates] == [ 116 | '"comments"', 117 | '"posts"', 118 | '"users"', 119 | '"users"', 120 | ] 121 | 122 | # Test reverse sorting 123 | reverse_sorted = sorted(identifiers, reverse=True) 124 | assert [str(literal) for literal in reverse_sorted] == [ 125 | '"users"', 126 | '"posts"', 127 | '"comments"', 128 | ] 129 | 130 | 131 | def test_sql_call_column(): 132 | """Test SQLGenerator's __call__ method with a column.""" 133 | result = sql(DemoModel.field_one) 134 | assert isinstance(result, QueryLiteral) 135 | assert str(result) == '"demomodel"."field_one"' 136 | 137 | 138 | def test_sql_call_table(): 139 | """Test SQLGenerator's __call__ method with a table.""" 140 | result = sql(DemoModel) 141 | assert isinstance(result, QueryIdentifier) 142 | assert str(result) == '"demomodel"' 143 | 144 | 145 | def test_sql_select_column(): 146 | """Test SQLGenerator's select method with a column.""" 147 | result = sql.select(DemoModel.field_one) 148 | assert isinstance(result, QueryLiteral) 149 | assert str(result) == '"demomodel"."field_one" AS "demomodel_field_one"' 150 | 151 | 152 | def test_sql_select_table(): 153 | """Test SQLGenerator's select method with a table.""" 154 | result = sql.select(DemoModel) 155 | assert isinstance(result, QueryLiteral) 156 | assert str(result) == ( 157 | '"demomodel"."field_one" AS "demomodel_field_one", ' 158 | '"demomodel"."field_two" AS "demomodel_field_two"' 159 | ) 160 | 161 | 162 | def test_sql_raw_column(): 163 | """Test SQLGenerator's raw method with a column.""" 164 | result = sql.raw(DemoModel.field_one) 165 | assert isinstance(result, QueryIdentifier) 166 | assert str(result) == '"field_one"' 167 | 168 | 169 | def test_sql_raw_table(): 170 | """Test SQLGenerator's raw method with a table.""" 171 | result = sql.raw(DemoModel) 172 | assert isinstance(result, QueryIdentifier) 173 | assert str(result) == '"demomodel"' 174 | -------------------------------------------------------------------------------- /iceaxe/__tests__/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import asyncpg 4 | import pytest 5 | import pytest_asyncio 6 | 7 | from iceaxe.__tests__ import docker_helpers 8 | from iceaxe.base import DBModelMetaclass 9 | from iceaxe.session import DBConnection 10 | 11 | # Configure logging 12 | logging.basicConfig(level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def docker_postgres(): 18 | """ 19 | Fixture that creates a PostgreSQL container using the Python Docker API. 20 | This allows running individual tests without needing Docker Compose. 21 | """ 22 | # Create and start a PostgreSQL container 23 | postgres_container = docker_helpers.PostgresContainer() 24 | 25 | # Start the container and yield connection details 26 | connection_info = postgres_container.start() 27 | yield connection_info 28 | 29 | # Cleanup: stop the container 30 | postgres_container.stop() 31 | 32 | 33 | @pytest_asyncio.fixture 34 | async def db_connection(docker_postgres): 35 | """ 36 | Create a database connection using the PostgreSQL container. 37 | """ 38 | conn = DBConnection( 39 | await asyncpg.connect( 40 | host=docker_postgres["host"], 41 | port=docker_postgres["port"], 42 | user=docker_postgres["user"], 43 | password=docker_postgres["password"], 44 | database=docker_postgres["database"], 45 | ) 46 | ) 47 | 48 | # Drop all tables first to ensure clean state 49 | known_tables = [ 50 | "artifactdemo", 51 | "userdemo", 52 | "complexdemo", 53 | "article", 54 | "employee", 55 | "department", 56 | "projectassignment", 57 | "employeemetadata", 58 | "functiondemomodel", 59 | "demomodela", 60 | "demomodelb", 61 | "jsondemo", 62 | "complextypedemo", 63 | ] 64 | known_types = ["statusenum", "employeestatus"] 65 | 66 | for table in known_tables: 67 | await conn.conn.execute(f"DROP TABLE IF EXISTS {table} CASCADE", timeout=30.0) 68 | 69 | for known_type in known_types: 70 | await conn.conn.execute( 71 | f"DROP TYPE IF EXISTS {known_type} CASCADE", timeout=30.0 72 | ) 73 | 74 | # Create tables 75 | await conn.conn.execute( 76 | """ 77 | CREATE TABLE IF NOT EXISTS userdemo ( 78 | id SERIAL PRIMARY KEY, 79 | name TEXT, 80 | email TEXT 81 | ) 82 | """, 83 | timeout=30.0, 84 | ) 85 | 86 | await conn.conn.execute( 87 | """ 88 | CREATE TABLE IF NOT EXISTS artifactdemo ( 89 | id SERIAL PRIMARY KEY, 90 | title TEXT, 91 | user_id INT REFERENCES userdemo(id) 92 | ) 93 | """, 94 | timeout=30.0, 95 | ) 96 | 97 | await conn.conn.execute( 98 | """ 99 | CREATE TABLE IF NOT EXISTS complexdemo ( 100 | id SERIAL PRIMARY KEY, 101 | string_list TEXT[], 102 | json_data JSON 103 | ) 104 | """, 105 | timeout=30.0, 106 | ) 107 | 108 | await conn.conn.execute( 109 | """ 110 | CREATE TABLE IF NOT EXISTS article ( 111 | id SERIAL PRIMARY KEY, 112 | title TEXT, 113 | content TEXT, 114 | summary TEXT 115 | ) 116 | """, 117 | timeout=30.0, 118 | ) 119 | 120 | # Create each index separately to handle errors better 121 | yield conn 122 | 123 | # Drop all tables after tests 124 | for table in known_tables: 125 | await conn.conn.execute(f"DROP TABLE IF EXISTS {table} CASCADE", timeout=30.0) 126 | 127 | # Drop all types after tests 128 | for known_type in known_types: 129 | await conn.conn.execute( 130 | f"DROP TYPE IF EXISTS {known_type} CASCADE", timeout=30.0 131 | ) 132 | 133 | await conn.conn.close() 134 | 135 | 136 | @pytest_asyncio.fixture() 137 | async def indexed_db_connection(db_connection: DBConnection): 138 | await db_connection.conn.execute( 139 | "CREATE INDEX IF NOT EXISTS article_title_tsv_idx ON article USING GIN (to_tsvector('english', title))", 140 | timeout=30.0, 141 | ) 142 | await db_connection.conn.execute( 143 | "CREATE INDEX IF NOT EXISTS article_content_tsv_idx ON article USING GIN (to_tsvector('english', content))", 144 | timeout=30.0, 145 | ) 146 | await db_connection.conn.execute( 147 | "CREATE INDEX IF NOT EXISTS article_summary_tsv_idx ON article USING GIN (to_tsvector('english', summary))", 148 | timeout=30.0, 149 | ) 150 | 151 | yield db_connection 152 | 153 | 154 | @pytest_asyncio.fixture(autouse=True) 155 | async def clear_table(db_connection): 156 | # Clear all tables and reset sequences 157 | await db_connection.conn.execute( 158 | "TRUNCATE TABLE userdemo, article RESTART IDENTITY CASCADE", timeout=30.0 159 | ) 160 | 161 | 162 | @pytest_asyncio.fixture 163 | async def clear_all_database_objects(db_connection: DBConnection): 164 | """ 165 | Clear all database objects. 166 | """ 167 | # Step 1: Drop all tables in the public schema 168 | await db_connection.conn.execute( 169 | """ 170 | DO $$ DECLARE 171 | r RECORD; 172 | BEGIN 173 | FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP 174 | EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE'; 175 | END LOOP; 176 | END $$; 177 | """, 178 | timeout=30.0, 179 | ) 180 | 181 | # Step 2: Drop all custom types in the public schema 182 | await db_connection.conn.execute( 183 | """ 184 | DO $$ DECLARE 185 | r RECORD; 186 | BEGIN 187 | FOR r IN (SELECT typname FROM pg_type WHERE typtype = 'e' AND typnamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'public')) LOOP 188 | EXECUTE 'DROP TYPE IF EXISTS ' || quote_ident(r.typname) || ' CASCADE'; 189 | END LOOP; 190 | END $$; 191 | """, 192 | timeout=30.0, 193 | ) 194 | 195 | 196 | @pytest.fixture 197 | def clear_registry(): 198 | current_registry = DBModelMetaclass._registry 199 | DBModelMetaclass._registry = [] 200 | 201 | try: 202 | yield 203 | finally: 204 | DBModelMetaclass._registry = current_registry 205 | -------------------------------------------------------------------------------- /iceaxe/modifications.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import traceback 3 | from dataclasses import dataclass 4 | from typing import Literal, Sequence, TypeVar 5 | 6 | from iceaxe.base import TableBase 7 | from iceaxe.logging import LOGGER 8 | 9 | MODIFICATION_TRACKER_VERBOSITY = Literal["ERROR", "WARNING", "INFO"] | None 10 | T = TypeVar("T", bound=TableBase) 11 | 12 | 13 | @dataclass 14 | class Modification: 15 | """ 16 | Tracks a single modification to a database model instance, including stack trace information. 17 | 18 | This class stores both the full stack trace and a simplified user-specific stack trace 19 | that excludes library code. This helps with debugging by showing where in the user's 20 | code a modification was made. 21 | 22 | :param instance: The model instance that was modified 23 | :param stack_trace: The complete stack trace at the time of modification 24 | :param user_stack_trace: The most relevant user code stack trace, excluding library code 25 | """ 26 | 27 | instance: TableBase 28 | 29 | # The full stack trace of the modification. 30 | stack_trace: str 31 | 32 | # Most specific line of the stack trace that is part of the user's code. 33 | user_stack_trace: str 34 | 35 | @staticmethod 36 | def get_current_stack_trace( 37 | package_allow_list: list[str] | None = None, 38 | package_deny_list: list[str] | None = None, 39 | ) -> tuple[str, str]: 40 | """ 41 | Get both the full stack trace and the most specific user code stack trace. 42 | 43 | The user stack trace filters out library code and frozen code to focus on 44 | the most relevant user code location where a modification occurred. 45 | 46 | :return: A tuple containing (full_stack_trace, user_stack_trace) 47 | :rtype: tuple[str, str] 48 | """ 49 | stack = traceback.extract_stack()[:-1] # Remove the current frame 50 | full_trace = "".join(traceback.format_list(stack)) 51 | 52 | # Find the most specific user code stack trace by filtering out library code 53 | user_traces = [ 54 | frame 55 | for frame in stack 56 | if ( 57 | package_allow_list is None 58 | or any(pkg in frame.filename for pkg in package_allow_list) 59 | ) 60 | and ( 61 | package_deny_list is None 62 | or not any(pkg in frame.filename for pkg in package_deny_list) 63 | ) 64 | ] 65 | 66 | user_trace = "" 67 | if user_traces: 68 | user_trace = "".join(traceback.format_list([user_traces[-1]])) 69 | 70 | return full_trace, user_trace 71 | 72 | 73 | class ModificationTracker: 74 | """ 75 | Tracks modifications to database model instances and manages their lifecycle. 76 | 77 | This class maintains a record of all modified model instances that haven't been 78 | committed yet. It provides functionality to track new modifications, handle commits, 79 | and log any remaining uncommitted modifications. 80 | 81 | The tracker organizes modifications by model class and prevents duplicate tracking 82 | of the same instance. It also captures stack traces at the point of modification 83 | to help with debugging. 84 | """ 85 | 86 | modified_models: dict[int, Modification] 87 | """ 88 | Dictionary mapping model classes to lists of their modifications 89 | """ 90 | 91 | verbosity: MODIFICATION_TRACKER_VERBOSITY | None 92 | """ 93 | The logging level to use when reporting uncommitted modifications 94 | """ 95 | 96 | def __init__( 97 | self, 98 | verbosity: MODIFICATION_TRACKER_VERBOSITY | None = None, 99 | known_first_party: list[str] | None = None, 100 | ): 101 | """ 102 | Initialize a new ModificationTracker. 103 | 104 | Creates an empty modification tracking dictionary and sets the initial 105 | verbosity level to None. 106 | """ 107 | self.modified_models = {} 108 | self.verbosity = verbosity 109 | self.known_first_party = known_first_party 110 | 111 | def track_modification(self, instance: TableBase) -> None: 112 | """ 113 | Track a modification to a model instance along with its stack trace. 114 | 115 | This method records a modification to a model instance if it hasn't already 116 | been tracked. It captures both the full stack trace and a user-specific 117 | stack trace at the point of modification. 118 | 119 | :param instance: The model instance that was modified 120 | :type instance: TableBase 121 | """ 122 | # Get stack traces. By default we filter out all iceaxe code, but allow users to override this behavior 123 | # if we want to still test this logic under test. 124 | full_trace, user_trace = Modification.get_current_stack_trace( 125 | package_allow_list=self.known_first_party, 126 | package_deny_list=["iceaxe"] if not self.known_first_party else None, 127 | ) 128 | 129 | # Only track if we haven't already tracked this instance 130 | instance_id = id(instance) 131 | if instance_id not in self.modified_models: 132 | modification = Modification( 133 | instance=instance, stack_trace=full_trace, user_stack_trace=user_trace 134 | ) 135 | self.modified_models[instance_id] = modification 136 | 137 | def clear_status(self, models: Sequence[TableBase]) -> None: 138 | """ 139 | Remove models that are about to be committed from tracking. 140 | 141 | This method should be called before committing changes to the database. 142 | It removes the specified models from tracking since they will no longer 143 | be in an uncommitted state. 144 | 145 | :param models: List of model instances that will be committed 146 | :type models: list[TableBase] 147 | """ 148 | for instance in models: 149 | instance_id = id(instance) 150 | if instance_id in self.modified_models: 151 | del self.modified_models[instance_id] 152 | 153 | def log(self) -> None: 154 | """ 155 | Log all uncommitted modifications with their stack traces. 156 | 157 | This method logs information about all tracked modifications that haven't 158 | been committed yet. The logging level is determined by the tracker's 159 | verbosity setting. At the INFO level, it includes the full stack trace 160 | in addition to the user stack trace. 161 | 162 | If verbosity is not set (None), this method does nothing. 163 | """ 164 | if not self.verbosity: 165 | return 166 | 167 | log_level = getattr(logging, self.verbosity) 168 | LOGGER.setLevel(log_level) # Ensure logger will capture messages at this level 169 | 170 | for mod in self.modified_models.values(): 171 | LOGGER.log( 172 | log_level, f"Object modified locally but not committed: {mod.instance}" 173 | ) 174 | LOGGER.log(log_level, f"Modified at:\n{mod.user_stack_trace}") 175 | if self.verbosity == "INFO": 176 | LOGGER.log(log_level, f"Full stack trace:\n{mod.stack_trace}") 177 | -------------------------------------------------------------------------------- /iceaxe/__tests__/docker_helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Docker helper utilities for testing. 3 | 4 | This module provides classes and functions to manage Docker containers for testing, 5 | particularly focusing on PostgreSQL database containers. 6 | """ 7 | 8 | import logging 9 | import socket 10 | import time 11 | import uuid 12 | from typing import Any, Dict, Optional, cast 13 | 14 | import docker 15 | from docker.errors import APIError 16 | 17 | # Configure logging 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def get_free_port() -> int: 22 | """Find a free port on the host machine.""" 23 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 24 | s.bind(("", 0)) 25 | return s.getsockname()[1] 26 | 27 | 28 | class PostgresContainer: 29 | """ 30 | A class that manages a PostgreSQL Docker container for testing. 31 | 32 | This class handles the lifecycle of a PostgreSQL container, including: 33 | - Starting the container with appropriate configuration 34 | - Finding available ports 35 | - Waiting for the container to be ready 36 | - Providing connection information 37 | - Cleaning up after tests 38 | """ 39 | 40 | def __init__( 41 | self, 42 | pg_user: str = "iceaxe", 43 | pg_password: str = "mysecretpassword", 44 | pg_db: str = "iceaxe_test_db", 45 | postgres_version: str = "16", 46 | ): 47 | self.pg_user = pg_user 48 | self.pg_password = pg_password 49 | self.pg_db = pg_db 50 | self.postgres_version = postgres_version 51 | self.port = get_free_port() 52 | self.container: Optional[Any] = None 53 | self.client = docker.from_env() 54 | self.container_name = f"iceaxe-postgres-test-{uuid.uuid4().hex[:8]}" 55 | 56 | def start(self) -> Dict[str, Any]: 57 | """ 58 | Start the PostgreSQL container. 59 | 60 | Returns: 61 | Dict[str, Any]: Connection information for the PostgreSQL container 62 | 63 | Raises: 64 | RuntimeError: If the container fails to start or become ready 65 | """ 66 | logger.info(f"Starting PostgreSQL container on port {self.port}") 67 | 68 | max_attempts = 3 69 | attempt = 0 70 | 71 | while attempt < max_attempts: 72 | attempt += 1 73 | try: 74 | self.container = self._run_container(self.port) 75 | break 76 | except APIError as e: 77 | if "port is already allocated" in str(e) and attempt < max_attempts: 78 | logger.warning( 79 | f"Port {self.port} is still in use. Trying with a new port (attempt {attempt}/{max_attempts})." 80 | ) 81 | self.port = get_free_port() 82 | else: 83 | raise RuntimeError(f"Failed to start PostgreSQL container: {e}") 84 | 85 | # Wait for PostgreSQL to be ready 86 | if not self._wait_for_container_ready(): 87 | self.stop() 88 | raise RuntimeError("Failed to connect to PostgreSQL container") 89 | 90 | return self.get_connection_info() 91 | 92 | def _run_container( 93 | self, port: int 94 | ) -> Any: # Type as Any since docker.models.containers.Container isn't imported 95 | """ 96 | Run the Docker container with the specified port. 97 | 98 | Args: 99 | port: The port to map PostgreSQL to on the host 100 | 101 | Returns: 102 | The Docker container object 103 | """ 104 | return self.client.containers.run( 105 | f"postgres:{self.postgres_version}", 106 | name=self.container_name, 107 | detach=True, 108 | environment={ 109 | "POSTGRES_USER": self.pg_user, 110 | "POSTGRES_PASSWORD": self.pg_password, 111 | "POSTGRES_DB": self.pg_db, 112 | # Additional settings for faster startup in testing 113 | "POSTGRES_HOST_AUTH_METHOD": "trust", 114 | }, 115 | ports={"5432/tcp": port}, 116 | remove=True, # Auto-remove container when stopped 117 | ) 118 | 119 | def _wait_for_container_ready(self) -> bool: 120 | """ 121 | Wait for the PostgreSQL container to be ready. 122 | 123 | Returns: 124 | bool: True if the container is ready, False otherwise 125 | """ 126 | max_retries = 30 127 | retry_interval = 1 128 | 129 | for i in range(max_retries): 130 | try: 131 | if self.container is None: 132 | logger.warning("Container is None, cannot proceed") 133 | return False 134 | 135 | # We've already checked that self.container is not None 136 | container = cast(Any, self.container) 137 | container.reload() # Refresh container status 138 | if container.status != "running": 139 | logger.warning(f"Container status: {container.status}") 140 | return False 141 | 142 | # Try to connect to PostgreSQL 143 | conn = socket.create_connection(("localhost", self.port), timeout=1) 144 | conn.close() 145 | # Wait a bit more to ensure PostgreSQL is fully initialized 146 | time.sleep(2) 147 | logger.info(f"PostgreSQL container is ready after {i + 1} attempt(s)") 148 | return True 149 | except (socket.error, ConnectionRefusedError) as e: 150 | if i == max_retries - 1: 151 | logger.warning( 152 | f"Failed to connect after {max_retries} attempts: {e}" 153 | ) 154 | return False 155 | time.sleep(retry_interval) 156 | except Exception as e: 157 | logger.warning(f"Unexpected error checking container readiness: {e}") 158 | if i == max_retries - 1: 159 | return False 160 | time.sleep(retry_interval) 161 | 162 | return False 163 | 164 | def stop(self) -> None: 165 | """ 166 | Stop the PostgreSQL container. 167 | 168 | This method ensures the container is properly stopped and removed. 169 | """ 170 | if self.container is not None: 171 | try: 172 | logger.info(f"Stopping PostgreSQL container {self.container_name}") 173 | # We've already checked that self.container is not None 174 | container = cast(Any, self.container) 175 | container.stop(timeout=10) # Allow 10 seconds for graceful shutdown 176 | except Exception as e: 177 | logger.warning(f"Failed to stop container: {e}") 178 | try: 179 | # Force remove as a fallback 180 | if self.container is not None: 181 | self.container.remove(force=True) 182 | logger.info("Forced container removal") 183 | except Exception as e2: 184 | logger.warning(f"Failed to force remove container: {e2}") 185 | 186 | def get_connection_info(self) -> Dict[str, Any]: 187 | """ 188 | Get the connection information for the PostgreSQL container. 189 | 190 | Returns: 191 | Dict[str, Any]: A dictionary containing connection parameters 192 | """ 193 | return { 194 | "host": "localhost", 195 | "port": self.port, 196 | "user": self.pg_user, 197 | "password": self.pg_password, 198 | "database": self.pg_db, 199 | } 200 | 201 | def get_connection_string(self) -> str: 202 | """ 203 | Get a PostgreSQL connection string. 204 | 205 | Returns: 206 | str: A connection string in the format 'postgresql://user:password@host:port/database' 207 | """ 208 | return f"postgresql://{self.pg_user}:{self.pg_password}@localhost:{self.port}/{self.pg_db}" 209 | -------------------------------------------------------------------------------- /iceaxe/__tests__/schemas/test_db_stubs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from iceaxe.schemas.db_stubs import ConstraintPointerInfo, DBObjectPointer, DBType 4 | 5 | 6 | class MockDBObjectPointer(DBObjectPointer): 7 | """Mock implementation of DBObjectPointer for testing parser methods.""" 8 | 9 | representation_str: str 10 | 11 | def representation(self) -> str: 12 | return self.representation_str 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "representation_str,expected_result", 17 | [ 18 | # Valid constraint pointer formats 19 | ( 20 | "users.['id'].PRIMARY KEY", 21 | ConstraintPointerInfo("users", ["id"], "PRIMARY KEY"), 22 | ), 23 | ( 24 | "orders.['user_id', 'product_id'].UNIQUE", 25 | ConstraintPointerInfo("orders", ["user_id", "product_id"], "UNIQUE"), 26 | ), 27 | ( 28 | "products.['name'].INDEX", 29 | ConstraintPointerInfo("products", ["name"], "INDEX"), 30 | ), 31 | ( 32 | "table_name.['col1', 'col2', 'col3'].FOREIGN KEY", 33 | ConstraintPointerInfo( 34 | "table_name", ["col1", "col2", "col3"], "FOREIGN KEY" 35 | ), 36 | ), 37 | # Single quotes 38 | ("users.['email'].UNIQUE", ConstraintPointerInfo("users", ["email"], "UNIQUE")), 39 | # Double quotes 40 | ('users.["email"].UNIQUE', ConstraintPointerInfo("users", ["email"], "UNIQUE")), 41 | # Mixed quotes 42 | ( 43 | "users.[\"col1\", 'col2'].UNIQUE", 44 | ConstraintPointerInfo("users", ["col1", "col2"], "UNIQUE"), 45 | ), 46 | # Extra whitespace 47 | ( 48 | "users.[ 'col1' , 'col2' ].UNIQUE", 49 | ConstraintPointerInfo("users", ["col1", "col2"], "UNIQUE"), 50 | ), 51 | # Empty column list 52 | ("users.[].CHECK", ConstraintPointerInfo("users", [], "CHECK")), 53 | # Schema-qualified table names (dots in table names are valid when representing schema.table) 54 | ( 55 | "public.users.['column'].PRIMARY KEY", 56 | ConstraintPointerInfo("public.users", ["column"], "PRIMARY KEY"), 57 | ), 58 | # Complex constraint types 59 | ( 60 | "users.['id'].PRIMARY KEY AUTOINCREMENT", 61 | ConstraintPointerInfo("users", ["id"], "PRIMARY KEY AUTOINCREMENT"), 62 | ), 63 | # Table names with underscores and numbers (valid PostgreSQL identifiers) 64 | ( 65 | "user_table_2.['id'].PRIMARY KEY", 66 | ConstraintPointerInfo("user_table_2", ["id"], "PRIMARY KEY"), 67 | ), 68 | # Column names with underscores and numbers 69 | ( 70 | "users.['user_id_2', 'created_at'].UNIQUE", 71 | ConstraintPointerInfo("users", ["user_id_2", "created_at"], "UNIQUE"), 72 | ), 73 | # Invalid formats that should return None 74 | ("users.column.UNIQUE", None), # Missing brackets 75 | ("users.['column']", None), # Missing constraint type 76 | ("['column'].UNIQUE", None), # Missing table name 77 | ("users", None), # Just table name 78 | ("", None), # Empty string 79 | ("users.column", None), # Simple table.column format 80 | ("invalid_format", None), # Random string 81 | # Malformed bracket syntax 82 | ("users.[column].UNIQUE", None), # Missing quotes in brackets 83 | ("users.['column.UNIQUE", None), # Unclosed bracket 84 | ("users.column'].UNIQUE", None), # Missing opening bracket 85 | ], 86 | ) 87 | def test_parse_constraint_pointer( 88 | representation_str: str, expected_result: ConstraintPointerInfo | None 89 | ): 90 | """Test parsing of constraint pointer representations.""" 91 | pointer = MockDBObjectPointer(representation_str=representation_str) 92 | result = pointer.parse_constraint_pointer() 93 | 94 | if expected_result is None: 95 | assert result is None 96 | else: 97 | assert result is not None 98 | assert result.table_name == expected_result.table_name 99 | assert result.column_names == expected_result.column_names 100 | assert result.constraint_type == expected_result.constraint_type 101 | 102 | 103 | @pytest.mark.parametrize( 104 | "representation_str,expected_table_name", 105 | [ 106 | # Constraint pointer formats 107 | ("users.['id'].PRIMARY KEY", "users"), 108 | ("orders.['user_id', 'product_id'].UNIQUE", "orders"), 109 | ("public.users.['column'].INDEX", "public.users"), 110 | # Simple table.column formats 111 | ("users.email", "users"), 112 | ("products.name", "products"), 113 | ("public.users.column", "public.users"), # Schema.table.column format 114 | # Edge cases 115 | ("table_only", "table_only"), 116 | ("", None), # Empty string should return None 117 | ("users.['id'].PRIMARY KEY", "users"), # Constraint format takes precedence 118 | # Complex table names with underscores and numbers 119 | ("user_table_123.column", "user_table_123"), 120 | ("schema_1.table_2.column", "schema_1.table_2"), 121 | # Multiple dots in representation (should extract the table part correctly) 122 | ("very.long.schema.table.['col'].UNIQUE", "very.long.schema.table"), 123 | ], 124 | ) 125 | def test_get_table_name(representation_str: str, expected_table_name: str | None): 126 | """Test extraction of table names from pointer representations.""" 127 | pointer = MockDBObjectPointer(representation_str=representation_str) 128 | result = pointer.get_table_name() 129 | assert result == expected_table_name 130 | 131 | 132 | @pytest.mark.parametrize( 133 | "representation_str,expected_column_names", 134 | [ 135 | # Constraint pointer formats 136 | ("users.['id'].PRIMARY KEY", ["id"]), 137 | ("orders.['user_id', 'product_id'].UNIQUE", ["user_id", "product_id"]), 138 | ("products.['name', 'category', 'price'].INDEX", ["name", "category", "price"]), 139 | ("users.[].CHECK", []), # Empty column list 140 | # Simple table.column formats 141 | ("users.email", ["email"]), 142 | ("products.name", ["name"]), 143 | ("public.users.column", ["column"]), # Schema.table.column format 144 | # Edge cases 145 | ("table_only", []), # No columns 146 | ("", []), # Empty string 147 | # Whitespace handling 148 | ("users.[ 'col1' , 'col2' ].UNIQUE", ["col1", "col2"]), 149 | # Quote handling 150 | ("users.[\"col1\", 'col2'].UNIQUE", ["col1", "col2"]), 151 | # Column names with underscores and numbers 152 | ( 153 | "users.['user_id_2', 'created_at_timestamp'].UNIQUE", 154 | ["user_id_2", "created_at_timestamp"], 155 | ), 156 | # Complex schema.table.column cases 157 | ("schema.table.column_name", ["column_name"]), 158 | ("very.long.schema.table.column", ["column"]), 159 | ], 160 | ) 161 | def test_get_column_names(representation_str: str, expected_column_names: list[str]): 162 | """Test extraction of column names from pointer representations.""" 163 | pointer = MockDBObjectPointer(representation_str=representation_str) 164 | result = pointer.get_column_names() 165 | assert result == expected_column_names 166 | 167 | 168 | def test_merge_type_columns(): 169 | """ 170 | Allow separately yielded type definitions to collect their reference columns. If an 171 | enum is referenced in one place, this should build up to the full definition. 172 | 173 | """ 174 | type_a = DBType( 175 | name="type_a", 176 | values=frozenset({"A"}), 177 | reference_columns=frozenset({("table_a", "column_a")}), 178 | ) 179 | type_b = DBType( 180 | name="type_a", 181 | values=frozenset({"A"}), 182 | reference_columns=frozenset({("table_b", "column_b")}), 183 | ) 184 | 185 | merged = type_a.merge(type_b) 186 | assert merged.name == "type_a" 187 | assert merged.values == frozenset({"A"}) 188 | assert merged.reference_columns == frozenset( 189 | {("table_a", "column_a"), ("table_b", "column_b")} 190 | ) 191 | -------------------------------------------------------------------------------- /iceaxe/__tests__/migrations/test_action_sorter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from iceaxe.migrations.action_sorter import ActionTopologicalSorter 4 | from iceaxe.schemas.actions import DatabaseActions 5 | from iceaxe.schemas.db_stubs import DBObject 6 | 7 | 8 | class MockNode(DBObject): 9 | name: str 10 | table_name: str = "None" 11 | 12 | model_config = { 13 | "frozen": True, 14 | } 15 | 16 | def representation(self): 17 | return f"MockNode({self.name}, {self.table_name})" 18 | 19 | async def create(self, actor: DatabaseActions): 20 | pass 21 | 22 | async def migrate(self, previous: DBObject, actor: DatabaseActions): 23 | pass 24 | 25 | async def destroy(self, actor: DatabaseActions): 26 | pass 27 | 28 | def __hash__(self): 29 | return hash(self.representation()) 30 | 31 | 32 | def custom_topological_sort(graph_edges): 33 | sorter = ActionTopologicalSorter(graph_edges) 34 | sorted_objects = sorter.sort() 35 | return {obj: i for i, obj in enumerate(sorted_objects)} 36 | 37 | 38 | def test_simple_dag(): 39 | A = MockNode(name="A") 40 | B = MockNode(name="B") 41 | C = MockNode(name="C") 42 | D = MockNode(name="D") 43 | graph = {D: [B, C], C: [A], B: [A], A: []} 44 | result = custom_topological_sort(graph) 45 | assert list(result.keys()) == [A, C, B, D] 46 | 47 | 48 | def test_disconnected_graph(): 49 | A = MockNode(name="A") 50 | B = MockNode(name="B") 51 | C = MockNode(name="C") 52 | D = MockNode(name="D") 53 | E = MockNode(name="E") 54 | graph = {B: [A], A: [], D: [C], C: [], E: []} 55 | result = custom_topological_sort(graph) 56 | assert set(result.keys()) == {A, B, C, D, E} 57 | assert result[A] < result[B] 58 | assert result[C] < result[D] 59 | 60 | 61 | def test_single_table_grouping(): 62 | A = MockNode(name="A", table_name="table1") 63 | B = MockNode(name="B", table_name="table1") 64 | C = MockNode(name="C", table_name="table1") 65 | graph = {C: [], B: [C], A: [B]} 66 | result = custom_topological_sort(graph) 67 | assert list(result.keys()) == [C, B, A] 68 | 69 | 70 | def test_multiple_table_grouping(): 71 | A = MockNode(name="A", table_name="table1") 72 | B = MockNode(name="B", table_name="table1") 73 | C = MockNode(name="C", table_name="table2") 74 | D = MockNode(name="D", table_name="table2") 75 | E = MockNode(name="E", table_name="table3") 76 | graph = {E: [], D: [], C: [D, E], A: [C], B: [C]} 77 | result = custom_topological_sort(graph) 78 | assert set(result.keys()) == {A, B, C, D, E} 79 | assert result[C] < result[A] and result[C] < result[B] 80 | assert result[D] < result[C] and result[E] < result[C] 81 | 82 | 83 | def test_cross_table_references(): 84 | A = MockNode(name="A", table_name="table1") 85 | B = MockNode(name="B", table_name="table2") 86 | C = MockNode(name="C", table_name="table1") 87 | D = MockNode(name="D", table_name="table2") 88 | graph = {D: [], C: [D], B: [C], A: [B]} 89 | result = custom_topological_sort(graph) 90 | assert list(result.keys()) == [D, C, B, A] 91 | 92 | 93 | def test_nodes_without_table_name(): 94 | A = MockNode(name="A", table_name="table1") 95 | B = MockNode(name="B") 96 | C = MockNode(name="C", table_name="table2") 97 | D = MockNode(name="D") 98 | graph = {D: [], C: [D], B: [C], A: [B]} 99 | result = custom_topological_sort(graph) 100 | assert list(result.keys()) == [D, C, B, A] 101 | 102 | 103 | def test_complex_graph(): 104 | A = MockNode(name="A", table_name="table1") 105 | B = MockNode(name="B", table_name="table1") 106 | C = MockNode(name="C", table_name="table2") 107 | D = MockNode(name="D", table_name="table2") 108 | E = MockNode(name="E", table_name="table3") 109 | F = MockNode(name="F") 110 | G = MockNode(name="G", table_name="table3") 111 | graph = {G: [], F: [G], E: [G], D: [F], C: [E, F], A: [C, D], B: [C]} 112 | result = custom_topological_sort(graph) 113 | assert set(result.keys()) == {A, B, C, D, E, F, G} 114 | assert result[C] < result[A] and result[D] < result[A] 115 | assert result[C] < result[B] 116 | assert result[E] < result[C] and result[F] < result[C] 117 | assert result[F] < result[D] 118 | assert result[G] < result[E] and result[G] < result[F] 119 | 120 | 121 | def test_cyclic_graph(): 122 | A = MockNode(name="A") 123 | B = MockNode(name="B") 124 | C = MockNode(name="C") 125 | graph = {A: [B], B: [C], C: [A]} 126 | with pytest.raises(ValueError, match="Graph contains a cycle"): 127 | custom_topological_sort(graph) 128 | 129 | 130 | def test_empty_graph(): 131 | graph = {} 132 | result = custom_topological_sort(graph) 133 | assert result == {} 134 | 135 | 136 | def test_single_node_graph(): 137 | A = MockNode(name="A") 138 | graph = {A: []} 139 | result = custom_topological_sort(graph) 140 | assert result == {A: 0} 141 | 142 | 143 | def test_all_nodes_same_table(): 144 | A = MockNode(name="A", table_name="table1") 145 | B = MockNode(name="B", table_name="table1") 146 | C = MockNode(name="C", table_name="table1") 147 | D = MockNode(name="D", table_name="table1") 148 | graph = {D: [], B: [D], C: [D], A: [B, C]} 149 | result = custom_topological_sort(graph) 150 | assert list(result.keys()) == [D, B, C, A] 151 | 152 | 153 | def test_mixed_node_types(): 154 | A = MockNode(name="A", table_name="table1") 155 | B = MockNode(name="B") 156 | C = MockNode(name="C", table_name="table2") 157 | D = MockNode(name="D", table_name="table3") 158 | graph = {D: [], C: [D], B: [C], A: [B]} 159 | result = custom_topological_sort(graph) 160 | assert list(result.keys()) == [D, C, B, A] 161 | 162 | 163 | def test_large_graph_performance(): 164 | import random 165 | import string 166 | import time 167 | 168 | def generate_large_graph(size): 169 | nodes = [MockNode(name=c) for c in string.ascii_uppercase] + [ 170 | MockNode(name=f"N{i}") for i in range(size - 26) 171 | ] 172 | graph = {node: set() for node in nodes} 173 | for i, node in enumerate(nodes): 174 | graph[node] = set(random.sample(nodes[i + 1 :], min(5, len(nodes) - i - 1))) 175 | return graph 176 | 177 | large_graph = generate_large_graph(1000) 178 | start_time = time.time() 179 | result = custom_topological_sort(large_graph) 180 | end_time = time.time() 181 | assert len(result) == 1000 182 | assert end_time - start_time < 5 183 | 184 | 185 | def test_graph_with_isolated_nodes(): 186 | A = MockNode(name="A") 187 | B = MockNode(name="B") 188 | C = MockNode(name="C") 189 | D = MockNode(name="D") 190 | E = MockNode(name="E") 191 | graph = {A: [B], B: [], C: [], D: [E], E: []} 192 | result = custom_topological_sort(graph) 193 | assert set(result.keys()) == {A, B, C, D, E} 194 | assert result[B] < result[A] 195 | assert result[E] < result[D] 196 | 197 | 198 | @pytest.mark.parametrize( 199 | "graph, expected_order", 200 | [ 201 | ( 202 | { 203 | MockNode(name="C"): set(), 204 | MockNode(name="B"): {MockNode(name="C")}, 205 | MockNode(name="A"): {MockNode(name="B")}, 206 | }, 207 | ["C", "B", "A"], 208 | ), 209 | ( 210 | { 211 | MockNode(name="D"): set(), 212 | MockNode(name="B"): {MockNode(name="D")}, 213 | MockNode(name="C"): {MockNode(name="D")}, 214 | MockNode(name="A"): {MockNode(name="B"), MockNode(name="C")}, 215 | }, 216 | ["D", "B", "C", "A"], 217 | ), 218 | ], 219 | ) 220 | def test_various_graph_structures(graph, expected_order): 221 | result = custom_topological_sort(graph) 222 | assert [node.name for node in result.keys()] == expected_order 223 | 224 | 225 | def test_consistent_results(): 226 | """ 227 | Test for consistent results with same input 228 | 229 | """ 230 | A = MockNode(name="A", table_name="table1") 231 | B = MockNode(name="B", table_name="table1") 232 | C = MockNode(name="C", table_name="table2") 233 | D = MockNode(name="D", table_name="table2") 234 | graph = {D: [], C: [D], B: [D], A: [B, C]} 235 | result1 = custom_topological_sort(graph) 236 | result2 = custom_topological_sort(graph) 237 | assert list(result1.keys()) == list(result2.keys()) 238 | -------------------------------------------------------------------------------- /iceaxe/field.py: -------------------------------------------------------------------------------- 1 | from json import dumps as json_dumps 2 | from typing import ( 3 | TYPE_CHECKING, 4 | Any, 5 | Callable, 6 | Concatenate, 7 | Generic, 8 | ParamSpec, 9 | Type, 10 | TypeVar, 11 | Unpack, 12 | cast, 13 | ) 14 | 15 | from pydantic import Field as PydanticField 16 | from pydantic.fields import FieldInfo, _FieldInfoInputs 17 | from pydantic_core import PydanticUndefined 18 | 19 | from iceaxe.comparison import ComparisonBase 20 | from iceaxe.postgres import PostgresFieldBase 21 | from iceaxe.queries_str import QueryIdentifier, QueryLiteral 22 | from iceaxe.sql_types import ColumnType 23 | 24 | if TYPE_CHECKING: 25 | from iceaxe.base import TableBase 26 | 27 | P = ParamSpec("P") 28 | 29 | _Unset: Any = PydanticUndefined 30 | 31 | 32 | class DBFieldInputs(_FieldInfoInputs, total=False): 33 | primary_key: bool 34 | autoincrement: bool 35 | postgres_config: PostgresFieldBase | None 36 | foreign_key: str | None 37 | unique: bool 38 | index: bool 39 | check_expression: str | None 40 | is_json: bool 41 | explicit_type: ColumnType | None 42 | 43 | 44 | class DBFieldInfo(FieldInfo): 45 | """ 46 | Extended field information for database fields, building upon Pydantic's FieldInfo. 47 | This class adds database-specific attributes and functionality for field configuration 48 | in SQL databases, particularly PostgreSQL. 49 | 50 | This class is used internally by the Field constructor to store metadata about 51 | database columns, including constraints, foreign keys, and PostgreSQL-specific 52 | configurations. 53 | """ 54 | 55 | primary_key: bool = False 56 | """ 57 | Indicates if this field serves as the primary key for the table. 58 | When True, this field will be used as the unique identifier for rows. 59 | """ 60 | 61 | autoincrement: bool = False 62 | """ 63 | Controls whether the field should automatically increment. 64 | By default, this is True for primary key fields that have no default value set. 65 | """ 66 | 67 | postgres_config: PostgresFieldBase | None = None 68 | """ 69 | Custom PostgreSQL configuration for the field. 70 | Allows for type-specific customization of PostgreSQL parameters. 71 | """ 72 | 73 | foreign_key: str | None = None 74 | """ 75 | Specifies a foreign key relationship to another table. 76 | Format should be "table_name.column_name" if set. 77 | """ 78 | 79 | unique: bool = False 80 | """ 81 | When True, enforces that all values in this column must be unique. 82 | Creates a unique constraint in the database. 83 | """ 84 | 85 | index: bool = False 86 | """ 87 | When True, creates an index on this column to optimize query performance. 88 | """ 89 | 90 | check_expression: str | None = None 91 | """ 92 | SQL expression for a CHECK constraint on this column. 93 | Allows for custom validation rules at the database level. 94 | """ 95 | 96 | is_json: bool = False 97 | """ 98 | Indicates if this field should be stored as JSON in the database. 99 | When True, the field's value will be JSON serialized before storage. 100 | """ 101 | 102 | explicit_type: ColumnType | None = None 103 | """ 104 | Explicitly specify the SQL column type for this field. 105 | When set, this type takes precedence over automatic type inference. 106 | """ 107 | 108 | def __init__(self, **kwargs: Unpack[DBFieldInputs]): 109 | """ 110 | Initialize a new DBFieldInfo instance with the given field configuration. 111 | 112 | :param kwargs: Keyword arguments that configure the field's behavior. 113 | Includes all standard Pydantic field options plus database-specific options. 114 | 115 | """ 116 | # The super call should persist all kwargs as _attributes_set 117 | # We're intentionally passing kwargs that we know aren't in the 118 | # base typehinted dict 119 | super().__init__(**kwargs) # type: ignore 120 | self.primary_key = kwargs.pop("primary_key", False) 121 | self.autoincrement = kwargs.pop( 122 | "autoincrement", (self.primary_key and self.default is None) 123 | ) 124 | self.postgres_config = kwargs.pop("postgres_config", None) 125 | self.foreign_key = kwargs.pop("foreign_key", None) 126 | self.unique = kwargs.pop("unique", False) 127 | self.index = kwargs.pop("index", False) 128 | self.check_expression = kwargs.pop("check_expression", None) 129 | self.is_json = kwargs.pop("is_json", False) 130 | self.explicit_type = kwargs.pop("explicit_type", None) 131 | 132 | @classmethod 133 | def extend_field( 134 | cls, 135 | field: FieldInfo, 136 | primary_key: bool, 137 | postgres_config: PostgresFieldBase | None, 138 | foreign_key: str | None, 139 | unique: bool, 140 | index: bool, 141 | check_expression: str | None, 142 | is_json: bool, 143 | explicit_type: ColumnType | None, 144 | ): 145 | """ 146 | Helper function to extend a Pydantic FieldInfo with database-specific attributes. 147 | 148 | """ 149 | return cls( 150 | primary_key=primary_key, 151 | postgres_config=postgres_config, 152 | foreign_key=foreign_key, 153 | unique=unique, 154 | index=index, 155 | check_expression=check_expression, 156 | is_json=is_json, 157 | explicit_type=explicit_type, 158 | **field._attributes_set, # type: ignore 159 | ) 160 | 161 | def to_db_value(self, value: Any): 162 | if self.is_json: 163 | return json_dumps(value) 164 | return value 165 | 166 | 167 | def __get_db_field(_: Callable[Concatenate[Any, P], Any] = PydanticField): # type: ignore 168 | """ 169 | Workaround constructor to pass typehints through our function subclass 170 | to the original PydanticField constructor 171 | 172 | """ 173 | 174 | def func( 175 | primary_key: bool = False, 176 | postgres_config: PostgresFieldBase | None = None, 177 | foreign_key: str | None = None, 178 | unique: bool = False, 179 | index: bool = False, 180 | check_expression: str | None = None, 181 | is_json: bool = False, 182 | explicit_type: ColumnType | None = None, 183 | default: Any = _Unset, 184 | default_factory: ( 185 | Callable[[], Any] | Callable[[dict[str, Any]], Any] | None 186 | ) = _Unset, 187 | *args: P.args, 188 | **kwargs: P.kwargs, 189 | ): 190 | raw_field = PydanticField( 191 | default=default, default_factory=default_factory, **kwargs 192 | ) # type: ignore 193 | 194 | # The Any request is required for us to be able to assign fields to any 195 | # arbitrary type, like `value: str = Field()` 196 | return cast( 197 | Any, 198 | DBFieldInfo.extend_field( 199 | raw_field, 200 | primary_key=primary_key, 201 | postgres_config=postgres_config, 202 | foreign_key=foreign_key, 203 | unique=unique, 204 | index=index, 205 | check_expression=check_expression, 206 | is_json=is_json, 207 | explicit_type=explicit_type, 208 | ), 209 | ) 210 | 211 | return func 212 | 213 | 214 | T = TypeVar("T") 215 | 216 | 217 | class DBFieldClassDefinition(Generic[T], ComparisonBase[T]): 218 | """ 219 | The returned model when users access a field directly from 220 | the table class, e.g. `User.id` 221 | 222 | """ 223 | 224 | root_model: Type["TableBase"] 225 | key: str 226 | field_definition: DBFieldInfo 227 | 228 | def __init__( 229 | self, 230 | root_model: Type["TableBase"], 231 | key: str, 232 | field_definition: DBFieldInfo, 233 | ): 234 | self.root_model = root_model 235 | self.key = key 236 | self.field_definition = field_definition 237 | 238 | def to_query(self): 239 | table = QueryIdentifier(self.root_model.get_table_name()) 240 | column = QueryIdentifier(self.key) 241 | return QueryLiteral(f"{table}.{column}"), [] 242 | 243 | 244 | Field = __get_db_field() 245 | """ 246 | Create a new database field with optional database-specific configurations. 247 | 248 | This function extends Pydantic's Field with additional database functionality. It accepts 249 | all standard Pydantic Field parameters plus all the database-specific parameters defined 250 | in DBFieldInfo. 251 | 252 | ```python {{sticky: True}} 253 | from iceaxe import Field 254 | from iceaxe.base import TableBase 255 | 256 | class User(TableBase): 257 | id: int = Field(primary_key=True) 258 | username: str = Field(unique=True, index=True) 259 | settings: dict = Field(is_json=True, default_factory=dict) 260 | department_id: int = Field(foreign_key="departments.id") 261 | ``` 262 | 263 | """ 264 | -------------------------------------------------------------------------------- /iceaxe/session_optimized.pyx: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Tuple 2 | from iceaxe.base import TableBase 3 | from iceaxe.queries import FunctionMetadata 4 | from iceaxe.alias_values import Alias 5 | from json import loads as json_loads 6 | from cpython.ref cimport PyObject 7 | from cpython.object cimport PyObject_GetItem 8 | from libc.stdlib cimport malloc, free 9 | from libc.string cimport memcpy 10 | from cpython.ref cimport Py_INCREF, Py_DECREF 11 | 12 | cdef struct FieldInfo: 13 | char* name # Field name 14 | char* select_attribute # Corresponding attribute in the select_raw 15 | bint is_json # Flag indicating if the field is JSON 16 | 17 | cdef char* allocate_cstring(bytes data): 18 | cdef Py_ssize_t length = len(data) 19 | cdef char* c_str = malloc((length + 1) * sizeof(char)) 20 | if not c_str: 21 | raise MemoryError("Failed to allocate memory for C string.") 22 | memcpy(c_str, data, length) # Cast bytes to char* for memcpy 23 | c_str[length] = 0 # Null-terminate the string 24 | return c_str 25 | 26 | cdef void free_fields(FieldInfo** fields, Py_ssize_t* num_fields_array, Py_ssize_t num_selects): 27 | cdef Py_ssize_t j, k 28 | if fields: 29 | for j in range(num_selects): 30 | if fields[j]: 31 | for k in range(num_fields_array[j]): 32 | free(fields[j][k].name) 33 | free(fields[j][k].select_attribute) 34 | free(fields[j]) 35 | free(fields) 36 | if num_fields_array: 37 | free(num_fields_array) 38 | 39 | cdef FieldInfo** precompute_fields(list select_raws, list select_types, Py_ssize_t num_selects, Py_ssize_t* num_fields_array): 40 | cdef FieldInfo** fields = malloc(num_selects * sizeof(FieldInfo*)) 41 | cdef Py_ssize_t j, k, num_fields 42 | cdef dict field_dict 43 | cdef bytes select_bytes, field_bytes 44 | cdef char* c_select 45 | cdef char* c_field 46 | cdef object select_raw 47 | cdef bint raw_is_table, raw_is_column, raw_is_function_metadata 48 | 49 | if not fields: 50 | raise MemoryError("Failed to allocate memory for fields.") 51 | 52 | for j in range(num_selects): 53 | select_raw = select_raws[j] 54 | raw_is_table, raw_is_column, raw_is_function_metadata = select_types[j] 55 | 56 | if raw_is_table: 57 | field_dict = {field: info.is_json for field, info in select_raw.get_client_fields().items() if not info.exclude} 58 | num_fields = len(field_dict) 59 | num_fields_array[j] = num_fields 60 | fields[j] = malloc(num_fields * sizeof(FieldInfo)) 61 | if not fields[j]: 62 | raise MemoryError("Failed to allocate memory for FieldInfo.") 63 | 64 | for k, (field, is_json) in enumerate(field_dict.items()): 65 | select_bytes = f"{select_raw.get_table_name()}_{field}".encode('utf-8') 66 | c_select = allocate_cstring(select_bytes) 67 | 68 | field_bytes = field.encode('utf-8') 69 | c_field = allocate_cstring(field_bytes) 70 | 71 | fields[j][k].select_attribute = c_select 72 | fields[j][k].name = c_field 73 | fields[j][k].is_json = is_json 74 | else: 75 | num_fields_array[j] = 0 76 | fields[j] = NULL 77 | 78 | return fields 79 | 80 | cdef list process_values( 81 | list values, 82 | FieldInfo** fields, 83 | Py_ssize_t* num_fields_array, 84 | list select_raws, 85 | list select_types, 86 | Py_ssize_t num_selects 87 | ): 88 | cdef Py_ssize_t num_values = len(values) 89 | cdef list result_all = [None] * num_values 90 | cdef Py_ssize_t i, j, k, num_fields 91 | cdef PyObject** result_value 92 | cdef object value, obj, item 93 | cdef dict obj_dict 94 | cdef bint raw_is_table, raw_is_column, raw_is_function_metadata, raw_is_alias 95 | cdef char* field_name_c 96 | cdef char* select_name_c 97 | cdef str field_name 98 | cdef str select_name 99 | cdef object field_value 100 | cdef object select_raw 101 | cdef PyObject* temp_obj 102 | cdef bint all_none 103 | 104 | for i in range(num_values): 105 | value = values[i] 106 | result_value = malloc(num_selects * sizeof(PyObject*)) 107 | if not result_value: 108 | raise MemoryError("Failed to allocate memory for result_value.") 109 | try: 110 | for j in range(num_selects): 111 | select_raw = select_raws[j] 112 | raw_is_table, raw_is_column, raw_is_function_metadata = select_types[j] 113 | raw_is_alias = isinstance(select_raw, Alias) 114 | 115 | if raw_is_table: 116 | obj_dict = {} 117 | num_fields = num_fields_array[j] 118 | all_none = True 119 | 120 | # First pass: collect all fields and check if they're all None 121 | for k in range(num_fields): 122 | field_name_c = fields[j][k].name 123 | select_name_c = fields[j][k].select_attribute 124 | field_name = field_name_c.decode('utf-8') 125 | select_name = select_name_c.decode('utf-8') 126 | 127 | try: 128 | field_value = value[select_name] 129 | except KeyError: 130 | raise KeyError(f"Key '{select_name}' not found in value.") 131 | 132 | if field_value is not None: 133 | all_none = False 134 | if fields[j][k].is_json: 135 | field_value = json_loads(field_value) 136 | 137 | obj_dict[field_name] = field_value 138 | 139 | # If all fields are None, store None instead of creating the table object 140 | if all_none: 141 | result_value[j] = None 142 | Py_INCREF(None) 143 | else: 144 | obj = select_raw(**obj_dict) 145 | result_value[j] = obj 146 | Py_INCREF(obj) 147 | 148 | elif raw_is_column: 149 | try: 150 | # Use the table-qualified column name 151 | table_name = select_raw.root_model.get_table_name() 152 | column_name = select_raw.key 153 | item = value[f"{table_name}_{column_name}"] 154 | except KeyError: 155 | raise KeyError(f"Key '{table_name}_{column_name}' not found in value.") 156 | result_value[j] = item 157 | Py_INCREF(item) 158 | 159 | elif raw_is_function_metadata: 160 | try: 161 | item = value[select_raw.local_name] 162 | except KeyError: 163 | raise KeyError(f"Key '{select_raw.local_name}' not found in value.") 164 | result_value[j] = item 165 | Py_INCREF(item) 166 | 167 | elif raw_is_alias: 168 | try: 169 | item = value[select_raw.name] 170 | except KeyError: 171 | raise KeyError(f"Key '{select_raw.name}' not found in value.") 172 | result_value[j] = item 173 | Py_INCREF(item) 174 | 175 | # Assemble the result 176 | if num_selects == 1: 177 | result_all[i] = result_value[0] 178 | Py_DECREF(result_value[0]) 179 | else: 180 | result_tuple = tuple([result_value[j] for j in range(num_selects)]) 181 | for j in range(num_selects): 182 | Py_DECREF(result_value[j]) 183 | result_all[i] = result_tuple 184 | 185 | finally: 186 | free(result_value) 187 | 188 | return result_all 189 | 190 | cdef list optimize_casting(list values, list select_raws, list select_types): 191 | cdef Py_ssize_t num_selects = len(select_raws) 192 | cdef Py_ssize_t* num_fields_array = malloc(num_selects * sizeof(Py_ssize_t)) 193 | cdef FieldInfo** fields 194 | cdef list result_all 195 | 196 | if not num_fields_array: 197 | raise MemoryError("Failed to allocate memory for num_fields_array.") 198 | 199 | try: 200 | fields = precompute_fields(select_raws, select_types, num_selects, num_fields_array) 201 | result_all = process_values(values, fields, num_fields_array, select_raws, select_types, num_selects) 202 | finally: 203 | free_fields(fields, num_fields_array, num_selects) 204 | 205 | return result_all 206 | 207 | def optimize_exec_casting( 208 | values: List[Any], 209 | select_raws: List[Any], 210 | select_types: List[Tuple[bool, bool, bool]] 211 | ) -> List[Any]: 212 | return optimize_casting(values, select_raws, select_types) 213 | -------------------------------------------------------------------------------- /iceaxe/migrations/cli.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from time import monotonic_ns 3 | 4 | from iceaxe.base import DBModelMetaclass, TableBase 5 | from iceaxe.io import resolve_package_path 6 | from iceaxe.logging import CONSOLE 7 | from iceaxe.schemas.db_serializer import DatabaseSerializer 8 | from iceaxe.session import DBConnection 9 | 10 | 11 | async def handle_generate( 12 | package: str, db_connection: DBConnection, message: str | None = None 13 | ): 14 | """ 15 | Creates a new migration definition file, comparing the previous version 16 | (if it exists) with the current schema. 17 | 18 | :param package: The current python package name. This should match the name of the 19 | project that's specified in pyproject.toml or setup.py. 20 | 21 | :param message: An optional message to include in the migration file. Helps 22 | with describing changes and searching for past migration logic over time. 23 | 24 | ```python {{sticky: True}} 25 | from iceaxe.migrations.cli import handle_generate 26 | from click import command, option 27 | 28 | @command() 29 | @option("--message", help="A message to include in the migration file.") 30 | def generate_migration(message: str): 31 | db_connection = DBConnection(...) 32 | handle_generate("my_project", db_connection, message=message) 33 | ``` 34 | """ 35 | # Any local imports must be done here to avoid circular imports because migrations.__init__ 36 | # imports this file. 37 | from iceaxe.migrations.client_io import fetch_migrations 38 | from iceaxe.migrations.generator import MigrationGenerator 39 | from iceaxe.migrations.migrator import Migrator 40 | 41 | CONSOLE.print("[bold blue]Generating migration to current schema") 42 | 43 | CONSOLE.print( 44 | "[grey58]Note that Iceaxe's migration support is well tested but still in beta." 45 | ) 46 | CONSOLE.print( 47 | "[grey58]File an issue @ https://github.com/piercefreeman/iceaxe/issues if you encounter any problems." 48 | ) 49 | 50 | # Locate the migrations directory that belongs to this project 51 | package_path = resolve_package_path(package) 52 | migrations_path = package_path / "migrations" 53 | 54 | # Create the path if it doesn't exist 55 | migrations_path.mkdir(exist_ok=True) 56 | if not (migrations_path / "__init__.py").exists(): 57 | (migrations_path / "__init__.py").touch() 58 | 59 | # Get all of the instances that have been registered 60 | # in memory scope by the user. 61 | models = [ 62 | cls 63 | for cls in DBModelMetaclass.get_registry() 64 | if isclass(cls) and issubclass(cls, TableBase) 65 | ] 66 | 67 | db_serializer = DatabaseSerializer() 68 | db_objects = [] 69 | async for values in db_serializer.get_objects(db_connection): 70 | db_objects.append(values) 71 | 72 | migration_generator = MigrationGenerator() 73 | up_objects = list(migration_generator.serializer.delegate(models)) 74 | 75 | # Get the current revision from the database, this should represent the "down" revision 76 | # for the new migration 77 | migrator = Migrator(db_connection) 78 | await migrator.init_db() 79 | current_revision = await migrator.get_active_revision() 80 | 81 | # Make sure there's not a duplicate revision that already have this down revision. If so that means 82 | # that we will have two conflicting migration chains 83 | migration_revisions = fetch_migrations(migrations_path) 84 | conflict_migrations = [ 85 | migration 86 | for migration in migration_revisions 87 | if migration.down_revision == current_revision 88 | ] 89 | if conflict_migrations: 90 | up_revisions = {migration.up_revision for migration in conflict_migrations} 91 | raise ValueError( 92 | f"Found conflicting migrations with down revision {current_revision} (conflicts: {up_revisions}).\n" 93 | "If you're trying to generate a new migration, make sure to apply the previous migration first - or delete the old one and recreate." 94 | ) 95 | 96 | migration_code, revision = await migration_generator.new_migration( 97 | db_objects, 98 | up_objects, 99 | down_revision=current_revision, 100 | user_message=message, 101 | ) 102 | 103 | # Create the migration file. The change of a conflict with this timestamp is very low, but we make sure 104 | # not to override any existing files anyway. 105 | migration_file_path = migrations_path / f"rev_{revision}.py" 106 | if migration_file_path.exists(): 107 | raise ValueError( 108 | f"Migration file {migration_file_path} already exists. Wait a second and try again." 109 | ) 110 | 111 | migration_file_path.write_text(migration_code) 112 | 113 | CONSOLE.print(f"[bold green]New migration added: {migration_file_path.name}") 114 | 115 | 116 | async def handle_apply( 117 | package: str, 118 | db_connection: DBConnection, 119 | ): 120 | """ 121 | Applies all migrations that have not been applied to the database. 122 | 123 | :param package: The current python package name. This should match the name of the 124 | project that's specified in pyproject.toml or setup.py. 125 | 126 | """ 127 | from iceaxe.migrations.client_io import fetch_migrations, sort_migrations 128 | from iceaxe.migrations.migrator import Migrator 129 | 130 | migrations_path = resolve_package_path(package) / "migrations" 131 | if not migrations_path.exists(): 132 | raise ValueError(f"Migrations path {migrations_path} does not exist.") 133 | 134 | # Load all the migration files into memory and locate the subclasses of MigrationRevisionBase 135 | migration_revisions = fetch_migrations(migrations_path) 136 | migration_revisions = sort_migrations(migration_revisions) 137 | 138 | # Get the current revision from the database 139 | migrator = Migrator(db_connection) 140 | await migrator.init_db() 141 | current_revision = await migrator.get_active_revision() 142 | 143 | CONSOLE.print(f"Current revision: {current_revision}") 144 | 145 | # Find the item in the sequence that has down_revision equal to the current_revision 146 | # This indicates the next migration to apply 147 | next_migration_index = None 148 | for i, revision in enumerate(migration_revisions): 149 | if revision.down_revision == current_revision: 150 | next_migration_index = i 151 | break 152 | 153 | if next_migration_index is None: 154 | raise ValueError( 155 | f"Could not find a migration to apply after revision {current_revision}." 156 | ) 157 | 158 | # Get the chain after this index, this should indicate the next migration to apply 159 | migration_chain = migration_revisions[next_migration_index:] 160 | CONSOLE.print(f"Applying {len(migration_chain)} migrations...") 161 | 162 | for migration in migration_chain: 163 | with CONSOLE.status( 164 | f"[bold blue]Applying {migration.up_revision}...", spinner="dots" 165 | ): 166 | start = monotonic_ns() 167 | await migration._handle_up(db_connection) 168 | 169 | CONSOLE.print( 170 | f"[bold green]🚀 Applied {migration.up_revision} in {(monotonic_ns() - start) / 1e9:.2f}s" 171 | ) 172 | 173 | 174 | async def handle_rollback( 175 | package: str, 176 | db_connection: DBConnection, 177 | ): 178 | """ 179 | Rolls back the last migration that was applied to the database. 180 | 181 | :param package: The current python package name. This should match the name of the 182 | project that's specified in pyproject.toml or setup.py. 183 | 184 | """ 185 | from iceaxe.migrations.client_io import fetch_migrations, sort_migrations 186 | from iceaxe.migrations.migrator import Migrator 187 | 188 | migrations_path = resolve_package_path(package) / "migrations" 189 | if not migrations_path.exists(): 190 | raise ValueError(f"Migrations path {migrations_path} does not exist.") 191 | 192 | # Load all the migration files into memory and locate the subclasses of MigrationRevisionBase 193 | migration_revisions = fetch_migrations(migrations_path) 194 | migration_revisions = sort_migrations(migration_revisions) 195 | 196 | # Get the current revision from the database 197 | migrator = Migrator(db_connection) 198 | await migrator.init_db() 199 | current_revision = await migrator.get_active_revision() 200 | 201 | CONSOLE.print(f"Current revision: {current_revision}") 202 | 203 | # Find the item in the sequence that has down_revision equal to the current_revision 204 | # This indicates the next migration to apply 205 | this_migration_index = None 206 | for i, revision in enumerate(migration_revisions): 207 | if revision.up_revision == current_revision: 208 | this_migration_index = i 209 | break 210 | 211 | if this_migration_index is None: 212 | raise ValueError( 213 | f"Could not find a migration matching {current_revision} for rollback." 214 | ) 215 | 216 | # Get the chain after this index, this should indicate the next migration to apply 217 | this_migration = migration_revisions[this_migration_index] 218 | 219 | with CONSOLE.status( 220 | f"[bold blue]Rolling back revision {this_migration.up_revision} to {this_migration.down_revision}...", 221 | spinner="dots", 222 | ): 223 | start = monotonic_ns() 224 | await this_migration._handle_down(db_connection) 225 | 226 | CONSOLE.print( 227 | f"[bold green]🪃 Rolled back migration to {this_migration.down_revision} in {(monotonic_ns() - start) / 1e9:.2f}s" 228 | ) 229 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # iceaxe 2 | 3 | ![Iceaxe Logo](https://raw.githubusercontent.com/piercefreeman/iceaxe/main/media/header.png) 4 | 5 | ![Python Version](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2Fpiercefreeman%2Ficeaxe%2Frefs%2Fheads%2Fmain%2Fpyproject.toml) [![Test status](https://github.com/piercefreeman/iceaxe/actions/workflows/test.yml/badge.svg)](https://github.com/piercefreeman/iceaxe/actions) 6 | 7 | A modern, fast ORM for Python. We have the following goals: 8 | 9 | - 🏎️ **Performance**: We want to exceed or match the fastest ORMs in Python. We want our ORM 10 | to be as close as possible to raw-[asyncpg](https://github.com/MagicStack/asyncpg) speeds. See the "Benchmarks" section for more. 11 | - 📝 **Typehinting**: Everything should be typehinted with expected types. Declare your data as you 12 | expect in Python and it should bidirectionally sync to the database. 13 | - 🐘 **Postgres only**: Leverage native Postgres features and simplify the implementation. 14 | - ⚡ **Common things are easy, rare things are possible**: 99% of the SQL queries we write are 15 | vanilla SELECT/INSERT/UPDATEs. These should be natively supported by your ORM. If you're writing _really_ 16 | complex queries, these are better done by hand so you can see exactly what SQL will be run. 17 | 18 | Iceaxe is used in production at several companies. It's also an independent project. It's compatible with the [Mountaineer](https://github.com/piercefreeman/mountaineer) ecosystem, but you can use it in whatever 19 | project and web framework you're using. 20 | 21 | For comprehensive documentation, visit [https://iceaxe.sh](https://iceaxe.sh). 22 | 23 | To auto-optimize your self hosted Postgres install, check out our new [autopg](https://github.com/piercefreeman/autopg) project. 24 | 25 | ## Installation 26 | 27 | If you're using poetry to manage your dependencies: 28 | 29 | ```bash 30 | uv add iceaxe 31 | ``` 32 | 33 | Otherwise install with pip: 34 | 35 | ```bash 36 | pip install iceaxe 37 | ``` 38 | 39 | ## Usage 40 | 41 | Define your models as a `TableBase` subclass: 42 | 43 | ```python 44 | from iceaxe import TableBase 45 | 46 | class Person(TableBase): 47 | id: int 48 | name: str 49 | age: int 50 | ``` 51 | 52 | TableBase is a subclass of Pydantic's `BaseModel`, so you get all of the validation and Field customization 53 | out of the box. We provide our own `Field` constructor that adds database-specific configuration. For instance, to make the 54 | `id` field a primary key / auto-incrementing you can do: 55 | 56 | ```python 57 | from iceaxe import Field 58 | 59 | class Person(TableBase): 60 | id: int = Field(primary_key=True) 61 | name: str 62 | age: int 63 | ``` 64 | 65 | Okay now you have a model. How do you interact with it? 66 | 67 | Databases are based on a few core primitives to insert data, update it, and fetch it out again. 68 | To do so you'll need a _database connection_, which is a connection over the network from your code 69 | to your Postgres database. The `DBConnection` is the core class for all ORM actions against the database. 70 | 71 | ```python 72 | from iceaxe import DBConnection 73 | import asyncpg 74 | 75 | conn = DBConnection( 76 | await asyncpg.connect( 77 | host="localhost", 78 | port=5432, 79 | user="db_user", 80 | password="yoursecretpassword", 81 | database="your_db", 82 | ) 83 | ) 84 | ``` 85 | 86 | The Person class currently just lives in memory. To back it with a full 87 | database table, we can run raw SQL or run a migration to add it: 88 | 89 | ```python 90 | await conn.conn.execute( 91 | """ 92 | CREATE TABLE IF NOT EXISTS person ( 93 | id SERIAL PRIMARY KEY, 94 | name TEXT NOT NULL, 95 | age INT NOT NULL 96 | ) 97 | """ 98 | ) 99 | ``` 100 | 101 | ### Inserting Data 102 | 103 | Instantiate object classes as you normally do: 104 | 105 | ```python 106 | people = [ 107 | Person(name="Alice", age=30), 108 | Person(name="Bob", age=40), 109 | Person(name="Charlie", age=50), 110 | ] 111 | await conn.insert(people) 112 | 113 | print(people[0].id) # 1 114 | print(people[1].id) # 2 115 | ``` 116 | 117 | Because we're using an auto-incrementing primary key, the `id` field will be populated after the insert. 118 | Iceaxe will automatically update the object in place with the newly assigned value. 119 | 120 | ### Updating data 121 | 122 | Now that we have these lovely people, let's modify them. 123 | 124 | ```python 125 | person = people[0] 126 | person.name = "Blice" 127 | ``` 128 | 129 | Right now, we have a Python object that's out of state with the database. But that's often okay. We can inspect it 130 | and further write logic - it's fully decoupled from the database. 131 | 132 | ```python 133 | def ensure_b_letter(person: Person): 134 | if person.name[0].lower() != "b": 135 | raise ValueError("Name must start with 'B'") 136 | 137 | ensure_b_letter(person) 138 | ``` 139 | 140 | To sync the values back to the database, we can call `update`: 141 | 142 | ```python 143 | await conn.update([person]) 144 | ``` 145 | 146 | If we were to query the database directly, we see that the name has been updated: 147 | 148 | ``` 149 | id | name | age 150 | ----+-------+----- 151 | 1 | Blice | 31 152 | 2 | Bob | 40 153 | 3 | Charlie | 50 154 | ``` 155 | 156 | But no other fields have been touched. This lets a potentially concurrent process 157 | modify `Alice`'s record - say, updating the age to 31. By the time we update the data, we'll 158 | change the name but nothing else. Under the hood we do this by tracking the fields that 159 | have been modified in-memory and creating a targeted UPDATE to modify only those values. 160 | 161 | ### Selecting data 162 | 163 | To select data, we can use a `QueryBuilder`. For a shortcut to `select` query functions, 164 | you can also just import select directly. This method takes the desired value parameters 165 | and returns a list of the desired objects. 166 | 167 | ```python 168 | from iceaxe import select 169 | 170 | query = select(Person).where(Person.name == "Blice", Person.age > 25) 171 | results = await conn.exec(query) 172 | ``` 173 | 174 | If we inspect the typing of `results`, we see that it's a `list[Person]` objects. This matches 175 | the typehint of the `select` function. You can also target columns directly: 176 | 177 | ```python 178 | query = select((Person.id, Person.name)).where(Person.age > 25) 179 | results = await conn.exec(query) 180 | ``` 181 | 182 | This will return a list of tuples, where each tuple is the id and name of the person: `list[tuple[int, str]]`. 183 | 184 | We support most of the common SQL operations. Just like the results, these are typehinted 185 | to their proper types as well. Static typecheckers and your IDE will throw an error if you try to compare 186 | a string column to an integer, for instance. A more complex example of a query: 187 | 188 | ```python 189 | query = select(( 190 | Person.id, 191 | FavoriteColor, 192 | )).join( 193 | FavoriteColor, 194 | Person.id == FavoriteColor.person_id, 195 | ).where( 196 | Person.age > 25, 197 | Person.name == "Blice", 198 | ).order_by( 199 | Person.age.desc(), 200 | ).limit(10) 201 | results = await conn.exec(query) 202 | ``` 203 | 204 | As expected this will deliver results - and typehint - as a `list[tuple[int, FavoriteColor]]` 205 | 206 | ## Production 207 | 208 | > [!IMPORTANT] 209 | > Iceaxe is in early alpha. We're using it internally and showly rolling out to our production 210 | applications, but we're not yet ready to recommend it for general use. The API and larger 211 | stability is subject to change. 212 | 213 | Note that underlying Postgres connection wrapped by `conn` will be alive for as long as your object is in memory. This uses up one 214 | of the allowable connections to your database. Your overall limit depends on your Postgres configuration 215 | or hosting provider, but most managed solutions top out around 150-300. If you need more concurrent clients 216 | connected (and even if you don't - connection creation at the Postgres level is expensive), you can adopt 217 | a load balancer like `pgbouncer` to better scale to traffic. More deployment notes to come. 218 | 219 | It's also worth noting the absence of request pooling in this initialization. This is a feature of many ORMs that lets you limit 220 | the overall connections you make to Postgres, and re-use these over time. We specifically don't offer request 221 | pooling as part of Iceaxe, despite being supported by our underlying engine `asyncpg`. This is a bit more 222 | aligned to how things should be structured in production. Python apps are always bound to one process thanks to 223 | the GIL. So no matter what your connection pool will always be tied to the current Python process / runtime. When you're deploying onto a server with multiple cores, the pool will be duplicated across CPUs and largely defeats the purpose of capping 224 | network connections in the first place. 225 | 226 | ## Benchmarking 227 | 228 | We have basic benchmarking tests in the `__tests__/benchmarks` directory. To run them, you'll need to execute the pytest suite: 229 | 230 | ```bash 231 | uv run pytest -m integration_tests 232 | ``` 233 | 234 | Current benchmarking as of October 11 2024 is: 235 | 236 | | | raw asyncpg | iceaxe | external overhead | | 237 | |-------------------|-------------|--------|-----------------------------------------------|---| 238 | | TableBase columns | 0.098s | 0.093s | | | 239 | | TableBase full | 0.164s | 1.345s | 10%: dict construction | 90%: pydantic overhead | | 240 | 241 | ## Development 242 | 243 | If you update your Cython implementation during development, you'll need to re-compile the Cython code. This can be done with 244 | a simple uv sync. 245 | 246 | ```bash 247 | uv sync 248 | ``` -------------------------------------------------------------------------------- /iceaxe/__tests__/helpers.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | import os 4 | from contextlib import contextmanager 5 | from dataclasses import dataclass 6 | from json import JSONDecodeError, dump as json_dump, loads as json_loads 7 | from re import Pattern 8 | from tempfile import NamedTemporaryFile, TemporaryDirectory 9 | from textwrap import dedent 10 | 11 | from pyright import run 12 | 13 | 14 | @dataclass 15 | class PyrightDiagnostic: 16 | file: str 17 | severity: str 18 | message: str 19 | rule: str | None 20 | line: int 21 | column: int 22 | 23 | 24 | class ExpectedPyrightError(Exception): 25 | """ 26 | Exception raised when Pyright doesn't produce the expected error 27 | 28 | """ 29 | 30 | pass 31 | 32 | 33 | def get_imports_from_module(module_source: str) -> set[str]: 34 | """ 35 | Extract all import statements from module source 36 | 37 | """ 38 | tree = ast.parse(module_source) 39 | imports: set[str] = set() 40 | 41 | for node in ast.walk(tree): 42 | if isinstance(node, ast.Import): 43 | for name in node.names: 44 | imports.add(f"import {name.name}") 45 | elif isinstance(node, ast.ImportFrom): 46 | names = ", ".join(name.name for name in node.names) 47 | if node.module is None: 48 | # Handle "from . import x" case 49 | imports.add(f"from . import {names}") 50 | else: 51 | imports.add(f"from {node.module} import {names}") 52 | 53 | return imports 54 | 55 | 56 | def strip_type_ignore(line: str) -> str: 57 | """ 58 | Strip type: ignore comments from a line while preserving the line content 59 | 60 | """ 61 | if "#" not in line: 62 | return line 63 | 64 | # Split only on the first # 65 | code_part, *comment_parts = line.split("#", 1) 66 | if not comment_parts: 67 | return line 68 | 69 | comment = comment_parts[0] 70 | # If this is a type: ignore comment, return just the code 71 | if "type:" in comment and "ignore" in comment: 72 | return code_part.rstrip() 73 | 74 | # Otherwise return the full line 75 | return line 76 | 77 | 78 | def extract_current_function_code(): 79 | """ 80 | Extracts the source code of the function calling this utility, 81 | along with any necessary imports at the module level. This only works for 82 | functions in a pytest testing context that are prefixed with `test_`. 83 | 84 | """ 85 | # Get the frame of the calling function 86 | frame = inspect.currentframe() 87 | 88 | try: 89 | # Go up until we find the test function; workaround to not 90 | # knowing the entrypoint of our contextmanager at runtime 91 | while frame is not None: 92 | func_name = frame.f_code.co_name 93 | if func_name.startswith("test_"): 94 | test_frame = frame 95 | break 96 | frame = frame.f_back 97 | else: 98 | raise RuntimeError("Could not find test function frame") 99 | 100 | # Source code of the function 101 | func_source = inspect.getsource(test_frame.f_code) 102 | 103 | # Source code of the larger test file, which contains the test function 104 | # All the imports used by the test function should be within this file 105 | module = inspect.getmodule(test_frame) 106 | if not module: 107 | raise RuntimeError("Could not find module for test function") 108 | 109 | module_source = inspect.getsource(module) 110 | 111 | # Postprocess the source code to build into a valid new module 112 | imports = get_imports_from_module(module_source) 113 | filtered_lines = [strip_type_ignore(line) for line in func_source.split("\n")] 114 | return "\n".join(sorted(imports)) + "\n\n" + dedent("\n".join(filtered_lines)) 115 | 116 | finally: 117 | del frame # Avoid reference cycles 118 | 119 | 120 | def create_pyright_config(): 121 | """ 122 | Creates a new pyright configuration that ignores unused imports or other 123 | issues that are not related to context-manager wrapped type checking. 124 | 125 | """ 126 | return { 127 | "include": ["."], 128 | "exclude": [], 129 | "ignore": [], 130 | "strict": [], 131 | "typeCheckingMode": "strict", 132 | "reportUnusedImport": False, 133 | "reportUnusedVariable": False, 134 | # Focus only on type checking 135 | "reportOptionalMemberAccess": True, 136 | "reportGeneralTypeIssues": True, 137 | "reportPropertyTypeMismatch": True, 138 | "reportFunctionMemberAccess": True, 139 | "reportTypeCommentUsage": True, 140 | "reportMissingTypeStubs": False, 141 | # Only typehint intentional typehints, not inferred values 142 | "reportUnknownParameterType": False, 143 | "reportUnknownVariableType": False, 144 | "reportUnknownMemberType": False, 145 | "reportUnknownArgumentType": False, 146 | "reportMissingParameterType": False, 147 | "extraPaths": [ 148 | os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) 149 | ], 150 | "reportAttributeAccessIssue": True, 151 | "reportArgumentType": True, 152 | } 153 | 154 | 155 | def run_pyright(file_path: str) -> list[PyrightDiagnostic]: 156 | """ 157 | Run pyright on a file and return the diagnostics 158 | 159 | """ 160 | try: 161 | with TemporaryDirectory() as temp_dir: 162 | # Create pyright config 163 | config_path = os.path.join(temp_dir, "pyrightconfig.json") 164 | with open(config_path, "w") as f: 165 | json_dump(create_pyright_config(), f) 166 | 167 | # Copy the file to analyze into the project directory 168 | test_file = os.path.join(temp_dir, "test.py") 169 | with open(file_path, "r") as src, open(test_file, "w") as dst: 170 | dst.write(src.read()) 171 | 172 | # Run pyright with the config 173 | result = run( 174 | "--project", 175 | temp_dir, 176 | "--outputjson", 177 | test_file, 178 | capture_output=True, 179 | text=True, 180 | ) 181 | 182 | try: 183 | output = json_loads(result.stdout) 184 | except JSONDecodeError: 185 | print(f"Failed to parse pyright output: {result.stdout}") # noqa: T201 186 | print(f"Stderr: {result.stderr}") # noqa: T201 187 | raise 188 | 189 | if "generalDiagnostics" not in output: 190 | raise RuntimeError( 191 | f"Unknown pyright output, missing generalDiagnostics: {output}" 192 | ) 193 | 194 | diagnostics: list[PyrightDiagnostic] = [] 195 | for diag in output["generalDiagnostics"]: 196 | diagnostics.append( 197 | PyrightDiagnostic( 198 | file=diag["file"], 199 | severity=diag["severity"], 200 | message=diag["message"], 201 | rule=diag.get("rule"), 202 | line=diag["range"]["start"]["line"] + 1, # Convert to 1-based 203 | column=( 204 | diag["range"]["start"]["character"] 205 | + 1 # Convert to 1-based 206 | ), 207 | ) 208 | ) 209 | 210 | return diagnostics 211 | 212 | except Exception as e: 213 | raise RuntimeError(f"Failed to run pyright: {str(e)}") 214 | 215 | 216 | @contextmanager 217 | def pyright_raises( 218 | expected_rule: str, 219 | expected_line: int | None = None, 220 | matches: Pattern | None = None, 221 | ): 222 | """ 223 | Context manager that verifies code produces a specific Pyright error. 224 | 225 | :params expected_rule: The Pyright rule that should be violated 226 | :params expected_line: Optional line number where the error should occur 227 | 228 | :raises ExpectedPyrightError: If Pyright doesn't produce the expected error 229 | 230 | """ 231 | # Create a temporary file to store the code 232 | with NamedTemporaryFile(mode="w", suffix=".py") as temp_file: 233 | temp_path = temp_file.name 234 | 235 | # Extract the source code of the calling function 236 | source_code = extract_current_function_code() 237 | print(f"Running Pyright on:\n{source_code}") # noqa: T201 238 | 239 | # Write the source code to the temporary file 240 | temp_file.write(source_code) 241 | temp_file.flush() 242 | 243 | # At runtime, our actual code is probably a no-op but we still let it run 244 | # inside the scope of the contextmanager 245 | yield 246 | 247 | # Run Pyright on the temporary file 248 | diagnostics = run_pyright(temp_path) 249 | 250 | # Check if any of the diagnostics match our expected error 251 | for diagnostic in diagnostics: 252 | if diagnostic.rule == expected_rule: 253 | if expected_line is not None and diagnostic.line != expected_line: 254 | continue 255 | if matches and not matches.search(diagnostic.message): 256 | continue 257 | # Found matching error 258 | return 259 | 260 | # If we get here, we didn't find the expected error 261 | actual_errors = [ 262 | f"{d.rule or 'unknown'} on line {d.line}: {d.message}" for d in diagnostics 263 | ] 264 | raise ExpectedPyrightError( 265 | f"Expected Pyright error {expected_rule}" 266 | f"{f' on line {expected_line}' if expected_line else ''}" 267 | f" but got: {', '.join(actual_errors) if actual_errors else 'no errors'}" 268 | ) 269 | -------------------------------------------------------------------------------- /docs/guides/inserts/page.mdx: -------------------------------------------------------------------------------- 1 | # Inserts, Updates, and Upserts 2 | 3 | Now that we've looked at how to query data, let's see how to store it in the 4 | database in the first place. In each of these cases you'll build up your TableBase 5 | objects as Python objects and use the DBConnection to directly sync these into the database. 6 | The operations are split into functions of what you intend to do with the data: 7 | 8 | - DBConnection.insert(): Insert a list of objects into the database (INSERT). 9 | - DBConnection.update(): Update a list of objects in the database (UPDATE). 10 | - DBConnection.upsert(): Insert or update a list of objects in the database (INSERT ... ON CONFLICT DO UPDATE). 11 | 12 | ## Inserting data 13 | 14 | Let's consider this table schema. You'll have an simple `employee` table with a 15 | bit of metadata about each employee. 16 | 17 | ```python 18 | from iceaxe import DBConnection, select, TableBase, Field 19 | 20 | class Employee(TableBase): 21 | id: int | None = Field(primary_key=True, default=None) 22 | name: str 23 | age: int 24 | ``` 25 | 26 | We configure the `id` of the table explicitly as primary key by configuring 27 | its properties in Field(). We also set it to `None` by default to signal to 28 | Iceaxe and Postgres that your table should be auto-incremented. Upon insert 29 | it will automatically assign a unique integer to each row. 30 | 31 | Define a few employees with a typical constructor. Because our TableBase subclasses 32 | Pydantic, we'll validate all input arguments for you on construction. If you pass in a string for 33 | the age, it will raise a validation error. 34 | 35 | ```python 36 | employees = [ 37 | Employee(name="Marcus Holloway", age=28), 38 | Employee(name="Astrid Keller", age=42), 39 | Employee(name="Jasper Okafor", age=36), 40 | ] 41 | 42 | employees[0].id 43 | > None 44 | 45 | await conn.insert(employees) 46 | 47 | employees[0].id 48 | > 1 49 | ``` 50 | 51 | Notice that before we insert the data, the `id` field is `None`. Until we call 52 | the `insert` method, these objects are just sitting in Python memory without any 53 | database backing. It takes our explicit insert call to sync the data to the database 54 | and assign the auto-generated primary keys back into the instance element. 55 | 56 | ## Updating data 57 | 58 | Once you have Iceaxe instances in Python, either fetched from a select or after 59 | you've made an insert, you can manipulate the attributes like you do with 60 | any regular Python object. In the background we track all of your field modifications 61 | so we can write them out when requested. When you're ready to save your changes back to the 62 | database, you can call the `update` method and pass it your instance. 63 | 64 | ```python 65 | from iceaxe import select 66 | 67 | query = select(Employee).where(Employee.name == "Marcus Holloway") 68 | results = await conn.exec(query) 69 | 70 | marcus = results[0] 71 | marcus.age = 29 72 | 73 | await conn.update([marcus]) 74 | ``` 75 | 76 | Before the conn.update(), the database will have the old value. After the update, 77 | the database will have the new value which we can confirm with a fresh select. 78 | 79 | ```python 80 | results = await conn.exec(query) 81 | results[0].age 82 | > 29 83 | ``` 84 | 85 | ## Upserting data 86 | 87 | Sometimes you want to insert a row if it doesn't exist, or update it if it does. This is useful 88 | if you need to enforce some uniqueness constraint on your data where you only want one instance of 89 | an object but might not know at call-time whether you already have one stored. In some cases you 90 | could do a `SELECT` followed by an `INSERT` to accomplish this same thing, but under high load this approach 91 | can run into race conditions where you might end up inserting a duplicate row. 92 | 93 | This is called an upsert operation (INSERT ... ON CONFLICT DO UPDATE). Iceaxe provides a 94 | dedicated `upsert` method for this purpose. 95 | 96 | Let's say we want to add some employees to our database, but we want to update their information if they already exist 97 | based on their name. We'll need to modify our table to add a unique constraint on the name field. 98 | 99 | ```python 100 | from iceaxe import UniqueConstraint 101 | 102 | class Employee(TableBase): 103 | id: int | None = Field(primary_key=True, default=None) 104 | name: str 105 | age: int 106 | 107 | table_args = [ 108 | UniqueConstraint(columns=["name"]) 109 | ] 110 | ``` 111 | 112 | ```python 113 | employees = [ 114 | Employee(name="Marcus Holloway", age=28), 115 | Employee(name="Astrid Keller", age=42), 116 | ] 117 | 118 | # Upsert based on name, updating age if the name already exists 119 | await conn.upsert( 120 | employees, 121 | conflict_fields=(Employee.name,), 122 | update_fields=(Employee.age,) 123 | ) 124 | ``` 125 | 126 | In this example, if an employee with the name "Marcus Holloway" already exists, their age will be updated to 28. 127 | If they don't exist, a new record will be created. The `conflict_fields` parameter specifies which fields should 128 | be used to detect conflicts, while `update_fields` specifies which fields should be updated when a conflict is found. 129 | 130 | ### Returning Values 131 | 132 | Often, you'll want to know which records were affected by the upsert operation. The `returning_fields` parameter 133 | allows you to specify which fields should be returned for each upserted record. 134 | 135 | ```python 136 | results = await conn.upsert( 137 | employees, 138 | conflict_fields=(Employee.name,), 139 | update_fields=(Employee.age,), 140 | returning_fields=(Employee.id, Employee.name) 141 | ) 142 | for employee_id, name in results: 143 | print(f"Upserted employee {name} with ID {employee_id}") 144 | ``` 145 | 146 | The returned values can be useful for tracking which records were affected or getting automatically generated 147 | fields like primary keys. 148 | 149 | 150 | ## Query updates 151 | 152 | If you have a large number of rows to update in a consistent way, you can also 153 | make use of a database level `UPDATE` statement. This avoids the SQL->Python->SQL 154 | round trip of doing a fetch before your updates. We export an `update` query builder 155 | for this purpose. It's close in nature to a `select` query but lets you choose 156 | the type values to update. 157 | 158 | ```python 159 | from iceaxe import update 160 | 161 | query = update(Employee).set(Employee.name, "Marcus Keller").where(Employee.name == "Marcus Holloway") 162 | 163 | query.build() 164 | > ('UPDATE "employee" SET "employee"."name" = $1 165 | WHERE "employee"."name" = $2', ['Marcus Keller', 'Marcus Holloway']) 166 | ``` 167 | 168 | ## Batch updates 169 | 170 | If you're making a large number of updates at the same time, you'll want to avoid individual 171 | calls to DBConnection functions. Since each call incurs a round trip to the database, individual looping 172 | can be quite inefficient when compared to raw Postgres speeds. 173 | 174 | All our manipulation functions (insert, update, upsert) accept a list of objects. If you can create 175 | your objects in bulk, you'll get the best performance by passing them all to the DBConnection 176 | functions. Internally we take care of optimizing these network operations for you to respect Postgres limits. 177 | Specifically: 178 | 179 | - Splitting up the list of table instances into batches of parameters. 180 | - Identifying manipulations that change the same keys and can be batched together. 181 | - Sending separate operations for each underlying database table. 182 | 183 | This addresses the following Postgres limitations: 184 | 185 | - Postgres enforces an overhead of ~32k parameters per query statement. 186 | - Prepared statements can't batch multiple statements in the same query execution. 187 | 188 | Because of our internal ORM bookkeeping, we can optimize these network requests to issue the minimum required 189 | queries and bundle as many operations as possible into a single query execution. This makes sure your database 190 | communication is as fast as possible even during large data manipulation. 191 | 192 | ## Dirty objects 193 | 194 | When you modify a local object in Python, you'll need to explicitly call `conn.update()` to sync the changes back to the database. 195 | If you don't, the changes will be lost when the object is garbage collected. During development and testing, it can be helpful 196 | to track these "dirty" objects - instances that have been modified but not synced to the database. 197 | 198 | You can enable this tracking by setting the `uncommitted_verbosity` parameter when creating your DBConnection: 199 | 200 | ```python 201 | # Enable tracking with different warning levels 202 | conn = DBConnection( 203 | pg_connection, 204 | uncommitted_verbosity="ERROR" 205 | ) 206 | 207 | # Disable tracking (recommended for production) 208 | conn = DBConnection( 209 | pg_connection, 210 | uncommitted_verbosity=None 211 | ) 212 | ``` 213 | 214 | The verbosity levels determine how aggressively the system will notify you about uncommitted changes: 215 | 216 | - `ERROR`: Raises errors in your logs when objects are modified but not committed 217 | - `WARNING`: Logs warnings when objects are modified but not committed 218 | - `INFO`: Provides informational messages about modified but uncommitted objects 219 | - `None`: Disables tracking entirely (default) 220 | 221 | For example, with WARNING level enabled: 222 | 223 | ```python 224 | user = await conn.exec(select(User).where(User.id == 1)) 225 | user.name = "New Name" 226 | # Session closes without calling conn.update([user]) 227 | # WARNING: Object has uncommitted changes: {'name'} 228 | # File "...", line 10, in 229 | # user.name = "New Name" 230 | ``` 231 | 232 | We also provide the initial location of this object's modification in the log message so you can track down where the object was modified 233 | and determine if this was an intentional change. 234 | 235 | ### Performance Considerations 236 | 237 | It's important to note that dirty object tracking introduces non-trivial performance overhead. The system needs to: 238 | 239 | 1. Track all modifications to objects 240 | 2. Maintain a registry of modified objects and the location of the modification 241 | 3. Check the state of objects when the session closes 242 | 243 | For this reason, it's recommended to: 244 | 245 | - Enable tracking during development and testing to catch potential bugs 246 | - Set `uncommitted_verbosity=None` in production environments 247 | - Use tracking selectively when debugging specific issues 248 | --------------------------------------------------------------------------------