├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ ├── config.yml │ └── features_request.yaml ├── pull_request_template.md └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGES.rst ├── CONTRIBUTING.rst ├── LICENSE ├── README.rst ├── pyproject.toml ├── src └── sqlacodegen │ ├── __init__.py │ ├── __main__.py │ ├── cli.py │ ├── generators.py │ ├── models.py │ ├── py.typed │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── test_cli.py ├── test_generator_dataclass.py ├── test_generator_declarative.py ├── test_generator_sqlmodel.py └── test_generator_tables.py /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report 3 | labels: ["bug"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | If you observed a crash in the project, or saw unexpected behavior in it, report 9 | your findings here. 10 | - type: checkboxes 11 | attributes: 12 | label: Things to check first 13 | options: 14 | - label: > 15 | I have searched the existing issues and didn't find my bug already reported 16 | there 17 | required: true 18 | - label: > 19 | I have checked that my bug is still present in the latest release 20 | required: true 21 | - type: input 22 | id: project-version 23 | attributes: 24 | label: Sqlacodegen version 25 | description: What version of Sqlacodegen were you running? 26 | validations: 27 | required: true 28 | - type: input 29 | id: sqlalchemy-version 30 | attributes: 31 | label: SQLAlchemy version 32 | description: What version of SQLAlchemy were you running? 33 | validations: 34 | required: true 35 | - type: dropdown 36 | id: rdbms 37 | attributes: 38 | label: RDBMS vendor 39 | description: > 40 | What RDBMS (relational database management system) did you run the tool against? 41 | options: 42 | - PostgreSQL 43 | - MySQL (or compatible) 44 | - SQLite 45 | - MSSQL 46 | - Oracle 47 | - DB2 48 | - Other 49 | - N/A 50 | validations: 51 | required: true 52 | - type: textarea 53 | id: what-happened 54 | attributes: 55 | label: What happened? 56 | description: > 57 | Unless you are reporting a crash, tell us what you expected to happen instead. 58 | validations: 59 | required: true 60 | - type: textarea 61 | id: schema 62 | attributes: 63 | label: Database schema for reproducing the bug 64 | description: > 65 | If applicable, paste the database schema (as a series of `CREATE TABLE` and 66 | other SQL commands) here. 67 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/features_request.yaml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest a new feature 3 | labels: ["enhancement"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | If you have thought of a new feature that would increase the usefulness of this 9 | project, please use this form to send us your idea. 10 | - type: checkboxes 11 | attributes: 12 | label: Things to check first 13 | options: 14 | - label: > 15 | I have searched the existing issues and didn't find my feature already 16 | requested there 17 | required: true 18 | - type: textarea 19 | id: feature 20 | attributes: 21 | label: Feature description 22 | description: > 23 | Describe the feature in detail. The more specific the description you can give, 24 | the easier it should be to implement this feature. 25 | validations: 26 | required: true 27 | - type: textarea 28 | id: usecase 29 | attributes: 30 | label: Use case 31 | description: > 32 | Explain why you need this feature, and why you think it would be useful to 33 | others too. 34 | validations: 35 | required: true 36 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 2 | ## Changes 3 | 4 | Fixes #. 5 | 6 | 7 | 8 | ## Checklist 9 | 10 | If this is a user-facing code change, like a bugfix or a new feature, please ensure that 11 | you've fulfilled the following conditions (where applicable): 12 | 13 | - [ ] You've added tests (in `tests/`) which would fail without your patch 14 | - [ ] You've added a new changelog entry (in `CHANGES.rst`). 15 | 16 | If this is a trivial change, like a typo fix or a code reformatting, then you can ignore 17 | these instructions. 18 | 19 | ### Updating the changelog 20 | 21 | If there are no entries after the last release, use `**UNRELEASED**` as the version. 22 | If, say, your patch fixes issue #123, the entry should look like this: 23 | 24 | ``` 25 | - Fix big bad boo-boo in task groups 26 | (`#123 `_; PR by @yourgithubaccount) 27 | ``` 28 | 29 | If there's no issue linked, just link to your pull request instead by updating the 30 | changelog after you've created the PR. 31 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish packages to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - "[0-9]+.[0-9]+.[0-9]+" 7 | - "[0-9]+.[0-9]+.[0-9]+.post[0-9]+" 8 | - "[0-9]+.[0-9]+.[0-9]+[a-b][0-9]+" 9 | - "[0-9]+.[0-9]+.[0-9]+rc[0-9]+" 10 | 11 | jobs: 12 | build: 13 | name: Build the source tarball and the wheel 14 | runs-on: ubuntu-latest 15 | environment: release 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: 3.x 22 | - name: Install dependencies 23 | run: pip install build 24 | - name: Create packages 25 | run: python -m build 26 | - name: Archive packages 27 | uses: actions/upload-artifact@v4 28 | with: 29 | name: dist 30 | path: dist 31 | 32 | publish: 33 | name: Publish build artifacts to the PyPI 34 | needs: build 35 | runs-on: ubuntu-latest 36 | environment: release 37 | permissions: 38 | id-token: write 39 | steps: 40 | - name: Retrieve packages 41 | uses: actions/download-artifact@v4 42 | - name: Upload packages 43 | uses: pypa/gh-action-pypi-publish@release/v1 44 | 45 | release: 46 | name: Create a GitHub release 47 | needs: build 48 | runs-on: ubuntu-latest 49 | permissions: 50 | contents: write 51 | steps: 52 | - uses: actions/checkout@v4 53 | - id: changelog 54 | uses: agronholm/release-notes@v1 55 | with: 56 | path: CHANGES.rst 57 | - uses: ncipollo/release-action@v1 58 | with: 59 | body: ${{ steps.changelog.outputs.changelog }} 60 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test suite 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | 8 | jobs: 9 | test: 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | allow-prereleases: true 22 | cache: pip 23 | cache-dependency-path: pyproject.toml 24 | - name: Install dependencies 25 | run: pip install -e .[test] 26 | - name: Test with pytest 27 | run: coverage run -m pytest 28 | - name: Upload Coverage 29 | uses: coverallsapp/github-action@v2 30 | with: 31 | parallel: true 32 | 33 | coveralls: 34 | name: Finish Coveralls 35 | needs: test 36 | runs-on: ubuntu-latest 37 | steps: 38 | - name: Finished 39 | uses: coverallsapp/github-action@v2 40 | with: 41 | parallel-finished: true 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.pyc 3 | .project 4 | .pydevproject 5 | .coverage 6 | .settings 7 | .tox 8 | .idea 9 | .vscode 10 | .cache 11 | .pytest_cache 12 | .mypy_cache 13 | dist 14 | build 15 | venv* 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # This is the configuration file for pre-commit (https://pre-commit.com/). 2 | # To use: 3 | # * Install pre-commit (https://pre-commit.com/#installation) 4 | # * Copy this file as ".pre-commit-config.yaml" 5 | # * Run "pre-commit install". 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v5.0.0 9 | hooks: 10 | - id: check-toml 11 | - id: check-yaml 12 | - id: debug-statements 13 | - id: end-of-file-fixer 14 | - id: mixed-line-ending 15 | args: [ "--fix=lf" ] 16 | - id: trailing-whitespace 17 | 18 | - repo: https://github.com/astral-sh/ruff-pre-commit 19 | rev: v0.11.12 20 | hooks: 21 | - id: ruff 22 | args: [--fix, --show-fixes] 23 | - id: ruff-format 24 | 25 | - repo: https://github.com/pre-commit/mirrors-mypy 26 | rev: v1.16.0 27 | hooks: 28 | - id: mypy 29 | additional_dependencies: 30 | - pytest 31 | - "SQLAlchemy >= 2.0.29" 32 | 33 | - repo: https://github.com/pre-commit/pygrep-hooks 34 | rev: v1.10.0 35 | hooks: 36 | - id: rst-backticks 37 | - id: rst-directive-colons 38 | - id: rst-inline-touching-normal 39 | -------------------------------------------------------------------------------- /CHANGES.rst: -------------------------------------------------------------------------------- 1 | Version history 2 | =============== 3 | 4 | **UNRELEASED** 5 | 6 | - Type annotations for ARRAY column attributes now include the Python type of 7 | the array elements 8 | - Added support for specifying engine arguments via ``--engine-arg`` 9 | (PR by @LajosCseppento) 10 | - Fixed incorrect package name used in ``importlib.metadata.version`` for 11 | ``sqlalchemy-citext``, resolving ``PackageNotFoundError`` (PR by @oaimtiaz) 12 | - Prevent double pluralization (PR by @dkratzert) 13 | 14 | **3.0.0** 15 | 16 | - Dropped support for Python 3.8 17 | - Changed nullable relationships to include ``Optional`` in their type annotations 18 | - Fixed SQLModel code generation 19 | - Fixed two rendering issues in ``ENUM`` columns when a non-default schema is used: an 20 | unwarranted positional argument and missing the ``schema`` argument 21 | - Fixed ``AttributeError`` when metadata contains user defined column types 22 | - Fixed ``AssertionError`` when metadata contains a column type that is a type decorator 23 | with an all-uppercase name 24 | - Fixed MySQL ``DOUBLE`` column types being rendered with the wrong arguments 25 | 26 | **3.0.0rc5** 27 | 28 | - Fixed pgvector support not working 29 | 30 | **3.0.0rc4** 31 | 32 | - Dropped support for Python 3.7 33 | - Dropped support for SQLAlchemy 1.x 34 | - Added support for the ``pgvector`` extension (with help from KellyRousselHoomano) 35 | 36 | **3.0.0rc3** 37 | 38 | - Added support for SQLAlchemy 2 (PR by rbuffat with help from mhauru) 39 | - Renamed ``--option`` to ``--options`` and made its values delimited by commas 40 | - Restored CIText and GeoAlchemy2 support (PR by stavvy-rotte) 41 | 42 | **3.0.0rc2** 43 | 44 | - Added support for generating SQLModel classes (PR by Andrii Khirilov) 45 | - Fixed code generation when a single-column index is unique or does not match the 46 | dialect's naming convention (PR by Leonardus Chen) 47 | - Fixed another problem where sequence schemas were not properly separated from the 48 | sequence name 49 | - Fixed invalid generated primary/secondaryjoin expressions in self-referential 50 | many-to-many relationships by using lambdas instead of strings 51 | - Fixed ``AttributeError`` when the declarative generator encounters a table name 52 | already in singular form when ``--option use_inflect`` is enabled 53 | - Increased minimum SQLAlchemy version to 1.4.36 to address issues with ``ForeignKey`` 54 | and indexes, and to eliminate the PostgreSQL UUID column type annotation hack 55 | 56 | **3.0.0rc1** 57 | 58 | - Migrated all packaging/testing configuration to ``pyproject.toml`` 59 | - Fixed unwarranted ``ForeignKey`` declarations appearing in column attributes when 60 | there are named, single column foreign key constraints (PR by Leonardus Chen) 61 | . Fixed ``KeyError`` when rendering an index without any columns 62 | - Fixed improper handling of schema prefixes in sequence names in server defaults 63 | - Fixed identically named tables from different schemas resulting in invalid generated 64 | code 65 | - Fixed imports caused by ``server_default`` conflicting with class attribute names 66 | - Worked around PostgreSQL UUID columns getting ``Any`` as the type annotation 67 | 68 | **3.0.0b3** 69 | 70 | - Dropped support for Python < 3.7 71 | - Dropped support for SQLAlchemy 1.3 72 | - Added a ``__main__`` module which can be used as an alternate entry point to the CLI 73 | - Added detection for sequence use in column defaults on PostgreSQL 74 | - Fixed ``sqlalchemy.exc.InvalidRequestError`` when encountering a column named 75 | "metadata" (regression from 2.0) 76 | - Fixed missing ``MetaData`` import with ``DeclarativeGenerator`` when only plain tables 77 | are generated 78 | - Fixed invalid data classes being generated due to some relationships having been 79 | rendered without a default value 80 | - Improved translation of column names into column attributes where the column name has 81 | whitespace at the beginning or end 82 | - Modified constraint and index rendering to add them explicitly instead of using 83 | shortcuts like ``unique=True``, ``index=True`` or ``primary=True`` when the constraint 84 | or index has a name that does not match the default naming convention 85 | 86 | **3.0.0b2** 87 | 88 | - Fixed ``IDENTITY`` columns not rendering properly when they are part of the primary 89 | key 90 | 91 | **3.0.0b1** 92 | 93 | **NOTE**: Both the API and the command line interface have been refactored in a 94 | backwards incompatible fashion. Notably several command line options have been moved to 95 | specific generators and are no longer visible from ``sqlacodegen --help``. Their 96 | replacement are documented in the README. 97 | 98 | - Dropped support for Python < 3.6 99 | - Added support for Python 3.10 100 | - Added support for SQLAlchemy 1.4 101 | - Added support for bidirectional relationships (use ``--option nobidi``) to disable 102 | - Added support for multiple schemas via ``--schemas`` 103 | - Added support for ``IDENTITY`` columns 104 | - Disabled inflection during table/relationship name generation by default 105 | (use ``--option use_inflect`` to re-enable) 106 | - Refactored the old ``CodeGenerator`` class into separate generator classes, selectable 107 | via ``--generator`` 108 | - Refactored several command line options into generator specific options: 109 | 110 | - ``--noindexes`` → ``--option noindexes`` 111 | - ``--noconstraints`` → ``--option noconstraints`` 112 | - ``--nocomments`` → ``--option nocomments`` 113 | - ``--nojoined`` → ``--option nojoined`` (``declarative`` and ``dataclass`` generators 114 | only) 115 | - ``--noinflect`` → (now the default; use ``--option use_inflect`` instead) 116 | (``declarative`` and ``dataclass`` generators only) 117 | - Fixed missing import for ``JSONB`` ``astext_type`` argument 118 | - Fixed generated column or relationship names colliding with imports or each other 119 | - Fixed ``CompileError`` when encountering server defaults that contain colons (``:``) 120 | 121 | **2.3.0** 122 | 123 | - Added support for rendering computed columns 124 | - Fixed ``--nocomments`` not taking effect (fix proposed by AzuresYang) 125 | - Fixed handling of MySQL ``SET`` column types (and possibly others as well) 126 | 127 | **2.2.0** 128 | 129 | - Added support for rendering table comments (PR by David Hirschfeld) 130 | - Fixed bad identifier names being generated for plain tables (PR by softwarepk) 131 | 132 | **2.1.0** 133 | 134 | - Dropped support for Python 3.4 135 | - Dropped support for SQLAlchemy 0.8 136 | - Added support for Python 3.7 and 3.8 137 | - Added support for SQLAlchemy 1.3 138 | - Added support for column comments (requires SQLAlchemy 1.2+; based on PR by koalas8) 139 | - Fixed crash on unknown column types (``NullType``) 140 | 141 | **2.0.1** 142 | 143 | - Don't adapt dialect specific column types if they need special constructor arguments 144 | (thanks Nicholas Martin for the PR) 145 | 146 | **2.0.0** 147 | 148 | - Refactored code for better reuse 149 | - Dropped support for Python 2.6, 3.2 and 3.3 150 | - Dropped support for SQLAlchemy < 0.8 151 | - Worked around a bug regarding Enum on SQLAlchemy 1.2+ (``name`` was missing) 152 | - Added support for Geoalchemy2 153 | - Fixed invalid class names being generated (fixes #60; PR by Dan O'Huiginn) 154 | - Fixed array item types not being adapted or imported 155 | (fixes #46; thanks to Martin Glauer and Shawn Koschik for help) 156 | - Fixed attribute name of columns named ``metadata`` in mapped classes (fixes #62) 157 | - Fixed rendered column types being changed from the original (fixes #11) 158 | - Fixed server defaults which contain double quotes (fixes #7, #17, #28, #33, #36) 159 | - Fixed ``secondary=`` not taking into account the association table's schema name 160 | (fixes #30) 161 | - Sort models by foreign key dependencies instead of schema and name (fixes #15, #16) 162 | 163 | **1.1.6** 164 | 165 | - Fixed compatibility with SQLAlchemy 1.0 166 | - Added an option to only generate tables 167 | 168 | **1.1.5** 169 | 170 | - Fixed potential assignment of columns or relationships into invalid attribute names 171 | (fixes #10) 172 | - Fixed unique=True missing from unique Index declarations 173 | - Fixed several issues with server defaults 174 | - Fixed potential assignment of columns or relationships into invalid attribute names 175 | - Allowed pascal case for tables already using it 176 | - Switched from Mercurial to Git 177 | 178 | **1.1.4** 179 | 180 | - Fixed compatibility with SQLAlchemy 0.9.0 181 | 182 | **1.1.3** 183 | 184 | - Fixed compatibility with SQLAlchemy 0.8.3+ 185 | - Migrated tests from nose to pytest 186 | 187 | **1.1.2** 188 | 189 | - Fixed non-default schema name not being present in __table_args__ (fixes #2) 190 | - Fixed self referential foreign key causing column type to not be rendered 191 | - Fixed missing "deferrable" and "initially" keyword arguments in ForeignKey constructs 192 | - Fixed foreign key and check constraint handling with alternate schemas (fixes #3) 193 | 194 | **1.1.1** 195 | 196 | - Fixed TypeError when inflect could not determine the singular name of a table for a 197 | many-to-1 relationship 198 | - Fixed _IntegerType, _StringType etc. being rendered instead of proper types on MySQL 199 | 200 | **1.1.0** 201 | 202 | - Added automatic detection of joined-table inheritance 203 | - Fixed missing class name prefix in primary/secondary joins in relationships 204 | - Instead of wildcard imports, generate explicit imports dynamically (fixes #1) 205 | - Use the inflect library to produce better guesses for table to class name conversion 206 | - Automatically detect Boolean columns based on CheckConstraints 207 | - Skip redundant CheckConstraints for Enum and Boolean columns 208 | 209 | **1.0.0** 210 | 211 | - Initial release 212 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing to sqlacodegen 2 | =========================== 3 | 4 | If you wish to contribute a fix or feature to sqlacodegen, please follow the following 5 | guidelines. 6 | 7 | When you make a pull request against the main sqlacodegen codebase, Github runs the 8 | sqlacodegen test suite against your modified code. Before making a pull request, you 9 | should ensure that the modified code passes tests locally. To that end, the use of tox_ 10 | is recommended. The default tox run first runs ``pre-commit`` and then the actual test 11 | suite. To run the checks on all environments in parallel, invoke tox with ``tox -p``. 12 | 13 | To build the documentation, run ``tox -e docs`` which will generate a directory named 14 | ``build`` in which you may view the formatted HTML documentation. 15 | 16 | sqlacodegen uses pre-commit_ to perform several code style/quality checks. It is 17 | recommended to activate pre-commit_ on your local clone of the repository (using 18 | ``pre-commit install``) to ensure that your changes will pass the same checks on GitHub. 19 | 20 | .. _tox: https://tox.readthedocs.io/en/latest/install.html 21 | .. _pre-commit: https://pre-commit.com/#installation 22 | 23 | Making a pull request on Github 24 | ------------------------------- 25 | 26 | To get your changes merged to the main codebase, you need a Github account. 27 | 28 | #. Fork the repository (if you don't have your own fork of it yet) by navigating to the 29 | `main sqlacodegen repository`_ and clicking on "Fork" near the top right corner. 30 | #. Clone the forked repository to your local machine with 31 | ``git clone git@github.com/yourusername/sqlacodegen``. 32 | #. Create a branch for your pull request, like ``git checkout -b myfixname`` 33 | #. Make the desired changes to the code base. 34 | #. Commit your changes locally. If your changes close an existing issue, add the text 35 | ``Fixes #XXX.`` or ``Closes #XXX.`` to the commit message (where XXX is the issue 36 | number). 37 | #. Push the changeset(s) to your forked repository (``git push``) 38 | #. Navigate to Pull requests page on the original repository (not your fork) and click 39 | "New pull request" 40 | #. Click on the text "compare across forks". 41 | #. Select your own fork as the head repository and then select the correct branch name. 42 | #. Click on "Create pull request". 43 | 44 | If you have trouble, consult the `pull request making guide`_ on opensource.com. 45 | 46 | .. _main sqlacodegen repository: https://github.com/agronholm/sqlacodegen 47 | .. _pull request making guide: https://opensource.com/article/19/7/create-pull-request-github 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is the MIT license: http://www.opensource.org/licenses/mit-license.php 2 | 3 | Copyright (c) Alex Grönholm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 6 | software and associated documentation files (the "Software"), to deal in the Software 7 | without restriction, including without limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons 9 | to whom the Software is furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all copies or 12 | substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 15 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 16 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 17 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://github.com/agronholm/sqlacodegen/actions/workflows/test.yml/badge.svg 2 | :target: https://github.com/agronholm/sqlacodegen/actions/workflows/test.yml 3 | :alt: Build Status 4 | .. image:: https://coveralls.io/repos/github/agronholm/sqlacodegen/badge.svg?branch=master 5 | :target: https://coveralls.io/github/agronholm/sqlacodegen?branch=master 6 | :alt: Code Coverage 7 | 8 | This is a tool that reads the structure of an existing database and generates the 9 | appropriate SQLAlchemy model code, using the declarative style if possible. 10 | 11 | This tool was written as a replacement for `sqlautocode`_, which was suffering from 12 | several issues (including, but not limited to, incompatibility with Python 3 and the 13 | latest SQLAlchemy version). 14 | 15 | .. _sqlautocode: http://code.google.com/p/sqlautocode/ 16 | 17 | 18 | Features 19 | ======== 20 | 21 | * Supports SQLAlchemy 2.x 22 | * Produces declarative code that almost looks like it was hand written 23 | * Produces `PEP 8`_ compliant code 24 | * Accurately determines relationships, including many-to-many, one-to-one 25 | * Automatically detects joined table inheritance 26 | * Excellent test coverage 27 | 28 | .. _PEP 8: http://www.python.org/dev/peps/pep-0008/ 29 | 30 | 31 | Installation 32 | ============ 33 | 34 | To install, do:: 35 | 36 | pip install sqlacodegen 37 | 38 | To include support for the PostgreSQL ``CITEXT`` extension type (which should be 39 | considered as tested only under a few environments) specify the ``citext`` extra:: 40 | 41 | pip install sqlacodegen[citext] 42 | 43 | 44 | To include support for the PostgreSQL ``GEOMETRY``, ``GEOGRAPHY``, and ``RASTER`` types 45 | (which should be considered as tested only under a few environments) specify the 46 | ``geoalchemy2`` extra: 47 | 48 | To include support for the PostgreSQL ``PGVECTOR`` extension type, specify the 49 | ``pgvector`` extra:: 50 | 51 | pip install sqlacodegen[pgvector] 52 | 53 | .. code-block:: bash 54 | 55 | pip install sqlacodegen[geoalchemy2] 56 | 57 | 58 | Quickstart 59 | ========== 60 | 61 | At the minimum, you have to give sqlacodegen a database URL. The URL is passed directly 62 | to SQLAlchemy's `create_engine()`_ method so please refer to 63 | `SQLAlchemy's documentation`_ for instructions on how to construct a proper URL. 64 | 65 | Examples:: 66 | 67 | sqlacodegen postgresql:///some_local_db 68 | sqlacodegen --generator tables mysql+pymysql://user:password@localhost/dbname 69 | sqlacodegen --generator dataclasses sqlite:///database.db 70 | # --engine-arg values are parsed with ast.literal_eval 71 | sqlacodegen oracle+oracledb://user:pass@127.0.0.1:1521/XE --engine-arg thick_mode=True 72 | sqlacodegen oracle+oracledb://user:pass@127.0.0.1:1521/XE --engine-arg thick_mode=True --engine-arg connect_args='{"user": "user", "dsn": "..."}' 73 | 74 | To see the list of generic options:: 75 | 76 | sqlacodegen --help 77 | 78 | .. _create_engine(): http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine 79 | .. _SQLAlchemy's documentation: http://docs.sqlalchemy.org/en/latest/core/engines.html 80 | 81 | Available generators 82 | ==================== 83 | 84 | The selection of a generator determines the 85 | 86 | The following built-in generators are available: 87 | 88 | * ``tables`` (only generates ``Table`` objects, for those who don't want to use the ORM) 89 | * ``declarative`` (the default; generates classes inheriting from ``declarative_base()`` 90 | * ``dataclasses`` (generates dataclass-based models; v1.4+ only) 91 | * ``sqlmodels`` (generates model classes for SQLModel_) 92 | 93 | .. _SQLModel: https://sqlmodel.tiangolo.com/ 94 | 95 | Generator-specific options 96 | ========================== 97 | 98 | The following options can be turned on by passing them using ``--options`` (multiple 99 | values must be delimited by commas, e.g. ``--options noconstraints,nobidi``): 100 | 101 | * ``tables`` 102 | 103 | * ``noconstraints``: ignore constraints (foreign key, unique etc.) 104 | * ``nocomments``: ignore table/column comments 105 | * ``noindexes``: ignore indexes 106 | 107 | * ``declarative`` 108 | 109 | * all the options from ``tables`` 110 | * ``use_inflect``: use the ``inflect`` library when naming classes and relationships 111 | (turning plural names into singular; see below for details) 112 | * ``nojoined``: don't try to detect joined-class inheritance (see below for details) 113 | * ``nobidi``: generate relationships in a unidirectional fashion, so only the 114 | many-to-one or first side of many-to-many relationships gets a relationship 115 | attribute, as on v2.X 116 | 117 | * ``dataclasses`` 118 | 119 | * all the options from ``declarative`` 120 | 121 | * ``sqlmodels`` 122 | 123 | * all the options from ``declarative`` 124 | 125 | Model class generators 126 | ---------------------- 127 | 128 | The code generators that generate classes try to generate model classes whenever 129 | possible. There are two circumstances in which a ``Table`` is generated instead: 130 | 131 | * the table has no primary key constraint (which is required by SQLAlchemy for every 132 | model class) 133 | * the table is an association table between two other tables (see below for the 134 | specifics) 135 | 136 | Model class naming logic 137 | ++++++++++++++++++++++++ 138 | 139 | By default, table names are converted to valid PEP 8 compliant class names by replacing 140 | all characters unsuitable for Python identifiers with ``_``. Then, each valid parts 141 | (separated by underscores) are title cased and then joined together, eliminating the 142 | underscores. So, ``example_name`` becomes ``ExampleName``. 143 | 144 | If the ``use_inflect`` option is used, the table name (which is assumed to be in 145 | English) is converted to singular form using the "inflect" library. For example, 146 | ``sales_invoices`` becomes ``SalesInvoice``. Since table names are not always in 147 | English, and the inflection process is far from perfect, inflection is disabled by 148 | default. 149 | 150 | Relationship detection logic 151 | ++++++++++++++++++++++++++++ 152 | 153 | Relationships are detected based on existing foreign key constraints as follows: 154 | 155 | * **many-to-one**: a foreign key constraint exists on the table 156 | * **one-to-one**: same as **many-to-one**, but a unique constraint exists on the 157 | column(s) involved 158 | * **many-to-many**: (not implemented on the ``sqlmodel`` generator) an association table 159 | is found to exist between two tables 160 | 161 | A table is considered an association table if it satisfies all of the following 162 | conditions: 163 | 164 | #. has exactly two foreign key constraints 165 | #. all its columns are involved in said constraints 166 | 167 | Relationship naming logic 168 | +++++++++++++++++++++++++ 169 | 170 | Relationships are typically named based on the table name of the opposite class. 171 | For example, if a class has a relationship to another class with the table named 172 | ``companies``, the relationship would be named ``companies`` (unless the ``use_inflect`` 173 | option was enabled, in which case it would be named ``company`` in the case of a 174 | many-to-one or one-to-one relationship). 175 | 176 | A special case for single column many-to-one and one-to-one relationships, however, is 177 | if the column is named like ``employer_id``. Then the relationship is named ``employer`` 178 | due to that ``_id`` suffix. 179 | 180 | For self referential relationships, the reverse side of the relationship will be named 181 | with the ``_reverse`` suffix appended to it. 182 | 183 | Customizing code generation logic 184 | ================================= 185 | 186 | If the built-in generators with all their options don't quite do what you want, you can 187 | customize the logic by subclassing one of the existing code generator classes. Override 188 | whichever methods you need, and then add an `entry point`_ in the 189 | ``sqlacodegen.generators`` namespace that points to your new class. Once the entry point 190 | is in place (you typically have to install the project with ``pip install``), you can 191 | use ``--generator `` to invoke your custom code generator. 192 | 193 | For examples, you can look at sqlacodegen's own entry points in its `pyproject.toml`_. 194 | 195 | .. _entry point: https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html 196 | .. _pyproject.toml: https://github.com/agronholm/sqlacodegen/blob/master/pyproject.toml 197 | 198 | Getting help 199 | ============ 200 | 201 | If you have problems or other questions, you should start a discussion on the 202 | `sqlacodegen discussion forum`_. As an alternative, you could also try your luck on the 203 | sqlalchemy_ room on Gitter. 204 | 205 | .. _sqlacodegen discussion forum: https://github.com/agronholm/sqlacodegen/discussions/categories/q-a 206 | .. _sqlalchemy: https://app.gitter.im/#/room/#sqlalchemy_community:gitter.im 207 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 64", 4 | "setuptools_scm[toml] >= 6.4" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "sqlacodegen" 10 | description = "Automatic model code generator for SQLAlchemy" 11 | readme = "README.rst" 12 | authors = [{name = "Alex Grönholm", email = "alex.gronholm@nextday.fi"}] 13 | keywords = ["sqlalchemy"] 14 | license = {text = "MIT"} 15 | classifiers = [ 16 | "Development Status :: 5 - Production/Stable", 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: MIT License", 19 | "Environment :: Console", 20 | "Topic :: Database", 21 | "Topic :: Software Development :: Code Generators", 22 | "Programming Language :: Python", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12", 28 | "Programming Language :: Python :: 3.13", 29 | ] 30 | requires-python = ">=3.9" 31 | dependencies = [ 32 | "SQLAlchemy >= 2.0.29", 33 | "inflect >= 4.0.0", 34 | "importlib_metadata; python_version < '3.10'", 35 | ] 36 | dynamic = ["version"] 37 | 38 | [project.urls] 39 | "Bug Tracker" = "https://github.com/agronholm/sqlacodegen/issues" 40 | "Source Code" = "https://github.com/agronholm/sqlacodegen" 41 | 42 | [project.optional-dependencies] 43 | test = [ 44 | "sqlacodegen[sqlmodel,pgvector]", 45 | "pytest >= 7.4", 46 | "coverage >= 7", 47 | "psycopg[binary]", 48 | "mysql-connector-python", 49 | ] 50 | sqlmodel = ["sqlmodel >= 0.0.22"] 51 | citext = ["sqlalchemy-citext >= 1.7.0"] 52 | geoalchemy2 = ["geoalchemy2 >= 0.11.1"] 53 | pgvector = ["pgvector >= 0.2.4"] 54 | 55 | [project.entry-points."sqlacodegen.generators"] 56 | tables = "sqlacodegen.generators:TablesGenerator" 57 | declarative = "sqlacodegen.generators:DeclarativeGenerator" 58 | dataclasses = "sqlacodegen.generators:DataclassGenerator" 59 | sqlmodels = "sqlacodegen.generators:SQLModelGenerator" 60 | 61 | [project.scripts] 62 | sqlacodegen = "sqlacodegen.cli:main" 63 | 64 | [tool.setuptools_scm] 65 | version_scheme = "post-release" 66 | local_scheme = "dirty-tag" 67 | 68 | [tool.ruff] 69 | src = ["src"] 70 | 71 | [tool.ruff.lint] 72 | extend-select = [ 73 | "I", # isort 74 | "ISC", # flake8-implicit-str-concat 75 | "PGH", # pygrep-hooks 76 | "RUF100", # unused noqa (yesqa) 77 | "UP", # pyupgrade 78 | "W", # pycodestyle warnings 79 | ] 80 | 81 | [tool.mypy] 82 | strict = true 83 | disable_error_code = "no-untyped-call" 84 | 85 | [tool.pytest.ini_options] 86 | addopts = "-rsfE --tb=short" 87 | testpaths = ["tests"] 88 | 89 | [coverage.run] 90 | source = ["sqlacodegen"] 91 | relative_files = true 92 | 93 | [coverage.report] 94 | show_missing = true 95 | 96 | [tool.tox] 97 | env_list = ["py39", "py310", "py311", "py312", "py313"] 98 | skip_missing_interpreters = true 99 | 100 | [tool.tox.env_run_base] 101 | package = "editable" 102 | commands = [["python", "-m", "pytest", { replace = "posargs", extend = true }]] 103 | extras = ["test"] 104 | -------------------------------------------------------------------------------- /src/sqlacodegen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agronholm/sqlacodegen/84dcc39d5cef1a840234624603857ab72ec333a0/src/sqlacodegen/__init__.py -------------------------------------------------------------------------------- /src/sqlacodegen/__main__.py: -------------------------------------------------------------------------------- 1 | from .cli import main 2 | 3 | main() 4 | -------------------------------------------------------------------------------- /src/sqlacodegen/cli.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import ast 5 | import sys 6 | from contextlib import ExitStack 7 | from typing import Any, TextIO 8 | 9 | from sqlalchemy.engine import create_engine 10 | from sqlalchemy.schema import MetaData 11 | 12 | try: 13 | import citext 14 | except ImportError: 15 | citext = None 16 | 17 | try: 18 | import geoalchemy2 19 | except ImportError: 20 | geoalchemy2 = None 21 | 22 | try: 23 | import pgvector.sqlalchemy 24 | except ImportError: 25 | pgvector = None 26 | 27 | if sys.version_info < (3, 10): 28 | from importlib_metadata import entry_points, version 29 | else: 30 | from importlib.metadata import entry_points, version 31 | 32 | 33 | def _parse_engine_arg(arg_str: str) -> tuple[str, Any]: 34 | if "=" not in arg_str: 35 | raise argparse.ArgumentTypeError("engine-arg must be in key=value format") 36 | 37 | key, value = arg_str.split("=", 1) 38 | try: 39 | value = ast.literal_eval(value) 40 | except Exception: 41 | pass # Leave as string if literal_eval fails 42 | 43 | return key, value 44 | 45 | 46 | def _parse_engine_args(arg_list: list[str]) -> dict[str, Any]: 47 | result = {} 48 | for arg in arg_list or []: 49 | key, value = _parse_engine_arg(arg) 50 | result[key] = value 51 | 52 | return result 53 | 54 | 55 | def main() -> None: 56 | generators = {ep.name: ep for ep in entry_points(group="sqlacodegen.generators")} 57 | parser = argparse.ArgumentParser( 58 | description="Generates SQLAlchemy model code from an existing database." 59 | ) 60 | parser.add_argument("url", nargs="?", help="SQLAlchemy url to the database") 61 | parser.add_argument( 62 | "--options", help="options (comma-delimited) passed to the generator class" 63 | ) 64 | parser.add_argument( 65 | "--version", action="store_true", help="print the version number and exit" 66 | ) 67 | parser.add_argument( 68 | "--schemas", help="load tables from the given schemas (comma-delimited)" 69 | ) 70 | parser.add_argument( 71 | "--generator", 72 | choices=generators, 73 | default="declarative", 74 | help="generator class to use", 75 | ) 76 | parser.add_argument( 77 | "--tables", help="tables to process (comma-delimited, default: all)" 78 | ) 79 | parser.add_argument( 80 | "--noviews", 81 | action="store_true", 82 | help="ignore views (always true for sqlmodels generator)", 83 | ) 84 | parser.add_argument( 85 | "--engine-arg", 86 | action="append", 87 | help=( 88 | "engine arguments in key=value format, e.g., " 89 | '--engine-arg=connect_args=\'{"user": "scott"}\' ' 90 | "--engine-arg thick_mode=true or " 91 | '--engine-arg thick_mode=\'{"lib_dir": "/path"}\' ' 92 | "(values are parsed with ast.literal_eval)" 93 | ), 94 | ) 95 | parser.add_argument("--outfile", help="file to write output to (default: stdout)") 96 | args = parser.parse_args() 97 | 98 | if args.version: 99 | print(version("sqlacodegen")) 100 | return 101 | 102 | if not args.url: 103 | print("You must supply a url\n", file=sys.stderr) 104 | parser.print_help() 105 | return 106 | 107 | if citext: 108 | print(f"Using sqlalchemy-citext {version('sqlalchemy-citext')}") 109 | 110 | if geoalchemy2: 111 | print(f"Using geoalchemy2 {version('geoalchemy2')}") 112 | 113 | if pgvector: 114 | print(f"Using pgvector {version('pgvector')}") 115 | 116 | # Use reflection to fill in the metadata 117 | engine_args = _parse_engine_args(args.engine_arg) 118 | engine = create_engine(args.url, **engine_args) 119 | metadata = MetaData() 120 | tables = args.tables.split(",") if args.tables else None 121 | schemas = args.schemas.split(",") if args.schemas else [None] 122 | options = set(args.options.split(",")) if args.options else set() 123 | 124 | # Instantiate the generator 125 | generator_class = generators[args.generator].load() 126 | generator = generator_class(metadata, engine, options) 127 | 128 | if not generator.views_supported: 129 | name = generator_class.__name__ 130 | print( 131 | f"VIEW models will not be generated when using the '{name}' generator", 132 | file=sys.stderr, 133 | ) 134 | 135 | for schema in schemas: 136 | metadata.reflect( 137 | engine, schema, (generator.views_supported and not args.noviews), tables 138 | ) 139 | 140 | # Open the target file (if given) 141 | with ExitStack() as stack: 142 | outfile: TextIO 143 | if args.outfile: 144 | outfile = open(args.outfile, "w", encoding="utf-8") 145 | stack.enter_context(outfile) 146 | else: 147 | outfile = sys.stdout 148 | 149 | # Write the generated model code to the specified file or standard output 150 | outfile.write(generator.generate()) 151 | -------------------------------------------------------------------------------- /src/sqlacodegen/generators.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import re 5 | import sys 6 | from abc import ABCMeta, abstractmethod 7 | from collections import defaultdict 8 | from collections.abc import Collection, Iterable, Sequence 9 | from dataclasses import dataclass 10 | from importlib import import_module 11 | from inspect import Parameter 12 | from itertools import count 13 | from keyword import iskeyword 14 | from pprint import pformat 15 | from textwrap import indent 16 | from typing import Any, ClassVar, Literal, cast 17 | 18 | import inflect 19 | import sqlalchemy 20 | from sqlalchemy import ( 21 | ARRAY, 22 | Boolean, 23 | CheckConstraint, 24 | Column, 25 | Computed, 26 | Constraint, 27 | DefaultClause, 28 | Enum, 29 | ForeignKey, 30 | ForeignKeyConstraint, 31 | Identity, 32 | Index, 33 | MetaData, 34 | PrimaryKeyConstraint, 35 | String, 36 | Table, 37 | Text, 38 | TypeDecorator, 39 | UniqueConstraint, 40 | ) 41 | from sqlalchemy.dialects.postgresql import DOMAIN, JSONB 42 | from sqlalchemy.engine import Connection, Engine 43 | from sqlalchemy.exc import CompileError 44 | from sqlalchemy.sql.elements import TextClause 45 | from sqlalchemy.sql.type_api import UserDefinedType 46 | from sqlalchemy.types import TypeEngine 47 | 48 | from .models import ( 49 | ColumnAttribute, 50 | JoinType, 51 | Model, 52 | ModelClass, 53 | RelationshipAttribute, 54 | RelationshipType, 55 | ) 56 | from .utils import ( 57 | decode_postgresql_sequence, 58 | get_column_names, 59 | get_common_fk_constraints, 60 | get_compiled_expression, 61 | get_constraint_sort_key, 62 | qualified_table_name, 63 | render_callable, 64 | uses_default_name, 65 | ) 66 | 67 | _re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)") 68 | _re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2') 69 | _re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)") 70 | _re_enum_item = re.compile(r"'(.*?)(? bool: 109 | pass 110 | 111 | @abstractmethod 112 | def generate(self) -> str: 113 | """ 114 | Generate the code for the given metadata. 115 | .. note:: May modify the metadata. 116 | """ 117 | 118 | 119 | @dataclass(eq=False) 120 | class TablesGenerator(CodeGenerator): 121 | valid_options: ClassVar[set[str]] = {"noindexes", "noconstraints", "nocomments"} 122 | builtin_module_names: ClassVar[set[str]] = set(sys.builtin_module_names) | { 123 | "dataclasses" 124 | } 125 | 126 | def __init__( 127 | self, 128 | metadata: MetaData, 129 | bind: Connection | Engine, 130 | options: Sequence[str], 131 | *, 132 | indentation: str = " ", 133 | ): 134 | super().__init__(metadata, bind, options) 135 | self.indentation: str = indentation 136 | self.imports: dict[str, set[str]] = defaultdict(set) 137 | self.module_imports: set[str] = set() 138 | 139 | @property 140 | def views_supported(self) -> bool: 141 | return True 142 | 143 | def generate_base(self) -> None: 144 | self.base = Base( 145 | literal_imports=[LiteralImport("sqlalchemy", "MetaData")], 146 | declarations=["metadata = MetaData()"], 147 | metadata_ref="metadata", 148 | ) 149 | 150 | def generate(self) -> str: 151 | self.generate_base() 152 | 153 | sections: list[str] = [] 154 | 155 | # Remove unwanted elements from the metadata 156 | for table in list(self.metadata.tables.values()): 157 | if self.should_ignore_table(table): 158 | self.metadata.remove(table) 159 | continue 160 | 161 | if "noindexes" in self.options: 162 | table.indexes.clear() 163 | 164 | if "noconstraints" in self.options: 165 | table.constraints.clear() 166 | 167 | if "nocomments" in self.options: 168 | table.comment = None 169 | 170 | for column in table.columns: 171 | if "nocomments" in self.options: 172 | column.comment = None 173 | 174 | # Use information from column constraints to figure out the intended column 175 | # types 176 | for table in self.metadata.tables.values(): 177 | self.fix_column_types(table) 178 | 179 | # Generate the models 180 | models: list[Model] = self.generate_models() 181 | 182 | # Render module level variables 183 | variables = self.render_module_variables(models) 184 | if variables: 185 | sections.append(variables + "\n") 186 | 187 | # Render models 188 | rendered_models = self.render_models(models) 189 | if rendered_models: 190 | sections.append(rendered_models) 191 | 192 | # Render collected imports 193 | groups = self.group_imports() 194 | imports = "\n\n".join("\n".join(line for line in group) for group in groups) 195 | if imports: 196 | sections.insert(0, imports) 197 | 198 | return "\n\n".join(sections) + "\n" 199 | 200 | def collect_imports(self, models: Iterable[Model]) -> None: 201 | for literal_import in self.base.literal_imports: 202 | self.add_literal_import(literal_import.pkgname, literal_import.name) 203 | 204 | for model in models: 205 | self.collect_imports_for_model(model) 206 | 207 | def collect_imports_for_model(self, model: Model) -> None: 208 | if model.__class__ is Model: 209 | self.add_import(Table) 210 | 211 | for column in model.table.c: 212 | self.collect_imports_for_column(column) 213 | 214 | for constraint in model.table.constraints: 215 | self.collect_imports_for_constraint(constraint) 216 | 217 | for index in model.table.indexes: 218 | self.collect_imports_for_constraint(index) 219 | 220 | def collect_imports_for_column(self, column: Column[Any]) -> None: 221 | self.add_import(column.type) 222 | 223 | if isinstance(column.type, ARRAY): 224 | self.add_import(column.type.item_type.__class__) 225 | elif isinstance(column.type, JSONB): 226 | if ( 227 | not isinstance(column.type.astext_type, Text) 228 | or column.type.astext_type.length is not None 229 | ): 230 | self.add_import(column.type.astext_type) 231 | elif isinstance(column.type, DOMAIN): 232 | self.add_import(column.type.data_type.__class__) 233 | 234 | if column.default: 235 | self.add_import(column.default) 236 | 237 | if column.server_default: 238 | if isinstance(column.server_default, (Computed, Identity)): 239 | self.add_import(column.server_default) 240 | elif isinstance(column.server_default, DefaultClause): 241 | self.add_literal_import("sqlalchemy", "text") 242 | 243 | def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None: 244 | if isinstance(constraint, Index): 245 | if len(constraint.columns) > 1 or not uses_default_name(constraint): 246 | self.add_literal_import("sqlalchemy", "Index") 247 | elif isinstance(constraint, PrimaryKeyConstraint): 248 | if not uses_default_name(constraint): 249 | self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint") 250 | elif isinstance(constraint, UniqueConstraint): 251 | if len(constraint.columns) > 1 or not uses_default_name(constraint): 252 | self.add_literal_import("sqlalchemy", "UniqueConstraint") 253 | elif isinstance(constraint, ForeignKeyConstraint): 254 | if len(constraint.columns) > 1 or not uses_default_name(constraint): 255 | self.add_literal_import("sqlalchemy", "ForeignKeyConstraint") 256 | else: 257 | self.add_import(ForeignKey) 258 | else: 259 | self.add_import(constraint) 260 | 261 | def add_import(self, obj: Any) -> None: 262 | # Don't store builtin imports 263 | if getattr(obj, "__module__", "builtins") == "builtins": 264 | return 265 | 266 | type_ = type(obj) if not isinstance(obj, type) else obj 267 | pkgname = type_.__module__ 268 | 269 | # The column types have already been adapted towards generic types if possible, 270 | # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to 271 | # use that rather than the generic sqlalchemy type as it might have different 272 | # constructor parameters. 273 | if pkgname.startswith("sqlalchemy.dialects."): 274 | dialect_pkgname = ".".join(pkgname.split(".")[0:3]) 275 | dialect_pkg = import_module(dialect_pkgname) 276 | 277 | if type_.__name__ in dialect_pkg.__all__: 278 | pkgname = dialect_pkgname 279 | elif type_.__name__ in dir(sqlalchemy): 280 | pkgname = "sqlalchemy" 281 | else: 282 | pkgname = type_.__module__ 283 | 284 | self.add_literal_import(pkgname, type_.__name__) 285 | 286 | def add_literal_import(self, pkgname: str, name: str) -> None: 287 | names = self.imports.setdefault(pkgname, set()) 288 | names.add(name) 289 | 290 | def remove_literal_import(self, pkgname: str, name: str) -> None: 291 | names = self.imports.setdefault(pkgname, set()) 292 | if name in names: 293 | names.remove(name) 294 | 295 | def add_module_import(self, pgkname: str) -> None: 296 | self.module_imports.add(pgkname) 297 | 298 | def group_imports(self) -> list[list[str]]: 299 | future_imports: list[str] = [] 300 | stdlib_imports: list[str] = [] 301 | thirdparty_imports: list[str] = [] 302 | 303 | for package in sorted(self.imports): 304 | imports = ", ".join(sorted(self.imports[package])) 305 | collection = thirdparty_imports 306 | if package == "__future__": 307 | collection = future_imports 308 | elif package in self.builtin_module_names: 309 | collection = stdlib_imports 310 | elif package in sys.modules: 311 | if "site-packages" not in (sys.modules[package].__file__ or ""): 312 | collection = stdlib_imports 313 | 314 | collection.append(f"from {package} import {imports}") 315 | 316 | for module in sorted(self.module_imports): 317 | thirdparty_imports.append(f"import {module}") 318 | 319 | return [ 320 | group 321 | for group in (future_imports, stdlib_imports, thirdparty_imports) 322 | if group 323 | ] 324 | 325 | def generate_models(self) -> list[Model]: 326 | models = [Model(table) for table in self.metadata.sorted_tables] 327 | 328 | # Collect the imports 329 | self.collect_imports(models) 330 | 331 | # Generate names for models 332 | global_names = { 333 | name for namespace in self.imports.values() for name in namespace 334 | } 335 | for model in models: 336 | self.generate_model_name(model, global_names) 337 | global_names.add(model.name) 338 | 339 | return models 340 | 341 | def generate_model_name(self, model: Model, global_names: set[str]) -> None: 342 | preferred_name = f"t_{model.table.name}" 343 | model.name = self.find_free_name(preferred_name, global_names) 344 | 345 | def render_module_variables(self, models: list[Model]) -> str: 346 | declarations = self.base.declarations 347 | 348 | if any(not isinstance(model, ModelClass) for model in models): 349 | if self.base.table_metadata_declaration is not None: 350 | declarations.append(self.base.table_metadata_declaration) 351 | 352 | return "\n".join(declarations) 353 | 354 | def render_models(self, models: list[Model]) -> str: 355 | rendered: list[str] = [] 356 | for model in models: 357 | rendered_table = self.render_table(model.table) 358 | rendered.append(f"{model.name} = {rendered_table}") 359 | 360 | return "\n\n".join(rendered) 361 | 362 | def render_table(self, table: Table) -> str: 363 | args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"] 364 | kwargs: dict[str, object] = {} 365 | for column in table.columns: 366 | # Cast is required because of a bug in the SQLAlchemy stubs regarding 367 | # Table.columns 368 | args.append(self.render_column(column, True, is_table=True)) 369 | 370 | for constraint in sorted(table.constraints, key=get_constraint_sort_key): 371 | if uses_default_name(constraint): 372 | if isinstance(constraint, PrimaryKeyConstraint): 373 | continue 374 | elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)): 375 | if len(constraint.columns) == 1: 376 | continue 377 | 378 | args.append(self.render_constraint(constraint)) 379 | 380 | for index in sorted(table.indexes, key=lambda i: cast(str, i.name)): 381 | # One-column indexes should be rendered as index=True on columns 382 | if len(index.columns) > 1 or not uses_default_name(index): 383 | args.append(self.render_index(index)) 384 | 385 | if table.schema: 386 | kwargs["schema"] = repr(table.schema) 387 | 388 | table_comment = getattr(table, "comment", None) 389 | if table_comment: 390 | kwargs["comment"] = repr(table.comment) 391 | 392 | return render_callable("Table", *args, kwargs=kwargs, indentation=" ") 393 | 394 | def render_index(self, index: Index) -> str: 395 | extra_args = [repr(col.name) for col in index.columns] 396 | kwargs = {} 397 | if index.unique: 398 | kwargs["unique"] = True 399 | 400 | return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs) 401 | 402 | # TODO find better solution for is_table 403 | def render_column( 404 | self, column: Column[Any], show_name: bool, is_table: bool = False 405 | ) -> str: 406 | args = [] 407 | kwargs: dict[str, Any] = {} 408 | kwarg = [] 409 | is_sole_pk = column.primary_key and len(column.table.primary_key) == 1 410 | dedicated_fks = [ 411 | c 412 | for c in column.foreign_keys 413 | if c.constraint 414 | and len(c.constraint.columns) == 1 415 | and uses_default_name(c.constraint) 416 | ] 417 | is_unique = any( 418 | isinstance(c, UniqueConstraint) 419 | and set(c.columns) == {column} 420 | and uses_default_name(c) 421 | for c in column.table.constraints 422 | ) 423 | is_unique = is_unique or any( 424 | i.unique and set(i.columns) == {column} and uses_default_name(i) 425 | for i in column.table.indexes 426 | ) 427 | is_primary = ( 428 | any( 429 | isinstance(c, PrimaryKeyConstraint) 430 | and column.name in c.columns 431 | and uses_default_name(c) 432 | for c in column.table.constraints 433 | ) 434 | or column.primary_key 435 | ) 436 | has_index = any( 437 | set(i.columns) == {column} and uses_default_name(i) 438 | for i in column.table.indexes 439 | ) 440 | 441 | if show_name: 442 | args.append(repr(column.name)) 443 | 444 | # Render the column type if there are no foreign keys on it or any of them 445 | # points back to itself 446 | if not dedicated_fks or any(fk.column is column for fk in dedicated_fks): 447 | args.append(self.render_column_type(column.type)) 448 | 449 | for fk in dedicated_fks: 450 | args.append(self.render_constraint(fk)) 451 | 452 | if column.default: 453 | args.append(repr(column.default)) 454 | 455 | if column.key != column.name: 456 | kwargs["key"] = column.key 457 | if is_primary: 458 | kwargs["primary_key"] = True 459 | if not column.nullable and not is_sole_pk and is_table: 460 | kwargs["nullable"] = False 461 | 462 | if is_unique: 463 | column.unique = True 464 | kwargs["unique"] = True 465 | if has_index: 466 | column.index = True 467 | kwarg.append("index") 468 | kwargs["index"] = True 469 | 470 | if isinstance(column.server_default, DefaultClause): 471 | kwargs["server_default"] = render_callable( 472 | "text", repr(cast(TextClause, column.server_default.arg).text) 473 | ) 474 | elif isinstance(column.server_default, Computed): 475 | expression = str(column.server_default.sqltext) 476 | 477 | computed_kwargs = {} 478 | if column.server_default.persisted is not None: 479 | computed_kwargs["persisted"] = column.server_default.persisted 480 | 481 | args.append( 482 | render_callable("Computed", repr(expression), kwargs=computed_kwargs) 483 | ) 484 | elif isinstance(column.server_default, Identity): 485 | args.append(repr(column.server_default)) 486 | elif column.server_default: 487 | kwargs["server_default"] = repr(column.server_default) 488 | 489 | comment = getattr(column, "comment", None) 490 | if comment: 491 | kwargs["comment"] = repr(comment) 492 | 493 | return self.render_column_callable(is_table, *args, **kwargs) 494 | 495 | def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str: 496 | if is_table: 497 | self.add_import(Column) 498 | return render_callable("Column", *args, kwargs=kwargs) 499 | else: 500 | return render_callable("mapped_column", *args, kwargs=kwargs) 501 | 502 | def render_column_type(self, coltype: object) -> str: 503 | args = [] 504 | kwargs: dict[str, Any] = {} 505 | sig = inspect.signature(coltype.__class__.__init__) 506 | defaults = {param.name: param.default for param in sig.parameters.values()} 507 | missing = object() 508 | use_kwargs = False 509 | for param in list(sig.parameters.values())[1:]: 510 | # Remove annoyances like _warn_on_bytestring 511 | if param.name.startswith("_"): 512 | continue 513 | elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD): 514 | use_kwargs = True 515 | continue 516 | 517 | value = getattr(coltype, param.name, missing) 518 | default = defaults.get(param.name, missing) 519 | if isinstance(value, TextClause): 520 | self.add_literal_import("sqlalchemy", "text") 521 | rendered_value = render_callable("text", repr(value.text)) 522 | else: 523 | rendered_value = repr(value) 524 | 525 | if value is missing or value == default: 526 | use_kwargs = True 527 | elif use_kwargs: 528 | kwargs[param.name] = rendered_value 529 | else: 530 | args.append(rendered_value) 531 | 532 | vararg = next( 533 | ( 534 | param.name 535 | for param in sig.parameters.values() 536 | if param.kind is Parameter.VAR_POSITIONAL 537 | ), 538 | None, 539 | ) 540 | if vararg and hasattr(coltype, vararg): 541 | varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)] 542 | args.extend(varargs_repr) 543 | 544 | # These arguments cannot be autodetected from the Enum initializer 545 | if isinstance(coltype, Enum): 546 | for colname in "name", "schema": 547 | if (value := getattr(coltype, colname)) is not None: 548 | kwargs[colname] = repr(value) 549 | 550 | if isinstance(coltype, JSONB): 551 | # Remove astext_type if it's the default 552 | if ( 553 | isinstance(coltype.astext_type, Text) 554 | and coltype.astext_type.length is None 555 | ): 556 | del kwargs["astext_type"] 557 | 558 | if args or kwargs: 559 | return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs) 560 | else: 561 | return coltype.__class__.__name__ 562 | 563 | def render_constraint(self, constraint: Constraint | ForeignKey) -> str: 564 | def add_fk_options(*opts: Any) -> None: 565 | args.extend(repr(opt) for opt in opts) 566 | for attr in "ondelete", "onupdate", "deferrable", "initially", "match": 567 | value = getattr(constraint, attr, None) 568 | if value: 569 | kwargs[attr] = repr(value) 570 | 571 | args: list[str] = [] 572 | kwargs: dict[str, Any] = {} 573 | if isinstance(constraint, ForeignKey): 574 | remote_column = ( 575 | f"{constraint.column.table.fullname}.{constraint.column.name}" 576 | ) 577 | add_fk_options(remote_column) 578 | elif isinstance(constraint, ForeignKeyConstraint): 579 | local_columns = get_column_names(constraint) 580 | remote_columns = [ 581 | f"{fk.column.table.fullname}.{fk.column.name}" 582 | for fk in constraint.elements 583 | ] 584 | add_fk_options(local_columns, remote_columns) 585 | elif isinstance(constraint, CheckConstraint): 586 | args.append(repr(get_compiled_expression(constraint.sqltext, self.bind))) 587 | elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)): 588 | args.extend(repr(col.name) for col in constraint.columns) 589 | else: 590 | raise TypeError( 591 | f"Cannot render constraint of type {constraint.__class__.__name__}" 592 | ) 593 | 594 | if isinstance(constraint, Constraint) and not uses_default_name(constraint): 595 | kwargs["name"] = repr(constraint.name) 596 | 597 | return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs) 598 | 599 | def should_ignore_table(self, table: Table) -> bool: 600 | # Support for Alembic and sqlalchemy-migrate -- never expose the schema version 601 | # tables 602 | return table.name in ("alembic_version", "migrate_version") 603 | 604 | def find_free_name( 605 | self, name: str, global_names: set[str], local_names: Collection[str] = () 606 | ) -> str: 607 | """ 608 | Generate an attribute name that does not clash with other local or global names. 609 | """ 610 | name = name.strip() 611 | assert name, "Identifier cannot be empty" 612 | name = _re_invalid_identifier.sub("_", name) 613 | if name[0].isdigit(): 614 | name = "_" + name 615 | elif iskeyword(name) or name == "metadata": 616 | name += "_" 617 | 618 | original = name 619 | for i in count(): 620 | if name not in global_names and name not in local_names: 621 | break 622 | 623 | name = original + (str(i) if i else "_") 624 | 625 | return name 626 | 627 | def fix_column_types(self, table: Table) -> None: 628 | """Adjust the reflected column types.""" 629 | # Detect check constraints for boolean and enum columns 630 | for constraint in table.constraints.copy(): 631 | if isinstance(constraint, CheckConstraint): 632 | sqltext = get_compiled_expression(constraint.sqltext, self.bind) 633 | 634 | # Turn any integer-like column with a CheckConstraint like 635 | # "column IN (0, 1)" into a Boolean 636 | match = _re_boolean_check_constraint.match(sqltext) 637 | if match: 638 | colname_match = _re_column_name.match(match.group(1)) 639 | if colname_match: 640 | colname = colname_match.group(3) 641 | table.constraints.remove(constraint) 642 | table.c[colname].type = Boolean() 643 | continue 644 | 645 | # Turn any string-type column with a CheckConstraint like 646 | # "column IN (...)" into an Enum 647 | match = _re_enum_check_constraint.match(sqltext) 648 | if match: 649 | colname_match = _re_column_name.match(match.group(1)) 650 | if colname_match: 651 | colname = colname_match.group(3) 652 | items = match.group(2) 653 | if isinstance(table.c[colname].type, String): 654 | table.constraints.remove(constraint) 655 | if not isinstance(table.c[colname].type, Enum): 656 | options = _re_enum_item.findall(items) 657 | table.c[colname].type = Enum( 658 | *options, native_enum=False 659 | ) 660 | 661 | continue 662 | 663 | for column in table.c: 664 | try: 665 | column.type = self.get_adapted_type(column.type) 666 | except CompileError: 667 | pass 668 | 669 | # PostgreSQL specific fix: detect sequences from server_default 670 | if column.server_default and self.bind.dialect.name == "postgresql": 671 | if isinstance(column.server_default, DefaultClause) and isinstance( 672 | column.server_default.arg, TextClause 673 | ): 674 | schema, seqname = decode_postgresql_sequence( 675 | column.server_default.arg 676 | ) 677 | if seqname: 678 | # Add an explicit sequence 679 | if seqname != f"{column.table.name}_{column.name}_seq": 680 | column.default = sqlalchemy.Sequence(seqname, schema=schema) 681 | 682 | column.server_default = None 683 | 684 | def get_adapted_type(self, coltype: Any) -> Any: 685 | compiled_type = coltype.compile(self.bind.engine.dialect) 686 | for supercls in coltype.__class__.__mro__: 687 | if not supercls.__name__.startswith("_") and hasattr( 688 | supercls, "__visit_name__" 689 | ): 690 | # Don't try to adapt UserDefinedType as it's not a proper column type 691 | if supercls is UserDefinedType or issubclass(supercls, TypeDecorator): 692 | return coltype 693 | 694 | # Hack to fix adaptation of the Enum class which is broken since 695 | # SQLAlchemy 1.2 696 | kw = {} 697 | if supercls is Enum: 698 | kw["name"] = coltype.name 699 | if coltype.schema: 700 | kw["schema"] = coltype.schema 701 | 702 | try: 703 | new_coltype = coltype.adapt(supercls) 704 | except TypeError: 705 | # If the adaptation fails, don't try again 706 | break 707 | 708 | for key, value in kw.items(): 709 | setattr(new_coltype, key, value) 710 | 711 | if isinstance(coltype, ARRAY): 712 | new_coltype.item_type = self.get_adapted_type(new_coltype.item_type) 713 | 714 | try: 715 | # If the adapted column type does not render the same as the 716 | # original, don't substitute it 717 | if new_coltype.compile(self.bind.engine.dialect) != compiled_type: 718 | break 719 | except CompileError: 720 | # If the adapted column type can't be compiled, don't substitute it 721 | break 722 | 723 | # Stop on the first valid non-uppercase column type class 724 | coltype = new_coltype 725 | if supercls.__name__ != supercls.__name__.upper(): 726 | break 727 | 728 | return coltype 729 | 730 | 731 | class DeclarativeGenerator(TablesGenerator): 732 | valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | { 733 | "use_inflect", 734 | "nojoined", 735 | "nobidi", 736 | } 737 | 738 | def __init__( 739 | self, 740 | metadata: MetaData, 741 | bind: Connection | Engine, 742 | options: Sequence[str], 743 | *, 744 | indentation: str = " ", 745 | base_class_name: str = "Base", 746 | ): 747 | super().__init__(metadata, bind, options, indentation=indentation) 748 | self.base_class_name: str = base_class_name 749 | self.inflect_engine = inflect.engine() 750 | 751 | def generate_base(self) -> None: 752 | self.base = Base( 753 | literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")], 754 | declarations=[ 755 | f"class {self.base_class_name}(DeclarativeBase):", 756 | f"{self.indentation}pass", 757 | ], 758 | metadata_ref=f"{self.base_class_name}.metadata", 759 | ) 760 | 761 | def collect_imports(self, models: Iterable[Model]) -> None: 762 | super().collect_imports(models) 763 | if any(isinstance(model, ModelClass) for model in models): 764 | self.add_literal_import("sqlalchemy.orm", "Mapped") 765 | self.add_literal_import("sqlalchemy.orm", "mapped_column") 766 | 767 | def collect_imports_for_model(self, model: Model) -> None: 768 | super().collect_imports_for_model(model) 769 | if isinstance(model, ModelClass): 770 | if model.relationships: 771 | self.add_literal_import("sqlalchemy.orm", "relationship") 772 | 773 | def generate_models(self) -> list[Model]: 774 | models_by_table_name: dict[str, Model] = {} 775 | 776 | # Pick association tables from the metadata into their own set, don't process 777 | # them normally 778 | links: defaultdict[str, list[Model]] = defaultdict(lambda: []) 779 | for table in self.metadata.sorted_tables: 780 | qualified_name = qualified_table_name(table) 781 | 782 | # Link tables have exactly two foreign key constraints and all columns are 783 | # involved in them 784 | fk_constraints = sorted( 785 | table.foreign_key_constraints, key=get_constraint_sort_key 786 | ) 787 | if len(fk_constraints) == 2 and all( 788 | col.foreign_keys for col in table.columns 789 | ): 790 | model = models_by_table_name[qualified_name] = Model(table) 791 | tablename = fk_constraints[0].elements[0].column.table.name 792 | links[tablename].append(model) 793 | continue 794 | 795 | # Only form model classes for tables that have a primary key and are not 796 | # association tables 797 | if not table.primary_key: 798 | models_by_table_name[qualified_name] = Model(table) 799 | else: 800 | model = ModelClass(table) 801 | models_by_table_name[qualified_name] = model 802 | 803 | # Fill in the columns 804 | for column in table.c: 805 | column_attr = ColumnAttribute(model, column) 806 | model.columns.append(column_attr) 807 | 808 | # Add relationships 809 | for model in models_by_table_name.values(): 810 | if isinstance(model, ModelClass): 811 | self.generate_relationships( 812 | model, models_by_table_name, links[model.table.name] 813 | ) 814 | 815 | # Nest inherited classes in their superclasses to ensure proper ordering 816 | if "nojoined" not in self.options: 817 | for model in list(models_by_table_name.values()): 818 | if not isinstance(model, ModelClass): 819 | continue 820 | 821 | pk_column_names = {col.name for col in model.table.primary_key.columns} 822 | for constraint in model.table.foreign_key_constraints: 823 | if set(get_column_names(constraint)) == pk_column_names: 824 | target = models_by_table_name[ 825 | qualified_table_name(constraint.elements[0].column.table) 826 | ] 827 | if isinstance(target, ModelClass): 828 | model.parent_class = target 829 | target.children.append(model) 830 | 831 | # Change base if we only have tables 832 | if not any( 833 | isinstance(model, ModelClass) for model in models_by_table_name.values() 834 | ): 835 | super().generate_base() 836 | 837 | # Collect the imports 838 | self.collect_imports(models_by_table_name.values()) 839 | 840 | # Rename models and their attributes that conflict with imports or other 841 | # attributes 842 | global_names = { 843 | name for namespace in self.imports.values() for name in namespace 844 | } 845 | for model in models_by_table_name.values(): 846 | self.generate_model_name(model, global_names) 847 | global_names.add(model.name) 848 | 849 | return list(models_by_table_name.values()) 850 | 851 | def generate_relationships( 852 | self, 853 | source: ModelClass, 854 | models_by_table_name: dict[str, Model], 855 | association_tables: list[Model], 856 | ) -> list[RelationshipAttribute]: 857 | relationships: list[RelationshipAttribute] = [] 858 | reverse_relationship: RelationshipAttribute | None 859 | 860 | # Add many-to-one (and one-to-many) relationships 861 | pk_column_names = {col.name for col in source.table.primary_key.columns} 862 | for constraint in sorted( 863 | source.table.foreign_key_constraints, key=get_constraint_sort_key 864 | ): 865 | target = models_by_table_name[ 866 | qualified_table_name(constraint.elements[0].column.table) 867 | ] 868 | if isinstance(target, ModelClass): 869 | if "nojoined" not in self.options: 870 | if set(get_column_names(constraint)) == pk_column_names: 871 | parent = models_by_table_name[ 872 | qualified_table_name(constraint.elements[0].column.table) 873 | ] 874 | if isinstance(parent, ModelClass): 875 | source.parent_class = parent 876 | parent.children.append(source) 877 | continue 878 | 879 | # Add uselist=False to One-to-One relationships 880 | column_names = get_column_names(constraint) 881 | if any( 882 | isinstance(c, (PrimaryKeyConstraint, UniqueConstraint)) 883 | and {col.name for col in c.columns} == set(column_names) 884 | for c in constraint.table.constraints 885 | ): 886 | r_type = RelationshipType.ONE_TO_ONE 887 | else: 888 | r_type = RelationshipType.MANY_TO_ONE 889 | 890 | relationship = RelationshipAttribute(r_type, source, target, constraint) 891 | source.relationships.append(relationship) 892 | 893 | # For self referential relationships, remote_side needs to be set 894 | if source is target: 895 | relationship.remote_side = [ 896 | source.get_column_attribute(col.name) 897 | for col in constraint.referred_table.primary_key 898 | ] 899 | 900 | # If the two tables share more than one foreign key constraint, 901 | # SQLAlchemy needs an explicit primaryjoin to figure out which column(s) 902 | # it needs 903 | common_fk_constraints = get_common_fk_constraints( 904 | source.table, target.table 905 | ) 906 | if len(common_fk_constraints) > 1: 907 | relationship.foreign_keys = [ 908 | source.get_column_attribute(key) 909 | for key in constraint.column_keys 910 | ] 911 | 912 | # Generate the opposite end of the relationship in the target class 913 | if "nobidi" not in self.options: 914 | if r_type is RelationshipType.MANY_TO_ONE: 915 | r_type = RelationshipType.ONE_TO_MANY 916 | 917 | reverse_relationship = RelationshipAttribute( 918 | r_type, 919 | target, 920 | source, 921 | constraint, 922 | foreign_keys=relationship.foreign_keys, 923 | backref=relationship, 924 | ) 925 | relationship.backref = reverse_relationship 926 | target.relationships.append(reverse_relationship) 927 | 928 | # For self referential relationships, remote_side needs to be set 929 | if source is target: 930 | reverse_relationship.remote_side = [ 931 | source.get_column_attribute(colname) 932 | for colname in constraint.column_keys 933 | ] 934 | 935 | # Add many-to-many relationships 936 | for association_table in association_tables: 937 | fk_constraints = sorted( 938 | association_table.table.foreign_key_constraints, 939 | key=get_constraint_sort_key, 940 | ) 941 | target = models_by_table_name[ 942 | qualified_table_name(fk_constraints[1].elements[0].column.table) 943 | ] 944 | if isinstance(target, ModelClass): 945 | relationship = RelationshipAttribute( 946 | RelationshipType.MANY_TO_MANY, 947 | source, 948 | target, 949 | fk_constraints[1], 950 | association_table, 951 | ) 952 | source.relationships.append(relationship) 953 | 954 | # Generate the opposite end of the relationship in the target class 955 | reverse_relationship = None 956 | if "nobidi" not in self.options: 957 | reverse_relationship = RelationshipAttribute( 958 | RelationshipType.MANY_TO_MANY, 959 | target, 960 | source, 961 | fk_constraints[0], 962 | association_table, 963 | relationship, 964 | ) 965 | relationship.backref = reverse_relationship 966 | target.relationships.append(reverse_relationship) 967 | 968 | # Add a primary/secondary join for self-referential many-to-many 969 | # relationships 970 | if source is target: 971 | both_relationships = [relationship] 972 | reverse_flags = [False, True] 973 | if reverse_relationship: 974 | both_relationships.append(reverse_relationship) 975 | 976 | for relationship, reverse in zip(both_relationships, reverse_flags): 977 | if ( 978 | not relationship.association_table 979 | or not relationship.constraint 980 | ): 981 | continue 982 | 983 | constraints = sorted( 984 | relationship.constraint.table.foreign_key_constraints, 985 | key=get_constraint_sort_key, 986 | reverse=reverse, 987 | ) 988 | pri_pairs = zip( 989 | get_column_names(constraints[0]), constraints[0].elements 990 | ) 991 | sec_pairs = zip( 992 | get_column_names(constraints[1]), constraints[1].elements 993 | ) 994 | relationship.primaryjoin = [ 995 | ( 996 | relationship.source, 997 | elem.column.name, 998 | relationship.association_table, 999 | col, 1000 | ) 1001 | for col, elem in pri_pairs 1002 | ] 1003 | relationship.secondaryjoin = [ 1004 | ( 1005 | relationship.target, 1006 | elem.column.name, 1007 | relationship.association_table, 1008 | col, 1009 | ) 1010 | for col, elem in sec_pairs 1011 | ] 1012 | 1013 | return relationships 1014 | 1015 | def generate_model_name(self, model: Model, global_names: set[str]) -> None: 1016 | if isinstance(model, ModelClass): 1017 | preferred_name = _re_invalid_identifier.sub("_", model.table.name) 1018 | preferred_name = "".join( 1019 | part[:1].upper() + part[1:] for part in preferred_name.split("_") 1020 | ) 1021 | if "use_inflect" in self.options: 1022 | singular_name = self.inflect_engine.singular_noun(preferred_name) 1023 | if singular_name: 1024 | preferred_name = singular_name 1025 | 1026 | model.name = self.find_free_name(preferred_name, global_names) 1027 | 1028 | # Fill in the names for column attributes 1029 | local_names: set[str] = set() 1030 | for column_attr in model.columns: 1031 | self.generate_column_attr_name(column_attr, global_names, local_names) 1032 | local_names.add(column_attr.name) 1033 | 1034 | # Fill in the names for relationship attributes 1035 | for relationship in model.relationships: 1036 | self.generate_relationship_name(relationship, global_names, local_names) 1037 | local_names.add(relationship.name) 1038 | else: 1039 | super().generate_model_name(model, global_names) 1040 | 1041 | def generate_column_attr_name( 1042 | self, 1043 | column_attr: ColumnAttribute, 1044 | global_names: set[str], 1045 | local_names: set[str], 1046 | ) -> None: 1047 | column_attr.name = self.find_free_name( 1048 | column_attr.column.name, global_names, local_names 1049 | ) 1050 | 1051 | def generate_relationship_name( 1052 | self, 1053 | relationship: RelationshipAttribute, 1054 | global_names: set[str], 1055 | local_names: set[str], 1056 | ) -> None: 1057 | # Self referential reverse relationships 1058 | preferred_name: str 1059 | if ( 1060 | relationship.type 1061 | in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE) 1062 | and relationship.source is relationship.target 1063 | and relationship.backref 1064 | and relationship.backref.name 1065 | ): 1066 | preferred_name = relationship.backref.name + "_reverse" 1067 | else: 1068 | preferred_name = relationship.target.table.name 1069 | 1070 | # If there's a constraint with a single column that ends with "_id", use the 1071 | # preceding part as the relationship name 1072 | if relationship.constraint: 1073 | is_source = relationship.source.table is relationship.constraint.table 1074 | if is_source or relationship.type not in ( 1075 | RelationshipType.ONE_TO_ONE, 1076 | RelationshipType.ONE_TO_MANY, 1077 | ): 1078 | column_names = [c.name for c in relationship.constraint.columns] 1079 | if len(column_names) == 1 and column_names[0].endswith("_id"): 1080 | preferred_name = column_names[0][:-3] 1081 | 1082 | if "use_inflect" in self.options: 1083 | inflected_name: str | Literal[False] 1084 | if relationship.type in ( 1085 | RelationshipType.ONE_TO_MANY, 1086 | RelationshipType.MANY_TO_MANY, 1087 | ): 1088 | if not self.inflect_engine.singular_noun(preferred_name): 1089 | preferred_name = self.inflect_engine.plural_noun(preferred_name) 1090 | else: 1091 | inflected_name = self.inflect_engine.singular_noun(preferred_name) 1092 | if inflected_name: 1093 | preferred_name = inflected_name 1094 | 1095 | relationship.name = self.find_free_name( 1096 | preferred_name, global_names, local_names 1097 | ) 1098 | 1099 | def render_models(self, models: list[Model]) -> str: 1100 | rendered: list[str] = [] 1101 | for model in models: 1102 | if isinstance(model, ModelClass): 1103 | rendered.append(self.render_class(model)) 1104 | else: 1105 | rendered.append(f"{model.name} = {self.render_table(model.table)}") 1106 | 1107 | return "\n\n\n".join(rendered) 1108 | 1109 | def render_class(self, model: ModelClass) -> str: 1110 | sections: list[str] = [] 1111 | 1112 | # Render class variables / special declarations 1113 | class_vars: str = self.render_class_variables(model) 1114 | if class_vars: 1115 | sections.append(class_vars) 1116 | 1117 | # Render column attributes 1118 | rendered_column_attributes: list[str] = [] 1119 | for nullable in (False, True): 1120 | for column_attr in model.columns: 1121 | if column_attr.column.nullable is nullable: 1122 | rendered_column_attributes.append( 1123 | self.render_column_attribute(column_attr) 1124 | ) 1125 | 1126 | if rendered_column_attributes: 1127 | sections.append("\n".join(rendered_column_attributes)) 1128 | 1129 | # Render relationship attributes 1130 | rendered_relationship_attributes: list[str] = [ 1131 | self.render_relationship(relationship) 1132 | for relationship in model.relationships 1133 | ] 1134 | 1135 | if rendered_relationship_attributes: 1136 | sections.append("\n".join(rendered_relationship_attributes)) 1137 | 1138 | declaration = self.render_class_declaration(model) 1139 | rendered_sections = "\n\n".join( 1140 | indent(section, self.indentation) for section in sections 1141 | ) 1142 | return f"{declaration}\n{rendered_sections}" 1143 | 1144 | def render_class_declaration(self, model: ModelClass) -> str: 1145 | parent_class_name = ( 1146 | model.parent_class.name if model.parent_class else self.base_class_name 1147 | ) 1148 | return f"class {model.name}({parent_class_name}):" 1149 | 1150 | def render_class_variables(self, model: ModelClass) -> str: 1151 | variables = [f"__tablename__ = {model.table.name!r}"] 1152 | 1153 | # Render constraints and indexes as __table_args__ 1154 | table_args = self.render_table_args(model.table) 1155 | if table_args: 1156 | variables.append(f"__table_args__ = {table_args}") 1157 | 1158 | return "\n".join(variables) 1159 | 1160 | def render_table_args(self, table: Table) -> str: 1161 | args: list[str] = [] 1162 | kwargs: dict[str, str] = {} 1163 | 1164 | # Render constraints 1165 | for constraint in sorted(table.constraints, key=get_constraint_sort_key): 1166 | if uses_default_name(constraint): 1167 | if isinstance(constraint, PrimaryKeyConstraint): 1168 | continue 1169 | if ( 1170 | isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)) 1171 | and len(constraint.columns) == 1 1172 | ): 1173 | continue 1174 | 1175 | args.append(self.render_constraint(constraint)) 1176 | 1177 | # Render indexes 1178 | for index in sorted(table.indexes, key=lambda i: cast(str, i.name)): 1179 | if len(index.columns) > 1 or not uses_default_name(index): 1180 | args.append(self.render_index(index)) 1181 | 1182 | if table.schema: 1183 | kwargs["schema"] = table.schema 1184 | 1185 | if table.comment: 1186 | kwargs["comment"] = table.comment 1187 | 1188 | if kwargs: 1189 | formatted_kwargs = pformat(kwargs) 1190 | if not args: 1191 | return formatted_kwargs 1192 | else: 1193 | args.append(formatted_kwargs) 1194 | 1195 | if args: 1196 | rendered_args = f",\n{self.indentation}".join(args) 1197 | if len(args) == 1: 1198 | rendered_args += "," 1199 | 1200 | return f"(\n{self.indentation}{rendered_args}\n)" 1201 | else: 1202 | return "" 1203 | 1204 | def render_column_attribute(self, column_attr: ColumnAttribute) -> str: 1205 | column = column_attr.column 1206 | rendered_column = self.render_column(column, column_attr.name != column.name) 1207 | 1208 | def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]: 1209 | column_type = column.type 1210 | pre: list[str] = [] 1211 | post_size = 0 1212 | if column.nullable: 1213 | self.add_literal_import("typing", "Optional") 1214 | pre.append("Optional[") 1215 | post_size += 1 1216 | 1217 | if isinstance(column_type, ARRAY): 1218 | dim = getattr(column_type, "dimensions", None) or 1 1219 | pre.extend("list[" for _ in range(dim)) 1220 | post_size += dim 1221 | 1222 | column_type = column_type.item_type 1223 | 1224 | return "".join(pre), column_type, "]" * post_size 1225 | 1226 | def render_python_type(column_type: TypeEngine[Any]) -> str: 1227 | python_type = column_type.python_type 1228 | python_type_name = python_type.__name__ 1229 | python_type_module = python_type.__module__ 1230 | if python_type_module == "builtins": 1231 | return python_type_name 1232 | 1233 | try: 1234 | self.add_module_import(python_type_module) 1235 | return f"{python_type_module}.{python_type_name}" 1236 | except NotImplementedError: 1237 | self.add_literal_import("typing", "Any") 1238 | return "Any" 1239 | 1240 | pre, col_type, post = get_type_qualifiers() 1241 | column_python_type = f"{pre}{render_python_type(col_type)}{post}" 1242 | return f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}" 1243 | 1244 | def render_relationship(self, relationship: RelationshipAttribute) -> str: 1245 | def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str: 1246 | rendered = [] 1247 | for attr in column_attrs: 1248 | if attr.model is relationship.source: 1249 | rendered.append(attr.name) 1250 | else: 1251 | rendered.append(repr(f"{attr.model.name}.{attr.name}")) 1252 | 1253 | return "[" + ", ".join(rendered) + "]" 1254 | 1255 | def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str: 1256 | rendered = [] 1257 | render_as_string = False 1258 | # Assume that column_attrs are all in relationship.source or none 1259 | for attr in column_attrs: 1260 | if attr.model is relationship.source: 1261 | rendered.append(attr.name) 1262 | else: 1263 | rendered.append(f"{attr.model.name}.{attr.name}") 1264 | render_as_string = True 1265 | 1266 | if render_as_string: 1267 | return "'[" + ", ".join(rendered) + "]'" 1268 | else: 1269 | return "[" + ", ".join(rendered) + "]" 1270 | 1271 | def render_join(terms: list[JoinType]) -> str: 1272 | rendered_joins = [] 1273 | for source, source_col, target, target_col in terms: 1274 | rendered = f"lambda: {source.name}.{source_col} == {target.name}." 1275 | if target.__class__ is Model: 1276 | rendered += "c." 1277 | 1278 | rendered += str(target_col) 1279 | rendered_joins.append(rendered) 1280 | 1281 | if len(rendered_joins) > 1: 1282 | rendered = ", ".join(rendered_joins) 1283 | return f"and_({rendered})" 1284 | else: 1285 | return rendered_joins[0] 1286 | 1287 | # Render keyword arguments 1288 | kwargs: dict[str, Any] = {} 1289 | if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint: 1290 | if relationship.constraint.referred_table is relationship.source.table: 1291 | kwargs["uselist"] = False 1292 | 1293 | # Add the "secondary" keyword for many-to-many relationships 1294 | if relationship.association_table: 1295 | table_ref = relationship.association_table.table.name 1296 | if relationship.association_table.schema: 1297 | table_ref = f"{relationship.association_table.schema}.{table_ref}" 1298 | 1299 | kwargs["secondary"] = repr(table_ref) 1300 | 1301 | if relationship.remote_side: 1302 | kwargs["remote_side"] = render_column_attrs(relationship.remote_side) 1303 | 1304 | if relationship.foreign_keys: 1305 | kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys) 1306 | 1307 | if relationship.primaryjoin: 1308 | kwargs["primaryjoin"] = render_join(relationship.primaryjoin) 1309 | 1310 | if relationship.secondaryjoin: 1311 | kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin) 1312 | 1313 | if relationship.backref: 1314 | kwargs["back_populates"] = repr(relationship.backref.name) 1315 | 1316 | rendered_relationship = render_callable( 1317 | "relationship", repr(relationship.target.name), kwargs=kwargs 1318 | ) 1319 | 1320 | relationship_type: str 1321 | if relationship.type == RelationshipType.ONE_TO_MANY: 1322 | relationship_type = f"list['{relationship.target.name}']" 1323 | elif relationship.type in ( 1324 | RelationshipType.ONE_TO_ONE, 1325 | RelationshipType.MANY_TO_ONE, 1326 | ): 1327 | relationship_type = f"'{relationship.target.name}'" 1328 | if relationship.constraint and any( 1329 | col.nullable for col in relationship.constraint.columns 1330 | ): 1331 | self.add_literal_import("typing", "Optional") 1332 | relationship_type = f"Optional[{relationship_type}]" 1333 | elif relationship.type == RelationshipType.MANY_TO_MANY: 1334 | relationship_type = f"list['{relationship.target.name}']" 1335 | else: 1336 | self.add_literal_import("typing", "Any") 1337 | relationship_type = "Any" 1338 | 1339 | return ( 1340 | f"{relationship.name}: Mapped[{relationship_type}] " 1341 | f"= {rendered_relationship}" 1342 | ) 1343 | 1344 | 1345 | class DataclassGenerator(DeclarativeGenerator): 1346 | def __init__( 1347 | self, 1348 | metadata: MetaData, 1349 | bind: Connection | Engine, 1350 | options: Sequence[str], 1351 | *, 1352 | indentation: str = " ", 1353 | base_class_name: str = "Base", 1354 | quote_annotations: bool = False, 1355 | metadata_key: str = "sa", 1356 | ): 1357 | super().__init__( 1358 | metadata, 1359 | bind, 1360 | options, 1361 | indentation=indentation, 1362 | base_class_name=base_class_name, 1363 | ) 1364 | self.metadata_key: str = metadata_key 1365 | self.quote_annotations: bool = quote_annotations 1366 | 1367 | def generate_base(self) -> None: 1368 | self.base = Base( 1369 | literal_imports=[ 1370 | LiteralImport("sqlalchemy.orm", "DeclarativeBase"), 1371 | LiteralImport("sqlalchemy.orm", "MappedAsDataclass"), 1372 | ], 1373 | declarations=[ 1374 | (f"class {self.base_class_name}(MappedAsDataclass, DeclarativeBase):"), 1375 | f"{self.indentation}pass", 1376 | ], 1377 | metadata_ref=f"{self.base_class_name}.metadata", 1378 | ) 1379 | 1380 | 1381 | class SQLModelGenerator(DeclarativeGenerator): 1382 | def __init__( 1383 | self, 1384 | metadata: MetaData, 1385 | bind: Connection | Engine, 1386 | options: Sequence[str], 1387 | *, 1388 | indentation: str = " ", 1389 | base_class_name: str = "SQLModel", 1390 | ): 1391 | super().__init__( 1392 | metadata, 1393 | bind, 1394 | options, 1395 | indentation=indentation, 1396 | base_class_name=base_class_name, 1397 | ) 1398 | 1399 | @property 1400 | def views_supported(self) -> bool: 1401 | return False 1402 | 1403 | def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str: 1404 | self.add_import(Column) 1405 | return render_callable("Column", *args, kwargs=kwargs) 1406 | 1407 | def generate_base(self) -> None: 1408 | self.base = Base( 1409 | literal_imports=[], 1410 | declarations=[], 1411 | metadata_ref="", 1412 | ) 1413 | 1414 | def collect_imports(self, models: Iterable[Model]) -> None: 1415 | super(DeclarativeGenerator, self).collect_imports(models) 1416 | if any(isinstance(model, ModelClass) for model in models): 1417 | self.remove_literal_import("sqlalchemy", "MetaData") 1418 | self.add_literal_import("sqlmodel", "SQLModel") 1419 | self.add_literal_import("sqlmodel", "Field") 1420 | 1421 | def collect_imports_for_model(self, model: Model) -> None: 1422 | super(DeclarativeGenerator, self).collect_imports_for_model(model) 1423 | if isinstance(model, ModelClass): 1424 | for column_attr in model.columns: 1425 | if column_attr.column.nullable: 1426 | self.add_literal_import("typing", "Optional") 1427 | break 1428 | 1429 | if model.relationships: 1430 | self.add_literal_import("sqlmodel", "Relationship") 1431 | 1432 | def collect_imports_for_column(self, column: Column[Any]) -> None: 1433 | super().collect_imports_for_column(column) 1434 | try: 1435 | python_type = column.type.python_type 1436 | except NotImplementedError: 1437 | self.add_literal_import("typing", "Any") 1438 | else: 1439 | self.add_import(python_type) 1440 | 1441 | def render_module_variables(self, models: list[Model]) -> str: 1442 | declarations: list[str] = [] 1443 | if any(not isinstance(model, ModelClass) for model in models): 1444 | if self.base.table_metadata_declaration is not None: 1445 | declarations.append(self.base.table_metadata_declaration) 1446 | 1447 | return "\n".join(declarations) 1448 | 1449 | def render_class_declaration(self, model: ModelClass) -> str: 1450 | if model.parent_class: 1451 | parent = model.parent_class.name 1452 | else: 1453 | parent = self.base_class_name 1454 | 1455 | superclass_part = f"({parent}, table=True)" 1456 | return f"class {model.name}{superclass_part}:" 1457 | 1458 | def render_class_variables(self, model: ModelClass) -> str: 1459 | variables = [] 1460 | 1461 | if model.table.name != model.name.lower(): 1462 | variables.append(f"__tablename__ = {model.table.name!r}") 1463 | 1464 | # Render constraints and indexes as __table_args__ 1465 | table_args = self.render_table_args(model.table) 1466 | if table_args: 1467 | variables.append(f"__table_args__ = {table_args}") 1468 | 1469 | return "\n".join(variables) 1470 | 1471 | def render_column_attribute(self, column_attr: ColumnAttribute) -> str: 1472 | column = column_attr.column 1473 | try: 1474 | python_type = column.type.python_type 1475 | except NotImplementedError: 1476 | python_type_name = "Any" 1477 | else: 1478 | python_type_name = python_type.__name__ 1479 | 1480 | kwargs: dict[str, Any] = {} 1481 | if ( 1482 | column.autoincrement and column.name in column.table.primary_key 1483 | ) or column.nullable: 1484 | self.add_literal_import("typing", "Optional") 1485 | kwargs["default"] = None 1486 | python_type_name = f"Optional[{python_type_name}]" 1487 | 1488 | rendered_column = self.render_column(column, True) 1489 | kwargs["sa_column"] = f"{rendered_column}" 1490 | rendered_field = render_callable("Field", kwargs=kwargs) 1491 | return f"{column_attr.name}: {python_type_name} = {rendered_field}" 1492 | 1493 | def render_relationship(self, relationship: RelationshipAttribute) -> str: 1494 | rendered = super().render_relationship(relationship).partition(" = ")[2] 1495 | args = self.render_relationship_args(rendered) 1496 | kwargs: dict[str, Any] = {} 1497 | annotation = repr(relationship.target.name) 1498 | 1499 | if relationship.type in ( 1500 | RelationshipType.ONE_TO_MANY, 1501 | RelationshipType.MANY_TO_MANY, 1502 | ): 1503 | annotation = f"list[{annotation}]" 1504 | else: 1505 | self.add_literal_import("typing", "Optional") 1506 | annotation = f"Optional[{annotation}]" 1507 | 1508 | rendered_field = render_callable("Relationship", *args, kwargs=kwargs) 1509 | return f"{relationship.name}: {annotation} = {rendered_field}" 1510 | 1511 | def render_relationship_args(self, arguments: str) -> list[str]: 1512 | argument_list = arguments.split(",") 1513 | # delete ')' and ' ' from args 1514 | argument_list[-1] = argument_list[-1][:-1] 1515 | argument_list = [argument[1:] for argument in argument_list] 1516 | 1517 | rendered_args: list[str] = [] 1518 | for arg in argument_list: 1519 | if "back_populates" in arg: 1520 | rendered_args.append(arg) 1521 | if "uselist=False" in arg: 1522 | rendered_args.append("sa_relationship_kwargs={'uselist': False}") 1523 | 1524 | return rendered_args 1525 | -------------------------------------------------------------------------------- /src/sqlacodegen/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass, field 4 | from enum import Enum, auto 5 | from typing import Any, Union 6 | 7 | from sqlalchemy.sql.schema import Column, ForeignKeyConstraint, Table 8 | 9 | 10 | @dataclass 11 | class Model: 12 | table: Table 13 | name: str = field(init=False, default="") 14 | 15 | @property 16 | def schema(self) -> str | None: 17 | return self.table.schema 18 | 19 | 20 | @dataclass 21 | class ModelClass(Model): 22 | columns: list[ColumnAttribute] = field(default_factory=list) 23 | relationships: list[RelationshipAttribute] = field(default_factory=list) 24 | parent_class: ModelClass | None = None 25 | children: list[ModelClass] = field(default_factory=list) 26 | 27 | def get_column_attribute(self, column_name: str) -> ColumnAttribute: 28 | for column in self.columns: 29 | if column.column.name == column_name: 30 | return column 31 | 32 | raise LookupError(f"Cannot find column attribute for {column_name!r}") 33 | 34 | 35 | class RelationshipType(Enum): 36 | ONE_TO_ONE = auto() 37 | ONE_TO_MANY = auto() 38 | MANY_TO_ONE = auto() 39 | MANY_TO_MANY = auto() 40 | 41 | 42 | @dataclass 43 | class ColumnAttribute: 44 | model: ModelClass 45 | column: Column[Any] 46 | name: str = field(init=False, default="") 47 | 48 | def __repr__(self) -> str: 49 | return f"{self.__class__.__name__}(name={self.name!r}, type={self.column.type})" 50 | 51 | def __str__(self) -> str: 52 | return self.name 53 | 54 | 55 | JoinType = tuple[Model, Union[ColumnAttribute, str], Model, Union[ColumnAttribute, str]] 56 | 57 | 58 | @dataclass 59 | class RelationshipAttribute: 60 | type: RelationshipType 61 | source: ModelClass 62 | target: ModelClass 63 | constraint: ForeignKeyConstraint | None = None 64 | association_table: Model | None = None 65 | backref: RelationshipAttribute | None = None 66 | remote_side: list[ColumnAttribute] = field(default_factory=list) 67 | foreign_keys: list[ColumnAttribute] = field(default_factory=list) 68 | primaryjoin: list[JoinType] = field(default_factory=list) 69 | secondaryjoin: list[JoinType] = field(default_factory=list) 70 | name: str = field(init=False, default="") 71 | 72 | def __repr__(self) -> str: 73 | return ( 74 | f"{self.__class__.__name__}(name={self.name!r}, type={self.type}, " 75 | f"target={self.target.name})" 76 | ) 77 | 78 | def __str__(self) -> str: 79 | return self.name 80 | -------------------------------------------------------------------------------- /src/sqlacodegen/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agronholm/sqlacodegen/84dcc39d5cef1a840234624603857ab72ec333a0/src/sqlacodegen/py.typed -------------------------------------------------------------------------------- /src/sqlacodegen/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from collections.abc import Mapping 5 | from typing import Any, Literal, cast 6 | 7 | from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint 8 | from sqlalchemy.engine import Connection, Engine 9 | from sqlalchemy.sql import ClauseElement 10 | from sqlalchemy.sql.elements import TextClause 11 | from sqlalchemy.sql.schema import ( 12 | CheckConstraint, 13 | ColumnCollectionConstraint, 14 | Constraint, 15 | ForeignKeyConstraint, 16 | Index, 17 | Table, 18 | ) 19 | 20 | _re_postgresql_nextval_sequence = re.compile(r"nextval\('(.+)'::regclass\)") 21 | _re_postgresql_sequence_delimiter = re.compile(r'(.*?)([."]|$)') 22 | 23 | 24 | def get_column_names(constraint: ColumnCollectionConstraint) -> list[str]: 25 | return list(constraint.columns.keys()) 26 | 27 | 28 | def get_constraint_sort_key(constraint: Constraint) -> str: 29 | if isinstance(constraint, CheckConstraint): 30 | return f"C{constraint.sqltext}" 31 | elif isinstance(constraint, ColumnCollectionConstraint): 32 | return constraint.__class__.__name__[0] + repr(get_column_names(constraint)) 33 | else: 34 | return str(constraint) 35 | 36 | 37 | def get_compiled_expression(statement: ClauseElement, bind: Engine | Connection) -> str: 38 | """Return the statement in a form where any placeholders have been filled in.""" 39 | return str(statement.compile(bind, compile_kwargs={"literal_binds": True})) 40 | 41 | 42 | def get_common_fk_constraints( 43 | table1: Table, table2: Table 44 | ) -> set[ForeignKeyConstraint]: 45 | """ 46 | Return a set of foreign key constraints the two tables have against each other. 47 | 48 | """ 49 | c1 = { 50 | c 51 | for c in table1.constraints 52 | if isinstance(c, ForeignKeyConstraint) and c.elements[0].column.table == table2 53 | } 54 | c2 = { 55 | c 56 | for c in table2.constraints 57 | if isinstance(c, ForeignKeyConstraint) and c.elements[0].column.table == table1 58 | } 59 | return c1.union(c2) 60 | 61 | 62 | def uses_default_name(constraint: Constraint | Index) -> bool: 63 | if not constraint.name or constraint.table is None: 64 | return True 65 | 66 | table = constraint.table 67 | values: dict[str, Any] = { 68 | "table_name": table.name, 69 | "constraint_name": constraint.name, 70 | } 71 | if isinstance(constraint, (Index, ColumnCollectionConstraint)): 72 | values.update( 73 | { 74 | "column_0N_name": "".join(col.name for col in constraint.columns), 75 | "column_0_N_name": "_".join(col.name for col in constraint.columns), 76 | "column_0N_label": "".join( 77 | col.label(col.name).name for col in constraint.columns 78 | ), 79 | "column_0_N_label": "_".join( 80 | col.label(col.name).name for col in constraint.columns 81 | ), 82 | "column_0N_key": "".join( 83 | col.key for col in constraint.columns if col.key 84 | ), 85 | "column_0_N_key": "_".join( 86 | col.key for col in constraint.columns if col.key 87 | ), 88 | } 89 | ) 90 | if constraint.columns: 91 | columns = constraint.columns.values() 92 | values.update( 93 | { 94 | "column_0_name": columns[0].name, 95 | "column_0_label": columns[0].label(columns[0].name).name, 96 | "column_0_key": columns[0].key, 97 | } 98 | ) 99 | 100 | key: Literal["fk", "pk", "ix", "ck", "uq"] 101 | if isinstance(constraint, Index): 102 | key = "ix" 103 | elif isinstance(constraint, CheckConstraint): 104 | key = "ck" 105 | elif isinstance(constraint, UniqueConstraint): 106 | key = "uq" 107 | elif isinstance(constraint, PrimaryKeyConstraint): 108 | key = "pk" 109 | elif isinstance(constraint, ForeignKeyConstraint): 110 | key = "fk" 111 | values.update( 112 | { 113 | "referred_table_name": constraint.referred_table, 114 | "referred_column_0_name": constraint.elements[0].column.name, 115 | "referred_column_0N_name": "".join( 116 | fk.column.name for fk in constraint.elements 117 | ), 118 | "referred_column_0_N_name": "_".join( 119 | fk.column.name for fk in constraint.elements 120 | ), 121 | "referred_column_0_label": constraint.elements[0] 122 | .column.label(constraint.elements[0].column.name) 123 | .name, 124 | "referred_fk.column_0N_label": "".join( 125 | fk.column.label(fk.column.name).name for fk in constraint.elements 126 | ), 127 | "referred_fk.column_0_N_label": "_".join( 128 | fk.column.label(fk.column.name).name for fk in constraint.elements 129 | ), 130 | "referred_fk.column_0_key": constraint.elements[0].column.key, 131 | "referred_fk.column_0N_key": "".join( 132 | fk.column.key for fk in constraint.elements if fk.column.key 133 | ), 134 | "referred_fk.column_0_N_key": "_".join( 135 | fk.column.key for fk in constraint.elements if fk.column.key 136 | ), 137 | } 138 | ) 139 | else: 140 | raise TypeError(f"Unknown constraint type: {constraint.__class__.__qualname__}") 141 | 142 | try: 143 | convention = cast( 144 | Mapping[str, str], 145 | table.metadata.naming_convention, 146 | )[key] 147 | return constraint.name == (convention % values) 148 | except KeyError: 149 | return False 150 | 151 | 152 | def render_callable( 153 | name: str, 154 | *args: object, 155 | kwargs: Mapping[str, object] | None = None, 156 | indentation: str = "", 157 | ) -> str: 158 | """ 159 | Render a function call. 160 | 161 | :param name: name of the callable 162 | :param args: positional arguments 163 | :param kwargs: keyword arguments 164 | :param indentation: if given, each argument will be rendered on its own line with 165 | this value used as the indentation 166 | 167 | """ 168 | if kwargs: 169 | args += tuple(f"{key}={value}" for key, value in kwargs.items()) 170 | 171 | if indentation: 172 | prefix = f"\n{indentation}" 173 | suffix = "\n" 174 | delimiter = f",\n{indentation}" 175 | else: 176 | prefix = suffix = "" 177 | delimiter = ", " 178 | 179 | rendered_args = delimiter.join(str(arg) for arg in args) 180 | return f"{name}({prefix}{rendered_args}{suffix})" 181 | 182 | 183 | def qualified_table_name(table: Table) -> str: 184 | if table.schema: 185 | return f"{table.schema}.{table.name}" 186 | else: 187 | return str(table.name) 188 | 189 | 190 | def decode_postgresql_sequence(clause: TextClause) -> tuple[str | None, str | None]: 191 | match = _re_postgresql_nextval_sequence.match(clause.text) 192 | if not match: 193 | return None, None 194 | 195 | schema: str | None = None 196 | sequence: str = "" 197 | in_quotes = False 198 | for match in _re_postgresql_sequence_delimiter.finditer(match.group(1)): 199 | sequence += match.group(1) 200 | if match.group(2) == '"': 201 | in_quotes = not in_quotes 202 | elif match.group(2) == ".": 203 | if in_quotes: 204 | sequence += "." 205 | else: 206 | schema, sequence = sequence, "" 207 | 208 | return schema, sequence 209 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agronholm/sqlacodegen/84dcc39d5cef1a840234624603857ab72ec333a0/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from textwrap import dedent 2 | 3 | import pytest 4 | from pytest import FixtureRequest 5 | from sqlalchemy.engine import Engine, create_engine 6 | from sqlalchemy.orm import clear_mappers, configure_mappers 7 | from sqlalchemy.schema import MetaData 8 | 9 | 10 | @pytest.fixture 11 | def engine(request: FixtureRequest) -> Engine: 12 | dialect = getattr(request, "param", None) 13 | if dialect == "postgresql": 14 | return create_engine("postgresql+psycopg:///testdb") 15 | elif dialect == "mysql": 16 | return create_engine("mysql+mysqlconnector://testdb") 17 | else: 18 | return create_engine("sqlite:///:memory:") 19 | 20 | 21 | @pytest.fixture 22 | def metadata() -> MetaData: 23 | return MetaData() 24 | 25 | 26 | def validate_code(generated_code: str, expected_code: str) -> None: 27 | expected_code = dedent(expected_code) 28 | assert generated_code == expected_code 29 | try: 30 | exec(generated_code, {}) 31 | configure_mappers() 32 | finally: 33 | clear_mappers() 34 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sqlite3 4 | import subprocess 5 | import sys 6 | from importlib.metadata import version 7 | from pathlib import Path 8 | 9 | import pytest 10 | 11 | future_imports = "from __future__ import annotations\n\n" 12 | 13 | 14 | @pytest.fixture 15 | def db_path(tmp_path: Path) -> Path: 16 | path = tmp_path / "test.db" 17 | with sqlite3.connect(str(path)) as conn: 18 | cursor = conn.cursor() 19 | cursor.execute( 20 | "CREATE TABLE foo (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL)" 21 | ) 22 | 23 | return path 24 | 25 | 26 | def test_cli_tables(db_path: Path, tmp_path: Path) -> None: 27 | output_path = tmp_path / "outfile" 28 | subprocess.run( 29 | [ 30 | "sqlacodegen", 31 | f"sqlite:///{db_path}", 32 | "--generator", 33 | "tables", 34 | "--outfile", 35 | str(output_path), 36 | ], 37 | check=True, 38 | ) 39 | 40 | assert ( 41 | output_path.read_text() 42 | == """\ 43 | from sqlalchemy import Column, Integer, MetaData, Table, Text 44 | 45 | metadata = MetaData() 46 | 47 | 48 | t_foo = Table( 49 | 'foo', metadata, 50 | Column('id', Integer, primary_key=True), 51 | Column('name', Text, nullable=False) 52 | ) 53 | """ 54 | ) 55 | 56 | 57 | def test_cli_declarative(db_path: Path, tmp_path: Path) -> None: 58 | output_path = tmp_path / "outfile" 59 | subprocess.run( 60 | [ 61 | "sqlacodegen", 62 | f"sqlite:///{db_path}", 63 | "--generator", 64 | "declarative", 65 | "--outfile", 66 | str(output_path), 67 | ], 68 | check=True, 69 | ) 70 | 71 | assert ( 72 | output_path.read_text() 73 | == """\ 74 | from sqlalchemy import Integer, Text 75 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 76 | 77 | class Base(DeclarativeBase): 78 | pass 79 | 80 | 81 | class Foo(Base): 82 | __tablename__ = 'foo' 83 | 84 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 85 | name: Mapped[str] = mapped_column(Text) 86 | """ 87 | ) 88 | 89 | 90 | def test_cli_dataclass(db_path: Path, tmp_path: Path) -> None: 91 | output_path = tmp_path / "outfile" 92 | subprocess.run( 93 | [ 94 | "sqlacodegen", 95 | f"sqlite:///{db_path}", 96 | "--generator", 97 | "dataclasses", 98 | "--outfile", 99 | str(output_path), 100 | ], 101 | check=True, 102 | ) 103 | 104 | assert ( 105 | output_path.read_text() 106 | == """\ 107 | from sqlalchemy import Integer, Text 108 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column 109 | 110 | class Base(MappedAsDataclass, DeclarativeBase): 111 | pass 112 | 113 | 114 | class Foo(Base): 115 | __tablename__ = 'foo' 116 | 117 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 118 | name: Mapped[str] = mapped_column(Text) 119 | """ 120 | ) 121 | 122 | 123 | def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None: 124 | output_path = tmp_path / "outfile" 125 | subprocess.run( 126 | [ 127 | "sqlacodegen", 128 | f"sqlite:///{db_path}", 129 | "--generator", 130 | "sqlmodels", 131 | "--outfile", 132 | str(output_path), 133 | ], 134 | check=True, 135 | ) 136 | 137 | assert ( 138 | output_path.read_text() 139 | == """\ 140 | from typing import Optional 141 | 142 | from sqlalchemy import Column, Integer, Text 143 | from sqlmodel import Field, SQLModel 144 | 145 | class Foo(SQLModel, table=True): 146 | id: Optional[int] = Field(default=None, sa_column=Column('id', Integer, \ 147 | primary_key=True)) 148 | name: str = Field(sa_column=Column('name', Text)) 149 | """ 150 | ) 151 | 152 | 153 | def test_cli_engine_arg(db_path: Path, tmp_path: Path) -> None: 154 | output_path = tmp_path / "outfile" 155 | subprocess.run( 156 | [ 157 | "sqlacodegen", 158 | f"sqlite:///{db_path}", 159 | "--generator", 160 | "tables", 161 | "--engine-arg", 162 | 'connect_args={"timeout": 10}', 163 | "--outfile", 164 | str(output_path), 165 | ], 166 | check=True, 167 | ) 168 | 169 | assert ( 170 | output_path.read_text() 171 | == """\ 172 | from sqlalchemy import Column, Integer, MetaData, Table, Text 173 | 174 | metadata = MetaData() 175 | 176 | 177 | t_foo = Table( 178 | 'foo', metadata, 179 | Column('id', Integer, primary_key=True), 180 | Column('name', Text, nullable=False) 181 | ) 182 | """ 183 | ) 184 | 185 | 186 | def test_cli_invalid_engine_arg(db_path: Path, tmp_path: Path) -> None: 187 | output_path = tmp_path / "outfile" 188 | 189 | # Expect exception: 190 | # TypeError: 'this_arg_does_not_exist' is an invalid keyword argument for Connection() 191 | with pytest.raises(subprocess.CalledProcessError) as exc_info: 192 | subprocess.run( 193 | [ 194 | "sqlacodegen", 195 | f"sqlite:///{db_path}", 196 | "--generator", 197 | "tables", 198 | "--engine-arg", 199 | 'connect_args={"this_arg_does_not_exist": 10}', 200 | "--outfile", 201 | str(output_path), 202 | ], 203 | check=True, 204 | capture_output=True, 205 | ) 206 | 207 | if sys.version_info < (3, 13): 208 | assert ( 209 | "'this_arg_does_not_exist' is an invalid keyword argument" 210 | in exc_info.value.stderr.decode() 211 | ) 212 | else: 213 | assert ( 214 | "got an unexpected keyword argument 'this_arg_does_not_exist'" 215 | in exc_info.value.stderr.decode() 216 | ) 217 | 218 | 219 | def test_main() -> None: 220 | expected_version = version("sqlacodegen") 221 | completed = subprocess.run( 222 | [sys.executable, "-m", "sqlacodegen", "--version"], 223 | stdout=subprocess.PIPE, 224 | check=True, 225 | ) 226 | assert completed.stdout.decode().strip() == expected_version 227 | -------------------------------------------------------------------------------- /tests/test_generator_dataclass.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | from _pytest.fixtures import FixtureRequest 5 | from sqlalchemy.dialects.postgresql import UUID 6 | from sqlalchemy.engine import Engine 7 | from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table 8 | from sqlalchemy.sql.expression import text 9 | from sqlalchemy.types import INTEGER, VARCHAR 10 | 11 | from sqlacodegen.generators import CodeGenerator, DataclassGenerator 12 | 13 | from .conftest import validate_code 14 | 15 | 16 | @pytest.fixture 17 | def generator( 18 | request: FixtureRequest, metadata: MetaData, engine: Engine 19 | ) -> CodeGenerator: 20 | options = getattr(request, "param", []) 21 | return DataclassGenerator(metadata, engine, options) 22 | 23 | 24 | def test_basic_class(generator: CodeGenerator) -> None: 25 | Table( 26 | "simple", 27 | generator.metadata, 28 | Column("id", INTEGER, primary_key=True), 29 | Column("name", VARCHAR(20)), 30 | ) 31 | 32 | validate_code( 33 | generator.generate(), 34 | """\ 35 | from typing import Optional 36 | 37 | from sqlalchemy import Integer, String 38 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ 39 | mapped_column 40 | 41 | class Base(MappedAsDataclass, DeclarativeBase): 42 | pass 43 | 44 | 45 | class Simple(Base): 46 | __tablename__ = 'simple' 47 | 48 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 49 | name: Mapped[Optional[str]] = mapped_column(String(20)) 50 | """, 51 | ) 52 | 53 | 54 | def test_mandatory_field_last(generator: CodeGenerator) -> None: 55 | Table( 56 | "simple", 57 | generator.metadata, 58 | Column("id", INTEGER, primary_key=True), 59 | Column("name", VARCHAR(20), server_default=text("foo")), 60 | Column("age", INTEGER, nullable=False), 61 | ) 62 | 63 | validate_code( 64 | generator.generate(), 65 | """\ 66 | from typing import Optional 67 | 68 | from sqlalchemy import Integer, String, text 69 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ 70 | mapped_column 71 | 72 | class Base(MappedAsDataclass, DeclarativeBase): 73 | pass 74 | 75 | 76 | class Simple(Base): 77 | __tablename__ = 'simple' 78 | 79 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 80 | age: Mapped[int] = mapped_column(Integer) 81 | name: Mapped[Optional[str]] = mapped_column(String(20), \ 82 | server_default=text('foo')) 83 | """, 84 | ) 85 | 86 | 87 | def test_onetomany_optional(generator: CodeGenerator) -> None: 88 | Table( 89 | "simple_items", 90 | generator.metadata, 91 | Column("id", INTEGER, primary_key=True), 92 | Column("container_id", INTEGER), 93 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), 94 | ) 95 | Table( 96 | "simple_containers", 97 | generator.metadata, 98 | Column("id", INTEGER, primary_key=True), 99 | ) 100 | 101 | validate_code( 102 | generator.generate(), 103 | """\ 104 | from typing import Optional 105 | 106 | from sqlalchemy import ForeignKey, Integer 107 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ 108 | mapped_column, relationship 109 | 110 | class Base(MappedAsDataclass, DeclarativeBase): 111 | pass 112 | 113 | 114 | class SimpleContainers(Base): 115 | __tablename__ = 'simple_containers' 116 | 117 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 118 | 119 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 120 | back_populates='container') 121 | 122 | 123 | class SimpleItems(Base): 124 | __tablename__ = 'simple_items' 125 | 126 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 127 | container_id: Mapped[Optional[int]] = \ 128 | mapped_column(ForeignKey('simple_containers.id')) 129 | 130 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \ 131 | back_populates='simple_items') 132 | """, 133 | ) 134 | 135 | 136 | def test_manytomany(generator: CodeGenerator) -> None: 137 | Table("simple_items", generator.metadata, Column("id", INTEGER, primary_key=True)) 138 | Table( 139 | "simple_containers", 140 | generator.metadata, 141 | Column("id", INTEGER, primary_key=True), 142 | ) 143 | Table( 144 | "container_items", 145 | generator.metadata, 146 | Column("item_id", INTEGER), 147 | Column("container_id", INTEGER), 148 | ForeignKeyConstraint(["item_id"], ["simple_items.id"]), 149 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), 150 | ) 151 | 152 | validate_code( 153 | generator.generate(), 154 | """\ 155 | from sqlalchemy import Column, ForeignKey, Integer, Table 156 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ 157 | mapped_column, relationship 158 | 159 | class Base(MappedAsDataclass, DeclarativeBase): 160 | pass 161 | 162 | 163 | class SimpleContainers(Base): 164 | __tablename__ = 'simple_containers' 165 | 166 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 167 | 168 | item: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 169 | secondary='container_items', back_populates='container') 170 | 171 | 172 | class SimpleItems(Base): 173 | __tablename__ = 'simple_items' 174 | 175 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 176 | 177 | container: Mapped[list['SimpleContainers']] = \ 178 | relationship('SimpleContainers', secondary='container_items', back_populates='item') 179 | 180 | 181 | t_container_items = Table( 182 | 'container_items', Base.metadata, 183 | Column('item_id', ForeignKey('simple_items.id')), 184 | Column('container_id', ForeignKey('simple_containers.id')) 185 | ) 186 | """, 187 | ) 188 | 189 | 190 | def test_named_foreign_key_constraints(generator: CodeGenerator) -> None: 191 | Table( 192 | "simple_items", 193 | generator.metadata, 194 | Column("id", INTEGER, primary_key=True), 195 | Column("container_id", INTEGER), 196 | ForeignKeyConstraint( 197 | ["container_id"], ["simple_containers.id"], name="foreignkeytest" 198 | ), 199 | ) 200 | Table( 201 | "simple_containers", 202 | generator.metadata, 203 | Column("id", INTEGER, primary_key=True), 204 | ) 205 | 206 | validate_code( 207 | generator.generate(), 208 | """\ 209 | from typing import Optional 210 | 211 | from sqlalchemy import ForeignKeyConstraint, Integer 212 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ 213 | mapped_column, relationship 214 | 215 | class Base(MappedAsDataclass, DeclarativeBase): 216 | pass 217 | 218 | 219 | class SimpleContainers(Base): 220 | __tablename__ = 'simple_containers' 221 | 222 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 223 | 224 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 225 | back_populates='container') 226 | 227 | 228 | class SimpleItems(Base): 229 | __tablename__ = 'simple_items' 230 | __table_args__ = ( 231 | ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \ 232 | name='foreignkeytest'), 233 | ) 234 | 235 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 236 | container_id: Mapped[Optional[int]] = mapped_column(Integer) 237 | 238 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \ 239 | back_populates='simple_items') 240 | """, 241 | ) 242 | 243 | 244 | def test_uuid_type_annotation(generator: CodeGenerator) -> None: 245 | Table( 246 | "simple", 247 | generator.metadata, 248 | Column("id", UUID, primary_key=True), 249 | ) 250 | 251 | validate_code( 252 | generator.generate(), 253 | """\ 254 | from sqlalchemy import UUID 255 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ 256 | mapped_column 257 | import uuid 258 | 259 | class Base(MappedAsDataclass, DeclarativeBase): 260 | pass 261 | 262 | 263 | class Simple(Base): 264 | __tablename__ = 'simple' 265 | 266 | id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True) 267 | """, 268 | ) 269 | -------------------------------------------------------------------------------- /tests/test_generator_declarative.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | from _pytest.fixtures import FixtureRequest 5 | from sqlalchemy import PrimaryKeyConstraint 6 | from sqlalchemy.engine import Engine 7 | from sqlalchemy.schema import ( 8 | CheckConstraint, 9 | Column, 10 | ForeignKey, 11 | ForeignKeyConstraint, 12 | Index, 13 | MetaData, 14 | Table, 15 | UniqueConstraint, 16 | ) 17 | from sqlalchemy.sql.expression import text 18 | from sqlalchemy.types import ARRAY, INTEGER, VARCHAR, Text 19 | 20 | from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator 21 | 22 | from .conftest import validate_code 23 | 24 | 25 | @pytest.fixture 26 | def generator( 27 | request: FixtureRequest, metadata: MetaData, engine: Engine 28 | ) -> CodeGenerator: 29 | options = getattr(request, "param", []) 30 | return DeclarativeGenerator(metadata, engine, options) 31 | 32 | 33 | def test_indexes(generator: CodeGenerator) -> None: 34 | simple_items = Table( 35 | "simple_items", 36 | generator.metadata, 37 | Column("id", INTEGER, primary_key=True), 38 | Column("number", INTEGER), 39 | Column("text", VARCHAR), 40 | ) 41 | simple_items.indexes.add(Index("idx_number", simple_items.c.number)) 42 | simple_items.indexes.add( 43 | Index("idx_text_number", simple_items.c.text, simple_items.c.number) 44 | ) 45 | simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True)) 46 | 47 | validate_code( 48 | generator.generate(), 49 | """\ 50 | from typing import Optional 51 | 52 | from sqlalchemy import Index, Integer, String 53 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 54 | 55 | class Base(DeclarativeBase): 56 | pass 57 | 58 | 59 | class SimpleItems(Base): 60 | __tablename__ = 'simple_items' 61 | __table_args__ = ( 62 | Index('idx_number', 'number'), 63 | Index('idx_text', 'text', unique=True), 64 | Index('idx_text_number', 'text', 'number') 65 | ) 66 | 67 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 68 | number: Mapped[Optional[int]] = mapped_column(Integer) 69 | text: Mapped[Optional[str]] = mapped_column(String) 70 | """, 71 | ) 72 | 73 | 74 | def test_constraints(generator: CodeGenerator) -> None: 75 | Table( 76 | "simple_items", 77 | generator.metadata, 78 | Column("id", INTEGER, primary_key=True), 79 | Column("number", INTEGER), 80 | CheckConstraint("number > 0"), 81 | UniqueConstraint("id", "number"), 82 | ) 83 | 84 | validate_code( 85 | generator.generate(), 86 | """\ 87 | from typing import Optional 88 | 89 | from sqlalchemy import CheckConstraint, Integer, UniqueConstraint 90 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 91 | 92 | class Base(DeclarativeBase): 93 | pass 94 | 95 | 96 | class SimpleItems(Base): 97 | __tablename__ = 'simple_items' 98 | __table_args__ = ( 99 | CheckConstraint('number > 0'), 100 | UniqueConstraint('id', 'number') 101 | ) 102 | 103 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 104 | number: Mapped[Optional[int]] = mapped_column(Integer) 105 | """, 106 | ) 107 | 108 | 109 | def test_onetomany(generator: CodeGenerator) -> None: 110 | Table( 111 | "simple_items", 112 | generator.metadata, 113 | Column("id", INTEGER, primary_key=True), 114 | Column("container_id", INTEGER), 115 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), 116 | ) 117 | Table( 118 | "simple_containers", 119 | generator.metadata, 120 | Column("id", INTEGER, primary_key=True), 121 | ) 122 | 123 | validate_code( 124 | generator.generate(), 125 | """\ 126 | from typing import Optional 127 | 128 | from sqlalchemy import ForeignKey, Integer 129 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 130 | 131 | class Base(DeclarativeBase): 132 | pass 133 | 134 | 135 | class SimpleContainers(Base): 136 | __tablename__ = 'simple_containers' 137 | 138 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 139 | 140 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 141 | back_populates='container') 142 | 143 | 144 | class SimpleItems(Base): 145 | __tablename__ = 'simple_items' 146 | 147 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 148 | container_id: Mapped[Optional[int]] = \ 149 | mapped_column(ForeignKey('simple_containers.id')) 150 | 151 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \ 152 | back_populates='simple_items') 153 | """, 154 | ) 155 | 156 | 157 | def test_onetomany_selfref(generator: CodeGenerator) -> None: 158 | Table( 159 | "simple_items", 160 | generator.metadata, 161 | Column("id", INTEGER, primary_key=True), 162 | Column("parent_item_id", INTEGER), 163 | ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]), 164 | ) 165 | 166 | validate_code( 167 | generator.generate(), 168 | """\ 169 | from typing import Optional 170 | 171 | from sqlalchemy import ForeignKey, Integer 172 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 173 | 174 | class Base(DeclarativeBase): 175 | pass 176 | 177 | 178 | class SimpleItems(Base): 179 | __tablename__ = 'simple_items' 180 | 181 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 182 | parent_item_id: Mapped[Optional[int]] = \ 183 | mapped_column(ForeignKey('simple_items.id')) 184 | 185 | parent_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', \ 186 | remote_side=[id], back_populates='parent_item_reverse') 187 | parent_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 188 | remote_side=[parent_item_id], back_populates='parent_item') 189 | """, 190 | ) 191 | 192 | 193 | def test_onetomany_selfref_multi(generator: CodeGenerator) -> None: 194 | Table( 195 | "simple_items", 196 | generator.metadata, 197 | Column("id", INTEGER, primary_key=True), 198 | Column("parent_item_id", INTEGER), 199 | Column("top_item_id", INTEGER), 200 | ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]), 201 | ForeignKeyConstraint(["top_item_id"], ["simple_items.id"]), 202 | ) 203 | 204 | validate_code( 205 | generator.generate(), 206 | """\ 207 | from typing import Optional 208 | 209 | from sqlalchemy import ForeignKey, Integer 210 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 211 | 212 | class Base(DeclarativeBase): 213 | pass 214 | 215 | 216 | class SimpleItems(Base): 217 | __tablename__ = 'simple_items' 218 | 219 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 220 | parent_item_id: Mapped[Optional[int]] = \ 221 | mapped_column(ForeignKey('simple_items.id')) 222 | top_item_id: Mapped[Optional[int]] = mapped_column(ForeignKey('simple_items.id')) 223 | 224 | parent_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', \ 225 | remote_side=[id], foreign_keys=[parent_item_id], back_populates='parent_item_reverse') 226 | parent_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 227 | remote_side=[parent_item_id], foreign_keys=[parent_item_id], \ 228 | back_populates='parent_item') 229 | top_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', remote_side=[id], \ 230 | foreign_keys=[top_item_id], back_populates='top_item_reverse') 231 | top_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 232 | remote_side=[top_item_id], foreign_keys=[top_item_id], back_populates='top_item') 233 | """, 234 | ) 235 | 236 | 237 | def test_onetomany_composite(generator: CodeGenerator) -> None: 238 | Table( 239 | "simple_items", 240 | generator.metadata, 241 | Column("id", INTEGER, primary_key=True), 242 | Column("container_id1", INTEGER), 243 | Column("container_id2", INTEGER), 244 | ForeignKeyConstraint( 245 | ["container_id1", "container_id2"], 246 | ["simple_containers.id1", "simple_containers.id2"], 247 | ondelete="CASCADE", 248 | onupdate="CASCADE", 249 | ), 250 | ) 251 | Table( 252 | "simple_containers", 253 | generator.metadata, 254 | Column("id1", INTEGER, primary_key=True), 255 | Column("id2", INTEGER, primary_key=True), 256 | ) 257 | 258 | validate_code( 259 | generator.generate(), 260 | """\ 261 | from typing import Optional 262 | 263 | from sqlalchemy import ForeignKeyConstraint, Integer 264 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 265 | 266 | class Base(DeclarativeBase): 267 | pass 268 | 269 | 270 | class SimpleContainers(Base): 271 | __tablename__ = 'simple_containers' 272 | 273 | id1: Mapped[int] = mapped_column(Integer, primary_key=True) 274 | id2: Mapped[int] = mapped_column(Integer, primary_key=True) 275 | 276 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 277 | back_populates='simple_containers') 278 | 279 | 280 | class SimpleItems(Base): 281 | __tablename__ = 'simple_items' 282 | __table_args__ = ( 283 | ForeignKeyConstraint(['container_id1', 'container_id2'], \ 284 | ['simple_containers.id1', 'simple_containers.id2'], ondelete='CASCADE', \ 285 | onupdate='CASCADE'), 286 | ) 287 | 288 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 289 | container_id1: Mapped[Optional[int]] = mapped_column(Integer) 290 | container_id2: Mapped[Optional[int]] = mapped_column(Integer) 291 | 292 | simple_containers: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \ 293 | back_populates='simple_items') 294 | """, 295 | ) 296 | 297 | 298 | def test_onetomany_multiref(generator: CodeGenerator) -> None: 299 | Table( 300 | "simple_items", 301 | generator.metadata, 302 | Column("id", INTEGER, primary_key=True), 303 | Column("parent_container_id", INTEGER), 304 | Column("top_container_id", INTEGER, nullable=False), 305 | ForeignKeyConstraint(["parent_container_id"], ["simple_containers.id"]), 306 | ForeignKeyConstraint(["top_container_id"], ["simple_containers.id"]), 307 | ) 308 | Table( 309 | "simple_containers", 310 | generator.metadata, 311 | Column("id", INTEGER, primary_key=True), 312 | ) 313 | 314 | validate_code( 315 | generator.generate(), 316 | """\ 317 | from typing import Optional 318 | 319 | from sqlalchemy import ForeignKey, Integer 320 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 321 | 322 | class Base(DeclarativeBase): 323 | pass 324 | 325 | 326 | class SimpleContainers(Base): 327 | __tablename__ = 'simple_containers' 328 | 329 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 330 | 331 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 332 | foreign_keys='[SimpleItems.parent_container_id]', back_populates='parent_container') 333 | simple_items_: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 334 | foreign_keys='[SimpleItems.top_container_id]', back_populates='top_container') 335 | 336 | 337 | class SimpleItems(Base): 338 | __tablename__ = 'simple_items' 339 | 340 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 341 | top_container_id: Mapped[int] = \ 342 | mapped_column(ForeignKey('simple_containers.id')) 343 | parent_container_id: Mapped[Optional[int]] = \ 344 | mapped_column(ForeignKey('simple_containers.id')) 345 | 346 | parent_container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \ 347 | foreign_keys=[parent_container_id], back_populates='simple_items') 348 | top_container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ 349 | foreign_keys=[top_container_id], back_populates='simple_items_') 350 | """, 351 | ) 352 | 353 | 354 | def test_onetoone(generator: CodeGenerator) -> None: 355 | Table( 356 | "simple_items", 357 | generator.metadata, 358 | Column("id", INTEGER, primary_key=True), 359 | Column("other_item_id", INTEGER), 360 | ForeignKeyConstraint(["other_item_id"], ["other_items.id"]), 361 | UniqueConstraint("other_item_id"), 362 | ) 363 | Table( 364 | "other_items", 365 | generator.metadata, 366 | Column("id", INTEGER, primary_key=True), 367 | ) 368 | 369 | validate_code( 370 | generator.generate(), 371 | """\ 372 | from typing import Optional 373 | 374 | from sqlalchemy import ForeignKey, Integer 375 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 376 | 377 | class Base(DeclarativeBase): 378 | pass 379 | 380 | 381 | class OtherItems(Base): 382 | __tablename__ = 'other_items' 383 | 384 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 385 | 386 | simple_items: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', uselist=False, \ 387 | back_populates='other_item') 388 | 389 | 390 | class SimpleItems(Base): 391 | __tablename__ = 'simple_items' 392 | 393 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 394 | other_item_id: Mapped[Optional[int]] = \ 395 | mapped_column(ForeignKey('other_items.id'), unique=True) 396 | 397 | other_item: Mapped[Optional['OtherItems']] = relationship('OtherItems', \ 398 | back_populates='simple_items') 399 | """, 400 | ) 401 | 402 | 403 | def test_onetomany_noinflect(generator: CodeGenerator) -> None: 404 | Table( 405 | "oglkrogk", 406 | generator.metadata, 407 | Column("id", INTEGER, primary_key=True), 408 | Column("fehwiuhfiwID", INTEGER), 409 | ForeignKeyConstraint(["fehwiuhfiwID"], ["fehwiuhfiw.id"]), 410 | ) 411 | Table("fehwiuhfiw", generator.metadata, Column("id", INTEGER, primary_key=True)) 412 | 413 | validate_code( 414 | generator.generate(), 415 | """\ 416 | from typing import Optional 417 | 418 | from sqlalchemy import ForeignKey, Integer 419 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 420 | 421 | class Base(DeclarativeBase): 422 | pass 423 | 424 | 425 | class Fehwiuhfiw(Base): 426 | __tablename__ = 'fehwiuhfiw' 427 | 428 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 429 | 430 | oglkrogk: Mapped[list['Oglkrogk']] = relationship('Oglkrogk', \ 431 | back_populates='fehwiuhfiw') 432 | 433 | 434 | class Oglkrogk(Base): 435 | __tablename__ = 'oglkrogk' 436 | 437 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 438 | fehwiuhfiwID: Mapped[Optional[int]] = mapped_column(ForeignKey('fehwiuhfiw.id')) 439 | 440 | fehwiuhfiw: Mapped[Optional['Fehwiuhfiw']] = \ 441 | relationship('Fehwiuhfiw', back_populates='oglkrogk') 442 | """, 443 | ) 444 | 445 | 446 | def test_onetomany_conflicting_column(generator: CodeGenerator) -> None: 447 | Table( 448 | "simple_items", 449 | generator.metadata, 450 | Column("id", INTEGER, primary_key=True), 451 | Column("container_id", INTEGER), 452 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), 453 | ) 454 | Table( 455 | "simple_containers", 456 | generator.metadata, 457 | Column("id", INTEGER, primary_key=True), 458 | Column("relationship", Text), 459 | ) 460 | 461 | validate_code( 462 | generator.generate(), 463 | """\ 464 | from typing import Optional 465 | 466 | from sqlalchemy import ForeignKey, Integer, Text 467 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 468 | 469 | class Base(DeclarativeBase): 470 | pass 471 | 472 | 473 | class SimpleContainers(Base): 474 | __tablename__ = 'simple_containers' 475 | 476 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 477 | relationship_: Mapped[Optional[str]] = mapped_column('relationship', Text) 478 | 479 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 480 | back_populates='container') 481 | 482 | 483 | class SimpleItems(Base): 484 | __tablename__ = 'simple_items' 485 | 486 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 487 | container_id: Mapped[Optional[int]] = \ 488 | mapped_column(ForeignKey('simple_containers.id')) 489 | 490 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \ 491 | back_populates='simple_items') 492 | """, 493 | ) 494 | 495 | 496 | def test_onetomany_conflicting_relationship(generator: CodeGenerator) -> None: 497 | Table( 498 | "simple_items", 499 | generator.metadata, 500 | Column("id", INTEGER, primary_key=True), 501 | Column("relationship_id", INTEGER), 502 | ForeignKeyConstraint(["relationship_id"], ["relationship.id"]), 503 | ) 504 | Table("relationship", generator.metadata, Column("id", INTEGER, primary_key=True)) 505 | 506 | validate_code( 507 | generator.generate(), 508 | """\ 509 | from typing import Optional 510 | 511 | from sqlalchemy import ForeignKey, Integer 512 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 513 | 514 | class Base(DeclarativeBase): 515 | pass 516 | 517 | 518 | class Relationship(Base): 519 | __tablename__ = 'relationship' 520 | 521 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 522 | 523 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 524 | back_populates='relationship_') 525 | 526 | 527 | class SimpleItems(Base): 528 | __tablename__ = 'simple_items' 529 | 530 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 531 | relationship_id: Mapped[Optional[int]] = \ 532 | mapped_column(ForeignKey('relationship.id')) 533 | 534 | relationship_: Mapped[Optional['Relationship']] = relationship('Relationship', \ 535 | back_populates='simple_items') 536 | """, 537 | ) 538 | 539 | 540 | @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True) 541 | def test_manytoone_nobidi(generator: CodeGenerator) -> None: 542 | Table( 543 | "simple_items", 544 | generator.metadata, 545 | Column("id", INTEGER, primary_key=True), 546 | Column("container_id", INTEGER), 547 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), 548 | ) 549 | Table( 550 | "simple_containers", 551 | generator.metadata, 552 | Column("id", INTEGER, primary_key=True), 553 | ) 554 | 555 | validate_code( 556 | generator.generate(), 557 | """\ 558 | from typing import Optional 559 | 560 | from sqlalchemy import ForeignKey, Integer 561 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 562 | 563 | class Base(DeclarativeBase): 564 | pass 565 | 566 | 567 | class SimpleContainers(Base): 568 | __tablename__ = 'simple_containers' 569 | 570 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 571 | 572 | 573 | class SimpleItems(Base): 574 | __tablename__ = 'simple_items' 575 | 576 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 577 | container_id: Mapped[Optional[int]] = \ 578 | mapped_column(ForeignKey('simple_containers.id')) 579 | 580 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers') 581 | """, 582 | ) 583 | 584 | 585 | def test_manytomany(generator: CodeGenerator) -> None: 586 | Table("left_table", generator.metadata, Column("id", INTEGER, primary_key=True)) 587 | Table( 588 | "right_table", 589 | generator.metadata, 590 | Column("id", INTEGER, primary_key=True), 591 | ) 592 | Table( 593 | "association_table", 594 | generator.metadata, 595 | Column("left_id", INTEGER), 596 | Column("right_id", INTEGER), 597 | ForeignKeyConstraint(["left_id"], ["left_table.id"]), 598 | ForeignKeyConstraint(["right_id"], ["right_table.id"]), 599 | ) 600 | 601 | validate_code( 602 | generator.generate(), 603 | """\ 604 | from sqlalchemy import Column, ForeignKey, Integer, Table 605 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 606 | 607 | class Base(DeclarativeBase): 608 | pass 609 | 610 | 611 | class LeftTable(Base): 612 | __tablename__ = 'left_table' 613 | 614 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 615 | 616 | right: Mapped[list['RightTable']] = relationship('RightTable', \ 617 | secondary='association_table', back_populates='left') 618 | 619 | 620 | class RightTable(Base): 621 | __tablename__ = 'right_table' 622 | 623 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 624 | 625 | left: Mapped[list['LeftTable']] = relationship('LeftTable', \ 626 | secondary='association_table', back_populates='right') 627 | 628 | 629 | t_association_table = Table( 630 | 'association_table', Base.metadata, 631 | Column('left_id', ForeignKey('left_table.id')), 632 | Column('right_id', ForeignKey('right_table.id')) 633 | ) 634 | """, 635 | ) 636 | 637 | 638 | @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True) 639 | def test_manytomany_nobidi(generator: CodeGenerator) -> None: 640 | Table("simple_items", generator.metadata, Column("id", INTEGER, primary_key=True)) 641 | Table( 642 | "simple_containers", 643 | generator.metadata, 644 | Column("id", INTEGER, primary_key=True), 645 | ) 646 | Table( 647 | "container_items", 648 | generator.metadata, 649 | Column("item_id", INTEGER), 650 | Column("container_id", INTEGER), 651 | ForeignKeyConstraint(["item_id"], ["simple_items.id"]), 652 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), 653 | ) 654 | 655 | validate_code( 656 | generator.generate(), 657 | """\ 658 | from sqlalchemy import Column, ForeignKey, Integer, Table 659 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 660 | 661 | class Base(DeclarativeBase): 662 | pass 663 | 664 | 665 | class SimpleContainers(Base): 666 | __tablename__ = 'simple_containers' 667 | 668 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 669 | 670 | item: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 671 | secondary='container_items') 672 | 673 | 674 | class SimpleItems(Base): 675 | __tablename__ = 'simple_items' 676 | 677 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 678 | 679 | 680 | t_container_items = Table( 681 | 'container_items', Base.metadata, 682 | Column('item_id', ForeignKey('simple_items.id')), 683 | Column('container_id', ForeignKey('simple_containers.id')) 684 | ) 685 | """, 686 | ) 687 | 688 | 689 | def test_manytomany_selfref(generator: CodeGenerator) -> None: 690 | Table("simple_items", generator.metadata, Column("id", INTEGER, primary_key=True)) 691 | Table( 692 | "child_items", 693 | generator.metadata, 694 | Column("parent_id", INTEGER), 695 | Column("child_id", INTEGER), 696 | ForeignKeyConstraint(["parent_id"], ["simple_items.id"]), 697 | ForeignKeyConstraint(["child_id"], ["simple_items.id"]), 698 | schema="otherschema", 699 | ) 700 | 701 | validate_code( 702 | generator.generate(), 703 | """\ 704 | from sqlalchemy import Column, ForeignKey, Integer, Table 705 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 706 | 707 | class Base(DeclarativeBase): 708 | pass 709 | 710 | 711 | class SimpleItems(Base): 712 | __tablename__ = 'simple_items' 713 | 714 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 715 | 716 | parent: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 717 | secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id \ 718 | == t_child_items.c.child_id, \ 719 | secondaryjoin=lambda: SimpleItems.id == \ 720 | t_child_items.c.parent_id, back_populates='child') 721 | child: Mapped[list['SimpleItems']] = \ 722 | relationship('SimpleItems', secondary='otherschema.child_items', \ 723 | primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, \ 724 | secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, \ 725 | back_populates='parent') 726 | 727 | 728 | t_child_items = Table( 729 | 'child_items', Base.metadata, 730 | Column('parent_id', ForeignKey('simple_items.id')), 731 | Column('child_id', ForeignKey('simple_items.id')), 732 | schema='otherschema' 733 | ) 734 | """, 735 | ) 736 | 737 | 738 | def test_manytomany_composite(generator: CodeGenerator) -> None: 739 | Table( 740 | "simple_items", 741 | generator.metadata, 742 | Column("id1", INTEGER, primary_key=True), 743 | Column("id2", INTEGER, primary_key=True), 744 | ) 745 | Table( 746 | "simple_containers", 747 | generator.metadata, 748 | Column("id1", INTEGER, primary_key=True), 749 | Column("id2", INTEGER, primary_key=True), 750 | ) 751 | Table( 752 | "container_items", 753 | generator.metadata, 754 | Column("item_id1", INTEGER), 755 | Column("item_id2", INTEGER), 756 | Column("container_id1", INTEGER), 757 | Column("container_id2", INTEGER), 758 | ForeignKeyConstraint( 759 | ["item_id1", "item_id2"], ["simple_items.id1", "simple_items.id2"] 760 | ), 761 | ForeignKeyConstraint( 762 | ["container_id1", "container_id2"], 763 | ["simple_containers.id1", "simple_containers.id2"], 764 | ), 765 | ) 766 | 767 | validate_code( 768 | generator.generate(), 769 | """\ 770 | from sqlalchemy import Column, ForeignKeyConstraint, Integer, Table 771 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 772 | 773 | class Base(DeclarativeBase): 774 | pass 775 | 776 | 777 | class SimpleContainers(Base): 778 | __tablename__ = 'simple_containers' 779 | 780 | id1: Mapped[int] = mapped_column(Integer, primary_key=True) 781 | id2: Mapped[int] = mapped_column(Integer, primary_key=True) 782 | 783 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 784 | secondary='container_items', back_populates='simple_containers') 785 | 786 | 787 | class SimpleItems(Base): 788 | __tablename__ = 'simple_items' 789 | 790 | id1: Mapped[int] = mapped_column(Integer, primary_key=True) 791 | id2: Mapped[int] = mapped_column(Integer, primary_key=True) 792 | 793 | simple_containers: Mapped[list['SimpleContainers']] = \ 794 | relationship('SimpleContainers', secondary='container_items', \ 795 | back_populates='simple_items') 796 | 797 | 798 | t_container_items = Table( 799 | 'container_items', Base.metadata, 800 | Column('item_id1', Integer), 801 | Column('item_id2', Integer), 802 | Column('container_id1', Integer), 803 | Column('container_id2', Integer), 804 | ForeignKeyConstraint(['container_id1', 'container_id2'], \ 805 | ['simple_containers.id1', 'simple_containers.id2']), 806 | ForeignKeyConstraint(['item_id1', 'item_id2'], \ 807 | ['simple_items.id1', 'simple_items.id2']) 808 | ) 809 | """, 810 | ) 811 | 812 | 813 | def test_joined_inheritance(generator: CodeGenerator) -> None: 814 | Table( 815 | "simple_sub_items", 816 | generator.metadata, 817 | Column("simple_items_id", INTEGER, primary_key=True), 818 | Column("data3", INTEGER), 819 | ForeignKeyConstraint(["simple_items_id"], ["simple_items.super_item_id"]), 820 | ) 821 | Table( 822 | "simple_super_items", 823 | generator.metadata, 824 | Column("id", INTEGER, primary_key=True), 825 | Column("data1", INTEGER), 826 | ) 827 | Table( 828 | "simple_items", 829 | generator.metadata, 830 | Column("super_item_id", INTEGER, primary_key=True), 831 | Column("data2", INTEGER), 832 | ForeignKeyConstraint(["super_item_id"], ["simple_super_items.id"]), 833 | ) 834 | 835 | validate_code( 836 | generator.generate(), 837 | """\ 838 | from typing import Optional 839 | 840 | from sqlalchemy import ForeignKey, Integer 841 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 842 | 843 | class Base(DeclarativeBase): 844 | pass 845 | 846 | 847 | class SimpleSuperItems(Base): 848 | __tablename__ = 'simple_super_items' 849 | 850 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 851 | data1: Mapped[Optional[int]] = mapped_column(Integer) 852 | 853 | 854 | class SimpleItems(SimpleSuperItems): 855 | __tablename__ = 'simple_items' 856 | 857 | super_item_id: Mapped[int] = mapped_column(ForeignKey('simple_super_items.id'), \ 858 | primary_key=True) 859 | data2: Mapped[Optional[int]] = mapped_column(Integer) 860 | 861 | 862 | class SimpleSubItems(SimpleItems): 863 | __tablename__ = 'simple_sub_items' 864 | 865 | simple_items_id: Mapped[int] = \ 866 | mapped_column(ForeignKey('simple_items.super_item_id'), primary_key=True) 867 | data3: Mapped[Optional[int]] = mapped_column(Integer) 868 | """, 869 | ) 870 | 871 | 872 | def test_joined_inheritance_same_table_name(generator: CodeGenerator) -> None: 873 | Table( 874 | "simple", 875 | generator.metadata, 876 | Column("id", INTEGER, primary_key=True), 877 | ) 878 | Table( 879 | "simple", 880 | generator.metadata, 881 | Column("id", INTEGER, ForeignKey("simple.id"), primary_key=True), 882 | schema="altschema", 883 | ) 884 | 885 | validate_code( 886 | generator.generate(), 887 | """\ 888 | from sqlalchemy import ForeignKey, Integer 889 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 890 | 891 | class Base(DeclarativeBase): 892 | pass 893 | 894 | 895 | class Simple(Base): 896 | __tablename__ = 'simple' 897 | 898 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 899 | 900 | 901 | class Simple_(Simple): 902 | __tablename__ = 'simple' 903 | __table_args__ = {'schema': 'altschema'} 904 | 905 | id: Mapped[int] = mapped_column(ForeignKey('simple.id'), primary_key=True) 906 | """, 907 | ) 908 | 909 | 910 | @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) 911 | def test_use_inflect(generator: CodeGenerator) -> None: 912 | Table("simple_items", generator.metadata, Column("id", INTEGER, primary_key=True)) 913 | 914 | Table("singular", generator.metadata, Column("id", INTEGER, primary_key=True)) 915 | 916 | validate_code( 917 | generator.generate(), 918 | """\ 919 | from sqlalchemy import Integer 920 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 921 | 922 | class Base(DeclarativeBase): 923 | pass 924 | 925 | 926 | class SimpleItem(Base): 927 | __tablename__ = 'simple_items' 928 | 929 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 930 | 931 | 932 | class Singular(Base): 933 | __tablename__ = 'singular' 934 | 935 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 936 | """, 937 | ) 938 | 939 | 940 | @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) 941 | @pytest.mark.parametrize( 942 | argnames=("table_name", "class_name", "relationship_name"), 943 | argvalues=[ 944 | ("manufacturers", "manufacturer", "manufacturer"), 945 | ("statuses", "status", "status"), 946 | ("studies", "study", "study"), 947 | ("moose", "moose", "moose"), 948 | ], 949 | ids=[ 950 | "test_inflect_manufacturer", 951 | "test_inflect_status", 952 | "test_inflect_study", 953 | "test_inflect_moose", 954 | ], 955 | ) 956 | def test_use_inflect_plural( 957 | generator: CodeGenerator, 958 | table_name: str, 959 | class_name: str, 960 | relationship_name: str, 961 | ) -> None: 962 | Table( 963 | "simple_items", 964 | generator.metadata, 965 | Column("id", INTEGER, primary_key=True), 966 | Column(f"{relationship_name}_id", INTEGER), 967 | ForeignKeyConstraint([f"{relationship_name}_id"], [f"{table_name}.id"]), 968 | UniqueConstraint(f"{relationship_name}_id"), 969 | ) 970 | Table(table_name, generator.metadata, Column("id", INTEGER, primary_key=True)) 971 | 972 | validate_code( 973 | generator.generate(), 974 | f"""\ 975 | from typing import Optional 976 | 977 | from sqlalchemy import ForeignKey, Integer 978 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 979 | 980 | class Base(DeclarativeBase): 981 | pass 982 | 983 | 984 | class {class_name.capitalize()}(Base): 985 | __tablename__ = '{table_name}' 986 | 987 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 988 | 989 | simple_item: Mapped[Optional['SimpleItem']] = relationship('SimpleItem', uselist=False, \ 990 | back_populates='{relationship_name}') 991 | 992 | 993 | class SimpleItem(Base): 994 | __tablename__ = 'simple_items' 995 | 996 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 997 | {relationship_name}_id: Mapped[Optional[int]] = \ 998 | mapped_column(ForeignKey('{table_name}.id'), unique=True) 999 | 1000 | {relationship_name}: Mapped[Optional['{class_name.capitalize()}']] = \ 1001 | relationship('{class_name.capitalize()}', back_populates='simple_item') 1002 | """, 1003 | ) 1004 | 1005 | 1006 | @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) 1007 | def test_use_inflect_plural_double_pluralize(generator: CodeGenerator) -> None: 1008 | Table( 1009 | "users", 1010 | generator.metadata, 1011 | Column("users_id", INTEGER), 1012 | Column("groups_id", INTEGER), 1013 | ForeignKeyConstraint( 1014 | ["groups_id"], ["groups.groups_id"], name="fk_users_groups_id" 1015 | ), 1016 | PrimaryKeyConstraint("users_id", name="users_pkey"), 1017 | ) 1018 | 1019 | Table( 1020 | "groups", 1021 | generator.metadata, 1022 | Column("groups_id", INTEGER), 1023 | Column("group_name", Text(50), nullable=False), 1024 | PrimaryKeyConstraint("groups_id", name="groups_pkey"), 1025 | ) 1026 | 1027 | validate_code( 1028 | generator.generate(), 1029 | """\ 1030 | from typing import Optional 1031 | 1032 | from sqlalchemy import ForeignKeyConstraint, Integer, PrimaryKeyConstraint, Text 1033 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 1034 | 1035 | class Base(DeclarativeBase): 1036 | pass 1037 | 1038 | 1039 | class Group(Base): 1040 | __tablename__ = 'groups' 1041 | __table_args__ = ( 1042 | PrimaryKeyConstraint('groups_id', name='groups_pkey'), 1043 | ) 1044 | 1045 | groups_id: Mapped[int] = mapped_column(Integer, primary_key=True) 1046 | group_name: Mapped[str] = mapped_column(Text(50)) 1047 | 1048 | users: Mapped[list['User']] = relationship('User', back_populates='group') 1049 | 1050 | 1051 | class User(Base): 1052 | __tablename__ = 'users' 1053 | __table_args__ = ( 1054 | ForeignKeyConstraint(['groups_id'], ['groups.groups_id'], name='fk_users_groups_id'), 1055 | PrimaryKeyConstraint('users_id', name='users_pkey') 1056 | ) 1057 | 1058 | users_id: Mapped[int] = mapped_column(Integer, primary_key=True) 1059 | groups_id: Mapped[Optional[int]] = mapped_column(Integer) 1060 | 1061 | group: Mapped[Optional['Group']] = relationship('Group', back_populates='users') 1062 | """, 1063 | ) 1064 | 1065 | 1066 | def test_table_kwargs(generator: CodeGenerator) -> None: 1067 | Table( 1068 | "simple_items", 1069 | generator.metadata, 1070 | Column("id", INTEGER, primary_key=True), 1071 | schema="testschema", 1072 | ) 1073 | 1074 | validate_code( 1075 | generator.generate(), 1076 | """\ 1077 | from sqlalchemy import Integer 1078 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1079 | 1080 | class Base(DeclarativeBase): 1081 | pass 1082 | 1083 | 1084 | class SimpleItems(Base): 1085 | __tablename__ = 'simple_items' 1086 | __table_args__ = {'schema': 'testschema'} 1087 | 1088 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1089 | """, 1090 | ) 1091 | 1092 | 1093 | def test_table_args_kwargs(generator: CodeGenerator) -> None: 1094 | simple_items = Table( 1095 | "simple_items", 1096 | generator.metadata, 1097 | Column("id", INTEGER, primary_key=True), 1098 | Column("name", VARCHAR), 1099 | schema="testschema", 1100 | ) 1101 | simple_items.indexes.add(Index("testidx", simple_items.c.id, simple_items.c.name)) 1102 | 1103 | validate_code( 1104 | generator.generate(), 1105 | """\ 1106 | from typing import Optional 1107 | 1108 | from sqlalchemy import Index, Integer, String 1109 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1110 | 1111 | class Base(DeclarativeBase): 1112 | pass 1113 | 1114 | 1115 | class SimpleItems(Base): 1116 | __tablename__ = 'simple_items' 1117 | __table_args__ = ( 1118 | Index('testidx', 'id', 'name'), 1119 | {'schema': 'testschema'} 1120 | ) 1121 | 1122 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1123 | name: Mapped[Optional[str]] = mapped_column(String) 1124 | """, 1125 | ) 1126 | 1127 | 1128 | def test_foreign_key_schema(generator: CodeGenerator) -> None: 1129 | Table( 1130 | "simple_items", 1131 | generator.metadata, 1132 | Column("id", INTEGER, primary_key=True), 1133 | Column("other_item_id", INTEGER), 1134 | ForeignKeyConstraint(["other_item_id"], ["otherschema.other_items.id"]), 1135 | ) 1136 | Table( 1137 | "other_items", 1138 | generator.metadata, 1139 | Column("id", INTEGER, primary_key=True), 1140 | schema="otherschema", 1141 | ) 1142 | 1143 | validate_code( 1144 | generator.generate(), 1145 | """\ 1146 | from typing import Optional 1147 | 1148 | from sqlalchemy import ForeignKey, Integer 1149 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 1150 | 1151 | class Base(DeclarativeBase): 1152 | pass 1153 | 1154 | 1155 | class OtherItems(Base): 1156 | __tablename__ = 'other_items' 1157 | __table_args__ = {'schema': 'otherschema'} 1158 | 1159 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1160 | 1161 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 1162 | back_populates='other_item') 1163 | 1164 | 1165 | class SimpleItems(Base): 1166 | __tablename__ = 'simple_items' 1167 | 1168 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1169 | other_item_id: Mapped[Optional[int]] = \ 1170 | mapped_column(ForeignKey('otherschema.other_items.id')) 1171 | 1172 | other_item: Mapped[Optional['OtherItems']] = relationship('OtherItems', \ 1173 | back_populates='simple_items') 1174 | """, 1175 | ) 1176 | 1177 | 1178 | def test_invalid_attribute_names(generator: CodeGenerator) -> None: 1179 | Table( 1180 | "simple-items", 1181 | generator.metadata, 1182 | Column("id-test", INTEGER, primary_key=True), 1183 | Column("4test", INTEGER), 1184 | Column("_4test", INTEGER), 1185 | Column("def", INTEGER), 1186 | ) 1187 | 1188 | validate_code( 1189 | generator.generate(), 1190 | """\ 1191 | from typing import Optional 1192 | 1193 | from sqlalchemy import Integer 1194 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1195 | 1196 | class Base(DeclarativeBase): 1197 | pass 1198 | 1199 | 1200 | class SimpleItems(Base): 1201 | __tablename__ = 'simple-items' 1202 | 1203 | id_test: Mapped[int] = mapped_column('id-test', Integer, primary_key=True) 1204 | _4test: Mapped[Optional[int]] = mapped_column('4test', Integer) 1205 | _4test_: Mapped[Optional[int]] = mapped_column('_4test', Integer) 1206 | def_: Mapped[Optional[int]] = mapped_column('def', Integer) 1207 | """, 1208 | ) 1209 | 1210 | 1211 | def test_pascal(generator: CodeGenerator) -> None: 1212 | Table( 1213 | "CustomerAPIPreference", 1214 | generator.metadata, 1215 | Column("id", INTEGER, primary_key=True), 1216 | ) 1217 | 1218 | validate_code( 1219 | generator.generate(), 1220 | """\ 1221 | from sqlalchemy import Integer 1222 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1223 | 1224 | class Base(DeclarativeBase): 1225 | pass 1226 | 1227 | 1228 | class CustomerAPIPreference(Base): 1229 | __tablename__ = 'CustomerAPIPreference' 1230 | 1231 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1232 | """, 1233 | ) 1234 | 1235 | 1236 | def test_underscore(generator: CodeGenerator) -> None: 1237 | Table( 1238 | "customer_api_preference", 1239 | generator.metadata, 1240 | Column("id", INTEGER, primary_key=True), 1241 | ) 1242 | 1243 | validate_code( 1244 | generator.generate(), 1245 | """\ 1246 | from sqlalchemy import Integer 1247 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1248 | 1249 | class Base(DeclarativeBase): 1250 | pass 1251 | 1252 | 1253 | class CustomerApiPreference(Base): 1254 | __tablename__ = 'customer_api_preference' 1255 | 1256 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1257 | """, 1258 | ) 1259 | 1260 | 1261 | def test_pascal_underscore(generator: CodeGenerator) -> None: 1262 | Table( 1263 | "customer_API_Preference", 1264 | generator.metadata, 1265 | Column("id", INTEGER, primary_key=True), 1266 | ) 1267 | 1268 | validate_code( 1269 | generator.generate(), 1270 | """\ 1271 | from sqlalchemy import Integer 1272 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1273 | 1274 | class Base(DeclarativeBase): 1275 | pass 1276 | 1277 | 1278 | class CustomerAPIPreference(Base): 1279 | __tablename__ = 'customer_API_Preference' 1280 | 1281 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1282 | """, 1283 | ) 1284 | 1285 | 1286 | def test_pascal_multiple_underscore(generator: CodeGenerator) -> None: 1287 | Table( 1288 | "customer_API__Preference", 1289 | generator.metadata, 1290 | Column("id", INTEGER, primary_key=True), 1291 | ) 1292 | 1293 | validate_code( 1294 | generator.generate(), 1295 | """\ 1296 | from sqlalchemy import Integer 1297 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1298 | 1299 | class Base(DeclarativeBase): 1300 | pass 1301 | 1302 | 1303 | class CustomerAPIPreference(Base): 1304 | __tablename__ = 'customer_API__Preference' 1305 | 1306 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1307 | """, 1308 | ) 1309 | 1310 | 1311 | @pytest.mark.parametrize( 1312 | "generator, nocomments", 1313 | [([], False), (["nocomments"], True)], 1314 | indirect=["generator"], 1315 | ) 1316 | def test_column_comment(generator: CodeGenerator, nocomments: bool) -> None: 1317 | Table( 1318 | "simple", 1319 | generator.metadata, 1320 | Column("id", INTEGER, primary_key=True, comment="this is a 'comment'"), 1321 | ) 1322 | 1323 | comment_part = "" if nocomments else ", comment=\"this is a 'comment'\"" 1324 | validate_code( 1325 | generator.generate(), 1326 | f"""\ 1327 | from sqlalchemy import Integer 1328 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1329 | 1330 | class Base(DeclarativeBase): 1331 | pass 1332 | 1333 | 1334 | class Simple(Base): 1335 | __tablename__ = 'simple' 1336 | 1337 | id: Mapped[int] = mapped_column(Integer, primary_key=True{comment_part}) 1338 | """, 1339 | ) 1340 | 1341 | 1342 | def test_table_comment(generator: CodeGenerator) -> None: 1343 | Table( 1344 | "simple", 1345 | generator.metadata, 1346 | Column("id", INTEGER, primary_key=True), 1347 | comment="this is a 'comment'", 1348 | ) 1349 | 1350 | validate_code( 1351 | generator.generate(), 1352 | """\ 1353 | from sqlalchemy import Integer 1354 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1355 | 1356 | class Base(DeclarativeBase): 1357 | pass 1358 | 1359 | 1360 | class Simple(Base): 1361 | __tablename__ = 'simple' 1362 | __table_args__ = {'comment': "this is a 'comment'"} 1363 | 1364 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1365 | """, 1366 | ) 1367 | 1368 | 1369 | def test_metadata_column(generator: CodeGenerator) -> None: 1370 | Table( 1371 | "simple", 1372 | generator.metadata, 1373 | Column("id", INTEGER, primary_key=True), 1374 | Column("metadata", VARCHAR), 1375 | ) 1376 | 1377 | validate_code( 1378 | generator.generate(), 1379 | """\ 1380 | from typing import Optional 1381 | 1382 | from sqlalchemy import Integer, String 1383 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1384 | 1385 | class Base(DeclarativeBase): 1386 | pass 1387 | 1388 | 1389 | class Simple(Base): 1390 | __tablename__ = 'simple' 1391 | 1392 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1393 | metadata_: Mapped[Optional[str]] = mapped_column('metadata', String) 1394 | """, 1395 | ) 1396 | 1397 | 1398 | def test_invalid_variable_name_from_column(generator: CodeGenerator) -> None: 1399 | Table( 1400 | "simple", 1401 | generator.metadata, 1402 | Column(" id ", INTEGER, primary_key=True), 1403 | ) 1404 | 1405 | validate_code( 1406 | generator.generate(), 1407 | """\ 1408 | from sqlalchemy import Integer 1409 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1410 | 1411 | class Base(DeclarativeBase): 1412 | pass 1413 | 1414 | 1415 | class Simple(Base): 1416 | __tablename__ = 'simple' 1417 | 1418 | id: Mapped[int] = mapped_column(' id ', Integer, primary_key=True) 1419 | """, 1420 | ) 1421 | 1422 | 1423 | def test_only_tables(generator: CodeGenerator) -> None: 1424 | Table("simple", generator.metadata, Column("id", INTEGER)) 1425 | 1426 | validate_code( 1427 | generator.generate(), 1428 | """\ 1429 | from sqlalchemy import Column, Integer, MetaData, Table 1430 | 1431 | metadata = MetaData() 1432 | 1433 | 1434 | t_simple = Table( 1435 | 'simple', metadata, 1436 | Column('id', Integer) 1437 | ) 1438 | """, 1439 | ) 1440 | 1441 | 1442 | def test_named_constraints(generator: CodeGenerator) -> None: 1443 | Table( 1444 | "simple", 1445 | generator.metadata, 1446 | Column("id", INTEGER), 1447 | Column("text", VARCHAR), 1448 | CheckConstraint("id > 0", name="checktest"), 1449 | PrimaryKeyConstraint("id", name="primarytest"), 1450 | UniqueConstraint("text", name="uniquetest"), 1451 | ) 1452 | 1453 | validate_code( 1454 | generator.generate(), 1455 | """\ 1456 | from typing import Optional 1457 | 1458 | from sqlalchemy import CheckConstraint, Integer, PrimaryKeyConstraint, \ 1459 | String, UniqueConstraint 1460 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1461 | 1462 | class Base(DeclarativeBase): 1463 | pass 1464 | 1465 | 1466 | class Simple(Base): 1467 | __tablename__ = 'simple' 1468 | __table_args__ = ( 1469 | CheckConstraint('id > 0', name='checktest'), 1470 | PrimaryKeyConstraint('id', name='primarytest'), 1471 | UniqueConstraint('text', name='uniquetest') 1472 | ) 1473 | 1474 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1475 | text: Mapped[Optional[str]] = mapped_column(String) 1476 | """, 1477 | ) 1478 | 1479 | 1480 | def test_named_foreign_key_constraints(generator: CodeGenerator) -> None: 1481 | Table( 1482 | "simple_items", 1483 | generator.metadata, 1484 | Column("id", INTEGER, primary_key=True), 1485 | Column("container_id", INTEGER), 1486 | ForeignKeyConstraint( 1487 | ["container_id"], ["simple_containers.id"], name="foreignkeytest" 1488 | ), 1489 | ) 1490 | Table( 1491 | "simple_containers", 1492 | generator.metadata, 1493 | Column("id", INTEGER, primary_key=True), 1494 | ) 1495 | 1496 | validate_code( 1497 | generator.generate(), 1498 | """\ 1499 | from typing import Optional 1500 | 1501 | from sqlalchemy import ForeignKeyConstraint, Integer 1502 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship 1503 | 1504 | class Base(DeclarativeBase): 1505 | pass 1506 | 1507 | 1508 | class SimpleContainers(Base): 1509 | __tablename__ = 'simple_containers' 1510 | 1511 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1512 | 1513 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ 1514 | back_populates='container') 1515 | 1516 | 1517 | class SimpleItems(Base): 1518 | __tablename__ = 'simple_items' 1519 | __table_args__ = ( 1520 | ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \ 1521 | name='foreignkeytest'), 1522 | ) 1523 | 1524 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1525 | container_id: Mapped[Optional[int]] = mapped_column(Integer) 1526 | 1527 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \ 1528 | back_populates='simple_items') 1529 | """, 1530 | ) 1531 | 1532 | 1533 | # @pytest.mark.xfail(strict=True) 1534 | def test_colname_import_conflict(generator: CodeGenerator) -> None: 1535 | Table( 1536 | "simple", 1537 | generator.metadata, 1538 | Column("id", INTEGER, primary_key=True), 1539 | Column("text", VARCHAR), 1540 | Column("textwithdefault", VARCHAR, server_default=text("'test'")), 1541 | ) 1542 | 1543 | validate_code( 1544 | generator.generate(), 1545 | """\ 1546 | from typing import Optional 1547 | 1548 | from sqlalchemy import Integer, String, text 1549 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1550 | 1551 | class Base(DeclarativeBase): 1552 | pass 1553 | 1554 | 1555 | class Simple(Base): 1556 | __tablename__ = 'simple' 1557 | 1558 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1559 | text_: Mapped[Optional[str]] = mapped_column('text', String) 1560 | textwithdefault: Mapped[Optional[str]] = mapped_column(String, \ 1561 | server_default=text("'test'")) 1562 | """, 1563 | ) 1564 | 1565 | 1566 | def test_table_with_arrays(generator: CodeGenerator) -> None: 1567 | Table( 1568 | "with_items", 1569 | generator.metadata, 1570 | Column("id", INTEGER, primary_key=True), 1571 | Column("int_items_not_optional", ARRAY(INTEGER()), nullable=False), 1572 | Column("str_matrix", ARRAY(VARCHAR(), dimensions=2)), 1573 | ) 1574 | 1575 | validate_code( 1576 | generator.generate(), 1577 | """\ 1578 | from typing import Optional 1579 | 1580 | from sqlalchemy import ARRAY, INTEGER, Integer, VARCHAR 1581 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 1582 | 1583 | class Base(DeclarativeBase): 1584 | pass 1585 | 1586 | 1587 | class WithItems(Base): 1588 | __tablename__ = 'with_items' 1589 | 1590 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 1591 | int_items_not_optional: Mapped[list[int]] = mapped_column(ARRAY(INTEGER())) 1592 | str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2)) 1593 | """, 1594 | ) 1595 | -------------------------------------------------------------------------------- /tests/test_generator_sqlmodel.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | from _pytest.fixtures import FixtureRequest 5 | from sqlalchemy.engine import Engine 6 | from sqlalchemy.schema import ( 7 | CheckConstraint, 8 | Column, 9 | ForeignKeyConstraint, 10 | Index, 11 | MetaData, 12 | Table, 13 | UniqueConstraint, 14 | ) 15 | from sqlalchemy.types import INTEGER, VARCHAR 16 | 17 | from sqlacodegen.generators import CodeGenerator, SQLModelGenerator 18 | 19 | from .conftest import validate_code 20 | 21 | 22 | @pytest.fixture 23 | def generator( 24 | request: FixtureRequest, metadata: MetaData, engine: Engine 25 | ) -> CodeGenerator: 26 | options = getattr(request, "param", []) 27 | return SQLModelGenerator(metadata, engine, options) 28 | 29 | 30 | def test_indexes(generator: CodeGenerator) -> None: 31 | simple_items = Table( 32 | "item", 33 | generator.metadata, 34 | Column("id", INTEGER, primary_key=True), 35 | Column("number", INTEGER), 36 | Column("text", VARCHAR), 37 | ) 38 | simple_items.indexes.add(Index("idx_number", simple_items.c.number)) 39 | simple_items.indexes.add( 40 | Index("idx_text_number", simple_items.c.text, simple_items.c.number) 41 | ) 42 | simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True)) 43 | 44 | validate_code( 45 | generator.generate(), 46 | """\ 47 | from typing import Optional 48 | 49 | from sqlalchemy import Column, Index, Integer, String 50 | from sqlmodel import Field, SQLModel 51 | 52 | class Item(SQLModel, table=True): 53 | __table_args__ = ( 54 | Index('idx_number', 'number'), 55 | Index('idx_text', 'text', unique=True), 56 | Index('idx_text_number', 'text', 'number') 57 | ) 58 | 59 | id: Optional[int] = Field(default=None, sa_column=Column(\ 60 | 'id', Integer, primary_key=True)) 61 | number: Optional[int] = Field(default=None, sa_column=Column(\ 62 | 'number', Integer)) 63 | text: Optional[str] = Field(default=None, sa_column=Column(\ 64 | 'text', String)) 65 | """, 66 | ) 67 | 68 | 69 | def test_constraints(generator: CodeGenerator) -> None: 70 | Table( 71 | "simple_constraints", 72 | generator.metadata, 73 | Column("id", INTEGER, primary_key=True), 74 | Column("number", INTEGER), 75 | CheckConstraint("number > 0"), 76 | UniqueConstraint("id", "number"), 77 | ) 78 | 79 | validate_code( 80 | generator.generate(), 81 | """\ 82 | from typing import Optional 83 | 84 | from sqlalchemy import CheckConstraint, Column, Integer, UniqueConstraint 85 | from sqlmodel import Field, SQLModel 86 | 87 | class SimpleConstraints(SQLModel, table=True): 88 | __tablename__ = 'simple_constraints' 89 | __table_args__ = ( 90 | CheckConstraint('number > 0'), 91 | UniqueConstraint('id', 'number') 92 | ) 93 | 94 | id: Optional[int] = Field(default=None, sa_column=Column(\ 95 | 'id', Integer, primary_key=True)) 96 | number: Optional[int] = Field(default=None, sa_column=Column(\ 97 | 'number', Integer)) 98 | """, 99 | ) 100 | 101 | 102 | def test_onetomany(generator: CodeGenerator) -> None: 103 | Table( 104 | "simple_goods", 105 | generator.metadata, 106 | Column("id", INTEGER, primary_key=True), 107 | Column("container_id", INTEGER), 108 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), 109 | ) 110 | Table( 111 | "simple_containers", 112 | generator.metadata, 113 | Column("id", INTEGER, primary_key=True), 114 | ) 115 | 116 | validate_code( 117 | generator.generate(), 118 | """\ 119 | from typing import Optional 120 | 121 | from sqlalchemy import Column, ForeignKey, Integer 122 | from sqlmodel import Field, Relationship, SQLModel 123 | 124 | class SimpleContainers(SQLModel, table=True): 125 | __tablename__ = 'simple_containers' 126 | 127 | id: Optional[int] = Field(default=None, sa_column=Column(\ 128 | 'id', Integer, primary_key=True)) 129 | 130 | simple_goods: list['SimpleGoods'] = Relationship(\ 131 | back_populates='container') 132 | 133 | 134 | class SimpleGoods(SQLModel, table=True): 135 | __tablename__ = 'simple_goods' 136 | 137 | id: Optional[int] = Field(default=None, sa_column=Column(\ 138 | 'id', Integer, primary_key=True)) 139 | container_id: Optional[int] = Field(default=None, sa_column=Column(\ 140 | 'container_id', ForeignKey('simple_containers.id'))) 141 | 142 | container: Optional['SimpleContainers'] = Relationship(\ 143 | back_populates='simple_goods') 144 | """, 145 | ) 146 | 147 | 148 | def test_onetoone(generator: CodeGenerator) -> None: 149 | Table( 150 | "simple_onetoone", 151 | generator.metadata, 152 | Column("id", INTEGER, primary_key=True), 153 | Column("other_item_id", INTEGER), 154 | ForeignKeyConstraint(["other_item_id"], ["other_items.id"]), 155 | UniqueConstraint("other_item_id"), 156 | ) 157 | Table("other_items", generator.metadata, Column("id", INTEGER, primary_key=True)) 158 | 159 | validate_code( 160 | generator.generate(), 161 | """\ 162 | from typing import Optional 163 | 164 | from sqlalchemy import Column, ForeignKey, Integer 165 | from sqlmodel import Field, Relationship, SQLModel 166 | 167 | class OtherItems(SQLModel, table=True): 168 | __tablename__ = 'other_items' 169 | 170 | id: Optional[int] = Field(default=None, sa_column=Column(\ 171 | 'id', Integer, primary_key=True)) 172 | 173 | simple_onetoone: Optional['SimpleOnetoone'] = Relationship(\ 174 | sa_relationship_kwargs={'uselist': False}, back_populates='other_item') 175 | 176 | 177 | class SimpleOnetoone(SQLModel, table=True): 178 | __tablename__ = 'simple_onetoone' 179 | 180 | id: Optional[int] = Field(default=None, sa_column=Column(\ 181 | 'id', Integer, primary_key=True)) 182 | other_item_id: Optional[int] = Field(default=None, sa_column=Column(\ 183 | 'other_item_id', ForeignKey('other_items.id'), unique=True)) 184 | 185 | other_item: Optional['OtherItems'] = Relationship(\ 186 | back_populates='simple_onetoone') 187 | """, 188 | ) 189 | -------------------------------------------------------------------------------- /tests/test_generator_tables.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from textwrap import dedent 4 | 5 | import pytest 6 | from _pytest.fixtures import FixtureRequest 7 | from sqlalchemy import TypeDecorator 8 | from sqlalchemy.dialects import mysql, postgresql 9 | from sqlalchemy.engine import Engine 10 | from sqlalchemy.schema import ( 11 | CheckConstraint, 12 | Column, 13 | Computed, 14 | ForeignKey, 15 | Identity, 16 | Index, 17 | MetaData, 18 | Table, 19 | UniqueConstraint, 20 | ) 21 | from sqlalchemy.sql.expression import text 22 | from sqlalchemy.sql.sqltypes import DateTime, NullType 23 | from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text 24 | 25 | from sqlacodegen.generators import CodeGenerator, TablesGenerator 26 | 27 | from .conftest import validate_code 28 | 29 | 30 | # This needs to be uppercased to trigger #315 31 | class TIMESTAMP_DECORATOR(TypeDecorator[DateTime]): 32 | impl = DateTime 33 | 34 | 35 | @pytest.fixture 36 | def generator( 37 | request: FixtureRequest, metadata: MetaData, engine: Engine 38 | ) -> CodeGenerator: 39 | options = getattr(request, "param", []) 40 | return TablesGenerator(metadata, engine, options) 41 | 42 | 43 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 44 | def test_fancy_coltypes(generator: CodeGenerator) -> None: 45 | from pgvector.sqlalchemy.vector import VECTOR 46 | 47 | Table( 48 | "simple_items", 49 | generator.metadata, 50 | Column("enum", postgresql.ENUM("A", "B", name="blah", schema="someschema")), 51 | Column("bool", postgresql.BOOLEAN), 52 | Column("vector", VECTOR(3)), 53 | Column("number", NUMERIC(10, asdecimal=False)), 54 | Column("timestamp", TIMESTAMP_DECORATOR()), 55 | schema="someschema", 56 | ) 57 | 58 | validate_code( 59 | generator.generate(), 60 | """\ 61 | from tests.test_generator_tables import TIMESTAMP_DECORATOR 62 | 63 | from pgvector.sqlalchemy.vector import VECTOR 64 | from sqlalchemy import Boolean, Column, Enum, MetaData, Numeric, Table 65 | 66 | metadata = MetaData() 67 | 68 | 69 | t_simple_items = Table( 70 | 'simple_items', metadata, 71 | Column('enum', Enum('A', 'B', name='blah', schema='someschema')), 72 | Column('bool', Boolean), 73 | Column('vector', VECTOR(3)), 74 | Column('number', Numeric(10, asdecimal=False)), 75 | Column('timestamp', TIMESTAMP_DECORATOR), 76 | schema='someschema' 77 | ) 78 | """, 79 | ) 80 | 81 | 82 | def test_boolean_detection(generator: CodeGenerator) -> None: 83 | Table( 84 | "simple_items", 85 | generator.metadata, 86 | Column("bool1", INTEGER), 87 | Column("bool2", SMALLINT), 88 | Column("bool3", mysql.TINYINT), 89 | CheckConstraint("simple_items.bool1 IN (0, 1)"), 90 | CheckConstraint("simple_items.bool2 IN (0, 1)"), 91 | CheckConstraint("simple_items.bool3 IN (0, 1)"), 92 | ) 93 | 94 | validate_code( 95 | generator.generate(), 96 | """\ 97 | from sqlalchemy import Boolean, Column, MetaData, Table 98 | 99 | metadata = MetaData() 100 | 101 | 102 | t_simple_items = Table( 103 | 'simple_items', metadata, 104 | Column('bool1', Boolean), 105 | Column('bool2', Boolean), 106 | Column('bool3', Boolean) 107 | ) 108 | """, 109 | ) 110 | 111 | 112 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 113 | def test_arrays(generator: CodeGenerator) -> None: 114 | Table( 115 | "simple_items", 116 | generator.metadata, 117 | Column("dp_array", postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53))), 118 | Column("int_array", postgresql.ARRAY(INTEGER)), 119 | ) 120 | 121 | validate_code( 122 | generator.generate(), 123 | """\ 124 | from sqlalchemy import ARRAY, Column, Double, Integer, MetaData, Table 125 | 126 | metadata = MetaData() 127 | 128 | 129 | t_simple_items = Table( 130 | 'simple_items', metadata, 131 | Column('dp_array', ARRAY(Double(precision=53))), 132 | Column('int_array', ARRAY(Integer())) 133 | ) 134 | """, 135 | ) 136 | 137 | 138 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 139 | def test_jsonb(generator: CodeGenerator) -> None: 140 | Table( 141 | "simple_items", 142 | generator.metadata, 143 | Column("jsonb", postgresql.JSONB(astext_type=Text(50))), 144 | ) 145 | 146 | validate_code( 147 | generator.generate(), 148 | """\ 149 | from sqlalchemy import Column, MetaData, Table, Text 150 | from sqlalchemy.dialects.postgresql import JSONB 151 | 152 | metadata = MetaData() 153 | 154 | 155 | t_simple_items = Table( 156 | 'simple_items', metadata, 157 | Column('jsonb', JSONB(astext_type=Text(length=50))) 158 | ) 159 | """, 160 | ) 161 | 162 | 163 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 164 | def test_jsonb_default(generator: CodeGenerator) -> None: 165 | Table("simple_items", generator.metadata, Column("jsonb", postgresql.JSONB)) 166 | 167 | validate_code( 168 | generator.generate(), 169 | """\ 170 | from sqlalchemy import Column, MetaData, Table 171 | from sqlalchemy.dialects.postgresql import JSONB 172 | 173 | metadata = MetaData() 174 | 175 | 176 | t_simple_items = Table( 177 | 'simple_items', metadata, 178 | Column('jsonb', JSONB) 179 | ) 180 | """, 181 | ) 182 | 183 | 184 | def test_enum_detection(generator: CodeGenerator) -> None: 185 | Table( 186 | "simple_items", 187 | generator.metadata, 188 | Column("enum", VARCHAR(255)), 189 | CheckConstraint(r"simple_items.enum IN ('A', '\'B', 'C')"), 190 | ) 191 | 192 | validate_code( 193 | generator.generate(), 194 | """\ 195 | from sqlalchemy import Column, Enum, MetaData, Table 196 | 197 | metadata = MetaData() 198 | 199 | 200 | t_simple_items = Table( 201 | 'simple_items', metadata, 202 | Column('enum', Enum('A', "\\\\'B", 'C')) 203 | ) 204 | """, 205 | ) 206 | 207 | 208 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 209 | def test_domain_text(generator: CodeGenerator) -> None: 210 | Table( 211 | "simple_items", 212 | generator.metadata, 213 | Column( 214 | "postal_code", 215 | postgresql.DOMAIN( 216 | "us_postal_code", 217 | Text, 218 | constraint_name="valid_us_postal_code", 219 | not_null=False, 220 | check=text("VALUE ~ '^\\d{5}$' OR VALUE ~ '^\\d{5}-\\d{4}$'"), 221 | ), 222 | nullable=False, 223 | ), 224 | ) 225 | 226 | validate_code( 227 | generator.generate(), 228 | """\ 229 | from sqlalchemy import Column, MetaData, Table, Text, text 230 | from sqlalchemy.dialects.postgresql import DOMAIN 231 | 232 | metadata = MetaData() 233 | 234 | 235 | t_simple_items = Table( 236 | 'simple_items', metadata, 237 | Column('postal_code', DOMAIN('us_postal_code', Text(), \ 238 | constraint_name='valid_us_postal_code', not_null=False, \ 239 | check=text("VALUE ~ '^\\\\d{5}$' OR VALUE ~ '^\\\\d{5}-\\\\d{4}$'")), nullable=False) 240 | ) 241 | """, 242 | ) 243 | 244 | 245 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 246 | def test_domain_int(generator: CodeGenerator) -> None: 247 | Table( 248 | "simple_items", 249 | generator.metadata, 250 | Column( 251 | "n", 252 | postgresql.DOMAIN( 253 | "positive_int", 254 | INTEGER, 255 | constraint_name="positive", 256 | not_null=False, 257 | check=text("VALUE > 0"), 258 | ), 259 | nullable=False, 260 | ), 261 | ) 262 | 263 | validate_code( 264 | generator.generate(), 265 | """\ 266 | from sqlalchemy import Column, INTEGER, MetaData, Table, text 267 | from sqlalchemy.dialects.postgresql import DOMAIN 268 | 269 | metadata = MetaData() 270 | 271 | 272 | t_simple_items = Table( 273 | 'simple_items', metadata, 274 | Column('n', DOMAIN('positive_int', INTEGER(), \ 275 | constraint_name='positive', not_null=False, \ 276 | check=text('VALUE > 0')), nullable=False) 277 | ) 278 | """, 279 | ) 280 | 281 | 282 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 283 | def test_column_adaptation(generator: CodeGenerator) -> None: 284 | Table( 285 | "simple_items", 286 | generator.metadata, 287 | Column("id", postgresql.BIGINT), 288 | Column("length", postgresql.DOUBLE_PRECISION), 289 | ) 290 | 291 | validate_code( 292 | generator.generate(), 293 | """\ 294 | from sqlalchemy import BigInteger, Column, Double, MetaData, Table 295 | 296 | metadata = MetaData() 297 | 298 | 299 | t_simple_items = Table( 300 | 'simple_items', metadata, 301 | Column('id', BigInteger), 302 | Column('length', Double) 303 | ) 304 | """, 305 | ) 306 | 307 | 308 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) 309 | def test_mysql_column_types(generator: CodeGenerator) -> None: 310 | Table( 311 | "simple_items", 312 | generator.metadata, 313 | Column("id", mysql.INTEGER), 314 | Column("name", mysql.VARCHAR(255)), 315 | Column("double", mysql.DOUBLE(1, 2)), 316 | Column("set", mysql.SET("one", "two")), 317 | ) 318 | 319 | validate_code( 320 | generator.generate(), 321 | """\ 322 | from sqlalchemy import Column, Integer, MetaData, String, Table 323 | from sqlalchemy.dialects.mysql import DOUBLE, SET 324 | 325 | metadata = MetaData() 326 | 327 | 328 | t_simple_items = Table( 329 | 'simple_items', metadata, 330 | Column('id', Integer), 331 | Column('name', String(255)), 332 | Column('double', DOUBLE(1, 2)), 333 | Column('set', SET('one', 'two')) 334 | ) 335 | """, 336 | ) 337 | 338 | 339 | def test_constraints(generator: CodeGenerator) -> None: 340 | Table( 341 | "simple_items", 342 | generator.metadata, 343 | Column("id", INTEGER), 344 | Column("number", INTEGER), 345 | CheckConstraint("number > 0"), 346 | UniqueConstraint("id", "number"), 347 | ) 348 | 349 | validate_code( 350 | generator.generate(), 351 | """\ 352 | from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table, \ 353 | UniqueConstraint 354 | 355 | metadata = MetaData() 356 | 357 | 358 | t_simple_items = Table( 359 | 'simple_items', metadata, 360 | Column('id', Integer), 361 | Column('number', Integer), 362 | CheckConstraint('number > 0'), 363 | UniqueConstraint('id', 'number') 364 | ) 365 | """, 366 | ) 367 | 368 | 369 | def test_indexes(generator: CodeGenerator) -> None: 370 | simple_items = Table( 371 | "simple_items", 372 | generator.metadata, 373 | Column("id", INTEGER), 374 | Column("number", INTEGER), 375 | Column("text", VARCHAR), 376 | Index("ix_empty"), 377 | ) 378 | simple_items.indexes.add(Index("ix_number", simple_items.c.number)) 379 | simple_items.indexes.add( 380 | Index( 381 | "ix_text_number", 382 | simple_items.c.text, 383 | simple_items.c.number, 384 | unique=True, 385 | ) 386 | ) 387 | simple_items.indexes.add(Index("ix_text", simple_items.c.text, unique=True)) 388 | 389 | validate_code( 390 | generator.generate(), 391 | """\ 392 | from sqlalchemy import Column, Index, Integer, MetaData, String, Table 393 | 394 | metadata = MetaData() 395 | 396 | 397 | t_simple_items = Table( 398 | 'simple_items', metadata, 399 | Column('id', Integer), 400 | Column('number', Integer, index=True), 401 | Column('text', String, unique=True, index=True), 402 | Index('ix_empty'), 403 | Index('ix_text_number', 'text', 'number', unique=True) 404 | ) 405 | """, 406 | ) 407 | 408 | 409 | def test_table_comment(generator: CodeGenerator) -> None: 410 | Table( 411 | "simple", 412 | generator.metadata, 413 | Column("id", INTEGER, primary_key=True), 414 | comment="this is a 'comment'", 415 | ) 416 | 417 | validate_code( 418 | generator.generate(), 419 | """\ 420 | from sqlalchemy import Column, Integer, MetaData, Table 421 | 422 | metadata = MetaData() 423 | 424 | 425 | t_simple = Table( 426 | 'simple', metadata, 427 | Column('id', Integer, primary_key=True), 428 | comment="this is a 'comment'" 429 | ) 430 | """, 431 | ) 432 | 433 | 434 | def test_table_name_identifiers(generator: CodeGenerator) -> None: 435 | Table( 436 | "simple-items table", 437 | generator.metadata, 438 | Column("id", INTEGER, primary_key=True), 439 | ) 440 | 441 | validate_code( 442 | generator.generate(), 443 | """\ 444 | from sqlalchemy import Column, Integer, MetaData, Table 445 | 446 | metadata = MetaData() 447 | 448 | 449 | t_simple_items_table = Table( 450 | 'simple-items table', metadata, 451 | Column('id', Integer, primary_key=True) 452 | ) 453 | """, 454 | ) 455 | 456 | 457 | @pytest.mark.parametrize("generator", [["noindexes"]], indirect=True) 458 | def test_option_noindexes(generator: CodeGenerator) -> None: 459 | simple_items = Table( 460 | "simple_items", 461 | generator.metadata, 462 | Column("number", INTEGER), 463 | CheckConstraint("number > 2"), 464 | ) 465 | simple_items.indexes.add(Index("idx_number", simple_items.c.number)) 466 | 467 | validate_code( 468 | generator.generate(), 469 | """\ 470 | from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table 471 | 472 | metadata = MetaData() 473 | 474 | 475 | t_simple_items = Table( 476 | 'simple_items', metadata, 477 | Column('number', Integer), 478 | CheckConstraint('number > 2') 479 | ) 480 | """, 481 | ) 482 | 483 | 484 | @pytest.mark.parametrize("generator", [["noconstraints"]], indirect=True) 485 | def test_option_noconstraints(generator: CodeGenerator) -> None: 486 | simple_items = Table( 487 | "simple_items", 488 | generator.metadata, 489 | Column("number", INTEGER), 490 | CheckConstraint("number > 2"), 491 | ) 492 | simple_items.indexes.add(Index("ix_number", simple_items.c.number)) 493 | 494 | validate_code( 495 | generator.generate(), 496 | """\ 497 | from sqlalchemy import Column, Integer, MetaData, Table 498 | 499 | metadata = MetaData() 500 | 501 | 502 | t_simple_items = Table( 503 | 'simple_items', metadata, 504 | Column('number', Integer, index=True) 505 | ) 506 | """, 507 | ) 508 | 509 | 510 | @pytest.mark.parametrize("generator", [["nocomments"]], indirect=True) 511 | def test_option_nocomments(generator: CodeGenerator) -> None: 512 | Table( 513 | "simple", 514 | generator.metadata, 515 | Column("id", INTEGER, primary_key=True, comment="pk column comment"), 516 | comment="this is a 'comment'", 517 | ) 518 | 519 | validate_code( 520 | generator.generate(), 521 | """\ 522 | from sqlalchemy import Column, Integer, MetaData, Table 523 | 524 | metadata = MetaData() 525 | 526 | 527 | t_simple = Table( 528 | 'simple', metadata, 529 | Column('id', Integer, primary_key=True) 530 | ) 531 | """, 532 | ) 533 | 534 | 535 | @pytest.mark.parametrize( 536 | "persisted, extra_args", 537 | [(None, ""), (False, ", persisted=False"), (True, ", persisted=True")], 538 | ) 539 | def test_computed_column( 540 | generator: CodeGenerator, persisted: bool | None, extra_args: str 541 | ) -> None: 542 | Table( 543 | "computed", 544 | generator.metadata, 545 | Column("id", INTEGER, primary_key=True), 546 | Column("computed", INTEGER, Computed("1 + 2", persisted=persisted)), 547 | ) 548 | 549 | validate_code( 550 | generator.generate(), 551 | f"""\ 552 | from sqlalchemy import Column, Computed, Integer, MetaData, Table 553 | 554 | metadata = MetaData() 555 | 556 | 557 | t_computed = Table( 558 | 'computed', metadata, 559 | Column('id', Integer, primary_key=True), 560 | Column('computed', Integer, Computed('1 + 2'{extra_args})) 561 | ) 562 | """, 563 | ) 564 | 565 | 566 | def test_schema(generator: CodeGenerator) -> None: 567 | Table( 568 | "simple_items", 569 | generator.metadata, 570 | Column("name", VARCHAR), 571 | schema="testschema", 572 | ) 573 | 574 | validate_code( 575 | generator.generate(), 576 | """\ 577 | from sqlalchemy import Column, MetaData, String, Table 578 | 579 | metadata = MetaData() 580 | 581 | 582 | t_simple_items = Table( 583 | 'simple_items', metadata, 584 | Column('name', String), 585 | schema='testschema' 586 | ) 587 | """, 588 | ) 589 | 590 | 591 | def test_foreign_key_options(generator: CodeGenerator) -> None: 592 | Table( 593 | "simple_items", 594 | generator.metadata, 595 | Column( 596 | "name", 597 | VARCHAR, 598 | ForeignKey( 599 | "simple_items.name", 600 | ondelete="CASCADE", 601 | onupdate="CASCADE", 602 | deferrable=True, 603 | initially="DEFERRED", 604 | ), 605 | ), 606 | ) 607 | 608 | validate_code( 609 | generator.generate(), 610 | """\ 611 | from sqlalchemy import Column, ForeignKey, MetaData, String, Table 612 | 613 | metadata = MetaData() 614 | 615 | 616 | t_simple_items = Table( 617 | 'simple_items', metadata, 618 | Column('name', String, ForeignKey('simple_items.name', \ 619 | ondelete='CASCADE', onupdate='CASCADE', deferrable=True, initially='DEFERRED')) 620 | ) 621 | """, 622 | ) 623 | 624 | 625 | def test_pk_default(generator: CodeGenerator) -> None: 626 | Table( 627 | "simple_items", 628 | generator.metadata, 629 | Column( 630 | "id", 631 | INTEGER, 632 | primary_key=True, 633 | server_default=text("uuid_generate_v4()"), 634 | ), 635 | ) 636 | 637 | validate_code( 638 | generator.generate(), 639 | """\ 640 | from sqlalchemy import Column, Integer, MetaData, Table, text 641 | 642 | metadata = MetaData() 643 | 644 | 645 | t_simple_items = Table( 646 | 'simple_items', metadata, 647 | Column('id', Integer, primary_key=True, \ 648 | server_default=text('uuid_generate_v4()')) 649 | ) 650 | """, 651 | ) 652 | 653 | 654 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) 655 | def test_mysql_timestamp(generator: CodeGenerator) -> None: 656 | Table( 657 | "simple", 658 | generator.metadata, 659 | Column("id", INTEGER, primary_key=True), 660 | Column("timestamp", mysql.TIMESTAMP), 661 | ) 662 | 663 | validate_code( 664 | generator.generate(), 665 | """\ 666 | from sqlalchemy import Column, Integer, MetaData, TIMESTAMP, Table 667 | 668 | metadata = MetaData() 669 | 670 | 671 | t_simple = Table( 672 | 'simple', metadata, 673 | Column('id', Integer, primary_key=True), 674 | Column('timestamp', TIMESTAMP) 675 | ) 676 | """, 677 | ) 678 | 679 | 680 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) 681 | def test_mysql_integer_display_width(generator: CodeGenerator) -> None: 682 | Table( 683 | "simple_items", 684 | generator.metadata, 685 | Column("id", INTEGER, primary_key=True), 686 | Column("number", mysql.INTEGER(11)), 687 | ) 688 | 689 | validate_code( 690 | generator.generate(), 691 | """\ 692 | from sqlalchemy import Column, Integer, MetaData, Table 693 | from sqlalchemy.dialects.mysql import INTEGER 694 | 695 | metadata = MetaData() 696 | 697 | 698 | t_simple_items = Table( 699 | 'simple_items', metadata, 700 | Column('id', Integer, primary_key=True), 701 | Column('number', INTEGER(11)) 702 | ) 703 | """, 704 | ) 705 | 706 | 707 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) 708 | def test_mysql_tinytext(generator: CodeGenerator) -> None: 709 | Table( 710 | "simple_items", 711 | generator.metadata, 712 | Column("id", INTEGER, primary_key=True), 713 | Column("my_tinytext", mysql.TINYTEXT), 714 | ) 715 | 716 | validate_code( 717 | generator.generate(), 718 | """\ 719 | from sqlalchemy import Column, Integer, MetaData, Table 720 | from sqlalchemy.dialects.mysql import TINYTEXT 721 | 722 | metadata = MetaData() 723 | 724 | 725 | t_simple_items = Table( 726 | 'simple_items', metadata, 727 | Column('id', Integer, primary_key=True), 728 | Column('my_tinytext', TINYTEXT) 729 | ) 730 | """, 731 | ) 732 | 733 | 734 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) 735 | def test_mysql_mediumtext(generator: CodeGenerator) -> None: 736 | Table( 737 | "simple_items", 738 | generator.metadata, 739 | Column("id", INTEGER, primary_key=True), 740 | Column("my_mediumtext", mysql.MEDIUMTEXT), 741 | ) 742 | 743 | validate_code( 744 | generator.generate(), 745 | """\ 746 | from sqlalchemy import Column, Integer, MetaData, Table 747 | from sqlalchemy.dialects.mysql import MEDIUMTEXT 748 | 749 | metadata = MetaData() 750 | 751 | 752 | t_simple_items = Table( 753 | 'simple_items', metadata, 754 | Column('id', Integer, primary_key=True), 755 | Column('my_mediumtext', MEDIUMTEXT) 756 | ) 757 | """, 758 | ) 759 | 760 | 761 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) 762 | def test_mysql_longtext(generator: CodeGenerator) -> None: 763 | Table( 764 | "simple_items", 765 | generator.metadata, 766 | Column("id", INTEGER, primary_key=True), 767 | Column("my_longtext", mysql.LONGTEXT), 768 | ) 769 | 770 | validate_code( 771 | generator.generate(), 772 | """\ 773 | from sqlalchemy import Column, Integer, MetaData, Table 774 | from sqlalchemy.dialects.mysql import LONGTEXT 775 | 776 | metadata = MetaData() 777 | 778 | 779 | t_simple_items = Table( 780 | 'simple_items', metadata, 781 | Column('id', Integer, primary_key=True), 782 | Column('my_longtext', LONGTEXT) 783 | ) 784 | """, 785 | ) 786 | 787 | 788 | def test_schema_boolean(generator: CodeGenerator) -> None: 789 | Table( 790 | "simple_items", 791 | generator.metadata, 792 | Column("bool1", INTEGER), 793 | CheckConstraint("testschema.simple_items.bool1 IN (0, 1)"), 794 | schema="testschema", 795 | ) 796 | 797 | validate_code( 798 | generator.generate(), 799 | """\ 800 | from sqlalchemy import Boolean, Column, MetaData, Table 801 | 802 | metadata = MetaData() 803 | 804 | 805 | t_simple_items = Table( 806 | 'simple_items', metadata, 807 | Column('bool1', Boolean), 808 | schema='testschema' 809 | ) 810 | """, 811 | ) 812 | 813 | 814 | def test_server_default_multiline(generator: CodeGenerator) -> None: 815 | Table( 816 | "simple_items", 817 | generator.metadata, 818 | Column( 819 | "id", 820 | INTEGER, 821 | primary_key=True, 822 | server_default=text( 823 | dedent( 824 | """\ 825 | /*Comment*/ 826 | /*Next line*/ 827 | something()""" 828 | ) 829 | ), 830 | ), 831 | ) 832 | 833 | validate_code( 834 | generator.generate(), 835 | """\ 836 | from sqlalchemy import Column, Integer, MetaData, Table, text 837 | 838 | metadata = MetaData() 839 | 840 | 841 | t_simple_items = Table( 842 | 'simple_items', metadata, 843 | Column('id', Integer, primary_key=True, server_default=\ 844 | text('/*Comment*/\\n/*Next line*/\\nsomething()')) 845 | ) 846 | """, 847 | ) 848 | 849 | 850 | def test_server_default_colon(generator: CodeGenerator) -> None: 851 | Table( 852 | "simple_items", 853 | generator.metadata, 854 | Column("problem", VARCHAR, server_default=text("':001'")), 855 | ) 856 | 857 | validate_code( 858 | generator.generate(), 859 | """\ 860 | from sqlalchemy import Column, MetaData, String, Table, text 861 | 862 | metadata = MetaData() 863 | 864 | 865 | t_simple_items = Table( 866 | 'simple_items', metadata, 867 | Column('problem', String, server_default=text("':001'")) 868 | ) 869 | """, 870 | ) 871 | 872 | 873 | def test_null_type(generator: CodeGenerator) -> None: 874 | Table( 875 | "simple_items", 876 | generator.metadata, 877 | Column("problem", NullType), 878 | ) 879 | 880 | validate_code( 881 | generator.generate(), 882 | """\ 883 | from sqlalchemy import Column, MetaData, Table 884 | from sqlalchemy.sql.sqltypes import NullType 885 | 886 | metadata = MetaData() 887 | 888 | 889 | t_simple_items = Table( 890 | 'simple_items', metadata, 891 | Column('problem', NullType) 892 | ) 893 | """, 894 | ) 895 | 896 | 897 | def test_identity_column(generator: CodeGenerator) -> None: 898 | Table( 899 | "simple_items", 900 | generator.metadata, 901 | Column( 902 | "id", 903 | INTEGER, 904 | primary_key=True, 905 | server_default=Identity(start=1, increment=2), 906 | ), 907 | ) 908 | 909 | validate_code( 910 | generator.generate(), 911 | """\ 912 | from sqlalchemy import Column, Identity, Integer, MetaData, Table 913 | 914 | metadata = MetaData() 915 | 916 | 917 | t_simple_items = Table( 918 | 'simple_items', metadata, 919 | Column('id', Integer, Identity(start=1, increment=2), primary_key=True) 920 | ) 921 | """, 922 | ) 923 | 924 | 925 | def test_multiline_column_comment(generator: CodeGenerator) -> None: 926 | Table( 927 | "simple_items", 928 | generator.metadata, 929 | Column("id", INTEGER, comment="This\nis a multi-line\ncomment"), 930 | ) 931 | 932 | validate_code( 933 | generator.generate(), 934 | """\ 935 | from sqlalchemy import Column, Integer, MetaData, Table 936 | 937 | metadata = MetaData() 938 | 939 | 940 | t_simple_items = Table( 941 | 'simple_items', metadata, 942 | Column('id', Integer, comment='This\\nis a multi-line\\ncomment') 943 | ) 944 | """, 945 | ) 946 | 947 | 948 | def test_multiline_table_comment(generator: CodeGenerator) -> None: 949 | Table( 950 | "simple_items", 951 | generator.metadata, 952 | Column("id", INTEGER), 953 | comment="This\nis a multi-line\ncomment", 954 | ) 955 | 956 | validate_code( 957 | generator.generate(), 958 | """\ 959 | from sqlalchemy import Column, Integer, MetaData, Table 960 | 961 | metadata = MetaData() 962 | 963 | 964 | t_simple_items = Table( 965 | 'simple_items', metadata, 966 | Column('id', Integer), 967 | comment='This\\nis a multi-line\\ncomment' 968 | ) 969 | """, 970 | ) 971 | 972 | 973 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 974 | def test_postgresql_sequence_standard_name(generator: CodeGenerator) -> None: 975 | Table( 976 | "simple_items", 977 | generator.metadata, 978 | Column( 979 | "id", 980 | INTEGER, 981 | primary_key=True, 982 | server_default=text("nextval('simple_items_id_seq'::regclass)"), 983 | ), 984 | ) 985 | 986 | validate_code( 987 | generator.generate(), 988 | """\ 989 | from sqlalchemy import Column, Integer, MetaData, Table 990 | 991 | metadata = MetaData() 992 | 993 | 994 | t_simple_items = Table( 995 | 'simple_items', metadata, 996 | Column('id', Integer, primary_key=True) 997 | ) 998 | """, 999 | ) 1000 | 1001 | 1002 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 1003 | def test_postgresql_sequence_nonstandard_name(generator: CodeGenerator) -> None: 1004 | Table( 1005 | "simple_items", 1006 | generator.metadata, 1007 | Column( 1008 | "id", 1009 | INTEGER, 1010 | primary_key=True, 1011 | server_default=text("nextval('test_seq'::regclass)"), 1012 | ), 1013 | ) 1014 | 1015 | validate_code( 1016 | generator.generate(), 1017 | """\ 1018 | from sqlalchemy import Column, Integer, MetaData, Sequence, Table 1019 | 1020 | metadata = MetaData() 1021 | 1022 | 1023 | t_simple_items = Table( 1024 | 'simple_items', metadata, 1025 | Column('id', Integer, Sequence('test_seq'), primary_key=True) 1026 | ) 1027 | """, 1028 | ) 1029 | 1030 | 1031 | @pytest.mark.parametrize( 1032 | "schemaname, seqname", 1033 | [ 1034 | pytest.param("myschema", "test_seq"), 1035 | pytest.param("myschema", '"test_seq"'), 1036 | pytest.param('"my.schema"', "test_seq"), 1037 | pytest.param('"my.schema"', '"test_seq"'), 1038 | ], 1039 | ) 1040 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) 1041 | def test_postgresql_sequence_with_schema( 1042 | generator: CodeGenerator, schemaname: str, seqname: str 1043 | ) -> None: 1044 | expected_schema = schemaname.strip('"') 1045 | Table( 1046 | "simple_items", 1047 | generator.metadata, 1048 | Column( 1049 | "id", 1050 | INTEGER, 1051 | primary_key=True, 1052 | server_default=text(f"nextval('{schemaname}.{seqname}'::regclass)"), 1053 | ), 1054 | schema=expected_schema, 1055 | ) 1056 | 1057 | validate_code( 1058 | generator.generate(), 1059 | f"""\ 1060 | from sqlalchemy import Column, Integer, MetaData, Sequence, Table 1061 | 1062 | metadata = MetaData() 1063 | 1064 | 1065 | t_simple_items = Table( 1066 | 'simple_items', metadata, 1067 | Column('id', Integer, Sequence('test_seq', \ 1068 | schema='{expected_schema}'), primary_key=True), 1069 | schema='{expected_schema}' 1070 | ) 1071 | """, 1072 | ) 1073 | --------------------------------------------------------------------------------