├── .coveragerc ├── .github └── workflows │ ├── pre-commit_hooks.yaml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── alembic.ini ├── docs ├── api.md ├── config.md ├── css │ └── custom.css ├── examples.md ├── index.md └── quickstart.md ├── mkdocs.yml ├── pyproject.toml ├── pytest.ini ├── setup.py └── src ├── alembic_utils ├── __init__.py ├── depends.py ├── exceptions.py ├── experimental │ ├── __init__.py │ └── _collect_instances.py ├── on_entity_mixin.py ├── pg_extension.py ├── pg_function.py ├── pg_grant_table.py ├── pg_materialized_view.py ├── pg_policy.py ├── pg_trigger.py ├── pg_view.py ├── py.typed ├── replaceable_entity.py ├── reversible_op.py ├── simulate.py ├── statement.py └── testbase.py └── test ├── alembic_config ├── README ├── env.py └── script.py.mako ├── conftest.py ├── resources └── to_upper.sql ├── test_alembic_check.py ├── test_collect_instances.py ├── test_create_function.py ├── test_current.py ├── test_depends.py ├── test_duplicate_registration.py ├── test_generate_revision.py ├── test_include_filters.py ├── test_initializers.py ├── test_op_drop_cascade.py ├── test_pg_constraint_trigger.py ├── test_pg_extension.py ├── test_pg_function.py ├── test_pg_function_overloading.py ├── test_pg_grant_table.py ├── test_pg_grant_table_w_columns.py ├── test_pg_materialized_view.py ├── test_pg_policy.py ├── test_pg_trigger.py ├── test_pg_view.py ├── test_recreate_dropped.py ├── test_simulate_entity.py └── test_statement.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | pragma: no cover 4 | raise AssertionError 5 | raise NotImplementedError 6 | UnreachableException 7 | if __name__ == .__main__.: 8 | if TYPE_CHECKING: 9 | pass 10 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit_hooks.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit hooks 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: checkout alembic_utils 11 | uses: actions/checkout@v2 12 | 13 | - name: set up python 3.10 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: '3.10' 17 | 18 | - name: Install Poetry 19 | run: | 20 | curl -sSL https://install.python-poetry.org | python3 - 21 | export PATH="$HOME/.local/bin:$PATH" 22 | 23 | - name: Install Dependencies 24 | run: | 25 | poetry install 26 | 27 | - name: Run Pre-commit Hooks 28 | run: | 29 | poetry run pre-commit run --all-files 30 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ['3.9', '3.13'] 11 | postgres-version: ['13', '17'] 12 | 13 | services: 14 | 15 | postgres: 16 | image: postgres:${{ matrix.postgres-version }} 17 | env: 18 | POSTGRES_DB: alem_db 19 | POSTGRES_HOST: localhost 20 | POSTGRES_USER: alem_user 21 | POSTGRES_PASSWORD: password 22 | ports: 23 | - 5610:5432 24 | options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 25 | 26 | steps: 27 | - uses: actions/checkout@v1 28 | - name: Set up Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | 33 | - name: Install Poetry 34 | run: | 35 | curl -sSL https://install.python-poetry.org | python3 - 36 | export PATH="$HOME/.local/bin:$PATH" 37 | 38 | - name: Install Dependencies 39 | run: | 40 | poetry install 41 | 42 | - name: Test with Coverage 43 | run: | 44 | poetry run pytest --cov=alembic_utils src/test --cov-report=xml 45 | 46 | - name: Upload coverage to Codecov 47 | uses: codecov/codecov-action@v1 48 | with: 49 | token: ${{ secrets.CODECOV_TOKEN }} 50 | file: ./coverage.xml 51 | flags: unittests 52 | name: codecov-umbrella 53 | fail_ci_if_error: true 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # vim 10 | *.swp 11 | 12 | # mac 13 | .DS_Store 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # poetry 37 | poetry.lock 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # PyCharm project settings 127 | .idea 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.3.0 5 | hooks: 6 | - id: check-ast 7 | - id: check-added-large-files 8 | - id: check-case-conflict 9 | - id: check-merge-conflict 10 | - id: mixed-line-ending 11 | args: ['--fix=lf'] 12 | - id: trailing-whitespace 13 | 14 | - repo: https://github.com/pre-commit/mirrors-isort 15 | rev: v5.10.1 16 | hooks: 17 | - id: isort 18 | args: ['--multi-line=3', '--trailing-comma', '--force-grid-wrap=0', '--use-parentheses'] 19 | 20 | 21 | - repo: https://github.com/humitos/mirrors-autoflake 22 | rev: v1.1 23 | hooks: 24 | - id: autoflake 25 | args: ['--in-place', '--remove-all-unused-imports'] 26 | 27 | - repo: https://github.com/ambv/black 28 | rev: 22.6.0 29 | hooks: 30 | - id: black 31 | language_version: python3.10 32 | 33 | 34 | - repo: https://github.com/pre-commit/mirrors-mypy 35 | rev: v0.961 36 | hooks: 37 | - id: mypy 38 | files: src/ 39 | args: ["--config-file=pyproject.toml"] 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Oliver Rice 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Alembic Utils 2 | 3 |

4 | 5 | Test Status 6 | 7 | 8 | Pre-commit Status 9 | 10 | 11 |

12 |

13 | License 14 | PyPI version 15 | 16 | Codestyle Black 17 | 18 | Download count 19 |

20 |

21 | Python version 22 | PostgreSQL version 23 |

24 | 25 | --- 26 | 27 | **Documentation**: https://olirice.github.io/alembic_utils 28 | 29 | **Source Code**: https://github.com/olirice/alembic_utils 30 | 31 | --- 32 | 33 | **Autogenerate Support for PostgreSQL Functions, Views, Materialized View, Triggers, and Policies** 34 | 35 | [Alembic](https://alembic.sqlalchemy.org/en/latest/) is the defacto migration tool for use with [SQLAlchemy](https://www.sqlalchemy.org/). Without extensions, alembic can detect local changes to SQLAlchemy models and autogenerate a database migration or "revision" script. That revision can be applied to update the database's schema to match the SQLAlchemy model definitions. 36 | 37 | Alembic Utils is an extension to alembic that adds support for autogenerating a larger number of [PostgreSQL](https://www.postgresql.org/) entity types, including [functions](https://www.postgresql.org/docs/current/sql-createfunction.html), [views](https://www.postgresql.org/docs/current/sql-createview.html), [materialized views](https://www.postgresql.org/docs/current/sql-creatematerializedview.html), [triggers](https://www.postgresql.org/docs/current/sql-createtrigger.html), and [policies](https://www.postgresql.org/docs/current/sql-createpolicy.html). 38 | 39 | ### TL;DR 40 | 41 | Update alembic's `env.py` to register a function or view: 42 | 43 | ```python 44 | # migrations/env.py 45 | from alembic_utils.pg_function import PGFunction 46 | from alembic_utils.replaceable_entity import register_entities 47 | 48 | 49 | to_upper = PGFunction( 50 | schema='public', 51 | signature='to_upper(some_text text)', 52 | definition=""" 53 | RETURNS text as 54 | $$ 55 | SELECT upper(some_text) 56 | $$ language SQL; 57 | """ 58 | ) 59 | 60 | register_entities([to_upper]) 61 | ``` 62 | 63 | You're done! 64 | 65 | The next time you autogenerate a revision with 66 | ```shell 67 | alembic revision --autogenerate -m 'create to_upper' 68 | ``` 69 | Alembic will detect if your entities are new, updated, or removed & populate the revison's `upgrade` and `downgrade` sections automatically. 70 | 71 | For example: 72 | 73 | ```python 74 | """create to_upper 75 | 76 | Revision ID: 8efi0da3a4 77 | Revises: 78 | Create Date: 2020-04-22 09:24:25.556995 79 | """ 80 | from alembic import op 81 | import sqlalchemy as sa 82 | from alembic_utils.pg_function import PGFunction 83 | 84 | # revision identifiers, used by Alembic. 85 | revision = '8efi0da3a4' 86 | down_revision = None 87 | branch_labels = None 88 | depends_on = None 89 | 90 | 91 | def upgrade(): 92 | public_to_upper_6fa0de = PGFunction( 93 | schema="public", 94 | signature="to_upper(some_text text)", 95 | definition=""" 96 | returns text 97 | as 98 | $$ select upper(some_text) $$ language SQL; 99 | """ 100 | ) 101 | 102 | op.create_entity(public_to_upper_6fa0de) 103 | 104 | 105 | def downgrade(): 106 | public_to_upper_6fa0de = PGFunction( 107 | schema="public", 108 | signature="to_upper(some_text text)", 109 | definition="# Not Used" 110 | ) 111 | 112 | op.drop_entity(public_to_upper_6fa0de) 113 | ``` 114 | 115 | 116 | Visit the [quickstart guide](https://olirice.github.io/alembic_utils/quickstart/) for usage instructions. 117 | 118 |

—— ——

119 | 120 | ### Contributing 121 | 122 | To run the tests 123 | ``` 124 | # install pip dependencies 125 | pip install wheel && pip install -e ".[dev]" 126 | 127 | # run the tests 128 | pytest src/test 129 | ``` 130 | 131 | To invoke the linter automated formatting and generally make use of precommit checks: 132 | ``` 133 | pip install pre-commit 134 | pre-commit install 135 | 136 | # manually run 137 | pre-commit run --all 138 | ``` 139 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = src/test/alembic_config 6 | 7 | # template used to generate migration files 8 | # file_template = %%(rev)s_%%(slug)s 9 | 10 | # timezone to use when rendering the date 11 | # within the migration file as well as the filename. 12 | # string value is passed to dateutil.tz.gettz() 13 | # leave blank for localtime 14 | # timezone = 15 | 16 | # max length of characters to apply to the 17 | # "slug" field 18 | # truncate_slug_length = 40 19 | 20 | # set to 'true' to run the environment during 21 | # the 'revision' command, regardless of autogenerate 22 | # revision_environment = false 23 | 24 | # set to 'true' to allow .pyc and .pyo files without 25 | # a source .py file to be detected as revisions in the 26 | # versions/ directory 27 | # sourceless = false 28 | 29 | # version location specification; this defaults 30 | # to src/test/alembic_config/versions. When using multiple version 31 | # directories, initial revisions must be specified with --version-path 32 | # version_locations = %(here)s/bar %(here)s/bat src/test/alembic_config/versions 33 | 34 | # the output encoding used when revision files 35 | # are written from script.py.mako 36 | # output_encoding = utf-8 37 | 38 | sqlalchemy.url = postgresql://alem_user:password@localhost:5680/alem_db 39 | 40 | 41 | [post_write_hooks] 42 | # post_write_hooks defines scripts or Python functions that are run 43 | # on newly generated revision scripts. See the documentation for further 44 | # detail and examples 45 | 46 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 47 | # hooks=black 48 | # black.type=console_scripts 49 | # black.entrypoint=black 50 | # black.options=-l 79 51 | 52 | # Logging configuration 53 | [loggers] 54 | keys = root,sqlalchemy,alembic,alembic_utils 55 | 56 | [handlers] 57 | keys = console 58 | 59 | [formatters] 60 | keys = generic 61 | 62 | [logger_root] 63 | level = WARN 64 | handlers = console 65 | qualname = 66 | 67 | [logger_sqlalchemy] 68 | level = WARN 69 | handlers = 70 | qualname = sqlalchemy.engine 71 | 72 | [logger_alembic] 73 | level = INFO 74 | handlers = 75 | qualname = alembic 76 | 77 | [logger_alembic_utils] 78 | level = INFO 79 | handlers = 80 | qualname = alembic_utils 81 | 82 | [handler_console] 83 | class = StreamHandler 84 | args = (sys.stderr,) 85 | level = NOTSET 86 | formatter = generic 87 | 88 | [formatter_generic] 89 | format = %(levelname)-5.5s [%(name)s] %(message)s 90 | datefmt = %H:%M:%S 91 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | ## API Reference 2 | 3 | ::: alembic_utils.replaceable_entity.register_entities 4 | :docstring: 5 | 6 | ```python 7 | # migrations/env.py 8 | 9 | from alembic_utils.replaceable_entity import register_entities 10 | from app.functions import my_function 11 | from app.views import my_view 12 | 13 | register_entities(entities=[my_function, my_view], exclude_schema=['audit']) 14 | ``` 15 | 16 | ::: alembic_utils.pg_function.PGFunction 17 | :docstring: 18 | 19 | ```python 20 | from alembic_utils.pg_function import PGFunction 21 | 22 | to_lower = PGFunction( 23 | schema="public", 24 | signature="to_lower(some_text text)", 25 | definition="returns text as $$ lower(some_text) $$ language sql" 26 | ) 27 | ``` 28 | 29 | 30 | ::: alembic_utils.pg_view.PGView 31 | :docstring: 32 | 33 | 34 | ```python 35 | from alembic_utils.pg_view import PGView 36 | 37 | scifi_books = PGView( 38 | schema="public", 39 | signature="scifi_books", 40 | definition="select * from books where genre='scifi'" 41 | ) 42 | ``` 43 | 44 | ::: alembic_utils.pg_materialized_view.PGMaterializedView 45 | :docstring: 46 | 47 | 48 | ```python 49 | from alembic_utils.pg_materialized_view import PGMaterializedView 50 | 51 | scifi_books = PGMaterializedView( 52 | schema="public", 53 | signature="scifi_books", 54 | definition="select * from books where genre='scifi'", 55 | with_data=True 56 | ) 57 | ``` 58 | 59 | 60 | ::: alembic_utils.pg_trigger.PGTrigger 61 | :docstring: 62 | 63 | 64 | ```python 65 | from alembic_utils.pg_trigger import PGTrigger 66 | 67 | trigger = PGTrigger( 68 | schema="public", 69 | signature="lower_account_email", 70 | on_entity="public.account", 71 | definition=""" 72 | BEFORE INSERT ON public.account 73 | FOR EACH ROW EXECUTE FUNCTION public.downcase_email() 74 | """, 75 | ) 76 | ``` 77 | 78 | ::: alembic_utils.pg_extension.PGExtension 79 | :docstring: 80 | 81 | 82 | ```python 83 | from alembic_utils.pg_extension import PGExtension 84 | 85 | extension = PGExtension( 86 | schema="public", 87 | signature="uuid-ossp", 88 | ) 89 | ``` 90 | 91 | 92 | ::: alembic_utils.pg_policy.PGPolicy 93 | :docstring: 94 | 95 | 96 | ```python 97 | from alembic_utils.pg_policy import PGPolicy 98 | 99 | policy = PGPolicy( 100 | schema="public", 101 | signature="allow_read", 102 | on_entity="public.account", 103 | definition=""" 104 | AS PERMISSIVE 105 | FOR SELECT 106 | TO api_user 107 | USING (id = current_setting('api_current_user', true)::int) 108 | """, 109 | ) 110 | ``` 111 | 112 | 113 | ::: alembic_utils.pg_grant_table.PGGrantTable 114 | :docstring: 115 | 116 | 117 | ```python 118 | from alembic_utils.pg_grant_table import PGGrantTable 119 | 120 | grant = PGGrantTable( 121 | schema="public", 122 | table="account", 123 | columns=["id", "email"], 124 | role="anon_user", 125 | grant='SELECT', 126 | with_grant_option=False, 127 | ) 128 | ``` 129 | -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | ## Configuration 2 | 3 | 4 | Controlling output of an `--autogenerate` revision is limited to including or excluding objects. By default, all schemas in the target database are included. 5 | 6 | Selectively filtering `alembic_utils` objects can acheived using the `include_name` and `include_object` callables in `env.py`. For more information see [controlling-what-to-be-autogenerated](https://alembic.sqlalchemy.org/en/latest/autogenerate.html#controlling-what-to-be-autogenerated) 7 | 8 | 9 | ### Examples 10 | 11 | The following examples perform common filtering tasks. 12 | 13 | #### Whitelist of Schemas 14 | 15 | Only include schemas from schemas named `public` and `bi`. 16 | 17 | ```python 18 | # env.py 19 | 20 | def include_name(name, type_, parent_names) -> bool: 21 | if type_ == "schema": 22 | return name in ["public", "bi"] 23 | return True 24 | 25 | context.configure( 26 | # ... 27 | include_schemas = True, 28 | include_name = include_name 29 | ) 30 | ``` 31 | 32 | #### Exclude PGFunctions 33 | 34 | Don't produce migrations for PGFunction entities. 35 | 36 | ```python 37 | # env.py 38 | 39 | def include_object(object, name, type_, reflected, compare_to) -> bool: 40 | if isinstance(object, PGFunction): 41 | return False 42 | return True 43 | 44 | context.configure( 45 | # ... 46 | include_object=include_object 47 | ) 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- /docs/css/custom.css: -------------------------------------------------------------------------------- 1 | div.autodoc-docstring { 2 | padding-left: 20px; 3 | margin-bottom: 30px; 4 | border-left: 5px solid rgba(230, 230, 230); 5 | } 6 | 7 | div.autodoc-members { 8 | padding-left: 20px; 9 | margin-bottom: 15px; 10 | } 11 | 12 | .codehilite code, .codehilite pre{color:#3F3F3F;background-color:#F7F7F7; 13 | overflow: auto; 14 | box-sizing: border-box; 15 | 16 | padding: 0.01em 16px; 17 | padding-top: 0.01em; 18 | padding-right-value: 16px; 19 | padding-bottom: 0.01em; 20 | padding-left-value: 16px; 21 | padding-left-ltr-source: physical; 22 | padding-left-rtl-source: physical; 23 | padding-right-ltr-source: physical; 24 | padding-right-rtl-source: physical; 25 | 26 | border-radius: 16px !important; 27 | border-top-left-radius: 16px; 28 | border-top-right-radius: 16px; 29 | border-bottom-right-radius: 16px; 30 | border-bottom-left-radius: 16px; 31 | 32 | border: 1px solid #CCC !important; 33 | border-top-width: 1px; 34 | border-right-width-value: 1px; 35 | border-right-width-ltr-source: physical; 36 | border-right-width-rtl-source: physical; 37 | border-bottom-width: 1px; 38 | border-left-width-value: 1px; 39 | border-left-width-ltr-source: physical; 40 | border-left-width-rtl-source: physical; 41 | border-top-style: solid; 42 | border-right-style-value: solid; 43 | border-right-style-ltr-source: physical; 44 | border-right-style-rtl-source: physical; 45 | border-bottom-style: solid; 46 | border-left-style-value: solid; 47 | border-left-style-ltr-source: physical; 48 | border-left-style-rtl-source: physical; 49 | border-top-color: #CCC; 50 | border-right-color-value: #CCC; 51 | border-right-color-ltr-source: physical; 52 | border-right-color-rtl-source: physical; 53 | border-bottom-color: #CCC; 54 | border-left-color-value: #CCC; 55 | border-left-color-ltr-source: physical; 56 | border-left-color-rtl-source: physical; 57 | -moz-border-top-colors: none; 58 | -moz-border-right-colors: none; 59 | -moz-border-bottom-colors: none; 60 | -moz-border-left-colors: none; 61 | border-image-source: none; 62 | border-image-slice: 100% 100% 100% 100%; 63 | border-image-width: 1 1 1 1; 64 | border-image-outset: 0 0 0 0; 65 | border-image-repeat: stretch stretch;} 66 | .codehilite .hll { background-color: #ffffcc } 67 | .codehilite .c { color: #999988; font-style: italic } /* Comment */ 68 | .codehilite .err { color: #a61717; background-color: #e3d2d2 } /* Error */ 69 | .codehilite .k { color: #000000; font-weight: bold } /* Keyword */ 70 | .codehilite .o { color: #000000; font-weight: bold } /* Operator */ 71 | .codehilite .cm { color: #999988; font-style: italic } /* Comment.Multiline */ 72 | .codehilite .cp { color: #999999; font-weight: bold; font-style: italic } /* Comment.Preproc */ 73 | .codehilite .c1 { color: #999988; font-style: italic } /* Comment.Single */ 74 | .codehilite .cs { color: #999999; font-weight: bold; font-style: italic } /* Comment.Special */ 75 | .codehilite .gd { color: #000000; background-color: #ffdddd } /* Generic.Deleted */ 76 | .codehilite .ge { color: #000000; font-style: italic } /* Generic.Emph */ 77 | .codehilite .gr { color: #aa0000 } /* Generic.Error */ 78 | .codehilite .gh { color: #999999 } /* Generic.Heading */ 79 | .codehilite .gi { color: #000000; background-color: #ddffdd } /* Generic.Inserted */ 80 | .codehilite .go { color: #888888 } /* Generic.Output */ 81 | .codehilite .gp { color: #555555 } /* Generic.Prompt */ 82 | .codehilite .gs { font-weight: bold } /* Generic.Strong */ 83 | .codehilite .gu { color: #aaaaaa } /* Generic.Subheading */ 84 | .codehilite .gt { color: #aa0000 } /* Generic.Traceback */ 85 | .codehilite .kc { color: #000000; font-weight: bold } /* Keyword.Constant */ 86 | .codehilite .kd { color: #000000; font-weight: bold } /* Keyword.Declaration */ 87 | .codehilite .kn { color: #000000; font-weight: bold } /* Keyword.Namespace */ 88 | .codehilite .kp { color: #000000; font-weight: bold } /* Keyword.Pseudo */ 89 | .codehilite .kr { color: #000000; font-weight: bold } /* Keyword.Reserved */ 90 | .codehilite .kt { color: #445588; font-weight: bold } /* Keyword.Type */ 91 | .codehilite .m { color: #009999 } /* Literal.Number */ 92 | .codehilite .s { color: #d01040 } /* Literal.String */ 93 | .codehilite .na { color: #008080 } /* Name.Attribute */ 94 | .codehilite .nb { color: #0086B3 } /* Name.Builtin */ 95 | .codehilite .nc { color: #445588; font-weight: bold } /* Name.Class */ 96 | .codehilite .no { color: #008080 } /* Name.Constant */ 97 | .codehilite .nd { color: #3c5d5d; font-weight: bold } /* Name.Decorator */ 98 | .codehilite .ni { color: #800080 } /* Name.Entity */ 99 | .codehilite .ne { color: #990000; font-weight: bold } /* Name.Exception */ 100 | .codehilite .nf { color: #990000; font-weight: bold } /* Name.Function */ 101 | .codehilite .nl { color: #990000; font-weight: bold } /* Name.Label */ 102 | .codehilite .nn { color: #555555 } /* Name.Namespace */ 103 | .codehilite .nt { color: #000080 } /* Name.Tag */ 104 | .codehilite .nv { color: #008080 } /* Name.Variable */ 105 | .codehilite .ow { color: #000000; font-weight: bold } /* Operator.Word */ 106 | .codehilite .w { color: #bbbbbb } /* Text.Whitespace */ 107 | .codehilite .mf { color: #009999 } /* Literal.Number.Float */ 108 | .codehilite .mh { color: #009999 } /* Literal.Number.Hex */ 109 | .codehilite .mi { color: #009999 } /* Literal.Number.Integer */ 110 | .codehilite .mo { color: #009999 } /* Literal.Number.Oct */ 111 | .codehilite .sb { color: #d01040 } /* Literal.String.Backtick */ 112 | .codehilite .sc { color: #d01040 } /* Literal.String.Char */ 113 | .codehilite .sd { color: #d01040 } /* Literal.String.Doc */ 114 | .codehilite .s2 { color: #d01040 } /* Literal.String.Double */ 115 | .codehilite .se { color: #d01040 } /* Literal.String.Escape */ 116 | .codehilite .sh { color: #d01040 } /* Literal.String.Heredoc */ 117 | .codehilite .si { color: #d01040 } /* Literal.String.Interpol */ 118 | .codehilite .sx { color: #d01040 } /* Literal.String.Other */ 119 | .codehilite .sr { color: #009926 } /* Literal.String.Regex */ 120 | .codehilite .s1 { color: #d01040 } /* Literal.String.Single */ 121 | .codehilite .ss { color: #990073 } /* Literal.String.Symbol */ 122 | .codehilite .bp { color: #999999 } /* Name.Builtin.Pseudo */ 123 | .codehilite .vc { color: #008080 } /* Name.Variable.Class */ 124 | .codehilite .vg { color: #008080 } /* Name.Variable.Global */ 125 | .codehilite .vi { color: #008080 } /* Name.Variable.Instance */ 126 | .codehilite .il { color: #009999 } /* Literal.Number.Integer.Long */ 127 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | ## Example Outputs 2 | 3 | ### Migration for newly created function 4 | ```python 5 | """create 6 | 7 | Revision ID: 1 8 | Revises: 9 | Create Date: 2020-04-22 09:24:25.556995 10 | """ 11 | from alembic import op 12 | import sqlalchemy as sa 13 | from alembic_utils.pg_function import PGFunction 14 | 15 | # revision identifiers, used by Alembic. 16 | revision = '1' 17 | down_revision = None 18 | branch_labels = None 19 | depends_on = None 20 | 21 | 22 | def upgrade(): 23 | public_to_upper_6fa0de = PGFunction( 24 | schema="public", 25 | signature="to_upper(some_text text)", 26 | definition=""" 27 | returns text 28 | as 29 | $$ select upper(some_text) $$ language SQL; 30 | """ 31 | ) 32 | 33 | op.create_entity(public_to_upper_6fa0de) 34 | 35 | 36 | def downgrade(): 37 | public_to_upper_6fa0de = PGFunction( 38 | schema="public", 39 | signature="to_upper(some_text text)", 40 | definition="# Not Used" 41 | ) 42 | 43 | op.drop_entity(public_to_upper_6fa0de) 44 | ``` 45 | 46 | ### Migration for updated Function 47 | 48 | ```python 49 | """replace 50 | 51 | Revision ID: 2 52 | Revises: 1 53 | Create Date: 2020-04-22 09:24:25.679031 54 | """ 55 | from alembic import op 56 | import sqlalchemy as sa 57 | from alembic_utils.pg_function import PGFunction 58 | 59 | # revision identifiers, used by Alembic. 60 | revision = '2' 61 | down_revision = '1' 62 | branch_labels = None 63 | depends_on = None 64 | 65 | 66 | def upgrade(): 67 | public_to_upper_6fa0de = PGFunction( 68 | schema="public", 69 | signature="to_upper(some_text text)", 70 | definition=""" 71 | returns text 72 | as 73 | $$ select upper(some_text) || 'def' $$ language SQL; 74 | """ 75 | ) 76 | 77 | op.replace_entity(public_to_upper_6fa0de) 78 | 79 | 80 | def downgrade(): 81 | public_to_upper_6fa0de = PGFunction( 82 | schema="public", 83 | signature="to_upper(some_text text)", 84 | definition="""returns text 85 | LANGUAGE sql 86 | AS $function$ select upper(some_text) || 'abc' $function$""" 87 | ) 88 | 89 | op.replace_entity(public_to_upper_6fa0de) 90 | ``` 91 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Alembic Utils 2 | 3 |

