├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ ├── config.yml
│ └── features_request.yaml
├── pull_request_template.md
└── workflows
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGES.rst
├── CONTRIBUTING.rst
├── LICENSE
├── README.rst
├── pyproject.toml
├── src
└── sqlacodegen
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli.py
│ ├── generators.py
│ ├── models.py
│ ├── py.typed
│ └── utils.py
└── tests
├── __init__.py
├── conftest.py
├── test_cli.py
├── test_generator_dataclass.py
├── test_generator_declarative.py
├── test_generator_sqlmodel.py
└── test_generator_tables.py
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: File a bug report
3 | labels: ["bug"]
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: >
8 | If you observed a crash in the project, or saw unexpected behavior in it, report
9 | your findings here.
10 | - type: checkboxes
11 | attributes:
12 | label: Things to check first
13 | options:
14 | - label: >
15 | I have searched the existing issues and didn't find my bug already reported
16 | there
17 | required: true
18 | - label: >
19 | I have checked that my bug is still present in the latest release
20 | required: true
21 | - type: input
22 | id: project-version
23 | attributes:
24 | label: Sqlacodegen version
25 | description: What version of Sqlacodegen were you running?
26 | validations:
27 | required: true
28 | - type: input
29 | id: sqlalchemy-version
30 | attributes:
31 | label: SQLAlchemy version
32 | description: What version of SQLAlchemy were you running?
33 | validations:
34 | required: true
35 | - type: dropdown
36 | id: rdbms
37 | attributes:
38 | label: RDBMS vendor
39 | description: >
40 | What RDBMS (relational database management system) did you run the tool against?
41 | options:
42 | - PostgreSQL
43 | - MySQL (or compatible)
44 | - SQLite
45 | - MSSQL
46 | - Oracle
47 | - DB2
48 | - Other
49 | - N/A
50 | validations:
51 | required: true
52 | - type: textarea
53 | id: what-happened
54 | attributes:
55 | label: What happened?
56 | description: >
57 | Unless you are reporting a crash, tell us what you expected to happen instead.
58 | validations:
59 | required: true
60 | - type: textarea
61 | id: schema
62 | attributes:
63 | label: Database schema for reproducing the bug
64 | description: >
65 | If applicable, paste the database schema (as a series of `CREATE TABLE` and
66 | other SQL commands) here.
67 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/features_request.yaml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Suggest a new feature
3 | labels: ["enhancement"]
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: >
8 | If you have thought of a new feature that would increase the usefulness of this
9 | project, please use this form to send us your idea.
10 | - type: checkboxes
11 | attributes:
12 | label: Things to check first
13 | options:
14 | - label: >
15 | I have searched the existing issues and didn't find my feature already
16 | requested there
17 | required: true
18 | - type: textarea
19 | id: feature
20 | attributes:
21 | label: Feature description
22 | description: >
23 | Describe the feature in detail. The more specific the description you can give,
24 | the easier it should be to implement this feature.
25 | validations:
26 | required: true
27 | - type: textarea
28 | id: usecase
29 | attributes:
30 | label: Use case
31 | description: >
32 | Explain why you need this feature, and why you think it would be useful to
33 | others too.
34 | validations:
35 | required: true
36 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 |
2 | ## Changes
3 |
4 | Fixes #.
5 |
6 |
7 |
8 | ## Checklist
9 |
10 | If this is a user-facing code change, like a bugfix or a new feature, please ensure that
11 | you've fulfilled the following conditions (where applicable):
12 |
13 | - [ ] You've added tests (in `tests/`) which would fail without your patch
14 | - [ ] You've added a new changelog entry (in `CHANGES.rst`).
15 |
16 | If this is a trivial change, like a typo fix or a code reformatting, then you can ignore
17 | these instructions.
18 |
19 | ### Updating the changelog
20 |
21 | If there are no entries after the last release, use `**UNRELEASED**` as the version.
22 | If, say, your patch fixes issue #123, the entry should look like this:
23 |
24 | ```
25 | - Fix big bad boo-boo in task groups
26 | (`#123 `_; PR by @yourgithubaccount)
27 | ```
28 |
29 | If there's no issue linked, just link to your pull request instead by updating the
30 | changelog after you've created the PR.
31 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish packages to PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - "[0-9]+.[0-9]+.[0-9]+"
7 | - "[0-9]+.[0-9]+.[0-9]+.post[0-9]+"
8 | - "[0-9]+.[0-9]+.[0-9]+[a-b][0-9]+"
9 | - "[0-9]+.[0-9]+.[0-9]+rc[0-9]+"
10 |
11 | jobs:
12 | build:
13 | name: Build the source tarball and the wheel
14 | runs-on: ubuntu-latest
15 | environment: release
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Set up Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: 3.x
22 | - name: Install dependencies
23 | run: pip install build
24 | - name: Create packages
25 | run: python -m build
26 | - name: Archive packages
27 | uses: actions/upload-artifact@v4
28 | with:
29 | name: dist
30 | path: dist
31 |
32 | publish:
33 | name: Publish build artifacts to the PyPI
34 | needs: build
35 | runs-on: ubuntu-latest
36 | environment: release
37 | permissions:
38 | id-token: write
39 | steps:
40 | - name: Retrieve packages
41 | uses: actions/download-artifact@v4
42 | - name: Upload packages
43 | uses: pypa/gh-action-pypi-publish@release/v1
44 |
45 | release:
46 | name: Create a GitHub release
47 | needs: build
48 | runs-on: ubuntu-latest
49 | permissions:
50 | contents: write
51 | steps:
52 | - uses: actions/checkout@v4
53 | - id: changelog
54 | uses: agronholm/release-notes@v1
55 | with:
56 | path: CHANGES.rst
57 | - uses: ncipollo/release-action@v1
58 | with:
59 | body: ${{ steps.changelog.outputs.changelog }}
60 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test suite
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 |
8 | jobs:
9 | test:
10 | strategy:
11 | fail-fast: false
12 | matrix:
13 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | allow-prereleases: true
22 | cache: pip
23 | cache-dependency-path: pyproject.toml
24 | - name: Install dependencies
25 | run: pip install -e .[test]
26 | - name: Test with pytest
27 | run: coverage run -m pytest
28 | - name: Upload Coverage
29 | uses: coverallsapp/github-action@v2
30 | with:
31 | parallel: true
32 |
33 | coveralls:
34 | name: Finish Coveralls
35 | needs: test
36 | runs-on: ubuntu-latest
37 | steps:
38 | - name: Finished
39 | uses: coverallsapp/github-action@v2
40 | with:
41 | parallel-finished: true
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | *.pyc
3 | .project
4 | .pydevproject
5 | .coverage
6 | .settings
7 | .tox
8 | .idea
9 | .vscode
10 | .cache
11 | .pytest_cache
12 | .mypy_cache
13 | dist
14 | build
15 | venv*
16 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # This is the configuration file for pre-commit (https://pre-commit.com/).
2 | # To use:
3 | # * Install pre-commit (https://pre-commit.com/#installation)
4 | # * Copy this file as ".pre-commit-config.yaml"
5 | # * Run "pre-commit install".
6 | repos:
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v5.0.0
9 | hooks:
10 | - id: check-toml
11 | - id: check-yaml
12 | - id: debug-statements
13 | - id: end-of-file-fixer
14 | - id: mixed-line-ending
15 | args: [ "--fix=lf" ]
16 | - id: trailing-whitespace
17 |
18 | - repo: https://github.com/astral-sh/ruff-pre-commit
19 | rev: v0.11.12
20 | hooks:
21 | - id: ruff
22 | args: [--fix, --show-fixes]
23 | - id: ruff-format
24 |
25 | - repo: https://github.com/pre-commit/mirrors-mypy
26 | rev: v1.16.0
27 | hooks:
28 | - id: mypy
29 | additional_dependencies:
30 | - pytest
31 | - "SQLAlchemy >= 2.0.29"
32 |
33 | - repo: https://github.com/pre-commit/pygrep-hooks
34 | rev: v1.10.0
35 | hooks:
36 | - id: rst-backticks
37 | - id: rst-directive-colons
38 | - id: rst-inline-touching-normal
39 |
--------------------------------------------------------------------------------
/CHANGES.rst:
--------------------------------------------------------------------------------
1 | Version history
2 | ===============
3 |
4 | **UNRELEASED**
5 |
6 | - Type annotations for ARRAY column attributes now include the Python type of
7 | the array elements
8 | - Added support for specifying engine arguments via ``--engine-arg``
9 | (PR by @LajosCseppento)
10 | - Fixed incorrect package name used in ``importlib.metadata.version`` for
11 | ``sqlalchemy-citext``, resolving ``PackageNotFoundError`` (PR by @oaimtiaz)
12 | - Prevent double pluralization (PR by @dkratzert)
13 |
14 | **3.0.0**
15 |
16 | - Dropped support for Python 3.8
17 | - Changed nullable relationships to include ``Optional`` in their type annotations
18 | - Fixed SQLModel code generation
19 | - Fixed two rendering issues in ``ENUM`` columns when a non-default schema is used: an
20 | unwarranted positional argument and missing the ``schema`` argument
21 | - Fixed ``AttributeError`` when metadata contains user defined column types
22 | - Fixed ``AssertionError`` when metadata contains a column type that is a type decorator
23 | with an all-uppercase name
24 | - Fixed MySQL ``DOUBLE`` column types being rendered with the wrong arguments
25 |
26 | **3.0.0rc5**
27 |
28 | - Fixed pgvector support not working
29 |
30 | **3.0.0rc4**
31 |
32 | - Dropped support for Python 3.7
33 | - Dropped support for SQLAlchemy 1.x
34 | - Added support for the ``pgvector`` extension (with help from KellyRousselHoomano)
35 |
36 | **3.0.0rc3**
37 |
38 | - Added support for SQLAlchemy 2 (PR by rbuffat with help from mhauru)
39 | - Renamed ``--option`` to ``--options`` and made its values delimited by commas
40 | - Restored CIText and GeoAlchemy2 support (PR by stavvy-rotte)
41 |
42 | **3.0.0rc2**
43 |
44 | - Added support for generating SQLModel classes (PR by Andrii Khirilov)
45 | - Fixed code generation when a single-column index is unique or does not match the
46 | dialect's naming convention (PR by Leonardus Chen)
47 | - Fixed another problem where sequence schemas were not properly separated from the
48 | sequence name
49 | - Fixed invalid generated primary/secondaryjoin expressions in self-referential
50 | many-to-many relationships by using lambdas instead of strings
51 | - Fixed ``AttributeError`` when the declarative generator encounters a table name
52 | already in singular form when ``--option use_inflect`` is enabled
53 | - Increased minimum SQLAlchemy version to 1.4.36 to address issues with ``ForeignKey``
54 | and indexes, and to eliminate the PostgreSQL UUID column type annotation hack
55 |
56 | **3.0.0rc1**
57 |
58 | - Migrated all packaging/testing configuration to ``pyproject.toml``
59 | - Fixed unwarranted ``ForeignKey`` declarations appearing in column attributes when
60 | there are named, single column foreign key constraints (PR by Leonardus Chen)
61 | . Fixed ``KeyError`` when rendering an index without any columns
62 | - Fixed improper handling of schema prefixes in sequence names in server defaults
63 | - Fixed identically named tables from different schemas resulting in invalid generated
64 | code
65 | - Fixed imports caused by ``server_default`` conflicting with class attribute names
66 | - Worked around PostgreSQL UUID columns getting ``Any`` as the type annotation
67 |
68 | **3.0.0b3**
69 |
70 | - Dropped support for Python < 3.7
71 | - Dropped support for SQLAlchemy 1.3
72 | - Added a ``__main__`` module which can be used as an alternate entry point to the CLI
73 | - Added detection for sequence use in column defaults on PostgreSQL
74 | - Fixed ``sqlalchemy.exc.InvalidRequestError`` when encountering a column named
75 | "metadata" (regression from 2.0)
76 | - Fixed missing ``MetaData`` import with ``DeclarativeGenerator`` when only plain tables
77 | are generated
78 | - Fixed invalid data classes being generated due to some relationships having been
79 | rendered without a default value
80 | - Improved translation of column names into column attributes where the column name has
81 | whitespace at the beginning or end
82 | - Modified constraint and index rendering to add them explicitly instead of using
83 | shortcuts like ``unique=True``, ``index=True`` or ``primary=True`` when the constraint
84 | or index has a name that does not match the default naming convention
85 |
86 | **3.0.0b2**
87 |
88 | - Fixed ``IDENTITY`` columns not rendering properly when they are part of the primary
89 | key
90 |
91 | **3.0.0b1**
92 |
93 | **NOTE**: Both the API and the command line interface have been refactored in a
94 | backwards incompatible fashion. Notably several command line options have been moved to
95 | specific generators and are no longer visible from ``sqlacodegen --help``. Their
96 | replacement are documented in the README.
97 |
98 | - Dropped support for Python < 3.6
99 | - Added support for Python 3.10
100 | - Added support for SQLAlchemy 1.4
101 | - Added support for bidirectional relationships (use ``--option nobidi``) to disable
102 | - Added support for multiple schemas via ``--schemas``
103 | - Added support for ``IDENTITY`` columns
104 | - Disabled inflection during table/relationship name generation by default
105 | (use ``--option use_inflect`` to re-enable)
106 | - Refactored the old ``CodeGenerator`` class into separate generator classes, selectable
107 | via ``--generator``
108 | - Refactored several command line options into generator specific options:
109 |
110 | - ``--noindexes`` → ``--option noindexes``
111 | - ``--noconstraints`` → ``--option noconstraints``
112 | - ``--nocomments`` → ``--option nocomments``
113 | - ``--nojoined`` → ``--option nojoined`` (``declarative`` and ``dataclass`` generators
114 | only)
115 | - ``--noinflect`` → (now the default; use ``--option use_inflect`` instead)
116 | (``declarative`` and ``dataclass`` generators only)
117 | - Fixed missing import for ``JSONB`` ``astext_type`` argument
118 | - Fixed generated column or relationship names colliding with imports or each other
119 | - Fixed ``CompileError`` when encountering server defaults that contain colons (``:``)
120 |
121 | **2.3.0**
122 |
123 | - Added support for rendering computed columns
124 | - Fixed ``--nocomments`` not taking effect (fix proposed by AzuresYang)
125 | - Fixed handling of MySQL ``SET`` column types (and possibly others as well)
126 |
127 | **2.2.0**
128 |
129 | - Added support for rendering table comments (PR by David Hirschfeld)
130 | - Fixed bad identifier names being generated for plain tables (PR by softwarepk)
131 |
132 | **2.1.0**
133 |
134 | - Dropped support for Python 3.4
135 | - Dropped support for SQLAlchemy 0.8
136 | - Added support for Python 3.7 and 3.8
137 | - Added support for SQLAlchemy 1.3
138 | - Added support for column comments (requires SQLAlchemy 1.2+; based on PR by koalas8)
139 | - Fixed crash on unknown column types (``NullType``)
140 |
141 | **2.0.1**
142 |
143 | - Don't adapt dialect specific column types if they need special constructor arguments
144 | (thanks Nicholas Martin for the PR)
145 |
146 | **2.0.0**
147 |
148 | - Refactored code for better reuse
149 | - Dropped support for Python 2.6, 3.2 and 3.3
150 | - Dropped support for SQLAlchemy < 0.8
151 | - Worked around a bug regarding Enum on SQLAlchemy 1.2+ (``name`` was missing)
152 | - Added support for Geoalchemy2
153 | - Fixed invalid class names being generated (fixes #60; PR by Dan O'Huiginn)
154 | - Fixed array item types not being adapted or imported
155 | (fixes #46; thanks to Martin Glauer and Shawn Koschik for help)
156 | - Fixed attribute name of columns named ``metadata`` in mapped classes (fixes #62)
157 | - Fixed rendered column types being changed from the original (fixes #11)
158 | - Fixed server defaults which contain double quotes (fixes #7, #17, #28, #33, #36)
159 | - Fixed ``secondary=`` not taking into account the association table's schema name
160 | (fixes #30)
161 | - Sort models by foreign key dependencies instead of schema and name (fixes #15, #16)
162 |
163 | **1.1.6**
164 |
165 | - Fixed compatibility with SQLAlchemy 1.0
166 | - Added an option to only generate tables
167 |
168 | **1.1.5**
169 |
170 | - Fixed potential assignment of columns or relationships into invalid attribute names
171 | (fixes #10)
172 | - Fixed unique=True missing from unique Index declarations
173 | - Fixed several issues with server defaults
174 | - Fixed potential assignment of columns or relationships into invalid attribute names
175 | - Allowed pascal case for tables already using it
176 | - Switched from Mercurial to Git
177 |
178 | **1.1.4**
179 |
180 | - Fixed compatibility with SQLAlchemy 0.9.0
181 |
182 | **1.1.3**
183 |
184 | - Fixed compatibility with SQLAlchemy 0.8.3+
185 | - Migrated tests from nose to pytest
186 |
187 | **1.1.2**
188 |
189 | - Fixed non-default schema name not being present in __table_args__ (fixes #2)
190 | - Fixed self referential foreign key causing column type to not be rendered
191 | - Fixed missing "deferrable" and "initially" keyword arguments in ForeignKey constructs
192 | - Fixed foreign key and check constraint handling with alternate schemas (fixes #3)
193 |
194 | **1.1.1**
195 |
196 | - Fixed TypeError when inflect could not determine the singular name of a table for a
197 | many-to-1 relationship
198 | - Fixed _IntegerType, _StringType etc. being rendered instead of proper types on MySQL
199 |
200 | **1.1.0**
201 |
202 | - Added automatic detection of joined-table inheritance
203 | - Fixed missing class name prefix in primary/secondary joins in relationships
204 | - Instead of wildcard imports, generate explicit imports dynamically (fixes #1)
205 | - Use the inflect library to produce better guesses for table to class name conversion
206 | - Automatically detect Boolean columns based on CheckConstraints
207 | - Skip redundant CheckConstraints for Enum and Boolean columns
208 |
209 | **1.0.0**
210 |
211 | - Initial release
212 |
--------------------------------------------------------------------------------
/CONTRIBUTING.rst:
--------------------------------------------------------------------------------
1 | Contributing to sqlacodegen
2 | ===========================
3 |
4 | If you wish to contribute a fix or feature to sqlacodegen, please follow the following
5 | guidelines.
6 |
7 | When you make a pull request against the main sqlacodegen codebase, Github runs the
8 | sqlacodegen test suite against your modified code. Before making a pull request, you
9 | should ensure that the modified code passes tests locally. To that end, the use of tox_
10 | is recommended. The default tox run first runs ``pre-commit`` and then the actual test
11 | suite. To run the checks on all environments in parallel, invoke tox with ``tox -p``.
12 |
13 | To build the documentation, run ``tox -e docs`` which will generate a directory named
14 | ``build`` in which you may view the formatted HTML documentation.
15 |
16 | sqlacodegen uses pre-commit_ to perform several code style/quality checks. It is
17 | recommended to activate pre-commit_ on your local clone of the repository (using
18 | ``pre-commit install``) to ensure that your changes will pass the same checks on GitHub.
19 |
20 | .. _tox: https://tox.readthedocs.io/en/latest/install.html
21 | .. _pre-commit: https://pre-commit.com/#installation
22 |
23 | Making a pull request on Github
24 | -------------------------------
25 |
26 | To get your changes merged to the main codebase, you need a Github account.
27 |
28 | #. Fork the repository (if you don't have your own fork of it yet) by navigating to the
29 | `main sqlacodegen repository`_ and clicking on "Fork" near the top right corner.
30 | #. Clone the forked repository to your local machine with
31 | ``git clone git@github.com/yourusername/sqlacodegen``.
32 | #. Create a branch for your pull request, like ``git checkout -b myfixname``
33 | #. Make the desired changes to the code base.
34 | #. Commit your changes locally. If your changes close an existing issue, add the text
35 | ``Fixes #XXX.`` or ``Closes #XXX.`` to the commit message (where XXX is the issue
36 | number).
37 | #. Push the changeset(s) to your forked repository (``git push``)
38 | #. Navigate to Pull requests page on the original repository (not your fork) and click
39 | "New pull request"
40 | #. Click on the text "compare across forks".
41 | #. Select your own fork as the head repository and then select the correct branch name.
42 | #. Click on "Create pull request".
43 |
44 | If you have trouble, consult the `pull request making guide`_ on opensource.com.
45 |
46 | .. _main sqlacodegen repository: https://github.com/agronholm/sqlacodegen
47 | .. _pull request making guide: https://opensource.com/article/19/7/create-pull-request-github
48 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | This is the MIT license: http://www.opensource.org/licenses/mit-license.php
2 |
3 | Copyright (c) Alex Grönholm
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this
6 | software and associated documentation files (the "Software"), to deal in the Software
7 | without restriction, including without limitation the rights to use, copy, modify, merge,
8 | publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
9 | to whom the Software is furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in all copies or
12 | substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
15 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
16 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
17 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | DEALINGS IN THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | .. image:: https://github.com/agronholm/sqlacodegen/actions/workflows/test.yml/badge.svg
2 | :target: https://github.com/agronholm/sqlacodegen/actions/workflows/test.yml
3 | :alt: Build Status
4 | .. image:: https://coveralls.io/repos/github/agronholm/sqlacodegen/badge.svg?branch=master
5 | :target: https://coveralls.io/github/agronholm/sqlacodegen?branch=master
6 | :alt: Code Coverage
7 |
8 | This is a tool that reads the structure of an existing database and generates the
9 | appropriate SQLAlchemy model code, using the declarative style if possible.
10 |
11 | This tool was written as a replacement for `sqlautocode`_, which was suffering from
12 | several issues (including, but not limited to, incompatibility with Python 3 and the
13 | latest SQLAlchemy version).
14 |
15 | .. _sqlautocode: http://code.google.com/p/sqlautocode/
16 |
17 |
18 | Features
19 | ========
20 |
21 | * Supports SQLAlchemy 2.x
22 | * Produces declarative code that almost looks like it was hand written
23 | * Produces `PEP 8`_ compliant code
24 | * Accurately determines relationships, including many-to-many, one-to-one
25 | * Automatically detects joined table inheritance
26 | * Excellent test coverage
27 |
28 | .. _PEP 8: http://www.python.org/dev/peps/pep-0008/
29 |
30 |
31 | Installation
32 | ============
33 |
34 | To install, do::
35 |
36 | pip install sqlacodegen
37 |
38 | To include support for the PostgreSQL ``CITEXT`` extension type (which should be
39 | considered as tested only under a few environments) specify the ``citext`` extra::
40 |
41 | pip install sqlacodegen[citext]
42 |
43 |
44 | To include support for the PostgreSQL ``GEOMETRY``, ``GEOGRAPHY``, and ``RASTER`` types
45 | (which should be considered as tested only under a few environments) specify the
46 | ``geoalchemy2`` extra:
47 |
48 | To include support for the PostgreSQL ``PGVECTOR`` extension type, specify the
49 | ``pgvector`` extra::
50 |
51 | pip install sqlacodegen[pgvector]
52 |
53 | .. code-block:: bash
54 |
55 | pip install sqlacodegen[geoalchemy2]
56 |
57 |
58 | Quickstart
59 | ==========
60 |
61 | At the minimum, you have to give sqlacodegen a database URL. The URL is passed directly
62 | to SQLAlchemy's `create_engine()`_ method so please refer to
63 | `SQLAlchemy's documentation`_ for instructions on how to construct a proper URL.
64 |
65 | Examples::
66 |
67 | sqlacodegen postgresql:///some_local_db
68 | sqlacodegen --generator tables mysql+pymysql://user:password@localhost/dbname
69 | sqlacodegen --generator dataclasses sqlite:///database.db
70 | # --engine-arg values are parsed with ast.literal_eval
71 | sqlacodegen oracle+oracledb://user:pass@127.0.0.1:1521/XE --engine-arg thick_mode=True
72 | sqlacodegen oracle+oracledb://user:pass@127.0.0.1:1521/XE --engine-arg thick_mode=True --engine-arg connect_args='{"user": "user", "dsn": "..."}'
73 |
74 | To see the list of generic options::
75 |
76 | sqlacodegen --help
77 |
78 | .. _create_engine(): http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
79 | .. _SQLAlchemy's documentation: http://docs.sqlalchemy.org/en/latest/core/engines.html
80 |
81 | Available generators
82 | ====================
83 |
84 | The selection of a generator determines the
85 |
86 | The following built-in generators are available:
87 |
88 | * ``tables`` (only generates ``Table`` objects, for those who don't want to use the ORM)
89 | * ``declarative`` (the default; generates classes inheriting from ``declarative_base()``
90 | * ``dataclasses`` (generates dataclass-based models; v1.4+ only)
91 | * ``sqlmodels`` (generates model classes for SQLModel_)
92 |
93 | .. _SQLModel: https://sqlmodel.tiangolo.com/
94 |
95 | Generator-specific options
96 | ==========================
97 |
98 | The following options can be turned on by passing them using ``--options`` (multiple
99 | values must be delimited by commas, e.g. ``--options noconstraints,nobidi``):
100 |
101 | * ``tables``
102 |
103 | * ``noconstraints``: ignore constraints (foreign key, unique etc.)
104 | * ``nocomments``: ignore table/column comments
105 | * ``noindexes``: ignore indexes
106 |
107 | * ``declarative``
108 |
109 | * all the options from ``tables``
110 | * ``use_inflect``: use the ``inflect`` library when naming classes and relationships
111 | (turning plural names into singular; see below for details)
112 | * ``nojoined``: don't try to detect joined-class inheritance (see below for details)
113 | * ``nobidi``: generate relationships in a unidirectional fashion, so only the
114 | many-to-one or first side of many-to-many relationships gets a relationship
115 | attribute, as on v2.X
116 |
117 | * ``dataclasses``
118 |
119 | * all the options from ``declarative``
120 |
121 | * ``sqlmodels``
122 |
123 | * all the options from ``declarative``
124 |
125 | Model class generators
126 | ----------------------
127 |
128 | The code generators that generate classes try to generate model classes whenever
129 | possible. There are two circumstances in which a ``Table`` is generated instead:
130 |
131 | * the table has no primary key constraint (which is required by SQLAlchemy for every
132 | model class)
133 | * the table is an association table between two other tables (see below for the
134 | specifics)
135 |
136 | Model class naming logic
137 | ++++++++++++++++++++++++
138 |
139 | By default, table names are converted to valid PEP 8 compliant class names by replacing
140 | all characters unsuitable for Python identifiers with ``_``. Then, each valid parts
141 | (separated by underscores) are title cased and then joined together, eliminating the
142 | underscores. So, ``example_name`` becomes ``ExampleName``.
143 |
144 | If the ``use_inflect`` option is used, the table name (which is assumed to be in
145 | English) is converted to singular form using the "inflect" library. For example,
146 | ``sales_invoices`` becomes ``SalesInvoice``. Since table names are not always in
147 | English, and the inflection process is far from perfect, inflection is disabled by
148 | default.
149 |
150 | Relationship detection logic
151 | ++++++++++++++++++++++++++++
152 |
153 | Relationships are detected based on existing foreign key constraints as follows:
154 |
155 | * **many-to-one**: a foreign key constraint exists on the table
156 | * **one-to-one**: same as **many-to-one**, but a unique constraint exists on the
157 | column(s) involved
158 | * **many-to-many**: (not implemented on the ``sqlmodel`` generator) an association table
159 | is found to exist between two tables
160 |
161 | A table is considered an association table if it satisfies all of the following
162 | conditions:
163 |
164 | #. has exactly two foreign key constraints
165 | #. all its columns are involved in said constraints
166 |
167 | Relationship naming logic
168 | +++++++++++++++++++++++++
169 |
170 | Relationships are typically named based on the table name of the opposite class.
171 | For example, if a class has a relationship to another class with the table named
172 | ``companies``, the relationship would be named ``companies`` (unless the ``use_inflect``
173 | option was enabled, in which case it would be named ``company`` in the case of a
174 | many-to-one or one-to-one relationship).
175 |
176 | A special case for single column many-to-one and one-to-one relationships, however, is
177 | if the column is named like ``employer_id``. Then the relationship is named ``employer``
178 | due to that ``_id`` suffix.
179 |
180 | For self referential relationships, the reverse side of the relationship will be named
181 | with the ``_reverse`` suffix appended to it.
182 |
183 | Customizing code generation logic
184 | =================================
185 |
186 | If the built-in generators with all their options don't quite do what you want, you can
187 | customize the logic by subclassing one of the existing code generator classes. Override
188 | whichever methods you need, and then add an `entry point`_ in the
189 | ``sqlacodegen.generators`` namespace that points to your new class. Once the entry point
190 | is in place (you typically have to install the project with ``pip install``), you can
191 | use ``--generator `` to invoke your custom code generator.
192 |
193 | For examples, you can look at sqlacodegen's own entry points in its `pyproject.toml`_.
194 |
195 | .. _entry point: https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html
196 | .. _pyproject.toml: https://github.com/agronholm/sqlacodegen/blob/master/pyproject.toml
197 |
198 | Getting help
199 | ============
200 |
201 | If you have problems or other questions, you should start a discussion on the
202 | `sqlacodegen discussion forum`_. As an alternative, you could also try your luck on the
203 | sqlalchemy_ room on Gitter.
204 |
205 | .. _sqlacodegen discussion forum: https://github.com/agronholm/sqlacodegen/discussions/categories/q-a
206 | .. _sqlalchemy: https://app.gitter.im/#/room/#sqlalchemy_community:gitter.im
207 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools >= 64",
4 | "setuptools_scm[toml] >= 6.4"
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [project]
9 | name = "sqlacodegen"
10 | description = "Automatic model code generator for SQLAlchemy"
11 | readme = "README.rst"
12 | authors = [{name = "Alex Grönholm", email = "alex.gronholm@nextday.fi"}]
13 | keywords = ["sqlalchemy"]
14 | license = {text = "MIT"}
15 | classifiers = [
16 | "Development Status :: 5 - Production/Stable",
17 | "Intended Audience :: Developers",
18 | "License :: OSI Approved :: MIT License",
19 | "Environment :: Console",
20 | "Topic :: Database",
21 | "Topic :: Software Development :: Code Generators",
22 | "Programming Language :: Python",
23 | "Programming Language :: Python :: 3",
24 | "Programming Language :: Python :: 3.9",
25 | "Programming Language :: Python :: 3.10",
26 | "Programming Language :: Python :: 3.11",
27 | "Programming Language :: Python :: 3.12",
28 | "Programming Language :: Python :: 3.13",
29 | ]
30 | requires-python = ">=3.9"
31 | dependencies = [
32 | "SQLAlchemy >= 2.0.29",
33 | "inflect >= 4.0.0",
34 | "importlib_metadata; python_version < '3.10'",
35 | ]
36 | dynamic = ["version"]
37 |
38 | [project.urls]
39 | "Bug Tracker" = "https://github.com/agronholm/sqlacodegen/issues"
40 | "Source Code" = "https://github.com/agronholm/sqlacodegen"
41 |
42 | [project.optional-dependencies]
43 | test = [
44 | "sqlacodegen[sqlmodel,pgvector]",
45 | "pytest >= 7.4",
46 | "coverage >= 7",
47 | "psycopg[binary]",
48 | "mysql-connector-python",
49 | ]
50 | sqlmodel = ["sqlmodel >= 0.0.22"]
51 | citext = ["sqlalchemy-citext >= 1.7.0"]
52 | geoalchemy2 = ["geoalchemy2 >= 0.11.1"]
53 | pgvector = ["pgvector >= 0.2.4"]
54 |
55 | [project.entry-points."sqlacodegen.generators"]
56 | tables = "sqlacodegen.generators:TablesGenerator"
57 | declarative = "sqlacodegen.generators:DeclarativeGenerator"
58 | dataclasses = "sqlacodegen.generators:DataclassGenerator"
59 | sqlmodels = "sqlacodegen.generators:SQLModelGenerator"
60 |
61 | [project.scripts]
62 | sqlacodegen = "sqlacodegen.cli:main"
63 |
64 | [tool.setuptools_scm]
65 | version_scheme = "post-release"
66 | local_scheme = "dirty-tag"
67 |
68 | [tool.ruff]
69 | src = ["src"]
70 |
71 | [tool.ruff.lint]
72 | extend-select = [
73 | "I", # isort
74 | "ISC", # flake8-implicit-str-concat
75 | "PGH", # pygrep-hooks
76 | "RUF100", # unused noqa (yesqa)
77 | "UP", # pyupgrade
78 | "W", # pycodestyle warnings
79 | ]
80 |
81 | [tool.mypy]
82 | strict = true
83 | disable_error_code = "no-untyped-call"
84 |
85 | [tool.pytest.ini_options]
86 | addopts = "-rsfE --tb=short"
87 | testpaths = ["tests"]
88 |
89 | [coverage.run]
90 | source = ["sqlacodegen"]
91 | relative_files = true
92 |
93 | [coverage.report]
94 | show_missing = true
95 |
96 | [tool.tox]
97 | env_list = ["py39", "py310", "py311", "py312", "py313"]
98 | skip_missing_interpreters = true
99 |
100 | [tool.tox.env_run_base]
101 | package = "editable"
102 | commands = [["python", "-m", "pytest", { replace = "posargs", extend = true }]]
103 | extras = ["test"]
104 |
--------------------------------------------------------------------------------
/src/sqlacodegen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/agronholm/sqlacodegen/84dcc39d5cef1a840234624603857ab72ec333a0/src/sqlacodegen/__init__.py
--------------------------------------------------------------------------------
/src/sqlacodegen/__main__.py:
--------------------------------------------------------------------------------
1 | from .cli import main
2 |
3 | main()
4 |
--------------------------------------------------------------------------------
/src/sqlacodegen/cli.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import argparse
4 | import ast
5 | import sys
6 | from contextlib import ExitStack
7 | from typing import Any, TextIO
8 |
9 | from sqlalchemy.engine import create_engine
10 | from sqlalchemy.schema import MetaData
11 |
12 | try:
13 | import citext
14 | except ImportError:
15 | citext = None
16 |
17 | try:
18 | import geoalchemy2
19 | except ImportError:
20 | geoalchemy2 = None
21 |
22 | try:
23 | import pgvector.sqlalchemy
24 | except ImportError:
25 | pgvector = None
26 |
27 | if sys.version_info < (3, 10):
28 | from importlib_metadata import entry_points, version
29 | else:
30 | from importlib.metadata import entry_points, version
31 |
32 |
33 | def _parse_engine_arg(arg_str: str) -> tuple[str, Any]:
34 | if "=" not in arg_str:
35 | raise argparse.ArgumentTypeError("engine-arg must be in key=value format")
36 |
37 | key, value = arg_str.split("=", 1)
38 | try:
39 | value = ast.literal_eval(value)
40 | except Exception:
41 | pass # Leave as string if literal_eval fails
42 |
43 | return key, value
44 |
45 |
46 | def _parse_engine_args(arg_list: list[str]) -> dict[str, Any]:
47 | result = {}
48 | for arg in arg_list or []:
49 | key, value = _parse_engine_arg(arg)
50 | result[key] = value
51 |
52 | return result
53 |
54 |
55 | def main() -> None:
56 | generators = {ep.name: ep for ep in entry_points(group="sqlacodegen.generators")}
57 | parser = argparse.ArgumentParser(
58 | description="Generates SQLAlchemy model code from an existing database."
59 | )
60 | parser.add_argument("url", nargs="?", help="SQLAlchemy url to the database")
61 | parser.add_argument(
62 | "--options", help="options (comma-delimited) passed to the generator class"
63 | )
64 | parser.add_argument(
65 | "--version", action="store_true", help="print the version number and exit"
66 | )
67 | parser.add_argument(
68 | "--schemas", help="load tables from the given schemas (comma-delimited)"
69 | )
70 | parser.add_argument(
71 | "--generator",
72 | choices=generators,
73 | default="declarative",
74 | help="generator class to use",
75 | )
76 | parser.add_argument(
77 | "--tables", help="tables to process (comma-delimited, default: all)"
78 | )
79 | parser.add_argument(
80 | "--noviews",
81 | action="store_true",
82 | help="ignore views (always true for sqlmodels generator)",
83 | )
84 | parser.add_argument(
85 | "--engine-arg",
86 | action="append",
87 | help=(
88 | "engine arguments in key=value format, e.g., "
89 | '--engine-arg=connect_args=\'{"user": "scott"}\' '
90 | "--engine-arg thick_mode=true or "
91 | '--engine-arg thick_mode=\'{"lib_dir": "/path"}\' '
92 | "(values are parsed with ast.literal_eval)"
93 | ),
94 | )
95 | parser.add_argument("--outfile", help="file to write output to (default: stdout)")
96 | args = parser.parse_args()
97 |
98 | if args.version:
99 | print(version("sqlacodegen"))
100 | return
101 |
102 | if not args.url:
103 | print("You must supply a url\n", file=sys.stderr)
104 | parser.print_help()
105 | return
106 |
107 | if citext:
108 | print(f"Using sqlalchemy-citext {version('sqlalchemy-citext')}")
109 |
110 | if geoalchemy2:
111 | print(f"Using geoalchemy2 {version('geoalchemy2')}")
112 |
113 | if pgvector:
114 | print(f"Using pgvector {version('pgvector')}")
115 |
116 | # Use reflection to fill in the metadata
117 | engine_args = _parse_engine_args(args.engine_arg)
118 | engine = create_engine(args.url, **engine_args)
119 | metadata = MetaData()
120 | tables = args.tables.split(",") if args.tables else None
121 | schemas = args.schemas.split(",") if args.schemas else [None]
122 | options = set(args.options.split(",")) if args.options else set()
123 |
124 | # Instantiate the generator
125 | generator_class = generators[args.generator].load()
126 | generator = generator_class(metadata, engine, options)
127 |
128 | if not generator.views_supported:
129 | name = generator_class.__name__
130 | print(
131 | f"VIEW models will not be generated when using the '{name}' generator",
132 | file=sys.stderr,
133 | )
134 |
135 | for schema in schemas:
136 | metadata.reflect(
137 | engine, schema, (generator.views_supported and not args.noviews), tables
138 | )
139 |
140 | # Open the target file (if given)
141 | with ExitStack() as stack:
142 | outfile: TextIO
143 | if args.outfile:
144 | outfile = open(args.outfile, "w", encoding="utf-8")
145 | stack.enter_context(outfile)
146 | else:
147 | outfile = sys.stdout
148 |
149 | # Write the generated model code to the specified file or standard output
150 | outfile.write(generator.generate())
151 |
--------------------------------------------------------------------------------
/src/sqlacodegen/generators.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import inspect
4 | import re
5 | import sys
6 | from abc import ABCMeta, abstractmethod
7 | from collections import defaultdict
8 | from collections.abc import Collection, Iterable, Sequence
9 | from dataclasses import dataclass
10 | from importlib import import_module
11 | from inspect import Parameter
12 | from itertools import count
13 | from keyword import iskeyword
14 | from pprint import pformat
15 | from textwrap import indent
16 | from typing import Any, ClassVar, Literal, cast
17 |
18 | import inflect
19 | import sqlalchemy
20 | from sqlalchemy import (
21 | ARRAY,
22 | Boolean,
23 | CheckConstraint,
24 | Column,
25 | Computed,
26 | Constraint,
27 | DefaultClause,
28 | Enum,
29 | ForeignKey,
30 | ForeignKeyConstraint,
31 | Identity,
32 | Index,
33 | MetaData,
34 | PrimaryKeyConstraint,
35 | String,
36 | Table,
37 | Text,
38 | TypeDecorator,
39 | UniqueConstraint,
40 | )
41 | from sqlalchemy.dialects.postgresql import DOMAIN, JSONB
42 | from sqlalchemy.engine import Connection, Engine
43 | from sqlalchemy.exc import CompileError
44 | from sqlalchemy.sql.elements import TextClause
45 | from sqlalchemy.sql.type_api import UserDefinedType
46 | from sqlalchemy.types import TypeEngine
47 |
48 | from .models import (
49 | ColumnAttribute,
50 | JoinType,
51 | Model,
52 | ModelClass,
53 | RelationshipAttribute,
54 | RelationshipType,
55 | )
56 | from .utils import (
57 | decode_postgresql_sequence,
58 | get_column_names,
59 | get_common_fk_constraints,
60 | get_compiled_expression,
61 | get_constraint_sort_key,
62 | qualified_table_name,
63 | render_callable,
64 | uses_default_name,
65 | )
66 |
67 | _re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
68 | _re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
69 | _re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
70 | _re_enum_item = re.compile(r"'(.*?)(? bool:
109 | pass
110 |
111 | @abstractmethod
112 | def generate(self) -> str:
113 | """
114 | Generate the code for the given metadata.
115 | .. note:: May modify the metadata.
116 | """
117 |
118 |
119 | @dataclass(eq=False)
120 | class TablesGenerator(CodeGenerator):
121 | valid_options: ClassVar[set[str]] = {"noindexes", "noconstraints", "nocomments"}
122 | builtin_module_names: ClassVar[set[str]] = set(sys.builtin_module_names) | {
123 | "dataclasses"
124 | }
125 |
126 | def __init__(
127 | self,
128 | metadata: MetaData,
129 | bind: Connection | Engine,
130 | options: Sequence[str],
131 | *,
132 | indentation: str = " ",
133 | ):
134 | super().__init__(metadata, bind, options)
135 | self.indentation: str = indentation
136 | self.imports: dict[str, set[str]] = defaultdict(set)
137 | self.module_imports: set[str] = set()
138 |
139 | @property
140 | def views_supported(self) -> bool:
141 | return True
142 |
143 | def generate_base(self) -> None:
144 | self.base = Base(
145 | literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
146 | declarations=["metadata = MetaData()"],
147 | metadata_ref="metadata",
148 | )
149 |
150 | def generate(self) -> str:
151 | self.generate_base()
152 |
153 | sections: list[str] = []
154 |
155 | # Remove unwanted elements from the metadata
156 | for table in list(self.metadata.tables.values()):
157 | if self.should_ignore_table(table):
158 | self.metadata.remove(table)
159 | continue
160 |
161 | if "noindexes" in self.options:
162 | table.indexes.clear()
163 |
164 | if "noconstraints" in self.options:
165 | table.constraints.clear()
166 |
167 | if "nocomments" in self.options:
168 | table.comment = None
169 |
170 | for column in table.columns:
171 | if "nocomments" in self.options:
172 | column.comment = None
173 |
174 | # Use information from column constraints to figure out the intended column
175 | # types
176 | for table in self.metadata.tables.values():
177 | self.fix_column_types(table)
178 |
179 | # Generate the models
180 | models: list[Model] = self.generate_models()
181 |
182 | # Render module level variables
183 | variables = self.render_module_variables(models)
184 | if variables:
185 | sections.append(variables + "\n")
186 |
187 | # Render models
188 | rendered_models = self.render_models(models)
189 | if rendered_models:
190 | sections.append(rendered_models)
191 |
192 | # Render collected imports
193 | groups = self.group_imports()
194 | imports = "\n\n".join("\n".join(line for line in group) for group in groups)
195 | if imports:
196 | sections.insert(0, imports)
197 |
198 | return "\n\n".join(sections) + "\n"
199 |
200 | def collect_imports(self, models: Iterable[Model]) -> None:
201 | for literal_import in self.base.literal_imports:
202 | self.add_literal_import(literal_import.pkgname, literal_import.name)
203 |
204 | for model in models:
205 | self.collect_imports_for_model(model)
206 |
207 | def collect_imports_for_model(self, model: Model) -> None:
208 | if model.__class__ is Model:
209 | self.add_import(Table)
210 |
211 | for column in model.table.c:
212 | self.collect_imports_for_column(column)
213 |
214 | for constraint in model.table.constraints:
215 | self.collect_imports_for_constraint(constraint)
216 |
217 | for index in model.table.indexes:
218 | self.collect_imports_for_constraint(index)
219 |
220 | def collect_imports_for_column(self, column: Column[Any]) -> None:
221 | self.add_import(column.type)
222 |
223 | if isinstance(column.type, ARRAY):
224 | self.add_import(column.type.item_type.__class__)
225 | elif isinstance(column.type, JSONB):
226 | if (
227 | not isinstance(column.type.astext_type, Text)
228 | or column.type.astext_type.length is not None
229 | ):
230 | self.add_import(column.type.astext_type)
231 | elif isinstance(column.type, DOMAIN):
232 | self.add_import(column.type.data_type.__class__)
233 |
234 | if column.default:
235 | self.add_import(column.default)
236 |
237 | if column.server_default:
238 | if isinstance(column.server_default, (Computed, Identity)):
239 | self.add_import(column.server_default)
240 | elif isinstance(column.server_default, DefaultClause):
241 | self.add_literal_import("sqlalchemy", "text")
242 |
243 | def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
244 | if isinstance(constraint, Index):
245 | if len(constraint.columns) > 1 or not uses_default_name(constraint):
246 | self.add_literal_import("sqlalchemy", "Index")
247 | elif isinstance(constraint, PrimaryKeyConstraint):
248 | if not uses_default_name(constraint):
249 | self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
250 | elif isinstance(constraint, UniqueConstraint):
251 | if len(constraint.columns) > 1 or not uses_default_name(constraint):
252 | self.add_literal_import("sqlalchemy", "UniqueConstraint")
253 | elif isinstance(constraint, ForeignKeyConstraint):
254 | if len(constraint.columns) > 1 or not uses_default_name(constraint):
255 | self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
256 | else:
257 | self.add_import(ForeignKey)
258 | else:
259 | self.add_import(constraint)
260 |
261 | def add_import(self, obj: Any) -> None:
262 | # Don't store builtin imports
263 | if getattr(obj, "__module__", "builtins") == "builtins":
264 | return
265 |
266 | type_ = type(obj) if not isinstance(obj, type) else obj
267 | pkgname = type_.__module__
268 |
269 | # The column types have already been adapted towards generic types if possible,
270 | # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
271 | # use that rather than the generic sqlalchemy type as it might have different
272 | # constructor parameters.
273 | if pkgname.startswith("sqlalchemy.dialects."):
274 | dialect_pkgname = ".".join(pkgname.split(".")[0:3])
275 | dialect_pkg = import_module(dialect_pkgname)
276 |
277 | if type_.__name__ in dialect_pkg.__all__:
278 | pkgname = dialect_pkgname
279 | elif type_.__name__ in dir(sqlalchemy):
280 | pkgname = "sqlalchemy"
281 | else:
282 | pkgname = type_.__module__
283 |
284 | self.add_literal_import(pkgname, type_.__name__)
285 |
286 | def add_literal_import(self, pkgname: str, name: str) -> None:
287 | names = self.imports.setdefault(pkgname, set())
288 | names.add(name)
289 |
290 | def remove_literal_import(self, pkgname: str, name: str) -> None:
291 | names = self.imports.setdefault(pkgname, set())
292 | if name in names:
293 | names.remove(name)
294 |
295 | def add_module_import(self, pgkname: str) -> None:
296 | self.module_imports.add(pgkname)
297 |
298 | def group_imports(self) -> list[list[str]]:
299 | future_imports: list[str] = []
300 | stdlib_imports: list[str] = []
301 | thirdparty_imports: list[str] = []
302 |
303 | for package in sorted(self.imports):
304 | imports = ", ".join(sorted(self.imports[package]))
305 | collection = thirdparty_imports
306 | if package == "__future__":
307 | collection = future_imports
308 | elif package in self.builtin_module_names:
309 | collection = stdlib_imports
310 | elif package in sys.modules:
311 | if "site-packages" not in (sys.modules[package].__file__ or ""):
312 | collection = stdlib_imports
313 |
314 | collection.append(f"from {package} import {imports}")
315 |
316 | for module in sorted(self.module_imports):
317 | thirdparty_imports.append(f"import {module}")
318 |
319 | return [
320 | group
321 | for group in (future_imports, stdlib_imports, thirdparty_imports)
322 | if group
323 | ]
324 |
325 | def generate_models(self) -> list[Model]:
326 | models = [Model(table) for table in self.metadata.sorted_tables]
327 |
328 | # Collect the imports
329 | self.collect_imports(models)
330 |
331 | # Generate names for models
332 | global_names = {
333 | name for namespace in self.imports.values() for name in namespace
334 | }
335 | for model in models:
336 | self.generate_model_name(model, global_names)
337 | global_names.add(model.name)
338 |
339 | return models
340 |
341 | def generate_model_name(self, model: Model, global_names: set[str]) -> None:
342 | preferred_name = f"t_{model.table.name}"
343 | model.name = self.find_free_name(preferred_name, global_names)
344 |
345 | def render_module_variables(self, models: list[Model]) -> str:
346 | declarations = self.base.declarations
347 |
348 | if any(not isinstance(model, ModelClass) for model in models):
349 | if self.base.table_metadata_declaration is not None:
350 | declarations.append(self.base.table_metadata_declaration)
351 |
352 | return "\n".join(declarations)
353 |
354 | def render_models(self, models: list[Model]) -> str:
355 | rendered: list[str] = []
356 | for model in models:
357 | rendered_table = self.render_table(model.table)
358 | rendered.append(f"{model.name} = {rendered_table}")
359 |
360 | return "\n\n".join(rendered)
361 |
362 | def render_table(self, table: Table) -> str:
363 | args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
364 | kwargs: dict[str, object] = {}
365 | for column in table.columns:
366 | # Cast is required because of a bug in the SQLAlchemy stubs regarding
367 | # Table.columns
368 | args.append(self.render_column(column, True, is_table=True))
369 |
370 | for constraint in sorted(table.constraints, key=get_constraint_sort_key):
371 | if uses_default_name(constraint):
372 | if isinstance(constraint, PrimaryKeyConstraint):
373 | continue
374 | elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
375 | if len(constraint.columns) == 1:
376 | continue
377 |
378 | args.append(self.render_constraint(constraint))
379 |
380 | for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
381 | # One-column indexes should be rendered as index=True on columns
382 | if len(index.columns) > 1 or not uses_default_name(index):
383 | args.append(self.render_index(index))
384 |
385 | if table.schema:
386 | kwargs["schema"] = repr(table.schema)
387 |
388 | table_comment = getattr(table, "comment", None)
389 | if table_comment:
390 | kwargs["comment"] = repr(table.comment)
391 |
392 | return render_callable("Table", *args, kwargs=kwargs, indentation=" ")
393 |
394 | def render_index(self, index: Index) -> str:
395 | extra_args = [repr(col.name) for col in index.columns]
396 | kwargs = {}
397 | if index.unique:
398 | kwargs["unique"] = True
399 |
400 | return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
401 |
402 | # TODO find better solution for is_table
403 | def render_column(
404 | self, column: Column[Any], show_name: bool, is_table: bool = False
405 | ) -> str:
406 | args = []
407 | kwargs: dict[str, Any] = {}
408 | kwarg = []
409 | is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
410 | dedicated_fks = [
411 | c
412 | for c in column.foreign_keys
413 | if c.constraint
414 | and len(c.constraint.columns) == 1
415 | and uses_default_name(c.constraint)
416 | ]
417 | is_unique = any(
418 | isinstance(c, UniqueConstraint)
419 | and set(c.columns) == {column}
420 | and uses_default_name(c)
421 | for c in column.table.constraints
422 | )
423 | is_unique = is_unique or any(
424 | i.unique and set(i.columns) == {column} and uses_default_name(i)
425 | for i in column.table.indexes
426 | )
427 | is_primary = (
428 | any(
429 | isinstance(c, PrimaryKeyConstraint)
430 | and column.name in c.columns
431 | and uses_default_name(c)
432 | for c in column.table.constraints
433 | )
434 | or column.primary_key
435 | )
436 | has_index = any(
437 | set(i.columns) == {column} and uses_default_name(i)
438 | for i in column.table.indexes
439 | )
440 |
441 | if show_name:
442 | args.append(repr(column.name))
443 |
444 | # Render the column type if there are no foreign keys on it or any of them
445 | # points back to itself
446 | if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
447 | args.append(self.render_column_type(column.type))
448 |
449 | for fk in dedicated_fks:
450 | args.append(self.render_constraint(fk))
451 |
452 | if column.default:
453 | args.append(repr(column.default))
454 |
455 | if column.key != column.name:
456 | kwargs["key"] = column.key
457 | if is_primary:
458 | kwargs["primary_key"] = True
459 | if not column.nullable and not is_sole_pk and is_table:
460 | kwargs["nullable"] = False
461 |
462 | if is_unique:
463 | column.unique = True
464 | kwargs["unique"] = True
465 | if has_index:
466 | column.index = True
467 | kwarg.append("index")
468 | kwargs["index"] = True
469 |
470 | if isinstance(column.server_default, DefaultClause):
471 | kwargs["server_default"] = render_callable(
472 | "text", repr(cast(TextClause, column.server_default.arg).text)
473 | )
474 | elif isinstance(column.server_default, Computed):
475 | expression = str(column.server_default.sqltext)
476 |
477 | computed_kwargs = {}
478 | if column.server_default.persisted is not None:
479 | computed_kwargs["persisted"] = column.server_default.persisted
480 |
481 | args.append(
482 | render_callable("Computed", repr(expression), kwargs=computed_kwargs)
483 | )
484 | elif isinstance(column.server_default, Identity):
485 | args.append(repr(column.server_default))
486 | elif column.server_default:
487 | kwargs["server_default"] = repr(column.server_default)
488 |
489 | comment = getattr(column, "comment", None)
490 | if comment:
491 | kwargs["comment"] = repr(comment)
492 |
493 | return self.render_column_callable(is_table, *args, **kwargs)
494 |
495 | def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
496 | if is_table:
497 | self.add_import(Column)
498 | return render_callable("Column", *args, kwargs=kwargs)
499 | else:
500 | return render_callable("mapped_column", *args, kwargs=kwargs)
501 |
502 | def render_column_type(self, coltype: object) -> str:
503 | args = []
504 | kwargs: dict[str, Any] = {}
505 | sig = inspect.signature(coltype.__class__.__init__)
506 | defaults = {param.name: param.default for param in sig.parameters.values()}
507 | missing = object()
508 | use_kwargs = False
509 | for param in list(sig.parameters.values())[1:]:
510 | # Remove annoyances like _warn_on_bytestring
511 | if param.name.startswith("_"):
512 | continue
513 | elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
514 | use_kwargs = True
515 | continue
516 |
517 | value = getattr(coltype, param.name, missing)
518 | default = defaults.get(param.name, missing)
519 | if isinstance(value, TextClause):
520 | self.add_literal_import("sqlalchemy", "text")
521 | rendered_value = render_callable("text", repr(value.text))
522 | else:
523 | rendered_value = repr(value)
524 |
525 | if value is missing or value == default:
526 | use_kwargs = True
527 | elif use_kwargs:
528 | kwargs[param.name] = rendered_value
529 | else:
530 | args.append(rendered_value)
531 |
532 | vararg = next(
533 | (
534 | param.name
535 | for param in sig.parameters.values()
536 | if param.kind is Parameter.VAR_POSITIONAL
537 | ),
538 | None,
539 | )
540 | if vararg and hasattr(coltype, vararg):
541 | varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)]
542 | args.extend(varargs_repr)
543 |
544 | # These arguments cannot be autodetected from the Enum initializer
545 | if isinstance(coltype, Enum):
546 | for colname in "name", "schema":
547 | if (value := getattr(coltype, colname)) is not None:
548 | kwargs[colname] = repr(value)
549 |
550 | if isinstance(coltype, JSONB):
551 | # Remove astext_type if it's the default
552 | if (
553 | isinstance(coltype.astext_type, Text)
554 | and coltype.astext_type.length is None
555 | ):
556 | del kwargs["astext_type"]
557 |
558 | if args or kwargs:
559 | return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs)
560 | else:
561 | return coltype.__class__.__name__
562 |
563 | def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
564 | def add_fk_options(*opts: Any) -> None:
565 | args.extend(repr(opt) for opt in opts)
566 | for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
567 | value = getattr(constraint, attr, None)
568 | if value:
569 | kwargs[attr] = repr(value)
570 |
571 | args: list[str] = []
572 | kwargs: dict[str, Any] = {}
573 | if isinstance(constraint, ForeignKey):
574 | remote_column = (
575 | f"{constraint.column.table.fullname}.{constraint.column.name}"
576 | )
577 | add_fk_options(remote_column)
578 | elif isinstance(constraint, ForeignKeyConstraint):
579 | local_columns = get_column_names(constraint)
580 | remote_columns = [
581 | f"{fk.column.table.fullname}.{fk.column.name}"
582 | for fk in constraint.elements
583 | ]
584 | add_fk_options(local_columns, remote_columns)
585 | elif isinstance(constraint, CheckConstraint):
586 | args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
587 | elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
588 | args.extend(repr(col.name) for col in constraint.columns)
589 | else:
590 | raise TypeError(
591 | f"Cannot render constraint of type {constraint.__class__.__name__}"
592 | )
593 |
594 | if isinstance(constraint, Constraint) and not uses_default_name(constraint):
595 | kwargs["name"] = repr(constraint.name)
596 |
597 | return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
598 |
599 | def should_ignore_table(self, table: Table) -> bool:
600 | # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
601 | # tables
602 | return table.name in ("alembic_version", "migrate_version")
603 |
604 | def find_free_name(
605 | self, name: str, global_names: set[str], local_names: Collection[str] = ()
606 | ) -> str:
607 | """
608 | Generate an attribute name that does not clash with other local or global names.
609 | """
610 | name = name.strip()
611 | assert name, "Identifier cannot be empty"
612 | name = _re_invalid_identifier.sub("_", name)
613 | if name[0].isdigit():
614 | name = "_" + name
615 | elif iskeyword(name) or name == "metadata":
616 | name += "_"
617 |
618 | original = name
619 | for i in count():
620 | if name not in global_names and name not in local_names:
621 | break
622 |
623 | name = original + (str(i) if i else "_")
624 |
625 | return name
626 |
627 | def fix_column_types(self, table: Table) -> None:
628 | """Adjust the reflected column types."""
629 | # Detect check constraints for boolean and enum columns
630 | for constraint in table.constraints.copy():
631 | if isinstance(constraint, CheckConstraint):
632 | sqltext = get_compiled_expression(constraint.sqltext, self.bind)
633 |
634 | # Turn any integer-like column with a CheckConstraint like
635 | # "column IN (0, 1)" into a Boolean
636 | match = _re_boolean_check_constraint.match(sqltext)
637 | if match:
638 | colname_match = _re_column_name.match(match.group(1))
639 | if colname_match:
640 | colname = colname_match.group(3)
641 | table.constraints.remove(constraint)
642 | table.c[colname].type = Boolean()
643 | continue
644 |
645 | # Turn any string-type column with a CheckConstraint like
646 | # "column IN (...)" into an Enum
647 | match = _re_enum_check_constraint.match(sqltext)
648 | if match:
649 | colname_match = _re_column_name.match(match.group(1))
650 | if colname_match:
651 | colname = colname_match.group(3)
652 | items = match.group(2)
653 | if isinstance(table.c[colname].type, String):
654 | table.constraints.remove(constraint)
655 | if not isinstance(table.c[colname].type, Enum):
656 | options = _re_enum_item.findall(items)
657 | table.c[colname].type = Enum(
658 | *options, native_enum=False
659 | )
660 |
661 | continue
662 |
663 | for column in table.c:
664 | try:
665 | column.type = self.get_adapted_type(column.type)
666 | except CompileError:
667 | pass
668 |
669 | # PostgreSQL specific fix: detect sequences from server_default
670 | if column.server_default and self.bind.dialect.name == "postgresql":
671 | if isinstance(column.server_default, DefaultClause) and isinstance(
672 | column.server_default.arg, TextClause
673 | ):
674 | schema, seqname = decode_postgresql_sequence(
675 | column.server_default.arg
676 | )
677 | if seqname:
678 | # Add an explicit sequence
679 | if seqname != f"{column.table.name}_{column.name}_seq":
680 | column.default = sqlalchemy.Sequence(seqname, schema=schema)
681 |
682 | column.server_default = None
683 |
684 | def get_adapted_type(self, coltype: Any) -> Any:
685 | compiled_type = coltype.compile(self.bind.engine.dialect)
686 | for supercls in coltype.__class__.__mro__:
687 | if not supercls.__name__.startswith("_") and hasattr(
688 | supercls, "__visit_name__"
689 | ):
690 | # Don't try to adapt UserDefinedType as it's not a proper column type
691 | if supercls is UserDefinedType or issubclass(supercls, TypeDecorator):
692 | return coltype
693 |
694 | # Hack to fix adaptation of the Enum class which is broken since
695 | # SQLAlchemy 1.2
696 | kw = {}
697 | if supercls is Enum:
698 | kw["name"] = coltype.name
699 | if coltype.schema:
700 | kw["schema"] = coltype.schema
701 |
702 | try:
703 | new_coltype = coltype.adapt(supercls)
704 | except TypeError:
705 | # If the adaptation fails, don't try again
706 | break
707 |
708 | for key, value in kw.items():
709 | setattr(new_coltype, key, value)
710 |
711 | if isinstance(coltype, ARRAY):
712 | new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
713 |
714 | try:
715 | # If the adapted column type does not render the same as the
716 | # original, don't substitute it
717 | if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
718 | break
719 | except CompileError:
720 | # If the adapted column type can't be compiled, don't substitute it
721 | break
722 |
723 | # Stop on the first valid non-uppercase column type class
724 | coltype = new_coltype
725 | if supercls.__name__ != supercls.__name__.upper():
726 | break
727 |
728 | return coltype
729 |
730 |
731 | class DeclarativeGenerator(TablesGenerator):
732 | valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
733 | "use_inflect",
734 | "nojoined",
735 | "nobidi",
736 | }
737 |
738 | def __init__(
739 | self,
740 | metadata: MetaData,
741 | bind: Connection | Engine,
742 | options: Sequence[str],
743 | *,
744 | indentation: str = " ",
745 | base_class_name: str = "Base",
746 | ):
747 | super().__init__(metadata, bind, options, indentation=indentation)
748 | self.base_class_name: str = base_class_name
749 | self.inflect_engine = inflect.engine()
750 |
751 | def generate_base(self) -> None:
752 | self.base = Base(
753 | literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
754 | declarations=[
755 | f"class {self.base_class_name}(DeclarativeBase):",
756 | f"{self.indentation}pass",
757 | ],
758 | metadata_ref=f"{self.base_class_name}.metadata",
759 | )
760 |
761 | def collect_imports(self, models: Iterable[Model]) -> None:
762 | super().collect_imports(models)
763 | if any(isinstance(model, ModelClass) for model in models):
764 | self.add_literal_import("sqlalchemy.orm", "Mapped")
765 | self.add_literal_import("sqlalchemy.orm", "mapped_column")
766 |
767 | def collect_imports_for_model(self, model: Model) -> None:
768 | super().collect_imports_for_model(model)
769 | if isinstance(model, ModelClass):
770 | if model.relationships:
771 | self.add_literal_import("sqlalchemy.orm", "relationship")
772 |
773 | def generate_models(self) -> list[Model]:
774 | models_by_table_name: dict[str, Model] = {}
775 |
776 | # Pick association tables from the metadata into their own set, don't process
777 | # them normally
778 | links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
779 | for table in self.metadata.sorted_tables:
780 | qualified_name = qualified_table_name(table)
781 |
782 | # Link tables have exactly two foreign key constraints and all columns are
783 | # involved in them
784 | fk_constraints = sorted(
785 | table.foreign_key_constraints, key=get_constraint_sort_key
786 | )
787 | if len(fk_constraints) == 2 and all(
788 | col.foreign_keys for col in table.columns
789 | ):
790 | model = models_by_table_name[qualified_name] = Model(table)
791 | tablename = fk_constraints[0].elements[0].column.table.name
792 | links[tablename].append(model)
793 | continue
794 |
795 | # Only form model classes for tables that have a primary key and are not
796 | # association tables
797 | if not table.primary_key:
798 | models_by_table_name[qualified_name] = Model(table)
799 | else:
800 | model = ModelClass(table)
801 | models_by_table_name[qualified_name] = model
802 |
803 | # Fill in the columns
804 | for column in table.c:
805 | column_attr = ColumnAttribute(model, column)
806 | model.columns.append(column_attr)
807 |
808 | # Add relationships
809 | for model in models_by_table_name.values():
810 | if isinstance(model, ModelClass):
811 | self.generate_relationships(
812 | model, models_by_table_name, links[model.table.name]
813 | )
814 |
815 | # Nest inherited classes in their superclasses to ensure proper ordering
816 | if "nojoined" not in self.options:
817 | for model in list(models_by_table_name.values()):
818 | if not isinstance(model, ModelClass):
819 | continue
820 |
821 | pk_column_names = {col.name for col in model.table.primary_key.columns}
822 | for constraint in model.table.foreign_key_constraints:
823 | if set(get_column_names(constraint)) == pk_column_names:
824 | target = models_by_table_name[
825 | qualified_table_name(constraint.elements[0].column.table)
826 | ]
827 | if isinstance(target, ModelClass):
828 | model.parent_class = target
829 | target.children.append(model)
830 |
831 | # Change base if we only have tables
832 | if not any(
833 | isinstance(model, ModelClass) for model in models_by_table_name.values()
834 | ):
835 | super().generate_base()
836 |
837 | # Collect the imports
838 | self.collect_imports(models_by_table_name.values())
839 |
840 | # Rename models and their attributes that conflict with imports or other
841 | # attributes
842 | global_names = {
843 | name for namespace in self.imports.values() for name in namespace
844 | }
845 | for model in models_by_table_name.values():
846 | self.generate_model_name(model, global_names)
847 | global_names.add(model.name)
848 |
849 | return list(models_by_table_name.values())
850 |
851 | def generate_relationships(
852 | self,
853 | source: ModelClass,
854 | models_by_table_name: dict[str, Model],
855 | association_tables: list[Model],
856 | ) -> list[RelationshipAttribute]:
857 | relationships: list[RelationshipAttribute] = []
858 | reverse_relationship: RelationshipAttribute | None
859 |
860 | # Add many-to-one (and one-to-many) relationships
861 | pk_column_names = {col.name for col in source.table.primary_key.columns}
862 | for constraint in sorted(
863 | source.table.foreign_key_constraints, key=get_constraint_sort_key
864 | ):
865 | target = models_by_table_name[
866 | qualified_table_name(constraint.elements[0].column.table)
867 | ]
868 | if isinstance(target, ModelClass):
869 | if "nojoined" not in self.options:
870 | if set(get_column_names(constraint)) == pk_column_names:
871 | parent = models_by_table_name[
872 | qualified_table_name(constraint.elements[0].column.table)
873 | ]
874 | if isinstance(parent, ModelClass):
875 | source.parent_class = parent
876 | parent.children.append(source)
877 | continue
878 |
879 | # Add uselist=False to One-to-One relationships
880 | column_names = get_column_names(constraint)
881 | if any(
882 | isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
883 | and {col.name for col in c.columns} == set(column_names)
884 | for c in constraint.table.constraints
885 | ):
886 | r_type = RelationshipType.ONE_TO_ONE
887 | else:
888 | r_type = RelationshipType.MANY_TO_ONE
889 |
890 | relationship = RelationshipAttribute(r_type, source, target, constraint)
891 | source.relationships.append(relationship)
892 |
893 | # For self referential relationships, remote_side needs to be set
894 | if source is target:
895 | relationship.remote_side = [
896 | source.get_column_attribute(col.name)
897 | for col in constraint.referred_table.primary_key
898 | ]
899 |
900 | # If the two tables share more than one foreign key constraint,
901 | # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
902 | # it needs
903 | common_fk_constraints = get_common_fk_constraints(
904 | source.table, target.table
905 | )
906 | if len(common_fk_constraints) > 1:
907 | relationship.foreign_keys = [
908 | source.get_column_attribute(key)
909 | for key in constraint.column_keys
910 | ]
911 |
912 | # Generate the opposite end of the relationship in the target class
913 | if "nobidi" not in self.options:
914 | if r_type is RelationshipType.MANY_TO_ONE:
915 | r_type = RelationshipType.ONE_TO_MANY
916 |
917 | reverse_relationship = RelationshipAttribute(
918 | r_type,
919 | target,
920 | source,
921 | constraint,
922 | foreign_keys=relationship.foreign_keys,
923 | backref=relationship,
924 | )
925 | relationship.backref = reverse_relationship
926 | target.relationships.append(reverse_relationship)
927 |
928 | # For self referential relationships, remote_side needs to be set
929 | if source is target:
930 | reverse_relationship.remote_side = [
931 | source.get_column_attribute(colname)
932 | for colname in constraint.column_keys
933 | ]
934 |
935 | # Add many-to-many relationships
936 | for association_table in association_tables:
937 | fk_constraints = sorted(
938 | association_table.table.foreign_key_constraints,
939 | key=get_constraint_sort_key,
940 | )
941 | target = models_by_table_name[
942 | qualified_table_name(fk_constraints[1].elements[0].column.table)
943 | ]
944 | if isinstance(target, ModelClass):
945 | relationship = RelationshipAttribute(
946 | RelationshipType.MANY_TO_MANY,
947 | source,
948 | target,
949 | fk_constraints[1],
950 | association_table,
951 | )
952 | source.relationships.append(relationship)
953 |
954 | # Generate the opposite end of the relationship in the target class
955 | reverse_relationship = None
956 | if "nobidi" not in self.options:
957 | reverse_relationship = RelationshipAttribute(
958 | RelationshipType.MANY_TO_MANY,
959 | target,
960 | source,
961 | fk_constraints[0],
962 | association_table,
963 | relationship,
964 | )
965 | relationship.backref = reverse_relationship
966 | target.relationships.append(reverse_relationship)
967 |
968 | # Add a primary/secondary join for self-referential many-to-many
969 | # relationships
970 | if source is target:
971 | both_relationships = [relationship]
972 | reverse_flags = [False, True]
973 | if reverse_relationship:
974 | both_relationships.append(reverse_relationship)
975 |
976 | for relationship, reverse in zip(both_relationships, reverse_flags):
977 | if (
978 | not relationship.association_table
979 | or not relationship.constraint
980 | ):
981 | continue
982 |
983 | constraints = sorted(
984 | relationship.constraint.table.foreign_key_constraints,
985 | key=get_constraint_sort_key,
986 | reverse=reverse,
987 | )
988 | pri_pairs = zip(
989 | get_column_names(constraints[0]), constraints[0].elements
990 | )
991 | sec_pairs = zip(
992 | get_column_names(constraints[1]), constraints[1].elements
993 | )
994 | relationship.primaryjoin = [
995 | (
996 | relationship.source,
997 | elem.column.name,
998 | relationship.association_table,
999 | col,
1000 | )
1001 | for col, elem in pri_pairs
1002 | ]
1003 | relationship.secondaryjoin = [
1004 | (
1005 | relationship.target,
1006 | elem.column.name,
1007 | relationship.association_table,
1008 | col,
1009 | )
1010 | for col, elem in sec_pairs
1011 | ]
1012 |
1013 | return relationships
1014 |
1015 | def generate_model_name(self, model: Model, global_names: set[str]) -> None:
1016 | if isinstance(model, ModelClass):
1017 | preferred_name = _re_invalid_identifier.sub("_", model.table.name)
1018 | preferred_name = "".join(
1019 | part[:1].upper() + part[1:] for part in preferred_name.split("_")
1020 | )
1021 | if "use_inflect" in self.options:
1022 | singular_name = self.inflect_engine.singular_noun(preferred_name)
1023 | if singular_name:
1024 | preferred_name = singular_name
1025 |
1026 | model.name = self.find_free_name(preferred_name, global_names)
1027 |
1028 | # Fill in the names for column attributes
1029 | local_names: set[str] = set()
1030 | for column_attr in model.columns:
1031 | self.generate_column_attr_name(column_attr, global_names, local_names)
1032 | local_names.add(column_attr.name)
1033 |
1034 | # Fill in the names for relationship attributes
1035 | for relationship in model.relationships:
1036 | self.generate_relationship_name(relationship, global_names, local_names)
1037 | local_names.add(relationship.name)
1038 | else:
1039 | super().generate_model_name(model, global_names)
1040 |
1041 | def generate_column_attr_name(
1042 | self,
1043 | column_attr: ColumnAttribute,
1044 | global_names: set[str],
1045 | local_names: set[str],
1046 | ) -> None:
1047 | column_attr.name = self.find_free_name(
1048 | column_attr.column.name, global_names, local_names
1049 | )
1050 |
1051 | def generate_relationship_name(
1052 | self,
1053 | relationship: RelationshipAttribute,
1054 | global_names: set[str],
1055 | local_names: set[str],
1056 | ) -> None:
1057 | # Self referential reverse relationships
1058 | preferred_name: str
1059 | if (
1060 | relationship.type
1061 | in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1062 | and relationship.source is relationship.target
1063 | and relationship.backref
1064 | and relationship.backref.name
1065 | ):
1066 | preferred_name = relationship.backref.name + "_reverse"
1067 | else:
1068 | preferred_name = relationship.target.table.name
1069 |
1070 | # If there's a constraint with a single column that ends with "_id", use the
1071 | # preceding part as the relationship name
1072 | if relationship.constraint:
1073 | is_source = relationship.source.table is relationship.constraint.table
1074 | if is_source or relationship.type not in (
1075 | RelationshipType.ONE_TO_ONE,
1076 | RelationshipType.ONE_TO_MANY,
1077 | ):
1078 | column_names = [c.name for c in relationship.constraint.columns]
1079 | if len(column_names) == 1 and column_names[0].endswith("_id"):
1080 | preferred_name = column_names[0][:-3]
1081 |
1082 | if "use_inflect" in self.options:
1083 | inflected_name: str | Literal[False]
1084 | if relationship.type in (
1085 | RelationshipType.ONE_TO_MANY,
1086 | RelationshipType.MANY_TO_MANY,
1087 | ):
1088 | if not self.inflect_engine.singular_noun(preferred_name):
1089 | preferred_name = self.inflect_engine.plural_noun(preferred_name)
1090 | else:
1091 | inflected_name = self.inflect_engine.singular_noun(preferred_name)
1092 | if inflected_name:
1093 | preferred_name = inflected_name
1094 |
1095 | relationship.name = self.find_free_name(
1096 | preferred_name, global_names, local_names
1097 | )
1098 |
1099 | def render_models(self, models: list[Model]) -> str:
1100 | rendered: list[str] = []
1101 | for model in models:
1102 | if isinstance(model, ModelClass):
1103 | rendered.append(self.render_class(model))
1104 | else:
1105 | rendered.append(f"{model.name} = {self.render_table(model.table)}")
1106 |
1107 | return "\n\n\n".join(rendered)
1108 |
1109 | def render_class(self, model: ModelClass) -> str:
1110 | sections: list[str] = []
1111 |
1112 | # Render class variables / special declarations
1113 | class_vars: str = self.render_class_variables(model)
1114 | if class_vars:
1115 | sections.append(class_vars)
1116 |
1117 | # Render column attributes
1118 | rendered_column_attributes: list[str] = []
1119 | for nullable in (False, True):
1120 | for column_attr in model.columns:
1121 | if column_attr.column.nullable is nullable:
1122 | rendered_column_attributes.append(
1123 | self.render_column_attribute(column_attr)
1124 | )
1125 |
1126 | if rendered_column_attributes:
1127 | sections.append("\n".join(rendered_column_attributes))
1128 |
1129 | # Render relationship attributes
1130 | rendered_relationship_attributes: list[str] = [
1131 | self.render_relationship(relationship)
1132 | for relationship in model.relationships
1133 | ]
1134 |
1135 | if rendered_relationship_attributes:
1136 | sections.append("\n".join(rendered_relationship_attributes))
1137 |
1138 | declaration = self.render_class_declaration(model)
1139 | rendered_sections = "\n\n".join(
1140 | indent(section, self.indentation) for section in sections
1141 | )
1142 | return f"{declaration}\n{rendered_sections}"
1143 |
1144 | def render_class_declaration(self, model: ModelClass) -> str:
1145 | parent_class_name = (
1146 | model.parent_class.name if model.parent_class else self.base_class_name
1147 | )
1148 | return f"class {model.name}({parent_class_name}):"
1149 |
1150 | def render_class_variables(self, model: ModelClass) -> str:
1151 | variables = [f"__tablename__ = {model.table.name!r}"]
1152 |
1153 | # Render constraints and indexes as __table_args__
1154 | table_args = self.render_table_args(model.table)
1155 | if table_args:
1156 | variables.append(f"__table_args__ = {table_args}")
1157 |
1158 | return "\n".join(variables)
1159 |
1160 | def render_table_args(self, table: Table) -> str:
1161 | args: list[str] = []
1162 | kwargs: dict[str, str] = {}
1163 |
1164 | # Render constraints
1165 | for constraint in sorted(table.constraints, key=get_constraint_sort_key):
1166 | if uses_default_name(constraint):
1167 | if isinstance(constraint, PrimaryKeyConstraint):
1168 | continue
1169 | if (
1170 | isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1171 | and len(constraint.columns) == 1
1172 | ):
1173 | continue
1174 |
1175 | args.append(self.render_constraint(constraint))
1176 |
1177 | # Render indexes
1178 | for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
1179 | if len(index.columns) > 1 or not uses_default_name(index):
1180 | args.append(self.render_index(index))
1181 |
1182 | if table.schema:
1183 | kwargs["schema"] = table.schema
1184 |
1185 | if table.comment:
1186 | kwargs["comment"] = table.comment
1187 |
1188 | if kwargs:
1189 | formatted_kwargs = pformat(kwargs)
1190 | if not args:
1191 | return formatted_kwargs
1192 | else:
1193 | args.append(formatted_kwargs)
1194 |
1195 | if args:
1196 | rendered_args = f",\n{self.indentation}".join(args)
1197 | if len(args) == 1:
1198 | rendered_args += ","
1199 |
1200 | return f"(\n{self.indentation}{rendered_args}\n)"
1201 | else:
1202 | return ""
1203 |
1204 | def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1205 | column = column_attr.column
1206 | rendered_column = self.render_column(column, column_attr.name != column.name)
1207 |
1208 | def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
1209 | column_type = column.type
1210 | pre: list[str] = []
1211 | post_size = 0
1212 | if column.nullable:
1213 | self.add_literal_import("typing", "Optional")
1214 | pre.append("Optional[")
1215 | post_size += 1
1216 |
1217 | if isinstance(column_type, ARRAY):
1218 | dim = getattr(column_type, "dimensions", None) or 1
1219 | pre.extend("list[" for _ in range(dim))
1220 | post_size += dim
1221 |
1222 | column_type = column_type.item_type
1223 |
1224 | return "".join(pre), column_type, "]" * post_size
1225 |
1226 | def render_python_type(column_type: TypeEngine[Any]) -> str:
1227 | python_type = column_type.python_type
1228 | python_type_name = python_type.__name__
1229 | python_type_module = python_type.__module__
1230 | if python_type_module == "builtins":
1231 | return python_type_name
1232 |
1233 | try:
1234 | self.add_module_import(python_type_module)
1235 | return f"{python_type_module}.{python_type_name}"
1236 | except NotImplementedError:
1237 | self.add_literal_import("typing", "Any")
1238 | return "Any"
1239 |
1240 | pre, col_type, post = get_type_qualifiers()
1241 | column_python_type = f"{pre}{render_python_type(col_type)}{post}"
1242 | return f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}"
1243 |
1244 | def render_relationship(self, relationship: RelationshipAttribute) -> str:
1245 | def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
1246 | rendered = []
1247 | for attr in column_attrs:
1248 | if attr.model is relationship.source:
1249 | rendered.append(attr.name)
1250 | else:
1251 | rendered.append(repr(f"{attr.model.name}.{attr.name}"))
1252 |
1253 | return "[" + ", ".join(rendered) + "]"
1254 |
1255 | def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
1256 | rendered = []
1257 | render_as_string = False
1258 | # Assume that column_attrs are all in relationship.source or none
1259 | for attr in column_attrs:
1260 | if attr.model is relationship.source:
1261 | rendered.append(attr.name)
1262 | else:
1263 | rendered.append(f"{attr.model.name}.{attr.name}")
1264 | render_as_string = True
1265 |
1266 | if render_as_string:
1267 | return "'[" + ", ".join(rendered) + "]'"
1268 | else:
1269 | return "[" + ", ".join(rendered) + "]"
1270 |
1271 | def render_join(terms: list[JoinType]) -> str:
1272 | rendered_joins = []
1273 | for source, source_col, target, target_col in terms:
1274 | rendered = f"lambda: {source.name}.{source_col} == {target.name}."
1275 | if target.__class__ is Model:
1276 | rendered += "c."
1277 |
1278 | rendered += str(target_col)
1279 | rendered_joins.append(rendered)
1280 |
1281 | if len(rendered_joins) > 1:
1282 | rendered = ", ".join(rendered_joins)
1283 | return f"and_({rendered})"
1284 | else:
1285 | return rendered_joins[0]
1286 |
1287 | # Render keyword arguments
1288 | kwargs: dict[str, Any] = {}
1289 | if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
1290 | if relationship.constraint.referred_table is relationship.source.table:
1291 | kwargs["uselist"] = False
1292 |
1293 | # Add the "secondary" keyword for many-to-many relationships
1294 | if relationship.association_table:
1295 | table_ref = relationship.association_table.table.name
1296 | if relationship.association_table.schema:
1297 | table_ref = f"{relationship.association_table.schema}.{table_ref}"
1298 |
1299 | kwargs["secondary"] = repr(table_ref)
1300 |
1301 | if relationship.remote_side:
1302 | kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
1303 |
1304 | if relationship.foreign_keys:
1305 | kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
1306 |
1307 | if relationship.primaryjoin:
1308 | kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
1309 |
1310 | if relationship.secondaryjoin:
1311 | kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
1312 |
1313 | if relationship.backref:
1314 | kwargs["back_populates"] = repr(relationship.backref.name)
1315 |
1316 | rendered_relationship = render_callable(
1317 | "relationship", repr(relationship.target.name), kwargs=kwargs
1318 | )
1319 |
1320 | relationship_type: str
1321 | if relationship.type == RelationshipType.ONE_TO_MANY:
1322 | relationship_type = f"list['{relationship.target.name}']"
1323 | elif relationship.type in (
1324 | RelationshipType.ONE_TO_ONE,
1325 | RelationshipType.MANY_TO_ONE,
1326 | ):
1327 | relationship_type = f"'{relationship.target.name}'"
1328 | if relationship.constraint and any(
1329 | col.nullable for col in relationship.constraint.columns
1330 | ):
1331 | self.add_literal_import("typing", "Optional")
1332 | relationship_type = f"Optional[{relationship_type}]"
1333 | elif relationship.type == RelationshipType.MANY_TO_MANY:
1334 | relationship_type = f"list['{relationship.target.name}']"
1335 | else:
1336 | self.add_literal_import("typing", "Any")
1337 | relationship_type = "Any"
1338 |
1339 | return (
1340 | f"{relationship.name}: Mapped[{relationship_type}] "
1341 | f"= {rendered_relationship}"
1342 | )
1343 |
1344 |
1345 | class DataclassGenerator(DeclarativeGenerator):
1346 | def __init__(
1347 | self,
1348 | metadata: MetaData,
1349 | bind: Connection | Engine,
1350 | options: Sequence[str],
1351 | *,
1352 | indentation: str = " ",
1353 | base_class_name: str = "Base",
1354 | quote_annotations: bool = False,
1355 | metadata_key: str = "sa",
1356 | ):
1357 | super().__init__(
1358 | metadata,
1359 | bind,
1360 | options,
1361 | indentation=indentation,
1362 | base_class_name=base_class_name,
1363 | )
1364 | self.metadata_key: str = metadata_key
1365 | self.quote_annotations: bool = quote_annotations
1366 |
1367 | def generate_base(self) -> None:
1368 | self.base = Base(
1369 | literal_imports=[
1370 | LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1371 | LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1372 | ],
1373 | declarations=[
1374 | (f"class {self.base_class_name}(MappedAsDataclass, DeclarativeBase):"),
1375 | f"{self.indentation}pass",
1376 | ],
1377 | metadata_ref=f"{self.base_class_name}.metadata",
1378 | )
1379 |
1380 |
1381 | class SQLModelGenerator(DeclarativeGenerator):
1382 | def __init__(
1383 | self,
1384 | metadata: MetaData,
1385 | bind: Connection | Engine,
1386 | options: Sequence[str],
1387 | *,
1388 | indentation: str = " ",
1389 | base_class_name: str = "SQLModel",
1390 | ):
1391 | super().__init__(
1392 | metadata,
1393 | bind,
1394 | options,
1395 | indentation=indentation,
1396 | base_class_name=base_class_name,
1397 | )
1398 |
1399 | @property
1400 | def views_supported(self) -> bool:
1401 | return False
1402 |
1403 | def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
1404 | self.add_import(Column)
1405 | return render_callable("Column", *args, kwargs=kwargs)
1406 |
1407 | def generate_base(self) -> None:
1408 | self.base = Base(
1409 | literal_imports=[],
1410 | declarations=[],
1411 | metadata_ref="",
1412 | )
1413 |
1414 | def collect_imports(self, models: Iterable[Model]) -> None:
1415 | super(DeclarativeGenerator, self).collect_imports(models)
1416 | if any(isinstance(model, ModelClass) for model in models):
1417 | self.remove_literal_import("sqlalchemy", "MetaData")
1418 | self.add_literal_import("sqlmodel", "SQLModel")
1419 | self.add_literal_import("sqlmodel", "Field")
1420 |
1421 | def collect_imports_for_model(self, model: Model) -> None:
1422 | super(DeclarativeGenerator, self).collect_imports_for_model(model)
1423 | if isinstance(model, ModelClass):
1424 | for column_attr in model.columns:
1425 | if column_attr.column.nullable:
1426 | self.add_literal_import("typing", "Optional")
1427 | break
1428 |
1429 | if model.relationships:
1430 | self.add_literal_import("sqlmodel", "Relationship")
1431 |
1432 | def collect_imports_for_column(self, column: Column[Any]) -> None:
1433 | super().collect_imports_for_column(column)
1434 | try:
1435 | python_type = column.type.python_type
1436 | except NotImplementedError:
1437 | self.add_literal_import("typing", "Any")
1438 | else:
1439 | self.add_import(python_type)
1440 |
1441 | def render_module_variables(self, models: list[Model]) -> str:
1442 | declarations: list[str] = []
1443 | if any(not isinstance(model, ModelClass) for model in models):
1444 | if self.base.table_metadata_declaration is not None:
1445 | declarations.append(self.base.table_metadata_declaration)
1446 |
1447 | return "\n".join(declarations)
1448 |
1449 | def render_class_declaration(self, model: ModelClass) -> str:
1450 | if model.parent_class:
1451 | parent = model.parent_class.name
1452 | else:
1453 | parent = self.base_class_name
1454 |
1455 | superclass_part = f"({parent}, table=True)"
1456 | return f"class {model.name}{superclass_part}:"
1457 |
1458 | def render_class_variables(self, model: ModelClass) -> str:
1459 | variables = []
1460 |
1461 | if model.table.name != model.name.lower():
1462 | variables.append(f"__tablename__ = {model.table.name!r}")
1463 |
1464 | # Render constraints and indexes as __table_args__
1465 | table_args = self.render_table_args(model.table)
1466 | if table_args:
1467 | variables.append(f"__table_args__ = {table_args}")
1468 |
1469 | return "\n".join(variables)
1470 |
1471 | def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1472 | column = column_attr.column
1473 | try:
1474 | python_type = column.type.python_type
1475 | except NotImplementedError:
1476 | python_type_name = "Any"
1477 | else:
1478 | python_type_name = python_type.__name__
1479 |
1480 | kwargs: dict[str, Any] = {}
1481 | if (
1482 | column.autoincrement and column.name in column.table.primary_key
1483 | ) or column.nullable:
1484 | self.add_literal_import("typing", "Optional")
1485 | kwargs["default"] = None
1486 | python_type_name = f"Optional[{python_type_name}]"
1487 |
1488 | rendered_column = self.render_column(column, True)
1489 | kwargs["sa_column"] = f"{rendered_column}"
1490 | rendered_field = render_callable("Field", kwargs=kwargs)
1491 | return f"{column_attr.name}: {python_type_name} = {rendered_field}"
1492 |
1493 | def render_relationship(self, relationship: RelationshipAttribute) -> str:
1494 | rendered = super().render_relationship(relationship).partition(" = ")[2]
1495 | args = self.render_relationship_args(rendered)
1496 | kwargs: dict[str, Any] = {}
1497 | annotation = repr(relationship.target.name)
1498 |
1499 | if relationship.type in (
1500 | RelationshipType.ONE_TO_MANY,
1501 | RelationshipType.MANY_TO_MANY,
1502 | ):
1503 | annotation = f"list[{annotation}]"
1504 | else:
1505 | self.add_literal_import("typing", "Optional")
1506 | annotation = f"Optional[{annotation}]"
1507 |
1508 | rendered_field = render_callable("Relationship", *args, kwargs=kwargs)
1509 | return f"{relationship.name}: {annotation} = {rendered_field}"
1510 |
1511 | def render_relationship_args(self, arguments: str) -> list[str]:
1512 | argument_list = arguments.split(",")
1513 | # delete ')' and ' ' from args
1514 | argument_list[-1] = argument_list[-1][:-1]
1515 | argument_list = [argument[1:] for argument in argument_list]
1516 |
1517 | rendered_args: list[str] = []
1518 | for arg in argument_list:
1519 | if "back_populates" in arg:
1520 | rendered_args.append(arg)
1521 | if "uselist=False" in arg:
1522 | rendered_args.append("sa_relationship_kwargs={'uselist': False}")
1523 |
1524 | return rendered_args
1525 |
--------------------------------------------------------------------------------
/src/sqlacodegen/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass, field
4 | from enum import Enum, auto
5 | from typing import Any, Union
6 |
7 | from sqlalchemy.sql.schema import Column, ForeignKeyConstraint, Table
8 |
9 |
10 | @dataclass
11 | class Model:
12 | table: Table
13 | name: str = field(init=False, default="")
14 |
15 | @property
16 | def schema(self) -> str | None:
17 | return self.table.schema
18 |
19 |
20 | @dataclass
21 | class ModelClass(Model):
22 | columns: list[ColumnAttribute] = field(default_factory=list)
23 | relationships: list[RelationshipAttribute] = field(default_factory=list)
24 | parent_class: ModelClass | None = None
25 | children: list[ModelClass] = field(default_factory=list)
26 |
27 | def get_column_attribute(self, column_name: str) -> ColumnAttribute:
28 | for column in self.columns:
29 | if column.column.name == column_name:
30 | return column
31 |
32 | raise LookupError(f"Cannot find column attribute for {column_name!r}")
33 |
34 |
35 | class RelationshipType(Enum):
36 | ONE_TO_ONE = auto()
37 | ONE_TO_MANY = auto()
38 | MANY_TO_ONE = auto()
39 | MANY_TO_MANY = auto()
40 |
41 |
42 | @dataclass
43 | class ColumnAttribute:
44 | model: ModelClass
45 | column: Column[Any]
46 | name: str = field(init=False, default="")
47 |
48 | def __repr__(self) -> str:
49 | return f"{self.__class__.__name__}(name={self.name!r}, type={self.column.type})"
50 |
51 | def __str__(self) -> str:
52 | return self.name
53 |
54 |
55 | JoinType = tuple[Model, Union[ColumnAttribute, str], Model, Union[ColumnAttribute, str]]
56 |
57 |
58 | @dataclass
59 | class RelationshipAttribute:
60 | type: RelationshipType
61 | source: ModelClass
62 | target: ModelClass
63 | constraint: ForeignKeyConstraint | None = None
64 | association_table: Model | None = None
65 | backref: RelationshipAttribute | None = None
66 | remote_side: list[ColumnAttribute] = field(default_factory=list)
67 | foreign_keys: list[ColumnAttribute] = field(default_factory=list)
68 | primaryjoin: list[JoinType] = field(default_factory=list)
69 | secondaryjoin: list[JoinType] = field(default_factory=list)
70 | name: str = field(init=False, default="")
71 |
72 | def __repr__(self) -> str:
73 | return (
74 | f"{self.__class__.__name__}(name={self.name!r}, type={self.type}, "
75 | f"target={self.target.name})"
76 | )
77 |
78 | def __str__(self) -> str:
79 | return self.name
80 |
--------------------------------------------------------------------------------
/src/sqlacodegen/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/agronholm/sqlacodegen/84dcc39d5cef1a840234624603857ab72ec333a0/src/sqlacodegen/py.typed
--------------------------------------------------------------------------------
/src/sqlacodegen/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 | from collections.abc import Mapping
5 | from typing import Any, Literal, cast
6 |
7 | from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
8 | from sqlalchemy.engine import Connection, Engine
9 | from sqlalchemy.sql import ClauseElement
10 | from sqlalchemy.sql.elements import TextClause
11 | from sqlalchemy.sql.schema import (
12 | CheckConstraint,
13 | ColumnCollectionConstraint,
14 | Constraint,
15 | ForeignKeyConstraint,
16 | Index,
17 | Table,
18 | )
19 |
20 | _re_postgresql_nextval_sequence = re.compile(r"nextval\('(.+)'::regclass\)")
21 | _re_postgresql_sequence_delimiter = re.compile(r'(.*?)([."]|$)')
22 |
23 |
24 | def get_column_names(constraint: ColumnCollectionConstraint) -> list[str]:
25 | return list(constraint.columns.keys())
26 |
27 |
28 | def get_constraint_sort_key(constraint: Constraint) -> str:
29 | if isinstance(constraint, CheckConstraint):
30 | return f"C{constraint.sqltext}"
31 | elif isinstance(constraint, ColumnCollectionConstraint):
32 | return constraint.__class__.__name__[0] + repr(get_column_names(constraint))
33 | else:
34 | return str(constraint)
35 |
36 |
37 | def get_compiled_expression(statement: ClauseElement, bind: Engine | Connection) -> str:
38 | """Return the statement in a form where any placeholders have been filled in."""
39 | return str(statement.compile(bind, compile_kwargs={"literal_binds": True}))
40 |
41 |
42 | def get_common_fk_constraints(
43 | table1: Table, table2: Table
44 | ) -> set[ForeignKeyConstraint]:
45 | """
46 | Return a set of foreign key constraints the two tables have against each other.
47 |
48 | """
49 | c1 = {
50 | c
51 | for c in table1.constraints
52 | if isinstance(c, ForeignKeyConstraint) and c.elements[0].column.table == table2
53 | }
54 | c2 = {
55 | c
56 | for c in table2.constraints
57 | if isinstance(c, ForeignKeyConstraint) and c.elements[0].column.table == table1
58 | }
59 | return c1.union(c2)
60 |
61 |
62 | def uses_default_name(constraint: Constraint | Index) -> bool:
63 | if not constraint.name or constraint.table is None:
64 | return True
65 |
66 | table = constraint.table
67 | values: dict[str, Any] = {
68 | "table_name": table.name,
69 | "constraint_name": constraint.name,
70 | }
71 | if isinstance(constraint, (Index, ColumnCollectionConstraint)):
72 | values.update(
73 | {
74 | "column_0N_name": "".join(col.name for col in constraint.columns),
75 | "column_0_N_name": "_".join(col.name for col in constraint.columns),
76 | "column_0N_label": "".join(
77 | col.label(col.name).name for col in constraint.columns
78 | ),
79 | "column_0_N_label": "_".join(
80 | col.label(col.name).name for col in constraint.columns
81 | ),
82 | "column_0N_key": "".join(
83 | col.key for col in constraint.columns if col.key
84 | ),
85 | "column_0_N_key": "_".join(
86 | col.key for col in constraint.columns if col.key
87 | ),
88 | }
89 | )
90 | if constraint.columns:
91 | columns = constraint.columns.values()
92 | values.update(
93 | {
94 | "column_0_name": columns[0].name,
95 | "column_0_label": columns[0].label(columns[0].name).name,
96 | "column_0_key": columns[0].key,
97 | }
98 | )
99 |
100 | key: Literal["fk", "pk", "ix", "ck", "uq"]
101 | if isinstance(constraint, Index):
102 | key = "ix"
103 | elif isinstance(constraint, CheckConstraint):
104 | key = "ck"
105 | elif isinstance(constraint, UniqueConstraint):
106 | key = "uq"
107 | elif isinstance(constraint, PrimaryKeyConstraint):
108 | key = "pk"
109 | elif isinstance(constraint, ForeignKeyConstraint):
110 | key = "fk"
111 | values.update(
112 | {
113 | "referred_table_name": constraint.referred_table,
114 | "referred_column_0_name": constraint.elements[0].column.name,
115 | "referred_column_0N_name": "".join(
116 | fk.column.name for fk in constraint.elements
117 | ),
118 | "referred_column_0_N_name": "_".join(
119 | fk.column.name for fk in constraint.elements
120 | ),
121 | "referred_column_0_label": constraint.elements[0]
122 | .column.label(constraint.elements[0].column.name)
123 | .name,
124 | "referred_fk.column_0N_label": "".join(
125 | fk.column.label(fk.column.name).name for fk in constraint.elements
126 | ),
127 | "referred_fk.column_0_N_label": "_".join(
128 | fk.column.label(fk.column.name).name for fk in constraint.elements
129 | ),
130 | "referred_fk.column_0_key": constraint.elements[0].column.key,
131 | "referred_fk.column_0N_key": "".join(
132 | fk.column.key for fk in constraint.elements if fk.column.key
133 | ),
134 | "referred_fk.column_0_N_key": "_".join(
135 | fk.column.key for fk in constraint.elements if fk.column.key
136 | ),
137 | }
138 | )
139 | else:
140 | raise TypeError(f"Unknown constraint type: {constraint.__class__.__qualname__}")
141 |
142 | try:
143 | convention = cast(
144 | Mapping[str, str],
145 | table.metadata.naming_convention,
146 | )[key]
147 | return constraint.name == (convention % values)
148 | except KeyError:
149 | return False
150 |
151 |
152 | def render_callable(
153 | name: str,
154 | *args: object,
155 | kwargs: Mapping[str, object] | None = None,
156 | indentation: str = "",
157 | ) -> str:
158 | """
159 | Render a function call.
160 |
161 | :param name: name of the callable
162 | :param args: positional arguments
163 | :param kwargs: keyword arguments
164 | :param indentation: if given, each argument will be rendered on its own line with
165 | this value used as the indentation
166 |
167 | """
168 | if kwargs:
169 | args += tuple(f"{key}={value}" for key, value in kwargs.items())
170 |
171 | if indentation:
172 | prefix = f"\n{indentation}"
173 | suffix = "\n"
174 | delimiter = f",\n{indentation}"
175 | else:
176 | prefix = suffix = ""
177 | delimiter = ", "
178 |
179 | rendered_args = delimiter.join(str(arg) for arg in args)
180 | return f"{name}({prefix}{rendered_args}{suffix})"
181 |
182 |
183 | def qualified_table_name(table: Table) -> str:
184 | if table.schema:
185 | return f"{table.schema}.{table.name}"
186 | else:
187 | return str(table.name)
188 |
189 |
190 | def decode_postgresql_sequence(clause: TextClause) -> tuple[str | None, str | None]:
191 | match = _re_postgresql_nextval_sequence.match(clause.text)
192 | if not match:
193 | return None, None
194 |
195 | schema: str | None = None
196 | sequence: str = ""
197 | in_quotes = False
198 | for match in _re_postgresql_sequence_delimiter.finditer(match.group(1)):
199 | sequence += match.group(1)
200 | if match.group(2) == '"':
201 | in_quotes = not in_quotes
202 | elif match.group(2) == ".":
203 | if in_quotes:
204 | sequence += "."
205 | else:
206 | schema, sequence = sequence, ""
207 |
208 | return schema, sequence
209 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/agronholm/sqlacodegen/84dcc39d5cef1a840234624603857ab72ec333a0/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from textwrap import dedent
2 |
3 | import pytest
4 | from pytest import FixtureRequest
5 | from sqlalchemy.engine import Engine, create_engine
6 | from sqlalchemy.orm import clear_mappers, configure_mappers
7 | from sqlalchemy.schema import MetaData
8 |
9 |
10 | @pytest.fixture
11 | def engine(request: FixtureRequest) -> Engine:
12 | dialect = getattr(request, "param", None)
13 | if dialect == "postgresql":
14 | return create_engine("postgresql+psycopg:///testdb")
15 | elif dialect == "mysql":
16 | return create_engine("mysql+mysqlconnector://testdb")
17 | else:
18 | return create_engine("sqlite:///:memory:")
19 |
20 |
21 | @pytest.fixture
22 | def metadata() -> MetaData:
23 | return MetaData()
24 |
25 |
26 | def validate_code(generated_code: str, expected_code: str) -> None:
27 | expected_code = dedent(expected_code)
28 | assert generated_code == expected_code
29 | try:
30 | exec(generated_code, {})
31 | configure_mappers()
32 | finally:
33 | clear_mappers()
34 |
--------------------------------------------------------------------------------
/tests/test_cli.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import sqlite3
4 | import subprocess
5 | import sys
6 | from importlib.metadata import version
7 | from pathlib import Path
8 |
9 | import pytest
10 |
11 | future_imports = "from __future__ import annotations\n\n"
12 |
13 |
14 | @pytest.fixture
15 | def db_path(tmp_path: Path) -> Path:
16 | path = tmp_path / "test.db"
17 | with sqlite3.connect(str(path)) as conn:
18 | cursor = conn.cursor()
19 | cursor.execute(
20 | "CREATE TABLE foo (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL)"
21 | )
22 |
23 | return path
24 |
25 |
26 | def test_cli_tables(db_path: Path, tmp_path: Path) -> None:
27 | output_path = tmp_path / "outfile"
28 | subprocess.run(
29 | [
30 | "sqlacodegen",
31 | f"sqlite:///{db_path}",
32 | "--generator",
33 | "tables",
34 | "--outfile",
35 | str(output_path),
36 | ],
37 | check=True,
38 | )
39 |
40 | assert (
41 | output_path.read_text()
42 | == """\
43 | from sqlalchemy import Column, Integer, MetaData, Table, Text
44 |
45 | metadata = MetaData()
46 |
47 |
48 | t_foo = Table(
49 | 'foo', metadata,
50 | Column('id', Integer, primary_key=True),
51 | Column('name', Text, nullable=False)
52 | )
53 | """
54 | )
55 |
56 |
57 | def test_cli_declarative(db_path: Path, tmp_path: Path) -> None:
58 | output_path = tmp_path / "outfile"
59 | subprocess.run(
60 | [
61 | "sqlacodegen",
62 | f"sqlite:///{db_path}",
63 | "--generator",
64 | "declarative",
65 | "--outfile",
66 | str(output_path),
67 | ],
68 | check=True,
69 | )
70 |
71 | assert (
72 | output_path.read_text()
73 | == """\
74 | from sqlalchemy import Integer, Text
75 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
76 |
77 | class Base(DeclarativeBase):
78 | pass
79 |
80 |
81 | class Foo(Base):
82 | __tablename__ = 'foo'
83 |
84 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
85 | name: Mapped[str] = mapped_column(Text)
86 | """
87 | )
88 |
89 |
90 | def test_cli_dataclass(db_path: Path, tmp_path: Path) -> None:
91 | output_path = tmp_path / "outfile"
92 | subprocess.run(
93 | [
94 | "sqlacodegen",
95 | f"sqlite:///{db_path}",
96 | "--generator",
97 | "dataclasses",
98 | "--outfile",
99 | str(output_path),
100 | ],
101 | check=True,
102 | )
103 |
104 | assert (
105 | output_path.read_text()
106 | == """\
107 | from sqlalchemy import Integer, Text
108 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
109 |
110 | class Base(MappedAsDataclass, DeclarativeBase):
111 | pass
112 |
113 |
114 | class Foo(Base):
115 | __tablename__ = 'foo'
116 |
117 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
118 | name: Mapped[str] = mapped_column(Text)
119 | """
120 | )
121 |
122 |
123 | def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None:
124 | output_path = tmp_path / "outfile"
125 | subprocess.run(
126 | [
127 | "sqlacodegen",
128 | f"sqlite:///{db_path}",
129 | "--generator",
130 | "sqlmodels",
131 | "--outfile",
132 | str(output_path),
133 | ],
134 | check=True,
135 | )
136 |
137 | assert (
138 | output_path.read_text()
139 | == """\
140 | from typing import Optional
141 |
142 | from sqlalchemy import Column, Integer, Text
143 | from sqlmodel import Field, SQLModel
144 |
145 | class Foo(SQLModel, table=True):
146 | id: Optional[int] = Field(default=None, sa_column=Column('id', Integer, \
147 | primary_key=True))
148 | name: str = Field(sa_column=Column('name', Text))
149 | """
150 | )
151 |
152 |
153 | def test_cli_engine_arg(db_path: Path, tmp_path: Path) -> None:
154 | output_path = tmp_path / "outfile"
155 | subprocess.run(
156 | [
157 | "sqlacodegen",
158 | f"sqlite:///{db_path}",
159 | "--generator",
160 | "tables",
161 | "--engine-arg",
162 | 'connect_args={"timeout": 10}',
163 | "--outfile",
164 | str(output_path),
165 | ],
166 | check=True,
167 | )
168 |
169 | assert (
170 | output_path.read_text()
171 | == """\
172 | from sqlalchemy import Column, Integer, MetaData, Table, Text
173 |
174 | metadata = MetaData()
175 |
176 |
177 | t_foo = Table(
178 | 'foo', metadata,
179 | Column('id', Integer, primary_key=True),
180 | Column('name', Text, nullable=False)
181 | )
182 | """
183 | )
184 |
185 |
186 | def test_cli_invalid_engine_arg(db_path: Path, tmp_path: Path) -> None:
187 | output_path = tmp_path / "outfile"
188 |
189 | # Expect exception:
190 | # TypeError: 'this_arg_does_not_exist' is an invalid keyword argument for Connection()
191 | with pytest.raises(subprocess.CalledProcessError) as exc_info:
192 | subprocess.run(
193 | [
194 | "sqlacodegen",
195 | f"sqlite:///{db_path}",
196 | "--generator",
197 | "tables",
198 | "--engine-arg",
199 | 'connect_args={"this_arg_does_not_exist": 10}',
200 | "--outfile",
201 | str(output_path),
202 | ],
203 | check=True,
204 | capture_output=True,
205 | )
206 |
207 | if sys.version_info < (3, 13):
208 | assert (
209 | "'this_arg_does_not_exist' is an invalid keyword argument"
210 | in exc_info.value.stderr.decode()
211 | )
212 | else:
213 | assert (
214 | "got an unexpected keyword argument 'this_arg_does_not_exist'"
215 | in exc_info.value.stderr.decode()
216 | )
217 |
218 |
219 | def test_main() -> None:
220 | expected_version = version("sqlacodegen")
221 | completed = subprocess.run(
222 | [sys.executable, "-m", "sqlacodegen", "--version"],
223 | stdout=subprocess.PIPE,
224 | check=True,
225 | )
226 | assert completed.stdout.decode().strip() == expected_version
227 |
--------------------------------------------------------------------------------
/tests/test_generator_dataclass.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 | from _pytest.fixtures import FixtureRequest
5 | from sqlalchemy.dialects.postgresql import UUID
6 | from sqlalchemy.engine import Engine
7 | from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table
8 | from sqlalchemy.sql.expression import text
9 | from sqlalchemy.types import INTEGER, VARCHAR
10 |
11 | from sqlacodegen.generators import CodeGenerator, DataclassGenerator
12 |
13 | from .conftest import validate_code
14 |
15 |
16 | @pytest.fixture
17 | def generator(
18 | request: FixtureRequest, metadata: MetaData, engine: Engine
19 | ) -> CodeGenerator:
20 | options = getattr(request, "param", [])
21 | return DataclassGenerator(metadata, engine, options)
22 |
23 |
24 | def test_basic_class(generator: CodeGenerator) -> None:
25 | Table(
26 | "simple",
27 | generator.metadata,
28 | Column("id", INTEGER, primary_key=True),
29 | Column("name", VARCHAR(20)),
30 | )
31 |
32 | validate_code(
33 | generator.generate(),
34 | """\
35 | from typing import Optional
36 |
37 | from sqlalchemy import Integer, String
38 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
39 | mapped_column
40 |
41 | class Base(MappedAsDataclass, DeclarativeBase):
42 | pass
43 |
44 |
45 | class Simple(Base):
46 | __tablename__ = 'simple'
47 |
48 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
49 | name: Mapped[Optional[str]] = mapped_column(String(20))
50 | """,
51 | )
52 |
53 |
54 | def test_mandatory_field_last(generator: CodeGenerator) -> None:
55 | Table(
56 | "simple",
57 | generator.metadata,
58 | Column("id", INTEGER, primary_key=True),
59 | Column("name", VARCHAR(20), server_default=text("foo")),
60 | Column("age", INTEGER, nullable=False),
61 | )
62 |
63 | validate_code(
64 | generator.generate(),
65 | """\
66 | from typing import Optional
67 |
68 | from sqlalchemy import Integer, String, text
69 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
70 | mapped_column
71 |
72 | class Base(MappedAsDataclass, DeclarativeBase):
73 | pass
74 |
75 |
76 | class Simple(Base):
77 | __tablename__ = 'simple'
78 |
79 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
80 | age: Mapped[int] = mapped_column(Integer)
81 | name: Mapped[Optional[str]] = mapped_column(String(20), \
82 | server_default=text('foo'))
83 | """,
84 | )
85 |
86 |
87 | def test_onetomany_optional(generator: CodeGenerator) -> None:
88 | Table(
89 | "simple_items",
90 | generator.metadata,
91 | Column("id", INTEGER, primary_key=True),
92 | Column("container_id", INTEGER),
93 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]),
94 | )
95 | Table(
96 | "simple_containers",
97 | generator.metadata,
98 | Column("id", INTEGER, primary_key=True),
99 | )
100 |
101 | validate_code(
102 | generator.generate(),
103 | """\
104 | from typing import Optional
105 |
106 | from sqlalchemy import ForeignKey, Integer
107 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
108 | mapped_column, relationship
109 |
110 | class Base(MappedAsDataclass, DeclarativeBase):
111 | pass
112 |
113 |
114 | class SimpleContainers(Base):
115 | __tablename__ = 'simple_containers'
116 |
117 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
118 |
119 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
120 | back_populates='container')
121 |
122 |
123 | class SimpleItems(Base):
124 | __tablename__ = 'simple_items'
125 |
126 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
127 | container_id: Mapped[Optional[int]] = \
128 | mapped_column(ForeignKey('simple_containers.id'))
129 |
130 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \
131 | back_populates='simple_items')
132 | """,
133 | )
134 |
135 |
136 | def test_manytomany(generator: CodeGenerator) -> None:
137 | Table("simple_items", generator.metadata, Column("id", INTEGER, primary_key=True))
138 | Table(
139 | "simple_containers",
140 | generator.metadata,
141 | Column("id", INTEGER, primary_key=True),
142 | )
143 | Table(
144 | "container_items",
145 | generator.metadata,
146 | Column("item_id", INTEGER),
147 | Column("container_id", INTEGER),
148 | ForeignKeyConstraint(["item_id"], ["simple_items.id"]),
149 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]),
150 | )
151 |
152 | validate_code(
153 | generator.generate(),
154 | """\
155 | from sqlalchemy import Column, ForeignKey, Integer, Table
156 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
157 | mapped_column, relationship
158 |
159 | class Base(MappedAsDataclass, DeclarativeBase):
160 | pass
161 |
162 |
163 | class SimpleContainers(Base):
164 | __tablename__ = 'simple_containers'
165 |
166 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
167 |
168 | item: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
169 | secondary='container_items', back_populates='container')
170 |
171 |
172 | class SimpleItems(Base):
173 | __tablename__ = 'simple_items'
174 |
175 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
176 |
177 | container: Mapped[list['SimpleContainers']] = \
178 | relationship('SimpleContainers', secondary='container_items', back_populates='item')
179 |
180 |
181 | t_container_items = Table(
182 | 'container_items', Base.metadata,
183 | Column('item_id', ForeignKey('simple_items.id')),
184 | Column('container_id', ForeignKey('simple_containers.id'))
185 | )
186 | """,
187 | )
188 |
189 |
190 | def test_named_foreign_key_constraints(generator: CodeGenerator) -> None:
191 | Table(
192 | "simple_items",
193 | generator.metadata,
194 | Column("id", INTEGER, primary_key=True),
195 | Column("container_id", INTEGER),
196 | ForeignKeyConstraint(
197 | ["container_id"], ["simple_containers.id"], name="foreignkeytest"
198 | ),
199 | )
200 | Table(
201 | "simple_containers",
202 | generator.metadata,
203 | Column("id", INTEGER, primary_key=True),
204 | )
205 |
206 | validate_code(
207 | generator.generate(),
208 | """\
209 | from typing import Optional
210 |
211 | from sqlalchemy import ForeignKeyConstraint, Integer
212 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
213 | mapped_column, relationship
214 |
215 | class Base(MappedAsDataclass, DeclarativeBase):
216 | pass
217 |
218 |
219 | class SimpleContainers(Base):
220 | __tablename__ = 'simple_containers'
221 |
222 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
223 |
224 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
225 | back_populates='container')
226 |
227 |
228 | class SimpleItems(Base):
229 | __tablename__ = 'simple_items'
230 | __table_args__ = (
231 | ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \
232 | name='foreignkeytest'),
233 | )
234 |
235 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
236 | container_id: Mapped[Optional[int]] = mapped_column(Integer)
237 |
238 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \
239 | back_populates='simple_items')
240 | """,
241 | )
242 |
243 |
244 | def test_uuid_type_annotation(generator: CodeGenerator) -> None:
245 | Table(
246 | "simple",
247 | generator.metadata,
248 | Column("id", UUID, primary_key=True),
249 | )
250 |
251 | validate_code(
252 | generator.generate(),
253 | """\
254 | from sqlalchemy import UUID
255 | from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
256 | mapped_column
257 | import uuid
258 |
259 | class Base(MappedAsDataclass, DeclarativeBase):
260 | pass
261 |
262 |
263 | class Simple(Base):
264 | __tablename__ = 'simple'
265 |
266 | id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True)
267 | """,
268 | )
269 |
--------------------------------------------------------------------------------
/tests/test_generator_declarative.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 | from _pytest.fixtures import FixtureRequest
5 | from sqlalchemy import PrimaryKeyConstraint
6 | from sqlalchemy.engine import Engine
7 | from sqlalchemy.schema import (
8 | CheckConstraint,
9 | Column,
10 | ForeignKey,
11 | ForeignKeyConstraint,
12 | Index,
13 | MetaData,
14 | Table,
15 | UniqueConstraint,
16 | )
17 | from sqlalchemy.sql.expression import text
18 | from sqlalchemy.types import ARRAY, INTEGER, VARCHAR, Text
19 |
20 | from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator
21 |
22 | from .conftest import validate_code
23 |
24 |
25 | @pytest.fixture
26 | def generator(
27 | request: FixtureRequest, metadata: MetaData, engine: Engine
28 | ) -> CodeGenerator:
29 | options = getattr(request, "param", [])
30 | return DeclarativeGenerator(metadata, engine, options)
31 |
32 |
33 | def test_indexes(generator: CodeGenerator) -> None:
34 | simple_items = Table(
35 | "simple_items",
36 | generator.metadata,
37 | Column("id", INTEGER, primary_key=True),
38 | Column("number", INTEGER),
39 | Column("text", VARCHAR),
40 | )
41 | simple_items.indexes.add(Index("idx_number", simple_items.c.number))
42 | simple_items.indexes.add(
43 | Index("idx_text_number", simple_items.c.text, simple_items.c.number)
44 | )
45 | simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True))
46 |
47 | validate_code(
48 | generator.generate(),
49 | """\
50 | from typing import Optional
51 |
52 | from sqlalchemy import Index, Integer, String
53 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
54 |
55 | class Base(DeclarativeBase):
56 | pass
57 |
58 |
59 | class SimpleItems(Base):
60 | __tablename__ = 'simple_items'
61 | __table_args__ = (
62 | Index('idx_number', 'number'),
63 | Index('idx_text', 'text', unique=True),
64 | Index('idx_text_number', 'text', 'number')
65 | )
66 |
67 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
68 | number: Mapped[Optional[int]] = mapped_column(Integer)
69 | text: Mapped[Optional[str]] = mapped_column(String)
70 | """,
71 | )
72 |
73 |
74 | def test_constraints(generator: CodeGenerator) -> None:
75 | Table(
76 | "simple_items",
77 | generator.metadata,
78 | Column("id", INTEGER, primary_key=True),
79 | Column("number", INTEGER),
80 | CheckConstraint("number > 0"),
81 | UniqueConstraint("id", "number"),
82 | )
83 |
84 | validate_code(
85 | generator.generate(),
86 | """\
87 | from typing import Optional
88 |
89 | from sqlalchemy import CheckConstraint, Integer, UniqueConstraint
90 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
91 |
92 | class Base(DeclarativeBase):
93 | pass
94 |
95 |
96 | class SimpleItems(Base):
97 | __tablename__ = 'simple_items'
98 | __table_args__ = (
99 | CheckConstraint('number > 0'),
100 | UniqueConstraint('id', 'number')
101 | )
102 |
103 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
104 | number: Mapped[Optional[int]] = mapped_column(Integer)
105 | """,
106 | )
107 |
108 |
109 | def test_onetomany(generator: CodeGenerator) -> None:
110 | Table(
111 | "simple_items",
112 | generator.metadata,
113 | Column("id", INTEGER, primary_key=True),
114 | Column("container_id", INTEGER),
115 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]),
116 | )
117 | Table(
118 | "simple_containers",
119 | generator.metadata,
120 | Column("id", INTEGER, primary_key=True),
121 | )
122 |
123 | validate_code(
124 | generator.generate(),
125 | """\
126 | from typing import Optional
127 |
128 | from sqlalchemy import ForeignKey, Integer
129 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
130 |
131 | class Base(DeclarativeBase):
132 | pass
133 |
134 |
135 | class SimpleContainers(Base):
136 | __tablename__ = 'simple_containers'
137 |
138 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
139 |
140 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
141 | back_populates='container')
142 |
143 |
144 | class SimpleItems(Base):
145 | __tablename__ = 'simple_items'
146 |
147 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
148 | container_id: Mapped[Optional[int]] = \
149 | mapped_column(ForeignKey('simple_containers.id'))
150 |
151 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \
152 | back_populates='simple_items')
153 | """,
154 | )
155 |
156 |
157 | def test_onetomany_selfref(generator: CodeGenerator) -> None:
158 | Table(
159 | "simple_items",
160 | generator.metadata,
161 | Column("id", INTEGER, primary_key=True),
162 | Column("parent_item_id", INTEGER),
163 | ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]),
164 | )
165 |
166 | validate_code(
167 | generator.generate(),
168 | """\
169 | from typing import Optional
170 |
171 | from sqlalchemy import ForeignKey, Integer
172 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
173 |
174 | class Base(DeclarativeBase):
175 | pass
176 |
177 |
178 | class SimpleItems(Base):
179 | __tablename__ = 'simple_items'
180 |
181 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
182 | parent_item_id: Mapped[Optional[int]] = \
183 | mapped_column(ForeignKey('simple_items.id'))
184 |
185 | parent_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', \
186 | remote_side=[id], back_populates='parent_item_reverse')
187 | parent_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
188 | remote_side=[parent_item_id], back_populates='parent_item')
189 | """,
190 | )
191 |
192 |
193 | def test_onetomany_selfref_multi(generator: CodeGenerator) -> None:
194 | Table(
195 | "simple_items",
196 | generator.metadata,
197 | Column("id", INTEGER, primary_key=True),
198 | Column("parent_item_id", INTEGER),
199 | Column("top_item_id", INTEGER),
200 | ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]),
201 | ForeignKeyConstraint(["top_item_id"], ["simple_items.id"]),
202 | )
203 |
204 | validate_code(
205 | generator.generate(),
206 | """\
207 | from typing import Optional
208 |
209 | from sqlalchemy import ForeignKey, Integer
210 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
211 |
212 | class Base(DeclarativeBase):
213 | pass
214 |
215 |
216 | class SimpleItems(Base):
217 | __tablename__ = 'simple_items'
218 |
219 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
220 | parent_item_id: Mapped[Optional[int]] = \
221 | mapped_column(ForeignKey('simple_items.id'))
222 | top_item_id: Mapped[Optional[int]] = mapped_column(ForeignKey('simple_items.id'))
223 |
224 | parent_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', \
225 | remote_side=[id], foreign_keys=[parent_item_id], back_populates='parent_item_reverse')
226 | parent_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
227 | remote_side=[parent_item_id], foreign_keys=[parent_item_id], \
228 | back_populates='parent_item')
229 | top_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', remote_side=[id], \
230 | foreign_keys=[top_item_id], back_populates='top_item_reverse')
231 | top_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
232 | remote_side=[top_item_id], foreign_keys=[top_item_id], back_populates='top_item')
233 | """,
234 | )
235 |
236 |
237 | def test_onetomany_composite(generator: CodeGenerator) -> None:
238 | Table(
239 | "simple_items",
240 | generator.metadata,
241 | Column("id", INTEGER, primary_key=True),
242 | Column("container_id1", INTEGER),
243 | Column("container_id2", INTEGER),
244 | ForeignKeyConstraint(
245 | ["container_id1", "container_id2"],
246 | ["simple_containers.id1", "simple_containers.id2"],
247 | ondelete="CASCADE",
248 | onupdate="CASCADE",
249 | ),
250 | )
251 | Table(
252 | "simple_containers",
253 | generator.metadata,
254 | Column("id1", INTEGER, primary_key=True),
255 | Column("id2", INTEGER, primary_key=True),
256 | )
257 |
258 | validate_code(
259 | generator.generate(),
260 | """\
261 | from typing import Optional
262 |
263 | from sqlalchemy import ForeignKeyConstraint, Integer
264 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
265 |
266 | class Base(DeclarativeBase):
267 | pass
268 |
269 |
270 | class SimpleContainers(Base):
271 | __tablename__ = 'simple_containers'
272 |
273 | id1: Mapped[int] = mapped_column(Integer, primary_key=True)
274 | id2: Mapped[int] = mapped_column(Integer, primary_key=True)
275 |
276 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
277 | back_populates='simple_containers')
278 |
279 |
280 | class SimpleItems(Base):
281 | __tablename__ = 'simple_items'
282 | __table_args__ = (
283 | ForeignKeyConstraint(['container_id1', 'container_id2'], \
284 | ['simple_containers.id1', 'simple_containers.id2'], ondelete='CASCADE', \
285 | onupdate='CASCADE'),
286 | )
287 |
288 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
289 | container_id1: Mapped[Optional[int]] = mapped_column(Integer)
290 | container_id2: Mapped[Optional[int]] = mapped_column(Integer)
291 |
292 | simple_containers: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \
293 | back_populates='simple_items')
294 | """,
295 | )
296 |
297 |
298 | def test_onetomany_multiref(generator: CodeGenerator) -> None:
299 | Table(
300 | "simple_items",
301 | generator.metadata,
302 | Column("id", INTEGER, primary_key=True),
303 | Column("parent_container_id", INTEGER),
304 | Column("top_container_id", INTEGER, nullable=False),
305 | ForeignKeyConstraint(["parent_container_id"], ["simple_containers.id"]),
306 | ForeignKeyConstraint(["top_container_id"], ["simple_containers.id"]),
307 | )
308 | Table(
309 | "simple_containers",
310 | generator.metadata,
311 | Column("id", INTEGER, primary_key=True),
312 | )
313 |
314 | validate_code(
315 | generator.generate(),
316 | """\
317 | from typing import Optional
318 |
319 | from sqlalchemy import ForeignKey, Integer
320 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
321 |
322 | class Base(DeclarativeBase):
323 | pass
324 |
325 |
326 | class SimpleContainers(Base):
327 | __tablename__ = 'simple_containers'
328 |
329 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
330 |
331 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
332 | foreign_keys='[SimpleItems.parent_container_id]', back_populates='parent_container')
333 | simple_items_: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
334 | foreign_keys='[SimpleItems.top_container_id]', back_populates='top_container')
335 |
336 |
337 | class SimpleItems(Base):
338 | __tablename__ = 'simple_items'
339 |
340 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
341 | top_container_id: Mapped[int] = \
342 | mapped_column(ForeignKey('simple_containers.id'))
343 | parent_container_id: Mapped[Optional[int]] = \
344 | mapped_column(ForeignKey('simple_containers.id'))
345 |
346 | parent_container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \
347 | foreign_keys=[parent_container_id], back_populates='simple_items')
348 | top_container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \
349 | foreign_keys=[top_container_id], back_populates='simple_items_')
350 | """,
351 | )
352 |
353 |
354 | def test_onetoone(generator: CodeGenerator) -> None:
355 | Table(
356 | "simple_items",
357 | generator.metadata,
358 | Column("id", INTEGER, primary_key=True),
359 | Column("other_item_id", INTEGER),
360 | ForeignKeyConstraint(["other_item_id"], ["other_items.id"]),
361 | UniqueConstraint("other_item_id"),
362 | )
363 | Table(
364 | "other_items",
365 | generator.metadata,
366 | Column("id", INTEGER, primary_key=True),
367 | )
368 |
369 | validate_code(
370 | generator.generate(),
371 | """\
372 | from typing import Optional
373 |
374 | from sqlalchemy import ForeignKey, Integer
375 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
376 |
377 | class Base(DeclarativeBase):
378 | pass
379 |
380 |
381 | class OtherItems(Base):
382 | __tablename__ = 'other_items'
383 |
384 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
385 |
386 | simple_items: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', uselist=False, \
387 | back_populates='other_item')
388 |
389 |
390 | class SimpleItems(Base):
391 | __tablename__ = 'simple_items'
392 |
393 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
394 | other_item_id: Mapped[Optional[int]] = \
395 | mapped_column(ForeignKey('other_items.id'), unique=True)
396 |
397 | other_item: Mapped[Optional['OtherItems']] = relationship('OtherItems', \
398 | back_populates='simple_items')
399 | """,
400 | )
401 |
402 |
403 | def test_onetomany_noinflect(generator: CodeGenerator) -> None:
404 | Table(
405 | "oglkrogk",
406 | generator.metadata,
407 | Column("id", INTEGER, primary_key=True),
408 | Column("fehwiuhfiwID", INTEGER),
409 | ForeignKeyConstraint(["fehwiuhfiwID"], ["fehwiuhfiw.id"]),
410 | )
411 | Table("fehwiuhfiw", generator.metadata, Column("id", INTEGER, primary_key=True))
412 |
413 | validate_code(
414 | generator.generate(),
415 | """\
416 | from typing import Optional
417 |
418 | from sqlalchemy import ForeignKey, Integer
419 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
420 |
421 | class Base(DeclarativeBase):
422 | pass
423 |
424 |
425 | class Fehwiuhfiw(Base):
426 | __tablename__ = 'fehwiuhfiw'
427 |
428 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
429 |
430 | oglkrogk: Mapped[list['Oglkrogk']] = relationship('Oglkrogk', \
431 | back_populates='fehwiuhfiw')
432 |
433 |
434 | class Oglkrogk(Base):
435 | __tablename__ = 'oglkrogk'
436 |
437 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
438 | fehwiuhfiwID: Mapped[Optional[int]] = mapped_column(ForeignKey('fehwiuhfiw.id'))
439 |
440 | fehwiuhfiw: Mapped[Optional['Fehwiuhfiw']] = \
441 | relationship('Fehwiuhfiw', back_populates='oglkrogk')
442 | """,
443 | )
444 |
445 |
446 | def test_onetomany_conflicting_column(generator: CodeGenerator) -> None:
447 | Table(
448 | "simple_items",
449 | generator.metadata,
450 | Column("id", INTEGER, primary_key=True),
451 | Column("container_id", INTEGER),
452 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]),
453 | )
454 | Table(
455 | "simple_containers",
456 | generator.metadata,
457 | Column("id", INTEGER, primary_key=True),
458 | Column("relationship", Text),
459 | )
460 |
461 | validate_code(
462 | generator.generate(),
463 | """\
464 | from typing import Optional
465 |
466 | from sqlalchemy import ForeignKey, Integer, Text
467 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
468 |
469 | class Base(DeclarativeBase):
470 | pass
471 |
472 |
473 | class SimpleContainers(Base):
474 | __tablename__ = 'simple_containers'
475 |
476 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
477 | relationship_: Mapped[Optional[str]] = mapped_column('relationship', Text)
478 |
479 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
480 | back_populates='container')
481 |
482 |
483 | class SimpleItems(Base):
484 | __tablename__ = 'simple_items'
485 |
486 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
487 | container_id: Mapped[Optional[int]] = \
488 | mapped_column(ForeignKey('simple_containers.id'))
489 |
490 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \
491 | back_populates='simple_items')
492 | """,
493 | )
494 |
495 |
496 | def test_onetomany_conflicting_relationship(generator: CodeGenerator) -> None:
497 | Table(
498 | "simple_items",
499 | generator.metadata,
500 | Column("id", INTEGER, primary_key=True),
501 | Column("relationship_id", INTEGER),
502 | ForeignKeyConstraint(["relationship_id"], ["relationship.id"]),
503 | )
504 | Table("relationship", generator.metadata, Column("id", INTEGER, primary_key=True))
505 |
506 | validate_code(
507 | generator.generate(),
508 | """\
509 | from typing import Optional
510 |
511 | from sqlalchemy import ForeignKey, Integer
512 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
513 |
514 | class Base(DeclarativeBase):
515 | pass
516 |
517 |
518 | class Relationship(Base):
519 | __tablename__ = 'relationship'
520 |
521 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
522 |
523 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
524 | back_populates='relationship_')
525 |
526 |
527 | class SimpleItems(Base):
528 | __tablename__ = 'simple_items'
529 |
530 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
531 | relationship_id: Mapped[Optional[int]] = \
532 | mapped_column(ForeignKey('relationship.id'))
533 |
534 | relationship_: Mapped[Optional['Relationship']] = relationship('Relationship', \
535 | back_populates='simple_items')
536 | """,
537 | )
538 |
539 |
540 | @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True)
541 | def test_manytoone_nobidi(generator: CodeGenerator) -> None:
542 | Table(
543 | "simple_items",
544 | generator.metadata,
545 | Column("id", INTEGER, primary_key=True),
546 | Column("container_id", INTEGER),
547 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]),
548 | )
549 | Table(
550 | "simple_containers",
551 | generator.metadata,
552 | Column("id", INTEGER, primary_key=True),
553 | )
554 |
555 | validate_code(
556 | generator.generate(),
557 | """\
558 | from typing import Optional
559 |
560 | from sqlalchemy import ForeignKey, Integer
561 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
562 |
563 | class Base(DeclarativeBase):
564 | pass
565 |
566 |
567 | class SimpleContainers(Base):
568 | __tablename__ = 'simple_containers'
569 |
570 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
571 |
572 |
573 | class SimpleItems(Base):
574 | __tablename__ = 'simple_items'
575 |
576 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
577 | container_id: Mapped[Optional[int]] = \
578 | mapped_column(ForeignKey('simple_containers.id'))
579 |
580 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers')
581 | """,
582 | )
583 |
584 |
585 | def test_manytomany(generator: CodeGenerator) -> None:
586 | Table("left_table", generator.metadata, Column("id", INTEGER, primary_key=True))
587 | Table(
588 | "right_table",
589 | generator.metadata,
590 | Column("id", INTEGER, primary_key=True),
591 | )
592 | Table(
593 | "association_table",
594 | generator.metadata,
595 | Column("left_id", INTEGER),
596 | Column("right_id", INTEGER),
597 | ForeignKeyConstraint(["left_id"], ["left_table.id"]),
598 | ForeignKeyConstraint(["right_id"], ["right_table.id"]),
599 | )
600 |
601 | validate_code(
602 | generator.generate(),
603 | """\
604 | from sqlalchemy import Column, ForeignKey, Integer, Table
605 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
606 |
607 | class Base(DeclarativeBase):
608 | pass
609 |
610 |
611 | class LeftTable(Base):
612 | __tablename__ = 'left_table'
613 |
614 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
615 |
616 | right: Mapped[list['RightTable']] = relationship('RightTable', \
617 | secondary='association_table', back_populates='left')
618 |
619 |
620 | class RightTable(Base):
621 | __tablename__ = 'right_table'
622 |
623 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
624 |
625 | left: Mapped[list['LeftTable']] = relationship('LeftTable', \
626 | secondary='association_table', back_populates='right')
627 |
628 |
629 | t_association_table = Table(
630 | 'association_table', Base.metadata,
631 | Column('left_id', ForeignKey('left_table.id')),
632 | Column('right_id', ForeignKey('right_table.id'))
633 | )
634 | """,
635 | )
636 |
637 |
638 | @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True)
639 | def test_manytomany_nobidi(generator: CodeGenerator) -> None:
640 | Table("simple_items", generator.metadata, Column("id", INTEGER, primary_key=True))
641 | Table(
642 | "simple_containers",
643 | generator.metadata,
644 | Column("id", INTEGER, primary_key=True),
645 | )
646 | Table(
647 | "container_items",
648 | generator.metadata,
649 | Column("item_id", INTEGER),
650 | Column("container_id", INTEGER),
651 | ForeignKeyConstraint(["item_id"], ["simple_items.id"]),
652 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]),
653 | )
654 |
655 | validate_code(
656 | generator.generate(),
657 | """\
658 | from sqlalchemy import Column, ForeignKey, Integer, Table
659 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
660 |
661 | class Base(DeclarativeBase):
662 | pass
663 |
664 |
665 | class SimpleContainers(Base):
666 | __tablename__ = 'simple_containers'
667 |
668 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
669 |
670 | item: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
671 | secondary='container_items')
672 |
673 |
674 | class SimpleItems(Base):
675 | __tablename__ = 'simple_items'
676 |
677 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
678 |
679 |
680 | t_container_items = Table(
681 | 'container_items', Base.metadata,
682 | Column('item_id', ForeignKey('simple_items.id')),
683 | Column('container_id', ForeignKey('simple_containers.id'))
684 | )
685 | """,
686 | )
687 |
688 |
689 | def test_manytomany_selfref(generator: CodeGenerator) -> None:
690 | Table("simple_items", generator.metadata, Column("id", INTEGER, primary_key=True))
691 | Table(
692 | "child_items",
693 | generator.metadata,
694 | Column("parent_id", INTEGER),
695 | Column("child_id", INTEGER),
696 | ForeignKeyConstraint(["parent_id"], ["simple_items.id"]),
697 | ForeignKeyConstraint(["child_id"], ["simple_items.id"]),
698 | schema="otherschema",
699 | )
700 |
701 | validate_code(
702 | generator.generate(),
703 | """\
704 | from sqlalchemy import Column, ForeignKey, Integer, Table
705 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
706 |
707 | class Base(DeclarativeBase):
708 | pass
709 |
710 |
711 | class SimpleItems(Base):
712 | __tablename__ = 'simple_items'
713 |
714 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
715 |
716 | parent: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
717 | secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id \
718 | == t_child_items.c.child_id, \
719 | secondaryjoin=lambda: SimpleItems.id == \
720 | t_child_items.c.parent_id, back_populates='child')
721 | child: Mapped[list['SimpleItems']] = \
722 | relationship('SimpleItems', secondary='otherschema.child_items', \
723 | primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, \
724 | secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, \
725 | back_populates='parent')
726 |
727 |
728 | t_child_items = Table(
729 | 'child_items', Base.metadata,
730 | Column('parent_id', ForeignKey('simple_items.id')),
731 | Column('child_id', ForeignKey('simple_items.id')),
732 | schema='otherschema'
733 | )
734 | """,
735 | )
736 |
737 |
738 | def test_manytomany_composite(generator: CodeGenerator) -> None:
739 | Table(
740 | "simple_items",
741 | generator.metadata,
742 | Column("id1", INTEGER, primary_key=True),
743 | Column("id2", INTEGER, primary_key=True),
744 | )
745 | Table(
746 | "simple_containers",
747 | generator.metadata,
748 | Column("id1", INTEGER, primary_key=True),
749 | Column("id2", INTEGER, primary_key=True),
750 | )
751 | Table(
752 | "container_items",
753 | generator.metadata,
754 | Column("item_id1", INTEGER),
755 | Column("item_id2", INTEGER),
756 | Column("container_id1", INTEGER),
757 | Column("container_id2", INTEGER),
758 | ForeignKeyConstraint(
759 | ["item_id1", "item_id2"], ["simple_items.id1", "simple_items.id2"]
760 | ),
761 | ForeignKeyConstraint(
762 | ["container_id1", "container_id2"],
763 | ["simple_containers.id1", "simple_containers.id2"],
764 | ),
765 | )
766 |
767 | validate_code(
768 | generator.generate(),
769 | """\
770 | from sqlalchemy import Column, ForeignKeyConstraint, Integer, Table
771 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
772 |
773 | class Base(DeclarativeBase):
774 | pass
775 |
776 |
777 | class SimpleContainers(Base):
778 | __tablename__ = 'simple_containers'
779 |
780 | id1: Mapped[int] = mapped_column(Integer, primary_key=True)
781 | id2: Mapped[int] = mapped_column(Integer, primary_key=True)
782 |
783 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
784 | secondary='container_items', back_populates='simple_containers')
785 |
786 |
787 | class SimpleItems(Base):
788 | __tablename__ = 'simple_items'
789 |
790 | id1: Mapped[int] = mapped_column(Integer, primary_key=True)
791 | id2: Mapped[int] = mapped_column(Integer, primary_key=True)
792 |
793 | simple_containers: Mapped[list['SimpleContainers']] = \
794 | relationship('SimpleContainers', secondary='container_items', \
795 | back_populates='simple_items')
796 |
797 |
798 | t_container_items = Table(
799 | 'container_items', Base.metadata,
800 | Column('item_id1', Integer),
801 | Column('item_id2', Integer),
802 | Column('container_id1', Integer),
803 | Column('container_id2', Integer),
804 | ForeignKeyConstraint(['container_id1', 'container_id2'], \
805 | ['simple_containers.id1', 'simple_containers.id2']),
806 | ForeignKeyConstraint(['item_id1', 'item_id2'], \
807 | ['simple_items.id1', 'simple_items.id2'])
808 | )
809 | """,
810 | )
811 |
812 |
813 | def test_joined_inheritance(generator: CodeGenerator) -> None:
814 | Table(
815 | "simple_sub_items",
816 | generator.metadata,
817 | Column("simple_items_id", INTEGER, primary_key=True),
818 | Column("data3", INTEGER),
819 | ForeignKeyConstraint(["simple_items_id"], ["simple_items.super_item_id"]),
820 | )
821 | Table(
822 | "simple_super_items",
823 | generator.metadata,
824 | Column("id", INTEGER, primary_key=True),
825 | Column("data1", INTEGER),
826 | )
827 | Table(
828 | "simple_items",
829 | generator.metadata,
830 | Column("super_item_id", INTEGER, primary_key=True),
831 | Column("data2", INTEGER),
832 | ForeignKeyConstraint(["super_item_id"], ["simple_super_items.id"]),
833 | )
834 |
835 | validate_code(
836 | generator.generate(),
837 | """\
838 | from typing import Optional
839 |
840 | from sqlalchemy import ForeignKey, Integer
841 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
842 |
843 | class Base(DeclarativeBase):
844 | pass
845 |
846 |
847 | class SimpleSuperItems(Base):
848 | __tablename__ = 'simple_super_items'
849 |
850 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
851 | data1: Mapped[Optional[int]] = mapped_column(Integer)
852 |
853 |
854 | class SimpleItems(SimpleSuperItems):
855 | __tablename__ = 'simple_items'
856 |
857 | super_item_id: Mapped[int] = mapped_column(ForeignKey('simple_super_items.id'), \
858 | primary_key=True)
859 | data2: Mapped[Optional[int]] = mapped_column(Integer)
860 |
861 |
862 | class SimpleSubItems(SimpleItems):
863 | __tablename__ = 'simple_sub_items'
864 |
865 | simple_items_id: Mapped[int] = \
866 | mapped_column(ForeignKey('simple_items.super_item_id'), primary_key=True)
867 | data3: Mapped[Optional[int]] = mapped_column(Integer)
868 | """,
869 | )
870 |
871 |
872 | def test_joined_inheritance_same_table_name(generator: CodeGenerator) -> None:
873 | Table(
874 | "simple",
875 | generator.metadata,
876 | Column("id", INTEGER, primary_key=True),
877 | )
878 | Table(
879 | "simple",
880 | generator.metadata,
881 | Column("id", INTEGER, ForeignKey("simple.id"), primary_key=True),
882 | schema="altschema",
883 | )
884 |
885 | validate_code(
886 | generator.generate(),
887 | """\
888 | from sqlalchemy import ForeignKey, Integer
889 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
890 |
891 | class Base(DeclarativeBase):
892 | pass
893 |
894 |
895 | class Simple(Base):
896 | __tablename__ = 'simple'
897 |
898 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
899 |
900 |
901 | class Simple_(Simple):
902 | __tablename__ = 'simple'
903 | __table_args__ = {'schema': 'altschema'}
904 |
905 | id: Mapped[int] = mapped_column(ForeignKey('simple.id'), primary_key=True)
906 | """,
907 | )
908 |
909 |
910 | @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True)
911 | def test_use_inflect(generator: CodeGenerator) -> None:
912 | Table("simple_items", generator.metadata, Column("id", INTEGER, primary_key=True))
913 |
914 | Table("singular", generator.metadata, Column("id", INTEGER, primary_key=True))
915 |
916 | validate_code(
917 | generator.generate(),
918 | """\
919 | from sqlalchemy import Integer
920 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
921 |
922 | class Base(DeclarativeBase):
923 | pass
924 |
925 |
926 | class SimpleItem(Base):
927 | __tablename__ = 'simple_items'
928 |
929 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
930 |
931 |
932 | class Singular(Base):
933 | __tablename__ = 'singular'
934 |
935 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
936 | """,
937 | )
938 |
939 |
940 | @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True)
941 | @pytest.mark.parametrize(
942 | argnames=("table_name", "class_name", "relationship_name"),
943 | argvalues=[
944 | ("manufacturers", "manufacturer", "manufacturer"),
945 | ("statuses", "status", "status"),
946 | ("studies", "study", "study"),
947 | ("moose", "moose", "moose"),
948 | ],
949 | ids=[
950 | "test_inflect_manufacturer",
951 | "test_inflect_status",
952 | "test_inflect_study",
953 | "test_inflect_moose",
954 | ],
955 | )
956 | def test_use_inflect_plural(
957 | generator: CodeGenerator,
958 | table_name: str,
959 | class_name: str,
960 | relationship_name: str,
961 | ) -> None:
962 | Table(
963 | "simple_items",
964 | generator.metadata,
965 | Column("id", INTEGER, primary_key=True),
966 | Column(f"{relationship_name}_id", INTEGER),
967 | ForeignKeyConstraint([f"{relationship_name}_id"], [f"{table_name}.id"]),
968 | UniqueConstraint(f"{relationship_name}_id"),
969 | )
970 | Table(table_name, generator.metadata, Column("id", INTEGER, primary_key=True))
971 |
972 | validate_code(
973 | generator.generate(),
974 | f"""\
975 | from typing import Optional
976 |
977 | from sqlalchemy import ForeignKey, Integer
978 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
979 |
980 | class Base(DeclarativeBase):
981 | pass
982 |
983 |
984 | class {class_name.capitalize()}(Base):
985 | __tablename__ = '{table_name}'
986 |
987 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
988 |
989 | simple_item: Mapped[Optional['SimpleItem']] = relationship('SimpleItem', uselist=False, \
990 | back_populates='{relationship_name}')
991 |
992 |
993 | class SimpleItem(Base):
994 | __tablename__ = 'simple_items'
995 |
996 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
997 | {relationship_name}_id: Mapped[Optional[int]] = \
998 | mapped_column(ForeignKey('{table_name}.id'), unique=True)
999 |
1000 | {relationship_name}: Mapped[Optional['{class_name.capitalize()}']] = \
1001 | relationship('{class_name.capitalize()}', back_populates='simple_item')
1002 | """,
1003 | )
1004 |
1005 |
1006 | @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True)
1007 | def test_use_inflect_plural_double_pluralize(generator: CodeGenerator) -> None:
1008 | Table(
1009 | "users",
1010 | generator.metadata,
1011 | Column("users_id", INTEGER),
1012 | Column("groups_id", INTEGER),
1013 | ForeignKeyConstraint(
1014 | ["groups_id"], ["groups.groups_id"], name="fk_users_groups_id"
1015 | ),
1016 | PrimaryKeyConstraint("users_id", name="users_pkey"),
1017 | )
1018 |
1019 | Table(
1020 | "groups",
1021 | generator.metadata,
1022 | Column("groups_id", INTEGER),
1023 | Column("group_name", Text(50), nullable=False),
1024 | PrimaryKeyConstraint("groups_id", name="groups_pkey"),
1025 | )
1026 |
1027 | validate_code(
1028 | generator.generate(),
1029 | """\
1030 | from typing import Optional
1031 |
1032 | from sqlalchemy import ForeignKeyConstraint, Integer, PrimaryKeyConstraint, Text
1033 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
1034 |
1035 | class Base(DeclarativeBase):
1036 | pass
1037 |
1038 |
1039 | class Group(Base):
1040 | __tablename__ = 'groups'
1041 | __table_args__ = (
1042 | PrimaryKeyConstraint('groups_id', name='groups_pkey'),
1043 | )
1044 |
1045 | groups_id: Mapped[int] = mapped_column(Integer, primary_key=True)
1046 | group_name: Mapped[str] = mapped_column(Text(50))
1047 |
1048 | users: Mapped[list['User']] = relationship('User', back_populates='group')
1049 |
1050 |
1051 | class User(Base):
1052 | __tablename__ = 'users'
1053 | __table_args__ = (
1054 | ForeignKeyConstraint(['groups_id'], ['groups.groups_id'], name='fk_users_groups_id'),
1055 | PrimaryKeyConstraint('users_id', name='users_pkey')
1056 | )
1057 |
1058 | users_id: Mapped[int] = mapped_column(Integer, primary_key=True)
1059 | groups_id: Mapped[Optional[int]] = mapped_column(Integer)
1060 |
1061 | group: Mapped[Optional['Group']] = relationship('Group', back_populates='users')
1062 | """,
1063 | )
1064 |
1065 |
1066 | def test_table_kwargs(generator: CodeGenerator) -> None:
1067 | Table(
1068 | "simple_items",
1069 | generator.metadata,
1070 | Column("id", INTEGER, primary_key=True),
1071 | schema="testschema",
1072 | )
1073 |
1074 | validate_code(
1075 | generator.generate(),
1076 | """\
1077 | from sqlalchemy import Integer
1078 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1079 |
1080 | class Base(DeclarativeBase):
1081 | pass
1082 |
1083 |
1084 | class SimpleItems(Base):
1085 | __tablename__ = 'simple_items'
1086 | __table_args__ = {'schema': 'testschema'}
1087 |
1088 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1089 | """,
1090 | )
1091 |
1092 |
1093 | def test_table_args_kwargs(generator: CodeGenerator) -> None:
1094 | simple_items = Table(
1095 | "simple_items",
1096 | generator.metadata,
1097 | Column("id", INTEGER, primary_key=True),
1098 | Column("name", VARCHAR),
1099 | schema="testschema",
1100 | )
1101 | simple_items.indexes.add(Index("testidx", simple_items.c.id, simple_items.c.name))
1102 |
1103 | validate_code(
1104 | generator.generate(),
1105 | """\
1106 | from typing import Optional
1107 |
1108 | from sqlalchemy import Index, Integer, String
1109 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1110 |
1111 | class Base(DeclarativeBase):
1112 | pass
1113 |
1114 |
1115 | class SimpleItems(Base):
1116 | __tablename__ = 'simple_items'
1117 | __table_args__ = (
1118 | Index('testidx', 'id', 'name'),
1119 | {'schema': 'testschema'}
1120 | )
1121 |
1122 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1123 | name: Mapped[Optional[str]] = mapped_column(String)
1124 | """,
1125 | )
1126 |
1127 |
1128 | def test_foreign_key_schema(generator: CodeGenerator) -> None:
1129 | Table(
1130 | "simple_items",
1131 | generator.metadata,
1132 | Column("id", INTEGER, primary_key=True),
1133 | Column("other_item_id", INTEGER),
1134 | ForeignKeyConstraint(["other_item_id"], ["otherschema.other_items.id"]),
1135 | )
1136 | Table(
1137 | "other_items",
1138 | generator.metadata,
1139 | Column("id", INTEGER, primary_key=True),
1140 | schema="otherschema",
1141 | )
1142 |
1143 | validate_code(
1144 | generator.generate(),
1145 | """\
1146 | from typing import Optional
1147 |
1148 | from sqlalchemy import ForeignKey, Integer
1149 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
1150 |
1151 | class Base(DeclarativeBase):
1152 | pass
1153 |
1154 |
1155 | class OtherItems(Base):
1156 | __tablename__ = 'other_items'
1157 | __table_args__ = {'schema': 'otherschema'}
1158 |
1159 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1160 |
1161 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
1162 | back_populates='other_item')
1163 |
1164 |
1165 | class SimpleItems(Base):
1166 | __tablename__ = 'simple_items'
1167 |
1168 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1169 | other_item_id: Mapped[Optional[int]] = \
1170 | mapped_column(ForeignKey('otherschema.other_items.id'))
1171 |
1172 | other_item: Mapped[Optional['OtherItems']] = relationship('OtherItems', \
1173 | back_populates='simple_items')
1174 | """,
1175 | )
1176 |
1177 |
1178 | def test_invalid_attribute_names(generator: CodeGenerator) -> None:
1179 | Table(
1180 | "simple-items",
1181 | generator.metadata,
1182 | Column("id-test", INTEGER, primary_key=True),
1183 | Column("4test", INTEGER),
1184 | Column("_4test", INTEGER),
1185 | Column("def", INTEGER),
1186 | )
1187 |
1188 | validate_code(
1189 | generator.generate(),
1190 | """\
1191 | from typing import Optional
1192 |
1193 | from sqlalchemy import Integer
1194 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1195 |
1196 | class Base(DeclarativeBase):
1197 | pass
1198 |
1199 |
1200 | class SimpleItems(Base):
1201 | __tablename__ = 'simple-items'
1202 |
1203 | id_test: Mapped[int] = mapped_column('id-test', Integer, primary_key=True)
1204 | _4test: Mapped[Optional[int]] = mapped_column('4test', Integer)
1205 | _4test_: Mapped[Optional[int]] = mapped_column('_4test', Integer)
1206 | def_: Mapped[Optional[int]] = mapped_column('def', Integer)
1207 | """,
1208 | )
1209 |
1210 |
1211 | def test_pascal(generator: CodeGenerator) -> None:
1212 | Table(
1213 | "CustomerAPIPreference",
1214 | generator.metadata,
1215 | Column("id", INTEGER, primary_key=True),
1216 | )
1217 |
1218 | validate_code(
1219 | generator.generate(),
1220 | """\
1221 | from sqlalchemy import Integer
1222 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1223 |
1224 | class Base(DeclarativeBase):
1225 | pass
1226 |
1227 |
1228 | class CustomerAPIPreference(Base):
1229 | __tablename__ = 'CustomerAPIPreference'
1230 |
1231 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1232 | """,
1233 | )
1234 |
1235 |
1236 | def test_underscore(generator: CodeGenerator) -> None:
1237 | Table(
1238 | "customer_api_preference",
1239 | generator.metadata,
1240 | Column("id", INTEGER, primary_key=True),
1241 | )
1242 |
1243 | validate_code(
1244 | generator.generate(),
1245 | """\
1246 | from sqlalchemy import Integer
1247 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1248 |
1249 | class Base(DeclarativeBase):
1250 | pass
1251 |
1252 |
1253 | class CustomerApiPreference(Base):
1254 | __tablename__ = 'customer_api_preference'
1255 |
1256 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1257 | """,
1258 | )
1259 |
1260 |
1261 | def test_pascal_underscore(generator: CodeGenerator) -> None:
1262 | Table(
1263 | "customer_API_Preference",
1264 | generator.metadata,
1265 | Column("id", INTEGER, primary_key=True),
1266 | )
1267 |
1268 | validate_code(
1269 | generator.generate(),
1270 | """\
1271 | from sqlalchemy import Integer
1272 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1273 |
1274 | class Base(DeclarativeBase):
1275 | pass
1276 |
1277 |
1278 | class CustomerAPIPreference(Base):
1279 | __tablename__ = 'customer_API_Preference'
1280 |
1281 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1282 | """,
1283 | )
1284 |
1285 |
1286 | def test_pascal_multiple_underscore(generator: CodeGenerator) -> None:
1287 | Table(
1288 | "customer_API__Preference",
1289 | generator.metadata,
1290 | Column("id", INTEGER, primary_key=True),
1291 | )
1292 |
1293 | validate_code(
1294 | generator.generate(),
1295 | """\
1296 | from sqlalchemy import Integer
1297 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1298 |
1299 | class Base(DeclarativeBase):
1300 | pass
1301 |
1302 |
1303 | class CustomerAPIPreference(Base):
1304 | __tablename__ = 'customer_API__Preference'
1305 |
1306 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1307 | """,
1308 | )
1309 |
1310 |
1311 | @pytest.mark.parametrize(
1312 | "generator, nocomments",
1313 | [([], False), (["nocomments"], True)],
1314 | indirect=["generator"],
1315 | )
1316 | def test_column_comment(generator: CodeGenerator, nocomments: bool) -> None:
1317 | Table(
1318 | "simple",
1319 | generator.metadata,
1320 | Column("id", INTEGER, primary_key=True, comment="this is a 'comment'"),
1321 | )
1322 |
1323 | comment_part = "" if nocomments else ", comment=\"this is a 'comment'\""
1324 | validate_code(
1325 | generator.generate(),
1326 | f"""\
1327 | from sqlalchemy import Integer
1328 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1329 |
1330 | class Base(DeclarativeBase):
1331 | pass
1332 |
1333 |
1334 | class Simple(Base):
1335 | __tablename__ = 'simple'
1336 |
1337 | id: Mapped[int] = mapped_column(Integer, primary_key=True{comment_part})
1338 | """,
1339 | )
1340 |
1341 |
1342 | def test_table_comment(generator: CodeGenerator) -> None:
1343 | Table(
1344 | "simple",
1345 | generator.metadata,
1346 | Column("id", INTEGER, primary_key=True),
1347 | comment="this is a 'comment'",
1348 | )
1349 |
1350 | validate_code(
1351 | generator.generate(),
1352 | """\
1353 | from sqlalchemy import Integer
1354 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1355 |
1356 | class Base(DeclarativeBase):
1357 | pass
1358 |
1359 |
1360 | class Simple(Base):
1361 | __tablename__ = 'simple'
1362 | __table_args__ = {'comment': "this is a 'comment'"}
1363 |
1364 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1365 | """,
1366 | )
1367 |
1368 |
1369 | def test_metadata_column(generator: CodeGenerator) -> None:
1370 | Table(
1371 | "simple",
1372 | generator.metadata,
1373 | Column("id", INTEGER, primary_key=True),
1374 | Column("metadata", VARCHAR),
1375 | )
1376 |
1377 | validate_code(
1378 | generator.generate(),
1379 | """\
1380 | from typing import Optional
1381 |
1382 | from sqlalchemy import Integer, String
1383 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1384 |
1385 | class Base(DeclarativeBase):
1386 | pass
1387 |
1388 |
1389 | class Simple(Base):
1390 | __tablename__ = 'simple'
1391 |
1392 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1393 | metadata_: Mapped[Optional[str]] = mapped_column('metadata', String)
1394 | """,
1395 | )
1396 |
1397 |
1398 | def test_invalid_variable_name_from_column(generator: CodeGenerator) -> None:
1399 | Table(
1400 | "simple",
1401 | generator.metadata,
1402 | Column(" id ", INTEGER, primary_key=True),
1403 | )
1404 |
1405 | validate_code(
1406 | generator.generate(),
1407 | """\
1408 | from sqlalchemy import Integer
1409 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1410 |
1411 | class Base(DeclarativeBase):
1412 | pass
1413 |
1414 |
1415 | class Simple(Base):
1416 | __tablename__ = 'simple'
1417 |
1418 | id: Mapped[int] = mapped_column(' id ', Integer, primary_key=True)
1419 | """,
1420 | )
1421 |
1422 |
1423 | def test_only_tables(generator: CodeGenerator) -> None:
1424 | Table("simple", generator.metadata, Column("id", INTEGER))
1425 |
1426 | validate_code(
1427 | generator.generate(),
1428 | """\
1429 | from sqlalchemy import Column, Integer, MetaData, Table
1430 |
1431 | metadata = MetaData()
1432 |
1433 |
1434 | t_simple = Table(
1435 | 'simple', metadata,
1436 | Column('id', Integer)
1437 | )
1438 | """,
1439 | )
1440 |
1441 |
1442 | def test_named_constraints(generator: CodeGenerator) -> None:
1443 | Table(
1444 | "simple",
1445 | generator.metadata,
1446 | Column("id", INTEGER),
1447 | Column("text", VARCHAR),
1448 | CheckConstraint("id > 0", name="checktest"),
1449 | PrimaryKeyConstraint("id", name="primarytest"),
1450 | UniqueConstraint("text", name="uniquetest"),
1451 | )
1452 |
1453 | validate_code(
1454 | generator.generate(),
1455 | """\
1456 | from typing import Optional
1457 |
1458 | from sqlalchemy import CheckConstraint, Integer, PrimaryKeyConstraint, \
1459 | String, UniqueConstraint
1460 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1461 |
1462 | class Base(DeclarativeBase):
1463 | pass
1464 |
1465 |
1466 | class Simple(Base):
1467 | __tablename__ = 'simple'
1468 | __table_args__ = (
1469 | CheckConstraint('id > 0', name='checktest'),
1470 | PrimaryKeyConstraint('id', name='primarytest'),
1471 | UniqueConstraint('text', name='uniquetest')
1472 | )
1473 |
1474 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1475 | text: Mapped[Optional[str]] = mapped_column(String)
1476 | """,
1477 | )
1478 |
1479 |
1480 | def test_named_foreign_key_constraints(generator: CodeGenerator) -> None:
1481 | Table(
1482 | "simple_items",
1483 | generator.metadata,
1484 | Column("id", INTEGER, primary_key=True),
1485 | Column("container_id", INTEGER),
1486 | ForeignKeyConstraint(
1487 | ["container_id"], ["simple_containers.id"], name="foreignkeytest"
1488 | ),
1489 | )
1490 | Table(
1491 | "simple_containers",
1492 | generator.metadata,
1493 | Column("id", INTEGER, primary_key=True),
1494 | )
1495 |
1496 | validate_code(
1497 | generator.generate(),
1498 | """\
1499 | from typing import Optional
1500 |
1501 | from sqlalchemy import ForeignKeyConstraint, Integer
1502 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
1503 |
1504 | class Base(DeclarativeBase):
1505 | pass
1506 |
1507 |
1508 | class SimpleContainers(Base):
1509 | __tablename__ = 'simple_containers'
1510 |
1511 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1512 |
1513 | simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
1514 | back_populates='container')
1515 |
1516 |
1517 | class SimpleItems(Base):
1518 | __tablename__ = 'simple_items'
1519 | __table_args__ = (
1520 | ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \
1521 | name='foreignkeytest'),
1522 | )
1523 |
1524 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1525 | container_id: Mapped[Optional[int]] = mapped_column(Integer)
1526 |
1527 | container: Mapped[Optional['SimpleContainers']] = relationship('SimpleContainers', \
1528 | back_populates='simple_items')
1529 | """,
1530 | )
1531 |
1532 |
1533 | # @pytest.mark.xfail(strict=True)
1534 | def test_colname_import_conflict(generator: CodeGenerator) -> None:
1535 | Table(
1536 | "simple",
1537 | generator.metadata,
1538 | Column("id", INTEGER, primary_key=True),
1539 | Column("text", VARCHAR),
1540 | Column("textwithdefault", VARCHAR, server_default=text("'test'")),
1541 | )
1542 |
1543 | validate_code(
1544 | generator.generate(),
1545 | """\
1546 | from typing import Optional
1547 |
1548 | from sqlalchemy import Integer, String, text
1549 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1550 |
1551 | class Base(DeclarativeBase):
1552 | pass
1553 |
1554 |
1555 | class Simple(Base):
1556 | __tablename__ = 'simple'
1557 |
1558 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1559 | text_: Mapped[Optional[str]] = mapped_column('text', String)
1560 | textwithdefault: Mapped[Optional[str]] = mapped_column(String, \
1561 | server_default=text("'test'"))
1562 | """,
1563 | )
1564 |
1565 |
1566 | def test_table_with_arrays(generator: CodeGenerator) -> None:
1567 | Table(
1568 | "with_items",
1569 | generator.metadata,
1570 | Column("id", INTEGER, primary_key=True),
1571 | Column("int_items_not_optional", ARRAY(INTEGER()), nullable=False),
1572 | Column("str_matrix", ARRAY(VARCHAR(), dimensions=2)),
1573 | )
1574 |
1575 | validate_code(
1576 | generator.generate(),
1577 | """\
1578 | from typing import Optional
1579 |
1580 | from sqlalchemy import ARRAY, INTEGER, Integer, VARCHAR
1581 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1582 |
1583 | class Base(DeclarativeBase):
1584 | pass
1585 |
1586 |
1587 | class WithItems(Base):
1588 | __tablename__ = 'with_items'
1589 |
1590 | id: Mapped[int] = mapped_column(Integer, primary_key=True)
1591 | int_items_not_optional: Mapped[list[int]] = mapped_column(ARRAY(INTEGER()))
1592 | str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2))
1593 | """,
1594 | )
1595 |
--------------------------------------------------------------------------------
/tests/test_generator_sqlmodel.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 | from _pytest.fixtures import FixtureRequest
5 | from sqlalchemy.engine import Engine
6 | from sqlalchemy.schema import (
7 | CheckConstraint,
8 | Column,
9 | ForeignKeyConstraint,
10 | Index,
11 | MetaData,
12 | Table,
13 | UniqueConstraint,
14 | )
15 | from sqlalchemy.types import INTEGER, VARCHAR
16 |
17 | from sqlacodegen.generators import CodeGenerator, SQLModelGenerator
18 |
19 | from .conftest import validate_code
20 |
21 |
22 | @pytest.fixture
23 | def generator(
24 | request: FixtureRequest, metadata: MetaData, engine: Engine
25 | ) -> CodeGenerator:
26 | options = getattr(request, "param", [])
27 | return SQLModelGenerator(metadata, engine, options)
28 |
29 |
30 | def test_indexes(generator: CodeGenerator) -> None:
31 | simple_items = Table(
32 | "item",
33 | generator.metadata,
34 | Column("id", INTEGER, primary_key=True),
35 | Column("number", INTEGER),
36 | Column("text", VARCHAR),
37 | )
38 | simple_items.indexes.add(Index("idx_number", simple_items.c.number))
39 | simple_items.indexes.add(
40 | Index("idx_text_number", simple_items.c.text, simple_items.c.number)
41 | )
42 | simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True))
43 |
44 | validate_code(
45 | generator.generate(),
46 | """\
47 | from typing import Optional
48 |
49 | from sqlalchemy import Column, Index, Integer, String
50 | from sqlmodel import Field, SQLModel
51 |
52 | class Item(SQLModel, table=True):
53 | __table_args__ = (
54 | Index('idx_number', 'number'),
55 | Index('idx_text', 'text', unique=True),
56 | Index('idx_text_number', 'text', 'number')
57 | )
58 |
59 | id: Optional[int] = Field(default=None, sa_column=Column(\
60 | 'id', Integer, primary_key=True))
61 | number: Optional[int] = Field(default=None, sa_column=Column(\
62 | 'number', Integer))
63 | text: Optional[str] = Field(default=None, sa_column=Column(\
64 | 'text', String))
65 | """,
66 | )
67 |
68 |
69 | def test_constraints(generator: CodeGenerator) -> None:
70 | Table(
71 | "simple_constraints",
72 | generator.metadata,
73 | Column("id", INTEGER, primary_key=True),
74 | Column("number", INTEGER),
75 | CheckConstraint("number > 0"),
76 | UniqueConstraint("id", "number"),
77 | )
78 |
79 | validate_code(
80 | generator.generate(),
81 | """\
82 | from typing import Optional
83 |
84 | from sqlalchemy import CheckConstraint, Column, Integer, UniqueConstraint
85 | from sqlmodel import Field, SQLModel
86 |
87 | class SimpleConstraints(SQLModel, table=True):
88 | __tablename__ = 'simple_constraints'
89 | __table_args__ = (
90 | CheckConstraint('number > 0'),
91 | UniqueConstraint('id', 'number')
92 | )
93 |
94 | id: Optional[int] = Field(default=None, sa_column=Column(\
95 | 'id', Integer, primary_key=True))
96 | number: Optional[int] = Field(default=None, sa_column=Column(\
97 | 'number', Integer))
98 | """,
99 | )
100 |
101 |
102 | def test_onetomany(generator: CodeGenerator) -> None:
103 | Table(
104 | "simple_goods",
105 | generator.metadata,
106 | Column("id", INTEGER, primary_key=True),
107 | Column("container_id", INTEGER),
108 | ForeignKeyConstraint(["container_id"], ["simple_containers.id"]),
109 | )
110 | Table(
111 | "simple_containers",
112 | generator.metadata,
113 | Column("id", INTEGER, primary_key=True),
114 | )
115 |
116 | validate_code(
117 | generator.generate(),
118 | """\
119 | from typing import Optional
120 |
121 | from sqlalchemy import Column, ForeignKey, Integer
122 | from sqlmodel import Field, Relationship, SQLModel
123 |
124 | class SimpleContainers(SQLModel, table=True):
125 | __tablename__ = 'simple_containers'
126 |
127 | id: Optional[int] = Field(default=None, sa_column=Column(\
128 | 'id', Integer, primary_key=True))
129 |
130 | simple_goods: list['SimpleGoods'] = Relationship(\
131 | back_populates='container')
132 |
133 |
134 | class SimpleGoods(SQLModel, table=True):
135 | __tablename__ = 'simple_goods'
136 |
137 | id: Optional[int] = Field(default=None, sa_column=Column(\
138 | 'id', Integer, primary_key=True))
139 | container_id: Optional[int] = Field(default=None, sa_column=Column(\
140 | 'container_id', ForeignKey('simple_containers.id')))
141 |
142 | container: Optional['SimpleContainers'] = Relationship(\
143 | back_populates='simple_goods')
144 | """,
145 | )
146 |
147 |
148 | def test_onetoone(generator: CodeGenerator) -> None:
149 | Table(
150 | "simple_onetoone",
151 | generator.metadata,
152 | Column("id", INTEGER, primary_key=True),
153 | Column("other_item_id", INTEGER),
154 | ForeignKeyConstraint(["other_item_id"], ["other_items.id"]),
155 | UniqueConstraint("other_item_id"),
156 | )
157 | Table("other_items", generator.metadata, Column("id", INTEGER, primary_key=True))
158 |
159 | validate_code(
160 | generator.generate(),
161 | """\
162 | from typing import Optional
163 |
164 | from sqlalchemy import Column, ForeignKey, Integer
165 | from sqlmodel import Field, Relationship, SQLModel
166 |
167 | class OtherItems(SQLModel, table=True):
168 | __tablename__ = 'other_items'
169 |
170 | id: Optional[int] = Field(default=None, sa_column=Column(\
171 | 'id', Integer, primary_key=True))
172 |
173 | simple_onetoone: Optional['SimpleOnetoone'] = Relationship(\
174 | sa_relationship_kwargs={'uselist': False}, back_populates='other_item')
175 |
176 |
177 | class SimpleOnetoone(SQLModel, table=True):
178 | __tablename__ = 'simple_onetoone'
179 |
180 | id: Optional[int] = Field(default=None, sa_column=Column(\
181 | 'id', Integer, primary_key=True))
182 | other_item_id: Optional[int] = Field(default=None, sa_column=Column(\
183 | 'other_item_id', ForeignKey('other_items.id'), unique=True))
184 |
185 | other_item: Optional['OtherItems'] = Relationship(\
186 | back_populates='simple_onetoone')
187 | """,
188 | )
189 |
--------------------------------------------------------------------------------
/tests/test_generator_tables.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from textwrap import dedent
4 |
5 | import pytest
6 | from _pytest.fixtures import FixtureRequest
7 | from sqlalchemy import TypeDecorator
8 | from sqlalchemy.dialects import mysql, postgresql
9 | from sqlalchemy.engine import Engine
10 | from sqlalchemy.schema import (
11 | CheckConstraint,
12 | Column,
13 | Computed,
14 | ForeignKey,
15 | Identity,
16 | Index,
17 | MetaData,
18 | Table,
19 | UniqueConstraint,
20 | )
21 | from sqlalchemy.sql.expression import text
22 | from sqlalchemy.sql.sqltypes import DateTime, NullType
23 | from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text
24 |
25 | from sqlacodegen.generators import CodeGenerator, TablesGenerator
26 |
27 | from .conftest import validate_code
28 |
29 |
30 | # This needs to be uppercased to trigger #315
31 | class TIMESTAMP_DECORATOR(TypeDecorator[DateTime]):
32 | impl = DateTime
33 |
34 |
35 | @pytest.fixture
36 | def generator(
37 | request: FixtureRequest, metadata: MetaData, engine: Engine
38 | ) -> CodeGenerator:
39 | options = getattr(request, "param", [])
40 | return TablesGenerator(metadata, engine, options)
41 |
42 |
43 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
44 | def test_fancy_coltypes(generator: CodeGenerator) -> None:
45 | from pgvector.sqlalchemy.vector import VECTOR
46 |
47 | Table(
48 | "simple_items",
49 | generator.metadata,
50 | Column("enum", postgresql.ENUM("A", "B", name="blah", schema="someschema")),
51 | Column("bool", postgresql.BOOLEAN),
52 | Column("vector", VECTOR(3)),
53 | Column("number", NUMERIC(10, asdecimal=False)),
54 | Column("timestamp", TIMESTAMP_DECORATOR()),
55 | schema="someschema",
56 | )
57 |
58 | validate_code(
59 | generator.generate(),
60 | """\
61 | from tests.test_generator_tables import TIMESTAMP_DECORATOR
62 |
63 | from pgvector.sqlalchemy.vector import VECTOR
64 | from sqlalchemy import Boolean, Column, Enum, MetaData, Numeric, Table
65 |
66 | metadata = MetaData()
67 |
68 |
69 | t_simple_items = Table(
70 | 'simple_items', metadata,
71 | Column('enum', Enum('A', 'B', name='blah', schema='someschema')),
72 | Column('bool', Boolean),
73 | Column('vector', VECTOR(3)),
74 | Column('number', Numeric(10, asdecimal=False)),
75 | Column('timestamp', TIMESTAMP_DECORATOR),
76 | schema='someschema'
77 | )
78 | """,
79 | )
80 |
81 |
82 | def test_boolean_detection(generator: CodeGenerator) -> None:
83 | Table(
84 | "simple_items",
85 | generator.metadata,
86 | Column("bool1", INTEGER),
87 | Column("bool2", SMALLINT),
88 | Column("bool3", mysql.TINYINT),
89 | CheckConstraint("simple_items.bool1 IN (0, 1)"),
90 | CheckConstraint("simple_items.bool2 IN (0, 1)"),
91 | CheckConstraint("simple_items.bool3 IN (0, 1)"),
92 | )
93 |
94 | validate_code(
95 | generator.generate(),
96 | """\
97 | from sqlalchemy import Boolean, Column, MetaData, Table
98 |
99 | metadata = MetaData()
100 |
101 |
102 | t_simple_items = Table(
103 | 'simple_items', metadata,
104 | Column('bool1', Boolean),
105 | Column('bool2', Boolean),
106 | Column('bool3', Boolean)
107 | )
108 | """,
109 | )
110 |
111 |
112 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
113 | def test_arrays(generator: CodeGenerator) -> None:
114 | Table(
115 | "simple_items",
116 | generator.metadata,
117 | Column("dp_array", postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53))),
118 | Column("int_array", postgresql.ARRAY(INTEGER)),
119 | )
120 |
121 | validate_code(
122 | generator.generate(),
123 | """\
124 | from sqlalchemy import ARRAY, Column, Double, Integer, MetaData, Table
125 |
126 | metadata = MetaData()
127 |
128 |
129 | t_simple_items = Table(
130 | 'simple_items', metadata,
131 | Column('dp_array', ARRAY(Double(precision=53))),
132 | Column('int_array', ARRAY(Integer()))
133 | )
134 | """,
135 | )
136 |
137 |
138 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
139 | def test_jsonb(generator: CodeGenerator) -> None:
140 | Table(
141 | "simple_items",
142 | generator.metadata,
143 | Column("jsonb", postgresql.JSONB(astext_type=Text(50))),
144 | )
145 |
146 | validate_code(
147 | generator.generate(),
148 | """\
149 | from sqlalchemy import Column, MetaData, Table, Text
150 | from sqlalchemy.dialects.postgresql import JSONB
151 |
152 | metadata = MetaData()
153 |
154 |
155 | t_simple_items = Table(
156 | 'simple_items', metadata,
157 | Column('jsonb', JSONB(astext_type=Text(length=50)))
158 | )
159 | """,
160 | )
161 |
162 |
163 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
164 | def test_jsonb_default(generator: CodeGenerator) -> None:
165 | Table("simple_items", generator.metadata, Column("jsonb", postgresql.JSONB))
166 |
167 | validate_code(
168 | generator.generate(),
169 | """\
170 | from sqlalchemy import Column, MetaData, Table
171 | from sqlalchemy.dialects.postgresql import JSONB
172 |
173 | metadata = MetaData()
174 |
175 |
176 | t_simple_items = Table(
177 | 'simple_items', metadata,
178 | Column('jsonb', JSONB)
179 | )
180 | """,
181 | )
182 |
183 |
184 | def test_enum_detection(generator: CodeGenerator) -> None:
185 | Table(
186 | "simple_items",
187 | generator.metadata,
188 | Column("enum", VARCHAR(255)),
189 | CheckConstraint(r"simple_items.enum IN ('A', '\'B', 'C')"),
190 | )
191 |
192 | validate_code(
193 | generator.generate(),
194 | """\
195 | from sqlalchemy import Column, Enum, MetaData, Table
196 |
197 | metadata = MetaData()
198 |
199 |
200 | t_simple_items = Table(
201 | 'simple_items', metadata,
202 | Column('enum', Enum('A', "\\\\'B", 'C'))
203 | )
204 | """,
205 | )
206 |
207 |
208 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
209 | def test_domain_text(generator: CodeGenerator) -> None:
210 | Table(
211 | "simple_items",
212 | generator.metadata,
213 | Column(
214 | "postal_code",
215 | postgresql.DOMAIN(
216 | "us_postal_code",
217 | Text,
218 | constraint_name="valid_us_postal_code",
219 | not_null=False,
220 | check=text("VALUE ~ '^\\d{5}$' OR VALUE ~ '^\\d{5}-\\d{4}$'"),
221 | ),
222 | nullable=False,
223 | ),
224 | )
225 |
226 | validate_code(
227 | generator.generate(),
228 | """\
229 | from sqlalchemy import Column, MetaData, Table, Text, text
230 | from sqlalchemy.dialects.postgresql import DOMAIN
231 |
232 | metadata = MetaData()
233 |
234 |
235 | t_simple_items = Table(
236 | 'simple_items', metadata,
237 | Column('postal_code', DOMAIN('us_postal_code', Text(), \
238 | constraint_name='valid_us_postal_code', not_null=False, \
239 | check=text("VALUE ~ '^\\\\d{5}$' OR VALUE ~ '^\\\\d{5}-\\\\d{4}$'")), nullable=False)
240 | )
241 | """,
242 | )
243 |
244 |
245 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
246 | def test_domain_int(generator: CodeGenerator) -> None:
247 | Table(
248 | "simple_items",
249 | generator.metadata,
250 | Column(
251 | "n",
252 | postgresql.DOMAIN(
253 | "positive_int",
254 | INTEGER,
255 | constraint_name="positive",
256 | not_null=False,
257 | check=text("VALUE > 0"),
258 | ),
259 | nullable=False,
260 | ),
261 | )
262 |
263 | validate_code(
264 | generator.generate(),
265 | """\
266 | from sqlalchemy import Column, INTEGER, MetaData, Table, text
267 | from sqlalchemy.dialects.postgresql import DOMAIN
268 |
269 | metadata = MetaData()
270 |
271 |
272 | t_simple_items = Table(
273 | 'simple_items', metadata,
274 | Column('n', DOMAIN('positive_int', INTEGER(), \
275 | constraint_name='positive', not_null=False, \
276 | check=text('VALUE > 0')), nullable=False)
277 | )
278 | """,
279 | )
280 |
281 |
282 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
283 | def test_column_adaptation(generator: CodeGenerator) -> None:
284 | Table(
285 | "simple_items",
286 | generator.metadata,
287 | Column("id", postgresql.BIGINT),
288 | Column("length", postgresql.DOUBLE_PRECISION),
289 | )
290 |
291 | validate_code(
292 | generator.generate(),
293 | """\
294 | from sqlalchemy import BigInteger, Column, Double, MetaData, Table
295 |
296 | metadata = MetaData()
297 |
298 |
299 | t_simple_items = Table(
300 | 'simple_items', metadata,
301 | Column('id', BigInteger),
302 | Column('length', Double)
303 | )
304 | """,
305 | )
306 |
307 |
308 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
309 | def test_mysql_column_types(generator: CodeGenerator) -> None:
310 | Table(
311 | "simple_items",
312 | generator.metadata,
313 | Column("id", mysql.INTEGER),
314 | Column("name", mysql.VARCHAR(255)),
315 | Column("double", mysql.DOUBLE(1, 2)),
316 | Column("set", mysql.SET("one", "two")),
317 | )
318 |
319 | validate_code(
320 | generator.generate(),
321 | """\
322 | from sqlalchemy import Column, Integer, MetaData, String, Table
323 | from sqlalchemy.dialects.mysql import DOUBLE, SET
324 |
325 | metadata = MetaData()
326 |
327 |
328 | t_simple_items = Table(
329 | 'simple_items', metadata,
330 | Column('id', Integer),
331 | Column('name', String(255)),
332 | Column('double', DOUBLE(1, 2)),
333 | Column('set', SET('one', 'two'))
334 | )
335 | """,
336 | )
337 |
338 |
339 | def test_constraints(generator: CodeGenerator) -> None:
340 | Table(
341 | "simple_items",
342 | generator.metadata,
343 | Column("id", INTEGER),
344 | Column("number", INTEGER),
345 | CheckConstraint("number > 0"),
346 | UniqueConstraint("id", "number"),
347 | )
348 |
349 | validate_code(
350 | generator.generate(),
351 | """\
352 | from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table, \
353 | UniqueConstraint
354 |
355 | metadata = MetaData()
356 |
357 |
358 | t_simple_items = Table(
359 | 'simple_items', metadata,
360 | Column('id', Integer),
361 | Column('number', Integer),
362 | CheckConstraint('number > 0'),
363 | UniqueConstraint('id', 'number')
364 | )
365 | """,
366 | )
367 |
368 |
369 | def test_indexes(generator: CodeGenerator) -> None:
370 | simple_items = Table(
371 | "simple_items",
372 | generator.metadata,
373 | Column("id", INTEGER),
374 | Column("number", INTEGER),
375 | Column("text", VARCHAR),
376 | Index("ix_empty"),
377 | )
378 | simple_items.indexes.add(Index("ix_number", simple_items.c.number))
379 | simple_items.indexes.add(
380 | Index(
381 | "ix_text_number",
382 | simple_items.c.text,
383 | simple_items.c.number,
384 | unique=True,
385 | )
386 | )
387 | simple_items.indexes.add(Index("ix_text", simple_items.c.text, unique=True))
388 |
389 | validate_code(
390 | generator.generate(),
391 | """\
392 | from sqlalchemy import Column, Index, Integer, MetaData, String, Table
393 |
394 | metadata = MetaData()
395 |
396 |
397 | t_simple_items = Table(
398 | 'simple_items', metadata,
399 | Column('id', Integer),
400 | Column('number', Integer, index=True),
401 | Column('text', String, unique=True, index=True),
402 | Index('ix_empty'),
403 | Index('ix_text_number', 'text', 'number', unique=True)
404 | )
405 | """,
406 | )
407 |
408 |
409 | def test_table_comment(generator: CodeGenerator) -> None:
410 | Table(
411 | "simple",
412 | generator.metadata,
413 | Column("id", INTEGER, primary_key=True),
414 | comment="this is a 'comment'",
415 | )
416 |
417 | validate_code(
418 | generator.generate(),
419 | """\
420 | from sqlalchemy import Column, Integer, MetaData, Table
421 |
422 | metadata = MetaData()
423 |
424 |
425 | t_simple = Table(
426 | 'simple', metadata,
427 | Column('id', Integer, primary_key=True),
428 | comment="this is a 'comment'"
429 | )
430 | """,
431 | )
432 |
433 |
434 | def test_table_name_identifiers(generator: CodeGenerator) -> None:
435 | Table(
436 | "simple-items table",
437 | generator.metadata,
438 | Column("id", INTEGER, primary_key=True),
439 | )
440 |
441 | validate_code(
442 | generator.generate(),
443 | """\
444 | from sqlalchemy import Column, Integer, MetaData, Table
445 |
446 | metadata = MetaData()
447 |
448 |
449 | t_simple_items_table = Table(
450 | 'simple-items table', metadata,
451 | Column('id', Integer, primary_key=True)
452 | )
453 | """,
454 | )
455 |
456 |
457 | @pytest.mark.parametrize("generator", [["noindexes"]], indirect=True)
458 | def test_option_noindexes(generator: CodeGenerator) -> None:
459 | simple_items = Table(
460 | "simple_items",
461 | generator.metadata,
462 | Column("number", INTEGER),
463 | CheckConstraint("number > 2"),
464 | )
465 | simple_items.indexes.add(Index("idx_number", simple_items.c.number))
466 |
467 | validate_code(
468 | generator.generate(),
469 | """\
470 | from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table
471 |
472 | metadata = MetaData()
473 |
474 |
475 | t_simple_items = Table(
476 | 'simple_items', metadata,
477 | Column('number', Integer),
478 | CheckConstraint('number > 2')
479 | )
480 | """,
481 | )
482 |
483 |
484 | @pytest.mark.parametrize("generator", [["noconstraints"]], indirect=True)
485 | def test_option_noconstraints(generator: CodeGenerator) -> None:
486 | simple_items = Table(
487 | "simple_items",
488 | generator.metadata,
489 | Column("number", INTEGER),
490 | CheckConstraint("number > 2"),
491 | )
492 | simple_items.indexes.add(Index("ix_number", simple_items.c.number))
493 |
494 | validate_code(
495 | generator.generate(),
496 | """\
497 | from sqlalchemy import Column, Integer, MetaData, Table
498 |
499 | metadata = MetaData()
500 |
501 |
502 | t_simple_items = Table(
503 | 'simple_items', metadata,
504 | Column('number', Integer, index=True)
505 | )
506 | """,
507 | )
508 |
509 |
510 | @pytest.mark.parametrize("generator", [["nocomments"]], indirect=True)
511 | def test_option_nocomments(generator: CodeGenerator) -> None:
512 | Table(
513 | "simple",
514 | generator.metadata,
515 | Column("id", INTEGER, primary_key=True, comment="pk column comment"),
516 | comment="this is a 'comment'",
517 | )
518 |
519 | validate_code(
520 | generator.generate(),
521 | """\
522 | from sqlalchemy import Column, Integer, MetaData, Table
523 |
524 | metadata = MetaData()
525 |
526 |
527 | t_simple = Table(
528 | 'simple', metadata,
529 | Column('id', Integer, primary_key=True)
530 | )
531 | """,
532 | )
533 |
534 |
535 | @pytest.mark.parametrize(
536 | "persisted, extra_args",
537 | [(None, ""), (False, ", persisted=False"), (True, ", persisted=True")],
538 | )
539 | def test_computed_column(
540 | generator: CodeGenerator, persisted: bool | None, extra_args: str
541 | ) -> None:
542 | Table(
543 | "computed",
544 | generator.metadata,
545 | Column("id", INTEGER, primary_key=True),
546 | Column("computed", INTEGER, Computed("1 + 2", persisted=persisted)),
547 | )
548 |
549 | validate_code(
550 | generator.generate(),
551 | f"""\
552 | from sqlalchemy import Column, Computed, Integer, MetaData, Table
553 |
554 | metadata = MetaData()
555 |
556 |
557 | t_computed = Table(
558 | 'computed', metadata,
559 | Column('id', Integer, primary_key=True),
560 | Column('computed', Integer, Computed('1 + 2'{extra_args}))
561 | )
562 | """,
563 | )
564 |
565 |
566 | def test_schema(generator: CodeGenerator) -> None:
567 | Table(
568 | "simple_items",
569 | generator.metadata,
570 | Column("name", VARCHAR),
571 | schema="testschema",
572 | )
573 |
574 | validate_code(
575 | generator.generate(),
576 | """\
577 | from sqlalchemy import Column, MetaData, String, Table
578 |
579 | metadata = MetaData()
580 |
581 |
582 | t_simple_items = Table(
583 | 'simple_items', metadata,
584 | Column('name', String),
585 | schema='testschema'
586 | )
587 | """,
588 | )
589 |
590 |
591 | def test_foreign_key_options(generator: CodeGenerator) -> None:
592 | Table(
593 | "simple_items",
594 | generator.metadata,
595 | Column(
596 | "name",
597 | VARCHAR,
598 | ForeignKey(
599 | "simple_items.name",
600 | ondelete="CASCADE",
601 | onupdate="CASCADE",
602 | deferrable=True,
603 | initially="DEFERRED",
604 | ),
605 | ),
606 | )
607 |
608 | validate_code(
609 | generator.generate(),
610 | """\
611 | from sqlalchemy import Column, ForeignKey, MetaData, String, Table
612 |
613 | metadata = MetaData()
614 |
615 |
616 | t_simple_items = Table(
617 | 'simple_items', metadata,
618 | Column('name', String, ForeignKey('simple_items.name', \
619 | ondelete='CASCADE', onupdate='CASCADE', deferrable=True, initially='DEFERRED'))
620 | )
621 | """,
622 | )
623 |
624 |
625 | def test_pk_default(generator: CodeGenerator) -> None:
626 | Table(
627 | "simple_items",
628 | generator.metadata,
629 | Column(
630 | "id",
631 | INTEGER,
632 | primary_key=True,
633 | server_default=text("uuid_generate_v4()"),
634 | ),
635 | )
636 |
637 | validate_code(
638 | generator.generate(),
639 | """\
640 | from sqlalchemy import Column, Integer, MetaData, Table, text
641 |
642 | metadata = MetaData()
643 |
644 |
645 | t_simple_items = Table(
646 | 'simple_items', metadata,
647 | Column('id', Integer, primary_key=True, \
648 | server_default=text('uuid_generate_v4()'))
649 | )
650 | """,
651 | )
652 |
653 |
654 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
655 | def test_mysql_timestamp(generator: CodeGenerator) -> None:
656 | Table(
657 | "simple",
658 | generator.metadata,
659 | Column("id", INTEGER, primary_key=True),
660 | Column("timestamp", mysql.TIMESTAMP),
661 | )
662 |
663 | validate_code(
664 | generator.generate(),
665 | """\
666 | from sqlalchemy import Column, Integer, MetaData, TIMESTAMP, Table
667 |
668 | metadata = MetaData()
669 |
670 |
671 | t_simple = Table(
672 | 'simple', metadata,
673 | Column('id', Integer, primary_key=True),
674 | Column('timestamp', TIMESTAMP)
675 | )
676 | """,
677 | )
678 |
679 |
680 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
681 | def test_mysql_integer_display_width(generator: CodeGenerator) -> None:
682 | Table(
683 | "simple_items",
684 | generator.metadata,
685 | Column("id", INTEGER, primary_key=True),
686 | Column("number", mysql.INTEGER(11)),
687 | )
688 |
689 | validate_code(
690 | generator.generate(),
691 | """\
692 | from sqlalchemy import Column, Integer, MetaData, Table
693 | from sqlalchemy.dialects.mysql import INTEGER
694 |
695 | metadata = MetaData()
696 |
697 |
698 | t_simple_items = Table(
699 | 'simple_items', metadata,
700 | Column('id', Integer, primary_key=True),
701 | Column('number', INTEGER(11))
702 | )
703 | """,
704 | )
705 |
706 |
707 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
708 | def test_mysql_tinytext(generator: CodeGenerator) -> None:
709 | Table(
710 | "simple_items",
711 | generator.metadata,
712 | Column("id", INTEGER, primary_key=True),
713 | Column("my_tinytext", mysql.TINYTEXT),
714 | )
715 |
716 | validate_code(
717 | generator.generate(),
718 | """\
719 | from sqlalchemy import Column, Integer, MetaData, Table
720 | from sqlalchemy.dialects.mysql import TINYTEXT
721 |
722 | metadata = MetaData()
723 |
724 |
725 | t_simple_items = Table(
726 | 'simple_items', metadata,
727 | Column('id', Integer, primary_key=True),
728 | Column('my_tinytext', TINYTEXT)
729 | )
730 | """,
731 | )
732 |
733 |
734 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
735 | def test_mysql_mediumtext(generator: CodeGenerator) -> None:
736 | Table(
737 | "simple_items",
738 | generator.metadata,
739 | Column("id", INTEGER, primary_key=True),
740 | Column("my_mediumtext", mysql.MEDIUMTEXT),
741 | )
742 |
743 | validate_code(
744 | generator.generate(),
745 | """\
746 | from sqlalchemy import Column, Integer, MetaData, Table
747 | from sqlalchemy.dialects.mysql import MEDIUMTEXT
748 |
749 | metadata = MetaData()
750 |
751 |
752 | t_simple_items = Table(
753 | 'simple_items', metadata,
754 | Column('id', Integer, primary_key=True),
755 | Column('my_mediumtext', MEDIUMTEXT)
756 | )
757 | """,
758 | )
759 |
760 |
761 | @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
762 | def test_mysql_longtext(generator: CodeGenerator) -> None:
763 | Table(
764 | "simple_items",
765 | generator.metadata,
766 | Column("id", INTEGER, primary_key=True),
767 | Column("my_longtext", mysql.LONGTEXT),
768 | )
769 |
770 | validate_code(
771 | generator.generate(),
772 | """\
773 | from sqlalchemy import Column, Integer, MetaData, Table
774 | from sqlalchemy.dialects.mysql import LONGTEXT
775 |
776 | metadata = MetaData()
777 |
778 |
779 | t_simple_items = Table(
780 | 'simple_items', metadata,
781 | Column('id', Integer, primary_key=True),
782 | Column('my_longtext', LONGTEXT)
783 | )
784 | """,
785 | )
786 |
787 |
788 | def test_schema_boolean(generator: CodeGenerator) -> None:
789 | Table(
790 | "simple_items",
791 | generator.metadata,
792 | Column("bool1", INTEGER),
793 | CheckConstraint("testschema.simple_items.bool1 IN (0, 1)"),
794 | schema="testschema",
795 | )
796 |
797 | validate_code(
798 | generator.generate(),
799 | """\
800 | from sqlalchemy import Boolean, Column, MetaData, Table
801 |
802 | metadata = MetaData()
803 |
804 |
805 | t_simple_items = Table(
806 | 'simple_items', metadata,
807 | Column('bool1', Boolean),
808 | schema='testschema'
809 | )
810 | """,
811 | )
812 |
813 |
814 | def test_server_default_multiline(generator: CodeGenerator) -> None:
815 | Table(
816 | "simple_items",
817 | generator.metadata,
818 | Column(
819 | "id",
820 | INTEGER,
821 | primary_key=True,
822 | server_default=text(
823 | dedent(
824 | """\
825 | /*Comment*/
826 | /*Next line*/
827 | something()"""
828 | )
829 | ),
830 | ),
831 | )
832 |
833 | validate_code(
834 | generator.generate(),
835 | """\
836 | from sqlalchemy import Column, Integer, MetaData, Table, text
837 |
838 | metadata = MetaData()
839 |
840 |
841 | t_simple_items = Table(
842 | 'simple_items', metadata,
843 | Column('id', Integer, primary_key=True, server_default=\
844 | text('/*Comment*/\\n/*Next line*/\\nsomething()'))
845 | )
846 | """,
847 | )
848 |
849 |
850 | def test_server_default_colon(generator: CodeGenerator) -> None:
851 | Table(
852 | "simple_items",
853 | generator.metadata,
854 | Column("problem", VARCHAR, server_default=text("':001'")),
855 | )
856 |
857 | validate_code(
858 | generator.generate(),
859 | """\
860 | from sqlalchemy import Column, MetaData, String, Table, text
861 |
862 | metadata = MetaData()
863 |
864 |
865 | t_simple_items = Table(
866 | 'simple_items', metadata,
867 | Column('problem', String, server_default=text("':001'"))
868 | )
869 | """,
870 | )
871 |
872 |
873 | def test_null_type(generator: CodeGenerator) -> None:
874 | Table(
875 | "simple_items",
876 | generator.metadata,
877 | Column("problem", NullType),
878 | )
879 |
880 | validate_code(
881 | generator.generate(),
882 | """\
883 | from sqlalchemy import Column, MetaData, Table
884 | from sqlalchemy.sql.sqltypes import NullType
885 |
886 | metadata = MetaData()
887 |
888 |
889 | t_simple_items = Table(
890 | 'simple_items', metadata,
891 | Column('problem', NullType)
892 | )
893 | """,
894 | )
895 |
896 |
897 | def test_identity_column(generator: CodeGenerator) -> None:
898 | Table(
899 | "simple_items",
900 | generator.metadata,
901 | Column(
902 | "id",
903 | INTEGER,
904 | primary_key=True,
905 | server_default=Identity(start=1, increment=2),
906 | ),
907 | )
908 |
909 | validate_code(
910 | generator.generate(),
911 | """\
912 | from sqlalchemy import Column, Identity, Integer, MetaData, Table
913 |
914 | metadata = MetaData()
915 |
916 |
917 | t_simple_items = Table(
918 | 'simple_items', metadata,
919 | Column('id', Integer, Identity(start=1, increment=2), primary_key=True)
920 | )
921 | """,
922 | )
923 |
924 |
925 | def test_multiline_column_comment(generator: CodeGenerator) -> None:
926 | Table(
927 | "simple_items",
928 | generator.metadata,
929 | Column("id", INTEGER, comment="This\nis a multi-line\ncomment"),
930 | )
931 |
932 | validate_code(
933 | generator.generate(),
934 | """\
935 | from sqlalchemy import Column, Integer, MetaData, Table
936 |
937 | metadata = MetaData()
938 |
939 |
940 | t_simple_items = Table(
941 | 'simple_items', metadata,
942 | Column('id', Integer, comment='This\\nis a multi-line\\ncomment')
943 | )
944 | """,
945 | )
946 |
947 |
948 | def test_multiline_table_comment(generator: CodeGenerator) -> None:
949 | Table(
950 | "simple_items",
951 | generator.metadata,
952 | Column("id", INTEGER),
953 | comment="This\nis a multi-line\ncomment",
954 | )
955 |
956 | validate_code(
957 | generator.generate(),
958 | """\
959 | from sqlalchemy import Column, Integer, MetaData, Table
960 |
961 | metadata = MetaData()
962 |
963 |
964 | t_simple_items = Table(
965 | 'simple_items', metadata,
966 | Column('id', Integer),
967 | comment='This\\nis a multi-line\\ncomment'
968 | )
969 | """,
970 | )
971 |
972 |
973 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
974 | def test_postgresql_sequence_standard_name(generator: CodeGenerator) -> None:
975 | Table(
976 | "simple_items",
977 | generator.metadata,
978 | Column(
979 | "id",
980 | INTEGER,
981 | primary_key=True,
982 | server_default=text("nextval('simple_items_id_seq'::regclass)"),
983 | ),
984 | )
985 |
986 | validate_code(
987 | generator.generate(),
988 | """\
989 | from sqlalchemy import Column, Integer, MetaData, Table
990 |
991 | metadata = MetaData()
992 |
993 |
994 | t_simple_items = Table(
995 | 'simple_items', metadata,
996 | Column('id', Integer, primary_key=True)
997 | )
998 | """,
999 | )
1000 |
1001 |
1002 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
1003 | def test_postgresql_sequence_nonstandard_name(generator: CodeGenerator) -> None:
1004 | Table(
1005 | "simple_items",
1006 | generator.metadata,
1007 | Column(
1008 | "id",
1009 | INTEGER,
1010 | primary_key=True,
1011 | server_default=text("nextval('test_seq'::regclass)"),
1012 | ),
1013 | )
1014 |
1015 | validate_code(
1016 | generator.generate(),
1017 | """\
1018 | from sqlalchemy import Column, Integer, MetaData, Sequence, Table
1019 |
1020 | metadata = MetaData()
1021 |
1022 |
1023 | t_simple_items = Table(
1024 | 'simple_items', metadata,
1025 | Column('id', Integer, Sequence('test_seq'), primary_key=True)
1026 | )
1027 | """,
1028 | )
1029 |
1030 |
1031 | @pytest.mark.parametrize(
1032 | "schemaname, seqname",
1033 | [
1034 | pytest.param("myschema", "test_seq"),
1035 | pytest.param("myschema", '"test_seq"'),
1036 | pytest.param('"my.schema"', "test_seq"),
1037 | pytest.param('"my.schema"', '"test_seq"'),
1038 | ],
1039 | )
1040 | @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
1041 | def test_postgresql_sequence_with_schema(
1042 | generator: CodeGenerator, schemaname: str, seqname: str
1043 | ) -> None:
1044 | expected_schema = schemaname.strip('"')
1045 | Table(
1046 | "simple_items",
1047 | generator.metadata,
1048 | Column(
1049 | "id",
1050 | INTEGER,
1051 | primary_key=True,
1052 | server_default=text(f"nextval('{schemaname}.{seqname}'::regclass)"),
1053 | ),
1054 | schema=expected_schema,
1055 | )
1056 |
1057 | validate_code(
1058 | generator.generate(),
1059 | f"""\
1060 | from sqlalchemy import Column, Integer, MetaData, Sequence, Table
1061 |
1062 | metadata = MetaData()
1063 |
1064 |
1065 | t_simple_items = Table(
1066 | 'simple_items', metadata,
1067 | Column('id', Integer, Sequence('test_seq', \
1068 | schema='{expected_schema}'), primary_key=True),
1069 | schema='{expected_schema}'
1070 | )
1071 | """,
1072 | )
1073 |
--------------------------------------------------------------------------------