├── docs ├── index.md ├── examples │ └── index.md ├── getting-started │ ├── security.md │ ├── changelog.md │ ├── getting-help.md │ ├── installation.md │ ├── welcome.md │ └── contributing.md ├── api │ ├── cli.md │ └── api-reference.md └── assets │ └── otter-solid.svg ├── tests ├── conftest.py ├── unit │ ├── conftest.py │ ├── test_main.py │ ├── db │ │ ├── drivers │ │ │ └── mysql │ │ │ │ └── test_constants.py │ │ ├── constants │ │ │ └── test_db_constants.py │ │ ├── marshalers │ │ │ ├── postgres │ │ │ │ ├── test_analyze_table_relationships.py │ │ │ │ ├── test_postgres_constraint_marshaler.py │ │ │ │ └── test_postgres_relationship_marshaler.py │ │ │ ├── abstract │ │ │ │ └── test_base_relationship_marshaler.py │ │ │ ├── test_column.py │ │ │ └── mysql │ │ │ │ └── test_mysql_constraint_marshaler.py │ │ ├── graph │ │ │ └── test_dependencies.py │ │ ├── utils │ │ │ └── test_url_parser.py │ │ └── factories │ │ │ └── test_connection_factory.py │ ├── core │ │ ├── config │ │ │ ├── test_writer_config.py │ │ │ └── test_config.py │ │ └── writers │ │ │ ├── test_factories.py │ │ │ ├── test_writer_util.py │ │ │ ├── test_pydantic_constraints.py │ │ │ └── utils │ │ │ └── test_writer_utils.py │ ├── utils │ │ ├── strings │ │ │ └── test_strings.py │ │ ├── serialization │ │ │ └── test_json_encoder.py │ │ ├── formatting │ │ │ └── test_formatting.py │ │ ├── types │ │ │ └── test_types.py │ │ └── config │ │ │ └── test_utils_config.py │ └── cli │ │ └── common │ │ └── test_common.py └── integration │ ├── conftest.py │ └── db │ ├── test_seed_generation.py │ └── test_schema_introspection.py ├── src └── supabase_pydantic │ ├── __init__.py │ ├── core │ ├── __init__.py │ ├── writers │ │ ├── __init__.py │ │ ├── utils.py │ │ └── factories.py │ ├── models.py │ └── config.py │ ├── db │ ├── __init__.py │ ├── drivers │ │ ├── __init__.py │ │ ├── postgres │ │ │ ├── __init__.py │ │ │ └── constants.py │ │ └── mysql │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ └── queries.py │ ├── marshalers │ │ ├── postgres │ │ │ ├── __init__.py │ │ │ ├── constraints.py │ │ │ ├── column.py │ │ │ ├── relationship.py │ │ │ └── schema.py │ │ ├── mysql │ │ │ ├── __init__.py │ │ │ ├── column.py │ │ │ ├── constraints.py │ │ │ └── relationship.py │ │ ├── abstract │ │ │ ├── __init__.py │ │ │ ├── base_relationship_marshaler.py │ │ │ ├── base_column_marshaler.py │ │ │ ├── base_constraint_marshaler.py │ │ │ └── base_schema_marshaler.py │ │ ├── __init__.py │ │ └── constraints.py │ ├── connectors │ │ ├── mysql │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── postgres │ │ │ ├── __init__.py │ │ │ └── schema_reader.py │ ├── utils │ │ ├── __init__.py │ │ └── url_parser.py │ ├── factories │ │ ├── __init__.py │ │ └── connection_factory.py │ ├── abstract │ │ ├── __init__.py │ │ └── base_schema_reader.py │ ├── exceptions.py │ ├── seed │ │ └── __init__.py │ ├── database_type.py │ ├── constants.py │ ├── type_factory.py │ ├── registrations.py │ └── graph.py │ ├── utils │ ├── __init__.py │ ├── constants.py │ ├── serialization.py │ ├── strings.py │ ├── config.py │ ├── formatting.py │ ├── io.py │ └── types.py │ ├── cli │ ├── commands │ │ ├── __init__.py │ │ └── clean.py │ ├── __init__.py │ └── common.py │ └── __main__.py ├── .example.env ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── feature-request.md │ └── bug-report.md └── workflows │ ├── ci.yml │ ├── python-multi-version.yml │ └── python-publish.yml ├── CODEOWNERS ├── SECURITY.md ├── LICENSE ├── requirements.txt ├── .pre-commit-config.yaml ├── SUPPORT.md ├── mkdocs.yml ├── .gitignore └── README.md /docs/index.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/examples/index.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/conftest.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/cli/commands/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/core/writers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/drivers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/getting-started/security.md: -------------------------------------------------------------------------------- 1 | --8<-- "SECURITY.md" -------------------------------------------------------------------------------- /src/supabase_pydantic/db/drivers/postgres/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/postgres/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/drivers/mysql/__init__.py: -------------------------------------------------------------------------------- 1 | """MySQL driver package.""" 2 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/connectors/mysql/__init__.py: -------------------------------------------------------------------------------- 1 | """MySQL connector package.""" 2 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities for database operations.""" 2 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/connectors/__init__.py: -------------------------------------------------------------------------------- 1 | """Database connector module package.""" 2 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/mysql/__init__.py: -------------------------------------------------------------------------------- 1 | """MySQL marshalers package.""" 2 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/connectors/postgres/__init__.py: -------------------------------------------------------------------------------- 1 | """PostgreSQL connector module.""" 2 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/factories/__init__.py: -------------------------------------------------------------------------------- 1 | """Factory classes for database operations.""" 2 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/abstract/__init__.py: -------------------------------------------------------------------------------- 1 | """Abstract base classes for database operations.""" 2 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/abstract/__init__.py: -------------------------------------------------------------------------------- 1 | """Abstract base classes for database marshaling.""" 2 | -------------------------------------------------------------------------------- /.example.env: -------------------------------------------------------------------------------- 1 | # connection parametrs 2 | DB_NAME = '' 3 | DB_USER = '' 4 | DB_PASS = '' 5 | DB_HOST = '' 6 | DB_PORT = '' -------------------------------------------------------------------------------- /src/supabase_pydantic/__main__.py: -------------------------------------------------------------------------------- 1 | """Main entry point for the package.""" 2 | 3 | from supabase_pydantic.cli import cli 4 | 5 | if __name__ == '__main__': 6 | cli() 7 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/exceptions.py: -------------------------------------------------------------------------------- 1 | class ConnectionError(Exception): 2 | """Raised when a connection to the Supabase API cannot be established.""" 3 | 4 | pass 5 | -------------------------------------------------------------------------------- /docs/getting-started/changelog.md: -------------------------------------------------------------------------------- 1 | --8<-- "CHANGELOG.md::200" 2 | 3 |
4 | ... and so on. For the full changelog, please see the [releases](https://github.com/kmbhm1/supabase-pydantic/releases) page. -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # Funding 2 | # https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/displaying-a-sponsor-button-in-your-repository 3 | 4 | github: kmbhm1 5 | custom: ["paypal.me/kmbhm1", "venmo.com/u/kmbhm1"] 6 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/seed/__init__.py: -------------------------------------------------------------------------------- 1 | # Re-export the required functions for gen.py 2 | from supabase_pydantic.core.writers.utils import write_seed_file 3 | from supabase_pydantic.db.seed.generator import generate_seed_data 4 | 5 | __all__ = ['generate_seed_data', 'write_seed_file'] 6 | -------------------------------------------------------------------------------- /docs/api/cli.md: -------------------------------------------------------------------------------- 1 | # CLI Reference 2 | 3 | 11 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/database_type.py: -------------------------------------------------------------------------------- 1 | """Database type enum for supported database backends.""" 2 | 3 | from enum import Enum 4 | 5 | 6 | class DatabaseType(Enum): 7 | """Enum for supported database types.""" 8 | 9 | POSTGRES = 'postgres' 10 | MYSQL = 'mysql' 11 | # Add other database types here as they are implemented 12 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/drivers/postgres/constants.py: -------------------------------------------------------------------------------- 1 | # Maps constraint types to their string representations 2 | CONSTRAINT_TYPE_MAP = {'p': 'PRIMARY KEY', 'f': 'FOREIGN KEY', 'u': 'UNIQUE', 'c': 'CHECK', 'x': 'EXCLUDE'} 3 | 4 | # Maps user defined types to their string representations 5 | USER_DEFINED_TYPE_MAP = {'d': 'DOMAIN', 'c': 'COMPOSITE', 'e': 'ENUM', 'r': 'RANGE'} 6 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/drivers/mysql/constants.py: -------------------------------------------------------------------------------- 1 | """MySQL-specific constants.""" 2 | 3 | # Maps constraint types to their string representations 4 | MYSQL_CONSTRAINT_TYPE_MAP = { 5 | 'PRIMARY KEY': 'PRIMARY KEY', 6 | 'FOREIGN KEY': 'FOREIGN KEY', 7 | 'UNIQUE': 'UNIQUE', 8 | 'CHECK': 'CHECK', 9 | } 10 | 11 | # Maps MySQL data types to their Python equivalents 12 | MYSQL_USER_DEFINED_TYPE_MAP = {'ENUM': 'ENUM', 'SET': 'SET'} 13 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # https://help.github.com/articles/about-codeowners/ 2 | 3 | # These owners will be the default owners for everything in 4 | # the repo. Unless a later match takes precedence, they will 5 | # be requested for review when someone opens a PR. 6 | * @kmbhm1 7 | 8 | # Order is important; the last matching pattern takes the most 9 | # precedence. When someone opens a PR that only modifies 10 | # .yml files, only the following people and NOT the global 11 | # owner(s) will be requested for a review. 12 | -------------------------------------------------------------------------------- /tests/unit/test_main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unittest.mock import patch 3 | 4 | 5 | def test_main_function(): 6 | """Test the __main__ module executes cli function when run directly.""" 7 | with patch('supabase_pydantic.cli.cli') as mock_cli: 8 | # Directly execute the conditional code from __main__.py 9 | from supabase_pydantic.__main__ import cli 10 | 11 | # Manually execute what would happen in the if __name__ == '__main__' block 12 | cli() 13 | 14 | # Verify cli was called 15 | mock_cli.assert_called_once() 16 | -------------------------------------------------------------------------------- /src/supabase_pydantic/utils/constants.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict 2 | 3 | 4 | class AppConfig(TypedDict, total=False): 5 | """Configuration for the Supabase Pydantic tool.""" 6 | 7 | default_directory: str 8 | overwrite_existing_files: bool 9 | nullify_base_schema: bool 10 | disable_model_prefix_protection: bool 11 | 12 | 13 | class ToolConfig(TypedDict): 14 | """Configuration for the Supabase Pydantic tool.""" 15 | 16 | supabase_pydantic: AppConfig 17 | 18 | 19 | # File Names 20 | STD_PYDANTIC_FILENAME = 'schemas.py' 21 | STD_SQLALCHEMY_FILENAME = 'database.py' 22 | STD_SEED_DATA_FILENAME = 'seed.sql' 23 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/__init__.py: -------------------------------------------------------------------------------- 1 | from supabase_pydantic.db.marshalers.column import get_alias, standardize_column_name 2 | from supabase_pydantic.db.marshalers.constraints import add_constraints_to_table_details 3 | from supabase_pydantic.db.marshalers.relationships import analyze_table_relationships, determine_relationship_type 4 | from supabase_pydantic.db.marshalers.schema import construct_table_info, get_table_details_from_columns 5 | 6 | __all__ = [ 7 | 'construct_table_info', 8 | 'standardize_column_name', 9 | 'get_alias', 10 | 'analyze_table_relationships', 11 | 'determine_relationship_type', 12 | 'add_constraints_to_table_details', 13 | 'get_table_details_from_columns', 14 | ] 15 | -------------------------------------------------------------------------------- /tests/integration/conftest.py: -------------------------------------------------------------------------------- 1 | """Configuration and fixtures for integration tests.""" 2 | 3 | import pytest 4 | 5 | 6 | def pytest_configure(config): 7 | """Configure pytest with custom markers.""" 8 | config.addinivalue_line('markers', 'integration: mark test as an integration test') 9 | config.addinivalue_line('markers', 'singular_names: mark test as related to singular names functionality') 10 | 11 | 12 | @pytest.fixture 13 | def integration_test_marker(): 14 | """Marker to identify integration tests.""" 15 | return pytest.mark.integration 16 | 17 | 18 | @pytest.fixture 19 | def singular_names_marker(): 20 | """Marker to identify singular names tests.""" 21 | return pytest.mark.singular_names 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea for this project 4 | title: "[FEATURE]" 5 | labels: enhancement 6 | assignees: kmbhm1 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | If you discover a security issue, please bring it to our attention right away! 4 | 5 | ## Supported Versions 6 | 7 | Project versions that are currently being supported with security updates vary per project. Please see specific project repositories for details. If nothing is specified, only the latest major versions are supported. 8 | 9 | ## Reporting a Vulnerability 10 | 11 | Please **DO NOT** file a public issue to report a security vulberability, instead send your report privately to **[@kmbhm1](mailto:kmbhm1@gmail.com)**. This will help ensure that any vulnerabilities that are found can be [disclosed responsibly](https://en.wikipedia.org/wiki/Responsible_disclosure) to any affected parties. 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: bug 6 | assignees: kmbhm1 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /docs/getting-started/getting-help.md: -------------------------------------------------------------------------------- 1 | # Getting help with supabase-pydantic 2 | 3 | ## Have an issue? 4 | 5 | If you have any issues with supabase-pydantic, feel free to [open an issue](https://github.com/kmbhm1/supabase-pydantic/issues/new/choose). However, please understand that this is a growing project and I may not be able to address all issues immediately or monitor as actively as a full time maintainer or team. 6 | 7 | ## Need help? 8 | 9 | The API [documents](../api/api-reference.md) & [examples](../examples/setup-slack-simple-fastapi.md) are the best source reference documentation for supabase-pydantic. 10 | 11 | ## Stack Overflow 12 | 13 | If you have a question about how to use supabase-pydantic, please ask on [Stack Overflow](https://stackoverflow.com/questions/tagged/supabase-pydantic). Tag your question with `supabase-pydantic` so that the community can see it. -------------------------------------------------------------------------------- /src/supabase_pydantic/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import decimal 2 | import json 3 | from dataclasses import asdict, dataclass 4 | from datetime import date, datetime 5 | from typing import Any 6 | 7 | 8 | @dataclass 9 | class AsDictParent: 10 | def as_dict(self) -> dict[str, Any]: 11 | """Convert the dataclass instance to a dictionary.""" 12 | return asdict(self) 13 | 14 | def __str__(self) -> str: 15 | return json.dumps(asdict(self), indent=4) 16 | 17 | 18 | class CustomJsonEncoder(json.JSONEncoder): 19 | """Custom JSON encoder for encoding decimal and datetime.""" 20 | 21 | def default(self, o: object) -> Any: 22 | """Encode decimal and datetime objects.""" 23 | if isinstance(o, decimal.Decimal): 24 | return str(o) 25 | elif isinstance(o, datetime | date): 26 | return o.isoformat() 27 | return super().default(o) 28 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class RelationType(str, Enum): 5 | """Enum for relation types.""" 6 | 7 | ONE_TO_ONE = 'One-to-One' 8 | ONE_TO_MANY = 'One-to-Many' 9 | MANY_TO_MANY = 'Many-to-Many' 10 | MANY_TO_ONE = 'Many-to-One' # When a table has a foreign key to another table (e.g., File -> Project) 11 | 12 | 13 | class DatabaseConnectionType(Enum): 14 | """Enum for database connection types.""" 15 | 16 | LOCAL = 'local' 17 | DB_URL = 'db_url' 18 | 19 | 20 | class DatabaseUserDefinedType(str, Enum): 21 | """Enum for database user defined types.""" 22 | 23 | DOMAIN = 'DOMAIN' 24 | COMPOSITE = 'COMPOSITE' 25 | ENUM = 'ENUM' 26 | RANGE = 'RANGE' 27 | 28 | 29 | # Regex 30 | 31 | POSTGRES_SQL_CONN_REGEX = ( 32 | r'(postgresql|postgres)://([^:@\s]*(?::[^@\s]*)?@)?(?P[^/\?\s:]+)(:\d+)?(/[^?\s]*)?(\?[^\s]*)?$' 33 | ) 34 | -------------------------------------------------------------------------------- /tests/unit/db/drivers/mysql/test_constants.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from supabase_pydantic.db.drivers.mysql.constants import ( 4 | MYSQL_CONSTRAINT_TYPE_MAP, 5 | MYSQL_USER_DEFINED_TYPE_MAP, 6 | ) 7 | 8 | 9 | @pytest.mark.unit 10 | @pytest.mark.constants 11 | def test_mysql_constraint_type_map_contents() -> None: 12 | assert MYSQL_CONSTRAINT_TYPE_MAP['PRIMARY KEY'] == 'PRIMARY KEY' 13 | assert MYSQL_CONSTRAINT_TYPE_MAP['FOREIGN KEY'] == 'FOREIGN KEY' 14 | assert MYSQL_CONSTRAINT_TYPE_MAP['UNIQUE'] == 'UNIQUE' 15 | assert MYSQL_CONSTRAINT_TYPE_MAP['CHECK'] == 'CHECK' 16 | 17 | # Ensure there are no unexpected keys 18 | assert set(MYSQL_CONSTRAINT_TYPE_MAP.keys()) == {'PRIMARY KEY', 'FOREIGN KEY', 'UNIQUE', 'CHECK'} 19 | 20 | 21 | @pytest.mark.unit 22 | @pytest.mark.constants 23 | def test_mysql_user_defined_type_map_contents() -> None: 24 | assert MYSQL_USER_DEFINED_TYPE_MAP == {'ENUM': 'ENUM', 'SET': 'SET'} 25 | -------------------------------------------------------------------------------- /tests/unit/core/config/test_writer_config.py: -------------------------------------------------------------------------------- 1 | """Tests for WriterConfig in supabase_pydantic.core.config.""" 2 | 3 | from supabase_pydantic.core.config import WriterConfig 4 | from supabase_pydantic.core.constants import FrameWorkType, OrmType 5 | 6 | 7 | import pytest 8 | 9 | 10 | @pytest.mark.unit 11 | @pytest.mark.config 12 | def test_WriterConfig_methods(): 13 | """Test WriterConfig methods.""" 14 | writer_config = WriterConfig( 15 | file_type=OrmType.PYDANTIC, 16 | framework_type=FrameWorkType.FASTAPI, 17 | filename='test.py', 18 | directory='foo', 19 | enabled=True, 20 | ) 21 | assert writer_config.ext() == 'py' 22 | assert writer_config.name() == 'test' 23 | assert writer_config.fpath() == 'foo/test.py' 24 | assert writer_config.to_dict() == { 25 | 'file_type': 'OrmType.PYDANTIC', 26 | 'framework_type': 'FrameWorkType.FASTAPI', 27 | 'filename': 'test.py', 28 | 'directory': 'foo', 29 | 'enabled': 'True', 30 | } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Kevin Boehm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/unit/core/config/test_config.py: -------------------------------------------------------------------------------- 1 | """Tests for configuration utilities in supabase_pydantic.core.config.""" 2 | 3 | from supabase_pydantic.core.config import ( 4 | WriterConfig, 5 | get_standard_jobs, 6 | local_default_env_configuration, 7 | ) 8 | 9 | 10 | import pytest 11 | 12 | 13 | @pytest.mark.unit 14 | @pytest.mark.config 15 | def test_get_standard_jobs_returns_jobs(): 16 | models = ('pydantic',) 17 | frameworks = ('fastapi',) 18 | dirs: dict[str, str | None] = { 19 | 'fastapi': 'fastapi_directory', 20 | } 21 | schemas = ('public',) 22 | 23 | job_dict = get_standard_jobs(models, frameworks, dirs, schemas) 24 | 25 | assert 'public' in job_dict 26 | jobs = job_dict['public'] 27 | assert all([isinstance(job, WriterConfig) for job in jobs.values()]) 28 | 29 | 30 | @pytest.mark.unit 31 | @pytest.mark.config 32 | def test_local_default_env_configuration(): 33 | env_vars = local_default_env_configuration() 34 | assert env_vars == { 35 | 'DB_NAME': 'postgres', 36 | 'DB_USER': 'postgres', 37 | 'DB_PASS': 'postgres', 38 | 'DB_HOST': 'localhost', 39 | 'DB_PORT': '54322', 40 | } 41 | -------------------------------------------------------------------------------- /tests/unit/utils/strings/test_strings.py: -------------------------------------------------------------------------------- 1 | """Tests for string utilities in supabase_pydantic.utils.strings.""" 2 | 3 | import pytest # noqa: F401 4 | 5 | from supabase_pydantic.utils.strings import ( 6 | chunk_text, 7 | to_pascal_case, 8 | ) 9 | 10 | 11 | @pytest.mark.unit 12 | @pytest.mark.strings 13 | def test_to_pascal_case(): 14 | """Test to_pascal_case.""" 15 | assert to_pascal_case('snake_case') == 'SnakeCase' 16 | assert to_pascal_case('snake_case_with_more_words') == 'SnakeCaseWithMoreWords' 17 | assert to_pascal_case('snake') == 'Snake' 18 | assert to_pascal_case('snake_case_with_1_number') == 'SnakeCaseWith1Number' 19 | assert to_pascal_case('snake_case_with_1_number_and_1_symbol!') == 'SnakeCaseWith1NumberAnd1Symbol!' 20 | 21 | 22 | @pytest.mark.unit 23 | @pytest.mark.strings 24 | def test_chunk_text(): 25 | """Test chunk_text.""" 26 | text = 'This is a test text that will be split into lines with a maximum number of characters.' 27 | lines = chunk_text(text, nchars=40) 28 | assert lines == [ 29 | 'This is a test text that will be split', 30 | 'into lines with a maximum number of', 31 | 'characters.', 32 | ] 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "4.0" 2 | click-option-group==0.5.6 ; python_version >= "3.10" and python_version < "4" 3 | click==8.1.8 ; python_version >= "3.10" and python_version < "4.0" 4 | colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and platform_system == "Windows" 5 | faker==24.14.1 ; python_version >= "3.10" and python_version < "4.0" 6 | inflection==0.5.1 ; python_version >= "3.10" and python_version < "4.0" 7 | psycopg2-binary==2.9.10 ; python_version >= "3.10" and python_version < "4.0" 8 | pydantic-core==2.27.2 ; python_version >= "3.10" and python_version < "4.0" 9 | pydantic==2.10.6 ; python_version >= "3.10" and python_version < "4.0" 10 | python-dateutil==2.9.0.post0 ; python_version >= "3.10" and python_version < "4.0" 11 | python-dotenv==1.0.1 ; python_version >= "3.10" and python_version < "4.0" 12 | six==1.17.0 ; python_version >= "3.10" and python_version < "4.0" 13 | toml==0.10.2 ; python_version >= "3.10" and python_version < "4.0" 14 | types-psycopg2==2.9.21.20250121 ; python_version >= "3.10" and python_version < "4.0" 15 | typing-extensions==4.12.2 ; python_version >= "3.10" and python_version < "4.0" 16 | -------------------------------------------------------------------------------- /src/supabase_pydantic/utils/strings.py: -------------------------------------------------------------------------------- 1 | def to_pascal_case(string: str) -> str: 2 | """Converts a string to PascalCase.""" 3 | words = string.split('_') 4 | camel_case = ''.join(word.capitalize() for word in words) 5 | return camel_case 6 | 7 | 8 | def chunk_text(text: str, nchars: int = 79) -> list[str]: 9 | """Split text into lines with a maximum number of characters.""" 10 | words = text.split() # Split the text into words 11 | lines: list[str] = [] # This will store the final lines 12 | current_line: list[str] = [] # This will store words for the current line 13 | 14 | for word in words: 15 | # Check if adding the next word would exceed the length limit 16 | if (sum(len(w) for w in current_line) + len(word) + len(current_line)) > nchars: 17 | # If adding the word would exceed the limit, join current_line into a string and add to lines 18 | lines.append(' '.join(current_line)) 19 | current_line = [word] # Start a new line with the current word 20 | else: 21 | current_line.append(word) # Add the word to the current line 22 | 23 | # Add the last line to lines if any words are left unadded 24 | if current_line: 25 | lines.append(' '.join(current_line)) 26 | 27 | return lines 28 | -------------------------------------------------------------------------------- /docs/api/api-reference.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | ## Database Constants 4 | 5 | ::: supabase_pydantic.db.constants 6 | handler: python 7 | options: 8 | heading_level: 3 9 | show_root_toc_entry: false 10 | 11 | ## Core Constants 12 | 13 | ::: supabase_pydantic.core.constants 14 | handler: python 15 | options: 16 | heading_level: 3 17 | show_root_toc_entry: false 18 | 19 | ## Utility Constants 20 | 21 | ::: supabase_pydantic.utils.constants 22 | handler: python 23 | options: 24 | heading_level: 3 25 | show_root_toc_entry: false 26 | 27 | ## Database Models 28 | 29 | ::: supabase_pydantic.db.models 30 | handler: python 31 | options: 32 | heading_level: 3 33 | show_root_toc_entry: false 34 | 35 | ## Core Models 36 | 37 | ::: supabase_pydantic.core.models 38 | handler: python 39 | options: 40 | heading_level: 3 41 | show_root_toc_entry: false 42 | 43 | ## Writers 44 | 45 | ::: supabase_pydantic.core.writers 46 | handler: python 47 | options: 48 | heading_level: 3 49 | show_root_toc_entry: false 50 | 51 | ## Serialization 52 | 53 | ::: supabase_pydantic.utils.serialization 54 | handler: python 55 | options: 56 | heading_level: 3 57 | show_root_toc_entry: false -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/abstract/base_relationship_marshaler.py: -------------------------------------------------------------------------------- 1 | """Abstract base class for relationship marshaling.""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | from supabase_pydantic.db.constants import RelationType 6 | from supabase_pydantic.db.models import TableInfo 7 | 8 | 9 | class BaseRelationshipMarshaler(ABC): 10 | """Abstract base class for relationship data marshaling.""" 11 | 12 | @abstractmethod 13 | def determine_relationship_type( 14 | self, source_table: TableInfo, target_table: TableInfo, source_column: str, target_column: str 15 | ) -> RelationType: 16 | """Determine the type of relationship between two tables. 17 | 18 | Args: 19 | source_table: Source table information. 20 | target_table: Target table information. 21 | source_column: Column name in source table. 22 | target_column: Column name in target table. 23 | 24 | Returns: 25 | RelationType enum value representing the relationship type. 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def analyze_table_relationships(self, tables: dict[tuple[str, str], TableInfo]) -> None: 31 | """Analyze relationships between tables. 32 | 33 | Args: 34 | tables: Dictionary of table information objects. 35 | """ 36 | pass 37 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Configure Git Credentials 14 | run: | 15 | git config user.name github-actions[bot] 16 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.13" 20 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 21 | - uses: actions/cache@v4 22 | with: 23 | key: mkdocs-material-${{ env.cache_id }} 24 | path: .cache 25 | restore-keys: | 26 | mkdocs-material- 27 | - name: Install Poetry 28 | uses: snok/install-poetry@v1 29 | with: 30 | version: 1.7.1 31 | virtualenvs-create: true 32 | virtualenvs-in-project: true 33 | - name: Install dependencies and package 34 | run: | 35 | poetry install 36 | pip install -e . 37 | - name: Install documentation dependencies 38 | run: pip install mkdocs-material mkdocstrings 'mkdocstrings[python]' mkdocs-redirects pymdown-extensions 39 | - name: Build and deploy docs 40 | run: mkdocs gh-deploy --force 41 | -------------------------------------------------------------------------------- /src/supabase_pydantic/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """Command line interface for Supabase Pydantic.""" 2 | 3 | from typing import Any 4 | 5 | import click 6 | 7 | from supabase_pydantic.cli.commands.clean import clean 8 | from supabase_pydantic.cli.commands.gen import gen 9 | from supabase_pydantic.utils.logging import setup_logging 10 | 11 | 12 | @click.group(invoke_without_command=True) 13 | @click.pass_context 14 | def cli(ctx: Any) -> None: 15 | """A CLI tool for generating Pydantic models from a Supabase/PostgreSQL database. 16 | 17 | In the future, more ORM frameworks and databases will be supported. In the works: 18 | Django, REST Flask, SQLAlchemy, Tortoise-ORM, and more. 19 | 20 | Additionally, more REST API generators will be supported ... 21 | 22 | Stay tuned! 23 | """ 24 | # Initialize logging globally (defaults to INFO). Individual commands may 25 | # reconfigure (e.g., --debug) if needed. 26 | setup_logging('INFO') 27 | 28 | # ensure that ctx.obj exists and is a dict (in case `cli()` is called 29 | # by means other than the `if` block below) 30 | ctx.ensure_object(dict) 31 | 32 | # When invoked without a subcommand, just show help 33 | if ctx.invoked_subcommand is None: 34 | click.echo(ctx.get_help()) 35 | return 36 | 37 | 38 | # Register commands 39 | cli.add_command(clean) 40 | cli.add_command(gen) 41 | 42 | if __name__ == '__main__': 43 | cli() 44 | -------------------------------------------------------------------------------- /tests/unit/db/constants/test_db_constants.py: -------------------------------------------------------------------------------- 1 | """Tests for database constants in supabase_pydantic.db.constants.""" 2 | 3 | import re 4 | 5 | from supabase_pydantic.db.constants import POSTGRES_SQL_CONN_REGEX 6 | 7 | 8 | import pytest 9 | 10 | 11 | @pytest.mark.unit 12 | @pytest.mark.db 13 | @pytest.mark.constants 14 | def test_postgres_db_string_regex(): 15 | """Test the regex for a postgres database string.""" 16 | pattern = re.compile(POSTGRES_SQL_CONN_REGEX) 17 | 18 | assert pattern.match('postgresql://localhost') 19 | assert pattern.match('postgresql://localhost:5433') 20 | assert pattern.match('postgresql://localhost/mydb') 21 | assert pattern.match('postgresql://user@localhost') 22 | assert pattern.match('postgresql://user:secret@localhost') 23 | assert pattern.match('postgresql://other@localhost/otherdb?connect_timeout=10&application_name=myapp') 24 | assert not pattern.match('postgresql://host1:123,host2:456/somedb?application_name=myapp') 25 | assert pattern.match('postgres://myuser:mypassword@192.168.0.100:5432/mydatabase') 26 | assert pattern.match('postgresql://192.168.0.100/mydb') 27 | 28 | assert not pattern.match('postgresql://') 29 | assert not pattern.match('postgresql://localhost:') 30 | assert pattern.match('postgresql://localhost:5433/') 31 | assert pattern.match('postgresql://localhost:5433/mydb?') 32 | assert pattern.match('postgresql://localhost:5433/mydb?connect_timeout=10&application_name=myapp') 33 | assert pattern.match('postgresql://postgres:postgres@127.0.0.1:54333') 34 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/utils/url_parser.py: -------------------------------------------------------------------------------- 1 | """URL parser utilities for database connections.""" 2 | 3 | import re 4 | 5 | from supabase_pydantic.db.database_type import DatabaseType 6 | 7 | # Connection string regexes for different database types 8 | POSTGRES_CONN_REGEX = ( 9 | r'(postgresql|postgres)://([^:@\s]*(?::[^@\s]*)?@)?(?P[^/\?\s:]+)(:\d+)?(/[^?\s]*)?(\?[^\s]*)?$' 10 | ) 11 | MYSQL_CONN_REGEX = r'(mysql|mariadb)://([^:@\s]*(?::[^@\s]*)?@)?(?P[^/\?\s:]+)(:\d+)?(/[^?\s]*)?(\?[^\s]*)?$' 12 | 13 | 14 | def detect_database_type(db_url: str) -> DatabaseType | None: 15 | """Detect the database type from a connection URL. 16 | 17 | Args: 18 | db_url: Database connection URL 19 | 20 | Returns: 21 | DatabaseType enum value or None if unknown 22 | """ 23 | if not db_url: 24 | return None 25 | 26 | if re.match(POSTGRES_CONN_REGEX, db_url): 27 | return DatabaseType.POSTGRES 28 | elif re.match(MYSQL_CONN_REGEX, db_url): 29 | return DatabaseType.MYSQL 30 | 31 | return None 32 | 33 | 34 | def get_database_name_from_url(db_url: str) -> str | None: 35 | """Extract database name from connection URL. 36 | 37 | Args: 38 | db_url: Database connection URL 39 | 40 | Returns: 41 | Database name or None if not found 42 | """ 43 | if not db_url: 44 | return None 45 | 46 | # Extract the path component 47 | path_match = re.search(r'://[^/]+/([^?]+)', db_url) 48 | if path_match: 49 | return path_match.group(1) 50 | 51 | return None 52 | -------------------------------------------------------------------------------- /tests/unit/core/writers/test_factories.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from supabase_pydantic.db.models import TableInfo 4 | from supabase_pydantic.core.constants import FrameWorkType, OrmType 5 | from supabase_pydantic.core.writers.factories import FileWriterFactory 6 | from supabase_pydantic.core.writers.pydantic import PydanticFastAPIWriter 7 | from supabase_pydantic.core.writers.sqlalchemy import SqlAlchemyFastAPIWriter 8 | 9 | 10 | @pytest.fixture 11 | def table_info(): 12 | return [TableInfo(name='test_table', columns=[])] 13 | 14 | 15 | @pytest.fixture 16 | def file_path(): 17 | return 'test_path.py' 18 | 19 | 20 | @pytest.mark.unit 21 | @pytest.mark.writers 22 | @pytest.mark.parametrize( 23 | 'file_type,framework_type,expected_writer', 24 | [ 25 | (OrmType.SQLALCHEMY, FrameWorkType.FASTAPI, SqlAlchemyFastAPIWriter), 26 | (OrmType.PYDANTIC, FrameWorkType.FASTAPI, PydanticFastAPIWriter), 27 | ], 28 | ) 29 | def test_get_file_writer_valid_cases(table_info, file_path, file_type, framework_type, expected_writer): 30 | writer = FileWriterFactory.get_file_writer(table_info, file_path, file_type, framework_type) 31 | assert isinstance(writer, expected_writer), f'Expected {expected_writer}, got {type(writer)}' 32 | 33 | 34 | @pytest.mark.unit 35 | @pytest.mark.writers 36 | def test_get_file_writer_invalid_cases(table_info, file_path): 37 | with pytest.raises(ValueError) as excinfo: 38 | FileWriterFactory.get_file_writer(table_info, file_path, OrmType.PYDANTIC, FrameWorkType) 39 | assert 'Unsupported file type and framework:' in str(excinfo.value) 40 | -------------------------------------------------------------------------------- /docs/assets/otter-solid.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/getting-started/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | To install the package, you can use the following command: 4 | 5 | ```bash 6 | pip install supabase-pydantic 7 | ``` 8 | 9 | There are a few evolving dependencies that will need to be installed as well & to take note of: 10 | 11 | - [pydantic](https://pydantic-docs.helpmanual.io/)-v2: A data validation and settings management using Python type annotations. We may integrate for v1 in the future. 12 | - [psycopg2-binary](https://pypi.org/project/psycopg2-binary/): A PostgreSQL database adapter for Python. We will likely integrate a more robust adapter in the future, possibly integrating the lower-level [asyncpg](https://pypi.org/project/asyncpg/) or golang client with Supabase. Other database connections will be added in the future as well. 13 | - [types-psycopg2](https://pypi.org/project/types-psycopg2/): Type hints for psycopg2. 14 | - [faker](https://pypi.org/project/Faker/): A Python package that generates fake data. This package will be critical for adding a future feature to generate seed data. 15 | 16 | If you have Python 3.10+ installed, you should be fine, but please [report](https://github.com/kmbhm1/supabase-pydantic/issues/new/choose) any issues you encounter. 17 | 18 | The package is also available on [conda](https://docs.conda.io/en/latest/) under the [conda-forge](https://conda-forge.org/) channel: 19 | 20 | ```bash 21 | conda install -c conda-forge supabase-pydantic 22 | ``` 23 | 24 | ## Install from source 25 | 26 | To install the package from source, you can use the following command: 27 | 28 | ```bash 29 | pip install git+https://github.com/kmbhm1/supabase-pydantic.git@main 30 | ``` -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/postgres/constraints.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from supabase_pydantic.db.marshalers.abstract.base_constraint_marshaler import BaseConstraintMarshaler 4 | from supabase_pydantic.db.marshalers.constraints import ( 5 | add_constraints_to_table_details as add_constraints, 6 | ) 7 | from supabase_pydantic.db.marshalers.constraints import ( 8 | parse_constraint_definition_for_fk, 9 | ) 10 | from supabase_pydantic.db.models import TableInfo 11 | 12 | 13 | class PostgresConstraintMarshaler(BaseConstraintMarshaler): 14 | """PostgreSQL implementation of constraint marshaling.""" 15 | 16 | def parse_foreign_key(self, constraint_def: str) -> tuple[str, str, str] | None: 17 | """Parse foreign key definition from database-specific syntax.""" 18 | result = parse_constraint_definition_for_fk(constraint_def) 19 | return result if result else None 20 | 21 | def parse_unique_constraint(self, constraint_def: str) -> list[str]: 22 | """Parse unique constraint from database-specific syntax.""" 23 | match = re.match(r'UNIQUE \(([^)]+)\)', constraint_def) 24 | if match: 25 | columns = match.group(1).split(',') 26 | return [c.strip() for c in columns] 27 | return [] 28 | 29 | def parse_check_constraint(self, constraint_def: str) -> str: 30 | """Parse check constraint from database-specific syntax.""" 31 | # For PostgreSQL, just return the constraint definition as is 32 | return constraint_def 33 | 34 | def add_constraints_to_table_details( 35 | self, tables: dict[tuple[str, str], TableInfo], constraint_data: list[tuple], schema: str 36 | ) -> None: 37 | """Add constraint information to table details.""" 38 | add_constraints(tables, schema, constraint_data) 39 | -------------------------------------------------------------------------------- /src/supabase_pydantic/cli/commands/clean.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import click 5 | 6 | from supabase_pydantic.cli.common import common_options, load_config 7 | from supabase_pydantic.utils.io import clean_directories, get_working_directories 8 | 9 | # Define framework choices - can be imported from a constants file later 10 | framework_choices = ['fastapi'] 11 | 12 | 13 | @click.command() 14 | @common_options 15 | @click.pass_context 16 | @click.option( 17 | '-d', 18 | '--dir', 19 | '--directory', 20 | 'directory', 21 | default=None, # Will use the config value if None 22 | required=False, 23 | help='The directory to clear of generated files and directories. Defaults to config value or "entities".', 24 | ) 25 | def clean(ctx: Any, directory: str | None, config: str | None) -> None: 26 | """Clean the project directory by removing generated files and clearing caches.""" 27 | # Load config if provided 28 | config_dict = {} 29 | if config: 30 | config_dict = load_config() 31 | 32 | click.echo('Cleaning up the project...') 33 | cfg_default_dir = config_dict.get('default_directory', 'entities') 34 | actual_directory = directory if directory is not None else cfg_default_dir 35 | logging.info(f'Cleaning directory: {actual_directory}') 36 | try: 37 | directories = get_working_directories(actual_directory, tuple(framework_choices), auto_create=False) 38 | clean_directories(directories) 39 | except (FileExistsError, FileNotFoundError): 40 | click.echo(f'Directory doesn\'t exist: "{actual_directory}". Exiting...') 41 | return 42 | except Exception as e: 43 | click.echo(f'An error occurred while cleaning the project: {e}') 44 | return 45 | else: 46 | click.echo('Project cleaned.') 47 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/postgres/column.py: -------------------------------------------------------------------------------- 1 | from supabase_pydantic.db.marshalers.abstract.base_column_marshaler import BaseColumnMarshaler 2 | from supabase_pydantic.db.marshalers.column import ( 3 | get_alias as get_col_alias, 4 | ) 5 | from supabase_pydantic.db.marshalers.column import ( 6 | process_udt_field, 7 | ) 8 | from supabase_pydantic.db.marshalers.column import ( 9 | standardize_column_name as std_column_name, 10 | ) 11 | 12 | 13 | class PostgresColumnMarshaler(BaseColumnMarshaler): 14 | """PostgreSQL-specific implementation of column marshaling.""" 15 | 16 | def standardize_column_name(self, column_name: str, disable_model_prefix_protection: bool = False) -> str: 17 | """Standardize column names across database types.""" 18 | result = std_column_name(column_name, disable_model_prefix_protection) 19 | return str(result) if result is not None else '' 20 | 21 | def get_alias(self, column_name: str) -> str: 22 | """Get alias for a column name.""" 23 | result = get_col_alias(column_name) 24 | return str(result) if result is not None else '' 25 | 26 | def process_column_type(self, db_type: str, type_info: str, enum_types: list[str] | None = None) -> str: 27 | """Process database-specific column type into standard Python type.""" 28 | result = process_udt_field(type_info, db_type, known_enum_types=enum_types) 29 | if result is None: 30 | return '' 31 | # Make sure we check for None again and return empty string 32 | # This can happen if the underlying function changes behavior 33 | return '' if result is None else str(result) 34 | 35 | def process_array_type(self, element_type: str) -> str: 36 | """Process array types for the database.""" 37 | return f'list[{element_type}]' 38 | -------------------------------------------------------------------------------- /src/supabase_pydantic/utils/config.py: -------------------------------------------------------------------------------- 1 | from supabase_pydantic.core.config import WriterConfig 2 | from supabase_pydantic.core.constants import FrameWorkType, OrmType 3 | 4 | 5 | def get_standard_jobs( 6 | models: tuple[str], frameworks: tuple[str], dirs: dict[str, str | None], schemas: tuple[str, ...] = ('public',) 7 | ) -> dict[str, dict[str, WriterConfig]]: 8 | """Get the standard jobs for the writer.""" 9 | 10 | jobs: dict[str, dict[str, WriterConfig]] = {} 11 | 12 | for schema in schemas: 13 | # Set the file names 14 | pydantic_fname, sqlalchemy_fname = f'schema_{schema}.py', f'database_{schema}.py' 15 | 16 | # Add the jobs 17 | if 'fastapi' in dirs and dirs['fastapi'] is not None: 18 | jobs[schema] = { 19 | 'Pydantic': WriterConfig( 20 | file_type=OrmType.PYDANTIC, 21 | framework_type=FrameWorkType.FASTAPI, 22 | filename=pydantic_fname, 23 | directory=dirs['fastapi'], 24 | enabled='pydantic' in models and 'fastapi' in frameworks, 25 | ), 26 | 'SQLAlchemy': WriterConfig( 27 | file_type=OrmType.SQLALCHEMY, 28 | framework_type=FrameWorkType.FASTAPI, 29 | filename=sqlalchemy_fname, 30 | directory=dirs['fastapi'], 31 | enabled='sqlalchemy' in models and 'fastapi' in frameworks, 32 | ), 33 | } 34 | 35 | return jobs 36 | 37 | 38 | def local_default_env_configuration() -> dict[str, str | None]: 39 | """Get the environment variables for a local connection.""" 40 | return { 41 | 'DB_NAME': 'postgres', 42 | 'DB_USER': 'postgres', 43 | 'DB_PASS': 'postgres', 44 | 'DB_HOST': 'localhost', 45 | 'DB_PORT': '54322', 46 | } 47 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/abstract/base_column_marshaler.py: -------------------------------------------------------------------------------- 1 | """Abstract base class for column marshaling.""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | 6 | class BaseColumnMarshaler(ABC): 7 | """Abstract base class for column data marshaling.""" 8 | 9 | @abstractmethod 10 | def standardize_column_name(self, column_name: str, disable_model_prefix_protection: bool = False) -> str: 11 | """Standardize column names across database types. 12 | 13 | Args: 14 | column_name: Raw column name from database. 15 | disable_model_prefix_protection: Whether to disable model prefix protection. 16 | 17 | Returns: 18 | Standardized column name suitable for Python code. 19 | """ 20 | pass 21 | 22 | @abstractmethod 23 | def get_alias(self, column_name: str) -> str: 24 | """Get alias for a column name. 25 | 26 | Args: 27 | column_name: Raw column name from database. 28 | 29 | Returns: 30 | Aliased column name if needed. 31 | """ 32 | pass 33 | 34 | @abstractmethod 35 | def process_column_type(self, db_type: str, type_info: str, enum_types: list[str] | None = None) -> str: 36 | """Process database-specific column type into standard Python type. 37 | 38 | Args: 39 | db_type: Database-specific type name. 40 | type_info: Additional type information. 41 | enum_types: Optional list of known enum type names. 42 | 43 | Returns: 44 | Python/Pydantic type name. 45 | """ 46 | pass 47 | 48 | @abstractmethod 49 | def process_array_type(self, element_type: str) -> str: 50 | """Process array types for the database. 51 | 52 | Args: 53 | element_type: Type of array elements. 54 | 55 | Returns: 56 | Python representation of array type. 57 | """ 58 | pass 59 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/factories/connection_factory.py: -------------------------------------------------------------------------------- 1 | """Factory for creating connection parameter objects.""" 2 | 3 | from typing import Any 4 | 5 | from supabase_pydantic.db.database_type import DatabaseType 6 | from supabase_pydantic.db.models import MySQLConnectionParams, PostgresConnectionParams 7 | from supabase_pydantic.db.utils.url_parser import detect_database_type 8 | 9 | 10 | class ConnectionParamFactory: 11 | """Factory for creating database connection parameter models.""" 12 | 13 | @staticmethod 14 | def create_connection_params( 15 | conn_params: dict[str, Any], db_type: DatabaseType | None = None 16 | ) -> PostgresConnectionParams | MySQLConnectionParams: 17 | """Create connection parameter object based on database type. 18 | 19 | Args: 20 | conn_params: Dictionary of connection parameters 21 | db_type: Database type enum value, or None to auto-detect from DB_URL 22 | 23 | Returns: 24 | Connection parameter model instance 25 | 26 | Raises: 27 | ValueError: If database type cannot be determined or is not supported 28 | """ 29 | # Auto-detect database type if not specified and db_url is available 30 | if db_type is None and conn_params.get('db_url'): 31 | detected_db_type = detect_database_type(conn_params['db_url']) 32 | if detected_db_type: 33 | db_type = detected_db_type 34 | 35 | # If still no db_type, default to PostgreSQL for backward compatibility 36 | if db_type is None: 37 | db_type = DatabaseType.POSTGRES 38 | 39 | # Create the appropriate connection parameter object 40 | if db_type == DatabaseType.POSTGRES: 41 | return PostgresConnectionParams(**conn_params) 42 | elif db_type == DatabaseType.MYSQL: 43 | return MySQLConnectionParams(**conn_params) 44 | else: 45 | raise ValueError(f'Unsupported database type: {db_type}') 46 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/abstract/base_constraint_marshaler.py: -------------------------------------------------------------------------------- 1 | """Abstract base class for constraint marshaling.""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | from supabase_pydantic.db.models import TableInfo 6 | 7 | 8 | class BaseConstraintMarshaler(ABC): 9 | """Abstract base class for constraint data marshaling.""" 10 | 11 | @abstractmethod 12 | def parse_foreign_key(self, constraint_def: str) -> tuple[str, str, str] | None: 13 | """Parse foreign key definition from database-specific syntax. 14 | 15 | Args: 16 | constraint_def: Constraint definition string. 17 | 18 | Returns: 19 | Tuple containing (source_column, target_table, target_column) or None. 20 | """ 21 | pass 22 | 23 | @abstractmethod 24 | def parse_unique_constraint(self, constraint_def: str) -> list[str]: 25 | """Parse unique constraint from database-specific syntax. 26 | 27 | Args: 28 | constraint_def: Constraint definition string. 29 | 30 | Returns: 31 | List of column names in the unique constraint. 32 | """ 33 | pass 34 | 35 | @abstractmethod 36 | def parse_check_constraint(self, constraint_def: str) -> str: 37 | """Parse check constraint from database-specific syntax. 38 | 39 | Args: 40 | constraint_def: Constraint definition string. 41 | 42 | Returns: 43 | Parsed check constraint expression. 44 | """ 45 | pass 46 | 47 | @abstractmethod 48 | def add_constraints_to_table_details( 49 | self, tables: dict[tuple[str, str], TableInfo], constraint_data: list[tuple], schema: str 50 | ) -> None: 51 | """Add constraint information to table details. 52 | 53 | Args: 54 | tables: Dictionary of table information objects. 55 | constraint_data: Raw constraint data from database. 56 | schema: Schema name. 57 | """ 58 | pass 59 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/type_factory.py: -------------------------------------------------------------------------------- 1 | """Factory for database-specific type maps.""" 2 | 3 | from typing import cast 4 | 5 | from supabase_pydantic.core.constants import PYDANTIC_TYPE_MAP, SQLALCHEMY_TYPE_MAP, SQLALCHEMY_V2_TYPE_MAP 6 | from supabase_pydantic.db.database_type import DatabaseType 7 | from supabase_pydantic.db.drivers.mysql.type_maps import ( 8 | MYSQL_PYDANTIC_TYPE_MAP, 9 | MYSQL_SQLALCHEMY_TYPE_MAP, 10 | MYSQL_SQLALCHEMY_V2_TYPE_MAP, 11 | ) 12 | 13 | # Type annotations for imported maps 14 | PydanticTypeMap = dict[str, tuple[str, str | None]] 15 | SQLAlchemyTypeMap = dict[str, tuple[str, str | None]] 16 | SQLAlchemyV2TypeMap = dict[str, tuple[str, str | None]] 17 | 18 | 19 | class TypeMapFactory: 20 | """Factory for database-specific type maps.""" 21 | 22 | @staticmethod 23 | def get_pydantic_type_map(db_type: DatabaseType) -> dict[str, tuple[str, str | None]]: 24 | """Get the Pydantic type map for a specific database type.""" 25 | if db_type == DatabaseType.MYSQL: 26 | return cast(PydanticTypeMap, MYSQL_PYDANTIC_TYPE_MAP) 27 | # Default to PostgreSQL 28 | return cast(PydanticTypeMap, PYDANTIC_TYPE_MAP) 29 | 30 | @staticmethod 31 | def get_sqlalchemy_type_map(db_type: DatabaseType) -> dict[str, tuple[str, str | None]]: 32 | """Get the SQLAlchemy type map for a specific database type.""" 33 | if db_type == DatabaseType.MYSQL: 34 | return cast(SQLAlchemyTypeMap, MYSQL_SQLALCHEMY_TYPE_MAP) 35 | # Default to PostgreSQL 36 | return cast(SQLAlchemyTypeMap, SQLALCHEMY_TYPE_MAP) 37 | 38 | @staticmethod 39 | def get_sqlalchemy_v2_type_map(db_type: DatabaseType) -> dict[str, tuple[str, str | None]]: 40 | """Get the SQLAlchemy V2 type map for a specific database type.""" 41 | if db_type == DatabaseType.MYSQL: 42 | return cast(SQLAlchemyV2TypeMap, MYSQL_SQLALCHEMY_V2_TYPE_MAP) 43 | # Default to PostgreSQL 44 | return cast(SQLAlchemyV2TypeMap, SQLALCHEMY_V2_TYPE_MAP) 45 | -------------------------------------------------------------------------------- /src/supabase_pydantic/core/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class EnumInfo: 6 | name: str # The name of the enum type in the DB 7 | values: list[str] # The possible values for the enum 8 | schema: str = 'public' # The schema, defaulting to 'public' 9 | 10 | def python_class_name(self) -> str: 11 | """Converts DB enum name to PascalCase for Python class, prefixed by schema. 12 | 13 | Handles various input formats: 14 | - snake_case: 'order_status' -> 'OrderStatusEnum' 15 | - camelCase: 'thirdType' -> 'ThirdTypeEnum' 16 | - PascalCase: 'FourthType' -> 'FourthTypeEnum' 17 | - mixed: 'Fifth_Type' -> 'FifthTypeEnum' 18 | - with leading underscore: '_first_type' -> 'FirstTypeEnum' 19 | 20 | Final class name is prefixed by schema: 'public.order_status' -> 'PublicOrderStatusEnum' 21 | """ 22 | # Handle empty name edge case 23 | if not self.name: 24 | return f'{self.schema.capitalize()}Enum' 25 | 26 | # Remove leading underscore if present (for PostgreSQL compatibility) 27 | clean_name = self.name 28 | if clean_name.startswith('_'): 29 | clean_name = clean_name[1:] 30 | 31 | # Special case: if the name is already PascalCase or camelCase without underscores 32 | if '_' not in clean_name and any(c.isupper() for c in clean_name): 33 | # Ensure first character is uppercase 34 | class_name = clean_name[0].upper() + clean_name[1:] + 'Enum' 35 | else: 36 | # Standard snake_case processing 37 | class_name = ''.join(word.capitalize() for word in clean_name.split('_')) + 'Enum' 38 | 39 | # Add schema prefix 40 | return f'{self.schema.capitalize()}{class_name}' 41 | 42 | def python_member_name(self, value: str) -> str: 43 | """Converts enum value to a valid Python identifier. 44 | 45 | e.g., 'pending_new' -> 'pending_new' 46 | """ 47 | return value.lower() 48 | -------------------------------------------------------------------------------- /src/supabase_pydantic/utils/formatting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import subprocess 3 | 4 | # Get Logger 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class RuffNotFoundError(Exception): 9 | """Custom exception raised when the ruff executable is not found.""" 10 | 11 | def __init__(self, file_path: str, message: str = 'Ruff executable not found. Formatting skipped.'): 12 | self.file_path = file_path 13 | self.message = f'{message} For file: {file_path}' 14 | super().__init__(self.message) 15 | 16 | 17 | def format_with_ruff(file_path: str) -> None: 18 | """Run the ruff formatter, import sorter, and fixer on a specified Python file.""" 19 | try: 20 | # First run ruff check --fix to handle imports and other issues 21 | import_result = subprocess.run( 22 | ['ruff', 'check', '--select', 'I', '--fix', file_path], text=True, capture_output=True 23 | ) 24 | if import_result.returncode != 0: 25 | logger.debug(f'Ruff import sorting had issues: {import_result.stderr}') 26 | 27 | # Then run ruff format for code formatting 28 | format_result = subprocess.run(['ruff', 'format', file_path], text=True, capture_output=True) 29 | if format_result.returncode != 0: 30 | logger.debug(f'Ruff formatting had issues: {format_result.stderr}') 31 | 32 | # Finally run ruff check --fix for any remaining issues 33 | fix_result = subprocess.run(['ruff', 'check', '--fix', file_path], text=True, capture_output=True) 34 | if fix_result.returncode != 0: 35 | logger.debug(f'Ruff fixing had issues: {fix_result.stderr}') 36 | 37 | # Even if there were warnings, the file is likely formatted correctly 38 | logger.info(f'File formatted successfully: {file_path}') 39 | 40 | except subprocess.CalledProcessError as e: 41 | logger.warning(f'An error occurred while trying to format {file_path} with ruff:') 42 | logger.warning(e.stderr if e.stderr else 'No stderr output available') 43 | logger.warning('The file was generated, but not formatted.') 44 | except FileNotFoundError: 45 | raise RuffNotFoundError(file_path=file_path) 46 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/postgres/relationship.py: -------------------------------------------------------------------------------- 1 | from supabase_pydantic.db.constants import RelationType 2 | from supabase_pydantic.db.marshalers.abstract.base_relationship_marshaler import BaseRelationshipMarshaler 3 | from supabase_pydantic.db.marshalers.relationships import ( 4 | analyze_table_relationships as analyze_relationships, 5 | ) 6 | from supabase_pydantic.db.models import TableInfo 7 | 8 | 9 | class PostgresRelationshipMarshaler(BaseRelationshipMarshaler): 10 | """PostgreSQL implementation of relationship marshaler.""" 11 | 12 | def determine_relationship_type( 13 | self, source_table: TableInfo, target_table: TableInfo, source_column: str, target_column: str 14 | ) -> RelationType: 15 | """Determine the type of relationship between two tables.""" 16 | # Check if the source column is unique or primary key 17 | is_source_unique = any( 18 | col.name == source_column and (col.is_unique or col.primary) for col in source_table.columns 19 | ) 20 | 21 | # Check if the target column is unique or primary key 22 | is_target_unique = any( 23 | col.name == target_column and (col.is_unique or col.primary) for col in target_table.columns 24 | ) 25 | 26 | # Determine relationship type based on uniqueness 27 | if is_source_unique and is_target_unique: 28 | # If both sides are unique, it's a one-to-one relationship 29 | return RelationType.ONE_TO_ONE 30 | elif is_source_unique and not is_target_unique: 31 | # If only source is unique, it's one-to-many fr`om source to target 32 | return RelationType.ONE_TO_MANY 33 | elif not is_source_unique and is_target_unique: 34 | # If only target is unique, it's many-to-one from source to target 35 | return RelationType.MANY_TO_ONE 36 | else: 37 | # If neither side is unique, it's many-to-many 38 | return RelationType.MANY_TO_MANY 39 | 40 | def analyze_table_relationships(self, tables: dict[tuple[str, str], TableInfo]) -> None: 41 | """Analyze relationships between tables.""" 42 | analyze_relationships(tables) 43 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/mysql/column.py: -------------------------------------------------------------------------------- 1 | """MySQL column marshaler implementation.""" 2 | 3 | import logging 4 | 5 | from supabase_pydantic.db.database_type import DatabaseType 6 | from supabase_pydantic.db.marshalers.abstract.base_column_marshaler import BaseColumnMarshaler 7 | from supabase_pydantic.db.marshalers.column import ( 8 | get_alias as get_col_alias, 9 | ) 10 | from supabase_pydantic.db.marshalers.column import ( 11 | process_udt_field, 12 | ) 13 | from supabase_pydantic.db.marshalers.column import ( 14 | standardize_column_name as std_column_name, 15 | ) 16 | 17 | # Get Logger 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class MySQLColumnMarshaler(BaseColumnMarshaler): 22 | """MySQL column marshaler implementation.""" 23 | 24 | def standardize_column_name(self, column_name: str, disable_model_prefix_protection: bool = False) -> str: 25 | """Standardize column names across database types.""" 26 | result = std_column_name(column_name, disable_model_prefix_protection) 27 | return str(result) if result is not None else '' 28 | 29 | def get_alias(self, column_name: str) -> str: 30 | """Get alias for a column name.""" 31 | result = get_col_alias(column_name) 32 | return str(result) if result is not None else '' 33 | 34 | def process_column_type(self, db_type: str, type_info: str, enum_types: list[str] | None = None) -> str: 35 | """Process database-specific column type into standard Python type.""" 36 | # Call process_udt_field with the correct parameter order: 37 | # (udt_name, data_type, db_type, known_enum_types) 38 | result = process_udt_field(type_info, db_type, db_type=DatabaseType.MYSQL, known_enum_types=enum_types) 39 | 40 | # Handle None return value by returning empty string 41 | if result is None: 42 | return '' 43 | 44 | return str(result) 45 | 46 | def process_array_type(self, element_type: str) -> str: 47 | """Process array types for the database. 48 | 49 | MySQL doesn't natively support array types like PostgreSQL, but we implement 50 | this method for compatibility with the base class. 51 | """ 52 | return f'list[{element_type}]' 53 | -------------------------------------------------------------------------------- /tests/unit/utils/serialization/test_json_encoder.py: -------------------------------------------------------------------------------- 1 | """Tests for the CustomJsonEncoder in supabase_pydantic.utils.serialization.""" 2 | 3 | import decimal 4 | import json 5 | from datetime import date, datetime 6 | 7 | import pytest 8 | 9 | from supabase_pydantic.utils.serialization import CustomJsonEncoder 10 | 11 | 12 | @pytest.mark.unit 13 | @pytest.mark.utils 14 | @pytest.mark.serialization 15 | def test_encode_decimal(): 16 | """Test encoding of decimal.Decimal.""" 17 | value = decimal.Decimal('10.5') 18 | expected = '"10.5"' # JSON representation should be a string of the decimal value 19 | result = json.dumps(value, cls=CustomJsonEncoder) 20 | assert result == expected, 'Decimal should be serialized to a JSON string.' 21 | 22 | 23 | @pytest.mark.unit 24 | @pytest.mark.utils 25 | @pytest.mark.serialization 26 | def test_encode_datetime(): 27 | """Test encoding of datetime.datetime.""" 28 | dt = datetime(2022, 7, 20, 15, 30, 45) 29 | expected = '"2022-07-20T15:30:45"' # ISO format for datetime 30 | result = json.dumps(dt, cls=CustomJsonEncoder) 31 | assert result == expected, 'Datetime should be serialized to ISO 8601 string.' 32 | 33 | 34 | @pytest.mark.unit 35 | @pytest.mark.utils 36 | @pytest.mark.serialization 37 | def test_encode_date(): 38 | """Test encoding of datetime.date.""" 39 | d = date(2022, 7, 20) 40 | expected = '"2022-07-20"' # ISO format for date 41 | result = json.dumps(d, cls=CustomJsonEncoder) 42 | assert result == expected, 'Date should be serialized to ISO 8601 string.' 43 | 44 | 45 | @pytest.mark.unit 46 | @pytest.mark.utils 47 | @pytest.mark.serialization 48 | def test_encode_complex_object(): 49 | """Test encoding of complex object with mixed types.""" 50 | obj = { 51 | 'decimal': decimal.Decimal('12.34'), 52 | 'datetime': datetime(2023, 1, 1, 12, 0), 53 | 'date': date(2023, 1, 1), 54 | 'string': 'hello', 55 | 'integer': 42, 56 | } 57 | expected = '{"decimal": "12.34", "datetime": "2023-01-01T12:00:00", "date": "2023-01-01", "string": "hello", "integer": 42}' 58 | result = json.dumps(obj, cls=CustomJsonEncoder) 59 | assert result == expected, 'Complex object should be serialized correctly with custom handling for specified types.' 60 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | fail_fast: true 4 | exclude: 'docs/.*|cli/.*|.venv/.*|.vscode/.*|serverless.yml|poetry.lock|Makefile|Dockerfile*|.env*' 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.5.0 8 | hooks: 9 | - id: trailing-whitespace 10 | stages: [commit] 11 | - id: end-of-file-fixer 12 | stages: [commit] 13 | - id: check-yaml 14 | stages: [commit] 15 | exclude: mkdocs.yml 16 | - id: sort-simple-yaml 17 | stages: [commit] 18 | - id: check-added-large-files 19 | stages: [commit] 20 | - id: check-ast 21 | stages: [commit] 22 | - id: check-json 23 | stages: [commit] 24 | - id: pretty-format-json 25 | stages: [commit] 26 | args: [--autofix, --indent, "2"] 27 | - id: check-toml 28 | stages: [commit] 29 | - id: debug-statements 30 | stages: [commit] 31 | - id: mixed-line-ending 32 | stages: [commit] 33 | - id: requirements-txt-fixer 34 | stages: [commit] 35 | - id: sort-simple-yaml 36 | stages: [commit] 37 | - id: detect-private-key 38 | stages: [commit] 39 | - id: no-commit-to-branch 40 | stages: [commit] 41 | args: [--branch, staging, --branch, main, --branch, master] 42 | - id: check-case-conflict 43 | stages: [commit] 44 | - id: check-merge-conflict 45 | stages: [commit] 46 | - repo: https://github.com/commitizen-tools/commitizen 47 | rev: v3.15.0 48 | hooks: 49 | - id: commitizen 50 | stages: [commit-msg] 51 | - repo: https://github.com/astral-sh/ruff-pre-commit 52 | rev: v0.2.2 53 | hooks: 54 | - id: ruff # linter without fixer 55 | stages: [ push ] 56 | types_or: [ python3, python ] 57 | exclude: '^poc/' 58 | - repo: https://github.com/pre-commit/mirrors-mypy 59 | rev: v1.8.0 60 | hooks: 61 | - id: mypy 62 | stages: [push] 63 | exclude: '^tests/|^poc/|whitelist.py' 64 | args: [--ignore-missing-imports, --warn-unused-configs, --check-untyped-defs] 65 | additional_dependencies: 66 | - pydantic 67 | - pydantic-settings 68 | - types-toml 69 | - types-pytz 70 | - fastapi 71 | - passlib 72 | - supabase 73 | -------------------------------------------------------------------------------- /.github/workflows/python-multi-version.yml: -------------------------------------------------------------------------------- 1 | name: Python Multi-Version Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | workflow_dispatch: 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ['3.10', '3.11', '3.12', '3.13'] 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Install Poetry 30 | uses: snok/install-poetry@v1 31 | with: 32 | version: 1.7.1 33 | virtualenvs-create: true 34 | virtualenvs-in-project: true 35 | 36 | - name: Install dependencies 37 | run: poetry install 38 | 39 | - name: Install tox directly 40 | run: pip install "tox>=4.11.3" 41 | 42 | - name: Run tests 43 | run: tox -e py$(echo ${{ matrix.python-version }} | tr -d '.') 44 | 45 | # Only run coverage report on Python 3.10 46 | - name: Run tests with coverage and analytics 47 | if: matrix.python-version == '3.10' 48 | id: coverage 49 | continue-on-error: true 50 | run: | 51 | poetry run pytest --cov=supabase_pydantic --cov-report=term-missing --cov-report=xml \ 52 | --junitxml=junit.xml -o junit_family=legacy 53 | 54 | COVERAGE=$(cat coverage.xml | grep -o 'line-rate="[0-9]\.[0-9]\+"' | head -1 | cut -d'"' -f2) 55 | COVERAGE=$(echo "$COVERAGE * 100" | bc) 56 | echo "Coverage percentage is $COVERAGE%" 57 | if (( $(echo "$COVERAGE < 90" | bc -l) )); then 58 | echo "Coverage is below 90% threshold" 59 | exit 1 60 | fi 61 | 62 | - name: Upload coverage to Codecov 63 | if: matrix.python-version == '3.10' 64 | uses: codecov/codecov-action@v5 65 | with: 66 | files: ./coverage.xml 67 | flags: unittests 68 | fail_ci_if_error: false 69 | env: 70 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 71 | 72 | - name: Upload test results to Codecov 73 | if: ${{ !cancelled() }} && matrix.python-version == '3.10' 74 | uses: codecov/test-results-action@v1 75 | with: 76 | files: ./junit.xml 77 | env: 78 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 79 | -------------------------------------------------------------------------------- /tests/unit/db/marshalers/postgres/test_analyze_table_relationships.py: -------------------------------------------------------------------------------- 1 | """Tests for PostgresRelationshipMarshaler.analyze_table_relationships method.""" 2 | 3 | import pytest 4 | 5 | from supabase_pydantic.db.constants import RelationType 6 | from supabase_pydantic.db.marshalers.postgres.relationship import PostgresRelationshipMarshaler 7 | from supabase_pydantic.db.models import ColumnInfo, ForeignKeyInfo, TableInfo 8 | 9 | 10 | @pytest.fixture 11 | def marshaler(): 12 | """Fixture to create a PostgresRelationshipMarshaler instance.""" 13 | return PostgresRelationshipMarshaler() 14 | 15 | 16 | @pytest.mark.unit 17 | @pytest.mark.db 18 | @pytest.mark.marshalers 19 | def test_analyze_table_relationships(marshaler): 20 | """Test that analyze_table_relationships correctly processes table relationships.""" 21 | # Create mock tables with foreign keys 22 | tables = { 23 | ('public', 'users'): TableInfo( 24 | name='users', 25 | schema='public', 26 | columns=[ 27 | ColumnInfo(name='id', post_gres_datatype='integer', datatype='integer', is_unique=True, primary=True) 28 | ], 29 | ), 30 | ('public', 'orders'): TableInfo( 31 | name='orders', 32 | schema='public', 33 | columns=[ 34 | ColumnInfo(name='id', post_gres_datatype='integer', datatype='integer', is_unique=True, primary=True), 35 | ColumnInfo( 36 | name='user_id', post_gres_datatype='integer', datatype='integer', is_unique=False, primary=False 37 | ), 38 | ], 39 | ), 40 | } 41 | 42 | # Add foreign key to orders table 43 | tables[('public', 'orders')].foreign_keys.append( 44 | ForeignKeyInfo( 45 | constraint_name='fk_user_id', 46 | column_name='user_id', 47 | foreign_table_name='users', 48 | foreign_column_name='id', 49 | foreign_table_schema='public', 50 | ) 51 | ) 52 | 53 | # Call the method 54 | marshaler.analyze_table_relationships(tables) 55 | 56 | # Verify relationships were analyzed 57 | assert tables[('public', 'orders')].foreign_keys[0].relation_type is not None 58 | 59 | # Since user_id is not unique in orders but id is unique in users, 60 | # this should be a MANY_TO_ONE relationship from orders to users 61 | assert tables[('public', 'orders')].foreign_keys[0].relation_type == RelationType.MANY_TO_ONE 62 | -------------------------------------------------------------------------------- /tests/unit/utils/formatting/test_formatting.py: -------------------------------------------------------------------------------- 1 | """Tests for formatting utilities in supabase_pydantic.utils.formatting.""" 2 | 3 | import logging 4 | import subprocess 5 | import pytest 6 | 7 | from supabase_pydantic.utils.formatting import format_with_ruff, RuffNotFoundError 8 | 9 | 10 | @pytest.mark.unit 11 | @pytest.mark.formatting 12 | def test_format_with_ruff_failure_fails_silently(mocker, caplog): 13 | """Test format_with_ruff handles CalledProcessError correctly by logging warnings.""" 14 | error_output = 'An error occurred' 15 | mocker.patch('subprocess.run', side_effect=subprocess.CalledProcessError(1, 'isort', stderr=error_output)) 16 | # Set log level to capture warnings 17 | caplog.set_level(logging.WARNING) 18 | # Run the function that should log a warning 19 | format_with_ruff('non_existent_file.py') 20 | # Check that the error was logged 21 | assert error_output in caplog.text 22 | assert 'An error occurred while trying to format non_existent_file.py with ruff' in caplog.text 23 | assert 'The file was generated, but not formatted' in caplog.text 24 | 25 | 26 | @pytest.mark.unit 27 | @pytest.mark.formatting 28 | def test_format_with_ruff_file_not_found(mocker): 29 | """Test format_with_ruff raises RuffNotFoundError when ruff is not found.""" 30 | # Simulate ruff executable not being found 31 | mocker.patch('subprocess.run', side_effect=FileNotFoundError('ruff')) 32 | 33 | # Check that the correct exception is raised 34 | with pytest.raises(RuffNotFoundError) as excinfo: 35 | format_with_ruff('test_file.py') 36 | 37 | # Check the exception message 38 | assert 'Ruff executable not found. Formatting skipped.' in str(excinfo.value) 39 | assert 'test_file.py' in str(excinfo.value) 40 | 41 | 42 | @pytest.mark.unit 43 | @pytest.mark.formatting 44 | def test_ruff_not_found_error_constructor(): 45 | """Test the RuffNotFoundError constructor.""" 46 | # Test with default message 47 | error1 = RuffNotFoundError(file_path='test_file.py') 48 | assert error1.file_path == 'test_file.py' 49 | assert 'Ruff executable not found. Formatting skipped.' in str(error1) 50 | assert 'For file: test_file.py' in str(error1) 51 | 52 | # Test with custom message 53 | error2 = RuffNotFoundError(file_path='another_file.py', message='Custom error message') 54 | assert error2.file_path == 'another_file.py' 55 | assert 'Custom error message' in str(error2) 56 | assert 'For file: another_file.py' in str(error2) 57 | -------------------------------------------------------------------------------- /tests/unit/utils/types/test_types.py: -------------------------------------------------------------------------------- 1 | """Tests for type utilities in supabase_pydantic.utils.types.""" 2 | 3 | import pytest 4 | 5 | from supabase_pydantic.core.constants import FrameWorkType, OrmType 6 | from supabase_pydantic.utils.types import ( 7 | get_enum_member_from_string, 8 | get_pydantic_type, 9 | get_sqlalchemy_type, 10 | ) 11 | 12 | 13 | @pytest.mark.unit 14 | @pytest.mark.types 15 | def test_all_postgres_data_types_are_mapped_to_pydantic_and_sqlalchemy(): 16 | """Test that all PostgreSQL data types are mapped to Pydantic and SQLAlchemy types.""" 17 | postgres_data_types = [ 18 | 'smallint', 19 | 'integer', 20 | 'bigint', 21 | 'decimal', 22 | 'numeric', 23 | 'real', 24 | 'double precision', 25 | 'serial', 26 | 'bigserial', 27 | 'money', 28 | 'character varying(n)', 29 | 'varchar(n)', 30 | 'character(n)', 31 | 'char(n)', 32 | 'text', 33 | 'bytea', 34 | 'timestamp', 35 | 'timestamp with time zone', 36 | 'date', 37 | 'time', 38 | 'time with time zone', 39 | 'interval', 40 | 'boolean', 41 | 'enum', 42 | 'point', 43 | 'line', 44 | 'lseg', 45 | 'box', 46 | 'path', 47 | 'polygon', 48 | 'circle', 49 | 'cidr', 50 | 'inet', 51 | 'macaddr', 52 | 'macaddr8', 53 | 'bit', 54 | 'bit varying', 55 | 'tsvector', 56 | 'tsquery', 57 | 'uuid', 58 | 'xml', 59 | 'json', 60 | 'jsonb', 61 | 'integer[]', 62 | 'any', 63 | 'void', 64 | 'record', 65 | ] 66 | 67 | for data_type in postgres_data_types: 68 | assert get_pydantic_type(data_type) is not None 69 | assert get_sqlalchemy_type(data_type) is not None 70 | 71 | 72 | @pytest.mark.unit 73 | @pytest.mark.types 74 | def test_get_enum_member_from_string_and_enums(): 75 | """Test get_enum_member_from_string and FrameWorkType & OrmType.""" 76 | assert get_enum_member_from_string(OrmType, 'pydantic') == OrmType.PYDANTIC 77 | assert get_enum_member_from_string(OrmType, 'sqlalchemy') == OrmType.SQLALCHEMY 78 | with pytest.raises(ValueError): 79 | get_enum_member_from_string(OrmType, 'invalid') 80 | 81 | assert get_enum_member_from_string(FrameWorkType, 'fastapi') == FrameWorkType.FASTAPI 82 | with pytest.raises(ValueError): 83 | get_enum_member_from_string(FrameWorkType, 'invalid') 84 | -------------------------------------------------------------------------------- /src/supabase_pydantic/cli/common.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections.abc import Callable 3 | from typing import Any 4 | 5 | import click 6 | import toml 7 | 8 | from supabase_pydantic.utils.constants import AppConfig, ToolConfig 9 | 10 | # Standard choices 11 | model_choices = ['pydantic', 'sqlalchemy'] 12 | framework_choices = ['fastapi'] 13 | 14 | 15 | def check_readiness(env_vars: dict[str, str | None]) -> bool: 16 | """Check if environment variables are set correctly.""" 17 | if not env_vars: 18 | logging.error('No environment variables provided.') 19 | return False 20 | for k, v in env_vars.items(): 21 | logging.debug(f'Checking environment variable: {k}') 22 | if v is None: 23 | logging.error(f'Environment variables not set correctly. {k} is missing. Please set it in .env file.') 24 | return False 25 | 26 | logging.debug('All required environment variables are set') 27 | return True 28 | 29 | 30 | def load_config(config_path: str | None = None) -> AppConfig: 31 | """Load the configuration from pyproject.toml or specified config file.""" 32 | if config_path: 33 | # Load from specified config file 34 | try: 35 | with open(config_path) as f: 36 | # Handle YAML files if that's what you use 37 | if config_path.endswith('.yaml') or config_path.endswith('.yml'): 38 | import yaml 39 | 40 | return yaml.safe_load(f) 41 | # Handle TOML files 42 | elif config_path.endswith('.toml'): 43 | return toml.load(f).get('tool', {}).get('supabase_pydantic', {}) 44 | except FileNotFoundError: 45 | click.echo(f'Config file not found: {config_path}') 46 | return {} 47 | else: 48 | # Default to pyproject.toml 49 | try: 50 | with open('pyproject.toml') as f: 51 | config_data: dict[str, Any] = toml.load(f) 52 | tool_config: ToolConfig = config_data.get('tool', {}) 53 | app_config: AppConfig = tool_config.get('supabase_pydantic', {}) 54 | return app_config 55 | except FileNotFoundError: 56 | return {} 57 | 58 | 59 | # Common option groups and decorators 60 | def common_options(f: Callable) -> Callable: 61 | """Common options for CLI commands.""" 62 | decorators = [ 63 | click.option('--config', '-c', help='Path to config file.'), 64 | # Add other common options here 65 | ] 66 | 67 | # Apply all decorators 68 | for decorator in reversed(decorators): 69 | f = decorator(f) 70 | return f 71 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support and Help 2 | 3 | Need help getting started or using a project? Here's how. 4 | 5 | ## How to get help 6 | 7 | Generally, we do not use GitHub as a support forum, but we recognize that this community is in a nascent state. For any usage questions that are not specific to the project itself, please ask on [Stack Overflow](https://stackoverflow.com) instead. By doing so, you are more likely to quickly solve your problem, and you will allow anyone else with the same question to find the answer. This also allows maintainers to focus on improving the project for others. 8 | 9 | Please seek support in the following ways: 10 | 11 | 1. :book: **Read the documentation and other guides** for the project to see if you can figure it out on your own. These should be located in a root `docs/` directory. If there is an example project, explore that to learn how it works to see if you can answer your question. 12 | 13 | 1. :bulb: **Search for answers and ask questions on [Stack Overflow](https://stackoverflow.com).** This is the most appropriate place for debugging issues specific to your use of the project, or figuring out how to use the project in a specific way. 14 | 15 | 1. :memo: As a **last resort**, you may open an issue on GitHub to ask for help. However, please clearly explain what you are trying to do, and list what you have already attempted to solve the problem. Provide code samples, but **do not** attach your entire project for someone else to debug & **do not** fill your issue details with *excessive* logs. Review our [contributing guidelines](https://github.com/kmbhm1/supabase-pydantic/blob/main/CONTRIBUTING.md). 16 | 17 | ## What NOT to do 18 | 19 | Please **do not** do any the following: 20 | 21 | 1. :x: Do not reach out to the author or contributor on Twitter (or other social media) by tweeting or sending a direct message. 22 | 23 | 1. :x: Do not email the author or contributor. 24 | 25 | 1. :x: Do not open duplicate issues or litter an existing issue with +1's. 26 | 27 | These are not appropriate avenues for seeking help or support with an open-source project. Please follow the guidelines in the previous section. Public questions get public answers, which benefits everyone in the community. ✌️ 28 | 29 | ## Customer Support 30 | 31 | I do not provide any sort of "customer support" for open-source projects. However, I am available for hire. For details on contracting and consulting, visit [LinkedIn](https://www.linkedin.com/in/kevin-b-stl/) or email [@kmbhm1](mailto:kmbhm1@gmail.com). Please do not abuse this privilege. 🙏 32 | 33 | ## Credits 34 | 35 | Written by [@kmbhm1](https://github.com/kmbhm1) based on docs by [@jessesquire](https://github.com/jessesquires). 36 | 37 | **Please feel free to adopt this guide in your own projects. Fork it wholesale or remix it for your needs.** 38 | -------------------------------------------------------------------------------- /tests/unit/core/writers/test_writer_util.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import re 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from supabase_pydantic.core.constants import WriterClassType 8 | from supabase_pydantic.core.writers.utils import ( 9 | generate_unique_filename, 10 | get_base_class_post_script, 11 | get_section_comment, 12 | ) 13 | 14 | 15 | @pytest.mark.unit 16 | @pytest.mark.writers 17 | @pytest.mark.utils 18 | @pytest.mark.parametrize( 19 | 'table_type, class_type, expected_output', 20 | [ 21 | ('VIEW', WriterClassType.PARENT, 'ViewBaseSchemaParent'), 22 | ('VIEW', WriterClassType.BASE, 'ViewBaseSchema'), 23 | ('TABLE', WriterClassType.PARENT, 'BaseSchemaParent'), 24 | ('TABLE', WriterClassType.BASE, 'BaseSchema'), 25 | ], 26 | ) 27 | def test_get_base_class_post_script(table_type, class_type, expected_output): 28 | assert get_base_class_post_script(table_type, class_type) == expected_output 29 | 30 | 31 | @pytest.mark.unit 32 | @pytest.mark.writers 33 | @pytest.mark.utils 34 | def test_generate_unique_filename(): 35 | base_name = 'test' 36 | extension = '.py' 37 | directory = '/fake/directory' 38 | 39 | with patch('os.path.exists', side_effect=lambda x: x.endswith('test.py') or x.endswith('test_1.py')) as _: 40 | unique_filename = generate_unique_filename(base_name, extension, directory) 41 | match = re.search(r'test_(\d{20})\.py$', unique_filename) 42 | assert match is not None, f'Filename does not match the expected pattern: {unique_filename}' 43 | 44 | # Validate the datetime format 45 | # Assuming the format is YYYYMMDDHHMMSSffffffffff (year, month, day, hour, minute, second, microsecond) 46 | datetime_str = match.group(1) 47 | try: 48 | datetime.datetime.strptime(datetime_str, '%Y%m%d%H%M%S%f') 49 | except ValueError: 50 | assert False, 'Datetime format is incorrect' 51 | 52 | 53 | @pytest.mark.unit 54 | @pytest.mark.writers 55 | @pytest.mark.utils 56 | def test_get_section_comment_without_notes(): 57 | comment_title = 'Test Section' 58 | expected_output = '# TEST SECTION' 59 | assert get_section_comment(comment_title) == expected_output 60 | 61 | 62 | @pytest.mark.unit 63 | @pytest.mark.writers 64 | @pytest.mark.utils 65 | def test_get_section_comment_with_notes(): 66 | comment_title = 'Test Section' 67 | notes = ['This is a very long note that should be chunked into multiple lines to fit the width restriction.'] 68 | expected_output = ( 69 | '# TEST SECTION\n# Note: This is a ' 70 | 'very long note that should be chunked into multiple lines to\n' 71 | '# fit the width restriction.' 72 | ) 73 | assert get_section_comment(comment_title, notes) == expected_output 74 | -------------------------------------------------------------------------------- /tests/unit/db/marshalers/abstract/test_base_relationship_marshaler.py: -------------------------------------------------------------------------------- 1 | """Tests for BaseRelationshipMarshaler abstract class.""" 2 | 3 | import pytest 4 | 5 | from supabase_pydantic.db.marshalers.abstract.base_relationship_marshaler import BaseRelationshipMarshaler 6 | 7 | 8 | class TestBaseRelationshipMarshaler: 9 | """Tests for BaseRelationshipMarshaler class.""" 10 | 11 | def test_base_relationship_marshaler_initialization(self): 12 | """Test that BaseRelationshipMarshaler can't be instantiated directly.""" 13 | with pytest.raises(TypeError): 14 | BaseRelationshipMarshaler() 15 | 16 | def test_abstract_determine_relationship_type_method(self): 17 | """Test that determine_relationship_type abstract method raises NotImplementedError when not overridden.""" 18 | 19 | # Create a minimal subclass that doesn't override the abstract methods 20 | class IncompleteMarshaler(BaseRelationshipMarshaler): 21 | pass 22 | 23 | # Should raise TypeError when instantiating because abstract methods aren't implemented 24 | with pytest.raises(TypeError): 25 | IncompleteMarshaler() 26 | 27 | # Create a minimal implementation that implements only one method 28 | class PartialMarshaler(BaseRelationshipMarshaler): 29 | def analyze_table_relationships(self, tables): 30 | pass 31 | 32 | # Should still raise TypeError since not all abstract methods are implemented 33 | with pytest.raises(TypeError): 34 | PartialMarshaler() 35 | 36 | def test_abstract_analyze_table_relationships_method(self): 37 | """Test that analyze_table_relationships abstract method raises NotImplementedError when not overridden.""" 38 | 39 | # Create a minimal implementation that implements only one method 40 | class PartialMarshaler(BaseRelationshipMarshaler): 41 | def determine_relationship_type(self, source_table, target_table, source_column, target_column): 42 | pass 43 | 44 | # Should still raise TypeError since not all abstract methods are implemented 45 | with pytest.raises(TypeError): 46 | PartialMarshaler() 47 | 48 | def test_complete_implementation(self): 49 | """Test that a complete implementation can be instantiated.""" 50 | 51 | class CompleteMarshaler(BaseRelationshipMarshaler): 52 | def determine_relationship_type(self, source_table, target_table, source_column, target_column): 53 | pass 54 | 55 | def analyze_table_relationships(self, tables): 56 | pass 57 | 58 | # Should not raise any exception 59 | marshaler = CompleteMarshaler() 60 | assert isinstance(marshaler, BaseRelationshipMarshaler) 61 | -------------------------------------------------------------------------------- /src/supabase_pydantic/utils/io.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | 5 | # Get Logger 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def clean_directory(directory: str) -> None: 10 | """Remove all files & directories in the specified directory.""" 11 | if not (os.path.isdir(directory) and len(os.listdir(directory)) == 0): 12 | for file in os.listdir(directory): 13 | file_path = os.path.join(directory, file) 14 | try: 15 | if os.path.isfile(file_path): 16 | os.unlink(file_path) 17 | elif os.path.isdir(file_path): 18 | shutil.rmtree(file_path) 19 | except Exception as e: 20 | logger.error(f'An error occurred while deleting {file_path}.') 21 | logger.error(str(e)) 22 | 23 | shutil.rmtree(directory) 24 | 25 | 26 | def clean_directories(directories: dict[str, str | None]) -> None: 27 | """Remove all files & directories in the specified directories.""" 28 | for k, d in directories.items(): 29 | if k == 'default' or d is None: 30 | continue 31 | 32 | logger.info(f'Checking for directory: {d}') 33 | if not os.path.isdir(d): 34 | logger.warning(f'Directory "{d}" does not exist. Skipping ...') 35 | continue 36 | 37 | logger.info(f'Directory found. Removing {d} and files...') 38 | clean_directory(d) 39 | logger.info(f'Directory {d} removed.') 40 | 41 | logger.info(f'Removing default directory {directories["default"]} ...') 42 | assert directories['default'] is not None 43 | clean_directory(directories['default']) 44 | logger.info('Default directory removed.') 45 | 46 | 47 | def get_working_directories( 48 | default_directory: str, frameworks: tuple, auto_create: bool = False 49 | ) -> dict[str, str | None]: 50 | """Get the directories for the generated files.""" 51 | if not os.path.exists(default_directory) or not os.path.isdir(default_directory): 52 | logger.warning(f'Directory {default_directory} does not exist.') 53 | 54 | directories = { 55 | 'default': default_directory, 56 | 'fastapi': None if 'fastapi' not in frameworks else os.path.join(default_directory, 'fastapi'), 57 | } 58 | 59 | if auto_create: 60 | create_directories_if_not_exist(directories) 61 | 62 | return directories 63 | 64 | 65 | def create_directories_if_not_exist(directories: dict[str, str | None]) -> None: 66 | """Create the directories for the generated files if they do not exist.""" 67 | # Check if the directory exists, if not, create it 68 | for _, d in directories.items(): 69 | if d is None: 70 | continue 71 | if not os.path.exists(d) or not os.path.isdir(d): 72 | logger.info(f'Creating directory: {d}') 73 | os.makedirs(d) 74 | -------------------------------------------------------------------------------- /src/supabase_pydantic/utils/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from supabase_pydantic.db.database_type import DatabaseType 4 | from supabase_pydantic.db.type_factory import TypeMapFactory 5 | 6 | # enum helpers 7 | 8 | 9 | def get_enum_member_from_string(cls: Any, value: str) -> Any: 10 | """Get an Enum member from a string value.""" 11 | value_lower = value.lower() 12 | for member in cls: 13 | if member.value == value_lower: 14 | return member 15 | raise ValueError(f"'{value}' is not a valid {cls.__name__}") 16 | 17 | 18 | # type map helpers 19 | 20 | 21 | def adapt_type_map( 22 | db_type: str, 23 | default_type: tuple[str, str | None], 24 | type_map: dict[str, tuple[str, str | None]], 25 | ) -> tuple[str, str | None]: 26 | """Adapt a database data type to a Pydantic and SQLAlchemy type.""" 27 | array_suffix = '[]' 28 | if db_type.endswith(array_suffix): 29 | base_type = db_type[: -len(array_suffix)] 30 | sqlalchemy_type, import_statement = type_map.get(base_type, default_type) 31 | adapted_type = f'ARRAY({sqlalchemy_type})' 32 | import_statement = ( 33 | f'{import_statement}, ARRAY' if import_statement else 'from sqlalchemy.dialects.postgresql import ARRAY' 34 | ) 35 | else: 36 | adapted_type, import_statement = type_map.get(db_type, default_type) 37 | 38 | return (adapted_type, import_statement) 39 | 40 | 41 | def get_sqlalchemy_type( 42 | db_type: str, 43 | database_type: DatabaseType = DatabaseType.POSTGRES, 44 | default: tuple[str, str | None] = ('String', 'from sqlalchemy import String'), 45 | ) -> tuple[str, str | None]: 46 | """Get the SQLAlchemy type from the database type.""" 47 | type_map = TypeMapFactory.get_sqlalchemy_type_map(database_type) 48 | return adapt_type_map(db_type, default, type_map) 49 | 50 | 51 | def get_sqlalchemy_v2_type( 52 | db_type: str, 53 | database_type: DatabaseType = DatabaseType.POSTGRES, 54 | default: tuple[str, str | None] = ('String,str', 'from sqlalchemy import String'), 55 | ) -> tuple[str, str, str | None]: 56 | """Get the SQLAlchemy v2 type from the database type.""" 57 | type_map = TypeMapFactory.get_sqlalchemy_v2_type_map(database_type) 58 | both_types, imports = adapt_type_map(db_type, default, type_map) 59 | sql, py = both_types.split(',') 60 | # print(f'get_sqlalchemy_v2_type({db_type}, {database_type}, {default}) -> ({sql}, {py}, {imports})') 61 | return (sql, py, imports) 62 | 63 | 64 | def get_pydantic_type( 65 | db_type: str, 66 | database_type: DatabaseType = DatabaseType.POSTGRES, 67 | default: tuple[str, str | None] = ('Any', 'from typing import Any'), 68 | ) -> tuple[str, str | None]: 69 | """Get the Pydantic type from the database type.""" 70 | type_map = TypeMapFactory.get_pydantic_type_map(database_type) 71 | return adapt_type_map(db_type, default, type_map) 72 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/registrations.py: -------------------------------------------------------------------------------- 1 | """Database component registrations for the factory.""" 2 | 3 | import logging 4 | 5 | from supabase_pydantic.db.connectors.mysql.connector import MySQLConnector 6 | from supabase_pydantic.db.connectors.mysql.schema_reader import MySQLSchemaReader 7 | from supabase_pydantic.db.connectors.postgres.connector import PostgresConnector 8 | from supabase_pydantic.db.connectors.postgres.schema_reader import PostgresSchemaReader 9 | from supabase_pydantic.db.database_type import DatabaseType 10 | from supabase_pydantic.db.factory import DatabaseFactory 11 | from supabase_pydantic.db.marshalers.mysql.column import MySQLColumnMarshaler 12 | from supabase_pydantic.db.marshalers.mysql.constraints import MySQLConstraintMarshaler 13 | from supabase_pydantic.db.marshalers.mysql.relationship import MySQLRelationshipMarshaler 14 | from supabase_pydantic.db.marshalers.mysql.schema import MySQLSchemaMarshaler 15 | from supabase_pydantic.db.marshalers.postgres.column import PostgresColumnMarshaler 16 | from supabase_pydantic.db.marshalers.postgres.constraints import PostgresConstraintMarshaler 17 | from supabase_pydantic.db.marshalers.postgres.relationship import PostgresRelationshipMarshaler 18 | from supabase_pydantic.db.marshalers.postgres.schema import PostgresSchemaMarshaler 19 | 20 | # Get Logger 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def register_database_components() -> None: 25 | """Register all database components with the factory.""" 26 | # Register PostgreSQL components 27 | logger.debug('Registering PostgreSQL components with factory') 28 | DatabaseFactory.register_connector(DatabaseType.POSTGRES, PostgresConnector) 29 | DatabaseFactory.register_schema_reader(DatabaseType.POSTGRES, PostgresSchemaReader) 30 | DatabaseFactory.register_column_marshaler(DatabaseType.POSTGRES, PostgresColumnMarshaler) 31 | DatabaseFactory.register_constraint_marshaler(DatabaseType.POSTGRES, PostgresConstraintMarshaler) 32 | DatabaseFactory.register_relationship_marshaler(DatabaseType.POSTGRES, PostgresRelationshipMarshaler) 33 | DatabaseFactory.register_schema_marshaler(DatabaseType.POSTGRES, PostgresSchemaMarshaler) 34 | 35 | # Register MySQL components 36 | logger.debug('Registering MySQL components with factory') 37 | DatabaseFactory.register_connector(DatabaseType.MYSQL, MySQLConnector) 38 | DatabaseFactory.register_schema_reader(DatabaseType.MYSQL, MySQLSchemaReader) 39 | DatabaseFactory.register_column_marshaler(DatabaseType.MYSQL, MySQLColumnMarshaler) 40 | DatabaseFactory.register_constraint_marshaler(DatabaseType.MYSQL, MySQLConstraintMarshaler) 41 | DatabaseFactory.register_relationship_marshaler(DatabaseType.MYSQL, MySQLRelationshipMarshaler) 42 | DatabaseFactory.register_schema_marshaler(DatabaseType.MYSQL, MySQLSchemaMarshaler) 43 | 44 | # TODO: Add other database types as needed 45 | 46 | logger.debug('Database components registered with factory') 47 | -------------------------------------------------------------------------------- /tests/unit/db/graph/test_dependencies.py: -------------------------------------------------------------------------------- 1 | """Tests for dependency graph functions in supabase_pydantic.db.graph.""" 2 | 3 | import pytest 4 | 5 | from supabase_pydantic.db.models import ForeignKeyInfo, TableInfo 6 | from supabase_pydantic.db.graph import ( 7 | build_dependency_graph, 8 | sort_tables_by_in_degree, 9 | topological_sort, 10 | ) 11 | 12 | 13 | @pytest.mark.unit 14 | @pytest.mark.db 15 | @pytest.mark.graph 16 | def test_build_dependency_graph(): 17 | table_a = TableInfo(name='A') 18 | table_b = TableInfo( 19 | name='B', 20 | foreign_keys=[ 21 | ForeignKeyInfo(foreign_table_name='A', constraint_name='foo', column_name='id', foreign_column_name='baz') 22 | ], 23 | ) 24 | table_c = TableInfo( 25 | name='C', 26 | foreign_keys=[ 27 | ForeignKeyInfo(foreign_table_name='B', constraint_name='foo', column_name='id', foreign_column_name='baz') 28 | ], 29 | ) 30 | table_d = TableInfo( 31 | name='D', 32 | foreign_keys=[ 33 | ForeignKeyInfo(foreign_table_name='A', constraint_name='foo', column_name='id', foreign_column_name='baz') 34 | ], 35 | ) 36 | 37 | tables = [table_a, table_b, table_c, table_d] 38 | 39 | graph, in_degree = build_dependency_graph(tables) 40 | 41 | expected_graph = {'A': ['B', 'D'], 'B': ['C']} 42 | 43 | expected_in_degree = {'A': 0, 'B': 1, 'C': 1, 'D': 1} 44 | 45 | assert graph == expected_graph 46 | assert in_degree == expected_in_degree 47 | 48 | 49 | @pytest.mark.unit 50 | @pytest.mark.db 51 | @pytest.mark.graph 52 | def test_topological_sort(): 53 | table_a = TableInfo(name='A') 54 | table_b = TableInfo( 55 | name='B', 56 | foreign_keys=[ 57 | ForeignKeyInfo(foreign_table_name='A', constraint_name='foo', column_name='id', foreign_column_name='baz') 58 | ], 59 | ) 60 | table_c = TableInfo( 61 | name='C', 62 | foreign_keys=[ 63 | ForeignKeyInfo(foreign_table_name='B', constraint_name='foo', column_name='id', foreign_column_name='baz') 64 | ], 65 | ) 66 | table_d = TableInfo( 67 | name='D', 68 | foreign_keys=[ 69 | ForeignKeyInfo(foreign_table_name='A', constraint_name='foo', column_name='id', foreign_column_name='baz') 70 | ], 71 | ) 72 | 73 | tables = [table_a, table_b, table_c, table_d] 74 | 75 | in_degree, graph = topological_sort(tables) 76 | 77 | expected_graph = {'A': ['B', 'D'], 'B': ['C']} 78 | 79 | expected_in_degree = {'A': 0, 'B': 1, 'C': 1, 'D': 1} 80 | 81 | assert graph == expected_graph 82 | assert in_degree == expected_in_degree 83 | 84 | 85 | @pytest.mark.unit 86 | @pytest.mark.db 87 | @pytest.mark.graph 88 | def test_sort_tables_by_in_degree(): 89 | in_degree = {'A': 0, 'B': 1, 'C': 2, 'D': 1} 90 | 91 | sorted_tables = sort_tables_by_in_degree(in_degree) 92 | 93 | expected_sorted_tables = ['A', 'B', 'D', 'C'] 94 | 95 | assert sorted_tables == expected_sorted_tables 96 | -------------------------------------------------------------------------------- /src/supabase_pydantic/core/writers/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime, timezone 3 | from pathlib import Path 4 | 5 | from supabase_pydantic.core.constants import BASE_CLASS_POSTFIX, WriterClassType 6 | from supabase_pydantic.utils.strings import chunk_text 7 | 8 | 9 | def get_base_class_post_script(table_type: str, class_type: WriterClassType) -> str: 10 | """Method to generate the header for the base class.""" 11 | post: str = 'View' + BASE_CLASS_POSTFIX if table_type == 'VIEW' else BASE_CLASS_POSTFIX 12 | result: str = post + 'Parent' if class_type == WriterClassType.PARENT else post 13 | return result 14 | 15 | 16 | def generate_unique_filename(base_name: str, extension: str, directory: str = '.') -> str: 17 | """Generate a unique filename based on the base name & extension. 18 | 19 | Args: 20 | base_name (str): The base name of the file (without extension) 21 | extension (str): The extension of the file (e.g., 'py', 'json', 'txt') 22 | directory (str): The directory where the file will be saved 23 | 24 | Returns: 25 | str: The unique file name 26 | 27 | """ 28 | extension = extension.lstrip('.') 29 | dt_str = datetime.now(tz=timezone.utc).strftime('%Y%m%d%H%M%S%f') 30 | file_name = f'{base_name}_{dt_str}.{extension}' 31 | 32 | return os.path.join(directory, file_name) 33 | 34 | 35 | def get_section_comment(comment_title: str, notes: list[str] | None = None) -> str: 36 | """Method to generate a section of columns.""" 37 | comment = f'# {comment_title.upper()}' 38 | if notes is not None and len(notes) > 0: 39 | chunked_notes = [chunk_text(n, 70) for n in notes] 40 | for cn in chunked_notes: 41 | cn[0] = f'\n# Note: {cn[0]}' 42 | comment += '\n# '.join(cn) 43 | 44 | return comment 45 | 46 | 47 | def get_latest_filename(file_path: str) -> str: 48 | """Get the latest filename for a given file path.""" 49 | fp = Path(file_path) 50 | base, ext, directory = fp.stem, fp.suffix, str(fp.parent) 51 | latest_file = os.path.join(directory, f'{base}_latest{ext}') 52 | 53 | return latest_file 54 | 55 | 56 | def write_seed_file(seed_data: dict[str, list[list[str]]], file_path: str, overwrite: bool = False) -> list[str]: 57 | """Write seed data to a file.""" 58 | 59 | fp = Path(file_path) 60 | base, ext, directory = fp.stem, fp.suffix, str(fp.parent) 61 | file_paths = [get_latest_filename(file_path)] 62 | 63 | if not overwrite and os.path.exists(file_path): 64 | file_paths.append(generate_unique_filename(base, ext, directory)) 65 | 66 | for fpath in file_paths: 67 | with open(fpath, 'w') as f: 68 | for table_name, data in seed_data.items(): 69 | f.write(f'-- {table_name}\n') 70 | headers = data[0] 71 | data = data[1:] 72 | for row in data: 73 | f.write( 74 | f'INSERT INTO {table_name} ({", ".join(headers)}) VALUES ({", ".join([str(r) for r in row])});\n' # noqa: E501 75 | ) 76 | f.write('\n') 77 | 78 | return file_paths 79 | -------------------------------------------------------------------------------- /tests/unit/db/marshalers/postgres/test_postgres_constraint_marshaler.py: -------------------------------------------------------------------------------- 1 | """Tests for PostgreSQL constraint marshaler implementation.""" 2 | 3 | import pytest 4 | from unittest.mock import patch 5 | 6 | from supabase_pydantic.db.marshalers.postgres.constraints import PostgresConstraintMarshaler 7 | 8 | 9 | @pytest.fixture 10 | def marshaler(): 11 | """Fixture to create a PostgresConstraintMarshaler instance.""" 12 | return PostgresConstraintMarshaler() 13 | 14 | 15 | @pytest.mark.unit 16 | @pytest.mark.db 17 | @pytest.mark.marshalers 18 | @pytest.mark.parametrize( 19 | 'constraint_def, expected', 20 | [ 21 | ('FOREIGN KEY (order_id) REFERENCES orders(id)', ('order_id', 'orders', 'id')), 22 | ('FOREIGN KEY (user_id) REFERENCES users(id)', ('user_id', 'users', 'id')), 23 | ('NOT A FOREIGN KEY', None), 24 | ], 25 | ) 26 | def test_parse_foreign_key(marshaler, constraint_def, expected): 27 | """Test parsing foreign key constraint definition.""" 28 | with patch('supabase_pydantic.db.marshalers.constraints.parse_constraint_definition_for_fk', return_value=expected): 29 | result = marshaler.parse_foreign_key(constraint_def) 30 | assert result == expected, f'Failed for {constraint_def}, got {result}, expected {expected}' 31 | 32 | 33 | @pytest.mark.unit 34 | @pytest.mark.db 35 | @pytest.mark.marshalers 36 | @pytest.mark.parametrize( 37 | 'constraint_def, expected', 38 | [ 39 | ('UNIQUE (email)', ['email']), 40 | ('UNIQUE (first_name, last_name)', ['first_name', 'last_name']), 41 | ('NOT A UNIQUE CONSTRAINT', []), 42 | ], 43 | ) 44 | def test_parse_unique_constraint(marshaler, constraint_def, expected): 45 | """Test parsing unique constraint definition.""" 46 | result = marshaler.parse_unique_constraint(constraint_def) 47 | assert result == expected, f'Failed for {constraint_def}, got {result}, expected {expected}' 48 | 49 | 50 | @pytest.mark.unit 51 | @pytest.mark.db 52 | @pytest.mark.marshalers 53 | @pytest.mark.parametrize( 54 | 'constraint_def, expected', 55 | [ 56 | ('CHECK (age > 0)', 'CHECK (age > 0)'), 57 | ('CHECK (price > 0 AND price < 1000)', 'CHECK (price > 0 AND price < 1000)'), 58 | ], 59 | ) 60 | def test_parse_check_constraint(marshaler, constraint_def, expected): 61 | """Test parsing check constraint definition.""" 62 | result = marshaler.parse_check_constraint(constraint_def) 63 | assert result == expected, f'Failed for {constraint_def}, got {result}, expected {expected}' 64 | 65 | 66 | @pytest.mark.unit 67 | @pytest.mark.db 68 | @pytest.mark.marshalers 69 | def test_add_constraints_to_table_details(marshaler): 70 | """Test adding constraints to table details.""" 71 | # Mock data 72 | tables = {} 73 | constraint_data = [] 74 | schema = 'public' 75 | 76 | # Use the correct import path with the alias as used in constraints.py 77 | with patch('supabase_pydantic.db.marshalers.postgres.constraints.add_constraints') as mock_add: 78 | # Call the method 79 | marshaler.add_constraints_to_table_details(tables, constraint_data, schema) 80 | 81 | # Verify the underlying function was called with correct arguments 82 | mock_add.assert_called_once_with(tables, schema, constraint_data) 83 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/abstract/base_schema_reader.py: -------------------------------------------------------------------------------- 1 | """Abstract base schema reader for database introspection.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any 5 | 6 | 7 | class BaseSchemaReader(ABC): 8 | """Abstract base class for database schema introspection.""" 9 | 10 | def __init__(self, connector: Any): 11 | """Initialize schema reader with a connector. 12 | 13 | Args: 14 | connector: Database connector instance. 15 | """ 16 | self.connector = connector 17 | 18 | @abstractmethod 19 | def get_schemas(self, conn: Any) -> list[str]: 20 | """Get all schemas in the database. 21 | 22 | Args: 23 | conn: Database connection object. 24 | 25 | Returns: 26 | List of schema names. 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def get_tables(self, conn: Any, schema: str) -> list[tuple]: 32 | """Get all tables in the specified schema. 33 | 34 | Args: 35 | conn: Database connection object. 36 | schema: Schema name. 37 | 38 | Returns: 39 | List of table information as tuples. 40 | """ 41 | pass 42 | 43 | @abstractmethod 44 | def get_columns(self, conn: Any, schema: str) -> list[tuple]: 45 | """Get all columns in the specified schema. 46 | 47 | Args: 48 | conn: Database connection object. 49 | schema: Schema name. 50 | 51 | Returns: 52 | List of column information as tuples. 53 | """ 54 | pass 55 | 56 | @abstractmethod 57 | def get_foreign_keys(self, conn: Any, schema: str) -> list[tuple]: 58 | """Get all foreign keys in the specified schema. 59 | 60 | Args: 61 | conn: Database connection object. 62 | schema: Schema name. 63 | 64 | Returns: 65 | List of foreign key information as tuples. 66 | """ 67 | pass 68 | 69 | @abstractmethod 70 | def get_constraints(self, conn: Any, schema: str) -> list[tuple]: 71 | """Get all constraints in the specified schema. 72 | 73 | Args: 74 | conn: Database connection object. 75 | schema: Schema name. 76 | 77 | Returns: 78 | List of constraint information as tuples. 79 | """ 80 | pass 81 | 82 | @abstractmethod 83 | def get_user_defined_types(self, conn: Any, schema: str) -> list[tuple]: 84 | """Get all user-defined types in the specified schema. 85 | 86 | Args: 87 | conn: Database connection object. 88 | schema: Schema name. 89 | 90 | Returns: 91 | List of user-defined type information as tuples. 92 | """ 93 | pass 94 | 95 | @abstractmethod 96 | def get_type_mappings(self, conn: Any, schema: str) -> list[tuple]: 97 | """Get type mapping information for the specified schema. 98 | 99 | Args: 100 | conn: Database connection object. 101 | schema: Schema name. 102 | 103 | Returns: 104 | List of type mapping information as tuples. 105 | """ 106 | pass 107 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/abstract/base_schema_marshaler.py: -------------------------------------------------------------------------------- 1 | """Abstract base class for schema marshaling.""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | from supabase_pydantic.db.marshalers.abstract.base_column_marshaler import BaseColumnMarshaler 6 | from supabase_pydantic.db.marshalers.abstract.base_constraint_marshaler import BaseConstraintMarshaler 7 | from supabase_pydantic.db.marshalers.abstract.base_relationship_marshaler import BaseRelationshipMarshaler 8 | from supabase_pydantic.db.models import TableInfo 9 | 10 | 11 | class BaseSchemaMarshaler(ABC): 12 | """Abstract base class for database schema marshaling.""" 13 | 14 | def __init__( 15 | self, 16 | column_marshaler: BaseColumnMarshaler, 17 | constraint_marshaler: BaseConstraintMarshaler, 18 | relationship_marshaler: BaseRelationshipMarshaler, 19 | ): 20 | """Initialize schema marshaler with component marshalers. 21 | 22 | Args: 23 | column_marshaler: Component for handling column details. 24 | constraint_marshaler: Component for handling constraint details. 25 | relationship_marshaler: Component for handling relationship details. 26 | """ 27 | self.column_marshaler = column_marshaler 28 | self.constraint_marshaler = constraint_marshaler 29 | self.relationship_marshaler = relationship_marshaler 30 | 31 | @abstractmethod 32 | def get_table_details_from_columns(self, column_data: list[tuple]) -> dict[tuple[str, str], TableInfo]: 33 | """Process database-specific column data into TableInfo objects. 34 | 35 | Args: 36 | column_data: Raw column data from database. 37 | 38 | Returns: 39 | Dictionary mapping (schema, table) tuples to TableInfo objects. 40 | """ 41 | pass 42 | 43 | @abstractmethod 44 | def process_foreign_keys(self, tables: dict[tuple[str, str], TableInfo], fk_data: list[tuple]) -> None: 45 | """Add foreign key information to tables. 46 | 47 | Args: 48 | tables: Dictionary of table information objects. 49 | fk_data: Raw foreign key data from database. 50 | """ 51 | pass 52 | 53 | @abstractmethod 54 | def construct_table_info( 55 | self, 56 | table_data: list[tuple], 57 | column_data: list[tuple], 58 | fk_data: list[tuple], 59 | constraint_data: list[tuple], 60 | type_data: list[tuple], 61 | type_mapping_data: list[tuple], 62 | schema: str, 63 | disable_model_prefix_protection: bool = False, 64 | ) -> list[TableInfo]: 65 | """Construct TableInfo objects from database-specific data. 66 | 67 | Args: 68 | table_data: Raw table data from database. 69 | column_data: Raw column data from database. 70 | fk_data: Raw foreign key data from database. 71 | constraint_data: Raw constraint data from database. 72 | type_data: Raw user-defined type data from database. 73 | type_mapping_data: Raw type mapping data from database. 74 | schema: Schema name. 75 | disable_model_prefix_protection: Whether to disable model prefix protection. 76 | 77 | Returns: 78 | List of TableInfo objects. 79 | """ 80 | pass 81 | -------------------------------------------------------------------------------- /src/supabase_pydantic/core/writers/factories.py: -------------------------------------------------------------------------------- 1 | from supabase_pydantic.core.constants import FrameWorkType, OrmType 2 | from supabase_pydantic.core.writers.abstract import AbstractFileWriter 3 | from supabase_pydantic.core.writers.pydantic import PydanticFastAPIWriter 4 | from supabase_pydantic.core.writers.sqlalchemy import SqlAlchemyFastAPIWriter 5 | from supabase_pydantic.db.database_type import DatabaseType 6 | from supabase_pydantic.db.models import TableInfo 7 | 8 | 9 | class FileWriterFactory: 10 | @staticmethod 11 | def get_file_writer( 12 | tables: list[TableInfo], 13 | file_path: str, 14 | file_type: OrmType = OrmType.PYDANTIC, 15 | framework_type: FrameWorkType = FrameWorkType.FASTAPI, 16 | add_null_parent_classes: bool = False, 17 | generate_crud_models: bool = True, 18 | generate_enums: bool = True, 19 | disable_model_prefix_protection: bool = False, 20 | singular_names: bool = False, 21 | database_type: DatabaseType = DatabaseType.POSTGRES, 22 | ) -> AbstractFileWriter: 23 | """Get the file writer based on the provided parameters. 24 | 25 | Args: 26 | tables (list[TableInfo]): The list of tables. 27 | file_path (str): The file path. 28 | file_type (OrmType, optional): The ORM type. Defaults to OrmType.PYDANTIC. 29 | framework_type (FrameWorkType, optional): The framework type. Defaults to FrameWorkType.FASTAPI. 30 | add_null_parent_classes (bool, optional): Add null parent classes for base classes. Defaults to False. 31 | generate_crud_models (bool, optional): Generate CRUD models. (i.e., Insert, Update) Defaults to True. 32 | generate_enums (bool, optional): Generate Enum classes for enum columns. Defaults to True. 33 | disable_model_prefix_protection (bool, optional): Disable Pydantic's "model_" prefix protection. Defaults to False. 34 | singular_names (bool, optional): Generate class names in singular form. Defaults to False. 35 | database_type (DatabaseType, optional): The database type. Defaults to DatabaseType.POSTGRES. 36 | 37 | Returns: 38 | The file writer instance. 39 | """ # noqa: E501 40 | match file_type, framework_type: 41 | case OrmType.SQLALCHEMY, FrameWorkType.FASTAPI: 42 | return SqlAlchemyFastAPIWriter( 43 | tables, 44 | file_path, 45 | add_null_parent_classes=add_null_parent_classes, 46 | singular_names=singular_names, 47 | database_type=database_type, 48 | ) 49 | case OrmType.PYDANTIC, FrameWorkType.FASTAPI: 50 | return PydanticFastAPIWriter( 51 | tables, 52 | file_path, 53 | add_null_parent_classes=add_null_parent_classes, 54 | generate_crud_models=generate_crud_models, 55 | generate_enums=generate_enums, 56 | disable_model_prefix_protection=disable_model_prefix_protection, 57 | singular_names=singular_names, 58 | database_type=database_type, 59 | ) 60 | case _: 61 | raise ValueError(f'Unsupported file type and framework: {file_type}, {framework_type}') 62 | -------------------------------------------------------------------------------- /src/supabase_pydantic/core/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | from supabase_pydantic.core.constants import FrameWorkType, OrmType 5 | from supabase_pydantic.db.database_type import DatabaseType 6 | 7 | 8 | @dataclass 9 | class WriterConfig: 10 | file_type: OrmType 11 | framework_type: FrameWorkType 12 | filename: str 13 | directory: str 14 | enabled: bool 15 | 16 | def ext(self) -> str: 17 | """Get the file extension based on the file name.""" 18 | return self.filename.split('.')[-1] 19 | 20 | def name(self) -> str: 21 | """Get the file name without the extension.""" 22 | return self.filename.split('.')[0] 23 | 24 | def fpath(self) -> str: 25 | """Get the full file path.""" 26 | return os.path.join(self.directory, self.filename) 27 | 28 | def to_dict(self) -> dict[str, str]: 29 | """Convert the WriterConfig object to a dictionary.""" 30 | return { 31 | 'file_type': str(self.file_type), 32 | 'framework_type': str(self.framework_type), 33 | 'filename': self.filename, 34 | 'directory': self.directory, 35 | 'enabled': str(self.enabled), 36 | } 37 | 38 | 39 | # job helpers 40 | def get_standard_jobs( 41 | models: tuple[str], frameworks: tuple[str], dirs: dict[str, str | None], schemas: tuple[str, ...] = ('public',) 42 | ) -> dict[str, dict[str, WriterConfig]]: 43 | """Get the standard jobs for the writer.""" 44 | 45 | jobs: dict[str, dict[str, WriterConfig]] = {} 46 | 47 | for schema in schemas: 48 | # Set the file names 49 | pydantic_fname, sqlalchemy_fname = f'schema_{schema}.py', f'database_{schema}.py' 50 | 51 | # Add the jobs 52 | if 'fastapi' in dirs and dirs['fastapi'] is not None: 53 | jobs[schema] = { 54 | 'Pydantic': WriterConfig( 55 | file_type=OrmType.PYDANTIC, 56 | framework_type=FrameWorkType.FASTAPI, 57 | filename=pydantic_fname, 58 | directory=dirs['fastapi'], 59 | enabled='pydantic' in models and 'fastapi' in frameworks, 60 | ), 61 | 'SQLAlchemy': WriterConfig( 62 | file_type=OrmType.SQLALCHEMY, 63 | framework_type=FrameWorkType.FASTAPI, 64 | filename=sqlalchemy_fname, 65 | directory=dirs['fastapi'], 66 | enabled='sqlalchemy' in models and 'fastapi' in frameworks, 67 | ), 68 | } 69 | 70 | return jobs 71 | 72 | 73 | def local_default_env_configuration(db_type: DatabaseType = DatabaseType.POSTGRES) -> dict[str, str | None]: 74 | """Get the environment variables for a local connection. 75 | 76 | Args: 77 | db_type: Database type (POSTGRES or MYSQL) 78 | 79 | Returns: 80 | Dict of environment variables with default values for local connection 81 | """ 82 | if db_type == DatabaseType.MYSQL: 83 | return { 84 | 'DB_NAME': 'mysql', 85 | 'DB_USER': 'root', 86 | 'DB_PASS': 'password', 87 | 'DB_HOST': 'localhost', 88 | 'DB_PORT': '3306', 89 | } 90 | else: # default to postgres 91 | return { 92 | 'DB_NAME': 'postgres', 93 | 'DB_USER': 'postgres', 94 | 'DB_PASS': 'postgres', 95 | 'DB_HOST': 'localhost', 96 | 'DB_PORT': '54322', 97 | } 98 | -------------------------------------------------------------------------------- /tests/unit/db/marshalers/test_column.py: -------------------------------------------------------------------------------- 1 | """Tests for column marshaler functions in supabase_pydantic.db.marshalers.column.""" 2 | 3 | import pytest 4 | 5 | from supabase_pydantic.db.marshalers.column import ( 6 | column_name_is_reserved, 7 | column_name_reserved_exceptions, 8 | get_alias, 9 | standardize_column_name, 10 | string_is_reserved, 11 | ) 12 | 13 | 14 | @pytest.mark.unit 15 | @pytest.mark.db 16 | @pytest.mark.marshalers 17 | @pytest.mark.parametrize( 18 | 'value, expected', 19 | [ 20 | ('int', True), # Python built-in 21 | ('print', True), # Python built-in 22 | ('for', True), # Python keyword 23 | ('def', True), # Python keyword 24 | ('username', False), # Not reserved 25 | ('customfield', False), # Not reserved 26 | ('id', True), # Built-in 27 | ], 28 | ) 29 | def test_string_is_reserved(value, expected): 30 | assert string_is_reserved(value) == expected, f'Failed for {value}' 31 | 32 | 33 | @pytest.mark.unit 34 | @pytest.mark.db 35 | @pytest.mark.marshalers 36 | @pytest.mark.parametrize( 37 | 'column_name, expected', 38 | [ 39 | ('int', True), 40 | ('print', True), 41 | ('model_abc', True), 42 | ('username', False), 43 | ('id', True), # Though 'id' is a built-in, check if it's treated as an exception in another test 44 | ], 45 | ) 46 | def test_column_name_is_reserved(column_name, expected): 47 | assert column_name_is_reserved(column_name) == expected, f'Failed for {column_name}' 48 | 49 | 50 | @pytest.mark.unit 51 | @pytest.mark.db 52 | @pytest.mark.marshalers 53 | @pytest.mark.parametrize( 54 | 'column_name, expected', [('id', True), ('ID', True), ('Id', True), ('username', False), ('int', False)] 55 | ) 56 | def test_column_name_reserved_exceptions(column_name, expected): 57 | assert column_name_reserved_exceptions(column_name) == expected, f'Failed for {column_name}' 58 | 59 | 60 | @pytest.mark.unit 61 | @pytest.mark.db 62 | @pytest.mark.marshalers 63 | @pytest.mark.parametrize( 64 | 'column_name, expected', 65 | [ 66 | ('int', 'field_int'), # Reserved keyword 67 | ('print', 'field_print'), # Built-in name 68 | ('model_abc', 'field_model_abc'), # Starts with 'model_' 69 | ('username', 'username'), # Not a reserved name 70 | ('id', 'id'), # Exception, should not be modified 71 | ('credits', 'credits'), # Business term exception, should not be modified 72 | ('copyright', 'copyright'), # Business term exception, should not be modified 73 | ('license', 'license'), # Business term exception, should not be modified 74 | ('help', 'help'), # Business term exception, should not be modified 75 | ('property', 'property'), # Business term exception, should not be modified 76 | ('sum', 'sum'), # Business term exception, should not be modified 77 | ], 78 | ) 79 | def test_standardize_column_name(column_name, expected): 80 | assert standardize_column_name(column_name) == expected, f'Failed for {column_name}' 81 | 82 | 83 | @pytest.mark.unit 84 | @pytest.mark.db 85 | @pytest.mark.marshalers 86 | @pytest.mark.parametrize( 87 | 'column_name, expected', 88 | [ 89 | ('int', 'int'), # Reserved keyword not an exception 90 | ('model_name', 'model_name'), # Starts with 'model_' 91 | ('id', None), # Exception to reserved keywords 92 | ('username', None), # Not a reserved name 93 | ], 94 | ) 95 | def test_get_alias(column_name, expected): 96 | assert get_alias(column_name) == expected, f'Failed for {column_name}' 97 | -------------------------------------------------------------------------------- /tests/unit/db/utils/test_url_parser.py: -------------------------------------------------------------------------------- 1 | """Tests for the URL parser module.""" 2 | 3 | import pytest 4 | 5 | from supabase_pydantic.db.database_type import DatabaseType 6 | from supabase_pydantic.db.utils.url_parser import detect_database_type, get_database_name_from_url 7 | 8 | 9 | @pytest.mark.unit 10 | @pytest.mark.db 11 | @pytest.mark.utils 12 | class TestDatabaseTypeDetection: 13 | """Tests for database type detection from URLs.""" 14 | 15 | @pytest.mark.parametrize( 16 | 'db_url, expected_type', 17 | [ 18 | ('postgresql://user:pass@localhost:5432/testdb', DatabaseType.POSTGRES), 19 | ('postgres://user:pass@localhost:5432/testdb', DatabaseType.POSTGRES), 20 | ('postgresql://localhost/testdb', DatabaseType.POSTGRES), 21 | ('postgresql://localhost:5432/testdb?sslmode=require', DatabaseType.POSTGRES), 22 | ('mysql://user:pass@localhost:3306/testdb', DatabaseType.MYSQL), 23 | ('mariadb://user:pass@localhost:3306/testdb', DatabaseType.MYSQL), 24 | ('mysql://localhost/testdb', DatabaseType.MYSQL), 25 | ('mysql://localhost:3306/testdb?charset=utf8mb4', DatabaseType.MYSQL), 26 | ], 27 | ) 28 | def test_detect_database_type_valid(self, db_url, expected_type): 29 | """Test detection of database types from valid URLs.""" 30 | result = detect_database_type(db_url) 31 | assert result == expected_type, f'Failed to detect {expected_type} from {db_url}' 32 | 33 | @pytest.mark.parametrize( 34 | 'db_url', 35 | [ 36 | '', 37 | None, 38 | 'sqlite:///path/to/db.sqlite', 39 | 'mongodb://localhost:27017', 40 | 'oracle://user:pass@localhost:1521/testdb', 41 | 'invalid_url', 42 | 'http://example.com', 43 | ], 44 | ) 45 | def test_detect_database_type_invalid(self, db_url): 46 | """Test handling of invalid or unsupported URLs.""" 47 | result = detect_database_type(db_url) 48 | assert result is None, f'Expected None for {db_url}, got {result}' 49 | 50 | 51 | @pytest.mark.unit 52 | @pytest.mark.db 53 | @pytest.mark.utils 54 | class TestDatabaseNameExtraction: 55 | """Tests for extracting database names from URLs.""" 56 | 57 | @pytest.mark.parametrize( 58 | 'db_url, expected_name', 59 | [ 60 | ('postgresql://user:pass@localhost:5432/testdb', 'testdb'), 61 | ('postgres://user:pass@localhost/mydatabase', 'mydatabase'), 62 | ('mysql://user:pass@localhost:3306/production_db', 'production_db'), 63 | ('postgresql://localhost/test-db-name', 'test-db-name'), 64 | ('mysql://localhost/db_with_underscores', 'db_with_underscores'), 65 | ('postgresql://user:pass@localhost:5432/testdb?sslmode=require', 'testdb'), 66 | ], 67 | ) 68 | def test_get_database_name_valid(self, db_url, expected_name): 69 | """Test extracting database names from valid URLs.""" 70 | result = get_database_name_from_url(db_url) 71 | assert result == expected_name, f'Failed to extract db name from {db_url}' 72 | 73 | @pytest.mark.parametrize( 74 | 'db_url', 75 | [ 76 | '', 77 | None, 78 | 'postgresql://localhost', 79 | 'mysql://localhost:3306', 80 | 'invalid_url', 81 | 'postgresql://', 82 | 'http://example.com', 83 | ], 84 | ) 85 | def test_get_database_name_invalid(self, db_url): 86 | """Test handling of invalid URLs or URLs without database names.""" 87 | result = get_database_name_from_url(db_url) 88 | assert result is None, f'Expected None for {db_url}, got {result}' 89 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Supabase Pydantic 2 | site_url: https://kmbhm1.github.io/supabase-pydantic/ 3 | repo_url: https://github.com/kmbhm1/supabase-pydantic 4 | repo_name: kmbhm1/supabase-pydantic 5 | site_description: Supabase Pydantic is a tool for generating ORM constructs for rapid prototyping and deployment with Supabase. 6 | site_author: KB 7 | 8 | theme: 9 | name: material 10 | logo: assets/otter-solid.svg 11 | palette: 12 | # Palette toggle for dark mode 13 | - media: "(prefers-color-scheme: light)" 14 | scheme: default 15 | primary: purple 16 | accent: pink 17 | toggle: 18 | icon: material/brightness-7 19 | name: Switch to dark mode 20 | 21 | # Palette toggle for dark mode 22 | - media: "(prefers-color-scheme: dark)" 23 | scheme: slate 24 | primary: deep purple 25 | accent: amber 26 | toggle: 27 | icon: material/brightness-4 28 | name: Switch to light mode 29 | icon: 30 | repo: fontawesome/brands/github-alt 31 | font: 32 | text: Roboto 33 | features: 34 | - content.code.copy 35 | - content.code.select 36 | - content.code.annotate 37 | - navigation.expand 38 | - navigation.indexes 39 | - navigation.instant 40 | - navigation.instant.progress 41 | - navigation.path 42 | - navigation.tabs 43 | - navigation.top 44 | - toc.follow 45 | - toc.integrate 46 | 47 | nav: 48 | - Getting Started: 49 | - Welcome: getting-started/welcome.md 50 | - Help with sb-pydantic: getting-started/getting-help.md 51 | - getting-started/installation.md 52 | - getting-started/contributing.md 53 | - Security: getting-started/security.md 54 | - Recent Changes: getting-started/changelog.md 55 | - Examples: 56 | - Create a Slack Clone: examples/setup-slack-simple-fastapi.md 57 | - Add seed.sql Data: examples/add-seed-sql-data.md 58 | - Analyze Multiple Schemas: examples/generate-multiple-basemodels.md 59 | - Working with Insert and Update Models: examples/insert-update-models.md 60 | - Generating Enum Models: examples/enum-model-generation.md 61 | - Singular Class Names: examples/singular-class-names.md 62 | - Model Prefix Protection: examples/model-prefix-example.md 63 | - SQLAlchemy ORM Models: examples/sqlalchemy-models.md 64 | - MySQL Database Support: examples/mysql-support.md 65 | - API: 66 | - API Reference: api/api-reference.md 67 | - CLI Reference: api/cli.md 68 | 69 | plugins: 70 | - search 71 | - mkdocstrings 72 | - redirects: 73 | redirect_maps: 74 | 'index.md': 'getting-started/welcome.md' 75 | 'examples/index.md': 'examples/setup-slack-simple-fastapi.md' 76 | 77 | markdown_extensions: 78 | - toc: 79 | permalink: true 80 | toc_depth: 3 81 | - pymdownx.snippets: 82 | check_paths: true 83 | - pymdownx.highlight: 84 | anchor_linenums: true 85 | line_spans: __span 86 | pygments_lang_class: true 87 | - pymdownx.inlinehilite 88 | - pymdownx.snippets 89 | - pymdownx.superfences 90 | - pymdownx.emoji: 91 | emoji_index: !!python/name:material.extensions.emoji.twemoji 92 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 93 | - attr_list 94 | - admonition 95 | - pymdownx.details 96 | 97 | extra: 98 | social: 99 | - icon: fontawesome/brands/github 100 | link: https://github.com/kmbhm1/supabase-pydantic 101 | - icon: fontawesome/brands/linkedin 102 | link: https://www.linkedin.com/in/kevin-b-stl/ 103 | generator: false 104 | 105 | copyright: Copyright © 2025 kmbhm1 106 | -------------------------------------------------------------------------------- /tests/unit/utils/config/test_utils_config.py: -------------------------------------------------------------------------------- 1 | from supabase_pydantic.core.config import WriterConfig 2 | from supabase_pydantic.core.constants import FrameWorkType, OrmType 3 | from supabase_pydantic.utils.config import get_standard_jobs, local_default_env_configuration 4 | 5 | 6 | class TestGetStandardJobs: 7 | def test_get_standard_jobs_with_fastapi_enabled(self): 8 | models = ('pydantic', 'sqlalchemy') 9 | frameworks = ('fastapi',) 10 | dirs = {'fastapi': '/path/to/fastapi'} 11 | schemas = ('public', 'custom') 12 | 13 | result = get_standard_jobs(models, frameworks, dirs, schemas) 14 | 15 | assert len(result) == 2 # Two schemas 16 | assert 'public' in result 17 | assert 'custom' in result 18 | 19 | # Check public schema config 20 | assert 'Pydantic' in result['public'] 21 | assert 'SQLAlchemy' in result['public'] 22 | 23 | # Verify Pydantic config 24 | pydantic_config = result['public']['Pydantic'] 25 | assert isinstance(pydantic_config, WriterConfig) 26 | assert pydantic_config.file_type == OrmType.PYDANTIC 27 | assert pydantic_config.framework_type == FrameWorkType.FASTAPI 28 | assert pydantic_config.filename == 'schema_public.py' 29 | assert pydantic_config.directory == '/path/to/fastapi' 30 | assert pydantic_config.enabled is True 31 | 32 | # Verify SQLAlchemy config 33 | sqlalchemy_config = result['public']['SQLAlchemy'] 34 | assert isinstance(sqlalchemy_config, WriterConfig) 35 | assert sqlalchemy_config.file_type == OrmType.SQLALCHEMY 36 | assert sqlalchemy_config.framework_type == FrameWorkType.FASTAPI 37 | assert sqlalchemy_config.filename == 'database_public.py' 38 | assert sqlalchemy_config.directory == '/path/to/fastapi' 39 | assert sqlalchemy_config.enabled is True 40 | 41 | # Check custom schema has correct filename 42 | assert result['custom']['Pydantic'].filename == 'schema_custom.py' 43 | assert result['custom']['SQLAlchemy'].filename == 'database_custom.py' 44 | 45 | def test_get_standard_jobs_with_partial_models(self): 46 | models = ('pydantic',) # Only pydantic, no sqlalchemy 47 | frameworks = ('fastapi',) 48 | dirs = {'fastapi': '/path/to/fastapi'} 49 | 50 | result = get_standard_jobs(models, frameworks, dirs) 51 | 52 | assert 'public' in result 53 | assert 'Pydantic' in result['public'] 54 | assert 'SQLAlchemy' in result['public'] 55 | 56 | # Pydantic should be enabled, SQLAlchemy disabled 57 | assert result['public']['Pydantic'].enabled is True 58 | assert result['public']['SQLAlchemy'].enabled is False 59 | 60 | def test_get_standard_jobs_with_no_fastapi_dir(self): 61 | models = ('pydantic', 'sqlalchemy') 62 | frameworks = ('fastapi',) 63 | dirs = {} # No fastapi directory 64 | 65 | result = get_standard_jobs(models, frameworks, dirs) 66 | 67 | assert result == {} # Should return empty dict since no fastapi dir 68 | 69 | def test_get_standard_jobs_with_none_fastapi_dir(self): 70 | models = ('pydantic', 'sqlalchemy') 71 | frameworks = ('fastapi',) 72 | dirs = {'fastapi': None} # None fastapi directory 73 | 74 | result = get_standard_jobs(models, frameworks, dirs) 75 | 76 | assert result == {} # Should return empty dict since fastapi dir is None 77 | 78 | 79 | class TestLocalDefaultEnvConfiguration: 80 | def test_local_default_env_configuration(self): 81 | result = local_default_env_configuration() 82 | 83 | assert isinstance(result, dict) 84 | assert result['DB_NAME'] == 'postgres' 85 | assert result['DB_USER'] == 'postgres' 86 | assert result['DB_PASS'] == 'postgres' 87 | assert result['DB_HOST'] == 'localhost' 88 | assert result['DB_PORT'] == '54322' 89 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/mysql/constraints.py: -------------------------------------------------------------------------------- 1 | """MySQL constraint marshaler implementation.""" 2 | 3 | import logging 4 | import re 5 | from typing import cast 6 | 7 | from supabase_pydantic.db.marshalers.abstract.base_constraint_marshaler import BaseConstraintMarshaler 8 | from supabase_pydantic.db.marshalers.constraints import ( 9 | add_constraints_to_table_details as add_constraints, 10 | ) 11 | from supabase_pydantic.db.marshalers.constraints import ( 12 | parse_constraint_definition_for_fk, 13 | ) 14 | from supabase_pydantic.db.models import TableInfo 15 | 16 | # Get Logger 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class MySQLConstraintMarshaler(BaseConstraintMarshaler): 21 | """MySQL constraint marshaler implementation.""" 22 | 23 | def parse_foreign_key(self, constraint_def: str) -> tuple[str, str, str] | None: 24 | """Parse foreign key definition from MySQL-specific syntax. 25 | 26 | Args: 27 | constraint_def: Constraint definition string. 28 | 29 | Returns: 30 | Tuple containing (source_column, target_table, target_column) or None. 31 | """ 32 | # Try to use the common parser first 33 | result = parse_constraint_definition_for_fk(constraint_def) 34 | if result: 35 | return cast(tuple[str, str, str], result) 36 | 37 | # MySQL-specific parsing for FOREIGN KEY constraints 38 | # Example: "FOREIGN KEY (`user_id`) REFERENCES `users` (`id`)" 39 | pattern = r'FOREIGN KEY\s+\(`?([^`)]+)`?\)\s+REFERENCES\s+`?([^`(]+)`?\s*\(`?([^`)]+)`?\)' 40 | match = re.match(pattern, constraint_def, re.IGNORECASE) 41 | if match: 42 | source_column, target_table, target_column = match.groups() 43 | # Strip backticks and any extra whitespace 44 | return (source_column.strip('`').strip(), target_table.strip('`').strip(), target_column.strip('`').strip()) 45 | 46 | return None 47 | 48 | def parse_unique_constraint(self, constraint_def: str) -> list[str]: 49 | """Parse unique constraint from MySQL-specific syntax. 50 | 51 | Args: 52 | constraint_def: Constraint definition string. 53 | 54 | Returns: 55 | List of column names in the unique constraint. 56 | """ 57 | # MySQL UNIQUE constraint format: "UNIQUE KEY `name` (`col1`, `col2`)" 58 | # or simply: "UNIQUE (`col1`, `col2`)" 59 | pattern = r'UNIQUE(?:\s+KEY\s+`?[^`]+`?)?\s*\(([^)]+)\)' 60 | match = re.match(pattern, constraint_def, re.IGNORECASE) 61 | if match: 62 | columns_str = match.group(1) 63 | # Split by comma and remove backticks 64 | columns = [col.strip().strip('`') for col in columns_str.split(',')] 65 | return columns 66 | return [] 67 | 68 | def parse_check_constraint(self, constraint_def: str) -> str: 69 | """Parse check constraint from MySQL-specific syntax. 70 | 71 | Args: 72 | constraint_def: Constraint definition string. 73 | 74 | Returns: 75 | Parsed check constraint expression. 76 | """ 77 | # MySQL CHECK constraint format: "CHECK (`col` > 0)" 78 | pattern = r'CHECK\s+\((.+)\)' 79 | match = re.match(pattern, constraint_def, re.IGNORECASE) 80 | if match: 81 | return match.group(1) 82 | return constraint_def 83 | 84 | def add_constraints_to_table_details( 85 | self, tables: dict[tuple[str, str], TableInfo], constraint_data: list[tuple], schema: str 86 | ) -> None: 87 | """Add constraint information to table details. 88 | 89 | Args: 90 | tables: Dictionary of table information objects. 91 | constraint_data: Raw constraint data from database. 92 | schema: Schema name. 93 | """ 94 | # Use the common implementation 95 | add_constraints(tables, schema, constraint_data) 96 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env* 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Generated Types 163 | entities/ 164 | notes/ 165 | 166 | # VS Code 167 | .vscode/ 168 | .python-version 169 | 170 | # macOS 171 | .DS_Store 172 | -------------------------------------------------------------------------------- /tests/integration/db/test_seed_generation.py: -------------------------------------------------------------------------------- 1 | """Integration tests for seed data generation. 2 | 3 | These tests require an actual database connection with test tables to run. 4 | """ 5 | 6 | import os 7 | 8 | import pytest 9 | from dotenv import load_dotenv 10 | 11 | from supabase_pydantic.db.database_type import DatabaseType 12 | from supabase_pydantic.db.factory import DatabaseFactory 13 | from supabase_pydantic.db.models import PostgresConnectionParams 14 | from supabase_pydantic.db.seed.generator import generate_seed_data 15 | 16 | 17 | # Load environment variables from .env file 18 | load_dotenv() 19 | 20 | 21 | @pytest.fixture 22 | def postgres_params(): 23 | """Get database connection parameters from environment variables.""" 24 | return PostgresConnectionParams( 25 | dbname=os.environ.get('TEST_DB_NAME', 'postgres'), 26 | user=os.environ.get('TEST_DB_USER', 'postgres'), 27 | password=os.environ.get('TEST_DB_PASS', 'postgres'), 28 | host=os.environ.get('TEST_DB_HOST', 'localhost'), 29 | port=os.environ.get('TEST_DB_PORT', '5432'), 30 | ) 31 | 32 | 33 | @pytest.fixture 34 | def postgres_url(): 35 | """Get database URL from environment variables.""" 36 | return os.environ.get( 37 | 'TEST_DB_URL', 38 | 'postgresql://postgres:postgres@localhost:5432/postgres', 39 | ) 40 | 41 | 42 | @pytest.fixture 43 | def schema_reader(postgres_params): 44 | """Get schema reader for the database.""" 45 | try: 46 | reader = DatabaseFactory.create_schema_reader( 47 | DatabaseType.POSTGRES, 48 | connector=DatabaseFactory.create_connector(DatabaseType.POSTGRES, connection_params=postgres_params), 49 | ) 50 | return reader 51 | except Exception as e: 52 | pytest.skip(f'Could not create schema reader: {str(e)}') 53 | 54 | 55 | @pytest.fixture 56 | def tables(schema_reader): 57 | """Get table information from the database using schema reader.""" 58 | try: 59 | tables_dict = {} 60 | schemas = schema_reader.get_schemas() 61 | for schema in schemas: 62 | if schema == 'public': # Only get 'public' schema 63 | table_models = schema_reader.get_tables(schema) 64 | tables_dict[schema] = {} 65 | for table in table_models: 66 | table.columns = schema_reader.get_columns(schema, table.name) 67 | table.constraints = schema_reader.get_constraints(schema, table.name) 68 | tables_dict[schema][table.name] = table 69 | return tables_dict 70 | except Exception as e: 71 | pytest.skip(f'Could not get tables: {str(e)}') 72 | 73 | 74 | @pytest.mark.integration 75 | @pytest.mark.db 76 | @pytest.mark.seed 77 | @pytest.mark.skipif( 78 | not os.environ.get('RUN_DB_TESTS'), 79 | reason='Database integration tests are disabled. Set RUN_DB_TESTS=1 to enable.', 80 | ) 81 | def test_generate_seed_data_integration(tables): 82 | """Test generating seed data from actual database schema.""" 83 | if not tables or 'public' not in tables: 84 | pytest.skip('No tables available for testing') 85 | 86 | tables_list = list(tables['public'].values()) 87 | if not tables_list: 88 | pytest.skip('No tables available in public schema for testing') 89 | 90 | # Generate seed data 91 | seed_data = generate_seed_data(tables_list) 92 | 93 | # Basic structure validation 94 | assert isinstance(seed_data, dict) 95 | 96 | # Each table should have data rows 97 | for table_name, data_rows in seed_data.items(): 98 | # First row should be column names 99 | assert len(data_rows) > 0 100 | 101 | # At least one row of data (could be just headers if no rows) 102 | if len(data_rows) > 1: 103 | # Check that all data rows have same number of columns as headers 104 | header_count = len(data_rows[0]) 105 | for row in data_rows[1:]: 106 | assert len(row) == header_count 107 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/graph.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | 4 | from supabase_pydantic.db.constants import RelationType 5 | from supabase_pydantic.db.models import RelationshipInfo, TableInfo 6 | 7 | # Get Logger 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def build_dependency_graph(tables: list[TableInfo]) -> tuple[defaultdict[str, list[str]], dict[str, int]]: 12 | """Build a dependency graph from the tables.""" 13 | graph = defaultdict(list) 14 | in_degree: dict[str, int] = defaultdict(int) 15 | base_tables = [table for table in tables if table.table_type != 'VIEW'] # Exclude views 16 | 17 | # Initialize in_degree 1 for all tables 18 | in_degree = {table.name: 0 for table in base_tables} 19 | 20 | # Build the graph and in_degree dictionary 21 | for table in base_tables: 22 | for fk in table.foreign_keys: 23 | # fk.foreign_table_name is the table this table depends on 24 | graph[fk.foreign_table_name].append(table.name) 25 | in_degree[table.name] += 1 26 | 27 | return graph, in_degree 28 | 29 | 30 | def topological_sort(tables: list[TableInfo]) -> tuple[dict, defaultdict]: 31 | """Topologically sort the tables based on foreign key relationships.""" 32 | in_degree = {t.name: 0 for t in tables} 33 | graph = defaultdict(list) 34 | 35 | for t in tables: 36 | for fk in t.foreign_keys: 37 | related_table = fk.foreign_table_name 38 | graph[related_table].append(t.name) 39 | in_degree[t.name] += 1 40 | 41 | return in_degree, graph 42 | 43 | 44 | def sort_tables_by_in_degree(in_degree: dict) -> list[str]: 45 | """Sort tables based on in-degree values.""" 46 | sorted_tables = sorted(in_degree.keys(), key=lambda x: (in_degree[x], x)) 47 | return sorted_tables 48 | 49 | 50 | def reorganize_tables_by_relationships(sorted_tables: list[str], relationships: list[RelationshipInfo]) -> list[str]: 51 | """Reorganize the initial sorted list of tables based on relationships.""" 52 | new_sorted_tables = sorted_tables.copy() 53 | for relationship in relationships: 54 | table = relationship.table_name 55 | foreign_table = relationship.related_table_name 56 | relation_type = relationship.relation_type 57 | 58 | # Only consider non-many-to-many relationships 59 | if relation_type != RelationType.MANY_TO_MANY: 60 | # Get the indexes of the table and foreign table 61 | table_index = new_sorted_tables.index(table) 62 | foreign_table_index = new_sorted_tables.index(foreign_table) 63 | 64 | # If the foreign table appears after the dependent table, reorder them 65 | if foreign_table_index > table_index: 66 | # Remove the foreign table and insert it after the dependent table 67 | new_sorted_tables.remove(foreign_table) 68 | new_sorted_tables.insert(table_index, foreign_table) 69 | 70 | return new_sorted_tables 71 | 72 | 73 | def separate_tables_list_by_type(tables: list[TableInfo], table_list: list[str]) -> tuple[list[str], list[str]]: 74 | """Separate tables into base tables and views.""" 75 | base_tables = [] 76 | views = [] 77 | 78 | for t in table_list: 79 | table = next((table for table in tables if table.name == t), None) 80 | if table is None: 81 | logger.warning(f'Could not find table {t}') 82 | continue 83 | 84 | if table.table_type != 'VIEW': 85 | base_tables.append(t) 86 | else: 87 | views.append(t) 88 | 89 | return base_tables, views 90 | 91 | 92 | def sort_tables_for_insert(tables: list[TableInfo]) -> tuple[list[str], list[str]]: 93 | """Sort tables based on foreign key relationships for insertion.""" 94 | in_degree, _ = topological_sort(tables) 95 | initial_sorted_list = sort_tables_by_in_degree(in_degree) 96 | relationships = [r for t in tables for r in t.relationships] 97 | sorted_tables = reorganize_tables_by_relationships(initial_sorted_list, relationships) 98 | 99 | return separate_tables_list_by_type(tables, sorted_tables) 100 | -------------------------------------------------------------------------------- /docs/getting-started/welcome.md: -------------------------------------------------------------------------------- 1 | # Supabase Pydantic 2 | 3 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/supabase-pydantic) 4 | [![Pydantic v2](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/pydantic/pydantic/main/docs/badge/v2.json)](https://pydantic.dev) 5 | ![GitHub License](https://img.shields.io/github/license/kmbhm1/supabase-pydantic) 6 | [![codecov](https://codecov.io/github/kmbhm1/supabase-pydantic/graph/badge.svg?token=PYOJPJTOLM)](https://codecov.io/github/kmbhm1/supabase-pydantic) 7 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/supabase-pydantic) 8 | 9 | 10 | Supabase Pydantic is a Python library that **generates** Pydantic models for Supabase - more models & database support to come :wink:. It is designed to enhance the utility of Supabase as an entity for rapid prototyping and development. 11 | 12 | ``` bash title="A First Example" hl_lines="1" 13 | $ sb-pydantic gen --type pydantic --framework fastapi --local 14 | 15 | 2023-09-12 14:30:15 - INFO - PostgreSQL connection opened successfully 16 | 2023-09-12 14:30:17 - INFO - Processing schema: public 17 | 2023-09-12 14:30:19 - INFO - PostgreSQL connection closed successfully 18 | 2023-09-12 14:30:20 - INFO - Generating FastAPI Pydantic models... 19 | 2023-09-12 14:30:22 - INFO - FastAPI Pydantic models generated successfully: /path/to/your/project/entities/fastapi/schemas_latest.py 20 | 2023-09-12 14:30:23 - INFO - File formatted successfully: /path/to/your/project/entities/fastapi/schemas_latest.py 21 | ``` 22 | 23 | Some users may find it more convenient to integrate a Makefile command: 24 | 25 | ``` bash title="Makefile" 26 | gen-types: 27 | @echo "Generating FastAPI Pydantic models..." 28 | @sb-pydantic gen --type pydantic --framework fastapi --dir --local 29 | ``` 30 | 31 | ## Why use supabase-pydantic? 32 | 33 | The [supabase-py](https://github.com/supabase-community/supabase-py) library currently lacks an automated system to enhance type safety and data validation in Python—a similar, but essential feature that is readily available and highly useful in the JavaScript/TypeScript library, as outlined in [Supabase's documentation](https://supabase.com/docs/reference/javascript/typescript-support#generating-typescript-types). 34 | 35 | [Pydantic](https://docs.pydantic.dev/latest/), a popular library for data validation and settings management in Python, leverages type annotations to validate data, significantly enhancing the robustness and clarity of code. While both TypeScript and Pydantic aim to improve type safety and structure within dynamic programming environments, a key distinction is that Pydantic validates data at runtime. 36 | 37 | This package aims to bridge the gap, delivering an enriched experience for Python developers with Supabase. It not only replicates the functionalities of the TypeScript library but is also finely tuned to the diverse tools and landscape of the Python community. Moreover, it's designed to turbocharge your workflow, making rapid prototyping not just faster but also a delightful adventure in development. 38 | 39 | ## The Bennies 40 | 41 | 1. ***Automated* enhanced type safety**: supabase-pydantic generates Pydantic models for Supabase without the hassle of manual setup, ensuring that your data is validated and structured correctly, requiring less oversight. This feature significantly enhances the robustness of deployment pipelines and clarity of your code, making it easier for you to focus on building your application. 42 | 43 | 2. **Rapid prototyping**: With supabase-pydantic, you can quickly generate FastAPI Pydantic models, saving you time and effort in setting up your project's schemas, models, etc. 44 | 45 | 3. **Seamless integration**: supabase-pydantic is intended to integrate seamlessly with development and automated pipelines, allowing you to leverage the power of Supabase while enjoying the benefits of Pydantic's data validation and settings management in a streamlined fashion. 46 | 47 | 4. **Pythonic experience**: Built specifically for the Python community, supabase-pydantic is finely tuned to the tools and landscape of Python development, providing a familiar and efficient workflow. 48 | 49 | 5. **Comprehensive documentation**: supabase-pydantic comes with comprehensive documentation, making it easy to get started and explore its features. 50 | 51 | -------------------------------------------------------------------------------- /tests/unit/core/writers/test_pydantic_constraints.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from supabase_pydantic.db.models import ColumnInfo, TableInfo 4 | from supabase_pydantic.core.writers.pydantic import PydanticFastAPIClassWriter 5 | 6 | 7 | @pytest.fixture 8 | def table_info(): 9 | """Return a TableInfo instance with constrained columns for testing.""" 10 | return TableInfo( 11 | name='User', 12 | schema='public', 13 | columns=[ 14 | ColumnInfo( 15 | name='country_code', 16 | post_gres_datatype='text', 17 | is_nullable=False, 18 | datatype='str', 19 | constraint_definition='CHECK (length(country_code) = 2)', 20 | ), 21 | ColumnInfo( 22 | name='username', 23 | post_gres_datatype='text', 24 | is_nullable=False, 25 | datatype='str', 26 | constraint_definition='CHECK (length(username) >= 4 AND length(username) <= 20)', 27 | ), 28 | ColumnInfo( 29 | name='description', 30 | post_gres_datatype='text', 31 | is_nullable=True, 32 | datatype='str', 33 | constraint_definition='CHECK (length(description) <= 100)', 34 | ), 35 | ], 36 | foreign_keys=[], 37 | constraints=[], 38 | ) 39 | 40 | 41 | @pytest.fixture 42 | def writer(table_info): 43 | """Return a PydanticFastAPIClassWriter instance for testing.""" 44 | return PydanticFastAPIClassWriter(table_info) 45 | 46 | 47 | @pytest.mark.unit 48 | @pytest.mark.writers 49 | @pytest.mark.pydantic 50 | def test_write_column_with_length_constraints(writer): 51 | """Test that write_column correctly handles length constraints for text fields.""" 52 | # Test fixed length (country_code) 53 | country_code_col = writer.write_column(writer.table.columns[0]) 54 | assert "country_code: Annotated[str, StringConstraints(**{'min_length': 2, 'max_length': 2})]" in country_code_col 55 | 56 | # Test min and max length (username) 57 | username_col = writer.write_column(writer.table.columns[1]) 58 | assert "username: Annotated[str, StringConstraints(**{'min_length': 4, 'max_length': 20})]" in username_col 59 | 60 | # Test max length with nullable (description) 61 | description_col = writer.write_column(writer.table.columns[2]) 62 | assert ( 63 | "description: Annotated[str, StringConstraints(**{'max_length': 100})] | None = Field(default=None)" 64 | in description_col 65 | ) 66 | 67 | # Test min and max length constraints 68 | table = TableInfo( 69 | name='User', 70 | schema='public', 71 | columns=[ 72 | ColumnInfo( 73 | name='description', 74 | post_gres_datatype='text', 75 | is_nullable=True, 76 | datatype='str', 77 | constraint_definition='CHECK (length(description) >= 10 AND length(description) <= 100)', 78 | ) 79 | ], 80 | ) 81 | writer = PydanticFastAPIClassWriter(table) 82 | result = writer.write_column(table.columns[0]) 83 | expected = "description: Annotated[str, StringConstraints(**{'min_length': 10, 'max_length': 100})] | None = Field(default=None)" 84 | assert result == expected 85 | 86 | # Test only max length constraint 87 | table.columns[0].constraint_definition = 'CHECK (length(description) <= 50)' 88 | result = writer.write_column(table.columns[0]) 89 | expected = "description: Annotated[str, StringConstraints(**{'max_length': 50})] | None = Field(default=None)" 90 | assert result == expected 91 | 92 | # Test only min length constraint 93 | table.columns[0].constraint_definition = 'CHECK (length(description) >= 5)' 94 | result = writer.write_column(table.columns[0]) 95 | expected = "description: Annotated[str, StringConstraints(**{'min_length': 5})] | None = Field(default=None)" 96 | assert result == expected 97 | 98 | # Test no length constraints 99 | table.columns[0].constraint_definition = None 100 | result = writer.write_column(table.columns[0]) 101 | expected = 'description: str | None = Field(default=None)' 102 | assert result == expected 103 | -------------------------------------------------------------------------------- /tests/integration/db/test_schema_introspection.py: -------------------------------------------------------------------------------- 1 | """Integration tests for database schema introspection. 2 | 3 | These tests require an actual database connection with test tables to run. 4 | """ 5 | 6 | import os 7 | 8 | import pytest 9 | from dotenv import load_dotenv 10 | 11 | from supabase_pydantic.db.database_type import DatabaseType 12 | from supabase_pydantic.db.factory import DatabaseFactory 13 | from supabase_pydantic.db.models import PostgresConnectionParams 14 | 15 | 16 | # Load environment variables from .env file 17 | load_dotenv() 18 | 19 | 20 | @pytest.fixture 21 | def postgres_params(): 22 | """Get database connection parameters from environment variables.""" 23 | return PostgresConnectionParams( 24 | dbname=os.environ.get('TEST_DB_NAME', 'postgres'), 25 | user=os.environ.get('TEST_DB_USER', 'postgres'), 26 | password=os.environ.get('TEST_DB_PASS', 'postgres'), 27 | host=os.environ.get('TEST_DB_HOST', 'localhost'), 28 | port=os.environ.get('TEST_DB_PORT', '5432'), 29 | ) 30 | 31 | 32 | @pytest.fixture 33 | def postgres_url(): 34 | """Get database URL from environment variables.""" 35 | return os.environ.get( 36 | 'TEST_DB_URL', 37 | 'postgresql://postgres:postgres@localhost:5432/postgres', 38 | ) 39 | 40 | 41 | @pytest.fixture 42 | def schema_reader(postgres_params): 43 | """Create a schema reader for postgres database testing.""" 44 | # Create connector and schema reader using the factory 45 | connector = DatabaseFactory.create_connector(DatabaseType.POSTGRES, connection_params=postgres_params) 46 | return DatabaseFactory.create_schema_reader(DatabaseType.POSTGRES, connector=connector) 47 | 48 | 49 | @pytest.mark.integration 50 | @pytest.mark.db 51 | @pytest.mark.connection 52 | @pytest.mark.construction 53 | @pytest.mark.skipif( 54 | not os.environ.get('RUN_DB_TESTS'), 55 | reason='Database integration tests are disabled. Set RUN_DB_TESTS=1 to enable.', 56 | ) 57 | def test_read_schema_integration(schema_reader): 58 | """Test reading schema information using local connection parameters.""" 59 | try: 60 | schemas = schema_reader.get_schemas() 61 | assert isinstance(schemas, list) 62 | assert len(schemas) > 0 63 | 64 | # Get tables for the first schema 65 | schema = schemas[0] # Usually 'public' 66 | tables = schema_reader.get_tables(schema) 67 | assert isinstance(tables, list) 68 | 69 | # If there are tables, check their structure 70 | if tables: 71 | table = tables[0] 72 | assert table.name is not None 73 | 74 | # Get columns for the first table 75 | columns = schema_reader.get_columns(schema, table.name) 76 | assert isinstance(columns, list) 77 | 78 | # Check column structure if columns exis 79 | if columns: 80 | column = columns[0] 81 | assert column.name is not None 82 | assert column.data_type is not None 83 | except Exception as e: 84 | pytest.skip(f'Could not read schema: {str(e)}') 85 | 86 | 87 | @pytest.mark.integration 88 | @pytest.mark.db 89 | @pytest.mark.connection 90 | @pytest.mark.construction 91 | @pytest.mark.skipif( 92 | not os.environ.get('RUN_DB_TESTS'), 93 | reason='Database integration tests are disabled. Set RUN_DB_TESTS=1 to enable.', 94 | ) 95 | def test_read_schema_from_url_integration(postgres_url): 96 | """Test reading schema information using a database URL.""" 97 | try: 98 | connector = DatabaseFactory.create_connector(DatabaseType.POSTGRES, connection_params={'db_url': postgres_url}) 99 | reader = DatabaseFactory.create_schema_reader(DatabaseType.POSTGRES, connector=connector) 100 | 101 | schemas = reader.get_schemas() 102 | assert isinstance(schemas, list) 103 | assert len(schemas) > 0 104 | 105 | # Get tables for the first schema 106 | schema = schemas[0] # Usually 'public' 107 | tables = reader.get_tables(schema) 108 | assert isinstance(tables, list) 109 | 110 | # If there are tables, check their structure 111 | if tables: 112 | table = tables[0] 113 | assert table.name is not None 114 | except Exception as e: 115 | pytest.skip(f'Could not read schema: {str(e)}') 116 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/constraints.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any 3 | 4 | from supabase_pydantic.db.marshalers.column import standardize_column_name 5 | from supabase_pydantic.db.models import ConstraintInfo 6 | 7 | 8 | def parse_constraint_definition_for_fk(constraint_definition: str) -> tuple[str, str, str] | None: 9 | """Parse the foreign key definition from the constraint.""" 10 | match = re.match(r'FOREIGN KEY \(([^)]+)\) REFERENCES (\S+)\(([^)]+)\)', constraint_definition) 11 | if match: 12 | column_name = match.group(1) 13 | foreign_table_name = match.group(2) 14 | foreign_column_name = match.group(3) 15 | 16 | return column_name, foreign_table_name, foreign_column_name 17 | return None 18 | 19 | 20 | def add_constraints_to_table_details(tables: dict, schema: str, constraints: list) -> None: 21 | """Add constraints to the table details.""" 22 | for row in constraints: 23 | (constraint_name, table_name, columns, constraint_type, constraint_definition) = row 24 | 25 | # Remove schema from the beginning of table_name if present 26 | if table_name.startswith(f'{schema}.'): 27 | table_name = table_name[len(schema) + 1 :] # Remove schema and the dot # noqa: E203 28 | table_name = table_name.lstrip('.') # Remove any leading dots 29 | table_key = (schema, table_name) 30 | 31 | # Create the constraint and add it to the table 32 | if table_key in tables: 33 | constraint = ConstraintInfo( 34 | constraint_name=constraint_name, 35 | columns=[standardize_column_name(c) or str(c) for c in columns], 36 | raw_constraint_type=constraint_type, 37 | constraint_definition=constraint_definition, 38 | ) 39 | tables[table_key].add_constraint(constraint) 40 | 41 | 42 | def get_unique_columns_from_constraints(constraint: ConstraintInfo) -> list[str | Any]: 43 | """Get unique columns from constraints.""" 44 | unique_columns = [] 45 | if constraint.constraint_type() == 'UNIQUE': 46 | match = re.match(r'UNIQUE \(([^)]+)\)', constraint.constraint_definition) 47 | if match: 48 | columns = match.group(1).split(',') 49 | unique_columns = [c.strip() for c in columns] 50 | return unique_columns 51 | 52 | 53 | def update_column_constraint_definitions(tables: dict) -> None: 54 | """Update columns with their CHECK constraint definitions.""" 55 | for table in tables.values(): 56 | if table.columns is None or len(table.columns) == 0: 57 | continue 58 | if table.constraints is None or len(table.constraints) == 0: 59 | continue 60 | 61 | # iterate through columns and constraints 62 | for column in table.columns: 63 | for constraint in table.constraints: 64 | # Only process CHECK constraints that affect this column 65 | if constraint.constraint_type() == 'CHECK' and len(constraint.columns) == 1: 66 | if column.name == constraint.columns[0]: 67 | column.constraint_definition = constraint.constraint_definition 68 | 69 | 70 | def update_columns_with_constraints(tables: dict) -> None: 71 | """Update columns with constraints.""" 72 | for table in tables.values(): 73 | if table.columns is None or len(table.columns) == 0: 74 | continue 75 | if table.constraints is None or len(table.constraints) == 0: 76 | continue 77 | 78 | # iterate through columns and constraints 79 | for column in table.columns: 80 | for constraint in table.constraints: 81 | for col in constraint.columns: 82 | if column.name == col: 83 | if constraint.constraint_type() == 'PRIMARY KEY': 84 | column.primary = True 85 | if constraint.constraint_type() == 'UNIQUE': 86 | column.is_unique = True 87 | column.unique_partners = get_unique_columns_from_constraints(constraint) 88 | if constraint.constraint_type() == 'FOREIGN KEY': 89 | column.is_foreign_key = True 90 | if constraint.constraint_type() == 'CHECK' and len(constraint.columns) == 1: 91 | column.constraint_definition = constraint.constraint_definition 92 | -------------------------------------------------------------------------------- /tests/unit/core/writers/utils/test_writer_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for writer utilities in supabase_pydantic.core.writers.utils.""" 2 | 3 | from unittest.mock import call, mock_open, patch 4 | 5 | import pytest 6 | 7 | from supabase_pydantic.core.writers.utils import get_latest_filename, write_seed_file 8 | 9 | 10 | @pytest.mark.unit 11 | @pytest.mark.writers 12 | @pytest.mark.io 13 | def test_get_latest_filename(): 14 | file_path = 'directory/test_file.py' 15 | latest_file = get_latest_filename(file_path) 16 | assert latest_file == 'directory/test_file_latest.py' 17 | 18 | 19 | @pytest.mark.unit 20 | @pytest.mark.writers 21 | @pytest.mark.io 22 | def test_write_seed_file_no_overwrite_file_does_not_exist(): 23 | seed_data = { 24 | 'table1': [['id', 'name'], [1, 'Alice'], [2, 'Bob']], 25 | 'table2': [['id', 'value'], [1, 'Value1'], [2, 'Value2']], 26 | } 27 | file_path = 'path/to/seed.sql' 28 | expected = ['path/to/seed_latest.sql'] 29 | 30 | m = mock_open() 31 | 32 | with patch('builtins.open', m): 33 | with patch('os.path.exists', return_value=False): 34 | with patch('pathlib.Path.write_text'): 35 | with patch('pathlib.Path.exists', return_value=False): 36 | result = write_seed_file(seed_data, file_path, overwrite=False) 37 | 38 | assert result == expected 39 | expected_calls = [ 40 | call('-- table1\n'), 41 | call('INSERT INTO table1 (id, name) VALUES (1, Alice);\n'), 42 | call('INSERT INTO table1 (id, name) VALUES (2, Bob);\n'), 43 | call('\n'), 44 | call('-- table2\n'), 45 | call('INSERT INTO table2 (id, value) VALUES (1, Value1);\n'), 46 | call('INSERT INTO table2 (id, value) VALUES (2, Value2);\n'), 47 | call('\n'), 48 | ] 49 | m().write.assert_has_calls(expected_calls, any_order=True) 50 | 51 | 52 | @pytest.mark.unit 53 | @pytest.mark.writers 54 | @pytest.mark.io 55 | def test_write_seed_file_no_overwrite_file_exists(): 56 | seed_data = { 57 | 'table1': [['id', 'name'], [1, 'Alice'], [2, 'Bob']], 58 | 'table2': [['id', 'value'], [1, 'Value1'], [2, 'Value2']], 59 | } 60 | file_path = 'path/to/seed.sql' 61 | expected_file_path = 'path/to/seed_latest.sql' 62 | unique_file_path = 'path/to/seed_unique.sql' 63 | 64 | m = mock_open() 65 | 66 | with patch('builtins.open', m): 67 | with patch('os.path.exists', return_value=True): 68 | with patch('supabase_pydantic.core.writers.utils.generate_unique_filename', return_value=unique_file_path): 69 | result = write_seed_file(seed_data, file_path, overwrite=False) 70 | 71 | assert result == [expected_file_path, unique_file_path] 72 | expected_calls = [ 73 | call('-- table1\n'), 74 | call('INSERT INTO table1 (id, name) VALUES (1, Alice);\n'), 75 | call('INSERT INTO table1 (id, name) VALUES (2, Bob);\n'), 76 | call('\n'), 77 | call('-- table2\n'), 78 | call('INSERT INTO table2 (id, value) VALUES (1, Value1);\n'), 79 | call('INSERT INTO table2 (id, value) VALUES (2, Value2);\n'), 80 | call('\n'), 81 | ] 82 | m().write.assert_has_calls(expected_calls, any_order=False) 83 | 84 | 85 | @pytest.mark.unit 86 | @pytest.mark.writers 87 | @pytest.mark.io 88 | def test_write_seed_file_overwrite(): 89 | seed_data = { 90 | 'table1': [['id', 'name'], [1, 'Alice'], [2, 'Bob']], 91 | 'table2': [['id', 'value'], [1, 'Value1'], [2, 'Value2']], 92 | } 93 | file_path = 'path/to/seed.sql' 94 | expected = 'path/to/seed_latest.sql' 95 | 96 | m = mock_open() 97 | 98 | with patch('builtins.open', m): 99 | with patch('os.path.exists', return_value=True): 100 | with patch('pathlib.Path.write_text'): 101 | with patch('pathlib.Path.exists', return_value=True): 102 | result = write_seed_file(seed_data, file_path, overwrite=True) 103 | 104 | assert result == [expected] 105 | expected_calls = [ 106 | call('-- table1\n'), 107 | call('INSERT INTO table1 (id, name) VALUES (1, Alice);\n'), 108 | call('INSERT INTO table1 (id, name) VALUES (2, Bob);\n'), 109 | call('\n'), 110 | call('-- table2\n'), 111 | call('INSERT INTO table2 (id, value) VALUES (1, Value1);\n'), 112 | call('INSERT INTO table2 (id, value) VALUES (2, Value2);\n'), 113 | call('\n'), 114 | ] 115 | m().write.assert_has_calls(expected_calls, any_order=True) 116 | -------------------------------------------------------------------------------- /tests/unit/cli/common/test_common.py: -------------------------------------------------------------------------------- 1 | """Tests for cli/common.py functions.""" 2 | 3 | from unittest.mock import patch, mock_open 4 | from inspect import signature 5 | 6 | from supabase_pydantic.cli.common import check_readiness, load_config, common_options 7 | 8 | 9 | def test_check_readiness_valid_env_vars(): 10 | """Test check_readiness with valid env vars.""" 11 | env_vars = {'DB_HOST': 'localhost', 'DB_PORT': '5432', 'DB_USER': 'postgres'} 12 | with patch('logging.debug') as mock_debug: 13 | result = check_readiness(env_vars) 14 | assert result is True 15 | mock_debug.assert_called_with('All required environment variables are set') 16 | 17 | 18 | def test_check_readiness_empty_env_vars(): 19 | """Test check_readiness with empty env vars.""" 20 | with patch('logging.error') as mock_error: 21 | result = check_readiness({}) 22 | assert result is False 23 | mock_error.assert_called_with('No environment variables provided.') 24 | 25 | 26 | def test_check_readiness_missing_env_vars(): 27 | """Test check_readiness with missing env vars.""" 28 | env_vars = {'DB_HOST': 'localhost', 'DB_PORT': None, 'DB_USER': 'postgres'} 29 | with patch('logging.error') as mock_error: 30 | result = check_readiness(env_vars) 31 | assert result is False 32 | mock_error.assert_called_with( 33 | 'Environment variables not set correctly. DB_PORT is missing. Please set it in .env file.' 34 | ) 35 | 36 | 37 | def test_load_config_from_pyproject(): 38 | """Test loading config from pyproject.toml.""" 39 | mock_config = {'tool': {'supabase_pydantic': {'output_dir': './models', 'schemas': ['public']}}} 40 | 41 | with patch('builtins.open', mock_open()) as mock_file: 42 | with patch('toml.load', return_value=mock_config) as mock_toml: 43 | config = load_config() 44 | 45 | mock_file.assert_called_once_with('pyproject.toml') 46 | mock_toml.assert_called_once() 47 | assert config == {'output_dir': './models', 'schemas': ['public']} 48 | 49 | 50 | def test_load_config_from_custom_toml(): 51 | """Test loading config from custom toml file.""" 52 | mock_config = {'tool': {'supabase_pydantic': {'output_dir': './custom', 'schemas': ['custom']}}} 53 | 54 | with patch('builtins.open', mock_open()) as mock_file: 55 | with patch('toml.load', return_value=mock_config) as mock_toml: 56 | config = load_config('config.toml') 57 | 58 | mock_file.assert_called_once_with('config.toml') 59 | mock_toml.assert_called_once() 60 | assert config == {'output_dir': './custom', 'schemas': ['custom']} 61 | 62 | 63 | def test_load_config_from_yaml(): 64 | """Test loading config from yaml file.""" 65 | mock_config = {'output_dir': './yaml_models', 'schemas': ['yaml']} 66 | 67 | with patch('builtins.open', mock_open()) as mock_file: 68 | with patch('yaml.safe_load', return_value=mock_config) as mock_yaml: 69 | config = load_config('config.yaml') 70 | 71 | mock_file.assert_called_once_with('config.yaml') 72 | mock_yaml.assert_called_once() 73 | assert config == {'output_dir': './yaml_models', 'schemas': ['yaml']} 74 | 75 | 76 | def test_load_config_file_not_found(): 77 | """Test load_config when file is not found.""" 78 | with patch('builtins.open', side_effect=FileNotFoundError): 79 | with patch('click.echo') as mock_echo: 80 | config = load_config('nonexistent.toml') 81 | 82 | mock_echo.assert_called_once_with('Config file not found: nonexistent.toml') 83 | assert config == {} 84 | 85 | 86 | def test_load_config_pyproject_not_found(): 87 | """Test load_config when pyproject.toml is not found.""" 88 | with patch('builtins.open', side_effect=FileNotFoundError): 89 | config = load_config() 90 | assert config == {} 91 | 92 | 93 | def test_common_options(): 94 | """Test common_options decorator.""" 95 | 96 | @common_options 97 | def dummy_command(config=None): 98 | return config 99 | 100 | # In this case common_options just wraps the function with click options 101 | # but doesn't convert it to a click.Command - that would happen when used with @click.command() 102 | # So we can check that the function signature includes our options 103 | sig = signature(dummy_command) 104 | assert 'config' in sig.parameters 105 | 106 | # Test that we can call the function 107 | result = dummy_command(config='test') 108 | assert result == 'test' 109 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/connectors/postgres/schema_reader.py: -------------------------------------------------------------------------------- 1 | """PostgreSQL schema reader implementation.""" 2 | 3 | import logging 4 | from typing import Any 5 | 6 | from supabase_pydantic.db.abstract.base_connector import BaseDBConnector 7 | from supabase_pydantic.db.abstract.base_schema_reader import BaseSchemaReader 8 | from supabase_pydantic.db.drivers.postgres.queries import ( 9 | GET_ALL_PUBLIC_TABLES_AND_COLUMNS, 10 | GET_COLUMN_TO_USER_DEFINED_TYPE_MAPPING, 11 | GET_CONSTRAINTS, 12 | GET_ENUM_TYPES, 13 | GET_TABLE_COLUMN_DETAILS, 14 | SCHEMAS_QUERY, 15 | TABLES_QUERY, 16 | ) 17 | 18 | # Get Logger 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class PostgresSchemaReader(BaseSchemaReader): 23 | """PostgreSQL schema reader implementation.""" 24 | 25 | def __init__(self, connector: BaseDBConnector): 26 | """Initialize with a PostgreSQL connector. 27 | 28 | Args: 29 | connector: PostgreSQL connector instance. 30 | """ 31 | logger.info('PostgresSchemaReader initialized') 32 | super().__init__(connector) 33 | 34 | def get_schemas(self, conn: Any) -> list[str]: 35 | """Get all schemas in the PostgreSQL database. 36 | 37 | Args: 38 | conn: PostgreSQL connection object. 39 | 40 | Returns: 41 | List of schema names. 42 | """ 43 | schemas_result = self.connector.execute_query(conn, SCHEMAS_QUERY) 44 | return [schema[0] for schema in schemas_result] 45 | 46 | def get_tables(self, conn: Any, schema: str) -> list[tuple]: 47 | """Get all tables in the specified schema. 48 | 49 | Args: 50 | conn: PostgreSQL connection object. 51 | schema: Schema name. 52 | 53 | Returns: 54 | List of table information as tuples. 55 | """ 56 | result = self.connector.execute_query(conn, TABLES_QUERY, (schema,)) 57 | return result if isinstance(result, list) else [] 58 | 59 | def get_columns(self, conn: Any, schema: str) -> list[tuple[Any, ...]]: 60 | """Get all columns in the specified schema. 61 | 62 | Args: 63 | conn: PostgreSQL connection object. 64 | schema: Schema name. 65 | 66 | Returns: 67 | List of column information as tuples. 68 | """ 69 | result = self.connector.execute_query(conn, GET_ALL_PUBLIC_TABLES_AND_COLUMNS, (schema,)) 70 | return result if isinstance(result, list) else [] 71 | 72 | def get_foreign_keys(self, conn: Any, schema: str) -> list[tuple[Any, ...]]: 73 | """Get all foreign keys in the specified schema. 74 | 75 | Args: 76 | conn: PostgreSQL connection object. 77 | schema: Schema name. 78 | 79 | Returns: 80 | List of foreign key information as tuples. 81 | """ 82 | result = self.connector.execute_query(conn, GET_TABLE_COLUMN_DETAILS, (schema,)) 83 | return result if isinstance(result, list) else [] 84 | 85 | def get_constraints(self, conn: Any, schema: str) -> list[tuple[Any, ...]]: 86 | """Get all constraints in the specified schema. 87 | 88 | Args: 89 | conn: PostgreSQL connection object. 90 | schema: Schema name. 91 | 92 | Returns: 93 | List of constraint information as tuples. 94 | """ 95 | result = self.connector.execute_query(conn, GET_CONSTRAINTS, (schema,)) 96 | return result if isinstance(result, list) else [] 97 | 98 | def get_user_defined_types(self, conn: Any, schema: str) -> list[tuple[Any, ...]]: 99 | """Get all user-defined types in the specified schema. 100 | 101 | Args: 102 | conn: PostgreSQL connection object. 103 | schema: Schema name. 104 | 105 | Returns: 106 | List of user-defined type information as tuples. 107 | """ 108 | result = self.connector.execute_query(conn, GET_ENUM_TYPES, (schema,)) 109 | return result if isinstance(result, list) else [] 110 | 111 | def get_type_mappings(self, conn: Any, schema: str) -> list[tuple[Any, ...]]: 112 | """Get type mapping information for the specified schema. 113 | 114 | Args: 115 | conn: PostgreSQL connection object. 116 | schema: Schema name. 117 | 118 | Returns: 119 | List of type mapping information as tuples. 120 | """ 121 | result = self.connector.execute_query(conn, GET_COLUMN_TO_USER_DEFINED_TYPE_MAPPING, (schema,)) 122 | return result if isinstance(result, list) else [] 123 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/mysql/relationship.py: -------------------------------------------------------------------------------- 1 | """MySQL relationship marshaler implementation.""" 2 | 3 | import logging 4 | from collections import defaultdict 5 | 6 | from supabase_pydantic.db.constants import RelationType 7 | from supabase_pydantic.db.marshalers.abstract.base_relationship_marshaler import BaseRelationshipMarshaler 8 | from supabase_pydantic.db.marshalers.relationships import ( 9 | analyze_table_relationships as analyze_relationships, 10 | ) 11 | from supabase_pydantic.db.marshalers.relationships import ( 12 | determine_relationship_type as determine_rel_type, 13 | ) 14 | from supabase_pydantic.db.models import ForeignKeyInfo, RelationshipInfo, TableInfo 15 | 16 | # Get Logger 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class MySQLRelationshipMarshaler(BaseRelationshipMarshaler): 21 | """MySQL relationship marshaler implementation.""" 22 | 23 | def determine_relationship_type( 24 | self, source_table: TableInfo, target_table: TableInfo, source_column: str, target_column: str 25 | ) -> RelationType: 26 | """Determine the type of relationship between two tables. 27 | 28 | Args: 29 | source_table: Source table information. 30 | target_table: Target table information. 31 | source_column: Column name in source table. 32 | target_column: Column name in target table. 33 | 34 | Returns: 35 | RelationType enum value representing the relationship type. 36 | """ 37 | # Create a meaningful constraint name based on tables and columns 38 | constraint_name = f'fk_{source_table.name}_{source_column}_to_{target_table.name}_{target_column}' 39 | 40 | # Create a ForeignKeyInfo object with the proper constraint name 41 | fk = ForeignKeyInfo( 42 | constraint_name=constraint_name, 43 | column_name=source_column, 44 | foreign_table_name=target_table.name, 45 | foreign_column_name=target_column, 46 | foreign_table_schema=target_table.schema, 47 | ) 48 | forward_type, _ = determine_rel_type(source_table, target_table, fk) 49 | return forward_type 50 | 51 | def analyze_table_relationships(self, tables: dict[tuple[str, str], TableInfo]) -> None: 52 | """Analyze relationships between tables. 53 | 54 | Args: 55 | tables: Dictionary of table information objects. 56 | """ 57 | analyze_relationships(tables) 58 | 59 | def _determine_relationship_types(self, table_relationships: dict[str, list[RelationshipInfo]]) -> None: 60 | """Determine the relationship types based on foreign key constraints. 61 | 62 | This is a post-processing step that looks at all relationships to determine 63 | if they are one-to-one, one-to-many, or many-to-many. 64 | 65 | Args: 66 | table_relationships: Dictionary of table names to lists of RelationshipInfo objects 67 | """ 68 | # First, collect information about foreign keys pointing to each table 69 | table_references = defaultdict(list) 70 | 71 | # For each table's relationships 72 | for table_name, relationships in table_relationships.items(): 73 | for rel in relationships: 74 | # Record this relationship as pointing to the related table 75 | target_key = rel.related_table_name 76 | table_references[target_key].append((table_name, rel)) 77 | 78 | # Now analyze the relationships to determine their types 79 | for table_name, relationships in table_relationships.items(): 80 | for rel in relationships: 81 | # Check if other tables refer to this table 82 | references_to_this = table_references.get(table_name, []) 83 | 84 | if len(references_to_this) > 0: 85 | # This is part of a bidirectional relationship 86 | 87 | # Many-to-many: if this table refers to another table and that table also refers back 88 | if any(ref_table == rel.related_table_name for ref_table, _ in references_to_this): 89 | rel.relation_type = RelationType.MANY_TO_MANY 90 | else: 91 | # One-to-one or one-to-many, depending on constraints 92 | # For simplicity in MySQL, assuming one-to-many as default 93 | rel.relation_type = RelationType.ONE_TO_MANY 94 | else: 95 | # This is a unidirectional relationship 96 | # For MySQL, we default to many-to-one for unidirectional relationships 97 | rel.relation_type = RelationType.MANY_TO_ONE 98 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/marshalers/postgres/schema.py: -------------------------------------------------------------------------------- 1 | from supabase_pydantic.db.database_type import DatabaseType 2 | from supabase_pydantic.db.marshalers.abstract.base_schema_marshaler import BaseSchemaMarshaler 3 | from supabase_pydantic.db.marshalers.schema import ( 4 | add_constraints_to_table_details, 5 | add_foreign_key_info_to_table_details, 6 | add_relationships_to_table_details, 7 | add_user_defined_types_to_tables, 8 | analyze_bridge_tables, 9 | analyze_table_relationships, 10 | get_enum_types_by_schema, 11 | get_table_details_from_columns, 12 | update_column_constraint_definitions, 13 | update_columns_with_constraints, 14 | ) 15 | from supabase_pydantic.db.models import TableInfo 16 | 17 | 18 | class PostgresSchemaMarshaler(BaseSchemaMarshaler): 19 | """PostgreSQL-specific implementation of schema marshaling.""" 20 | 21 | def process_columns(self, column_data: list[tuple]) -> list[tuple]: 22 | """Process column data from database-specific format. 23 | 24 | In PostgreSQL, we can use the column data as is. 25 | """ 26 | return column_data 27 | 28 | def process_foreign_keys(self, fk_data: list[tuple]) -> list[tuple]: 29 | """Process foreign key data from database-specific format. 30 | 31 | In PostgreSQL, we can use the foreign key data as is. 32 | """ 33 | return fk_data 34 | 35 | def process_constraints(self, constraint_data: list[tuple]) -> list[tuple]: 36 | """Process constraint data from database-specific format. 37 | 38 | In PostgreSQL, we can use the constraint data as is. 39 | """ 40 | return constraint_data 41 | 42 | def get_table_details_from_columns(self, column_data: list[tuple]) -> dict[tuple[str, str], TableInfo]: 43 | """Get table details from column data. 44 | 45 | Args: 46 | column_data: List of column data from database. 47 | 48 | Returns: 49 | Dictionary mapping (schema, table) tuples to TableInfo objects. 50 | """ 51 | # Process columns first 52 | processed_column_data = self.process_columns(column_data) 53 | 54 | # Get the disable_model_prefix_protection value from the instance or default to False 55 | disable_model_prefix_protection = getattr(self, 'disable_model_prefix_protection', False) 56 | 57 | # Call the common implementation 58 | # Get table details using common function 59 | result: dict[tuple[str, str], TableInfo] = get_table_details_from_columns( 60 | processed_column_data, disable_model_prefix_protection, column_marshaler=self.column_marshaler 61 | ) 62 | return result 63 | 64 | def construct_table_info( 65 | self, 66 | table_data: list[tuple], 67 | column_data: list[tuple], 68 | fk_data: list[tuple], 69 | constraint_data: list[tuple], 70 | type_data: list[tuple], 71 | type_mapping_data: list[tuple], 72 | schema: str, 73 | disable_model_prefix_protection: bool = False, 74 | ) -> list[TableInfo]: 75 | """Construct TableInfo objects from database details.""" 76 | # Process the raw data 77 | processed_column_data = self.process_columns(column_data) 78 | processed_fk_data = self.process_foreign_keys(fk_data) 79 | processed_constraint_data = self.process_constraints(constraint_data) 80 | 81 | # Construct table information 82 | enum_types = get_enum_types_by_schema(type_data, schema) 83 | tables = get_table_details_from_columns( 84 | column_details=processed_column_data, 85 | disable_model_prefix_protection=disable_model_prefix_protection, 86 | column_marshaler=self.column_marshaler, 87 | enum_types=enum_types, 88 | ) 89 | add_foreign_key_info_to_table_details(tables, processed_fk_data) 90 | add_constraints_to_table_details(tables, schema, processed_constraint_data) 91 | add_relationships_to_table_details(tables, processed_fk_data) 92 | add_user_defined_types_to_tables(tables, schema, type_data, type_mapping_data, DatabaseType.POSTGRES) 93 | 94 | # Update columns with constraints 95 | update_columns_with_constraints(tables) 96 | update_column_constraint_definitions(tables) 97 | analyze_bridge_tables(tables) 98 | for _ in range(2): 99 | # TODO: update this fn to avoid running twice. 100 | # print('running analyze_table_relationships ' + str(i)) 101 | analyze_table_relationships(tables) # run twice to ensure all relationships are captured 102 | 103 | return list(tables.values()) 104 | -------------------------------------------------------------------------------- /tests/unit/db/marshalers/mysql/test_mysql_constraint_marshaler.py: -------------------------------------------------------------------------------- 1 | """Tests for MySQL constraint marshaler implementation.""" 2 | 3 | import pytest 4 | from unittest.mock import patch 5 | 6 | from supabase_pydantic.db.marshalers.mysql.constraints import MySQLConstraintMarshaler 7 | from supabase_pydantic.db.models import TableInfo 8 | 9 | 10 | @pytest.fixture 11 | def marshaler(): 12 | """Fixture to create a MySQLConstraintMarshaler instance.""" 13 | return MySQLConstraintMarshaler() 14 | 15 | 16 | @pytest.fixture 17 | def tables_setup(): 18 | """Create a fixture with tables for testing.""" 19 | # Create mock table info 20 | users_table = TableInfo(name='users', schema='test_schema', table_type='BASE TABLE', columns=[], constraints=[]) 21 | 22 | orders_table = TableInfo(name='orders', schema='test_schema', table_type='BASE TABLE', columns=[], constraints=[]) 23 | 24 | # Dictionary with multiple tables 25 | tables = { 26 | ('test_schema', 'users'): users_table, 27 | ('test_schema', 'orders'): orders_table, 28 | } 29 | return tables 30 | 31 | 32 | @pytest.mark.unit 33 | @pytest.mark.db 34 | @pytest.mark.marshalers 35 | @pytest.mark.parametrize( 36 | 'constraint_def, expected', 37 | [ 38 | # Standard format 39 | ('FOREIGN KEY (user_id) REFERENCES users(id)', ('user_id', 'users', 'id')), 40 | # MySQL-specific format with backticks 41 | ('FOREIGN KEY (`user_id`) REFERENCES `users` (`id`)', ('user_id', 'users', 'id')), 42 | # Mixed format (with and without backticks) 43 | ('FOREIGN KEY (`order_id`) REFERENCES orders (`order_id`)', ('order_id', 'orders', 'order_id')), 44 | # Invalid formats 45 | ('INVALID KEY (user_id) REFERENCES users(id)', None), 46 | ('FOREIGN REF `users` (`id`)', None), 47 | ('', None), 48 | ], 49 | ) 50 | def test_parse_foreign_key(marshaler, constraint_def, expected): 51 | """Test MySQL foreign key parsing.""" 52 | result = marshaler.parse_foreign_key(constraint_def) 53 | assert result == expected, f"Failed for '{constraint_def}', got {result}, expected {expected}" 54 | 55 | 56 | @pytest.mark.unit 57 | @pytest.mark.db 58 | @pytest.mark.marshalers 59 | @pytest.mark.parametrize( 60 | 'constraint_def, expected', 61 | [ 62 | # Standard format 63 | ('UNIQUE (email)', ['email']), 64 | # Multiple columns 65 | ('UNIQUE (`email`, `username`)', ['email', 'username']), 66 | # With KEY keyword 67 | ('UNIQUE KEY `uk_email` (`email`)', ['email']), 68 | # Invalid formats 69 | ('UNIQ (`email`)', []), 70 | ('', []), 71 | ], 72 | ) 73 | def test_parse_unique_constraint(marshaler, constraint_def, expected): 74 | """Test MySQL unique constraint parsing.""" 75 | result = marshaler.parse_unique_constraint(constraint_def) 76 | assert result == expected, f"Failed for '{constraint_def}', got {result}, expected {expected}" 77 | 78 | 79 | @pytest.mark.unit 80 | @pytest.mark.db 81 | @pytest.mark.marshalers 82 | @pytest.mark.parametrize( 83 | 'constraint_def, expected', 84 | [ 85 | ('CHECK (`age` > 0)', '`age` > 0'), 86 | ('CHECK (price > 0 AND price < 1000)', 'price > 0 AND price < 1000'), 87 | ("CHECK (`status` IN ('active', 'inactive'))", "`status` IN ('active', 'inactive')"), 88 | ('INVALID CHECK', 'INVALID CHECK'), # Should return original if no match 89 | ], 90 | ) 91 | def test_parse_check_constraint(marshaler, constraint_def, expected): 92 | """Test MySQL check constraint parsing.""" 93 | result = marshaler.parse_check_constraint(constraint_def) 94 | assert result == expected, f"Failed for '{constraint_def}', got {result}, expected {expected}" 95 | 96 | 97 | @pytest.mark.unit 98 | @pytest.mark.db 99 | @pytest.mark.marshalers 100 | def test_add_constraints_to_table_details(marshaler, tables_setup): 101 | """Test adding constraints to table details.""" 102 | # Create mock constraint data 103 | constraint_data = [ 104 | {'table_name': 'users', 'constraint_name': 'pk_users', 'constraint_type': 'PRIMARY KEY'}, 105 | {'table_name': 'users', 'constraint_name': 'uk_email', 'constraint_type': 'UNIQUE'}, 106 | ] 107 | 108 | # We'll patch the common implementation since that's what's used 109 | with patch('supabase_pydantic.db.marshalers.mysql.constraints.add_constraints') as mock_add_constraints: 110 | marshaler.add_constraints_to_table_details(tables_setup, constraint_data, 'test_schema') 111 | 112 | # Verify the common implementation was called with correct parameters 113 | mock_add_constraints.assert_called_once_with(tables_setup, 'test_schema', constraint_data) 114 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | permissions: 13 | contents: read 14 | strategy: 15 | matrix: 16 | python-version: ["3.10"] 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install Poetry 26 | uses: snok/install-poetry@v1 27 | with: 28 | version: 1.7.1 29 | virtualenvs-create: true 30 | virtualenvs-in-project: true 31 | 32 | - name: Install dependencies 33 | run: poetry install 34 | 35 | - name: Lint checking 36 | run: poetry run ruff check -v --output-format=github --exclude=poc . 37 | 38 | - name: Type checking 39 | run: poetry run mypy -v --exclude='^poc/' . 40 | 41 | release: 42 | needs: [lint] 43 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 44 | runs-on: ubuntu-latest 45 | permissions: 46 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 47 | contents: write 48 | strategy: 49 | matrix: 50 | python-version: ["3.10"] 51 | steps: 52 | #---------------------------------------------- 53 | # Create GitHub App Token 54 | #---------------------------------------------- 55 | - name: Create GitHub App Token 56 | uses: actions/create-github-app-token@v1 57 | id: app-token 58 | with: 59 | app-id: ${{ vars.VERSION_BUMPER_APPID }} 60 | private-key: ${{ secrets.VERSION_BUMPER_SECRET }} 61 | 62 | #---------------------------------------------- 63 | # Checkout Code and Setup Python 64 | #---------------------------------------------- 65 | - name: Checkout 66 | uses: actions/checkout@v4 67 | with: 68 | fetch-depth: 0 # pulls all history and tags for Lerna to detect which packages changed 69 | token: ${{ steps.app-token.outputs.token }} 70 | 71 | - name: Set up Python ${{ matrix.python-version }} 72 | uses: actions/setup-python@v5 73 | with: 74 | python-version: ${{ matrix.python-version }} 75 | 76 | #---------------------------------------------- 77 | # Install Poetry 78 | #---------------------------------------------- 79 | - name: Install Poetry 80 | uses: snok/install-poetry@v1 81 | with: 82 | version: 1.7.1 83 | virtualenvs-create: true 84 | virtualenvs-in-project: true 85 | installer-parallel: true 86 | 87 | #---------------------------------------------- 88 | # Install Dependencies & Build Package 89 | #---------------------------------------------- 90 | - name: Install dependencies 91 | run: poetry install 92 | 93 | - name: Build package 94 | run: poetry build 95 | 96 | - name: List build artifacts 97 | run: ls -la dist/ 98 | 99 | #---------------------------------------------- 100 | # Upload Coverage to Codecov 101 | #---------------------------------------------- 102 | - name: Upload coverage to Codecov 103 | uses: codecov/codecov-action@v4 104 | with: 105 | verbose: true 106 | env: 107 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 108 | 109 | #---------------------------------------------- 110 | # Python Semantic Release 111 | #---------------------------------------------- 112 | - name: Python Semantic Release (Version) 113 | id: release 114 | uses: python-semantic-release/python-semantic-release@v9.8.5 115 | with: 116 | github_token: ${{ steps.app-token.outputs.token }} 117 | 118 | #---------------------------------------------- 119 | # Publish to PyPI & GitHub Release 120 | #---------------------------------------------- 121 | - name: Publish package distributions to PyPI 122 | uses: pypa/gh-action-pypi-publish@release/v1 123 | with: 124 | packages-dir: dist/ 125 | verbose: true 126 | print-hash: true 127 | 128 | - name: Publish package to GitHub Release 129 | uses: python-semantic-release/upload-to-gh-release@v9.8.5 130 | if: steps.release.outputs.released == 'true' 131 | with: 132 | github_token: ${{ steps.app-token.outputs.token }} 133 | tag: ${{ steps.release.outputs.tag }} 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Supabase Pydantic Schemas 2 | 3 | [![PyPI Version](https://img.shields.io/pypi/v/supabase-pydantic)](https://pypi.org/project/supabase-pydantic/) 4 | [![Conda Version](https://img.shields.io/conda/vn/conda-forge/supabase-pydantic) 5 | ](https://anaconda.org/conda-forge/supabase-pydantic) 6 | [![Pydantic v2](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/pydantic/pydantic/main/docs/badge/v2.json)](https://pydantic.dev) 7 | ![GitHub License](https://img.shields.io/github/license/kmbhm1/supabase-pydantic) 8 | [![codecov](https://codecov.io/github/kmbhm1/supabase-pydantic/graph/badge.svg?token=PYOJPJTOLM)](https://codecov.io/github/kmbhm1/supabase-pydantic) 9 | ![PePy Downloads](https://static.pepy.tech/badge/supabase-pydantic) 10 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/supabase-pydantic) 11 | 12 | A project for generating Pydantic and SQLAlchemy models from Supabase and MySQL databases. This tool bridges the gap between your database schema and your Python code, providing type-safe models for FastAPI and other frameworks. 13 | 14 | Currently, this is ideal for integrating [FastAPI](https://fastapi.tiangolo.com/) with [supabase-py](https://supabase.com/docs/reference/python/introduction) as a primary use-case, but more updates are coming! This project is inspired by the TS [type generating](https://supabase.com/docs/guides/api/rest/generating-types) capabilities of supabase cli. Its aim is to provide a similar experience for Python developers. 15 | 16 | > **📣 NEW (Aug 2025)**: MySQL support! Generate models directly from MySQL databases with the `--db-type mysql` flag. [See the docs](https://kmbhm1.github.io/supabase-pydantic/examples/mysql-support/) 17 | > 18 | > **📣 NEW (Aug 2025)**: SQLAlchemy models with Insert and Update variants for better type safety. [Learn more](https://kmbhm1.github.io/supabase-pydantic/examples/sqlalchemy-models/) 19 | 20 | ## Installation 21 | 22 | ```bash 23 | # Install with pip 24 | $ pip install supabase-pydantic 25 | 26 | # Or with conda 27 | $ conda install -c conda-forge supabase-pydantic 28 | ``` 29 | 30 | ## Configuration 31 | 32 | Create a `.env` file with your database connection details: 33 | 34 | ```bash 35 | DB_NAME= 36 | DB_USER= 37 | DB_PASS= 38 | DB_HOST= 39 | DB_PORT= 40 | ``` 41 | 42 | ## Usage 43 | 44 | ### Generate Pydantic Models 45 | 46 | Example with full command output (using a local connection): 47 | 48 | ```bash 49 | $ sb-pydantic gen --type pydantic --framework fastapi --local 50 | 51 | 2025-08-18 15:47:42 - INFO - supabase_pydantic.db.connection:check_connection:72 - PostGres connection is open. 52 | 2025-08-18 15:47:42 - INFO - supabase_pydantic.db.connection:construct_tables:136 - Processing schema: public 53 | 2025-08-18 15:47:42 - INFO - supabase_pydantic.db.connection:__exit__:105 - PostGres connection is closed. 54 | 2025-08-18 15:47:42 - INFO - supabase_pydantic.cli.commands.gen:gen:239 - Generating Pydantic models... 55 | 2025-08-18 15:47:42 - INFO - supabase_pydantic.cli.commands.gen:gen:251 - Pydantic models generated successfully for schema 'public': /path/to/your/project/entities/fastapi/schema_public_latest.py 56 | 2025-08-18 15:47:42 - INFO - supabase_pydantic.cli.commands.gen:gen:258 - File formatted successfully: /path/to/your/project/entities/fastapi/schema_public_latest.py 57 | ``` 58 | 59 | ### Common Commands 60 | 61 | **Using a database URL:** 62 | ```bash 63 | $ sb-pydantic gen --type pydantic --framework fastapi --db-url postgresql://postgres:postgres@127.0.0.1:54322/postgres 64 | ``` 65 | 66 | **Generating models for specific schemas:** 67 | ```bash 68 | $ sb-pydantic gen --type pydantic --framework fastapi --local --schema extensions --schema auth 69 | ``` 70 | 71 | **Generate SQLAlchemy models:** 72 | ```bash 73 | $ sb-pydantic gen --type sqlalchemy --local 74 | ``` 75 | 76 | **Using MySQL:** 77 | ```bash 78 | $ sb-pydantic gen --type pydantic --framework fastapi --db-type mysql --db-url mysql://user:pass@localhost:3306/dbname 79 | ``` 80 | 81 | ### Makefile Integration 82 | 83 | ```makefile 84 | # Makefile examples for both Pydantic and SQLAlchemy generation 85 | 86 | gen-pydantic: 87 | @echo "Generating FastAPI Pydantic models..." 88 | @sb-pydantic gen --type pydantic --framework fastapi --dir ./entities/fastapi --local 89 | 90 | gen-sqlalchemy: 91 | @echo "Generating SQLAlchemy ORM models..." 92 | @sb-pydantic gen --type sqlalchemy --dir ./entities/sqlalchemy --local 93 | 94 | # Generate all model types at once 95 | gen-all: gen-pydantic gen-sqlalchemy 96 | @echo "All models generated successfully." 97 | ``` 98 | 99 | For full documentation, visit [our docs site](https://kmbhm1.github.io/supabase-pydantic/). 100 | -------------------------------------------------------------------------------- /docs/getting-started/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We welcome contributions and volunteers for the project! Please read the following guidelines before contributing. 4 | 5 | ## Issues 6 | 7 | Questions, bug reports, and feature requests are welcome as [discussions or issues](https://github.com/kmbhm1/supabase-pydantic/issues/new/choose). Please search the existing issues before opening a new one. **To report a security vulneratbility, please see the [security policy](./security.md).** 8 | 9 | To help us resolve your issue, please provide the following information: 10 | 11 | - **Expected behavior:** A clear and concise description of what you expected to happen. 12 | - **Actual behavior:** A clear and concise description of what actually happened. 13 | - **Steps to reproduce:** How can we reproduce the issue? 14 | - **Environment:** Include relevant details like the operating system, Python version, and any other relevant information. 15 | 16 | ## Pull Requests 17 | 18 | We welcome pull requests for bug fixes, new features, and improvements. We aim to provide feedback regularly and will review your pull request as soon as possible. :tada: 19 | 20 | Unless your change is not trivial, please [create an issue](https://github.com/kmbhm1/supabase-pydantic/issues/new/choose) before submitting a pull request. This will allow us to discuss the change and ensure it aligns with the project goals. 21 | 22 | ### Prerequisites 23 | 24 | You will need the following to start developing: 25 | 26 | - Python 3.10+ 27 | - [virtualenv](https://virtualenv.pypa.io/en/latest/), [poetry](https://python-poetry.org/), or [pipenv](https://pipenv.pypa.io/en/latest/) for the development environment. 28 | - [git](https://git-scm.com/) for version control. 29 | - [make](https://www.gnu.org/software/make/) for running the Makefile commands. 30 | 31 | ### Installation & Setup 32 | 33 | Fork the [repository](https://github.com/kmbhm1/supabase-pydantic) on GitHub and clone your fork locally. Then, install the dependencies: 34 | 35 | ``` bash 36 | # Clone your fork and cd into the repo directory 37 | git clone git@github.com:/supabase-pydantic.git 38 | cd supabase-pydantic 39 | 40 | # Install the dependencies 41 | # e.g., using poetry 42 | poetry install 43 | 44 | # Setup pre-commit 45 | make pre-commit-setup 46 | ``` 47 | 48 | ### Checkout a New Branch 49 | 50 | Create a new branch for your changes: 51 | 52 | ``` bash 53 | # Create a new branch 54 | git checkout -b my-new-feature 55 | ``` 56 | 57 | ### Run Tests, Linting, Formatting, Type checking, & Pre-Commit Hooks 58 | 59 | Before submitting a pull request, ensure that the tests pass, the code is formatted correctly, and the code passes the linting and type checking checks: 60 | 61 | ``` bash 62 | # Run tests 63 | make test 64 | 65 | # Run linting & formatting 66 | make lint 67 | make format 68 | make check-types 69 | 70 | # Run pre-commit hooks 71 | make pre-commit 72 | ``` 73 | 74 | ### Build Documentation 75 | 76 | If you make changes to the documentation, you can build the documentation locally: 77 | 78 | ``` bash 79 | # Build the documentation 80 | make serve-docs 81 | ``` 82 | 83 | ### Commit and Push your Changes 84 | 85 | Commit your changes, push to your forked repository, and create a pull request: 86 | 87 | Please follow the pull request template and fill in as much information as possible. Link to any relevant issues and include a description of your changes. 88 | 89 | When your pull request is ready for review, add a comment with the message "please review" and we'll take a look as soon as we can. 90 | 91 | ## Code Style & Conventions 92 | 93 | ### Documentation Style 94 | 95 | Documentation is written in format and follows the [Markdown Style Guide](https://www.cirosantilli.com/markdown-style-guide/). Please ensure that the documentation is clear, concise, and easy to read. API documentation is generated using [mkdocs](https://www.mkdocs.org/) & [mkdocstrings](https://mkdocstrings.github.io/). We follow [google-style](https://google.github.io/styleguide/pyguide.html) docstrings. 96 | 97 | ### Code Documentation 98 | 99 | Please ensure that the code is well-documented and easy to understand when contributing. The following should be documented using proper docstrings: 100 | 101 | - Modules 102 | - Class Definitions 103 | - Function Definitions 104 | - Module-level Variables 105 | 106 | `supabase-pydantic` uses [Google-style docstrings](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) according to [PEP 257](https://www.python.org/dev/peps/pep-0257/). Please see the [example](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) for more information. 107 | 108 | Class attributes and function arguments should be documented in the style of "name: description" & include an annotated type and return type when applicable. Feel free to include example code in docstrings. 109 | -------------------------------------------------------------------------------- /tests/unit/db/marshalers/postgres/test_postgres_relationship_marshaler.py: -------------------------------------------------------------------------------- 1 | """Tests for PostgreSQL relationship marshaler implementation.""" 2 | 3 | import pytest 4 | 5 | from supabase_pydantic.db.constants import RelationType 6 | from supabase_pydantic.db.marshalers.postgres.relationship import PostgresRelationshipMarshaler 7 | from supabase_pydantic.db.models import ColumnInfo, TableInfo 8 | 9 | 10 | @pytest.fixture 11 | def marshaler(): 12 | """Fixture to create a PostgresRelationshipMarshaler instance.""" 13 | return PostgresRelationshipMarshaler() 14 | 15 | 16 | @pytest.mark.unit 17 | @pytest.mark.db 18 | @pytest.mark.marshalers 19 | def test_initialize(): 20 | """Test the initialization of PostgresRelationshipMarshaler.""" 21 | marshaler = PostgresRelationshipMarshaler() 22 | assert isinstance(marshaler, PostgresRelationshipMarshaler) 23 | 24 | 25 | @pytest.mark.unit 26 | @pytest.mark.db 27 | @pytest.mark.marshalers 28 | @pytest.mark.parametrize( 29 | 'table_name, column_name, ref_table_name, ref_column_name, expected', 30 | [ 31 | ('users', 'id', 'orders', 'user_id', RelationType.ONE_TO_MANY), 32 | ('orders', 'user_id', 'users', 'id', RelationType.MANY_TO_ONE), 33 | ('users', 'id', 'user_details', 'user_id', RelationType.ONE_TO_ONE), 34 | ('products', 'id', 'product_tags', 'product_id', RelationType.ONE_TO_MANY), 35 | ], 36 | ) 37 | def test_determine_relationship_type(marshaler, table_name, column_name, ref_table_name, ref_column_name, expected): 38 | """Test determining relationship type between tables.""" 39 | # Create TableInfo objects with appropriate column properties 40 | source_table = TableInfo(name=table_name, columns=[]) 41 | target_table = TableInfo(name=ref_table_name, columns=[]) 42 | 43 | # Add columns with appropriate uniqueness/primary flags based on the test case 44 | if expected == RelationType.ONE_TO_ONE: 45 | # Both sides are unique 46 | source_table.columns.append( 47 | ColumnInfo( 48 | name=column_name, post_gres_datatype='integer', datatype='integer', is_unique=True, primary=False 49 | ) 50 | ) 51 | target_table.columns.append( 52 | ColumnInfo( 53 | name=ref_column_name, post_gres_datatype='integer', datatype='integer', is_unique=True, primary=False 54 | ) 55 | ) 56 | elif expected == RelationType.ONE_TO_MANY: 57 | # Source is unique, target is not 58 | source_table.columns.append( 59 | ColumnInfo( 60 | name=column_name, post_gres_datatype='integer', datatype='integer', is_unique=True, primary=False 61 | ) 62 | ) 63 | target_table.columns.append( 64 | ColumnInfo( 65 | name=ref_column_name, post_gres_datatype='integer', datatype='integer', is_unique=False, primary=False 66 | ) 67 | ) 68 | elif expected == RelationType.MANY_TO_ONE: 69 | # Source is not unique, target is 70 | source_table.columns.append( 71 | ColumnInfo( 72 | name=column_name, post_gres_datatype='integer', datatype='integer', is_unique=False, primary=False 73 | ) 74 | ) 75 | target_table.columns.append( 76 | ColumnInfo( 77 | name=ref_column_name, post_gres_datatype='integer', datatype='integer', is_unique=True, primary=False 78 | ) 79 | ) 80 | else: 81 | # Neither is unique 82 | source_table.columns.append( 83 | ColumnInfo( 84 | name=column_name, post_gres_datatype='integer', datatype='integer', is_unique=False, primary=False 85 | ) 86 | ) 87 | target_table.columns.append( 88 | ColumnInfo( 89 | name=ref_column_name, post_gres_datatype='integer', datatype='integer', is_unique=False, primary=False 90 | ) 91 | ) 92 | 93 | result = marshaler.determine_relationship_type(source_table, target_table, column_name, ref_column_name) 94 | assert result == expected, ( 95 | f'Failed for {table_name}.{column_name} -> {ref_table_name}.{ref_column_name}, ' 96 | f'got {result}, expected {expected}' 97 | ) 98 | 99 | 100 | @pytest.mark.unit 101 | @pytest.mark.db 102 | @pytest.mark.marshalers 103 | def test_determine_relationship_type_none_handling(marshaler): 104 | """Test handling of None values in determine_relationship_type.""" 105 | # Create TableInfo objects with no column data 106 | source_table = TableInfo(name='table1', columns=[]) 107 | target_table = TableInfo(name='table2', columns=[]) 108 | 109 | # When columns don't exist, should default to MANY_TO_MANY 110 | result = marshaler.determine_relationship_type(source_table, target_table, 'col1', 'col2') 111 | assert result == RelationType.MANY_TO_MANY 112 | -------------------------------------------------------------------------------- /tests/unit/db/factories/test_connection_factory.py: -------------------------------------------------------------------------------- 1 | """Tests for the connection parameter factory.""" 2 | 3 | import pytest 4 | from unittest.mock import patch 5 | 6 | from supabase_pydantic.db.database_type import DatabaseType 7 | from supabase_pydantic.db.factories.connection_factory import ConnectionParamFactory 8 | from supabase_pydantic.db.models import PostgresConnectionParams, MySQLConnectionParams 9 | 10 | 11 | @pytest.mark.unit 12 | @pytest.mark.db 13 | class TestConnectionParamFactory: 14 | """Tests for the ConnectionParamFactory class.""" 15 | 16 | def test_create_postgres_explicit(self): 17 | """Test creating PostgreSQL connection params with explicit DB type.""" 18 | params = {'host': 'localhost', 'port': '5432', 'dbname': 'testdb', 'user': 'user', 'password': 'pass'} 19 | result = ConnectionParamFactory.create_connection_params(params, db_type=DatabaseType.POSTGRES) 20 | 21 | assert isinstance(result, PostgresConnectionParams) 22 | assert result.host == 'localhost' 23 | assert result.port == '5432' 24 | assert result.dbname == 'testdb' 25 | assert result.user == 'user' 26 | assert result.password == 'pass' 27 | 28 | def test_create_mysql_explicit(self): 29 | """Test creating MySQL connection params with explicit DB type.""" 30 | params = {'host': 'localhost', 'port': '3306', 'dbname': 'testdb', 'user': 'user', 'password': 'pass'} 31 | result = ConnectionParamFactory.create_connection_params(params, db_type=DatabaseType.MYSQL) 32 | 33 | assert isinstance(result, MySQLConnectionParams) 34 | assert result.host == 'localhost' 35 | assert result.port == '3306' 36 | assert result.dbname == 'testdb' 37 | assert result.user == 'user' 38 | assert result.password == 'pass' 39 | 40 | def test_default_to_postgres(self): 41 | """Test default to PostgreSQL when no DB type is provided.""" 42 | params = {'host': 'localhost', 'port': '5432', 'dbname': 'testdb', 'user': 'user', 'password': 'pass'} 43 | result = ConnectionParamFactory.create_connection_params(params) 44 | 45 | assert isinstance(result, PostgresConnectionParams) 46 | assert result.host == 'localhost' 47 | 48 | @patch('supabase_pydantic.db.factories.connection_factory.detect_database_type') 49 | def test_autodetect_postgres_from_url(self, mock_detect): 50 | """Test auto-detection of PostgreSQL from db_url.""" 51 | mock_detect.return_value = DatabaseType.POSTGRES 52 | 53 | params = {'db_url': 'postgresql://user:pass@localhost:5432/testdb'} 54 | result = ConnectionParamFactory.create_connection_params(params) 55 | 56 | assert isinstance(result, PostgresConnectionParams) 57 | assert result.db_url == 'postgresql://user:pass@localhost:5432/testdb' 58 | mock_detect.assert_called_once_with('postgresql://user:pass@localhost:5432/testdb') 59 | 60 | @patch('supabase_pydantic.db.factories.connection_factory.detect_database_type') 61 | def test_autodetect_mysql_from_url(self, mock_detect): 62 | """Test auto-detection of MySQL from db_url.""" 63 | mock_detect.return_value = DatabaseType.MYSQL 64 | 65 | params = {'db_url': 'mysql://user:pass@localhost:3306/testdb'} 66 | result = ConnectionParamFactory.create_connection_params(params) 67 | 68 | assert isinstance(result, MySQLConnectionParams) 69 | assert result.db_url == 'mysql://user:pass@localhost:3306/testdb' 70 | mock_detect.assert_called_once_with('mysql://user:pass@localhost:3306/testdb') 71 | 72 | @patch('supabase_pydantic.db.factories.connection_factory.detect_database_type') 73 | def test_autodetect_fallback_to_postgres(self, mock_detect): 74 | """Test fallback to PostgreSQL when auto-detection returns None.""" 75 | # Simulate detect_database_type returning None (unable to identify DB type) 76 | mock_detect.return_value = None 77 | 78 | # Even though this is a valid PostgreSQL URL, we're testing the behavior 79 | # when the detection mechanism fails to identify the type 80 | params = {'db_url': 'postgresql://user:pass@localhost:5432/testdb'} 81 | result = ConnectionParamFactory.create_connection_params(params) 82 | 83 | assert isinstance(result, PostgresConnectionParams) 84 | assert result.db_url == 'postgresql://user:pass@localhost:5432/testdb' 85 | mock_detect.assert_called_once_with('postgresql://user:pass@localhost:5432/testdb') 86 | 87 | def test_unsupported_db_type(self): 88 | """Test error handling for unsupported DB types.""" 89 | params = {'host': 'localhost', 'dbname': 'testdb'} 90 | 91 | # Create a custom enum value that isn't supported 92 | unsupported_type = 'UNSUPPORTED' 93 | 94 | with pytest.raises(ValueError) as excinfo: 95 | ConnectionParamFactory.create_connection_params(params, db_type=unsupported_type) 96 | 97 | assert 'Unsupported database type' in str(excinfo.value) 98 | -------------------------------------------------------------------------------- /src/supabase_pydantic/db/drivers/mysql/queries.py: -------------------------------------------------------------------------------- 1 | """MySQL-specific queries for schema information retrieval.""" 2 | 3 | TABLES_QUERY = """ 4 | SELECT table_name 5 | FROM information_schema.tables 6 | WHERE table_schema = %s 7 | """ 8 | 9 | COLUMNS_QUERY = """ 10 | SELECT 11 | column_name, 12 | data_type, 13 | column_default, 14 | is_nullable, 15 | character_maximum_length, 16 | column_type, 17 | CASE 18 | WHEN extra LIKE '%auto_increment%' THEN 'IDENTITY' 19 | ELSE NULL 20 | END as identity_generation 21 | FROM information_schema.columns 22 | WHERE table_schema = %s AND table_name = %s 23 | ORDER BY ordinal_position 24 | """ 25 | 26 | SCHEMAS_QUERY = """ 27 | SELECT schema_name 28 | FROM information_schema.schemata 29 | WHERE schema_name NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') 30 | """ 31 | 32 | FOREIGN_KEYS_QUERY = """ 33 | SELECT 34 | tc.constraint_schema AS table_schema, 35 | tc.table_name AS table_name, 36 | kcu.column_name AS column_name, 37 | kcu.referenced_table_schema AS foreign_table_schema, 38 | kcu.referenced_table_name AS foreign_table_name, 39 | kcu.referenced_column_name AS foreign_column_name, 40 | tc.constraint_name 41 | FROM 42 | information_schema.table_constraints AS tc 43 | JOIN 44 | information_schema.key_column_usage AS kcu 45 | ON 46 | tc.constraint_name = kcu.constraint_name 47 | AND tc.constraint_schema = kcu.constraint_schema 48 | WHERE 49 | tc.constraint_type = 'FOREIGN KEY' 50 | AND tc.constraint_schema = %s; 51 | """ 52 | 53 | GET_ALL_PUBLIC_TABLES_AND_COLUMNS = """ 54 | SELECT 55 | c.table_schema, 56 | c.table_name, 57 | c.column_name, 58 | c.column_default, 59 | c.is_nullable, 60 | c.data_type, 61 | c.character_maximum_length, 62 | t.table_type, 63 | CASE 64 | WHEN c.extra LIKE '%auto_increment%' THEN 'IDENTITY' 65 | ELSE NULL 66 | END as identity_generation, 67 | c.data_type as udt_name, 68 | NULL as array_element_type, 69 | c.column_comment as description 70 | FROM 71 | information_schema.columns AS c 72 | JOIN 73 | information_schema.tables AS t ON c.table_name = t.table_name AND c.table_schema = t.table_schema 74 | WHERE 75 | c.table_schema = %s 76 | AND (t.table_type = 'BASE TABLE' OR t.table_type = 'VIEW') 77 | ORDER BY 78 | c.table_schema, c.table_name, c.ordinal_position; 79 | """ 80 | 81 | CONSTRAINTS_QUERY = """ 82 | SELECT 83 | tc.constraint_name AS constraint_name, 84 | tc.table_name AS table_name, 85 | GROUP_CONCAT(kcu.column_name ORDER BY kcu.ordinal_position) AS columns, 86 | tc.constraint_type AS constraint_type, 87 | CONCAT(tc.constraint_type, ' (', GROUP_CONCAT(kcu.column_name ORDER BY kcu.ordinal_position), ')') AS constraint_definition 88 | FROM 89 | information_schema.table_constraints tc 90 | JOIN 91 | information_schema.key_column_usage kcu 92 | ON 93 | tc.constraint_catalog = kcu.constraint_catalog 94 | AND tc.constraint_schema = kcu.constraint_schema 95 | AND tc.constraint_name = kcu.constraint_name 96 | AND tc.table_name = kcu.table_name 97 | WHERE 98 | tc.constraint_schema = %s 99 | GROUP BY 100 | tc.constraint_name, tc.constraint_type, tc.table_name 101 | ORDER BY 102 | tc.table_name, tc.constraint_type DESC; 103 | """ # noqa: E501 104 | 105 | ENUMS_QUERY = """ 106 | SELECT 107 | -- Enum type name (use column_name + '_enum' to create a unique type name) 108 | CONCAT(c.column_name, '_enum') AS type_name, 109 | -- Namespace is the schema name 110 | c.table_schema AS namespace, 111 | -- Owner - MySQL doesn't track owner the same way, use current user 112 | 'mysql_user' AS owner, 113 | -- Category - use 'E' for enum like PostgreSQL 114 | 'E' AS category, 115 | -- Is defined - always true for MySQL enums 116 | TRUE AS is_defined, 117 | -- Type - use 'e' for enum like PostgreSQL 118 | 'e' AS type, 119 | -- Enum values as a string that can be parsed as an array 120 | -- Extract values from enum('value1','value2') format 121 | SUBSTRING( 122 | COLUMN_TYPE, 123 | LOCATE('enum(', COLUMN_TYPE) + 5, 124 | LENGTH(COLUMN_TYPE) - LOCATE('enum(', COLUMN_TYPE) - 6 125 | ) AS enum_values 126 | FROM 127 | information_schema.columns c 128 | WHERE 129 | c.table_schema = %s 130 | AND c.column_type LIKE 'enum%'; 131 | """ 132 | 133 | USER_DEFINED_TYPES_QUERY = """ 134 | SELECT 135 | c.column_name AS column_name, 136 | c.table_name AS table_name, 137 | c.table_schema AS namespace, 138 | c.data_type AS type_name, 139 | CASE 140 | WHEN c.data_type = 'enum' THEN 'e' 141 | WHEN c.data_type = 'set' THEN 's' 142 | ELSE 'o' 143 | END AS type_category, 144 | CASE 145 | WHEN c.data_type = 'enum' THEN 'Enum' 146 | WHEN c.data_type = 'set' THEN 'Set' 147 | ELSE 'Other' 148 | END AS type_description 149 | FROM 150 | information_schema.columns c 151 | WHERE 152 | c.table_schema = %s 153 | AND (c.data_type = 'enum' OR c.data_type = 'set'); 154 | """ 155 | --------------------------------------------------------------------------------