4 | 5 | Test Status 6 | 7 | 8 | Pre-commit Status 9 | 10 | 11 |

12 |

13 | License 14 | PyPI version 15 | 16 | Codestyle Black 17 | 18 | Download count 19 |

20 |

21 | Python version 22 | PostgreSQL version 23 |

24 | 25 | ---- 26 | 27 | **Documentation**: https://olirice.github.io/alembic_utils 28 | 29 | **Source Code**: https://github.com/olirice/alembic_utils 30 | 31 | --- 32 | [Alembic](https://alembic.sqlalchemy.org/en/latest/) is the defacto migration tool for use with [SQLAlchemy](https://www.sqlalchemy.org/). Without extensions, alembic can detect local changes to SQLAlchemy models and autogenerate a database migration or "revision" script. That revision can be applied to update the database's schema to match the SQLAlchemy model definitions. 33 | 34 | Alembic Utils is an extension to alembic that adds support for autogenerating a larger number of [PostgreSQL](https://www.postgresql.org/) entity types, including [functions](https://www.postgresql.org/docs/current/sql-createfunction.html), [views](https://www.postgresql.org/docs/current/sql-createview.html), [materialized views](https://www.postgresql.org/docs/current/sql-creatematerializedview.html), [triggers](https://www.postgresql.org/docs/current/sql-createtrigger.html), and [policies](https://www.postgresql.org/docs/current/sql-createpolicy.html). 35 | 36 | 37 | Visit the [quickstart guide](quickstart.md) for usage instructions. 38 | 39 |

—— ——

