├── .github └── workflows │ ├── cd.yml │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.rst ├── pyproject.toml ├── sqlalchemy_schemadisplay ├── __init__.py ├── db_diagram.py ├── model_diagram.py └── utils.py └── tests ├── __init__.py ├── test_schema_graph.py ├── test_uml_graph.py └── utils.py /.github/workflows/cd.yml: -------------------------------------------------------------------------------- 1 | 2 | name: CD 3 | 4 | on: 5 | push: 6 | tags-ignore: 7 | - 1.* 8 | 9 | env: 10 | FORCE_COLOR: 1 11 | 12 | jobs: 13 | pre-commit: 14 | runs-on: ubuntu-latest 15 | timeout-minutes: 90 16 | strategy: 17 | matrix: 18 | python: ["3.8"] 19 | steps: 20 | - uses: actions/checkout@v3.3.0 21 | - name: "Install dependencies" 22 | run: sudo apt-get -y install graphviz 23 | - name: Cache python dependencies 24 | id: cache-pip 25 | uses: actions/cache@v3.2.4 26 | with: 27 | path: ~/.cache/pip 28 | key: pip-${{ matrix.python }}-tests-${{ hashFiles('**/setup.json') }} 29 | restore-keys: pip-${{ matrix.python }}-tests- 30 | - name: Set up Python 31 | uses: actions/setup-python@v4.5.0 32 | with: 33 | python-version: ${{ matrix.python }} 34 | - name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version. 35 | run: pip install virtualenv>20 36 | - name: Install Tox 37 | run: pip install tox 38 | - name: Run pre-commit in Tox 39 | run: tox -e pre-commit 40 | 41 | test: 42 | runs-on: "ubuntu-latest" 43 | strategy: 44 | matrix: 45 | python: ["3.8", "3.9", "3.10"] 46 | steps: 47 | - uses: actions/checkout@v3.3.0 48 | - name: "Install dependencies" 49 | run: sudo apt-get -y install graphviz 50 | - uses: actions/checkout@v3.3.0 51 | - name: Cache python dependencies 52 | id: cache-pip 53 | uses: actions/cache@v3.2.4 54 | with: 55 | path: ~/.cache/pip 56 | key: pip-${{ matrix.python }}-tests-${{ hashFiles('**/setup.json') }} 57 | restore-keys: pip-${{ matrix.python }}-tests- 58 | - name: Set up Python 59 | uses: actions/setup-python@v4.5.0 60 | with: 61 | python-version: ${{ matrix.python }} 62 | - name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version. 63 | run: pip install virtualenv>20 64 | - name: Install Tox 65 | run: pip install tox 66 | - name: "Test with tox" 67 | run: | 68 | tox -e py${{ matrix.python }} -- tests/ --cov=./sqlalchemy_schemadisplay --cov-append --cov-report=xml --cov-report=term-missing 69 | 70 | release: 71 | if: "github.event_name == 'push' && startsWith(github.ref, 'refs/tags')" 72 | runs-on: "ubuntu-latest" 73 | needs: "test" 74 | steps: 75 | - uses: "actions/checkout@v3" 76 | 77 | - name: "Install dependencies" 78 | run: | 79 | python -m pip install --upgrade pip 80 | pip install build 81 | 82 | - name: "Build" 83 | run: | 84 | python -m build 85 | git status --ignored 86 | 87 | - name: "Publish" 88 | uses: "pypa/gh-action-pypi-publish@release/v1" 89 | with: 90 | user: "__token__" 91 | password: "${{ secrets.TEST_PYPI_API_TOKEN }}" 92 | repository_url: "https://test.pypi.org/legacy/" 93 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | env: 6 | FORCE_COLOR: 1 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | timeout-minutes: 90 12 | strategy: 13 | matrix: 14 | python: ["3.8"] 15 | steps: 16 | - uses: actions/checkout@v3.3.0 17 | - name: "Install dependencies" 18 | run: sudo apt-get -y install graphviz 19 | - name: Cache python dependencies 20 | id: cache-pip 21 | uses: actions/cache@v3.2.4 22 | with: 23 | path: ~/.cache/pip 24 | key: pip-${{ matrix.python }}-tests-${{ hashFiles('**/setup.json') }} 25 | restore-keys: pip-${{ matrix.python }}-tests- 26 | - name: Set up Python 27 | uses: actions/setup-python@v4.5.0 28 | with: 29 | python-version: ${{ matrix.python }} 30 | - name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version. 31 | run: pip install virtualenv>20 32 | - name: Install Tox 33 | run: pip install tox 34 | - name: Run pre-commit in Tox 35 | run: tox -e pre-commit 36 | test: 37 | runs-on: "ubuntu-latest" 38 | strategy: 39 | matrix: 40 | python: ["3.8", "3.9", "3.10"] 41 | steps: 42 | - uses: actions/checkout@v3.3.0 43 | - name: "Install dependencies" 44 | run: sudo apt-get -y install graphviz 45 | - uses: actions/checkout@v3.3.0 46 | - name: Cache python dependencies 47 | id: cache-pip 48 | uses: actions/cache@v3.2.4 49 | with: 50 | path: ~/.cache/pip 51 | key: pip-${{ matrix.python }}-tests-${{ hashFiles('**/setup.json') }} 52 | restore-keys: pip-${{ matrix.python }}-tests- 53 | - name: Set up Python 54 | uses: actions/setup-python@v4.5.0 55 | with: 56 | python-version: ${{ matrix.python }} 57 | - name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version. 58 | run: pip install virtualenv>20 59 | - name: Install Tox 60 | run: pip install tox 61 | - name: "Test with tox" 62 | run: | 63 | tox -e py${{ matrix.python }} -- tests/ --cov=./sqlalchemy_schemadisplay --cov-append --cov-report=xml --cov-report=term-missing 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /.tox/ 3 | /.coverage 4 | /dist/ 5 | /htmlcov/ 6 | .vscode/* 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install pre-commit hooks via 2 | # pre-commit install 3 | exclude: > 4 | (?x)^(tests/) 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.1.0 9 | hooks: 10 | - id: check-json 11 | - id: check-yaml 12 | - id: end-of-file-fixer 13 | - id: trailing-whitespace 14 | 15 | - repo: https://github.com/asottile/pyupgrade 16 | rev: v2.31.1 17 | hooks: 18 | - id: pyupgrade 19 | args: [--py38-plus] 20 | 21 | - repo: https://github.com/pycqa/isort 22 | rev: 5.12.0 23 | hooks: 24 | - id: isort 25 | 26 | - repo: https://github.com/psf/black 27 | rev: 22.3.0 28 | hooks: 29 | - id: black 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is the MIT license: http://www.opensource.org/licenses/mit-license.php 2 | 3 | Copyright (c) 2010-2014 Ants Aasma and contributors . 4 | SQLAlchemy is a trademark of Michael Bayer. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 7 | software and associated documentation files (the "Software"), to deal in the Software 8 | without restriction, including without limitation the rights to use, copy, modify, merge, 9 | publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons 10 | to whom the Software is furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all copies or 13 | substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 16 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 17 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 18 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 19 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE *.ini *.rst 2 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | sqlalchemy_schemadisplay 2 | ======================== 3 | 4 | Turn SQLAlchemy DB Model into a graph. 5 | 6 | .. image:: https://img.shields.io/pypi/v/sqlalchemy_schemadisplay 7 | :alt: PyPI 8 | :target: https://pypi.org/project/sqlalchemy_schemadisplay 9 | 10 | 11 | See `SQLAlchemy wiki `_ for the previous version of this doc. 12 | 13 | Usage 14 | ===== 15 | 16 | You will need atleast SQLAlchemy and pydot along with graphviz for this. Graphviz-cairo is highly recommended to get tolerable image quality. If PIL and an image viewer are available then there are functions to automatically show the image. Some of the stuff, specifically loading list of tables from a database via a mapper and reflecting indexes currently only work on postgres. 17 | 18 | This is an example of database entity diagram generation: 19 | 20 | .. code-block:: python 21 | 22 | from sqlalchemy import MetaData, create_engine 23 | from sqlalchemy_schemadisplay import create_schema_graph 24 | 25 | # create the pydot graph object by autoloading all tables via a bound metadata object 26 | graph = create_schema_graph( 27 | engine=create_engine('postgresql+psycopg2://user:pwd@host/database'), 28 | metadata=MetaData('postgres://user:pwd@host/database'), 29 | show_datatypes=False, # The image would get nasty big if we'd show the datatypes 30 | show_indexes=False, # ditto for indexes 31 | rankdir='LR', # From left to right (instead of top to bottom) 32 | concentrate=False # Don't try to join the relation lines together 33 | ) 34 | graph.write_png('dbschema.png') # write out the file 35 | 36 | 37 | And an UML class diagram from a model: 38 | 39 | .. code-block:: python 40 | 41 | from myapp import model 42 | from sqlalchemy_schemadisplay import create_uml_graph 43 | from sqlalchemy.orm import class_mapper 44 | 45 | # lets find all the mappers in our model 46 | mappers = [model.__mapper__] 47 | for attr in dir(model): 48 | if attr[0] == '_': continue 49 | try: 50 | cls = getattr(model, attr) 51 | mappers.append(cls.property.entity) 52 | except: 53 | pass 54 | 55 | # pass them to the function and set some formatting options 56 | graph = create_uml_graph(mappers, 57 | show_operations=False, # not necessary in this case 58 | show_multiplicity_one=False # some people like to see the ones, some don't 59 | ) 60 | graph.write_png('schema.png') # write out the file 61 | 62 | 63 | Changelog 64 | ========= 65 | 66 | 2.0 - 2024-02-15 67 | ---------------- 68 | 69 | - Compatibility with SQLAlchemy 2.x [Jonathan Chico] 70 | 71 | - Python requirements changed to >= 3.8 [zlopez] 72 | 73 | 1.4 - 2024-02-15 74 | ---------------- 75 | 76 | Last release to support Python 2. 77 | 78 | - Limit SQLAlchemy dependency to < 2.0 to fix installation for Python 2 [abitrolly - Anatoli Babenia] 79 | 80 | - Set dir kwarg in Edge instantiation to 'both' in order to show arrow heads and arrow tails. 81 | [bkrn - Aaron Di Silvestro] 82 | 83 | - Add 'show_column_keys' kwarg to 'create_schema_graph' to allow a PK/FK suffix to be added to columns that are primary keys/foreign keys respectively [cchrysostomou - Constantine Chrysostomou] 84 | 85 | - Add 'restrict_tables' kwarg to 'create_schema_graph' to restrict the desired tables we want to generate via graphviz and show relationships for [cchrysostomou - Constantine Chrysostomou] 86 | 87 | 88 | 1.3 - 2016-01-27 89 | ---------------- 90 | 91 | - Fix warning about illegal attribute in uml generation by using correct 92 | attribute. 93 | [electrocucaracha - Victor Morales] 94 | 95 | - Use MIT license 96 | [fschulze] 97 | 98 | 99 | 1.2 - 2014-03-02 100 | ---------------- 101 | 102 | - Compatibility with SQLAlchemy 0.9. 103 | [fschulze] 104 | 105 | - Compatibility with SQLAlchemy 0.8. 106 | [Surgo - Kosei Kitahara] 107 | 108 | - Leave tables out even when a foreign key points to them but they are not in 109 | the table list. 110 | [tiagosab - Tiago Saboga] 111 | 112 | 113 | 1.1 - 2011-10-12 114 | ---------------- 115 | 116 | - New option to skip inherited attributes. 117 | [nouri] 118 | 119 | - Quote class name, because some names like 'Node' confuse dot. 120 | [nouri - Daniel Nouri] 121 | 122 | 123 | 1.0 - 2011-01-07 124 | ---------------- 125 | 126 | - Initial release 127 | [fschulze - Florian Schulze] 128 | 129 | - Original releases as recipe on SQLAlchemy Wiki by Ants Aasma 130 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.4,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "sqlalchemy_schemadisplay" 7 | dynamic = ["version", "description"] 8 | authors = [{name = "Florian Schulze", email = "florian.schulze@gmx.net"}] 9 | readme = "README.rst" 10 | license = {file = "LICENSE"} 11 | classifiers = [ 12 | "Development Status :: 4 - Beta", 13 | "Development Status :: 5 - Production/Stable", 14 | "Intended Audience :: Developers", 15 | "License :: OSI Approved :: MIT License", 16 | "Programming Language :: Python", 17 | "Programming Language :: Python :: Implementation :: CPython", 18 | "Topic :: Database :: Front-Ends", 19 | "Operating System :: OS Independent" 20 | ] 21 | keywords = ["aiida", "workflows", "lammps"] 22 | requires-python = ">=3.8" 23 | dependencies = [ 24 | 'setuptools', 25 | 'sqlalchemy>=2.0,<3', 26 | 'pydot', 27 | 'Pillow' 28 | ] 29 | 30 | [project.urls] 31 | Documentation = "https://github.com/fschulze/sqlalchemy_schemadisplay/blob/master/README.rst" 32 | Source = "https://github.com/fschulze/sqlalchemy_schemadisplay" 33 | 34 | [project.optional-dependencies] 35 | testing = [ 36 | "attrs>=17.4.0", 37 | "pgtest", 38 | "pytest", 39 | "pytest-cov", 40 | "coverage", 41 | "pytest-timeout", 42 | "pytest-regressions" 43 | ] 44 | 45 | pre-commit = [ 46 | "pre-commit", 47 | "tox>=3.23.0", 48 | "virtualenv>20" 49 | ] 50 | 51 | [tool.flit.module] 52 | name = "sqlalchemy_schemadisplay" 53 | 54 | [tool.flit.sdist] 55 | exclude = [ 56 | "docs/", 57 | "tests/", 58 | ] 59 | 60 | [tool.coverage.run] 61 | # Configuration of [coverage.py](https://coverage.readthedocs.io) 62 | # reporting which lines of your plugin are covered by tests 63 | source=["sqlalchemy_schemadisplay"] 64 | 65 | [tool.isort] 66 | skip = ["venv"] 67 | # Force imports to be sorted by module, independent of import type 68 | force_sort_within_sections = true 69 | # Group first party and local folder imports together 70 | no_lines_before = ["LOCALFOLDER"] 71 | 72 | # Configure isort to work without access to site-packages 73 | known_first_party = ["sqlalchemy_schemadisplay"] 74 | 75 | [tool.tox] 76 | legacy_tox_ini = """ 77 | [tox] 78 | envlist = pre-commit,py{3.8,3.9,3.10} 79 | requires = virtualenv >= 20 80 | isolated_build = True 81 | 82 | [testenv] 83 | commands = pytest {posargs} 84 | extras = testing 85 | 86 | [testenv:pre-commit] 87 | allowlist_externals = bash 88 | commands = bash -ec 'pre-commit run --all-files || ( git diff; git status; exit 1; )' 89 | extras = 90 | pre-commit 91 | tests 92 | 93 | [flake8] 94 | max-line-length = 140 95 | import-order-style = edited 96 | [pycodestyle] 97 | max-line-length = 140 98 | """ 99 | -------------------------------------------------------------------------------- /sqlalchemy_schemadisplay/__init__.py: -------------------------------------------------------------------------------- 1 | """Package for the generation of diagrams based on SQLAlchemy ORM models and or the database itself""" 2 | from .db_diagram import create_schema_graph 3 | from .model_diagram import create_uml_graph 4 | from .utils import show_schema_graph, show_uml_graph 5 | 6 | __version__ = "2.0" 7 | 8 | __all__ = ( 9 | "create_schema_graph", 10 | "create_uml_graph", 11 | "show_schema_graph", 12 | "show_uml_graph", 13 | ) 14 | -------------------------------------------------------------------------------- /sqlalchemy_schemadisplay/db_diagram.py: -------------------------------------------------------------------------------- 1 | """Set of functions to generate the diagram of the actual database""" 2 | from typing import List, Union 3 | 4 | import pydot 5 | from sqlalchemy import Column, ForeignKeyConstraint, MetaData, Table, text 6 | from sqlalchemy.dialects.postgresql.base import PGDialect 7 | from sqlalchemy.engine import Engine 8 | 9 | 10 | def _render_table_html( 11 | table: Table, 12 | engine: Engine, 13 | show_indexes: bool, 14 | show_datatypes: bool, 15 | show_column_keys: bool, 16 | show_schema_name: bool, 17 | format_schema_name: dict, 18 | format_table_name: dict, 19 | ) -> str: 20 | # pylint: disable=too-many-locals,too-many-arguments 21 | """Create a rendering of a table in the database 22 | 23 | Args: 24 | table (sqlalchemy.Table): SqlAlchemy table which is going to be rendered. 25 | engine (sqlalchemy.engine.Engine): SqlAlchemy database engine to connect to the database. 26 | show_indexes (bool): Whether to display the index column in the table 27 | show_datatypes (bool): Whether to display the type of the columns in the table 28 | show_column_keys (bool): If true then add a PK/FK suffix to columns names that \ 29 | are primary and foreign keys 30 | show_schema_name (bool): If true, then prepend '.' to the table \ 31 | name resulting in '.' 32 | format_schema_name (dict): If provided, allowed keys include: \ 33 | 'color' (hex color code incl #), 'fontsize' as a float, and 'bold' and 'italics' \ 34 | as bools 35 | format_table_name (dict): If provided, allowed keys include: \ 36 | 'color' (hex color code incl #), 'fontsize' as a float, and 'bold' and 'italics' as \ 37 | bools 38 | 39 | Returns: 40 | str: html string with the rendering of the table 41 | """ 42 | # add in (PK) OR (FK) suffixes to column names that are considered to be primary key or 43 | # foreign key 44 | use_column_key_attr = hasattr(ForeignKeyConstraint, "column_keys") 45 | # sqlalchemy > 1.0 uses column_keys to return list of strings for foreign keys,previously 46 | # was columns 47 | if show_column_keys: 48 | if use_column_key_attr: 49 | # sqlalchemy > 1.0 50 | fk_col_names = { 51 | h for f in table.foreign_key_constraints for h in f.columns.keys() 52 | } 53 | else: 54 | # sqlalchemy pre 1.0? 55 | fk_col_names = { 56 | h.name for f in table.foreign_keys for h in f.constraint.columns 57 | } 58 | # fk_col_names = set([h for f in table.foreign_key_constraints for h in f.columns.keys()]) 59 | pk_col_names = set(list(table.primary_key.columns.keys())) 60 | else: 61 | fk_col_names = set() 62 | pk_col_names = set() 63 | 64 | def format_col_type(col: Column) -> str: 65 | """Get the type of the column as a string 66 | 67 | Args: 68 | col (Column): SqlAlchemy column of the table that is being rendered 69 | 70 | Returns: 71 | str: column type 72 | """ 73 | try: 74 | return col.type.get_col_spec() 75 | except (AttributeError, NotImplementedError): 76 | return str(col.type) 77 | 78 | def format_col_str(col: Column) -> str: 79 | """Generate the column name so that it takes into account any possible suffix. 80 | 81 | Args: 82 | col (sqlalchemy.Column): SqlAlchemy column of the table that is being rendered 83 | 84 | Returns: 85 | str: name of the column with the appropriate suffixes 86 | """ 87 | # add in (PK) OR (FK) suffixes to column names that are considered to be primary key 88 | # or foreign key 89 | suffix = ( 90 | "(FK)" 91 | if col.name in fk_col_names 92 | else "(PK)" 93 | if col.name in pk_col_names 94 | else "" 95 | ) 96 | if show_datatypes: 97 | return f"- {col.name + suffix} : {format_col_type(col)}" 98 | return f"- {col.name + suffix}" 99 | 100 | def format_name(obj_name: str, format_dict: Union[dict, None]) -> str: 101 | """Format the name of the object so that it is rendered differently. 102 | 103 | Args: 104 | obj_name (str): name of the object being rendered 105 | format_dict (Union[dict,None]): dictionary with the rendering options. \ 106 | If None nothing is done 107 | 108 | Returns: 109 | str: formatted name of the object 110 | """ 111 | # Check if format_dict was provided 112 | if format_dict is not None: 113 | # Should color be checked? 114 | # Could use /^#([A-Fa-f0-9]{8}|[A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$/ 115 | _color = format_dict.get("color") if "color" in format_dict else "initial" 116 | _point_size = ( 117 | float(format_dict["fontsize"]) 118 | if "fontsize" in format_dict 119 | else "initial" 120 | ) 121 | _bold = "" if format_dict.get("bold") else "" 122 | _italic = "" if format_dict.get("italics") else "" 123 | _text = f'' 125 | _text += f'{_bold}{_italic}{obj_name}{"" if format_dict.get("italics") else ""}' 126 | _text += f'{"" if format_dict.get("bold") else ""}' 127 | 128 | return _text 129 | return obj_name 130 | 131 | schema_str = "" 132 | if show_schema_name and hasattr(table, "schema") and table.schema is not None: 133 | # Build string for schema name, empty if show_schema_name is False 134 | schema_str = format_name(table.schema, format_schema_name) 135 | table_str = format_name(table.name, format_table_name) 136 | 137 | # Assemble table header 138 | html = '<
' 140 | html += '' 141 | 142 | html += "".join( 143 | f'' 144 | for col in table.columns 145 | ) 146 | if isinstance(engine, Engine) and isinstance(engine.engine.dialect, PGDialect): 147 | # postgres engine doesn't reflect indexes 148 | with engine.connect() as connection: 149 | indexes = { 150 | key: value 151 | for key, value in connection.execute( 152 | text( 153 | f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table.name}'" 154 | ) 155 | ) 156 | } 157 | if indexes and show_indexes: 158 | html += '' 159 | for value in indexes.values(): 160 | i_label = "UNIQUE " if "UNIQUE" in value else "INDEX " 161 | i_label += value[value.index("(") :] 162 | html += f'' 163 | html += "
' 139 | html += f'{schema_str}{"." if show_schema_name else ""}{table_str}
{format_col_str(col)}
{i_label}
>" 164 | return html 165 | 166 | 167 | def create_schema_graph( 168 | engine: Engine, 169 | tables: List[Table] = None, 170 | metadata: MetaData = None, 171 | show_indexes: bool = True, 172 | show_datatypes: bool = True, 173 | font: str = "Bitstream-Vera Sans", 174 | concentrate: bool = True, 175 | relation_options: Union[dict, None] = None, 176 | rankdir: str = "TB", 177 | show_column_keys: bool = False, 178 | restrict_tables: Union[List[str], None] = None, 179 | show_schema_name: bool = False, 180 | format_schema_name: Union[dict, None] = None, 181 | format_table_name: Union[dict, None] = None, 182 | ) -> pydot.Dot: 183 | # pylint: disable=too-many-locals,too-many-arguments 184 | """Create a diagram for the database schema. 185 | 186 | Args: 187 | engine (sqlalchemy.engine.Engine): SqlAlchemy database engine to connect to the database. 188 | tables (List[sqlalchemy.Table], optional): SqlAlchemy database tables. Defaults to None. 189 | metadata (sqlalchemy.MetaData, optional): SqlAlchemy `MetaData` with reference to related \ 190 | tables. Defaults to None. 191 | show_indexes (bool, optional): Whether to display the index column in the table. \ 192 | Defaults to True. 193 | show_datatypes (bool, optional): Whether to display the type of the columns in the table. \ 194 | Defaults to True. 195 | font (str, optional): font to be used in the diagram. Defaults to "Bitstream-Vera Sans". 196 | concentrate (bool, optional): Specifies if multi-edges should be merged into a single edge \ 197 | & partially parallel edges to share overlapping path. Passed to `pydot.Dot` object. \ 198 | Defaults to True. 199 | relation_options (Union[dict, None], optional): kwargs passed to pydot.Edge init. \ 200 | Most attributes in pydot.EDGE_ATTRIBUTES are viable options. A few values are set \ 201 | programmatically. Defaults to None. 202 | rankdir (str, optional): Sets direction of graph layout. Passed to `pydot.Dot` object. \ 203 | Options are 'TB' (top to bottom), 'BT' (bottom to top), 'LR' (left to right), \ 204 | 'RL' (right to left). Defaults to 'TB'. 205 | show_column_keys (bool, optional): If true then add a PK/FK suffix to columns names that \ 206 | are primary and foreign keys. Defaults to False. 207 | restrict_tables (Union[List[str], optional): Restrict the graph to only consider tables \ 208 | whose name are defined `restrict_tables`. Defaults to None. 209 | show_schema_name (bool, optional): If true, then prepend '.' to the table \ 210 | name resulting in '.'. Defaults to False. 211 | format_schema_name (Union[dict, None], optional): If provided, allowed keys include: \ 212 | 'color' (hex color code incl #), 'fontsize' as a float, and 'bold' and 'italics' \ 213 | as bools. Defaults to None. 214 | format_table_name (Union[dict, None], optional): If provided, allowed keys include: \ 215 | 'color' (hex color code incl #), 'fontsize' as a float, and 'bold' and 'italics' as \ 216 | bools. Defaults to None. 217 | 218 | Raises: 219 | ValueError: One needs to specify either the metadata or the tables 220 | KeyError: raised when unexpected keys are given to `format_schema_name` or \ 221 | `format_table_name` 222 | Returns: 223 | pydot.Dot: pydot object with the schema of the database 224 | """ 225 | 226 | if not relation_options: 227 | relation_options = {} 228 | 229 | relation_kwargs = {"fontsize": "7.0", "dir": "both"} 230 | relation_kwargs.update(relation_options) 231 | 232 | if not metadata and not tables: 233 | raise ValueError("You need to specify at least tables or metadata") 234 | 235 | if metadata and not tables: 236 | metadata.reflect(bind=engine) 237 | tables = metadata.tables.values() 238 | 239 | _accepted_keys = {"color", "fontsize", "italics", "bold"} 240 | 241 | # check if unexpected keys were used in format_schema_name param 242 | if ( 243 | format_schema_name is not None 244 | and len(set(format_schema_name.keys()).difference(_accepted_keys)) > 0 245 | ): 246 | raise KeyError( 247 | "Unrecognized keys were used in dict provided for `format_schema_name` parameter" 248 | ) 249 | # check if unexpected keys were used in format_table_name param 250 | if ( 251 | format_table_name is not None 252 | and len(set(format_table_name.keys()).difference(_accepted_keys)) > 0 253 | ): 254 | raise KeyError( 255 | "Unrecognized keys were used in dict provided for `format_table_name` parameter" 256 | ) 257 | 258 | graph = pydot.Dot( 259 | prog="dot", 260 | mode="ipsep", 261 | overlap="ipsep", 262 | sep="0.01", 263 | concentrate=str(concentrate), 264 | rankdir=rankdir, 265 | ) 266 | if restrict_tables is None: 267 | restrict_tables = {t.name.lower() for t in tables} 268 | else: 269 | restrict_tables = {t.lower() for t in restrict_tables} 270 | tables = [t for t in tables if t.name.lower() in restrict_tables] 271 | for table in tables: 272 | 273 | graph.add_node( 274 | pydot.Node( 275 | str(table.name), 276 | shape="plaintext", 277 | label=_render_table_html( 278 | table=table, 279 | engine=engine, 280 | show_indexes=show_indexes, 281 | show_datatypes=show_datatypes, 282 | show_column_keys=show_column_keys, 283 | show_schema_name=show_schema_name, 284 | format_schema_name=format_schema_name, 285 | format_table_name=format_table_name, 286 | ), 287 | fontname=font, 288 | fontsize="7.0", 289 | ), 290 | ) 291 | 292 | for table in tables: 293 | for fk in table.foreign_keys: 294 | if fk.column.table not in tables: 295 | continue 296 | edge = [table.name, fk.column.table.name] 297 | is_inheritance = fk.parent.primary_key and fk.column.primary_key 298 | if is_inheritance: 299 | edge = edge[::-1] 300 | graph_edge = pydot.Edge( 301 | headlabel="+ %s" % fk.column.name, 302 | taillabel="+ %s" % fk.parent.name, 303 | arrowhead=is_inheritance and "none" or "odot", 304 | arrowtail=(fk.parent.primary_key or fk.parent.unique) 305 | and "empty" 306 | or "crow", 307 | fontname=font, 308 | # samehead=fk.column.name, sametail=fk.parent.name, 309 | *edge, 310 | **relation_kwargs, 311 | ) 312 | graph.add_edge(graph_edge) 313 | 314 | # not sure what this part is for, doesn't work with pydot 1.0.2 315 | # graph_edge.parent_graph = graph.parent_graph 316 | # if table.name not in [e.get_source() for e in graph.get_edge_list()]: 317 | # graph.edge_src_list.append(table.name) 318 | # if fk.column.table.name not in graph.edge_dst_list: 319 | # graph.edge_dst_list.append(fk.column.table.name) 320 | # graph.sorted_graph_elements.append(graph_edge) 321 | return graph 322 | -------------------------------------------------------------------------------- /sqlalchemy_schemadisplay/model_diagram.py: -------------------------------------------------------------------------------- 1 | """ 2 | Set of functions to generate the diagram related to the ORM models 3 | """ 4 | import types 5 | from typing import List 6 | 7 | import pydot 8 | from sqlalchemy.orm import Mapper, Relationship 9 | from sqlalchemy.orm.properties import RelationshipProperty 10 | 11 | 12 | def _mk_label( 13 | mapper: Mapper, 14 | show_operations: bool, 15 | show_attributes: bool, 16 | show_datatypes: bool, 17 | show_inherited: bool, 18 | bordersize: float, 19 | ) -> str: 20 | # pylint: disable=too-many-arguments 21 | """Generate the rendering of a given orm model. 22 | 23 | Args: 24 | mapper (sqlalchemy.orm.Mapper): mapper for the SqlAlchemy orm class. 25 | show_operations (bool): whether to show functions defined in the orm. 26 | show_attributes (bool): whether to show the attributes of the class. 27 | show_datatypes (bool): Whether to display the type of the columns in the model. 28 | show_inherited (bool): whether to show inherited columns. 29 | bordersize (float): thickness of the border lines in the diagram 30 | 31 | Returns: 32 | str: html string to render the orm model 33 | """ 34 | html = ( 35 | f'<
' 38 | 39 | def format_col(col): 40 | colstr = f"+{col.name}" 41 | if show_datatypes: 42 | colstr += f" : {col.type.__class__.__name__}" 43 | return colstr 44 | 45 | if show_attributes: 46 | if not show_inherited: 47 | cols = [c for c in mapper.columns if c.table == mapper.tables[0]] 48 | else: 49 | cols = mapper.columns 50 | html += '' % '
'.join( 51 | format_col(col) for col in cols 52 | ) 53 | else: 54 | _ = [ 55 | format_col(col) 56 | for col in sorted(mapper.columns, key=lambda col: not col.primary_key) 57 | ] 58 | if show_operations: 59 | html += '' % '
'.join( 60 | "%s(%s)" 61 | % ( 62 | name, 63 | ", ".join( 64 | default is _mk_label and (f"{arg}") or (f"{arg}={repr(default)}") 65 | for default, arg in zip( 66 | ( 67 | func.func_defaults 68 | and len(func.func_code.co_varnames) 69 | - 1 70 | - (len(func.func_defaults) or 0) 71 | or func.func_code.co_argcount - 1 72 | ) 73 | * [_mk_label] 74 | + list(func.func_defaults or []), 75 | func.func_code.co_varnames[1:], 76 | ) 77 | ), 78 | ) 79 | for name, func in mapper.class_.__dict__.items() 80 | if isinstance(func, types.FunctionType) 81 | and func.__module__ == mapper.class_.__module__ 82 | ) 83 | html += "
{mapper.class_.__name__}
%s
%s
>" 84 | return html 85 | 86 | 87 | def escape(name: str) -> str: 88 | """Set the name of the object between quotations to avoid reading errors 89 | 90 | Args: 91 | name (str): name of the object 92 | 93 | Returns: 94 | str: name of the object between quotations to avoid reading errors 95 | """ 96 | return f'"{name}"' 97 | 98 | 99 | def create_uml_graph( 100 | mappers: List[Mapper], 101 | show_operations: bool = True, 102 | show_attributes: bool = True, 103 | show_inherited: bool = True, 104 | show_multiplicity_one: bool = False, 105 | show_datatypes: bool = True, 106 | linewidth: float = 1.0, 107 | font: str = "Bitstream-Vera Sans", 108 | ) -> pydot.Dot: 109 | # pylint: disable=too-many-locals,too-many-arguments 110 | """Create rendering of the orm models associated with the database 111 | 112 | Args: 113 | mappers (List[sqlalchemy.orm.Mapper]): SqlAlchemy list of mappers of the orm classes. 114 | show_operations (bool, optional): whether to show functions defined in the orm. \ 115 | Defaults to True. 116 | show_attributes (bool, optional): whether to show the attributes of the class. \ 117 | Defaults to True. 118 | show_inherited (bool, optional): whether to show inherited columns. Defaults to True. 119 | show_multiplicity_one (bool, optional): whether to show the multiplicity as a float or \ 120 | integer. Defaults to False. 121 | show_datatypes (bool, optional): Whether to display the type of the columns in the model. \ 122 | Defaults to True. 123 | linewidth (float, optional): thickness of the lines in the diagram. Defaults to 1.0. 124 | font (str, optional): type of fond to be used for the diagram. \ 125 | Defaults to "Bitstream-Vera Sans". 126 | 127 | Returns: 128 | pydot.Dot: pydot object with the diagram for the orm models. 129 | """ 130 | graph = pydot.Dot( 131 | prog="neato", 132 | mode="major", 133 | overlap="0", 134 | sep="0.01", 135 | dim="3", 136 | pack="True", 137 | ratio=".75", 138 | ) 139 | relations = set() 140 | for mapper in mappers: 141 | graph.add_node( 142 | pydot.Node( 143 | escape(mapper.class_.__name__), 144 | shape="plaintext", 145 | label=_mk_label( 146 | mapper, 147 | show_operations, 148 | show_attributes, 149 | show_datatypes, 150 | show_inherited, 151 | linewidth, 152 | ), 153 | fontname=font, 154 | fontsize="8.0", 155 | ) 156 | ) 157 | if mapper.inherits: 158 | graph.add_edge( 159 | pydot.Edge( 160 | escape(mapper.inherits.class_.__name__), 161 | escape(mapper.class_.__name__), 162 | arrowhead="none", 163 | arrowtail="empty", 164 | style="setlinewidth(%s)" % linewidth, 165 | arrowsize=str(linewidth), 166 | ), 167 | ) 168 | for loader in mapper.iterate_properties: 169 | if isinstance(loader, RelationshipProperty) and loader.mapper in mappers: 170 | if hasattr(loader, "reverse_property"): 171 | relations.add(frozenset([loader, loader.reverse_property])) 172 | else: 173 | relations.add(frozenset([loader])) 174 | 175 | def multiplicity_indicator(prop: Relationship) -> str: 176 | """Indicate the multiplicity of a given relationship 177 | 178 | Args: 179 | prop (sqlalchemy.orm.Relationship): relationship associated with this model 180 | 181 | Returns: 182 | str: string indicating the multiplicity of the relationship 183 | """ 184 | if prop.uselist: 185 | return " *" 186 | if hasattr(prop, "local_side"): 187 | cols = prop.local_side 188 | else: 189 | cols = prop.local_columns 190 | if any(col.nullable for col in cols): 191 | return " 0..1" 192 | if show_multiplicity_one: 193 | return " 1" 194 | return "" 195 | 196 | def calc_label(src: Relationship) -> str: 197 | """Generate the label for a given relationship 198 | 199 | Args: 200 | src (Relationship): relationship associated with this model 201 | 202 | Returns: 203 | str: relationship label 204 | """ 205 | return "+" + src.key + multiplicity_indicator(src) 206 | 207 | for relation in relations: 208 | # if len(loaders) > 2: 209 | # raise Exception("Warning: too many loaders for join %s" % join) 210 | args = {} 211 | 212 | if len(relation) == 2: 213 | src, dest = relation 214 | from_name = escape(src.parent.class_.__name__) 215 | to_name = escape(dest.parent.class_.__name__) 216 | 217 | args["headlabel"] = calc_label(src) 218 | 219 | args["taillabel"] = calc_label(dest) 220 | args["arrowtail"] = "none" 221 | args["arrowhead"] = "none" 222 | args["constraint"] = False 223 | else: 224 | (prop,) = relation 225 | from_name = escape(prop.parent.class_.__name__) 226 | to_name = escape(prop.mapper.class_.__name__) 227 | args["headlabel"] = f"+{prop.key}{multiplicity_indicator(prop)}" 228 | args["arrowtail"] = "none" 229 | args["arrowhead"] = "vee" 230 | 231 | graph.add_edge( 232 | pydot.Edge( 233 | from_name, 234 | to_name, 235 | fontname=font, 236 | fontsize="7.0", 237 | style=f"setlinewidth({linewidth})", 238 | arrowsize=str(linewidth), 239 | **args, 240 | ), 241 | ) 242 | 243 | return graph 244 | -------------------------------------------------------------------------------- /sqlalchemy_schemadisplay/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Set of functions to display the database diagrams generated. 3 | """ 4 | from io import StringIO 5 | 6 | from PIL import Image 7 | 8 | from sqlalchemy_schemadisplay import create_schema_graph, create_uml_graph 9 | 10 | 11 | def show_uml_graph(*args, **kwargs): 12 | """ 13 | Show the SQLAlchemy ORM diagram generated. 14 | """ 15 | iostream = StringIO( 16 | create_uml_graph(*args, **kwargs).create_png() 17 | ) # pylint: disable=no-member 18 | Image.open(iostream).show(command=kwargs.get("command", "gwenview")) 19 | 20 | 21 | def show_schema_graph(*args, **kwargs): 22 | """ 23 | Show the database diagram generated 24 | """ 25 | iostream = StringIO( 26 | create_schema_graph(*args, **kwargs).create_png() 27 | ) # pylint: disable=no-member 28 | Image.open(iostream).show(command=kwargs.get("command", "gwenview")) 29 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fschulze/sqlalchemy_schemadisplay/5c17079bd9ad7a9a6832c01a1dd7a8c00a6613f8/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_schema_graph.py: -------------------------------------------------------------------------------- 1 | """Set of tests for database diagrams""" 2 | import pydot 3 | import pytest 4 | from sqlalchemy import Column, ForeignKey, MetaData, Table, create_engine, types 5 | 6 | import sqlalchemy_schemadisplay 7 | from .utils import parse_graph 8 | 9 | 10 | @pytest.fixture 11 | def metadata(request): 12 | engine = create_engine("sqlite:///:memory:") 13 | _metadata = MetaData() 14 | _metadata.reflect(engine) 15 | return _metadata 16 | 17 | 18 | @pytest.fixture 19 | def engine(): 20 | return create_engine("sqlite:///:memory:") 21 | 22 | 23 | def plain_result(**kw): 24 | if "metadata" in kw: 25 | kw["metadata"].create_all(kw['engine']) 26 | elif "tables" in kw: 27 | if len(kw["tables"]): 28 | kw["tables"][0].metadata.create_all(kw['engine']) 29 | return parse_graph(sqlalchemy_schemadisplay.create_schema_graph(**kw)) 30 | 31 | 32 | def test_no_args(engine): 33 | with pytest.raises(ValueError) as e: 34 | sqlalchemy_schemadisplay.create_schema_graph(engine=engine) 35 | assert e.value.args[0] == "You need to specify at least tables or metadata" 36 | 37 | 38 | def test_empty_db(metadata, engine): 39 | graph = sqlalchemy_schemadisplay.create_schema_graph(engine=engine, 40 | metadata=metadata) 41 | assert isinstance(graph, pydot.Graph) 42 | assert graph.create_plain() == b"graph 1 0 0\nstop\n" 43 | 44 | 45 | def test_empty_table(metadata, engine): 46 | foo = Table("foo", metadata, Column("id", types.Integer, primary_key=True)) 47 | result = plain_result(engine=engine, metadata=metadata) 48 | assert list(result.keys()) == ["1"] 49 | assert list(result["1"]["nodes"].keys()) == ["foo"] 50 | assert "- id : INTEGER" in result["1"]["nodes"]["foo"] 51 | 52 | 53 | def test_empty_table_with_key_suffix(metadata, engine): 54 | foo = Table("foo", metadata, Column("id", types.Integer, primary_key=True)) 55 | result = plain_result( 56 | engine=engine, 57 | metadata=metadata, 58 | show_column_keys=True, 59 | ) 60 | print(result) 61 | assert list(result.keys()) == ["1"] 62 | assert list(result["1"]["nodes"].keys()) == ["foo"] 63 | assert "- id(PK) : INTEGER" in result["1"]["nodes"]["foo"] 64 | 65 | 66 | def test_foreign_key(metadata, engine): 67 | foo = Table( 68 | "foo", 69 | metadata, 70 | Column("id", types.Integer, primary_key=True), 71 | ) 72 | bar = Table( 73 | "bar", 74 | metadata, 75 | Column("foo_id", types.Integer, ForeignKey(foo.c.id)), 76 | ) 77 | result = plain_result(engine=engine, metadata=metadata) 78 | assert list(result.keys()) == ["1"] 79 | assert sorted(result["1"]["nodes"].keys()) == ["bar", "foo"] 80 | assert "- id : INTEGER" in result["1"]["nodes"]["foo"] 81 | assert "- foo_id : INTEGER" in result["1"]["nodes"]["bar"] 82 | assert "edges" in result["1"] 83 | assert ("bar", "foo") in result["1"]["edges"] 84 | 85 | 86 | def test_foreign_key_with_key_suffix(metadata, engine): 87 | foo = Table( 88 | "foo", 89 | metadata, 90 | Column("id", types.Integer, primary_key=True), 91 | ) 92 | bar = Table( 93 | "bar", 94 | metadata, 95 | Column("foo_id", types.Integer, ForeignKey(foo.c.id)), 96 | ) 97 | result = plain_result(engine=engine, 98 | metadata=metadata, 99 | show_column_keys=True) 100 | assert list(result.keys()) == ["1"] 101 | assert sorted(result["1"]["nodes"].keys()) == ["bar", "foo"] 102 | assert "- id(PK) : INTEGER" in result["1"]["nodes"]["foo"] 103 | assert "- foo_id(FK) : INTEGER" in result["1"]["nodes"]["bar"] 104 | assert "edges" in result["1"] 105 | assert ("bar", "foo") in result["1"]["edges"] 106 | 107 | 108 | def test_table_filtering(engine, metadata): 109 | foo = Table( 110 | "foo", 111 | metadata, 112 | Column("id", types.Integer, primary_key=True), 113 | ) 114 | bar = Table( 115 | "bar", 116 | metadata, 117 | Column("foo_id", types.Integer, ForeignKey(foo.c.id)), 118 | ) 119 | result = plain_result(engine=engine, tables=[bar]) 120 | assert list(result.keys()) == ["1"] 121 | assert list(result["1"]["nodes"].keys()) == ["bar"] 122 | assert "- foo_id : INTEGER" in result["1"]["nodes"]["bar"] 123 | 124 | 125 | def test_table_rendering_without_schema(metadata, engine): 126 | foo = Table( 127 | "foo", 128 | metadata, 129 | Column("id", types.Integer, primary_key=True), 130 | ) 131 | bar = Table( 132 | "bar", 133 | metadata, 134 | Column("foo_id", types.Integer, ForeignKey(foo.c.id)), 135 | ) 136 | 137 | try: 138 | sqlalchemy_schemadisplay.create_schema_graph( 139 | engine=engine, metadata=metadata).create_png() 140 | except Exception as ex: 141 | assert ( 142 | False 143 | ), f"An exception of type {ex.__class__.__name__} was produced when attempting to render a png of the graph" 144 | 145 | 146 | def test_table_rendering_with_schema(metadata, engine): 147 | foo = Table("foo", 148 | metadata, 149 | Column("id", types.Integer, primary_key=True), 150 | schema="sch_foo") 151 | bar = Table( 152 | "bar", 153 | metadata, 154 | Column("foo_id", types.Integer, ForeignKey(foo.c.id)), 155 | schema="sch_bar", 156 | ) 157 | 158 | try: 159 | sqlalchemy_schemadisplay.create_schema_graph( 160 | engine=engine, 161 | metadata=metadata, 162 | show_schema_name=True, 163 | ).create_png() 164 | except Exception as ex: 165 | assert ( 166 | False 167 | ), f"An exception of type {ex.__class__.__name__} was produced when attempting to render a png of the graph" 168 | 169 | 170 | def test_table_rendering_with_schema_and_formatting(metadata, engine): 171 | foo = Table("foo", 172 | metadata, 173 | Column("id", types.Integer, primary_key=True), 174 | schema="sch_foo") 175 | bar = Table( 176 | "bar", 177 | metadata, 178 | Column("foo_id", types.Integer, ForeignKey(foo.c.id)), 179 | schema="sch_bar", 180 | ) 181 | 182 | try: 183 | sqlalchemy_schemadisplay.create_schema_graph( 184 | engine=engine, 185 | metadata=metadata, 186 | show_schema_name=True, 187 | format_schema_name={ 188 | "fontsize": 8.0, 189 | "color": "#888888" 190 | }, 191 | format_table_name={ 192 | "bold": True, 193 | "fontsize": 10.0 194 | }, 195 | ).create_png() 196 | except Exception as ex: 197 | assert ( 198 | False 199 | ), f"An exception of type {ex.__class__.__name__} was produced when attempting to render a png of the graph" 200 | -------------------------------------------------------------------------------- /tests/test_uml_graph.py: -------------------------------------------------------------------------------- 1 | """Set of tests for the ORM diagrams""" 2 | import pytest 3 | from sqlalchemy import Column, ForeignKey, MetaData, types 4 | from sqlalchemy.ext.declarative import declarative_base 5 | from sqlalchemy.orm import class_mapper, relationship 6 | 7 | import sqlalchemy_schemadisplay 8 | from .utils import parse_graph 9 | 10 | 11 | @pytest.fixture 12 | def metadata(request): 13 | return MetaData("sqlite:///:memory:") 14 | 15 | 16 | @pytest.fixture 17 | def Base(request, metadata): 18 | return declarative_base(metadata=metadata) 19 | 20 | 21 | def plain_result(mapper, **kw): 22 | return parse_graph(sqlalchemy_schemadisplay.create_uml_graph(mapper, **kw)) 23 | 24 | 25 | def mappers(*args): 26 | return [class_mapper(x) for x in args] 27 | 28 | 29 | def test_simple_class(Base, capsys): 30 | class Foo(Base): 31 | __tablename__ = "foo" 32 | id = Column(types.Integer, primary_key=True) 33 | 34 | result = plain_result(mappers(Foo)) 35 | assert list(result.keys()) == ["1"] 36 | assert list(result["1"]["nodes"].keys()) == ["Foo"] 37 | assert "+id : Integer" in result["1"]["nodes"]["Foo"] 38 | out, err = capsys.readouterr() 39 | assert out == "" 40 | assert err == "" 41 | 42 | 43 | def test_relation(Base): 44 | class Foo(Base): 45 | __tablename__ = "foo" 46 | id = Column(types.Integer, primary_key=True) 47 | 48 | class Bar(Base): 49 | __tablename__ = "bar" 50 | id = Column(types.Integer, primary_key=True) 51 | foo_id = Column(types.Integer, ForeignKey(Foo.id)) 52 | 53 | Foo.bars = relationship(Bar) 54 | graph = sqlalchemy_schemadisplay.create_uml_graph(mappers(Foo, Bar)) 55 | assert sorted(graph.obj_dict["nodes"].keys()) == ['"Bar"', '"Foo"'] 56 | assert "+id : Integer" in graph.obj_dict["nodes"]['"Foo"'][0][ 57 | "attributes"]["label"] 58 | assert ("+foo_id : Integer" 59 | in graph.obj_dict["nodes"]['"Bar"'][0]["attributes"]["label"]) 60 | assert "edges" in graph.obj_dict 61 | assert ('"Foo"', '"Bar"') in graph.obj_dict["edges"] 62 | assert (graph.obj_dict["edges"][( 63 | '"Foo"', '"Bar"')][0]["attributes"]["headlabel"] == "+bars *") 64 | 65 | 66 | def test_backref(Base): 67 | class Foo(Base): 68 | __tablename__ = "foo" 69 | id = Column(types.Integer, primary_key=True) 70 | 71 | class Bar(Base): 72 | __tablename__ = "bar" 73 | id = Column(types.Integer, primary_key=True) 74 | foo_id = Column(types.Integer, ForeignKey(Foo.id)) 75 | 76 | Foo.bars = relationship(Bar, backref="foo") 77 | graph = sqlalchemy_schemadisplay.create_uml_graph(mappers(Foo, Bar)) 78 | assert sorted(graph.obj_dict["nodes"].keys()) == ['"Bar"', '"Foo"'] 79 | assert "+id : Integer" in graph.obj_dict["nodes"]['"Foo"'][0][ 80 | "attributes"]["label"] 81 | assert ("+foo_id : Integer" 82 | in graph.obj_dict["nodes"]['"Bar"'][0]["attributes"]["label"]) 83 | assert "edges" in graph.obj_dict 84 | assert ('"Foo"', '"Bar"') in graph.obj_dict["edges"] 85 | assert ('"Bar"', '"Foo"') in graph.obj_dict["edges"] 86 | assert (graph.obj_dict["edges"][( 87 | '"Foo"', '"Bar"')][0]["attributes"]["headlabel"] == "+bars *") 88 | assert (graph.obj_dict["edges"][( 89 | '"Bar"', '"Foo"')][0]["attributes"]["headlabel"] == "+foo 0..1") 90 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | 3 | 4 | def parse_graph(graph): 5 | result = {} 6 | graph_bytes = graph.create_plain() 7 | sio = StringIO(graph_bytes.decode("utf-8")) 8 | graph = None 9 | for line in sio: 10 | line = line.strip() 11 | if not line: 12 | continue 13 | if line.startswith("graph"): 14 | parts = line.split(None, 4) 15 | graph = result.setdefault(parts[1], {"nodes": {}}) 16 | if len(parts) > 4: 17 | graph["options"] = parts[4] 18 | elif line.startswith("node"): 19 | parts = line.split(None, 6) 20 | graph["nodes"][parts[1]] = parts[6] 21 | elif line.startswith("edge"): 22 | parts = line.split(None, 3) 23 | graph.setdefault("edges", {})[(parts[1], parts[2])] = parts[3] 24 | elif line == "stop": 25 | graph = None 26 | else: 27 | raise ValueError("Don't know how to handle line:\n%s" % line) 28 | return result 29 | --------------------------------------------------------------------------------