40 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | ## Quickstart 2 | 3 | ### Installation 4 | *Requirements* Python 3.6+ 5 | 6 | First, install alembic_utils 7 | ```shell 8 | $ pip install alembic_utils 9 | ``` 10 | 11 | Next, add "alembic_utils" to the logger keys in `alembic.ini` and a configuration for it. 12 | ``` 13 | ... 14 | [loggers] 15 | keys=root,sqlalchemy,alembic,alembic_utils 16 | 17 | [logger_alembic_utils] 18 | level = INFO 19 | handlers = 20 | qualname = alembic_utils 21 | ``` 22 | 23 | ### Reference 24 | 25 | Then add a function to your project 26 | ```python 27 | # my_function.py 28 | from alembic_utils.pg_function import PGFunction 29 | 30 | to_upper = PGFunction( 31 | schema='public', 32 | signature='to_upper(some_text text)', 33 | definition=""" 34 | RETURNS text as 35 | $$ 36 | SELECT upper(some_text) 37 | $$ language SQL; 38 | """ 39 | ) 40 | ``` 41 | 42 | and/or a view 43 | ```python 44 | # my_view.py 45 | from alembic_utils.pg_view import PGView 46 | 47 | first_view = PGView( 48 | schema="public", 49 | signature="first_view", 50 | definition="select * from information_schema.tables", 51 | ) 52 | 53 | ``` 54 | 55 | 56 | 57 | Finally, update your `/env.py` to register your entities with alembic_utils. 58 | 59 | ```python 60 | # /env.py 61 | 62 | # Add these lines 63 | from alembic_utils.replaceable_entity import register_entities 64 | 65 | from my_function import to_upper 66 | from my_view import first_view 67 | 68 | register_entities([to_upper, first_view]) 69 | ``` 70 | 71 | You're done! 72 | 73 | The next time you autogenerate a revision with 74 | ```shell 75 | alembic revision --autogenerate -m 'some message' 76 | ``` 77 | Alembic will detect if your entities are new, updated, or removed & populate the revision's `upgrade` and `downgrade` sections automatically. 78 | 79 | For example outputs, check the [examples](examples.md). 80 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: alembic_utils 2 | site_description: Autogenerate Support for PostgreSQL Functions and (soon) Views 3 | 4 | theme: 5 | name: 'readthedocs' 6 | 7 | extra_css: 8 | - css/custom.css 9 | 10 | repo_name: olirice/alembic_utils 11 | repo_url: https://github.com/olirice/alembic_utils 12 | site_url: https://olirice.github.io/alembic_utils 13 | 14 | nav: 15 | - Introduction: 'index.md' 16 | - Quickstart: 'quickstart.md' 17 | - API Reference: 'api.md' 18 | - Configuration: 'config.md' 19 | - Examples: 'examples.md' 20 | 21 | markdown_extensions: 22 | - codehilite 23 | - admonition 24 | - mkautodoc 25 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "alembic_utils" 3 | version = "0.8.8" 4 | description = "A sqlalchemy/alembic extension for migrating procedures and views" 5 | requires-python = ">=3.9" 6 | authors = [ 7 | {"name" = "Oliver Rice", "email" = "oliver@oliverrice.com"} 8 | ] 9 | license = "MIT" 10 | readme = "README.md" 11 | packages = [{include = "alembic_utils", from = "src"}] 12 | classifiers = [ 13 | "Development Status :: 4 - Beta", 14 | "License :: OSI Approved :: MIT License", 15 | "Intended Audience :: Developers", 16 | "Programming Language :: Python :: 3", 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.11", 20 | "Programming Language :: Python :: 3.12", 21 | "Programming Language :: Python :: 3.13", 22 | "Programming Language :: SQL", 23 | ] 24 | dependencies = [ 25 | "alembic>=1.9", 26 | "flupy", 27 | "parse>=1.8.4", 28 | "sqlalchemy>=1.4", 29 | "typing_extensions>=0.1.0", 30 | ] 31 | 32 | [build-system] 33 | requires = ["poetry-core>=2.0.0,<3.0.0"] 34 | build-backend = "poetry.core.masonry.api" 35 | 36 | [tool.poetry.urls] 37 | "PyPI" = "https://pypi.org/project/alembic-utils/" 38 | "GitHub" = "https://github.com/olirice/alembic_utils" 39 | 40 | [tool.poetry.group.dev.dependencies] 41 | black = "*" 42 | pylint = "*" 43 | pre-commit = "*" 44 | mypy = "*" 45 | psycopg2-binary = "*" 46 | pytest = "*" 47 | pytest-cov = "*" 48 | mkdocs = "*" 49 | 50 | [tool.poetry.group.nvim.dependencies] 51 | neovim = "*" 52 | python-language-server = "*" 53 | 54 | [tool.poetry.group.docs.dependencies] 55 | mkdocs = "*" 56 | pygments = "*" 57 | pymdown-extensions = "*" 58 | mkautodoc = "*" 59 | 60 | [tool.setuptools.packages.find] 61 | where = ["src"] 62 | 63 | [tool.setuptools.package-data] 64 | "*" = ["py.typed"] 65 | 66 | [tool.black] 67 | line-length = 100 68 | exclude = ''' 69 | /( 70 | \.git 71 | | \.hg 72 | | \.mypy_cache 73 | | \.tox 74 | | \.venv 75 | | _build 76 | | buck-out 77 | | build 78 | | dist 79 | )/ 80 | ''' 81 | 82 | [tool.isort] 83 | known_third_party = ["alembic", "flupy", "parse", "pytest", "setuptools", "sqlalchemy"] 84 | 85 | [tool.mypy] 86 | follow_imports = "skip" 87 | strict_optional = true 88 | warn_redundant_casts = true 89 | warn_unused_ignores = false 90 | disallow_any_generics = true 91 | check_untyped_defs = true 92 | no_implicit_reexport = true 93 | ignore_missing_imports = true 94 | # disallow_untyped_defs = true 95 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --cov=alembic_utils src/test 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def get_version(package): 8 | """ 9 | Return package version as listed in `__version__` in `__init__.py`. 10 | """ 11 | with open(os.path.join("src", package, "__init__.py")) as f: 12 | return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) 13 | 14 | 15 | DEV_REQUIRES = [ 16 | "black", 17 | "pylint", 18 | "pre-commit", 19 | "mypy", 20 | "psycopg2-binary", 21 | "pytest", 22 | "pytest-cov", 23 | "mkdocs", 24 | ] 25 | 26 | setup( 27 | name="alembic_utils", 28 | version=get_version("alembic_utils"), 29 | author="Oliver Rice", 30 | author_email="oliver@oliverrice.com", 31 | license="MIT", 32 | description="A sqlalchemy/alembic extension for migrating procedures and views ", 33 | python_requires=">=3.7", 34 | packages=find_packages("src"), 35 | package_dir={"": "src"}, 36 | install_requires=[ 37 | "alembic>=1.9", 38 | "flupy", 39 | "parse>=1.8.4", 40 | "sqlalchemy>=1.4", 41 | "typing_extensions", 42 | ], 43 | extras_require={ 44 | "dev": DEV_REQUIRES, 45 | "nvim": ["neovim", "python-language-server"], 46 | "docs": ["mkdocs", "pygments", "pymdown-extensions", "mkautodoc"], 47 | }, 48 | package_data={"": ["py.typed"]}, 49 | include_package_data=True, 50 | classifiers=[ 51 | "Development Status :: 4 - Beta", 52 | "License :: OSI Approved :: MIT License", 53 | "Intended Audience :: Developers", 54 | "Programming Language :: Python :: 3", 55 | "Programming Language :: Python :: 3.8", 56 | "Programming Language :: Python :: 3.9", 57 | "Programming Language :: Python :: 3.10", 58 | "Programming Language :: Python :: 3.11", 59 | "Programming Language :: Python :: 3.12", 60 | "Programming Language :: SQL", 61 | ], 62 | ) 63 | -------------------------------------------------------------------------------- /src/alembic_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olirice/alembic_utils/ccf9be7b8097bd6df13886a40a8e22c847c10bc8/src/alembic_utils/__init__.py -------------------------------------------------------------------------------- /src/alembic_utils/depends.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | from typing import Generator, List 4 | 5 | from sqlalchemy import exc as sqla_exc 6 | from sqlalchemy.orm import Session 7 | 8 | from alembic_utils.simulate import simulate_entity 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def solve_resolution_order(sess: Session, entities): 14 | """Solve for an entity resolution order that increases the probability that 15 | a migration will suceed if, for example, two new views are created and one 16 | refers to the other 17 | 18 | This strategy will only solve for simple cases 19 | """ 20 | 21 | resolved = [] 22 | 23 | # Resolve the entities with 0 dependencies first (faster) 24 | logger.info("Resolving entities with no dependencies") 25 | for entity in entities: 26 | try: 27 | with simulate_entity(sess, entity): 28 | resolved.append(entity) 29 | except (sqla_exc.ProgrammingError, sqla_exc.InternalError) as exc: 30 | continue 31 | 32 | # Resolve entities with possible dependencies 33 | logger.info("Resolving entities with dependencies. This may take a minute") 34 | for _ in range(len(entities)): 35 | n_resolved = len(resolved) 36 | 37 | for entity in entities: 38 | if entity in resolved: 39 | continue 40 | 41 | try: 42 | with simulate_entity(sess, entity, dependencies=resolved): 43 | resolved.append(entity) 44 | except (sqla_exc.ProgrammingError, sqla_exc.InternalError): 45 | continue 46 | 47 | if len(resolved) == n_resolved: 48 | # No new entities resolved in the last iteration. Exit 49 | break 50 | 51 | for entity in entities: 52 | if entity not in resolved: 53 | resolved.append(entity) 54 | 55 | return resolved 56 | 57 | 58 | @contextmanager 59 | def recreate_dropped(connection) -> Generator[Session, None, None]: 60 | """Recreate any dropped all ReplaceableEntities that were dropped within block 61 | 62 | This is useful for making cascading updates. For example, updating a table's column type when it has dependent views. 63 | 64 | def upgrade() -> None: 65 | 66 | my_view = PGView(...) 67 | 68 | with recreate_dropped(op.get_bind()) as conn: 69 | 70 | op.drop_entity(my_view) 71 | 72 | # change an integer column to a bigint 73 | op.alter_column( 74 | table_name="account", 75 | column_name="id", 76 | schema="public" 77 | type_=sa.BIGINT() 78 | existing_type=sa.Integer(), 79 | ) 80 | """ 81 | from alembic_utils.pg_function import PGFunction 82 | from alembic_utils.pg_materialized_view import PGMaterializedView 83 | from alembic_utils.pg_trigger import PGTrigger 84 | from alembic_utils.pg_view import PGView 85 | from alembic_utils.replaceable_entity import ReplaceableEntity 86 | 87 | # Do not include permissions here e.g. PGGrantTable. If columns granted to users are dropped, it will cause an error 88 | 89 | def collect_all_db_entities(sess: Session) -> List[ReplaceableEntity]: 90 | """Collect all entities from the database""" 91 | 92 | return [ 93 | *PGFunction.from_database(sess, "%"), 94 | *PGTrigger.from_database(sess, "%"), 95 | *PGView.from_database(sess, "%"), 96 | *PGMaterializedView.from_database(sess, "%"), 97 | ] 98 | 99 | sess = Session(bind=connection) 100 | 101 | # All existing entities, before the upgrade 102 | before = collect_all_db_entities(sess) 103 | 104 | # In the yield, do a 105 | # op.drop_entity(my_mat_view, cascade=True) 106 | # op.create_entity(my_mat_view) 107 | try: 108 | yield sess 109 | except: 110 | sess.rollback() 111 | raise 112 | 113 | # All existing entities, after the upgrade 114 | after = collect_all_db_entities(sess) 115 | after_identities = {x.identity for x in after} 116 | 117 | # Entities that were not impacted, or that we have "recovered" 118 | resolved = [] 119 | unresolved = [] 120 | 121 | # First, ignore the ones that were not impacted by the upgrade 122 | for ent in before: 123 | if ent.identity in after_identities: 124 | resolved.append(ent) 125 | else: 126 | unresolved.append(ent) 127 | 128 | # Attempt to find an acceptable order of creation for the unresolved entities 129 | ordered_unresolved = solve_resolution_order(sess, unresolved) 130 | 131 | # Attempt to recreate the missing entities in the specified order 132 | for ent in ordered_unresolved: 133 | sess.execute(ent.to_sql_statement_create()) 134 | 135 | # Sanity check that everything is now fine 136 | sanity_check = collect_all_db_entities(sess) 137 | # Fail and rollback if the sanity check is wrong 138 | try: 139 | assert len(before) == len(sanity_check) 140 | except: 141 | sess.rollback() 142 | raise 143 | 144 | # Close out the session 145 | sess.commit() 146 | -------------------------------------------------------------------------------- /src/alembic_utils/exceptions.py: -------------------------------------------------------------------------------- 1 | class AlembicUtilsException(Exception): 2 | """Base exception for AlembicUtils package""" 3 | 4 | 5 | class SQLParseFailure(AlembicUtilsException): 6 | """An entity could not be parsed""" 7 | 8 | 9 | class FailedToGenerateComparable(AlembicUtilsException): 10 | """Failed to generate a comparable entity""" 11 | 12 | 13 | class UnreachableException(AlembicUtilsException): 14 | """An exception no one should ever see""" 15 | 16 | 17 | class BadInputException(AlembicUtilsException): 18 | """Invalid user input""" 19 | -------------------------------------------------------------------------------- /src/alembic_utils/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from alembic_utils.experimental._collect_instances import ( 2 | collect_instances, 3 | collect_subclasses, 4 | ) 5 | 6 | __all__ = ["collect_instances", "collect_subclasses"] 7 | -------------------------------------------------------------------------------- /src/alembic_utils/experimental/_collect_instances.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from pathlib import Path 4 | from types import ModuleType 5 | from typing import Generator, List, Type, TypeVar 6 | 7 | from flupy import walk_files 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | def walk_modules(module: ModuleType) -> Generator[ModuleType, None, None]: 13 | """Recursively yield python import paths to submodules in *module* 14 | 15 | Example: 16 | module_iter = walk_modules(alembic_utils) 17 | 18 | for module_path in module_iter: 19 | print(module_path) 20 | 21 | # alembic_utils.exceptions 22 | # alembic_utils.on_entity_mixin 23 | # ... 24 | """ 25 | top_module = module 26 | top_path = Path(top_module.__path__[0]) 27 | top_path_absolute = top_path.resolve() 28 | 29 | directories = ( 30 | walk_files(str(top_path_absolute)) 31 | .filter(lambda x: x.endswith(".py")) 32 | .map(Path) 33 | .group_by(lambda x: x.parent) 34 | .map(lambda x: (x[0], x[1].collect())) 35 | ) 36 | 37 | for base_path, files in directories: 38 | if str(base_path / "__init__.py") in [str(x) for x in files]: 39 | for module_path in files: 40 | if "__init__.py" not in str(module_path): 41 | 42 | # Example: elt.settings 43 | module_import_path = str(module_path)[ 44 | len(str(top_path_absolute)) - len(top_module.__name__) : 45 | ].replace(os.path.sep, ".")[:-3] 46 | 47 | module = importlib.import_module(module_import_path) 48 | yield module 49 | 50 | 51 | def collect_instances(module: ModuleType, class_: Type[T]) -> List[T]: 52 | """Collect all instances of *class_* defined in *module* 53 | 54 | Note: Will import all submodules in *module*. Beware of import side effects 55 | """ 56 | 57 | found: List[T] = [] 58 | 59 | for module_ in walk_modules(module): 60 | 61 | for _, variable in module_.__dict__.items(): 62 | 63 | if isinstance(variable, class_): 64 | # Ensure variable is not a subclass 65 | if variable.__class__ == class_: 66 | found.append(variable) 67 | return found 68 | 69 | 70 | def collect_subclasses(module: ModuleType, class_: Type[T]) -> List[Type[T]]: 71 | """Collect all subclasses of *class_* currently imported or defined in *module* 72 | 73 | Note: Will import all submodules in *module*. Beware of import side effects 74 | """ 75 | 76 | found: List[Type[T]] = [] 77 | 78 | for module_ in walk_modules(module): 79 | 80 | for _, variable in module_.__dict__.items(): 81 | try: 82 | if issubclass(variable, class_) and not class_ == variable: 83 | found.append(variable) 84 | except TypeError: 85 | # argument 2 to issubclass must be a class .... 86 | pass 87 | 88 | imported: List[Type[T]] = list(class_.__subclasses__()) # type: ignore 89 | 90 | return list(set(found + imported)) 91 | -------------------------------------------------------------------------------- /src/alembic_utils/on_entity_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from alembic_utils.statement import coerce_to_unquoted 4 | 5 | if TYPE_CHECKING: 6 | from alembic_utils.replaceable_entity import ReplaceableEntity 7 | 8 | _Base = ReplaceableEntity 9 | else: 10 | _Base = object 11 | 12 | 13 | class OnEntityMixin(_Base): 14 | """Mixin to ReplaceableEntity providing setup for entity types requiring an "ON" clause""" 15 | 16 | def __init__(self, schema: str, signature: str, definition: str, on_entity: str): 17 | super().__init__(schema=schema, signature=signature, definition=definition) 18 | 19 | if "." not in on_entity: 20 | on_entity = "public." + on_entity 21 | 22 | # Guarenteed to have a schema 23 | self.on_entity = coerce_to_unquoted(on_entity) 24 | 25 | @property 26 | def identity(self) -> str: 27 | """A string that consistently and globally identifies a function 28 | 29 | Overriding default to add the "on table" clause 30 | """ 31 | return f"{self.__class__.__name__}: {self.schema}.{self.signature} {self.on_entity}" 32 | 33 | def render_self_for_migration(self, omit_definition=False) -> str: 34 | """Render a string that is valid python code to reconstruct self in a migration""" 35 | var_name = self.to_variable_name() 36 | class_name = self.__class__.__name__ 37 | escaped_definition = self.definition if not omit_definition else "# not required for op" 38 | 39 | return f"""{var_name} = {class_name}( 40 | schema="{self.schema}", 41 | signature="{self.signature}", 42 | on_entity="{self.on_entity}", 43 | definition={repr(escaped_definition)} 44 | )\n""" 45 | 46 | def to_variable_name(self) -> str: 47 | """A deterministic variable name based on PGFunction's contents""" 48 | schema_name = self.schema.lower() 49 | object_name = self.signature.split("(")[0].strip().lower() 50 | _, _, unqualified_entity_name = self.on_entity.lower().partition(".") 51 | return f"{schema_name}_{unqualified_entity_name}_{object_name}" 52 | -------------------------------------------------------------------------------- /src/alembic_utils/pg_extension.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long 2 | 3 | 4 | from typing import Generator 5 | 6 | from sqlalchemy import text as sql_text 7 | from sqlalchemy.sql.elements import TextClause 8 | 9 | from alembic_utils.replaceable_entity import ReplaceableEntity 10 | from alembic_utils.statement import coerce_to_unquoted, normalize_whitespace 11 | 12 | 13 | class PGExtension(ReplaceableEntity): 14 | """A PostgreSQL Extension compatible with `alembic revision --autogenerate` 15 | 16 | **Parameters:** 17 | 18 | * **schema** - *str*: A SQL schema name 19 | * **signature** - *str*: A PostgreSQL extension's name 20 | """ 21 | 22 | type_ = "extension" 23 | 24 | def __init__(self, schema: str, signature: str): 25 | self.schema: str = coerce_to_unquoted(normalize_whitespace(schema)) 26 | self.signature: str = coerce_to_unquoted(normalize_whitespace(signature)) 27 | # Include schema in definition since extensions can only exist once per 28 | # database and we want to detect schema changes and emit alter schema 29 | self.definition: str = f"{self.__class__.__name__}: {self.schema} {self.signature}" 30 | 31 | def to_sql_statement_create(self) -> TextClause: 32 | """Generates a SQL "create extension" statement""" 33 | return sql_text(f'CREATE EXTENSION "{self.signature}" WITH SCHEMA {self.literal_schema};') 34 | 35 | def to_sql_statement_drop(self, cascade=False) -> TextClause: 36 | """Generates a SQL "drop extension" statement""" 37 | cascade = "CASCADE" if cascade else "" 38 | return sql_text(f'DROP EXTENSION "{self.signature}" {cascade}') 39 | 40 | def to_sql_statement_create_or_replace(self) -> Generator[TextClause, None, None]: 41 | """Generates SQL equivalent to "create or replace" statement""" 42 | raise NotImplementedError() 43 | 44 | @property 45 | def identity(self) -> str: 46 | """A string that consistently and globally identifies an extension""" 47 | # Extensions may only be installed once per db, schema is not a 48 | # component of identity 49 | return f"{self.__class__.__name__}: {self.signature}" 50 | 51 | def render_self_for_migration(self, omit_definition=False) -> str: 52 | """Render a string that is valid python code to reconstruct self in a migration""" 53 | var_name = self.to_variable_name() 54 | class_name = self.__class__.__name__ 55 | 56 | return f"""{var_name} = {class_name}( 57 | schema="{self.schema}", 58 | signature="{self.signature}" 59 | )\n""" 60 | 61 | @classmethod 62 | def from_database(cls, sess, schema): 63 | """Get a list of all extensions defined in the db""" 64 | sql = sql_text( 65 | f""" 66 | select 67 | np.nspname schema_name, 68 | ext.extname extension_name 69 | from 70 | pg_extension ext 71 | join pg_namespace np 72 | on ext.extnamespace = np.oid 73 | where 74 | np.nspname not in ('pg_catalog') 75 | and np.nspname like :schema; 76 | """ 77 | ) 78 | rows = sess.execute(sql, {"schema": schema}).fetchall() 79 | db_exts = [cls(x[0], x[1]) for x in rows] 80 | return db_exts 81 | -------------------------------------------------------------------------------- /src/alembic_utils/pg_function.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long 2 | from typing import List 3 | 4 | from parse import parse 5 | from sqlalchemy import text as sql_text 6 | 7 | from alembic_utils.exceptions import SQLParseFailure 8 | from alembic_utils.replaceable_entity import ReplaceableEntity 9 | from alembic_utils.statement import ( 10 | escape_colon_for_plpgsql, 11 | escape_colon_for_sql, 12 | normalize_whitespace, 13 | strip_terminating_semicolon, 14 | ) 15 | 16 | 17 | class PGFunction(ReplaceableEntity): 18 | """A PostgreSQL Function compatible with `alembic revision --autogenerate` 19 | 20 | **Parameters:** 21 | 22 | * **schema** - *str*: A SQL schema name 23 | * **signature** - *str*: A SQL function's call signature 24 | * **definition** - *str*: The remainig function body and identifiers 25 | """ 26 | 27 | type_ = "function" 28 | 29 | def __init__(self, schema: str, signature: str, definition: str): 30 | super().__init__(schema, signature, definition) 31 | # Detect if function uses plpgsql and update escaping rules to not escape ":=" 32 | is_plpgsql: bool = "language plpgsql" in normalize_whitespace(definition).lower().replace( 33 | "'", "" 34 | ) 35 | escaping_callable = escape_colon_for_plpgsql if is_plpgsql else escape_colon_for_sql 36 | # Override definition with correct escaping rules 37 | self.definition: str = escaping_callable(strip_terminating_semicolon(definition)) 38 | 39 | @classmethod 40 | def from_sql(cls, sql: str) -> "PGFunction": 41 | """Create an instance instance from a SQL string""" 42 | template = "create{}function{:s}{schema}.{signature}{:s}returns{:s}{definition}" 43 | result = parse(template, sql.strip(), case_sensitive=False) 44 | if result is not None: 45 | # remove possible quotes from signature 46 | raw_signature = result["signature"] 47 | signature = ( 48 | "".join(raw_signature.split('"', 2)) 49 | if raw_signature.startswith('"') 50 | else raw_signature 51 | ) 52 | return cls( 53 | schema=result["schema"], 54 | signature=signature, 55 | definition="returns " + result["definition"], 56 | ) 57 | raise SQLParseFailure(f'Failed to parse SQL into PGFunction """{sql}"""') 58 | 59 | @property 60 | def literal_signature(self) -> str: 61 | """Adds quoting around the functions name when emitting SQL statements 62 | 63 | e.g. 64 | 'toUpper(text) returns text' -> '"toUpper"(text) returns text' 65 | """ 66 | # May already be quoted if loading from database or SQL file 67 | name, remainder = self.signature.split("(", 1) 68 | return '"' + name.strip() + '"(' + remainder 69 | 70 | def to_sql_statement_create(self): 71 | """Generates a SQL "create function" statement for PGFunction""" 72 | return sql_text( 73 | f"CREATE FUNCTION {self.literal_schema}.{self.literal_signature} {self.definition}" 74 | ) 75 | 76 | def to_sql_statement_drop(self, cascade=False): 77 | """Generates a SQL "drop function" statement for PGFunction""" 78 | cascade = "cascade" if cascade else "" 79 | template = "{function_name}({parameters})" 80 | result = parse(template, self.signature, case_sensitive=False) 81 | try: 82 | function_name = result["function_name"].strip() 83 | parameters_str = result["parameters"].strip() 84 | except TypeError: 85 | # Did not match, NoneType is not scriptable 86 | result = parse("{function_name}()", self.signature, case_sensitive=False) 87 | function_name = result["function_name"].strip() 88 | parameters_str = "" 89 | 90 | # NOTE: Will fail if a text field has a default and that deafult contains a comma... 91 | parameters: List[str] = parameters_str.split(",") 92 | parameters = [x[: len(x.lower().split("default")[0])] for x in parameters] 93 | parameters = [x.strip() for x in parameters] 94 | drop_params = ", ".join(parameters) 95 | return sql_text( 96 | f'DROP FUNCTION {self.literal_schema}."{function_name}"({drop_params}) {cascade}' 97 | ) 98 | 99 | def to_sql_statement_create_or_replace(self): 100 | """Generates a SQL "create or replace function" statement for PGFunction""" 101 | yield sql_text( 102 | f"CREATE OR REPLACE FUNCTION {self.literal_schema}.{self.literal_signature} {self.definition}" 103 | ) 104 | 105 | @classmethod 106 | def from_database(cls, sess, schema): 107 | """Get a list of all functions defined in the db""" 108 | 109 | # Prior to postgres 11, pg_proc had different columns 110 | # https://github.com/olirice/alembic_utils/issues/12 111 | PG_GTE_11 = """ 112 | and p.prokind = 'f' 113 | """ 114 | 115 | PG_LT_11 = """ 116 | and not p.proisagg 117 | and not p.proiswindow 118 | """ 119 | 120 | # Retrieve the postgres server version e.g. 90603 for 9.6.3 or 120003 for 12.3 121 | pg_version_str = sess.execute(sql_text("show server_version_num")).fetchone()[0] 122 | pg_version = int(pg_version_str) 123 | 124 | sql = sql_text( 125 | f""" 126 | with extension_functions as ( 127 | select 128 | objid as extension_function_oid 129 | from 130 | pg_depend 131 | where 132 | -- depends on an extension 133 | deptype='e' 134 | -- is a proc/function 135 | and classid = 'pg_proc'::regclass 136 | ) 137 | 138 | select 139 | n.nspname as function_schema, 140 | p.proname as function_name, 141 | pg_get_function_arguments(p.oid) as function_arguments, 142 | case 143 | when l.lanname = 'internal' then p.prosrc 144 | else pg_get_functiondef(p.oid) 145 | end as create_statement, 146 | t.typname as return_type, 147 | l.lanname as function_language 148 | from 149 | pg_proc p 150 | left join pg_namespace n on p.pronamespace = n.oid 151 | left join pg_language l on p.prolang = l.oid 152 | left join pg_type t on t.oid = p.prorettype 153 | left join extension_functions ef on p.oid = ef.extension_function_oid 154 | where 155 | n.nspname not in ('pg_catalog', 'information_schema') 156 | -- Filter out functions from extensions 157 | and ef.extension_function_oid is null 158 | and n.nspname = :schema 159 | """ 160 | + (PG_GTE_11 if pg_version >= 110000 else PG_LT_11) 161 | ) 162 | 163 | rows = sess.execute(sql, {"schema": schema}).fetchall() 164 | db_functions = [cls.from_sql(x[3]) for x in rows] 165 | 166 | for func in db_functions: 167 | assert func is not None 168 | 169 | return db_functions 170 | -------------------------------------------------------------------------------- /src/alembic_utils/pg_grant_table.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import Generator, List, Optional, Union 4 | 5 | from flupy import flu 6 | from sqlalchemy import text as sql_text 7 | from sqlalchemy.orm import Session 8 | from sqlalchemy.sql.elements import TextClause 9 | 10 | from alembic_utils.exceptions import BadInputException 11 | from alembic_utils.replaceable_entity import ReplaceableEntity 12 | from alembic_utils.statement import coerce_to_quoted, coerce_to_unquoted 13 | 14 | 15 | class PGGrantTableChoice(str, Enum): 16 | # Applies at column level 17 | SELECT = "SELECT" 18 | INSERT = "INSERT" 19 | UPDATE = "UPDATE" 20 | REFERENCES = "REFERENCES" 21 | # Applies at table Level 22 | DELETE = "DELETE" 23 | TRUNCATE = "TRUNCATE" 24 | TRIGGER = "TRIGGER" 25 | 26 | def __str__(self) -> str: 27 | return str.__str__(self) 28 | 29 | def __repr__(self) -> str: 30 | return f"'{str.__str__(self)}'" 31 | 32 | 33 | C = PGGrantTableChoice 34 | 35 | 36 | @dataclass(frozen=True, eq=True, order=True) 37 | class SchemaTableRole: 38 | schema: str 39 | table: str 40 | role: str 41 | grant: PGGrantTableChoice 42 | with_grant_option: str # 'YES' or 'NO' 43 | 44 | 45 | @dataclass 46 | class PGGrantTable(ReplaceableEntity): 47 | """A PostgreSQL Grant Statement compatible with `alembic revision --autogenerate` 48 | 49 | PGGrantTable requires the role name being used to generate migrations to 50 | match the role name that executes migrations. 51 | 52 | If your system does not meet that requirement, disable them by excluding PGGrantTable 53 | in `include_object` https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.environment.EnvironmentContext.configure.params.include_object 54 | 55 | **Parameters:** 56 | 57 | * **schema** - *str*: A SQL schema name 58 | * **table** - *str*: The table to grant access to 59 | * **columns** - *List[str]*: A list of column names on *table* to grant access to 60 | * **role** - *str*: The role to grant access to 61 | * **grant** - *Union[Grant, str]*: On of SELECT, INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER 62 | * **with_grant_option** - *bool*: Can the role grant access to other roles 63 | """ 64 | 65 | schema: str 66 | table: str 67 | columns: List[str] 68 | role: str 69 | grant: PGGrantTableChoice 70 | with_grant_option: bool 71 | 72 | type_ = "grant_table" 73 | 74 | def __init__( 75 | self, 76 | schema: str, 77 | table: str, 78 | role: str, 79 | grant: Union[PGGrantTableChoice, str], 80 | columns: Optional[List[str]] = None, 81 | with_grant_option=False, 82 | ): 83 | self.schema: str = coerce_to_unquoted(schema) 84 | self.table: str = coerce_to_unquoted(table) 85 | self.columns: List[str] = sorted(columns) if columns else [] 86 | self.role: str = coerce_to_unquoted(role) 87 | self.grant: PGGrantTableChoice = PGGrantTableChoice(grant) 88 | self.with_grant_option: bool = with_grant_option 89 | self.signature = self.identity 90 | 91 | if PGGrantTableChoice(self.grant) in {C.SELECT, C.INSERT, C.UPDATE, C.REFERENCES}: 92 | if len(self.columns) == 0: 93 | raise BadInputException( 94 | f"When grant type is {self.grant} a value must be provided for columns" 95 | ) 96 | else: 97 | if self.columns: 98 | raise BadInputException( 99 | f"When grant type is {self.grant} a value must not be provided for columns" 100 | ) 101 | 102 | @classmethod 103 | def from_sql(cls, sql: str) -> "PGGrantTable": 104 | raise NotImplementedError() 105 | 106 | @property 107 | def identity(self) -> str: 108 | """A string that consistently and globally identifies a function""" 109 | # rows in information_schema.role_column_grants are uniquely identified by 110 | # the columns listed below + the grantor 111 | # be cautious when editing 112 | return f"{self.__class__.__name__}: {self.schema}.{self.table}.{self.role}.{self.grant}" 113 | 114 | @property 115 | def definition(self) -> str: # type: ignore 116 | return str(self) 117 | 118 | def to_variable_name(self) -> str: 119 | """A deterministic variable name based on PGFunction's contents""" 120 | schema_name = self.schema.lower() 121 | table_name = self.table.lower() 122 | role_name = self.role.lower() 123 | return f"{schema_name}_{table_name}_{role_name}_{str(self.grant)}".lower() 124 | 125 | def render_self_for_migration(self, omit_definition=False) -> str: 126 | """Render a string that is valid python code to reconstruct self in a migration""" 127 | var_name = self.to_variable_name() 128 | class_name = self.__class__.__name__ 129 | 130 | return f"""{var_name} = {self}\n""" 131 | 132 | @classmethod 133 | def from_database(cls, sess: Session, schema: str = "%"): 134 | # COLUMN LEVEL 135 | sql = sql_text( 136 | """ 137 | SELECT 138 | table_schema as schema, 139 | table_name, 140 | grantee as role_name, 141 | privilege_type as grant_option, 142 | is_grantable, 143 | column_name 144 | FROM 145 | information_schema.role_column_grants rcg 146 | -- Cant revoke from superusers so filter out those recs 147 | join pg_roles pr 148 | on rcg.grantee = pr.rolname 149 | WHERE 150 | not pr.rolsuper 151 | and grantor = CURRENT_USER 152 | and table_schema like :schema 153 | and privilege_type in ('SELECT', 'INSERT', 'UPDATE', 'REFERENCES') 154 | """ 155 | ) 156 | 157 | rows = sess.execute(sql, params={"schema": schema}).fetchall() 158 | grants = [] 159 | 160 | grouped = ( 161 | flu(rows) 162 | .group_by(lambda x: SchemaTableRole(*x[:5])) 163 | .map(lambda x: (x[0], x[1].map_item(5).collect())) 164 | .collect() 165 | ) 166 | for s_t_r, columns in grouped: 167 | grant = cls( 168 | schema=s_t_r.schema, 169 | table=s_t_r.table, 170 | role=s_t_r.role, 171 | grant=s_t_r.grant, 172 | with_grant_option=s_t_r.with_grant_option == "YES", 173 | columns=columns, 174 | ) 175 | grants.append(grant) 176 | 177 | # TABLE LEVEL 178 | sql = sql_text( 179 | """ 180 | SELECT 181 | table_schema as schema_name, 182 | table_name, 183 | grantee as role_name, 184 | privilege_type as grant_option, 185 | is_grantable 186 | FROM 187 | information_schema.role_table_grants rcg 188 | -- Cant revoke from superusers so filter out those recs 189 | join pg_roles pr 190 | on rcg.grantee = pr.rolname 191 | WHERE 192 | not pr.rolsuper 193 | and grantor = CURRENT_USER 194 | and table_schema like :schema 195 | and privilege_type in ('DELETE', 'TRUNCATE', 'TRIGGER') 196 | """ 197 | ) 198 | 199 | rows = sess.execute(sql, params={"schema": schema}).fetchall() 200 | 201 | for schema_name, table_name, role_name, grant_option, is_grantable in rows: 202 | grant = cls( 203 | schema=schema_name, 204 | table=table_name, 205 | role=role_name, 206 | grant=grant_option, 207 | with_grant_option=is_grantable == "YES", 208 | ) 209 | grants.append(grant) 210 | return grants 211 | 212 | def to_sql_statement_create(self) -> TextClause: 213 | """Generates a SQL "create view" statement""" 214 | with_grant_option = " WITH GRANT OPTION" if self.with_grant_option else "" 215 | maybe_columns_clause = f'( {", ".join(self.columns)} )' if self.columns else "" 216 | return sql_text( 217 | f"GRANT {self.grant} {maybe_columns_clause} ON {self.literal_schema}.{coerce_to_quoted(self.table)} TO {coerce_to_quoted(self.role)} {with_grant_option}" 218 | ) 219 | 220 | def to_sql_statement_drop(self, cascade=False) -> TextClause: 221 | """Generates a SQL "drop view" statement""" 222 | # cascade has no impact 223 | return sql_text( 224 | f"REVOKE {self.grant} ON {self.literal_schema}.{coerce_to_quoted(self.table)} FROM {coerce_to_quoted(self.role)}" 225 | ) 226 | 227 | def to_sql_statement_create_or_replace(self) -> Generator[TextClause, None, None]: 228 | yield self.to_sql_statement_drop() 229 | yield self.to_sql_statement_create() 230 | -------------------------------------------------------------------------------- /src/alembic_utils/pg_materialized_view.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long 2 | 3 | 4 | from typing import Generator 5 | 6 | from parse import parse 7 | from sqlalchemy import text as sql_text 8 | from sqlalchemy.sql.elements import TextClause 9 | 10 | from alembic_utils.exceptions import SQLParseFailure 11 | from alembic_utils.replaceable_entity import ReplaceableEntity 12 | from alembic_utils.statement import ( 13 | coerce_to_unquoted, 14 | normalize_whitespace, 15 | strip_terminating_semicolon, 16 | ) 17 | 18 | 19 | class PGMaterializedView(ReplaceableEntity): 20 | """A PostgreSQL Materialized View compatible with `alembic revision --autogenerate` 21 | 22 | Limitations: 23 | Materialized views may not have other views or materialized views that depend on them. 24 | 25 | **Parameters:** 26 | 27 | * **schema** - *str*: A SQL schema name 28 | * **signature** - *str*: A SQL view's call signature 29 | * **definition** - *str*: The SQL select statement body of the view 30 | * **with_data** - *bool*: Should create and replace statements populate data 31 | """ 32 | 33 | type_ = "materialized_view" 34 | 35 | def __init__(self, schema: str, signature: str, definition: str, with_data: bool = True): 36 | self.schema: str = coerce_to_unquoted(normalize_whitespace(schema)) 37 | self.signature: str = coerce_to_unquoted(normalize_whitespace(signature)) 38 | self.definition: str = strip_terminating_semicolon(definition) 39 | self.with_data = with_data 40 | 41 | @classmethod 42 | def from_sql(cls, sql: str) -> "PGMaterializedView": 43 | """Create an instance from a SQL string""" 44 | 45 | # Strip optional semicolon and all whitespace from end of definition 46 | # because the "with data" clause is optional and the alternative would be to enumerate 47 | # every possibility in the templates 48 | sql = strip_terminating_semicolon(sql) 49 | 50 | templates = [ 51 | # Enumerate maybe semicolon endings 52 | "create{}materialized{}view{:s}{schema}.{signature}{:s}as{:s}{definition}{:s}with{:s}data", 53 | "create{}materialized{}view{:s}{schema}.{signature}{:s}as{:s}{definition}{}with{:s}{no_data}{:s}data", 54 | "create{}materialized{}view{:s}{schema}.{signature}{:s}as{:s}{definition}", 55 | ] 56 | 57 | for template in templates: 58 | result = parse(template, sql, case_sensitive=False) 59 | 60 | if result is not None: 61 | with_data = not "no_data" in result 62 | 63 | # If the signature includes column e.g. my_view (col1, col2, col3) remove them 64 | signature = result["signature"].split("(")[0] 65 | 66 | return cls( 67 | schema=result["schema"], 68 | # strip quote characters 69 | signature=signature.replace('"', ""), 70 | definition=result["definition"], 71 | with_data=with_data, 72 | ) 73 | 74 | raise SQLParseFailure(f'Failed to parse SQL into PGView """{sql}"""') 75 | 76 | def to_sql_statement_create(self) -> TextClause: 77 | """Generates a SQL "create view" statement""" 78 | 79 | # Remove possible semicolon from definition because we're adding a "WITH DATA" clause 80 | definition = self.definition.rstrip().rstrip(";") 81 | 82 | return sql_text( 83 | f'CREATE MATERIALIZED VIEW {self.literal_schema}."{self.signature}" AS {definition} WITH {"NO" if not self.with_data else ""} DATA;' 84 | ) 85 | 86 | def to_sql_statement_drop(self, cascade=False) -> TextClause: 87 | """Generates a SQL "drop view" statement""" 88 | cascade = "cascade" if cascade else "" 89 | return sql_text( 90 | f'DROP MATERIALIZED VIEW {self.literal_schema}."{self.signature}" {cascade}' 91 | ) 92 | 93 | def to_sql_statement_create_or_replace(self) -> Generator[TextClause, None, None]: 94 | """Generates a SQL "create or replace view" statement""" 95 | # Remove possible semicolon from definition because we're adding a "WITH DATA" clause 96 | definition = self.definition.rstrip().rstrip(";") 97 | 98 | yield sql_text( 99 | f"""DROP MATERIALIZED VIEW IF EXISTS {self.literal_schema}."{self.signature}"; """ 100 | ) 101 | yield sql_text( 102 | f"""CREATE MATERIALIZED VIEW {self.literal_schema}."{self.signature}" AS {definition} WITH {"NO" if not self.with_data else ""} DATA""" 103 | ) 104 | 105 | def render_self_for_migration(self, omit_definition=False) -> str: 106 | """Render a string that is valid python code to reconstruct self in a migration""" 107 | var_name = self.to_variable_name() 108 | class_name = self.__class__.__name__ 109 | escaped_definition = self.definition if not omit_definition else "# not required for op" 110 | 111 | return f"""{var_name} = {class_name}( 112 | schema="{self.schema}", 113 | signature="{self.signature}", 114 | definition={repr(escaped_definition)}, 115 | with_data={repr(self.with_data)} 116 | )\n\n""" 117 | 118 | @classmethod 119 | def from_database(cls, sess, schema): 120 | """Get a list of all functions defined in the db""" 121 | sql = sql_text( 122 | f""" 123 | select 124 | schemaname schema_name, 125 | matviewname view_name, 126 | definition, 127 | ispopulated is_populated 128 | from 129 | pg_matviews 130 | where 131 | schemaname not in ('pg_catalog', 'information_schema') 132 | and schemaname::text like '{schema}'; 133 | """ 134 | ) 135 | rows = sess.execute(sql).fetchall() 136 | db_views = [cls(x[0], x[1], x[2], with_data=x[3]) for x in rows] 137 | 138 | for view in db_views: 139 | assert view is not None 140 | 141 | return db_views 142 | -------------------------------------------------------------------------------- /src/alembic_utils/pg_policy.py: -------------------------------------------------------------------------------- 1 | from parse import parse 2 | from sqlalchemy import text as sql_text 3 | 4 | from alembic_utils.exceptions import SQLParseFailure 5 | from alembic_utils.on_entity_mixin import OnEntityMixin 6 | from alembic_utils.replaceable_entity import ReplaceableEntity 7 | from alembic_utils.statement import coerce_to_quoted 8 | 9 | 10 | class PGPolicy(OnEntityMixin, ReplaceableEntity): 11 | """A PostgreSQL Policy compatible with `alembic revision --autogenerate` 12 | 13 | **Parameters:** 14 | 15 | * **schema** - *str*: A SQL schema name 16 | * **signature** - *str*: A SQL policy name and tablename, separated by "." 17 | * **definition** - *str*: The definition of the policy, incl. permissive, for, to, using, with check 18 | * **on_entity** - *str*: fully qualifed entity that the policy applies 19 | """ 20 | 21 | type_ = "policy" 22 | 23 | @classmethod 24 | def from_sql(cls, sql: str) -> "PGPolicy": 25 | """Create an instance instance from a SQL string""" 26 | 27 | template = "create policy{:s}{signature}{:s}on{:s}{on_entity}{:s}{definition}" 28 | result = parse(template, sql.strip(), case_sensitive=False) 29 | 30 | if result is not None: 31 | 32 | on_entity = result["on_entity"] 33 | if "." not in on_entity: 34 | schema = "public" 35 | on_entity = schema + "." + on_entity 36 | schema, _, _ = on_entity.partition(".") 37 | 38 | return cls( # type: ignore 39 | schema=schema, 40 | signature=result["signature"], 41 | definition=result["definition"], 42 | on_entity=on_entity, 43 | ) 44 | raise SQLParseFailure(f'Failed to parse SQL into PGPolicy """{sql}"""') 45 | 46 | def to_sql_statement_create(self): 47 | """Generates a SQL "create poicy" statement for PGPolicy""" 48 | 49 | return sql_text(f"CREATE POLICY {self.signature} on {self.on_entity} {self.definition}") 50 | 51 | def to_sql_statement_drop(self, cascade=False): 52 | """Generates a SQL "drop policy" statement for PGPolicy""" 53 | cascade = "cascade" if cascade else "" 54 | return sql_text(f"DROP POLICY {self.signature} on {self.on_entity} {cascade}") 55 | 56 | def to_sql_statement_create_or_replace(self): 57 | """Not implemented, postgres policies do not support replace.""" 58 | yield sql_text(f"DROP POLICY IF EXISTS {self.signature} on {self.on_entity};") 59 | yield sql_text(f"CREATE POLICY {self.signature} on {self.on_entity} {self.definition};") 60 | 61 | @classmethod 62 | def from_database(cls, connection, schema): 63 | """Get a list of all policies defined in the db""" 64 | sql = sql_text( 65 | f""" 66 | select 67 | schemaname, 68 | tablename, 69 | policyname, 70 | permissive, 71 | roles, 72 | cmd, 73 | qual, 74 | with_check 75 | from 76 | pg_policies 77 | where 78 | schemaname = '{schema}' 79 | """ 80 | ) 81 | rows = connection.execute(sql).fetchall() 82 | 83 | def get_definition(permissive, roles, cmd, qual, with_check): 84 | definition = "" 85 | if permissive is not None: 86 | definition += f"as {permissive} " 87 | if cmd is not None: 88 | definition += f"for {cmd} " 89 | if roles is not None: 90 | definition += f"to {', '.join(roles)} " 91 | if qual is not None: 92 | if qual[0] != "(": 93 | qual = f"({qual})" 94 | definition += f"using {qual} " 95 | if with_check is not None: 96 | if with_check[0] != "(": 97 | with_check = f"({with_check})" 98 | definition += f"with check {with_check} " 99 | return definition 100 | 101 | db_policies = [] 102 | for schema, table, policy_name, permissive, roles, cmd, qual, with_check in rows: 103 | definition = get_definition(permissive, roles, cmd, qual, with_check) 104 | 105 | schema = coerce_to_quoted(schema) 106 | table = coerce_to_quoted(table) 107 | policy_name = coerce_to_quoted(policy_name) 108 | policy = cls.from_sql(f"create policy {policy_name} on {schema}.{table} {definition}") 109 | db_policies.append(policy) 110 | 111 | for policy in db_policies: 112 | assert policy is not None 113 | 114 | return db_policies 115 | -------------------------------------------------------------------------------- /src/alembic_utils/pg_trigger.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long 2 | 3 | from parse import parse 4 | from sqlalchemy import text as sql_text 5 | 6 | from alembic_utils.exceptions import SQLParseFailure 7 | from alembic_utils.on_entity_mixin import OnEntityMixin 8 | from alembic_utils.replaceable_entity import ReplaceableEntity 9 | 10 | 11 | class PGTrigger(OnEntityMixin, ReplaceableEntity): 12 | """A PostgreSQL Trigger compatible with `alembic revision --autogenerate` 13 | 14 | **Parameters:** 15 | 16 | * **schema** - *str*: A SQL schema name 17 | * **signature** - *str*: A SQL trigger's call signature 18 | * **definition** - *str*: The remainig trigger body and identifiers 19 | * **on_entity** - *str*: fully qualifed entity that the policy applies 20 | * **is_constraint** - *bool*: Is the trigger a constraint trigger 21 | 22 | Postgres Create Trigger Specification: 23 | 24 | CREATE [ CONSTRAINT ] TRIGGER name { BEFORE | AFTER | INSTEAD OF } { event [ OR ... ] } 25 | ON table 26 | [ FROM referenced_table_name ] 27 | [ NOT DEFERRABLE | [ DEFERRABLE ] { INITIALLY IMMEDIATE | INITIALLY DEFERRED } ] 28 | [ FOR [ EACH ] { ROW | STATEMENT } ] 29 | [ WHEN ( condition ) ] 30 | EXECUTE PROCEDURE function_name ( arguments ) 31 | """ 32 | 33 | type_ = "trigger" 34 | 35 | _templates = [ 36 | "create{:s}constraint{:s}trigger{:s}{signature}{:s}{event}{:s}ON{:s}{on_entity}{:s}{action}", 37 | "create{:s}trigger{:s}{signature}{:s}{event}{:s}ON{:s}{on_entity}{:s}{action}", 38 | ] 39 | 40 | def __init__( 41 | self, 42 | schema: str, 43 | signature: str, 44 | definition: str, 45 | on_entity: str, 46 | is_constraint: bool = False, 47 | ): 48 | super().__init__( 49 | schema=schema, signature=signature, definition=definition, on_entity=on_entity # type: ignore 50 | ) 51 | self.is_constraint = is_constraint 52 | 53 | def render_self_for_migration(self, omit_definition=False) -> str: 54 | """Render a string that is valid python code to reconstruct self in a migration""" 55 | var_name = self.to_variable_name() 56 | class_name = self.__class__.__name__ 57 | escaped_definition = self.definition if not omit_definition else "# not required for op" 58 | 59 | return f"""{var_name} = {class_name}( 60 | schema="{self.schema}", 61 | signature="{self.signature}", 62 | on_entity="{self.on_entity}", 63 | is_constraint={self.is_constraint}, 64 | definition={repr(escaped_definition)} 65 | )\n""" 66 | 67 | @property 68 | def identity(self) -> str: 69 | """A string that consistently and globally identifies a trigger""" 70 | return f"{self.__class__.__name__}: {self.schema}.{self.signature} {self.is_constraint} {self.on_entity}" 71 | 72 | @classmethod 73 | def from_sql(cls, sql: str) -> "PGTrigger": 74 | """Create an instance instance from a SQL string""" 75 | for template in cls._templates: 76 | result = parse(template, sql, case_sensitive=False) 77 | if result is not None: 78 | # remove possible quotes from signature 79 | signature = result["signature"] 80 | event = result["event"] 81 | on_entity = result["on_entity"] 82 | action = result["action"] 83 | is_constraint = "constraint" in template 84 | 85 | if "." not in on_entity: 86 | on_entity = "public" + "." + on_entity 87 | 88 | schema = on_entity.split(".")[0] 89 | 90 | definition_template = " {event} ON {on_entity} {action}" 91 | definition = definition_template.format( 92 | event=event, on_entity=on_entity, action=action 93 | ) 94 | 95 | return cls( 96 | schema=schema, 97 | signature=signature, 98 | on_entity=on_entity, 99 | definition=definition, 100 | is_constraint=is_constraint, 101 | ) 102 | raise SQLParseFailure(f'Failed to parse SQL into PGTrigger """{sql}"""') 103 | 104 | def to_sql_statement_create(self): 105 | """Generates a SQL "create trigger" statement for PGTrigger""" 106 | 107 | # We need to parse and replace the schema qualifier on the table for simulate_entity to 108 | # operate 109 | _def = self.definition 110 | _template = "{event}{:s}ON{:s}{on_entity}{:s}{action}" 111 | match = parse(_template, _def) 112 | if not match: 113 | raise SQLParseFailure(f'Failed to parse SQL into PGTrigger.definition """{_def}"""') 114 | 115 | event = match["event"] 116 | action = match["action"] 117 | 118 | # Ensure entity is qualified with schema 119 | on_entity = match["on_entity"] 120 | if "." in on_entity: 121 | _, _, on_entity = on_entity.partition(".") 122 | on_entity = f"{self.schema}.{on_entity}" 123 | 124 | # Re-render the definition ensuring the table is qualified with 125 | def_rendered = _template.replace("{:s}", " ").format( 126 | event=event, on_entity=on_entity, action=action 127 | ) 128 | 129 | return sql_text( 130 | f"CREATE{' CONSTRAINT ' if self.is_constraint else ' '}TRIGGER \"{self.signature}\" {def_rendered}" 131 | ) 132 | 133 | def to_sql_statement_drop(self, cascade=False): 134 | """Generates a SQL "drop trigger" statement for PGTrigger""" 135 | cascade = "cascade" if cascade else "" 136 | return sql_text(f'DROP TRIGGER "{self.signature}" ON {self.on_entity} {cascade}') 137 | 138 | def to_sql_statement_create_or_replace(self): 139 | """Generates a SQL "replace trigger" statement for PGTrigger""" 140 | yield sql_text(f'DROP TRIGGER IF EXISTS "{self.signature}" ON {self.on_entity};') 141 | yield self.to_sql_statement_create() 142 | 143 | @classmethod 144 | def from_database(cls, sess, schema): 145 | """Get a list of all triggers defined in the db""" 146 | 147 | sql = sql_text( 148 | """ 149 | select 150 | pc.relnamespace::regnamespace::text as table_schema, 151 | tgname trigger_name, 152 | pg_get_triggerdef(pgt.oid) definition 153 | from 154 | pg_trigger pgt 155 | inner join pg_class pc 156 | on pgt.tgrelid = pc.oid 157 | where 158 | not tgisinternal 159 | and pc.relnamespace::regnamespace::text like :schema 160 | """ 161 | ) 162 | rows = sess.execute(sql, {"schema": schema}).fetchall() 163 | 164 | db_triggers = [cls.from_sql(x[2]) for x in rows] 165 | 166 | for trig in db_triggers: 167 | assert trig is not None 168 | 169 | return db_triggers 170 | -------------------------------------------------------------------------------- /src/alembic_utils/pg_view.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long 2 | 3 | 4 | from typing import Generator 5 | 6 | from parse import parse 7 | from sqlalchemy import text as sql_text 8 | from sqlalchemy.sql.elements import TextClause 9 | 10 | from alembic_utils.exceptions import SQLParseFailure 11 | from alembic_utils.replaceable_entity import ReplaceableEntity 12 | from alembic_utils.statement import ( 13 | coerce_to_unquoted, 14 | normalize_whitespace, 15 | strip_terminating_semicolon, 16 | ) 17 | 18 | 19 | class PGView(ReplaceableEntity): 20 | """A PostgreSQL View compatible with `alembic revision --autogenerate` 21 | 22 | **Parameters:** 23 | 24 | * **schema** - *str*: A SQL schema name 25 | * **signature** - *str*: A SQL view's call signature 26 | * **definition** - *str*: The SQL select statement body of the view 27 | """ 28 | 29 | type_ = "view" 30 | 31 | def __init__(self, schema: str, signature: str, definition: str): 32 | self.schema: str = coerce_to_unquoted(normalize_whitespace(schema)) 33 | self.signature: str = coerce_to_unquoted(normalize_whitespace(signature)) 34 | self.definition: str = strip_terminating_semicolon(definition) 35 | 36 | @classmethod 37 | def from_sql(cls, sql: str) -> "PGView": 38 | """Create an instance from a SQL string""" 39 | template = "create{}view{:s}{schema}.{signature}{:s}as{:s}{definition}" 40 | result = parse(template, sql, case_sensitive=False) 41 | if result is not None: 42 | # If the signature includes column e.g. my_view (col1, col2, col3) remove them 43 | signature = result["signature"].split("(")[0] 44 | return cls( 45 | schema=result["schema"], 46 | # strip quote characters 47 | signature=signature.replace('"', ""), 48 | definition=strip_terminating_semicolon(result["definition"]), 49 | ) 50 | 51 | raise SQLParseFailure(f'Failed to parse SQL into PGView """{sql}"""') 52 | 53 | def to_sql_statement_create(self) -> TextClause: 54 | """Generates a SQL "create view" statement""" 55 | return sql_text( 56 | f'CREATE VIEW {self.literal_schema}."{self.signature}" AS {self.definition};' 57 | ) 58 | 59 | def to_sql_statement_drop(self, cascade=False) -> TextClause: 60 | """Generates a SQL "drop view" statement""" 61 | cascade = "cascade" if cascade else "" 62 | return sql_text(f'DROP VIEW {self.literal_schema}."{self.signature}" {cascade}') 63 | 64 | def to_sql_statement_create_or_replace(self) -> Generator[TextClause, None, None]: 65 | """Generates a SQL "create or replace view" statement 66 | 67 | If the initial "CREATE OR REPLACE" statement does not succeed, 68 | fails over onto "DROP VIEW" followed by "CREATE VIEW" 69 | """ 70 | yield sql_text( 71 | f""" 72 | do $$ 73 | begin 74 | CREATE OR REPLACE VIEW {self.literal_schema}."{self.signature}" AS {self.definition}; 75 | 76 | exception when others then 77 | DROP VIEW IF EXISTS {self.literal_schema}."{self.signature}"; 78 | 79 | CREATE VIEW {self.literal_schema}."{self.signature}" AS {self.definition}; 80 | end; 81 | $$ language 'plpgsql' 82 | """ 83 | ) 84 | 85 | @classmethod 86 | def from_database(cls, sess, schema): 87 | """Get a list of all functions defined in the db""" 88 | sql = sql_text( 89 | f""" 90 | select 91 | schemaname schema_name, 92 | viewname view_name, 93 | definition 94 | from 95 | pg_views 96 | where 97 | schemaname not in ('pg_catalog', 'information_schema') 98 | and schemaname::text like '{schema}'; 99 | """ 100 | ) 101 | rows = sess.execute(sql).fetchall() 102 | db_views = [cls(x[0], x[1], x[2]) for x in rows] 103 | 104 | for view in db_views: 105 | assert view is not None 106 | 107 | return db_views 108 | -------------------------------------------------------------------------------- /src/alembic_utils/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olirice/alembic_utils/ccf9be7b8097bd6df13886a40a8e22c847c10bc8/src/alembic_utils/py.typed -------------------------------------------------------------------------------- /src/alembic_utils/replaceable_entity.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long 2 | import logging 3 | from itertools import zip_longest 4 | from pathlib import Path 5 | from typing import ( 6 | Dict, 7 | Generator, 8 | Iterable, 9 | List, 10 | Optional, 11 | Set, 12 | Type, 13 | TypeVar, 14 | ) 15 | 16 | from alembic.autogenerate import comparators 17 | from alembic.autogenerate.api import AutogenContext 18 | from sqlalchemy.orm import Session 19 | from sqlalchemy.sql.elements import TextClause 20 | 21 | import alembic_utils 22 | from alembic_utils.depends import solve_resolution_order 23 | from alembic_utils.exceptions import UnreachableException 24 | from alembic_utils.experimental import collect_subclasses 25 | from alembic_utils.reversible_op import ( 26 | CreateOp, 27 | DropOp, 28 | ReplaceOp, 29 | ReversibleOp, 30 | ) 31 | from alembic_utils.simulate import simulate_entity 32 | from alembic_utils.statement import ( 33 | coerce_to_quoted, 34 | coerce_to_unquoted, 35 | escape_colon_for_sql, 36 | normalize_whitespace, 37 | strip_terminating_semicolon, 38 | ) 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | T = TypeVar("T", bound="ReplaceableEntity") 43 | 44 | 45 | class ReplaceableEntity: 46 | """A SQL Entity that can be replaced""" 47 | 48 | def __init__(self, schema: str, signature: str, definition: str): 49 | self.schema: str = coerce_to_unquoted(normalize_whitespace(schema)) 50 | self.signature: str = coerce_to_unquoted(normalize_whitespace(signature)) 51 | self.definition: str = escape_colon_for_sql(strip_terminating_semicolon(definition)) 52 | 53 | @property 54 | def type_(self) -> str: 55 | """In order to support calls to `run_name_filters` and 56 | `run_object_filters` on the AutogenContext object, each 57 | entity needs to have a named type. 58 | https://alembic.sqlalchemy.org/en/latest/api/autogenerate.html#alembic.autogenerate.api.AutogenContext.run_name_filters 59 | """ 60 | raise NotImplementedError() 61 | 62 | @classmethod 63 | def from_sql(cls: Type[T], sql: str) -> T: 64 | """Create an instance from a SQL string""" 65 | raise NotImplementedError() 66 | 67 | @property 68 | def literal_schema(self) -> str: 69 | """Wrap a schema name in literal quotes 70 | Useful for emitting SQL statements 71 | """ 72 | return coerce_to_quoted(self.schema) 73 | 74 | @classmethod 75 | def from_path(cls: Type[T], path: Path) -> T: 76 | """Create an instance instance from a SQL file path""" 77 | with path.open() as sql_file: 78 | sql = sql_file.read() 79 | return cls.from_sql(sql) 80 | 81 | @classmethod 82 | def from_database(cls, sess: Session, schema="%") -> List[T]: 83 | """Collect existing entities from the database for given schema""" 84 | raise NotImplementedError() 85 | 86 | def to_sql_statement_create(self) -> TextClause: 87 | """Generates a SQL "create function" statement for PGFunction""" 88 | raise NotImplementedError() 89 | 90 | def to_sql_statement_drop(self, cascade=False) -> TextClause: 91 | """Generates a SQL "drop function" statement for PGFunction""" 92 | raise NotImplementedError() 93 | 94 | def to_sql_statement_create_or_replace(self) -> Generator[TextClause, None, None]: 95 | """Generates a SQL "create or replace function" statement for PGFunction""" 96 | raise NotImplementedError() 97 | 98 | def get_database_definition( 99 | self: T, sess: Session, dependencies: Optional[List["ReplaceableEntity"]] = None 100 | ) -> T: # $Optional[T]: 101 | """Creates the entity in the database, retrieves its 'rendered' then rolls it back""" 102 | with simulate_entity(sess, self, dependencies) as sess: 103 | # Drop self 104 | sess.execute(self.to_sql_statement_drop()) 105 | 106 | # collect all remaining entities 107 | db_entities: List[T] = sorted( 108 | self.from_database(sess, schema=self.schema), key=lambda x: x.identity 109 | ) 110 | 111 | with simulate_entity(sess, self, dependencies) as sess: 112 | # collect all remaining entities 113 | all_w_self: List[T] = sorted( 114 | self.from_database(sess, schema=self.schema), key=lambda x: x.identity 115 | ) 116 | 117 | # Find "self" by diffing the before and after 118 | for without_self, with_self in zip_longest(db_entities, all_w_self): 119 | if without_self is None or without_self.identity != with_self.identity: 120 | return with_self 121 | 122 | raise UnreachableException() 123 | 124 | def render_self_for_migration(self, omit_definition=False) -> str: 125 | """Render a string that is valid python code to reconstruct self in a migration""" 126 | var_name = self.to_variable_name() 127 | class_name = self.__class__.__name__ 128 | escaped_definition = self.definition if not omit_definition else "# not required for op" 129 | 130 | return f"""{var_name} = {class_name}( 131 | schema="{self.schema}", 132 | signature="{self.signature}", 133 | definition={repr(escaped_definition)} 134 | )\n""" 135 | 136 | @classmethod 137 | def render_import_statement(cls) -> str: 138 | """Render a string that is valid python code to import current class""" 139 | module_path = cls.__module__ 140 | class_name = cls.__name__ 141 | return f"from {module_path} import {class_name}\nfrom sqlalchemy import text as sql_text" 142 | 143 | @property 144 | def identity(self) -> str: 145 | """A string that consistently and globally identifies a function""" 146 | return f"{self.__class__.__name__}: {self.schema}.{self.signature}" 147 | 148 | def to_variable_name(self) -> str: 149 | """A deterministic variable name based on PGFunction's contents""" 150 | schema_name = self.schema.lower() 151 | object_name = self.signature.split("(")[0].strip().lower().replace("-", "_") 152 | return f"{schema_name}_{object_name}" 153 | 154 | _version_to_replace: Optional[T] = None # type: ignore 155 | 156 | def get_required_migration_op( 157 | self: T, sess: Session, dependencies: Optional[List["ReplaceableEntity"]] = None 158 | ) -> Optional[ReversibleOp]: 159 | """Get the migration operation required for autogenerate""" 160 | # All entities in the database for self's schema 161 | entities_in_database: List[T] = self.from_database(sess, schema=self.schema) 162 | 163 | db_def = self.get_database_definition(sess, dependencies=dependencies) 164 | 165 | for x in entities_in_database: 166 | 167 | if (db_def.identity, normalize_whitespace(db_def.definition)) == ( 168 | x.identity, 169 | normalize_whitespace(x.definition), 170 | ): 171 | return None 172 | 173 | if db_def.identity == x.identity: 174 | # Cache the currently live copy to render a RevertOp without hitting DB again 175 | self._version_to_replace = x 176 | return ReplaceOp(self) 177 | 178 | return CreateOp(self) 179 | 180 | 181 | class ReplaceableEntityRegistry: 182 | def __init__(self): 183 | self._entities: Dict[str, ReplaceableEntity] = {} 184 | self.schemas: Set[str] = set() 185 | self.exclude_schemas: Set[str] = set() 186 | self.entity_types: Set[Type[ReplaceableEntity]] = set() 187 | 188 | def clear(self): 189 | self._entities.clear() 190 | self.schemas.clear() 191 | self.exclude_schemas.clear() 192 | self.entity_types.clear() 193 | 194 | def register( 195 | self, 196 | entities: Iterable[ReplaceableEntity], 197 | schemas: Optional[List[str]] = None, 198 | exclude_schemas: Optional[Iterable[str]] = None, 199 | entity_types: Optional[Iterable[Type[ReplaceableEntity]]] = None, 200 | ) -> None: 201 | self._entities.update({e.identity: e for e in entities}) 202 | 203 | if schemas: 204 | self.schemas |= set(schemas) 205 | 206 | if exclude_schemas: 207 | self.exclude_schemas |= set(exclude_schemas) 208 | 209 | if entity_types: 210 | self.entity_types |= set(entity_types) 211 | 212 | @property 213 | def allowed_entity_types(self) -> Set[Type[ReplaceableEntity]]: 214 | if self.entity_types: 215 | return self.entity_types 216 | return set(collect_subclasses(alembic_utils, ReplaceableEntity)) 217 | 218 | def entities(self) -> List[ReplaceableEntity]: 219 | return list(self._entities.values()) 220 | 221 | 222 | registry = ReplaceableEntityRegistry() 223 | 224 | 225 | ################## 226 | # Event Listener # 227 | ################## 228 | 229 | 230 | def register_entities( 231 | entities: Iterable[ReplaceableEntity], 232 | schemas: Optional[List[str]] = None, 233 | exclude_schemas: Optional[Iterable[str]] = None, 234 | entity_types: Optional[Iterable[Type[ReplaceableEntity]]] = None, 235 | ) -> None: 236 | """Register entities to be monitored for changes when alembic is invoked with `revision --autogenerate`. 237 | 238 | **Parameters:** 239 | 240 | * **entities** - *List[ReplaceableEntity]*: A list of entities (PGFunction, PGView, etc) to monitor for revisions 241 | 242 | **Deprecated Parameters:** 243 | 244 | *Configure schema and object inclusion/exclusion with `include_name` and `include_object` in `env.py`. For more information see https://alembic.sqlalchemy.org/en/latest/autogenerate.html#controlling-what-to-be-autogenerated* 245 | 246 | 247 | * **schemas** - *Optional[List[str]]*: A list of SQL schema names to monitor. Note, schemas referenced in registered entities are automatically monitored. 248 | * **exclude_schemas** - *Optional[List[str]]*: A list of SQL schemas to ignore. Note, explicitly registered entities will still be monitored. 249 | * **entity_types** - *Optional[List[Type[ReplaceableEntity]]]*: A list of ReplaceableEntity classes to consider during migrations. Other entity types are ignored 250 | """ 251 | registry.register(entities, schemas, exclude_schemas, entity_types) 252 | 253 | 254 | @comparators.dispatch_for("schema") 255 | def compare_registered_entities( 256 | autogen_context: AutogenContext, 257 | upgrade_ops, 258 | schemas: List[Optional[str]], 259 | ): 260 | connection = autogen_context.connection 261 | 262 | entities = registry.entities() 263 | 264 | # https://alembic.sqlalchemy.org/en/latest/autogenerate.html#controlling-what-to-be-autogenerated 265 | # if EnvironmentContext.configure.include_schemas is True, all non-default "scehmas" should be included 266 | # pulled from the inspector 267 | include_schemas: bool = autogen_context.opts["include_schemas"] 268 | 269 | reflected_schemas = set( 270 | autogen_context.inspector.get_schema_names() if include_schemas else [] # type: ignore 271 | ) 272 | sqla_schemas: Set[Optional[str]] = set(schemas) 273 | manual_schemas = set(registry.schemas or set()) # Deprecated for remove in 0.6.0 274 | entity_schemas = {x.schema for x in entities} # from ReplaceableEntity instances 275 | all_schema_references = reflected_schemas | sqla_schemas | manual_schemas | entity_schemas # type: ignore 276 | 277 | # Remove excluded schemas 278 | observed_schemas: Set[str] = { 279 | schema_name 280 | for schema_name in all_schema_references 281 | if ( 282 | schema_name 283 | is not None 284 | not in ( 285 | registry.exclude_schemas or set() 286 | ) # user defined. Deprecated for remove in 0.6.0 287 | and schema_name not in {"information_schema", None} 288 | ) 289 | } 290 | 291 | # Solve resolution order 292 | transaction = connection.begin_nested() 293 | sess = Session(bind=connection) 294 | try: 295 | ordered_entities: List[ReplaceableEntity] = solve_resolution_order(sess, entities) 296 | finally: 297 | sess.rollback() 298 | 299 | # entities that are receiving a create or update op 300 | has_create_or_update_op: List[ReplaceableEntity] = [] 301 | 302 | # database rendered definitions for the entities we have a local instance for 303 | # Note: used for drops 304 | local_entities = [] 305 | 306 | # Required migration OPs, Create/Update/NoOp 307 | for entity in ordered_entities: 308 | logger.info( 309 | "Detecting required migration op %s %s", 310 | entity.__class__.__name__, 311 | entity.identity, 312 | ) 313 | 314 | if entity.__class__ not in registry.allowed_entity_types: 315 | continue 316 | 317 | if not include_entity(entity, autogen_context, reflected=False): 318 | logger.debug( 319 | "Ignoring local entity %s %s due to AutogenContext filters", 320 | entity.__class__.__name__, 321 | entity.identity, 322 | ) 323 | continue 324 | 325 | transaction = connection.begin_nested() 326 | sess = Session(bind=connection) 327 | try: 328 | maybe_op = entity.get_required_migration_op(sess, dependencies=has_create_or_update_op) 329 | 330 | local_db_def = entity.get_database_definition( 331 | sess, dependencies=has_create_or_update_op 332 | ) 333 | local_entities.append(local_db_def) 334 | 335 | if maybe_op: 336 | upgrade_ops.ops.append(maybe_op) 337 | has_create_or_update_op.append(entity) 338 | 339 | logger.info( 340 | "Detected %s op for %s %s", 341 | maybe_op.__class__.__name__, 342 | entity.__class__.__name__, 343 | entity.identity, 344 | ) 345 | else: 346 | logger.debug( 347 | "Detected NoOp op for %s %s", 348 | entity.__class__.__name__, 349 | entity.identity, 350 | ) 351 | 352 | finally: 353 | sess.rollback() 354 | 355 | # Required migration OPs, Drop 356 | # Start a parent transaction 357 | # Bind the session within the parent transaction 358 | transaction = connection.begin_nested() 359 | sess = Session(bind=connection) 360 | try: 361 | # All database entities currently live 362 | # Check if anything needs to drop 363 | subclasses = collect_subclasses(alembic_utils, ReplaceableEntity) 364 | for entity_class in subclasses: 365 | 366 | if entity_class not in registry.allowed_entity_types: 367 | continue 368 | 369 | # Entities within the schemas that are live 370 | for schema in observed_schemas: 371 | 372 | db_entities: List[ReplaceableEntity] = entity_class.from_database( 373 | sess, schema=schema 374 | ) 375 | 376 | # Check for functions that were deleted locally 377 | for db_entity in db_entities: 378 | 379 | if not include_entity(db_entity, autogen_context, reflected=True): 380 | logger.debug( 381 | "Ignoring remote entity %s %s due to AutogenContext filters", 382 | db_entity.__class__.__name__, 383 | db_entity.identity, 384 | ) 385 | continue 386 | 387 | for local_entity in local_entities: 388 | if db_entity.identity == local_entity.identity: 389 | break 390 | else: 391 | # No match was found locally 392 | # If the entity passes the filters, 393 | # we should create a DropOp 394 | upgrade_ops.ops.append(DropOp(db_entity)) 395 | logger.info( 396 | "Detected DropOp op for %s %s", 397 | db_entity.__class__.__name__, 398 | db_entity.identity, 399 | ) 400 | 401 | finally: 402 | sess.rollback() 403 | 404 | 405 | def include_entity( 406 | entity: ReplaceableEntity, autogen_context: AutogenContext, reflected: bool 407 | ) -> bool: 408 | """The functions on the AutogenContext object 409 | are described here: 410 | https://alembic.sqlalchemy.org/en/latest/api/autogenerate.html#alembic.autogenerate.api.AutogenContext.run_name_filters 411 | 412 | The meaning of the function parameters are explained in the corresponding 413 | definitions in the EnvironmentContext object: 414 | https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.environment.EnvironmentContext.configure.params.include_name 415 | 416 | This will only have an impact for projects which set include_object and/or include_name in the configuration 417 | of their Alembic env. 418 | """ 419 | name = f"{entity.schema}.{entity.signature}" 420 | parent_names = { 421 | "schema_name": entity.schema, 422 | # At the time of writing, the implementation of `run_name_filters` in Alembic assumes that every type of object 423 | # will either be a table or have a table_name in its `parent_names` dict. This is true for columns and indexes, 424 | # but not true for the type of objects supported in this library such as views as functions. Nevertheless, to avoid 425 | # a KeyError when calling `run_name_filters`, we have to set some value. 426 | "table_name": f"Not applicable for type {entity.type_}", 427 | } 428 | # According to the Alembic docs, the name filter is only relevant for reflected objects 429 | if reflected: 430 | name_result = autogen_context.run_name_filters(name, entity.type_, parent_names) 431 | else: 432 | name_result = True 433 | 434 | # Object filters should be applied to object from local metadata and to reflected objects 435 | object_result = autogen_context.run_object_filters( 436 | entity, name, entity.type_, reflected=reflected, compare_to=None 437 | ) 438 | return name_result and object_result 439 | -------------------------------------------------------------------------------- /src/alembic_utils/reversible_op.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Tuple, Type 2 | 3 | from alembic.autogenerate import renderers 4 | from alembic.operations import MigrateOperation, Operations 5 | from typing_extensions import Protocol 6 | 7 | from alembic_utils.exceptions import UnreachableException 8 | 9 | if TYPE_CHECKING: 10 | from alembic_utils.replaceable_entity import ReplaceableEntity 11 | 12 | 13 | class SupportsTarget(Protocol): 14 | def __init__(self, target: "ReplaceableEntity") -> None: 15 | pass 16 | 17 | 18 | class SupportsTargetCascade(Protocol): 19 | def __init__(self, target: "ReplaceableEntity", cascade: bool) -> None: 20 | pass 21 | 22 | 23 | class ReversibleOp(MigrateOperation): 24 | """A SQL operation that can be reversed""" 25 | 26 | def __init__(self, target: "ReplaceableEntity"): 27 | self.target = target 28 | 29 | @classmethod 30 | def invoke_for_target(cls: Type[SupportsTarget], operations, target: "ReplaceableEntity"): 31 | op = cls(target) 32 | return operations.invoke(op) 33 | 34 | @classmethod 35 | def invoke_for_target_optional_cascade( 36 | cls: Type[SupportsTargetCascade], operations, target: "ReplaceableEntity", cascade=False 37 | ): 38 | op = cls(target, cascade=cascade) # pylint: disable=unexpected-keyword-arg 39 | return operations.invoke(op) 40 | 41 | def reverse(self): 42 | raise NotImplementedError() 43 | 44 | 45 | ############## 46 | # Operations # 47 | ############## 48 | 49 | 50 | @Operations.register_operation("create_entity", "invoke_for_target") 51 | class CreateOp(ReversibleOp): 52 | def reverse(self): 53 | return DropOp(self.target) 54 | 55 | def to_diff_tuple(self) -> Tuple[Any, ...]: 56 | return "create_entity", self.target.identity, str(self.target.to_sql_statement_create()) 57 | 58 | 59 | @Operations.register_operation("drop_entity", "invoke_for_target_optional_cascade") 60 | class DropOp(ReversibleOp): 61 | def __init__(self, target: "ReplaceableEntity", cascade: bool = False) -> None: 62 | self.cascade = cascade 63 | super().__init__(target) 64 | 65 | def reverse(self): 66 | return CreateOp(self.target) 67 | 68 | def to_diff_tuple(self) -> Tuple[Any, ...]: 69 | return "drop_entity", self.target.identity 70 | 71 | 72 | @Operations.register_operation("replace_entity", "invoke_for_target") 73 | class ReplaceOp(ReversibleOp): 74 | def reverse(self): 75 | return RevertOp(self.target) 76 | 77 | def to_diff_tuple(self) -> Tuple[Any, ...]: 78 | return ( 79 | "replace_or_revert_entity", 80 | self.target.identity, 81 | str([str(x) for x in self.target.to_sql_statement_create_or_replace()]), 82 | ) 83 | 84 | 85 | class RevertOp(ReversibleOp): 86 | # Revert is never in an upgrade, so no need to implement reverse 87 | 88 | def to_diff_tuple(self) -> Tuple[Any, ...]: 89 | return ( 90 | "replace_or_revert_entity", 91 | self.target.identity, 92 | str([str(x) for x in self.target.to_sql_statement_create_or_replace()]), 93 | ) 94 | 95 | 96 | ################### 97 | # Implementations # 98 | ################### 99 | 100 | 101 | @Operations.implementation_for(CreateOp) 102 | def create_entity(operations, operation): 103 | target: "ReplaceableEntity" = operation.target 104 | operations.execute(target.to_sql_statement_create()) 105 | 106 | 107 | @Operations.implementation_for(DropOp) 108 | def drop_entity(operations, operation): 109 | target: "ReplaceableEntity" = operation.target 110 | operations.execute(target.to_sql_statement_drop(cascade=operation.cascade)) 111 | 112 | 113 | @Operations.implementation_for(ReplaceOp) 114 | @Operations.implementation_for(RevertOp) 115 | def replace_or_revert_entity(operations, operation): 116 | target: "ReplaceableEntity" = operation.target 117 | for stmt in target.to_sql_statement_create_or_replace(): 118 | operations.execute(stmt) 119 | 120 | 121 | ########## 122 | # Render # 123 | ########## 124 | 125 | 126 | @renderers.dispatch_for(CreateOp) 127 | def render_create_entity(autogen_context, op): 128 | target = op.target 129 | autogen_context.imports.add(target.render_import_statement()) 130 | variable_name = target.to_variable_name() 131 | return target.render_self_for_migration() + f"op.create_entity({variable_name})\n" 132 | 133 | 134 | @renderers.dispatch_for(DropOp) 135 | def render_drop_entity(autogen_context, op): 136 | target = op.target 137 | autogen_context.imports.add(target.render_import_statement()) 138 | variable_name = target.to_variable_name() 139 | return ( 140 | target.render_self_for_migration(omit_definition=False) 141 | + f"op.drop_entity({variable_name})\n" 142 | ) 143 | 144 | 145 | @renderers.dispatch_for(ReplaceOp) 146 | def render_replace_entity(autogen_context, op): 147 | target = op.target 148 | autogen_context.imports.add(target.render_import_statement()) 149 | variable_name = target.to_variable_name() 150 | return target.render_self_for_migration() + f"op.replace_entity({variable_name})\n" 151 | 152 | 153 | @renderers.dispatch_for(RevertOp) 154 | def render_revert_entity(autogen_context, op): 155 | """Collect the entity definition currently live in the database and use its definition 156 | as the downgrade revert target""" 157 | # At the time is call is made, the engine is disconnected 158 | 159 | # We should never reach this call unless an update's revert is being rendered 160 | # In that case, get_required_migration_op has cached the database's liver version 161 | # as target._version_to_replace 162 | 163 | target = op.target 164 | autogen_context.imports.add(target.render_import_statement()) 165 | 166 | db_target = target._version_to_replace 167 | 168 | if db_target is None: 169 | raise UnreachableException 170 | 171 | variable_name = db_target.to_variable_name() 172 | return db_target.render_self_for_migration() + f"op.replace_entity({variable_name})" 173 | -------------------------------------------------------------------------------- /src/alembic_utils/simulate.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long 2 | import copy 3 | import logging 4 | from contextlib import ExitStack, contextmanager 5 | from typing import TYPE_CHECKING, List, Optional 6 | 7 | from sqlalchemy.orm import Session 8 | 9 | if TYPE_CHECKING: 10 | from alembic_utils.replaceable_entity import ReplaceableEntity 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @contextmanager 17 | def simulate_entity( 18 | sess: Session, 19 | entity: "ReplaceableEntity", 20 | dependencies: Optional[List["ReplaceableEntity"]] = None, 21 | ): 22 | """Creates *entiity* in a transaction so postgres rendered definition 23 | can be retrieved 24 | """ 25 | 26 | # When simulating materialized view, don't populate them with data 27 | from alembic_utils.pg_materialized_view import PGMaterializedView 28 | 29 | if isinstance(entity, PGMaterializedView) and entity.with_data: 30 | entity = copy.deepcopy(entity) 31 | entity.with_data = False 32 | 33 | deps: List["ReplaceableEntity"] = dependencies or [] 34 | 35 | outer_transaction = sess.begin_nested() 36 | try: 37 | dependency_managers = [simulate_entity(sess, x) for x in deps] 38 | 39 | with ExitStack() as stack: 40 | # Setup all the possible deps 41 | for mgr in dependency_managers: 42 | stack.enter_context(mgr) 43 | 44 | did_drop = False 45 | inner_transaction = sess.begin_nested() 46 | try: 47 | sess.execute(entity.to_sql_statement_drop(cascade=True)) 48 | did_drop = True 49 | sess.execute(entity.to_sql_statement_create()) 50 | yield sess 51 | except: 52 | if did_drop: 53 | # The drop was successful, so either create was not, or the 54 | # error came from user code after the yield. 55 | # Anyway, we can exit now. 56 | raise 57 | 58 | # Try again without the drop in case the drop raised 59 | # a does not exist error 60 | inner_transaction.rollback() 61 | inner_transaction = sess.begin_nested() 62 | sess.execute(entity.to_sql_statement_create()) 63 | yield sess 64 | finally: 65 | inner_transaction.rollback() 66 | finally: 67 | outer_transaction.rollback() 68 | -------------------------------------------------------------------------------- /src/alembic_utils/statement.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | 3 | 4 | def normalize_whitespace(text, base_whitespace: str = " ") -> str: 5 | """Convert all whitespace to *base_whitespace*""" 6 | return base_whitespace.join(text.split()).strip() 7 | 8 | 9 | def strip_terminating_semicolon(sql: str) -> str: 10 | """Removes terminating semicolon on a SQL statement if it exists""" 11 | return sql.strip().rstrip(";").strip() 12 | 13 | 14 | def strip_double_quotes(sql: str) -> str: 15 | """Removes starting and ending double quotes""" 16 | sql = sql.strip().rstrip('"') 17 | return sql.strip().lstrip('"').strip() 18 | 19 | 20 | def escape_colon_for_sql(sql: str) -> str: 21 | """Escapes colons for use in sqlalchemy.text""" 22 | holder = str(uuid4()) 23 | sql = sql.replace("::", holder) 24 | sql = sql.replace(":", r"\:") 25 | sql = sql.replace(holder, "::") 26 | return sql 27 | 28 | 29 | def escape_colon_for_plpgsql(sql: str) -> str: 30 | """Escapes colons for plpgsql for use in sqlalchemy.text""" 31 | holder1 = str(uuid4()) 32 | holder2 = str(uuid4()) 33 | holder3 = str(uuid4()) 34 | sql = sql.replace("::", holder1) 35 | sql = sql.replace(":=", holder2) 36 | sql = sql.replace(r"\:", holder3) 37 | 38 | sql = sql.replace(":", r"\:") 39 | 40 | sql = sql.replace(holder3, r"\:") 41 | sql = sql.replace(holder2, ":=") 42 | sql = sql.replace(holder1, "::") 43 | return sql 44 | 45 | 46 | def coerce_to_quoted(text: str) -> str: 47 | """Coerces schema and entity names to double quoted one 48 | 49 | Examples: 50 | coerce_to_quoted('"public"') => '"public"' 51 | coerce_to_quoted('public') => '"public"' 52 | coerce_to_quoted('public.table') => '"public"."table"' 53 | coerce_to_quoted('"public".table') => '"public"."table"' 54 | coerce_to_quoted('public."table"') => '"public"."table"' 55 | """ 56 | if "." in text: 57 | schema, _, name = text.partition(".") 58 | schema = f'"{strip_double_quotes(schema)}"' 59 | name = f'"{strip_double_quotes(name)}"' 60 | return f"{schema}.{name}" 61 | 62 | text = strip_double_quotes(text) 63 | return f'"{text}"' 64 | 65 | 66 | def coerce_to_unquoted(text: str) -> str: 67 | """Coerces schema and entity names to unquoted 68 | 69 | Examples: 70 | coerce_to_unquoted('"public"') => 'public' 71 | coerce_to_unquoted('public') => 'public' 72 | coerce_to_unquoted('public.table') => 'public.table' 73 | coerce_to_unquoted('"public".table') => 'public.table' 74 | """ 75 | return "".join(text.split('"')) 76 | -------------------------------------------------------------------------------- /src/alembic_utils/testbase.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unsupported-assignment-operation 2 | import contextlib 3 | import os 4 | from io import StringIO 5 | from pathlib import Path 6 | from typing import Any, Callable, Dict, NoReturn 7 | 8 | from alembic import command as alem_command 9 | from alembic.config import Config 10 | from sqlalchemy.engine import Engine 11 | 12 | REPO_ROOT = Path(os.path.abspath(os.path.dirname(__file__))).parent.parent.resolve() 13 | TEST_RESOURCE_ROOT = REPO_ROOT / "src" / "test" / "resources" 14 | TEST_VERSIONS_ROOT = REPO_ROOT / "src" / "test" / "alembic_config" / "versions" 15 | 16 | 17 | ALEMBIC_COMMAND_MAP: Dict[str, Callable[..., NoReturn]] = { 18 | "upgrade": alem_command.upgrade, 19 | "downgrade": alem_command.downgrade, 20 | "revision": alem_command.revision, 21 | "current": alem_command.current, 22 | "check": alem_command.check, 23 | } 24 | 25 | 26 | def build_alembic_config(engine: Engine) -> Config: 27 | """Populate alembic configuration from metadata and config file.""" 28 | path_to_alembic_ini = REPO_ROOT / "alembic.ini" 29 | 30 | alembic_cfg = Config(path_to_alembic_ini) 31 | 32 | # Make double sure alembic references the test database 33 | alembic_cfg.set_main_option("sqlalchemy.url", engine.url.render_as_string(hide_password=False)) 34 | 35 | alembic_cfg.set_main_option("script_location", str((Path("src") / "test" / "alembic_config"))) 36 | return alembic_cfg 37 | 38 | 39 | def run_alembic_command(engine: Engine, command: str, command_kwargs: Dict[str, Any]) -> str: 40 | command_func = ALEMBIC_COMMAND_MAP[command] 41 | 42 | stdout = StringIO() 43 | 44 | alembic_cfg = build_alembic_config(engine) 45 | with engine.begin() as connection: 46 | alembic_cfg.attributes["connection"] = connection 47 | with contextlib.redirect_stdout(stdout): 48 | command_func(alembic_cfg, **command_kwargs) 49 | return stdout.getvalue() 50 | -------------------------------------------------------------------------------- /src/test/alembic_config/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /src/test/alembic_config/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | from alembic import context 4 | from sqlalchemy import MetaData, engine_from_config, pool 5 | 6 | from alembic_utils.replaceable_entity import ReplaceableEntity 7 | 8 | # this is the Alembic Config object, which provides 9 | # access to the values within the .ini file in use. 10 | config = context.config 11 | 12 | # Interpret the config file for Python logging. 13 | # This line sets up loggers basically. 14 | fileConfig(config.config_file_name) 15 | 16 | # add your model's MetaData object here 17 | # for 'autogenerate' support 18 | # from myapp import mymodel 19 | # target_metadata = mymodel.Base.metadata 20 | target_metadata = MetaData() 21 | 22 | # other values from the config, defined by the needs of env.py, 23 | # can be acquired: 24 | # my_important_option = config.get_main_option("my_important_option") 25 | # ... etc. 26 | 27 | 28 | def include_object(object, name, type_, reflected, compare_to) -> bool: 29 | # Do not generate migrations for non-alembic_utils entities 30 | if isinstance(object, ReplaceableEntity): 31 | # In order to test the application if this filter within 32 | # the autogeneration logic, apply a simple filter that 33 | # unit tests can relate to. 34 | # 35 | # In a 'real' implementation, this could be for example 36 | # ignoring entities from particular schemas. 37 | return not "exclude_obj_" in name 38 | else: 39 | return False 40 | 41 | 42 | def include_name(name, type_, parent_names) -> bool: 43 | # In order to test the application if this filter within 44 | # the autogeneration logic, apply a simple filter that 45 | # unit tests can relate to 46 | return not "exclude_name_" in name if name else True 47 | 48 | 49 | def run_migrations_online(): 50 | """Run migrations in 'online' mode. 51 | 52 | In this scenario we need to create an Engine 53 | and associate a connection with the context. 54 | 55 | """ 56 | connectable = engine_from_config( 57 | config.get_section(config.config_ini_section), 58 | prefix="sqlalchemy.", 59 | poolclass=pool.NullPool, 60 | ) 61 | 62 | with connectable.connect() as connection: 63 | context.configure( 64 | connection=connection, 65 | target_metadata=target_metadata, 66 | include_schemas=True, 67 | include_object=include_object, 68 | include_name=include_name, 69 | ) 70 | 71 | with context.begin_transaction(): 72 | context.run_migrations() 73 | 74 | 75 | run_migrations_online() 76 | -------------------------------------------------------------------------------- /src/test/alembic_config/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade(): 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade(): 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /src/test/conftest.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=redefined-outer-name,no-member 2 | 3 | import json 4 | import os 5 | import shutil 6 | import subprocess 7 | import time 8 | from typing import Generator 9 | 10 | import pytest 11 | from parse import parse 12 | from sqlalchemy import create_engine, text 13 | from sqlalchemy.engine import Engine 14 | from sqlalchemy.orm import Session, sessionmaker 15 | 16 | from alembic_utils.replaceable_entity import registry 17 | from alembic_utils.testbase import TEST_VERSIONS_ROOT 18 | 19 | PYTEST_DB = "postgresql://alem_user:password@localhost:5610/alem_db" 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def maybe_start_pg() -> Generator[None, None, None]: 24 | """Creates a postgres 12 docker container that can be connected 25 | to using the PYTEST_DB connection string""" 26 | 27 | container_name = "alembic_utils_pg" 28 | image = "postgres:13" 29 | 30 | connection_template = "postgresql://{user}:{pw}@{host}:{port:d}/{db}" 31 | conn_args = parse(connection_template, PYTEST_DB) 32 | 33 | # Don't attempt to instantiate a container if 34 | # we're on CI 35 | if "GITHUB_SHA" in os.environ: 36 | yield 37 | return 38 | 39 | try: 40 | is_running = ( 41 | subprocess.check_output( 42 | ["docker", "inspect", "-f", "{{.State.Running}}", container_name] 43 | ) 44 | .decode() 45 | .strip() 46 | == "true" 47 | ) 48 | except subprocess.CalledProcessError: 49 | # Can't inspect container if it isn't running 50 | is_running = False 51 | 52 | if is_running: 53 | yield 54 | return 55 | 56 | subprocess.call( 57 | [ 58 | "docker", 59 | "run", 60 | "--rm", 61 | "--name", 62 | container_name, 63 | "-p", 64 | f"{conn_args['port']}:5432", 65 | "-d", 66 | "-e", 67 | f"POSTGRES_DB={conn_args['db']}", 68 | "-e", 69 | f"POSTGRES_PASSWORD={conn_args['pw']}", 70 | "-e", 71 | f"POSTGRES_USER={conn_args['user']}", 72 | "--health-cmd", 73 | "pg_isready", 74 | "--health-interval", 75 | "3s", 76 | "--health-timeout", 77 | "3s", 78 | "--health-retries", 79 | "15", 80 | image, 81 | ] 82 | ) 83 | # Wait for postgres to become healthy 84 | for _ in range(10): 85 | out = subprocess.check_output(["docker", "inspect", container_name]) 86 | inspect_info = json.loads(out)[0] 87 | health_status = inspect_info["State"]["Health"]["Status"] 88 | if health_status == "healthy": 89 | break 90 | else: 91 | time.sleep(1) 92 | else: 93 | raise Exception("Could not reach postgres comtainer. Check docker installation") 94 | yield 95 | # subprocess.call(["docker", "stop", container_name]) 96 | return 97 | 98 | 99 | @pytest.fixture(scope="session") 100 | def raw_engine(maybe_start_pg: None) -> Generator[Engine, None, None]: 101 | """sqlalchemy engine fixture""" 102 | eng = create_engine(PYTEST_DB) 103 | yield eng 104 | eng.dispose() 105 | 106 | 107 | @pytest.fixture(scope="function") 108 | def engine(raw_engine: Engine) -> Generator[Engine, None, None]: 109 | """Engine that has been reset between tests""" 110 | 111 | def run_cleaners(): 112 | registry.clear() 113 | with raw_engine.begin() as connection: 114 | connection.execute(text("drop schema public cascade; create schema public;")) 115 | connection.execute(text('drop schema if exists "DEV" cascade; create schema "DEV";')) 116 | connection.execute(text('drop role if exists "anon_user"')) 117 | # Remove any migrations that were left behind 118 | TEST_VERSIONS_ROOT.mkdir(exist_ok=True, parents=True) 119 | shutil.rmtree(TEST_VERSIONS_ROOT) 120 | TEST_VERSIONS_ROOT.mkdir(exist_ok=True, parents=True) 121 | 122 | run_cleaners() 123 | 124 | yield raw_engine 125 | 126 | run_cleaners() 127 | 128 | 129 | @pytest.fixture(scope="function") 130 | def sess(engine) -> Generator[Session, None, None]: 131 | maker = sessionmaker(engine) 132 | sess = maker() 133 | yield sess 134 | sess.rollback() 135 | sess.expire_all() 136 | sess.close() 137 | -------------------------------------------------------------------------------- /src/test/resources/to_upper.sql: -------------------------------------------------------------------------------- 1 | CREATE OR REPLACE FUNCTION public.to_upper(some_text text) 2 | RETURNS TEXT AS 3 | $$ 4 | SELECT upper(some_text) 5 | $$ language SQL; 6 | -------------------------------------------------------------------------------- /src/test/test_alembic_check.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from alembic.util import AutogenerateDiffsDetected 3 | 4 | from alembic_utils.pg_view import PGView 5 | from alembic_utils.replaceable_entity import register_entities 6 | from alembic_utils.testbase import run_alembic_command 7 | 8 | TEST_VIEW_before = PGView( 9 | schema="public", 10 | signature="testExample", 11 | definition="select feature_name from information_schema.sql_features", 12 | ) 13 | TEST_VIEW_after = PGView( 14 | schema="public", 15 | signature="testExample", 16 | definition="select feature_name, is_supported from information_schema.sql_features", 17 | ) 18 | 19 | 20 | def test_check_diff_create(engine) -> None: 21 | register_entities([TEST_VIEW_before]) 22 | 23 | with pytest.raises(AutogenerateDiffsDetected) as e_info: 24 | run_alembic_command(engine, "check", {}) 25 | 26 | exp = ( 27 | "New upgrade operations detected: " 28 | "[('create_entity', 'PGView: public.testExample', " 29 | '\'CREATE VIEW "public"."testExample" AS select feature_name from information_schema.sql_features;\')]' 30 | ) 31 | assert e_info.value.args[0] == exp 32 | 33 | 34 | def test_check_diff_upgrade(engine) -> None: 35 | with engine.begin() as connection: 36 | connection.execute(TEST_VIEW_before.to_sql_statement_create()) 37 | 38 | register_entities([TEST_VIEW_after]) 39 | 40 | with pytest.raises(AutogenerateDiffsDetected) as e_info: 41 | run_alembic_command(engine, "check", {}) 42 | 43 | assert e_info.value.args[0].startswith( 44 | "New upgrade operations detected: [('replace_or_revert_entity', 'PGView: public.testExample'" 45 | ) 46 | 47 | 48 | def test_check_diff_drop(engine) -> None: 49 | with engine.begin() as connection: 50 | connection.execute(TEST_VIEW_before.to_sql_statement_create()) 51 | 52 | register_entities([]) 53 | 54 | with pytest.raises(AutogenerateDiffsDetected) as e_info: 55 | run_alembic_command(engine, "check", {}) 56 | 57 | exp = "New upgrade operations detected: [('drop_entity', 'PGView: public.testExample')]" 58 | 59 | assert e_info.value.args[0] == exp 60 | -------------------------------------------------------------------------------- /src/test/test_collect_instances.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import flupy 4 | from flupy import fluent 5 | 6 | import alembic_utils 7 | from alembic_utils.experimental._collect_instances import ( 8 | T, 9 | collect_instances, 10 | collect_subclasses, 11 | walk_modules, 12 | ) 13 | from alembic_utils.pg_view import PGView 14 | from alembic_utils.replaceable_entity import ReplaceableEntity 15 | 16 | 17 | def test_walk_modules() -> None: 18 | 19 | all_modules = [x for x in walk_modules(flupy)] 20 | assert fluent in all_modules 21 | 22 | 23 | def test_collect_instances() -> None: 24 | 25 | instances = collect_instances(alembic_utils, TypeVar) 26 | assert T in instances 27 | 28 | 29 | def test_collect_subclasses() -> None: 30 | class ImportedSubclass(ReplaceableEntity): 31 | ... 32 | 33 | classes = collect_subclasses(alembic_utils, ReplaceableEntity) 34 | assert PGView in classes 35 | assert ImportedSubclass in classes 36 | -------------------------------------------------------------------------------- /src/test/test_create_function.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import text 2 | 3 | from alembic_utils.pg_function import PGFunction 4 | 5 | to_upper = PGFunction( 6 | schema="public", 7 | signature="to_upper(some_text text)", 8 | definition=""" 9 | returns text 10 | as 11 | $$ select upper(some_text) $$ language SQL; 12 | """, 13 | ) 14 | 15 | 16 | def test_create_and_drop(engine) -> None: 17 | """Test that the alembic current command does not error""" 18 | # Runs with no error 19 | up_sql = to_upper.to_sql_statement_create() 20 | down_sql = to_upper.to_sql_statement_drop() 21 | 22 | # Testing that the following two lines don't raise 23 | with engine.begin() as connection: 24 | connection.execute(up_sql) 25 | result = connection.execute(text("select public.to_upper('hello');")).fetchone() 26 | assert result[0] == "HELLO" 27 | connection.execute(down_sql) 28 | assert True 29 | -------------------------------------------------------------------------------- /src/test/test_current.py: -------------------------------------------------------------------------------- 1 | from alembic_utils.testbase import run_alembic_command 2 | 3 | 4 | def test_current(engine) -> None: 5 | """Test that the alembic current command does not erorr""" 6 | # Runs with no error 7 | output = run_alembic_command(engine, "current", {}) 8 | assert output == "" 9 | -------------------------------------------------------------------------------- /src/test/test_depends.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from alembic_utils.depends import solve_resolution_order 4 | from alembic_utils.pg_view import PGView 5 | from alembic_utils.replaceable_entity import register_entities 6 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 7 | 8 | # NAME_DEPENDENCIES 9 | 10 | A = PGView( 11 | schema="public", 12 | signature="A_view", 13 | definition="select 1 as one", 14 | ) 15 | 16 | B_A = PGView( 17 | schema="public", 18 | signature="B_view", 19 | definition='select * from public."A_view"', 20 | ) 21 | 22 | C_A = PGView( 23 | schema="public", 24 | signature="C_view", 25 | definition='select * from public."A_view"', 26 | ) 27 | 28 | D_B = PGView( 29 | schema="public", 30 | signature="D_view", 31 | definition='select * from public."B_view"', 32 | ) 33 | 34 | E_AD = PGView( 35 | schema="public", 36 | signature="E_view", 37 | definition='select av.one, dv.one as two from public."A_view" as av join public."D_view" as dv on true', 38 | ) 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "order", 43 | [ 44 | [A, B_A, C_A, D_B, E_AD], 45 | [C_A, A, B_A, E_AD, D_B], 46 | [B_A, C_A, A, D_B, E_AD], 47 | [B_A, E_AD, D_B, C_A, A], 48 | ], 49 | ) 50 | def test_solve_resolution_order(sess, order) -> None: 51 | solution = solve_resolution_order(sess, order) 52 | 53 | assert solution.index(A) < solution.index(B_A) 54 | assert solution.index(A) < solution.index(C_A) 55 | assert solution.index(B_A) < solution.index(D_B) 56 | assert solution.index(A) < solution.index(E_AD) 57 | assert solution.index(D_B) < solution.index(E_AD) 58 | 59 | 60 | def test_create_revision(engine) -> None: 61 | register_entities([B_A, E_AD, D_B, C_A, A], entity_types=[PGView]) 62 | 63 | output = run_alembic_command( 64 | engine=engine, 65 | command="revision", 66 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 67 | ) 68 | 69 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 70 | 71 | with migration_create_path.open() as migration_file: 72 | migration_contents = migration_file.read() 73 | 74 | assert migration_contents.count("op.create_entity") == 5 75 | assert migration_contents.count("op.drop_entity") == 5 76 | assert "op.replace_entity" not in migration_contents 77 | assert "from alembic_utils.pg_view import PGView" in migration_contents 78 | 79 | # Execute upgrade 80 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 81 | # Execute Downgrade 82 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 83 | -------------------------------------------------------------------------------- /src/test/test_duplicate_registration.py: -------------------------------------------------------------------------------- 1 | from alembic_utils.pg_function import PGFunction 2 | from alembic_utils.replaceable_entity import register_entities, registry 3 | from alembic_utils.testbase import run_alembic_command 4 | 5 | 6 | def to_upper(): 7 | return PGFunction( 8 | schema="public", 9 | signature="to_upper(some_text text)", 10 | definition=""" 11 | returns text 12 | as 13 | $$ select upper(some_text) || 'abc' $$ language SQL; 14 | """, 15 | ) 16 | 17 | 18 | def test_migration_create_function(engine) -> None: 19 | to_upper1 = to_upper() 20 | to_upper2 = to_upper() 21 | register_entities([to_upper1, to_upper2], entity_types=[PGFunction]) 22 | 23 | entities = registry.entities() 24 | assert len(entities) == 1 25 | assert entities[0] == to_upper2 26 | 27 | run_alembic_command( 28 | engine=engine, 29 | command="revision", 30 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "raise"}, 31 | ) 32 | -------------------------------------------------------------------------------- /src/test/test_generate_revision.py: -------------------------------------------------------------------------------- 1 | from alembic_utils.pg_function import PGFunction 2 | from alembic_utils.replaceable_entity import register_entities 3 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 4 | 5 | TO_UPPER = PGFunction( 6 | schema="public", 7 | signature="to_upper(some_text text)", 8 | definition=""" 9 | returns text 10 | as 11 | $$ select upper(some_text) || 'abc' $$ language SQL; 12 | """, 13 | ) 14 | 15 | 16 | def test_migration_create_function(engine) -> None: 17 | register_entities([TO_UPPER], entity_types=[PGFunction]) 18 | output = run_alembic_command( 19 | engine=engine, 20 | command="revision", 21 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 22 | ) 23 | 24 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 25 | 26 | with migration_create_path.open() as migration_file: 27 | migration_contents = migration_file.read() 28 | 29 | assert migration_contents.count("op.create_entity") == 1 30 | assert migration_contents.count("op.drop_entity") == 1 31 | assert migration_contents.count("from alembic_utils") == 1 32 | -------------------------------------------------------------------------------- /src/test/test_include_filters.py: -------------------------------------------------------------------------------- 1 | from alembic_utils.pg_function import PGFunction 2 | from alembic_utils.pg_view import PGView 3 | from alembic_utils.replaceable_entity import register_entities 4 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 5 | 6 | # The objects marked as "excluded" have names corresponding 7 | # to the filters in src/test/alembic_config/env.py 8 | 9 | IncludedView = PGView( 10 | schema="public", 11 | signature="A_view", 12 | definition="select 1 as one", 13 | ) 14 | 15 | ObjExcludedView = PGView( 16 | schema="public", 17 | signature="exclude_obj_view", 18 | definition="select 1 as one", 19 | ) 20 | 21 | ReflectedIncludedView = PGView( 22 | schema="public", 23 | signature="reflected_view", 24 | definition="select 1 as one", 25 | ) 26 | 27 | ReflectedExcludedView = PGView( 28 | schema="public", 29 | signature="exclude_name_reflected_view", 30 | definition="select 1 as one", 31 | ) 32 | 33 | 34 | FuncDef = """ 35 | returns text 36 | as 37 | $$ begin return upper(some_text) || 'abc'; end; $$ language PLPGSQL; 38 | """ 39 | 40 | IncludedFunc = PGFunction( 41 | schema="public", signature="toUpper(some_text text default 'my text!')", definition=FuncDef 42 | ) 43 | 44 | ObjExcludedFunc = PGFunction( 45 | schema="public", 46 | signature="exclude_obj_toUpper(some_text text default 'my text!')", 47 | definition=FuncDef, 48 | ) 49 | 50 | ReflectedIncludedFunc = PGFunction( 51 | schema="public", 52 | signature="reflected_toUpper(some_text text default 'my text!')", 53 | definition=FuncDef, 54 | ) 55 | 56 | ReflectedExcludedFunc = PGFunction( 57 | schema="public", 58 | signature="exclude_obj_reflected_toUpper(some_text text default 'my text!')", 59 | definition=FuncDef, 60 | ) 61 | 62 | reflected_entities = [ 63 | ReflectedIncludedView, 64 | ReflectedExcludedView, 65 | ReflectedIncludedFunc, 66 | ReflectedExcludedFunc, 67 | ] 68 | 69 | registered_entities = [ 70 | IncludedView, 71 | ObjExcludedView, 72 | IncludedFunc, 73 | ObjExcludedFunc, 74 | ] 75 | 76 | 77 | def test_create_revision_with_filters(engine) -> None: 78 | with engine.begin() as connection: 79 | for entity in reflected_entities: 80 | connection.execute(entity.to_sql_statement_create()) 81 | register_entities(registered_entities) 82 | 83 | run_alembic_command( 84 | engine=engine, 85 | command="revision", 86 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "filtered_upgrade"}, 87 | ) 88 | 89 | migration_create_path = TEST_VERSIONS_ROOT / "1_filtered_upgrade.py" 90 | 91 | with migration_create_path.open() as migration_file: 92 | migration_contents = migration_file.read() 93 | 94 | assert "op.create_entity(public_a_view)" in migration_contents 95 | assert "op.create_entity(public_toupper)" in migration_contents 96 | assert "op.drop_entity(public_reflected_view)" in migration_contents 97 | assert "op.drop_entity(public_reflected_toupper)" in migration_contents 98 | 99 | assert not "op.create_entity(public_exclude_obj_view)" in migration_contents 100 | assert not "op.create_entity(public_exclude_obj_toupper)" in migration_contents 101 | assert not "op.drop_entity(public_exclude_name_reflected_view)" in migration_contents 102 | assert not "op.drop_entity(public_exclude_obj_reflected_toupper)" in migration_contents 103 | 104 | # Execute upgrade 105 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 106 | # Execute Downgrade 107 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 108 | -------------------------------------------------------------------------------- /src/test/test_initializers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from alembic_utils.exceptions import SQLParseFailure 4 | from alembic_utils.pg_function import PGFunction 5 | from alembic_utils.testbase import TEST_RESOURCE_ROOT 6 | 7 | 8 | def test_pg_function_from_file() -> None: 9 | """Test that the alembic current command does not erorr""" 10 | # Runs with no error 11 | SQL_PATH = TEST_RESOURCE_ROOT / "to_upper.sql" 12 | func = PGFunction.from_path(SQL_PATH) 13 | assert func.schema == "public" 14 | 15 | 16 | def test_pg_function_from_sql_file_valid() -> None: 17 | SQL = """ 18 | CREATE OR REPLACE FUNCTION public.to_upper(some_text text) 19 | RETURNS TEXT AS 20 | $$ 21 | SELECT upper(some_text) 22 | $$ language SQL; 23 | """ 24 | 25 | func = PGFunction.from_sql(SQL) 26 | assert func.schema == "public" 27 | 28 | 29 | def test_pg_function_from_sql_file_invalid() -> None: 30 | SQL = """ 31 | NO VALID SQL TO BE FOUND HERE 32 | """ 33 | with pytest.raises(SQLParseFailure): 34 | PGFunction.from_sql(SQL) 35 | -------------------------------------------------------------------------------- /src/test/test_op_drop_cascade.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy.exc import InternalError 3 | 4 | from alembic_utils.pg_view import PGView 5 | from alembic_utils.replaceable_entity import register_entities 6 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 7 | 8 | A = PGView( 9 | schema="public", 10 | signature="A_view", 11 | definition="select 1::integer as one", 12 | ) 13 | 14 | B_A = PGView( 15 | schema="public", 16 | signature="B_view", 17 | definition='select * from public."A_view"', 18 | ) 19 | 20 | 21 | def test_drop_fails_without_cascade(engine) -> None: 22 | 23 | with engine.begin() as connection: 24 | connection.execute(A.to_sql_statement_create()) 25 | connection.execute(B_A.to_sql_statement_create()) 26 | 27 | register_entities([B_A], schemas=["DEV"], entity_types=[PGView]) 28 | 29 | output = run_alembic_command( 30 | engine=engine, 31 | command="revision", 32 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 33 | ) 34 | 35 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 36 | 37 | with migration_create_path.open() as migration_file: 38 | migration_contents = migration_file.read() 39 | 40 | assert "op.drop_entity" in migration_contents 41 | assert "op.create_entity" in migration_contents 42 | assert "from alembic_utils" in migration_contents 43 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 44 | 45 | with pytest.raises(InternalError): 46 | # sqlalchemy.exc.InternalError: (psycopg2.errors.DependentObjectsStillExist) cannot drop view "A_view" because other objects depend on it 47 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 48 | 49 | 50 | def test_drop_fails_with_cascade(engine, sess) -> None: 51 | 52 | with engine.begin() as connection: 53 | connection.execute(A.to_sql_statement_create()) 54 | connection.execute(B_A.to_sql_statement_create()) 55 | 56 | register_entities([B_A], schemas=["DEV"], entity_types=[PGView]) 57 | 58 | output = run_alembic_command( 59 | engine=engine, 60 | command="revision", 61 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 62 | ) 63 | 64 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 65 | 66 | with migration_create_path.open() as migration_file: 67 | migration_contents = migration_file.read() 68 | 69 | assert "op.drop_entity" in migration_contents 70 | assert "op.drop_entity(public_a_view)" in migration_contents 71 | 72 | migration_contents = migration_contents.replace( 73 | "op.drop_entity(public_a_view)", "op.drop_entity(public_a_view, cascade=True)" 74 | ) 75 | 76 | with migration_create_path.open("w") as migration_file: 77 | migration_file.write(migration_contents) 78 | 79 | assert "op.create_entity" in migration_contents 80 | assert "from alembic_utils" in migration_contents 81 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 82 | 83 | # Cascade drops *B_A* and succeeds 84 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 85 | 86 | # Make sure the drop ocurred 87 | all_views = PGView.from_database(sess, "public") 88 | assert len(all_views) == 0 89 | -------------------------------------------------------------------------------- /src/test/test_pg_constraint_trigger.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy import text 3 | 4 | from alembic_utils.exceptions import SQLParseFailure 5 | from alembic_utils.pg_function import PGFunction 6 | from alembic_utils.pg_trigger import PGTrigger 7 | from alembic_utils.replaceable_entity import register_entities 8 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 9 | 10 | 11 | @pytest.fixture(scope="function") 12 | def sql_setup(engine): 13 | with engine.begin() as connection: 14 | connection.execute( 15 | text( 16 | """ 17 | create table public.account ( 18 | id serial primary key, 19 | email text not null 20 | ); 21 | """ 22 | ) 23 | ) 24 | 25 | yield 26 | with engine.begin() as connection: 27 | connection.execute(text("drop table public.account cascade")) 28 | 29 | 30 | FUNC = PGFunction.from_sql( 31 | """create function public.downcase_email() returns trigger as $$ 32 | begin 33 | return new; 34 | end; 35 | $$ language plpgsql; 36 | """ 37 | ) 38 | 39 | TRIG = PGTrigger( 40 | schema="public", 41 | signature="lower_account_email", 42 | on_entity="public.account", 43 | is_constraint=True, 44 | definition=""" 45 | AFTER INSERT ON public.account 46 | FOR EACH ROW EXECUTE PROCEDURE public.downcase_email() 47 | """, 48 | ) 49 | 50 | 51 | def test_create_revision(sql_setup, engine) -> None: 52 | with engine.begin() as connection: 53 | connection.execute(FUNC.to_sql_statement_create()) 54 | 55 | register_entities([FUNC, TRIG], entity_types=[PGTrigger]) 56 | run_alembic_command( 57 | engine=engine, 58 | command="revision", 59 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 60 | ) 61 | 62 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 63 | 64 | with migration_create_path.open() as migration_file: 65 | migration_contents = migration_file.read() 66 | 67 | assert "op.create_entity" in migration_contents 68 | assert "op.drop_entity" in migration_contents 69 | assert "op.replace_entity" not in migration_contents 70 | assert "from alembic_utils.pg_trigger import PGTrigger" in migration_contents 71 | 72 | # Execute upgrade 73 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 74 | # Execute Downgrade 75 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 76 | 77 | 78 | def test_trig_update_revision(sql_setup, engine) -> None: 79 | with engine.begin() as connection: 80 | connection.execute(FUNC.to_sql_statement_create()) 81 | connection.execute(TRIG.to_sql_statement_create()) 82 | 83 | UPDATED_TRIG = PGTrigger( 84 | schema="public", 85 | signature="lower_account_email", 86 | on_entity="public.account", 87 | is_constraint=True, 88 | definition=""" 89 | AFTER INSERT OR UPDATE ON public.account 90 | FOR EACH ROW EXECUTE PROCEDURE public.downcase_email() 91 | """, 92 | ) 93 | 94 | register_entities([FUNC, UPDATED_TRIG], entity_types=[PGTrigger]) 95 | 96 | # Autogenerate a new migration 97 | # It should detect the change we made and produce a "replace_function" statement 98 | run_alembic_command( 99 | engine=engine, 100 | command="revision", 101 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 102 | ) 103 | 104 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 105 | 106 | with migration_replace_path.open() as migration_file: 107 | migration_contents = migration_file.read() 108 | 109 | assert "op.replace_entity" in migration_contents 110 | assert "op.create_entity" not in migration_contents 111 | assert "op.drop_entity" not in migration_contents 112 | assert "from alembic_utils.pg_trigger import PGTrigger" in migration_contents 113 | 114 | # Execute upgrade 115 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 116 | 117 | # Execute Downgrade 118 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 119 | 120 | 121 | def test_noop_revision(sql_setup, engine) -> None: 122 | with engine.begin() as connection: 123 | connection.execute(FUNC.to_sql_statement_create()) 124 | connection.execute(TRIG.to_sql_statement_create()) 125 | 126 | register_entities([FUNC, TRIG], entity_types=[PGTrigger]) 127 | 128 | output = run_alembic_command( 129 | engine=engine, 130 | command="revision", 131 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 132 | ) 133 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 134 | 135 | with migration_do_nothing_path.open() as migration_file: 136 | migration_contents = migration_file.read() 137 | 138 | assert "op.create_entity" not in migration_contents 139 | assert "op.drop_entity" not in migration_contents 140 | assert "op.replace_entity" not in migration_contents 141 | assert "from alembic_utils" not in migration_contents 142 | 143 | # Execute upgrade 144 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 145 | # Execute Downgrade 146 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 147 | 148 | 149 | def test_drop(sql_setup, engine) -> None: 150 | # Manually create a SQL function 151 | with engine.begin() as connection: 152 | connection.execute(FUNC.to_sql_statement_create()) 153 | connection.execute(TRIG.to_sql_statement_create()) 154 | 155 | # Register no functions locally 156 | register_entities([], schemas=["public"], entity_types=[PGTrigger]) 157 | 158 | run_alembic_command( 159 | engine=engine, 160 | command="revision", 161 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 162 | ) 163 | 164 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 165 | 166 | with migration_create_path.open() as migration_file: 167 | migration_contents = migration_file.read() 168 | 169 | assert "op.drop_entity" in migration_contents 170 | assert "op.create_entity" in migration_contents 171 | assert "from alembic_utils" in migration_contents 172 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 173 | 174 | 175 | def test_unparsable() -> None: 176 | SQL = "create trigger lower_account_email faile fail fail" 177 | with pytest.raises(SQLParseFailure): 178 | PGTrigger.from_sql(SQL) 179 | 180 | 181 | def test_on_entity_schema_not_qualified() -> None: 182 | SQL = """create trigger lower_account_email 183 | AFTER INSERT ON account 184 | FOR EACH ROW EXECUTE PROCEDURE public.downcase_email() 185 | """ 186 | trigger = PGTrigger.from_sql(SQL) 187 | assert trigger.schema == "public" 188 | 189 | 190 | def test_fail_create_sql_statement_create(): 191 | trig = PGTrigger( 192 | schema="public", 193 | signature="lower_account_email", 194 | on_entity="public.sometab", 195 | definition="INVALID DEF", 196 | ) 197 | 198 | with pytest.raises(SQLParseFailure): 199 | trig.to_sql_statement_create() 200 | -------------------------------------------------------------------------------- /src/test/test_pg_extension.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from alembic_utils.pg_extension import PGExtension 4 | from alembic_utils.replaceable_entity import register_entities 5 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 6 | 7 | TEST_EXT = PGExtension(schema="public", signature="uuid-ossp") 8 | 9 | 10 | def test_create_revision(engine) -> None: 11 | register_entities([TEST_EXT], entity_types=[PGExtension]) 12 | 13 | output = run_alembic_command( 14 | engine=engine, 15 | command="revision", 16 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 17 | ) 18 | 19 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 20 | 21 | with migration_create_path.open() as migration_file: 22 | migration_contents = migration_file.read() 23 | 24 | assert "op.create_entity" in migration_contents 25 | assert "op.drop_entity" in migration_contents 26 | assert "op.replace_entity" not in migration_contents 27 | assert "from alembic_utils.pg_extension import PGExtension" in migration_contents 28 | 29 | # Execute upgrade 30 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 31 | # Execute Downgrade 32 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 33 | 34 | 35 | def test_create_or_replace_raises(): 36 | with pytest.raises(NotImplementedError): 37 | TEST_EXT.to_sql_statement_create_or_replace() 38 | 39 | 40 | def test_update_is_unreachable(engine) -> None: 41 | # Updates are not possible. The only parameter that may change is 42 | # schema, and that will result in a drop and create due to schema 43 | # scoping assumptions made for all other entities 44 | 45 | # Create the view outside of a revision 46 | with engine.begin() as connection: 47 | connection.execute(TEST_EXT.to_sql_statement_create()) 48 | 49 | UPDATED_TEST_EXT = PGExtension("DEV", TEST_EXT.signature) 50 | 51 | register_entities([UPDATED_TEST_EXT], schemas=["public", "DEV"], entity_types=[PGExtension]) 52 | 53 | # Autogenerate a new migration 54 | # It should detect the change we made and produce a "replace_function" statement 55 | output = run_alembic_command( 56 | engine=engine, 57 | command="revision", 58 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 59 | ) 60 | 61 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 62 | 63 | with migration_replace_path.open() as migration_file: 64 | migration_contents = migration_file.read() 65 | 66 | assert "op.replace_entity" not in migration_contents 67 | assert "from alembic_utils.pg_extension import PGExtension" in migration_contents 68 | 69 | 70 | def test_noop_revision(engine) -> None: 71 | # Create the view outside of a revision 72 | with engine.begin() as connection: 73 | connection.execute(TEST_EXT.to_sql_statement_create()) 74 | 75 | register_entities([TEST_EXT], entity_types=[PGExtension]) 76 | 77 | # Create a third migration without making changes. 78 | # This should result in no create, drop or replace statements 79 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 80 | 81 | output = run_alembic_command( 82 | engine=engine, 83 | command="revision", 84 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 85 | ) 86 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 87 | 88 | with migration_do_nothing_path.open() as migration_file: 89 | migration_contents = migration_file.read() 90 | 91 | assert "op.create_entity" not in migration_contents 92 | assert "op.drop_entity" not in migration_contents 93 | assert "op.replace_entity" not in migration_contents 94 | assert "from alembic_utils" not in migration_contents 95 | 96 | # Execute upgrade 97 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 98 | # Execute Downgrade 99 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 100 | 101 | 102 | def test_drop_revision(engine) -> None: 103 | # Register no functions locally 104 | register_entities([], entity_types=[PGExtension]) 105 | 106 | # Manually create a SQL function 107 | with engine.begin() as connection: 108 | connection.execute(TEST_EXT.to_sql_statement_create()) 109 | 110 | output = run_alembic_command( 111 | engine=engine, 112 | command="revision", 113 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 114 | ) 115 | 116 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 117 | 118 | with migration_create_path.open() as migration_file: 119 | migration_contents = migration_file.read() 120 | 121 | assert "op.drop_entity" in migration_contents 122 | assert "op.create_entity" in migration_contents 123 | assert "from alembic_utils" in migration_contents 124 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 125 | 126 | # Execute upgrade 127 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 128 | # Execute Downgrade 129 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 130 | -------------------------------------------------------------------------------- /src/test/test_pg_function.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from sqlalchemy import text 4 | 5 | from alembic_utils.pg_function import PGFunction 6 | from alembic_utils.replaceable_entity import register_entities 7 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 8 | 9 | TO_UPPER = PGFunction( 10 | schema="public", 11 | signature="toUpper (some_text text default 'my text!')", 12 | definition=""" 13 | returns text 14 | as 15 | $$ begin return upper(some_text) || 'abc'; end; $$ language PLPGSQL; 16 | """, 17 | ) 18 | 19 | 20 | def test_trailing_whitespace_stripped(): 21 | sql_statements: List[str] = [ 22 | str(TO_UPPER.to_sql_statement_create()), 23 | str(next(iter(TO_UPPER.to_sql_statement_create_or_replace()))), 24 | str(TO_UPPER.to_sql_statement_drop()), 25 | ] 26 | 27 | for statement in sql_statements: 28 | print(statement) 29 | assert '"toUpper"' in statement 30 | assert not '"toUpper "' in statement 31 | 32 | 33 | def test_create_revision(engine) -> None: 34 | register_entities([TO_UPPER], entity_types=[PGFunction]) 35 | 36 | run_alembic_command( 37 | engine=engine, 38 | command="revision", 39 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 40 | ) 41 | 42 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 43 | 44 | with migration_create_path.open() as migration_file: 45 | migration_contents = migration_file.read() 46 | 47 | assert "op.create_entity" in migration_contents 48 | assert "op.drop_entity" in migration_contents 49 | assert "op.replace_entity" not in migration_contents 50 | assert "from alembic_utils.pg_function import PGFunction" in migration_contents 51 | 52 | # Execute upgrade 53 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 54 | # Execute Downgrade 55 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 56 | 57 | 58 | def test_update_revision(engine) -> None: 59 | with engine.begin() as connection: 60 | connection.execute(TO_UPPER.to_sql_statement_create()) 61 | 62 | # Update definition of TO_UPPER 63 | UPDATED_TO_UPPER = PGFunction( 64 | TO_UPPER.schema, 65 | TO_UPPER.signature, 66 | r'''returns text as 67 | $$ 68 | select upper(some_text) || 'def' -- """ \n \\ 69 | $$ language SQL immutable strict;''', 70 | ) 71 | 72 | register_entities([UPDATED_TO_UPPER], entity_types=[PGFunction]) 73 | 74 | # Autogenerate a new migration 75 | # It should detect the change we made and produce a "replace_function" statement 76 | run_alembic_command( 77 | engine=engine, 78 | command="revision", 79 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 80 | ) 81 | 82 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 83 | 84 | with migration_replace_path.open() as migration_file: 85 | migration_contents = migration_file.read() 86 | 87 | assert "op.replace_entity" in migration_contents 88 | assert "op.create_entity" not in migration_contents 89 | assert "op.drop_entity" not in migration_contents 90 | assert "from alembic_utils.pg_function import PGFunction" in migration_contents 91 | 92 | # Execute upgrade 93 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 94 | 95 | # Execute Downgrade 96 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 97 | 98 | 99 | def test_noop_revision(engine) -> None: 100 | with engine.begin() as connection: 101 | connection.execute(TO_UPPER.to_sql_statement_create()) 102 | 103 | register_entities([TO_UPPER], entity_types=[PGFunction]) 104 | 105 | output = run_alembic_command( 106 | engine=engine, 107 | command="revision", 108 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 109 | ) 110 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 111 | 112 | with migration_do_nothing_path.open() as migration_file: 113 | migration_contents = migration_file.read() 114 | 115 | assert "op.create_entity" not in migration_contents 116 | assert "op.drop_entity" not in migration_contents 117 | assert "op.replace_entity" not in migration_contents 118 | assert "from alembic_utils" not in migration_contents 119 | 120 | # Execute upgrade 121 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 122 | # Execute Downgrade 123 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 124 | 125 | 126 | def test_drop(engine) -> None: 127 | # Manually create a SQL function 128 | with engine.begin() as connection: 129 | connection.execute(TO_UPPER.to_sql_statement_create()) 130 | 131 | # Register no functions locally 132 | register_entities([], schemas=["public"], entity_types=[PGFunction]) 133 | 134 | run_alembic_command( 135 | engine=engine, 136 | command="revision", 137 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 138 | ) 139 | 140 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 141 | 142 | with migration_create_path.open() as migration_file: 143 | migration_contents = migration_file.read() 144 | 145 | assert "op.drop_entity" in migration_contents 146 | assert "op.create_entity" in migration_contents 147 | assert "from alembic_utils" in migration_contents 148 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 149 | 150 | # Execute upgrade 151 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 152 | # Execute Downgrade 153 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 154 | 155 | 156 | def test_has_no_parameters(engine) -> None: 157 | # Error was occuring in drop statement when function had no parameters 158 | # related to parameter parsing to drop default statements 159 | 160 | SIDE_EFFECT = PGFunction( 161 | schema="public", 162 | signature="side_effect()", 163 | definition=""" 164 | returns integer 165 | as 166 | $$ select 1; $$ language SQL; 167 | """, 168 | ) 169 | 170 | # Register no functions locally 171 | register_entities([SIDE_EFFECT], schemas=["public"], entity_types=[PGFunction]) 172 | 173 | run_alembic_command( 174 | engine=engine, 175 | command="revision", 176 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "no_arguments"}, 177 | ) 178 | 179 | migration_create_path = TEST_VERSIONS_ROOT / "1_no_arguments.py" 180 | 181 | with migration_create_path.open() as migration_file: 182 | migration_contents = migration_file.read() 183 | 184 | assert "op.drop_entity" in migration_contents 185 | 186 | # Execute upgrade 187 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 188 | # Execute Downgrade 189 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 190 | 191 | 192 | def test_ignores_extension_functions(engine) -> None: 193 | # Extensions contain functions and don't have local representations 194 | # Unless they are excluded, every autogenerate migration will produce 195 | # drop statements for those functions 196 | try: 197 | with engine.begin() as connection: 198 | connection.execute(text("create extension if not exists unaccent;")) 199 | register_entities([], schemas=["public"], entity_types=[PGFunction]) 200 | run_alembic_command( 201 | engine=engine, 202 | command="revision", 203 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "no_drops"}, 204 | ) 205 | 206 | migration_create_path = TEST_VERSIONS_ROOT / "1_no_drops.py" 207 | 208 | with migration_create_path.open() as migration_file: 209 | migration_contents = migration_file.read() 210 | 211 | assert "op.drop_entity" not in migration_contents 212 | finally: 213 | with engine.begin() as connection: 214 | connection.execute(text("drop extension if exists unaccent;")) 215 | 216 | 217 | def test_plpgsql_colon_esacpe(engine) -> None: 218 | # PGFunction.__init__ overrides colon escapes for plpgsql 219 | # because := should not be escaped for sqlalchemy.text 220 | # if := is escaped, an exception would be raised 221 | 222 | PLPGSQL_FUNC = PGFunction( 223 | schema="public", 224 | signature="some_func(some_text text)", 225 | definition=""" 226 | returns text 227 | as 228 | $$ 229 | declare 230 | copy_o_text text; 231 | begin 232 | copy_o_text := some_text; 233 | return copy_o_text; 234 | end; 235 | $$ language plpgsql 236 | """, 237 | ) 238 | 239 | register_entities([PLPGSQL_FUNC], entity_types=[PGFunction]) 240 | 241 | run_alembic_command( 242 | engine=engine, 243 | command="revision", 244 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 245 | ) 246 | 247 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 248 | 249 | with migration_create_path.open() as migration_file: 250 | migration_contents = migration_file.read() 251 | 252 | assert "op.create_entity" in migration_contents 253 | assert "op.drop_entity" in migration_contents 254 | assert "op.replace_entity" not in migration_contents 255 | assert "from alembic_utils.pg_function import PGFunction" in migration_contents 256 | 257 | # Execute upgrade 258 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 259 | # Execute Downgrade 260 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 261 | -------------------------------------------------------------------------------- /src/test/test_pg_function_overloading.py: -------------------------------------------------------------------------------- 1 | from alembic_utils.pg_function import PGFunction 2 | from alembic_utils.replaceable_entity import register_entities 3 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 4 | 5 | TO_FLOAT_FROM_INT = PGFunction( 6 | schema="public", 7 | signature="to_float(x integer)", 8 | definition=""" 9 | returns float 10 | as 11 | $$ select x::float $$ language SQL; 12 | """, 13 | ) 14 | 15 | TO_FLOAT_FROM_TEXT = PGFunction( 16 | schema="public", 17 | signature="to_float(x text)", 18 | definition=""" 19 | returns float 20 | as 21 | $$ select x::float $$ language SQL; 22 | """, 23 | ) 24 | 25 | 26 | def test_create_revision(engine) -> None: 27 | register_entities([TO_FLOAT_FROM_INT, TO_FLOAT_FROM_TEXT], entity_types=[PGFunction]) 28 | 29 | run_alembic_command( 30 | engine=engine, 31 | command="revision", 32 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 33 | ) 34 | 35 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 36 | 37 | with migration_create_path.open() as migration_file: 38 | migration_contents = migration_file.read() 39 | 40 | assert migration_contents.count("op.create_entity") == 2 41 | assert "op.drop_entity" in migration_contents 42 | assert "op.replace_entity" not in migration_contents 43 | assert "from alembic_utils.pg_function import PGFunction" in migration_contents 44 | 45 | # Execute upgrade 46 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 47 | # Execute Downgrade 48 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 49 | 50 | 51 | def test_update_revision(engine) -> None: 52 | with engine.begin() as connection: 53 | connection.execute(TO_FLOAT_FROM_INT.to_sql_statement_create()) 54 | connection.execute(TO_FLOAT_FROM_TEXT.to_sql_statement_create()) 55 | 56 | UPDATE = PGFunction( 57 | schema="public", 58 | signature="to_float(x integer)", 59 | definition=""" 60 | returns float 61 | as 62 | $$ select x::text::float $$ language SQL; 63 | """, 64 | ) 65 | 66 | register_entities([UPDATE, TO_FLOAT_FROM_TEXT], entity_types=[PGFunction]) 67 | 68 | # Autogenerate a new migration 69 | # It should detect the change we made and produce a "replace_function" statement 70 | run_alembic_command( 71 | engine=engine, 72 | command="revision", 73 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 74 | ) 75 | 76 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 77 | 78 | with migration_replace_path.open() as migration_file: 79 | migration_contents = migration_file.read() 80 | 81 | # One up and one down 82 | assert migration_contents.count("op.replace_entity") == 2 83 | assert "op.create_entity" not in migration_contents 84 | assert "op.drop_entity" not in migration_contents 85 | assert "from alembic_utils.pg_function import PGFunction" in migration_contents 86 | 87 | # Execute upgrade 88 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 89 | 90 | # Execute Downgrade 91 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 92 | -------------------------------------------------------------------------------- /src/test/test_pg_grant_table.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy import text 3 | 4 | from alembic_utils.exceptions import BadInputException 5 | from alembic_utils.pg_grant_table import PGGrantTable, PGGrantTableChoice 6 | from alembic_utils.replaceable_entity import register_entities 7 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 8 | 9 | 10 | @pytest.fixture(scope="function") 11 | def sql_setup(engine): 12 | with engine.begin() as connection: 13 | connection.execute( 14 | text( 15 | """ 16 | create table public.account ( 17 | id serial primary key, 18 | email text not null 19 | ); 20 | create role anon_user 21 | """ 22 | ) 23 | ) 24 | 25 | yield 26 | with engine.begin() as connection: 27 | connection.execute(text("drop table public.account cascade")) 28 | 29 | 30 | TEST_GRANT = PGGrantTable( 31 | schema="public", 32 | table="account", 33 | role="anon_user", 34 | grant=PGGrantTableChoice.DELETE, 35 | with_grant_option=False, 36 | ) 37 | 38 | 39 | def test_repr(): 40 | go = PGGrantTableChoice("TRUNCATE") 41 | assert go.__repr__() == "'TRUNCATE'" 42 | 43 | 44 | def test_bad_input(): 45 | 46 | with pytest.raises(BadInputException): 47 | PGGrantTable( 48 | schema="public", 49 | table="account", 50 | role="anon_user", 51 | grant=PGGrantTableChoice.DELETE, 52 | columns=["id"], # columns not allowed for delete 53 | ) 54 | 55 | with pytest.raises(BadInputException): 56 | PGGrantTable( 57 | schema="public", 58 | table="account", 59 | role="anon_user", 60 | grant=PGGrantTableChoice.SELECT, 61 | # columns required for select 62 | ) 63 | 64 | 65 | def test_create_revision(sql_setup, engine) -> None: 66 | register_entities([TEST_GRANT], entity_types=[PGGrantTable]) 67 | run_alembic_command( 68 | engine=engine, 69 | command="revision", 70 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 71 | ) 72 | 73 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 74 | 75 | with migration_create_path.open() as migration_file: 76 | migration_contents = migration_file.read() 77 | 78 | assert "op.create_entity" in migration_contents 79 | assert "op.drop_entity" in migration_contents 80 | assert "op.replace_entity" not in migration_contents 81 | assert "from alembic_utils.pg_grant_table import PGGrantTable" in migration_contents 82 | 83 | # Execute upgrade 84 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 85 | # Execute Downgrade 86 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 87 | 88 | 89 | def test_replace_revision(sql_setup, engine) -> None: 90 | with engine.begin() as connection: 91 | connection.execute(TEST_GRANT.to_sql_statement_create()) 92 | 93 | UPDATED_GRANT = PGGrantTable( 94 | schema="public", 95 | table="account", 96 | role="anon_user", 97 | grant=PGGrantTableChoice.DELETE, 98 | with_grant_option=True, 99 | ) 100 | 101 | register_entities([UPDATED_GRANT], entity_types=[PGGrantTable]) 102 | run_alembic_command( 103 | engine=engine, 104 | command="revision", 105 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "update"}, 106 | ) 107 | 108 | migration_create_path = TEST_VERSIONS_ROOT / "2_update.py" 109 | 110 | with migration_create_path.open() as migration_file: 111 | migration_contents = migration_file.read() 112 | 113 | # Granting can not be done in place. 114 | assert "op.replace_entity" in migration_contents 115 | assert "op.create_entity" not in migration_contents 116 | assert "op.drop_entity" not in migration_contents 117 | assert "from alembic_utils.pg_grant_table import PGGrantTable" in migration_contents 118 | 119 | # Execute upgrade 120 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 121 | # Execute Downgrade 122 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 123 | 124 | 125 | def test_noop_revision(sql_setup, engine) -> None: 126 | # Create the view outside of a revision 127 | with engine.begin() as connection: 128 | connection.execute(TEST_GRANT.to_sql_statement_create()) 129 | 130 | register_entities([TEST_GRANT], entity_types=[PGGrantTable]) 131 | 132 | # Create a third migration without making changes. 133 | # This should result in no create, drop or replace statements 134 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 135 | 136 | output = run_alembic_command( 137 | engine=engine, 138 | command="revision", 139 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 140 | ) 141 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 142 | 143 | with migration_do_nothing_path.open() as migration_file: 144 | migration_contents = migration_file.read() 145 | 146 | assert "op.create_entity" not in migration_contents 147 | assert "op.drop_entity" not in migration_contents 148 | assert "op.replace_entity" not in migration_contents 149 | assert "from alembic_utils" not in migration_contents 150 | 151 | # Execute upgrade 152 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 153 | # Execute Downgrade 154 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 155 | 156 | 157 | def test_drop_revision(sql_setup, engine) -> None: 158 | 159 | # Register no functions locally 160 | register_entities([], schemas=["public"], entity_types=[PGGrantTable]) 161 | 162 | # Manually create a SQL function 163 | with engine.begin() as connection: 164 | connection.execute(TEST_GRANT.to_sql_statement_create()) 165 | 166 | output = run_alembic_command( 167 | engine=engine, 168 | command="revision", 169 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 170 | ) 171 | 172 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 173 | 174 | with migration_create_path.open() as migration_file: 175 | migration_contents = migration_file.read() 176 | 177 | # import pdb; pdb.set_trace() 178 | 179 | assert "op.drop_entity" in migration_contents 180 | assert "op.create_entity" in migration_contents 181 | assert "from alembic_utils" in migration_contents 182 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 183 | 184 | # Execute upgrade 185 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 186 | # Execute Downgrade 187 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 188 | -------------------------------------------------------------------------------- /src/test/test_pg_grant_table_w_columns.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy import text 3 | 4 | from alembic_utils.pg_grant_table import PGGrantTable, PGGrantTableChoice 5 | from alembic_utils.replaceable_entity import register_entities 6 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 7 | 8 | # TODO 9 | """ 10 | - Cant revoke permissions from superuser so they need to be filtered out 11 | - Multiple users may grant the same permsision to a role so we need a new 12 | parameter for grantor and a check that "CURENT_USER == grantor" 13 | """ 14 | 15 | 16 | @pytest.fixture(scope="function") 17 | def sql_setup(engine): 18 | with engine.begin() as connection: 19 | connection.execute( 20 | text( 21 | """ 22 | create table public.account ( 23 | id serial primary key, 24 | email text not null 25 | ); 26 | create role anon_user 27 | """ 28 | ) 29 | ) 30 | 31 | yield 32 | with engine.begin() as connection: 33 | connection.execute(text("drop table public.account cascade")) 34 | 35 | 36 | TEST_GRANT = PGGrantTable( 37 | schema="public", 38 | table="account", 39 | columns=["id", "email"], 40 | role="anon_user", 41 | grant=PGGrantTableChoice.SELECT, 42 | with_grant_option=False, 43 | ) 44 | 45 | 46 | def test_repr(): 47 | go = PGGrantTableChoice("SELECT") 48 | assert go.__repr__() == "'SELECT'" 49 | 50 | 51 | def test_create_revision(sql_setup, engine) -> None: 52 | register_entities([TEST_GRANT], entity_types=[PGGrantTable]) 53 | run_alembic_command( 54 | engine=engine, 55 | command="revision", 56 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 57 | ) 58 | 59 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 60 | 61 | with migration_create_path.open() as migration_file: 62 | migration_contents = migration_file.read() 63 | 64 | assert "op.create_entity" in migration_contents 65 | assert "op.drop_entity" in migration_contents 66 | assert "op.replace_entity" not in migration_contents 67 | assert "from alembic_utils.pg_grant_table import PGGrantTable" in migration_contents 68 | 69 | # Execute upgrade 70 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 71 | # Execute Downgrade 72 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 73 | 74 | 75 | def test_replace_revision(sql_setup, engine) -> None: 76 | with engine.begin() as connection: 77 | connection.execute(TEST_GRANT.to_sql_statement_create()) 78 | 79 | UPDATED_GRANT = PGGrantTable( 80 | schema="public", 81 | table="account", 82 | columns=["id"], 83 | role="anon_user", 84 | grant=PGGrantTableChoice.SELECT, 85 | with_grant_option=True, 86 | ) 87 | 88 | register_entities([UPDATED_GRANT], entity_types=[PGGrantTable]) 89 | run_alembic_command( 90 | engine=engine, 91 | command="revision", 92 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "update"}, 93 | ) 94 | 95 | migration_create_path = TEST_VERSIONS_ROOT / "2_update.py" 96 | 97 | with migration_create_path.open() as migration_file: 98 | migration_contents = migration_file.read() 99 | 100 | # Granting can not be done in place. 101 | assert "op.replace_entity" in migration_contents 102 | assert "op.create_entity" not in migration_contents 103 | assert "op.drop_entity" not in migration_contents 104 | assert "from alembic_utils.pg_grant_table import PGGrantTable" in migration_contents 105 | 106 | # Execute upgrade 107 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 108 | # Execute Downgrade 109 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 110 | 111 | 112 | def test_noop_revision(sql_setup, engine) -> None: 113 | # Create the view outside of a revision 114 | with engine.begin() as connection: 115 | connection.execute(TEST_GRANT.to_sql_statement_create()) 116 | 117 | register_entities([TEST_GRANT], entity_types=[PGGrantTable]) 118 | 119 | # Create a third migration without making changes. 120 | # This should result in no create, drop or replace statements 121 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 122 | 123 | output = run_alembic_command( 124 | engine=engine, 125 | command="revision", 126 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 127 | ) 128 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 129 | 130 | with migration_do_nothing_path.open() as migration_file: 131 | migration_contents = migration_file.read() 132 | 133 | assert "op.create_entity" not in migration_contents 134 | assert "op.drop_entity" not in migration_contents 135 | assert "op.replace_entity" not in migration_contents 136 | assert "from alembic_utils" not in migration_contents 137 | 138 | # Execute upgrade 139 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 140 | # Execute Downgrade 141 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 142 | 143 | 144 | def test_drop_revision(sql_setup, engine) -> None: 145 | 146 | # Register no functions locally 147 | register_entities([], schemas=["public"], entity_types=[PGGrantTable]) 148 | 149 | # Manually create a SQL function 150 | with engine.begin() as connection: 151 | connection.execute(TEST_GRANT.to_sql_statement_create()) 152 | 153 | output = run_alembic_command( 154 | engine=engine, 155 | command="revision", 156 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 157 | ) 158 | 159 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 160 | 161 | with migration_create_path.open() as migration_file: 162 | migration_contents = migration_file.read() 163 | 164 | # import pdb; pdb.set_trace() 165 | 166 | assert "op.drop_entity" in migration_contents 167 | assert "op.create_entity" in migration_contents 168 | assert "from alembic_utils" in migration_contents 169 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 170 | 171 | # Execute upgrade 172 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 173 | # Execute Downgrade 174 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 175 | -------------------------------------------------------------------------------- /src/test/test_pg_materialized_view.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from alembic_utils.exceptions import SQLParseFailure 4 | from alembic_utils.pg_materialized_view import PGMaterializedView 5 | from alembic_utils.replaceable_entity import register_entities 6 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 7 | 8 | TEST_MAT_VIEW = PGMaterializedView( 9 | schema="DEV", 10 | signature="test_mat_view", 11 | definition="SELECT concat('https://something/', cast(x as text)) FROM generate_series(1,10) x", 12 | with_data=True, 13 | ) 14 | 15 | 16 | def test_unparsable_view() -> None: 17 | SQL = "create or replace materialized vew public.some_view as select 1 one;" 18 | with pytest.raises(SQLParseFailure): 19 | view = PGMaterializedView.from_sql(SQL) 20 | 21 | 22 | def test_parsable_body() -> None: 23 | SQL = "create materialized view public.some_view as select 1 one;" 24 | try: 25 | view = PGMaterializedView.from_sql(SQL) 26 | except SQLParseFailure: 27 | pytest.fail(f"Unexpected SQLParseFailure for view {SQL}") 28 | 29 | SQL = "create materialized view public.some_view as select 1 one with data;" 30 | try: 31 | view = PGMaterializedView.from_sql(SQL) 32 | assert view.with_data 33 | except SQLParseFailure: 34 | pytest.fail(f"Unexpected SQLParseFailure for view {SQL}") 35 | 36 | SQL = "create materialized view public.some_view as select 1 one with no data;" 37 | try: 38 | view = PGMaterializedView.from_sql(SQL) 39 | assert not view.with_data 40 | except SQLParseFailure: 41 | pytest.fail(f"Unexpected SQLParseFailure for view {SQL}") 42 | 43 | SQL = "create materialized view public.some_view(one) as select 1 one;" 44 | try: 45 | view = PGMaterializedView.from_sql(SQL) 46 | assert view.signature == "some_view" 47 | except SQLParseFailure: 48 | pytest.fail(f"Unexpected SQLParseFailure for view {SQL}") 49 | 50 | 51 | def test_create_revision(engine) -> None: 52 | register_entities([TEST_MAT_VIEW], entity_types=[PGMaterializedView]) 53 | 54 | output = run_alembic_command( 55 | engine=engine, 56 | command="revision", 57 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 58 | ) 59 | 60 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 61 | 62 | with migration_create_path.open() as migration_file: 63 | migration_contents = migration_file.read() 64 | 65 | assert "op.create_entity" in migration_contents 66 | assert "op.drop_entity" in migration_contents 67 | assert "op.replace_entity" not in migration_contents 68 | assert "from alembic_utils.pg_materialized_view import PGMaterializedView" in migration_contents 69 | 70 | # ensure colon was not quoted 71 | # https://github.com/olirice/alembic_utils/issues/95 72 | assert "https://" in migration_contents 73 | 74 | # Execute upgrade 75 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 76 | # Execute Downgrade 77 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 78 | 79 | 80 | def test_update_revision(engine) -> None: 81 | # Create the view outside of a revision 82 | with engine.begin() as connection: 83 | connection.execute(TEST_MAT_VIEW.to_sql_statement_create()) 84 | 85 | # Update definition of TO_UPPER 86 | UPDATED_TEST_MAT_VIEW = PGMaterializedView( 87 | TEST_MAT_VIEW.schema, 88 | TEST_MAT_VIEW.signature, 89 | """select *, TRUE as is_updated from pg_matviews""", 90 | with_data=TEST_MAT_VIEW.with_data, 91 | ) 92 | 93 | register_entities([UPDATED_TEST_MAT_VIEW], entity_types=[PGMaterializedView]) 94 | 95 | # Autogenerate a new migration 96 | # It should detect the change we made and produce a "replace_function" statement 97 | output = run_alembic_command( 98 | engine=engine, 99 | command="revision", 100 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 101 | ) 102 | 103 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 104 | 105 | with migration_replace_path.open() as migration_file: 106 | migration_contents = migration_file.read() 107 | 108 | assert "op.replace_entity" in migration_contents 109 | assert "op.create_entity" not in migration_contents 110 | assert "op.drop_entity" not in migration_contents 111 | assert "from alembic_utils.pg_materialized_view import PGMaterializedView" in migration_contents 112 | 113 | # Execute upgrade 114 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 115 | # Execute Downgrade 116 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 117 | 118 | 119 | def test_noop_revision(engine) -> None: 120 | # Create the view outside of a revision 121 | with engine.begin() as connection: 122 | connection.execute(TEST_MAT_VIEW.to_sql_statement_create()) 123 | 124 | register_entities([TEST_MAT_VIEW], entity_types=[PGMaterializedView]) 125 | 126 | # Create a third migration without making changes. 127 | # This should result in no create, drop or replace statements 128 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 129 | 130 | output = run_alembic_command( 131 | engine=engine, 132 | command="revision", 133 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 134 | ) 135 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 136 | 137 | with migration_do_nothing_path.open() as migration_file: 138 | migration_contents = migration_file.read() 139 | 140 | assert "op.create_entity" not in migration_contents 141 | assert "op.drop_entity" not in migration_contents 142 | assert "op.replace_entity" not in migration_contents 143 | assert "from alembic_utils" not in migration_contents 144 | 145 | # Execute upgrade 146 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 147 | # Execute Downgrade 148 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 149 | 150 | 151 | def test_drop_revision(engine) -> None: 152 | 153 | # Register no functions locally 154 | register_entities([], schemas=["DEV"], entity_types=[PGMaterializedView]) 155 | 156 | # Manually create a SQL function 157 | with engine.begin() as connection: 158 | connection.execute(TEST_MAT_VIEW.to_sql_statement_create()) 159 | 160 | output = run_alembic_command( 161 | engine=engine, 162 | command="revision", 163 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 164 | ) 165 | 166 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 167 | 168 | with migration_create_path.open() as migration_file: 169 | migration_contents = migration_file.read() 170 | 171 | # import pdb; pdb.set_trace() 172 | 173 | assert "op.drop_entity" in migration_contents 174 | assert "op.create_entity" in migration_contents 175 | assert "from alembic_utils" in migration_contents 176 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 177 | 178 | # Execute upgrade 179 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 180 | # Execute Downgrade 181 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 182 | -------------------------------------------------------------------------------- /src/test/test_pg_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | import pytest 4 | from sqlalchemy import text 5 | 6 | from alembic_utils.exceptions import SQLParseFailure 7 | from alembic_utils.pg_policy import PGPolicy 8 | from alembic_utils.replaceable_entity import register_entities 9 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 10 | 11 | TEST_POLICY = PGPolicy( 12 | schema="public", 13 | signature="some_policy", 14 | on_entity="some_tab", # schema omitted intentionally 15 | definition=""" 16 | for all 17 | to anon_user 18 | using (true) 19 | with check (true); 20 | """, 21 | ) 22 | 23 | 24 | @pytest.fixture() 25 | def schema_setup(engine) -> Generator[None, None, None]: 26 | with engine.begin() as connection: 27 | connection.execute( 28 | text( 29 | """ 30 | create table public.some_tab ( 31 | id serial primary key, 32 | name text 33 | ); 34 | 35 | create user anon_user; 36 | """ 37 | ) 38 | ) 39 | yield 40 | with engine.begin() as connection: 41 | connection.execute( 42 | text( 43 | """ 44 | drop table public.some_tab; 45 | drop user anon_user; 46 | """ 47 | ) 48 | ) 49 | 50 | 51 | def test_unparsable() -> None: 52 | sql = "create po mypol on public.sometab for all;" 53 | with pytest.raises(SQLParseFailure): 54 | PGPolicy.from_sql(sql) 55 | 56 | 57 | def test_parse_without_schema_on_entity() -> None: 58 | 59 | sql = "CREATE POLICY mypol on accOunt as PERMISSIVE for SELECT to account_creator using (true) wiTh CHECK (true)" 60 | 61 | policy = PGPolicy.from_sql(sql) 62 | assert policy.schema == "public" 63 | assert policy.signature == "mypol" 64 | assert policy.on_entity == "public.accOunt" 65 | assert ( 66 | policy.definition 67 | == "as PERMISSIVE for SELECT to account_creator using (true) wiTh CHECK (true)" 68 | ) 69 | 70 | 71 | def test_create_revision(engine, schema_setup) -> None: 72 | register_entities([TEST_POLICY], entity_types=[PGPolicy]) 73 | 74 | output = run_alembic_command( 75 | engine=engine, 76 | command="revision", 77 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 78 | ) 79 | 80 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 81 | 82 | with migration_create_path.open() as migration_file: 83 | migration_contents = migration_file.read() 84 | 85 | assert "op.create_entity" in migration_contents 86 | assert "op.drop_entity" in migration_contents 87 | assert "op.replace_entity" not in migration_contents 88 | assert "from alembic_utils.pg_policy import PGPolicy" in migration_contents 89 | 90 | # Execute upgrade 91 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 92 | # Execute Downgrade 93 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 94 | 95 | 96 | def test_update_revision(engine, schema_setup) -> None: 97 | # Create the view outside of a revision 98 | with engine.begin() as connection: 99 | connection.execute(TEST_POLICY.to_sql_statement_create()) 100 | 101 | # Update definition of TO_UPPER 102 | UPDATED_TEST_POLICY = PGPolicy( 103 | schema=TEST_POLICY.schema, 104 | signature=TEST_POLICY.signature, 105 | on_entity=TEST_POLICY.on_entity, 106 | definition=""" 107 | for update 108 | to anon_user 109 | using (true); 110 | """, 111 | ) 112 | 113 | register_entities([UPDATED_TEST_POLICY], entity_types=[PGPolicy]) 114 | 115 | # Autogenerate a new migration 116 | # It should detect the change we made and produce a "replace_function" statement 117 | output = run_alembic_command( 118 | engine=engine, 119 | command="revision", 120 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 121 | ) 122 | 123 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 124 | 125 | with migration_replace_path.open() as migration_file: 126 | migration_contents = migration_file.read() 127 | 128 | assert "op.replace_entity" in migration_contents 129 | assert "op.create_entity" not in migration_contents 130 | assert "op.drop_entity" not in migration_contents 131 | assert "from alembic_utils.pg_policy import PGPolicy" in migration_contents 132 | 133 | # Execute upgrade 134 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 135 | # Execute Downgrade 136 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 137 | 138 | 139 | def test_noop_revision(engine, schema_setup) -> None: 140 | # Create the view outside of a revision 141 | with engine.begin() as connection: 142 | connection.execute(TEST_POLICY.to_sql_statement_create()) 143 | 144 | register_entities([TEST_POLICY], entity_types=[PGPolicy]) 145 | 146 | # Create a third migration without making changes. 147 | # This should result in no create, drop or replace statements 148 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 149 | 150 | output = run_alembic_command( 151 | engine=engine, 152 | command="revision", 153 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 154 | ) 155 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 156 | 157 | with migration_do_nothing_path.open() as migration_file: 158 | migration_contents = migration_file.read() 159 | 160 | assert "op.create_entity" not in migration_contents 161 | assert "op.drop_entity" not in migration_contents 162 | assert "op.replace_entity" not in migration_contents 163 | assert "from alembic_utils" not in migration_contents 164 | assert "from alembic_utils.pg_policy import PGPolicy" not in migration_contents 165 | 166 | # Execute upgrade 167 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 168 | # Execute Downgrade 169 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 170 | 171 | 172 | def test_drop_revision(engine, schema_setup) -> None: 173 | 174 | # Register no functions locally 175 | register_entities([], schemas=["public"], entity_types=[PGPolicy]) 176 | 177 | # Manually create a SQL function 178 | with engine.begin() as connection: 179 | connection.execute(TEST_POLICY.to_sql_statement_create()) 180 | output = run_alembic_command( 181 | engine=engine, 182 | command="revision", 183 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 184 | ) 185 | 186 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 187 | 188 | with migration_create_path.open() as migration_file: 189 | migration_contents = migration_file.read() 190 | 191 | # import pdb; pdb.set_trace() 192 | 193 | assert "op.drop_entity" in migration_contents 194 | assert "op.create_entity" in migration_contents 195 | assert "from alembic_utils" in migration_contents 196 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 197 | 198 | # Execute upgrade 199 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 200 | # Execute Downgrade 201 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 202 | -------------------------------------------------------------------------------- /src/test/test_pg_trigger.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy import text 3 | 4 | from alembic_utils.exceptions import SQLParseFailure 5 | from alembic_utils.pg_function import PGFunction 6 | from alembic_utils.pg_trigger import PGTrigger 7 | from alembic_utils.replaceable_entity import register_entities 8 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 9 | 10 | 11 | @pytest.fixture(scope="function") 12 | def sql_setup(engine): 13 | with engine.begin() as connection: 14 | connection.execute( 15 | text( 16 | """ 17 | create table public.account ( 18 | id serial primary key, 19 | email text not null 20 | ); 21 | """ 22 | ) 23 | ) 24 | 25 | yield 26 | with engine.begin() as connection: 27 | connection.execute(text("drop table public.account cascade")) 28 | 29 | 30 | FUNC = PGFunction.from_sql( 31 | """create function public.downcase_email() returns trigger as $$ 32 | begin 33 | return new; 34 | end; 35 | $$ language plpgsql; 36 | """ 37 | ) 38 | 39 | TRIG = PGTrigger( 40 | schema="public", 41 | signature="lower_account_EMAIL", 42 | on_entity="public.account", 43 | definition=""" 44 | BEFORE INSERT ON public.account 45 | FOR EACH ROW EXECUTE PROCEDURE public.downcase_email() 46 | """, 47 | ) 48 | 49 | 50 | def test_create_revision(sql_setup, engine) -> None: 51 | with engine.begin() as connection: 52 | connection.execute(FUNC.to_sql_statement_create()) 53 | 54 | register_entities([FUNC, TRIG], entity_types=[PGTrigger]) 55 | run_alembic_command( 56 | engine=engine, 57 | command="revision", 58 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 59 | ) 60 | 61 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 62 | 63 | with migration_create_path.open() as migration_file: 64 | migration_contents = migration_file.read() 65 | 66 | assert "op.create_entity" in migration_contents 67 | # Make sure #is_constraint flag was populated 68 | assert "is_constraint" in migration_contents 69 | assert "op.drop_entity" in migration_contents 70 | assert "op.replace_entity" not in migration_contents 71 | assert "from alembic_utils.pg_trigger import PGTrigger" in migration_contents 72 | 73 | # Execute upgrade 74 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 75 | # Execute Downgrade 76 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 77 | 78 | 79 | def test_trig_update_revision(sql_setup, engine) -> None: 80 | with engine.begin() as connection: 81 | connection.execute(FUNC.to_sql_statement_create()) 82 | connection.execute(TRIG.to_sql_statement_create()) 83 | 84 | UPDATED_TRIG = PGTrigger( 85 | schema=TRIG.schema, 86 | signature=TRIG.signature, 87 | on_entity=TRIG.on_entity, 88 | definition=""" 89 | AFTER INSERT ON public.account 90 | FOR EACH ROW EXECUTE PROCEDURE public.downcase_email() 91 | """, 92 | ) 93 | 94 | register_entities([FUNC, UPDATED_TRIG], entity_types=[PGTrigger]) 95 | 96 | # Autogenerate a new migration 97 | # It should detect the change we made and produce a "replace_function" statement 98 | run_alembic_command( 99 | engine=engine, 100 | command="revision", 101 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 102 | ) 103 | 104 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 105 | 106 | with migration_replace_path.open() as migration_file: 107 | migration_contents = migration_file.read() 108 | 109 | assert "op.replace_entity" in migration_contents 110 | assert "op.create_entity" not in migration_contents 111 | assert "op.drop_entity" not in migration_contents 112 | assert "from alembic_utils.pg_trigger import PGTrigger" in migration_contents 113 | 114 | # Execute upgrade 115 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 116 | 117 | # Execute Downgrade 118 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 119 | 120 | 121 | def test_noop_revision(sql_setup, engine) -> None: 122 | with engine.begin() as connection: 123 | connection.execute(FUNC.to_sql_statement_create()) 124 | connection.execute(TRIG.to_sql_statement_create()) 125 | 126 | register_entities([FUNC, TRIG], entity_types=[PGTrigger]) 127 | 128 | output = run_alembic_command( 129 | engine=engine, 130 | command="revision", 131 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 132 | ) 133 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 134 | 135 | with migration_do_nothing_path.open() as migration_file: 136 | migration_contents = migration_file.read() 137 | 138 | assert "op.create_entity" not in migration_contents 139 | assert "op.drop_entity" not in migration_contents 140 | assert "op.replace_entity" not in migration_contents 141 | assert "from alembic_utils" not in migration_contents 142 | 143 | # Execute upgrade 144 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 145 | # Execute Downgrade 146 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 147 | 148 | 149 | def test_drop(sql_setup, engine) -> None: 150 | # Manually create a SQL function 151 | with engine.begin() as connection: 152 | connection.execute(FUNC.to_sql_statement_create()) 153 | connection.execute(TRIG.to_sql_statement_create()) 154 | 155 | # Register no functions locally 156 | register_entities([], schemas=["public"], entity_types=[PGTrigger]) 157 | 158 | run_alembic_command( 159 | engine=engine, 160 | command="revision", 161 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 162 | ) 163 | 164 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 165 | 166 | with migration_create_path.open() as migration_file: 167 | migration_contents = migration_file.read() 168 | 169 | assert "op.drop_entity" in migration_contents 170 | assert "op.create_entity" in migration_contents 171 | assert "from alembic_utils" in migration_contents 172 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 173 | 174 | 175 | def test_unparsable() -> None: 176 | SQL = "create trigger lower_account_email faile fail fail" 177 | with pytest.raises(SQLParseFailure): 178 | PGTrigger.from_sql(SQL) 179 | 180 | 181 | def test_on_entity_schema_not_qualified() -> None: 182 | SQL = """create trigger lower_account_email 183 | AFTER INSERT ON account 184 | FOR EACH ROW EXECUTE PROCEDURE public.downcase_email() 185 | """ 186 | trigger = PGTrigger.from_sql(SQL) 187 | assert trigger.schema == "public" 188 | 189 | 190 | def test_fail_create_sql_statement_create(): 191 | trig = PGTrigger( 192 | schema=TRIG.schema, 193 | signature=TRIG.signature, 194 | on_entity=TRIG.on_entity, 195 | definition="INVALID DEF", 196 | ) 197 | 198 | with pytest.raises(SQLParseFailure): 199 | trig.to_sql_statement_create() 200 | -------------------------------------------------------------------------------- /src/test/test_pg_view.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy.exc import ProgrammingError 3 | 4 | from alembic_utils.exceptions import SQLParseFailure 5 | from alembic_utils.pg_view import PGView 6 | from alembic_utils.replaceable_entity import register_entities 7 | from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command 8 | 9 | TEST_VIEW = PGView( 10 | schema="DEV", signature="testExample", definition="select *, FALSE as is_updated from pg_views" 11 | ) 12 | 13 | 14 | def test_unparsable_view() -> None: 15 | SQL = "create or replace vew public.some_view as select 1 one;" 16 | with pytest.raises(SQLParseFailure): 17 | view = PGView.from_sql(SQL) 18 | 19 | 20 | def test_parsable_body() -> None: 21 | SQL = "create or replace view public.some_view as select 1 one;" 22 | try: 23 | view = PGView.from_sql(SQL) 24 | except SQLParseFailure: 25 | pytest.fail(f"Unexpected SQLParseFailure for view {SQL}") 26 | 27 | SQL = "create view public.some_view(one) as select 1 one;" 28 | try: 29 | view = PGView.from_sql(SQL) 30 | assert view.signature == "some_view" 31 | except SQLParseFailure: 32 | pytest.fail(f"Unexpected SQLParseFailure for view {SQL}") 33 | 34 | 35 | def test_create_revision(engine) -> None: 36 | register_entities([TEST_VIEW], entity_types=[PGView]) 37 | 38 | output = run_alembic_command( 39 | engine=engine, 40 | command="revision", 41 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 42 | ) 43 | 44 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 45 | 46 | with migration_create_path.open() as migration_file: 47 | migration_contents = migration_file.read() 48 | 49 | assert "op.create_entity" in migration_contents 50 | assert "op.drop_entity" in migration_contents 51 | assert "op.replace_entity" not in migration_contents 52 | assert "from alembic_utils.pg_view import PGView" in migration_contents 53 | 54 | # Execute upgrade 55 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 56 | # Execute Downgrade 57 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 58 | 59 | 60 | def test_update_revision(engine) -> None: 61 | # Create the view outside of a revision 62 | with engine.begin() as connection: 63 | connection.execute(TEST_VIEW.to_sql_statement_create()) 64 | 65 | # Update definition of TO_UPPER 66 | UPDATED_TEST_VIEW = PGView( 67 | TEST_VIEW.schema, TEST_VIEW.signature, """select *, TRUE as is_updated from pg_views;""" 68 | ) 69 | 70 | register_entities([UPDATED_TEST_VIEW], entity_types=[PGView]) 71 | 72 | # Autogenerate a new migration 73 | # It should detect the change we made and produce a "replace_function" statement 74 | output = run_alembic_command( 75 | engine=engine, 76 | command="revision", 77 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 78 | ) 79 | 80 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 81 | 82 | with migration_replace_path.open() as migration_file: 83 | migration_contents = migration_file.read() 84 | 85 | assert "op.replace_entity" in migration_contents 86 | assert "op.create_entity" not in migration_contents 87 | assert "op.drop_entity" not in migration_contents 88 | assert "from alembic_utils.pg_view import PGView" in migration_contents 89 | 90 | assert "true" in migration_contents.lower() 91 | assert "false" in migration_contents.lower() 92 | assert migration_contents.lower().find("true") < migration_contents.lower().find("false") 93 | 94 | # Execute upgrade 95 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 96 | # Execute Downgrade 97 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 98 | 99 | 100 | def test_noop_revision(engine) -> None: 101 | # Create the view outside of a revision 102 | with engine.begin() as connection: 103 | connection.execute(TEST_VIEW.to_sql_statement_create()) 104 | 105 | register_entities([TEST_VIEW], entity_types=[PGView]) 106 | 107 | # Create a third migration without making changes. 108 | # This should result in no create, drop or replace statements 109 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 110 | 111 | output = run_alembic_command( 112 | engine=engine, 113 | command="revision", 114 | command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, 115 | ) 116 | migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" 117 | 118 | with migration_do_nothing_path.open() as migration_file: 119 | migration_contents = migration_file.read() 120 | 121 | assert "op.create_entity" not in migration_contents 122 | assert "op.drop_entity" not in migration_contents 123 | assert "op.replace_entity" not in migration_contents 124 | assert "from alembic_utils" not in migration_contents 125 | 126 | # Execute upgrade 127 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 128 | # Execute Downgrade 129 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 130 | 131 | 132 | def test_drop_revision(engine) -> None: 133 | 134 | # Register no functions locally 135 | register_entities([], schemas=["DEV"], entity_types=[PGView]) 136 | 137 | # Manually create a SQL function 138 | with engine.begin() as connection: 139 | connection.execute(TEST_VIEW.to_sql_statement_create()) 140 | 141 | output = run_alembic_command( 142 | engine=engine, 143 | command="revision", 144 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, 145 | ) 146 | 147 | migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" 148 | 149 | with migration_create_path.open() as migration_file: 150 | migration_contents = migration_file.read() 151 | 152 | # import pdb; pdb.set_trace() 153 | 154 | assert "op.drop_entity" in migration_contents 155 | assert "op.create_entity" in migration_contents 156 | assert "from alembic_utils" in migration_contents 157 | assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") 158 | 159 | # Execute upgrade 160 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 161 | # Execute Downgrade 162 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 163 | 164 | 165 | def test_update_create_or_replace_failover_to_drop_add(engine) -> None: 166 | # Create the view outside of a revision 167 | with engine.begin() as connection: 168 | connection.execute(TEST_VIEW.to_sql_statement_create()) 169 | 170 | # Update definition of TO_UPPER 171 | # deleted columns from the beginning of the view. 172 | # this will fail a create or replace statemnt 173 | # psycopg2.errors.InvalidTableDefinition) cannot drop columns from view 174 | # and should fail over to drop and then replace (in plpgsql of `create_or_replace_entity` method 175 | # on pgview 176 | 177 | UPDATED_TEST_VIEW = PGView( 178 | TEST_VIEW.schema, TEST_VIEW.signature, """select TRUE as is_updated from pg_views""" 179 | ) 180 | 181 | register_entities([UPDATED_TEST_VIEW], entity_types=[PGView]) 182 | 183 | # Autogenerate a new migration 184 | # It should detect the change we made and produce a "replace_function" statement 185 | output = run_alembic_command( 186 | engine=engine, 187 | command="revision", 188 | command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, 189 | ) 190 | 191 | migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" 192 | 193 | with migration_replace_path.open() as migration_file: 194 | migration_contents = migration_file.read() 195 | 196 | assert "op.replace_entity" in migration_contents 197 | assert "op.create_entity" not in migration_contents 198 | assert "op.drop_entity" not in migration_contents 199 | assert "from alembic_utils.pg_view import PGView" in migration_contents 200 | 201 | # Execute upgrade 202 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 203 | # Execute Downgrade 204 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 205 | 206 | 207 | def test_attempt_revision_on_unparsable(engine) -> None: 208 | BROKEN_VIEW = PGView(schema="public", signature="broken_view", definition="NOPE;") 209 | register_entities([BROKEN_VIEW], entity_types=[PGView]) 210 | 211 | # Reraise of psycopg2.errors.SyntaxError 212 | with pytest.raises(ProgrammingError): 213 | run_alembic_command( 214 | engine=engine, 215 | command="revision", 216 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 217 | ) 218 | 219 | 220 | def test_create_revision_with_url_w_colon(engine) -> None: 221 | """Ensure no regression where views escape colons 222 | More info at: https://github.com/olirice/alembic_utils/issues/58 223 | """ 224 | url = "https://something/" 225 | query = f"SELECT concat('{url}', v::text) FROM generate_series(1,2) x(v)" 226 | some_view = PGView(schema="public", signature="exa", definition=query) 227 | register_entities([some_view], entity_types=[PGView]) 228 | 229 | output = run_alembic_command( 230 | engine=engine, 231 | command="revision", 232 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 233 | ) 234 | 235 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 236 | 237 | with migration_create_path.open() as migration_file: 238 | migration_contents = migration_file.read() 239 | 240 | assert url in migration_contents 241 | assert "op.create_entity" in migration_contents 242 | assert "op.drop_entity" in migration_contents 243 | assert "op.replace_entity" not in migration_contents 244 | assert "from alembic_utils.pg_view import PGView" in migration_contents 245 | 246 | # Execute upgrade 247 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 248 | # Execute Downgrade 249 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 250 | 251 | 252 | def test_view_contains_colon(engine) -> None: 253 | TEST_SEMI_VIEW = PGView( 254 | schema="public", 255 | signature="sample", 256 | definition="select ':' as myfield, '1'::int as othi", 257 | ) 258 | # NOTE: if a definition contains something that looks like a bind parameter e.g. :a 259 | # an exception is raised. This test confirms that non-bind-parameter usage of colon 260 | # is a non-issue 261 | 262 | register_entities([TEST_SEMI_VIEW], entity_types=[PGView]) 263 | 264 | output = run_alembic_command( 265 | engine=engine, 266 | command="revision", 267 | command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, 268 | ) 269 | 270 | migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" 271 | 272 | with migration_create_path.open() as migration_file: 273 | migration_contents = migration_file.read() 274 | 275 | assert "op.create_entity" in migration_contents 276 | assert "op.drop_entity" in migration_contents 277 | assert "op.replace_entity" not in migration_contents 278 | assert "from alembic_utils.pg_view import PGView" in migration_contents 279 | 280 | # Execute upgrade 281 | run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) 282 | # Execute Downgrade 283 | run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) 284 | -------------------------------------------------------------------------------- /src/test/test_recreate_dropped.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy.orm import Session 3 | 4 | from alembic_utils.depends import recreate_dropped 5 | from alembic_utils.pg_view import PGView 6 | 7 | TEST_ROOT_BIGINT = PGView( 8 | schema="public", signature="root", definition="select 1::bigint as some_val" 9 | ) 10 | 11 | TEST_ROOT_INT = PGView(schema="public", signature="root", definition="select 1::int as some_val") 12 | 13 | TEST_DEPENDENT = PGView(schema="public", signature="branch", definition="select * from public.root") 14 | 15 | 16 | def test_fails_without_defering(sess: Session) -> None: 17 | 18 | # Create the original view 19 | sess.execute(TEST_ROOT_BIGINT.to_sql_statement_create()) 20 | # Create the view that depends on it 21 | sess.execute(TEST_DEPENDENT.to_sql_statement_create()) 22 | 23 | # Try to update a column type of the base view from undeneath 24 | # the dependent view 25 | with pytest.raises(Exception): 26 | sess.execute(TEST_ROOT_INT.to_sql_statement_create_or_replace()) 27 | 28 | 29 | def test_succeeds_when_defering(engine) -> None: 30 | 31 | with engine.begin() as connection: 32 | # Create the original view 33 | connection.execute(TEST_ROOT_BIGINT.to_sql_statement_create()) 34 | # Create the view that depends on it 35 | connection.execute(TEST_DEPENDENT.to_sql_statement_create()) 36 | 37 | # Try to update a column type of the base view from undeneath 38 | # the dependent view 39 | with recreate_dropped(connection=engine) as sess: 40 | sess.execute(TEST_ROOT_INT.to_sql_statement_drop(cascade=True)) 41 | sess.execute(TEST_ROOT_INT.to_sql_statement_create()) 42 | 43 | 44 | def test_fails_gracefully_on_bad_user_statement(engine) -> None: 45 | with engine.begin() as connection: 46 | # Create the original view 47 | connection.execute(TEST_ROOT_BIGINT.to_sql_statement_create()) 48 | # Create the view that depends on it 49 | connection.execute(TEST_DEPENDENT.to_sql_statement_create()) 50 | 51 | # Execute a failing statement in the session 52 | with pytest.raises(Exception): 53 | with recreate_dropped(connection=engine) as sess: 54 | sess.execute(TEST_ROOT_INT.to_sql_statement_create()) 55 | 56 | 57 | def test_fails_if_user_creates_new_entity(engine) -> None: 58 | with engine.begin() as connection: 59 | # Create the original view 60 | connection.execute(TEST_ROOT_BIGINT.to_sql_statement_create()) 61 | 62 | # User creates a brand new entity 63 | with pytest.raises(Exception): 64 | with recreate_dropped(connection=connection) as sess: 65 | connection.execute(TEST_DEPENDENT.to_sql_statement_create()) 66 | -------------------------------------------------------------------------------- /src/test/test_simulate_entity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sqlalchemy import text 3 | from sqlalchemy.exc import DataError 4 | from sqlalchemy.orm import Session 5 | 6 | from alembic_utils.pg_view import PGView 7 | from alembic_utils.simulate import simulate_entity 8 | 9 | TEST_VIEW = PGView( 10 | schema="public", 11 | signature="tview", 12 | definition="select *, FALSE as is_updated from pg_views", 13 | ) 14 | 15 | 16 | def test_simulate_entity_shows_user_code_error(sess: Session) -> None: 17 | sess.execute(TEST_VIEW.to_sql_statement_create()) 18 | 19 | with pytest.raises(DataError): 20 | with simulate_entity(sess, TEST_VIEW): 21 | # Raises a sql error 22 | sess.execute(text("select 1/0")).fetchone() 23 | 24 | # Confirm context manager exited gracefully 25 | assert True 26 | -------------------------------------------------------------------------------- /src/test/test_statement.py: -------------------------------------------------------------------------------- 1 | from alembic_utils.statement import coerce_to_quoted, coerce_to_unquoted 2 | 3 | 4 | def test_coerce_to_quoted() -> None: 5 | assert coerce_to_quoted('"public"') == '"public"' 6 | assert coerce_to_quoted("public") == '"public"' 7 | assert coerce_to_quoted("public.table") == '"public"."table"' 8 | assert coerce_to_quoted('"public".table') == '"public"."table"' 9 | assert coerce_to_quoted('public."table"') == '"public"."table"' 10 | 11 | 12 | def test_coerce_to_unquoted() -> None: 13 | assert coerce_to_unquoted('"public"') == "public" 14 | assert coerce_to_unquoted("public") == "public" 15 | assert coerce_to_unquoted("public.table") == "public.table" 16 | assert coerce_to_unquoted('"public".table') == "public.table" 17 | --------------------------------------------------------------------------------