├── tests ├── types │ ├── __init__.py │ ├── test_ip_address.py │ ├── test_url.py │ ├── test_email.py │ ├── test_country.py │ ├── test_ltree.py │ ├── test_currency.py │ ├── encrypted │ │ └── test_padding.py │ ├── test_weekdays.py │ ├── test_uuid.py │ ├── test_locale.py │ ├── test_enriched_date_pendulum.py │ ├── test_color.py │ ├── test_tsvector.py │ ├── test_arrow.py │ ├── test_scalar_list.py │ └── test_enriched_datetime_pendulum.py ├── aggregate │ ├── __init__.py │ ├── test_with_ondelete_cascade.py │ ├── test_search_vectors.py │ ├── test_with_column_alias.py │ ├── test_o2m_o2m.py │ ├── test_m2m.py │ ├── test_custom_select_expressions.py │ ├── test_backrefs.py │ ├── test_simple_paths.py │ ├── test_o2m_m2m.py │ ├── test_m2m_m2m.py │ └── test_multiple_aggregates_per_class.py ├── functions │ ├── __init__.py │ ├── test_escape_like.py │ ├── test_naturally_equivalent.py │ ├── test_get_bind.py │ ├── test_is_loaded.py │ ├── test_quote.py │ ├── test_json_sql.py │ ├── test_jsonb_sql.py │ ├── test_table_name.py │ ├── test_get_hybrid_properties.py │ ├── test_get_column_key.py │ ├── test_identity.py │ ├── test_cast_if.py │ ├── test_has_changes.py │ ├── test_get_primary_keys.py │ ├── test_get_type.py │ ├── test_get_columns.py │ ├── test_non_indexed_foreign_keys.py │ ├── test_render.py │ └── test_get_tables.py ├── observes │ ├── __init__.py │ ├── test_o2o_o2o.py │ ├── test_dynamic_relationship.py │ └── test_o2o_o2o_o2o.py ├── primitives │ ├── __init__.py │ ├── test_currency.py │ └── test_country.py ├── relationships │ └── __init__.py ├── __init__.py ├── test_instrumented_list.py ├── generic_relationship │ ├── test_column_aliases.py │ ├── test_abstract_base_class.py │ ├── test_hybrid_properties.py │ ├── test_composite_keys.py │ └── __init__.py ├── test_case_insensitive_comparator.py ├── test_instant_defaults_listener.py ├── test_expressions.py ├── test_models.py ├── test_query_chain.py └── test_auto_delete_orphans.py ├── docs ├── license.rst ├── observers.rst ├── aggregates.rst ├── utility_classes.rst ├── models.rst ├── listeners.rst ├── view.rst ├── index.rst ├── testing.rst ├── range_data_types.rst ├── foreign_key_helpers.rst ├── database_helpers.rst ├── installation.rst ├── conf.py └── orm_helpers.rst ├── sqlalchemy_utils ├── types │ ├── encrypted │ │ └── __init__.py │ ├── scalar_coercible.py │ ├── enriched_datetime │ │ ├── __init__.py │ │ ├── pendulum_date.py │ │ ├── arrow_datetime.py │ │ ├── pendulum_datetime.py │ │ ├── enriched_date_type.py │ │ └── enriched_datetime_type.py │ ├── bit.py │ ├── email.py │ ├── ip_address.py │ ├── arrow.py │ ├── url.py │ ├── country.py │ ├── locale.py │ ├── __init__.py │ ├── currency.py │ ├── color.py │ ├── json.py │ ├── weekdays.py │ ├── scalar_list.py │ └── timezone.py ├── primitives │ ├── __init__.py │ ├── weekday.py │ ├── weekdays.py │ ├── country.py │ └── currency.py ├── exceptions.py ├── compat.py ├── utils.py ├── relationships │ └── chained_join.py ├── functions │ ├── __init__.py │ ├── render.py │ └── sort_query.py ├── expressions.py ├── operators.py ├── proxy_dict.py └── __init__.py ├── requirements ├── docs │ ├── pyproject.toml │ └── requirements.txt └── README.rst ├── .flake8 ├── MANIFEST.in ├── .github ├── dependabot.yaml └── workflows │ ├── lint.yaml │ └── test.yaml ├── .readthedocs.yaml ├── .editorconfig ├── .ci └── install_mssql.sh ├── ROADMAP.rst ├── .pre-commit-config.yaml ├── .gitignore ├── README.rst ├── LICENSE ├── RELEASE.md └── tox.ini /tests/types/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/aggregate/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/functions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/observes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/relationships/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | License 2 | ======= 3 | 4 | .. include:: ../LICENSE 5 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/encrypted/__init__.py: -------------------------------------------------------------------------------- 1 | # Module for encrypted type 2 | -------------------------------------------------------------------------------- /docs/observers.rst: -------------------------------------------------------------------------------- 1 | Observers 2 | ========= 3 | 4 | .. automodule:: sqlalchemy_utils.observer 5 | 6 | .. autofunction:: observes 7 | -------------------------------------------------------------------------------- /docs/aggregates.rst: -------------------------------------------------------------------------------- 1 | Aggregated attributes 2 | ===================== 3 | 4 | .. automodule:: sqlalchemy_utils.aggregates 5 | 6 | .. autofunction:: aggregated 7 | -------------------------------------------------------------------------------- /requirements/docs/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | package-mode = false 3 | 4 | [tool.poetry.dependencies] 5 | python = ">=3.9" 6 | sphinx = "*" 7 | sphinx-rtd-theme = "*" 8 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | def assert_contains(clause, query): 2 | # Test that query executes 3 | query.all() 4 | assert clause in str(query), f"[{clause}] is not in [{query}]" 5 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203 4 | exclude = 5 | .eggs, 6 | .git, 7 | .tox, 8 | .venv, 9 | build, 10 | docs/conf.py 11 | -------------------------------------------------------------------------------- /tests/functions/test_escape_like.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy_utils import escape_like 2 | 3 | 4 | class TestEscapeLike: 5 | def test_escapes_wildcards(self): 6 | assert escape_like('_*%') == '*_***%' 7 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CHANGES.rst LICENSE README.rst 2 | recursive-include tests * 3 | recursive-exclude tests *.pyc 4 | recursive-include docs * 5 | recursive-exclude docs *.pyc 6 | prune docs/_build 7 | exclude docs/_themes/.git 8 | -------------------------------------------------------------------------------- /docs/utility_classes.rst: -------------------------------------------------------------------------------- 1 | Utility classes 2 | =============== 3 | 4 | QueryChain 5 | ---------- 6 | 7 | .. automodule:: sqlalchemy_utils.query_chain 8 | 9 | API 10 | --- 11 | 12 | .. autoclass:: QueryChain 13 | :members: 14 | -------------------------------------------------------------------------------- /sqlalchemy_utils/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | from .country import Country # noqa 2 | from .currency import Currency # noqa 3 | from .ltree import Ltree # noqa 4 | from .weekday import WeekDay # noqa 5 | from .weekdays import WeekDays # noqa 6 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/scalar_coercible.py: -------------------------------------------------------------------------------- 1 | class ScalarCoercible: 2 | def _coerce(self, value): 3 | raise NotImplementedError 4 | 5 | def coercion_listener(self, target, value, oldvalue, initiator): 6 | return self._coerce(value) 7 | -------------------------------------------------------------------------------- /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | groups: 8 | github-actions: 9 | patterns: 10 | - "*" 11 | -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | Model mixins 2 | ============ 3 | 4 | 5 | Timestamp 6 | --------- 7 | 8 | .. autoclass:: sqlalchemy_utils.models.Timestamp 9 | 10 | 11 | generic_repr 12 | ------------ 13 | 14 | .. autofunction:: sqlalchemy_utils.models.generic_repr 15 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/enriched_datetime/__init__.py: -------------------------------------------------------------------------------- 1 | # Module for enriched date, datetime type 2 | from .arrow_datetime import ArrowDateTime # noqa 3 | from .pendulum_date import PendulumDate # noqa 4 | from .pendulum_datetime import PendulumDateTime # noqa 5 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-24.04" 5 | tools: 6 | python: "3.13" 7 | 8 | python: 9 | install: 10 | - requirements: "requirements/docs/requirements.txt" 11 | 12 | sphinx: 13 | configuration: "docs/conf.py" 14 | fail_on_warning: false 15 | -------------------------------------------------------------------------------- /sqlalchemy_utils/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global SQLAlchemy-Utils exception classes. 3 | """ 4 | 5 | 6 | class ImproperlyConfigured(Exception): 7 | """ 8 | SQLAlchemy-Utils is improperly configured; normally due to usage of 9 | a utility that depends on a missing library. 10 | """ 11 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig helps developers define and maintain consistent 2 | # coding styles between different editors and IDEs 3 | # editorconfig.org 4 | 5 | root = true 6 | 7 | 8 | [*] 9 | indent_style = space 10 | end_of_line = lf 11 | charset = utf-8 12 | trim_trailing_whitespace = true 13 | insert_final_newline = true 14 | indent_size = 4 15 | 16 | [*.yaml] 17 | indent_size = 2 18 | -------------------------------------------------------------------------------- /.ci/install_mssql.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ ! -f /etc/apt/sources.list.d/microsoft-prod.list ]; then 4 | curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add - 5 | sudo sh -c "curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -r -s)/prod.list > /etc/apt/sources.list.d/mssql-release.list" 6 | fi 7 | 8 | sudo apt-get update 9 | sudo ACCEPT_EULA=Y apt-get -y install msodbcsql18 unixodbc 10 | -------------------------------------------------------------------------------- /docs/listeners.rst: -------------------------------------------------------------------------------- 1 | Listeners 2 | ========= 3 | 4 | 5 | .. module:: sqlalchemy_utils.listeners 6 | 7 | Automatic data coercion 8 | ----------------------- 9 | 10 | .. autofunction:: force_auto_coercion 11 | 12 | 13 | Instant defaults 14 | ---------------- 15 | 16 | .. autofunction:: force_instant_defaults 17 | 18 | 19 | Many-to-many orphan deletion 20 | ---------------------------- 21 | 22 | .. autofunction:: auto_delete_orphans 23 | -------------------------------------------------------------------------------- /docs/view.rst: -------------------------------------------------------------------------------- 1 | View utilities 2 | ============== 3 | 4 | 5 | create_view 6 | ----------- 7 | 8 | .. autofunction:: sqlalchemy_utils.view.create_view 9 | 10 | 11 | create_materialized_view 12 | ------------------------ 13 | 14 | .. autofunction:: sqlalchemy_utils.view.create_materialized_view 15 | 16 | 17 | refresh_materialized_view 18 | ------------------------- 19 | 20 | .. autofunction:: sqlalchemy_utils.view.refresh_materialized_view 21 | -------------------------------------------------------------------------------- /sqlalchemy_utils/compat.py: -------------------------------------------------------------------------------- 1 | import re 2 | from importlib.metadata import metadata 3 | 4 | 5 | def get_sqlalchemy_version(version=metadata('sqlalchemy')['Version']): 6 | """Extract the sqlalchemy version as a tuple of integers.""" 7 | 8 | match = re.search(r'^(\d+)(?:\.(\d+)(?:\.(\d+))?)?', version) 9 | try: 10 | return tuple(int(v) for v in match.groups() if v is not None) 11 | except AttributeError: 12 | return () 13 | -------------------------------------------------------------------------------- /ROADMAP.rst: -------------------------------------------------------------------------------- 1 | * Add efficient pagination support 2 | https://www.depesz.com/2011/05/20/pagination-with-fixed-order/ 3 | https://stackoverflow.com/questions/6618366/improving-offset-performance-in-postgresql 4 | * Generic file model 5 | https://github.com/jpvanhal/silo 6 | * Query to Postgres JSON converter 7 | https://hashrocket.com/blog/posts/faster-json-generation-with-postgresql 8 | * Postgres Cube datatype: 9 | https://www.postgresql.org/docs/current/cube.html 10 | -------------------------------------------------------------------------------- /tests/functions/test_naturally_equivalent.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy_utils.functions import naturally_equivalent 2 | 3 | 4 | class TestNaturallyEquivalent: 5 | def test_returns_true_when_properties_match(self, User): 6 | assert naturally_equivalent( 7 | User(name='someone'), User(name='someone') 8 | ) 9 | 10 | def test_skips_primary_keys(self, User): 11 | assert naturally_equivalent( 12 | User(id=1, name='someone'), User(id=2, name='someone') 13 | ) 14 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Python linting 2 | on: [push, pull_request] 3 | jobs: 4 | test: 5 | name: Lint 6 | runs-on: ubuntu-latest 7 | strategy: 8 | matrix: 9 | python-version: ['3.13'] 10 | steps: 11 | - uses: actions/checkout@v6 12 | - uses: actions/setup-python@v6 13 | with: 14 | python-version: ${{ matrix.python-version }} 15 | - name: Install tox 16 | run: pip install tox 17 | - name: Run linting 18 | run: tox -e ruff 19 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | SQLAlchemy-Utils 2 | ================ 3 | 4 | 5 | SQLAlchemy-Utils provides custom data types and various utility functions for SQLAlchemy. 6 | 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | installation 12 | listeners 13 | data_types 14 | range_data_types 15 | aggregates 16 | observers 17 | internationalization 18 | generic_relationship 19 | database_helpers 20 | foreign_key_helpers 21 | orm_helpers 22 | utility_classes 23 | models 24 | view 25 | testing 26 | license 27 | -------------------------------------------------------------------------------- /sqlalchemy_utils/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | 4 | def str_coercible(cls): 5 | def __str__(self): 6 | return self.__unicode__() 7 | 8 | cls.__str__ = __str__ 9 | return cls 10 | 11 | 12 | def is_sequence(value): 13 | return isinstance(value, Iterable) and not isinstance(value, str) 14 | 15 | 16 | def starts_with(iterable, prefix): 17 | """ 18 | Returns whether or not given iterable starts with given prefix. 19 | """ 20 | return list(iterable)[0 : len(prefix)] == list(prefix) 21 | -------------------------------------------------------------------------------- /docs/testing.rst: -------------------------------------------------------------------------------- 1 | Testing 2 | ======= 3 | 4 | .. automodule:: sqlalchemy_utils.asserts 5 | 6 | 7 | assert_min_value 8 | ---------------- 9 | 10 | .. autofunction:: assert_min_value 11 | 12 | assert_max_length 13 | ----------------- 14 | 15 | .. autofunction:: assert_max_length 16 | 17 | assert_max_value 18 | ---------------- 19 | 20 | .. autofunction:: assert_max_value 21 | 22 | assert_nullable 23 | --------------- 24 | 25 | .. autofunction:: assert_nullable 26 | 27 | assert_non_nullable 28 | ------------------- 29 | 30 | .. autofunction:: assert_non_nullable 31 | -------------------------------------------------------------------------------- /docs/range_data_types.rst: -------------------------------------------------------------------------------- 1 | Range data types 2 | ================ 3 | 4 | .. automodule:: sqlalchemy_utils.types.range 5 | 6 | 7 | DateRangeType 8 | ------------- 9 | 10 | .. autoclass:: DateRangeType 11 | 12 | 13 | DateTimeRangeType 14 | ----------------- 15 | 16 | .. autoclass:: DateTimeRangeType 17 | 18 | 19 | IntRangeType 20 | ------------ 21 | 22 | .. autoclass:: IntRangeType 23 | 24 | 25 | NumericRangeType 26 | ---------------- 27 | 28 | .. autoclass:: NumericRangeType 29 | 30 | 31 | RangeComparator 32 | --------------- 33 | 34 | .. autoclass:: RangeComparator 35 | :members: 36 | -------------------------------------------------------------------------------- /tests/test_instrumented_list.py: -------------------------------------------------------------------------------- 1 | class TestInstrumentedList: 2 | def test_any_returns_true_if_member_has_attr_defined( 3 | self, 4 | Category, 5 | Article 6 | ): 7 | category = Category() 8 | category.articles.append(Article()) 9 | category.articles.append(Article(name='some name')) 10 | assert category.articles.any('name') 11 | 12 | def test_any_returns_false_if_no_member_has_attr_defined( 13 | self, 14 | Category, 15 | Article 16 | ): 17 | category = Category() 18 | category.articles.append(Article()) 19 | assert not category.articles.any('name') 20 | -------------------------------------------------------------------------------- /tests/functions/test_get_bind.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sqlalchemy_utils import get_bind 4 | 5 | 6 | class TestGetBind: 7 | def test_with_session(self, session, connection): 8 | assert get_bind(session) == connection 9 | 10 | def test_with_connection(self, session, connection): 11 | assert get_bind(connection) == connection 12 | 13 | def test_with_model_object(self, session, connection, Article): 14 | article = Article() 15 | session.add(article) 16 | assert get_bind(article) == connection 17 | 18 | def test_with_unknown_type(self): 19 | with pytest.raises(TypeError): 20 | get_bind(None) 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Run `pre-commit autoupdate` to update the hook versions. 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | 11 | - repo: https://github.com/astral-sh/ruff-pre-commit 12 | rev: v0.12.2 13 | hooks: 14 | - id: ruff-check 15 | args: [--fix] 16 | - id: ruff-format 17 | 18 | - repo: https://github.com/asottile/pyupgrade 19 | rev: v3.20.0 20 | hooks: 21 | - id: pyupgrade 22 | name: "Enforce Python 3.9+ idioms" 23 | args: ["--py39-plus"] 24 | -------------------------------------------------------------------------------- /tests/functions/test_is_loaded.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import is_loaded 5 | 6 | 7 | @pytest.fixture 8 | def Article(Base): 9 | class Article(Base): 10 | __tablename__ = 'article_translation' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | title = sa.orm.deferred(sa.Column(sa.String(100))) 13 | return Article 14 | 15 | 16 | class TestIsLoaded: 17 | 18 | def test_loaded_property(self, Article): 19 | article = Article(id=1) 20 | assert is_loaded(article, 'id') 21 | 22 | def test_unloaded_property(self, Article): 23 | article = Article(id=4) 24 | assert not is_loaded(article, 'title') 25 | -------------------------------------------------------------------------------- /tests/functions/test_quote.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.dialects import postgresql 2 | 3 | from sqlalchemy_utils.functions import quote 4 | 5 | 6 | class TestQuote: 7 | def test_quote_with_preserved_keyword(self, connection, session): 8 | assert quote(connection, 'order') == '"order"' 9 | assert quote(session, 'order') == '"order"' 10 | assert quote(postgresql.dialect(), 'order') == '"order"' 11 | 12 | def test_quote_with_non_preserved_keyword( 13 | self, 14 | connection, 15 | session 16 | ): 17 | assert quote(connection, 'some_order') == 'some_order' 18 | assert quote(session, 'some_order') == 'some_order' 19 | assert quote(postgresql.dialect(), 'some_order') == 'some_order' 20 | -------------------------------------------------------------------------------- /docs/foreign_key_helpers.rst: -------------------------------------------------------------------------------- 1 | Foreign key helpers 2 | =================== 3 | 4 | 5 | dependent_objects 6 | ----------------- 7 | 8 | .. autofunction:: sqlalchemy_utils.functions.dependent_objects 9 | 10 | 11 | get_referencing_foreign_keys 12 | ---------------------------- 13 | 14 | .. autofunction:: sqlalchemy_utils.functions.get_referencing_foreign_keys 15 | 16 | 17 | group_foreign_keys 18 | ------------------ 19 | 20 | .. autofunction:: sqlalchemy_utils.functions.group_foreign_keys 21 | 22 | 23 | merge_references 24 | ---------------- 25 | 26 | .. autofunction:: sqlalchemy_utils.functions.merge_references 27 | 28 | 29 | non_indexed_foreign_keys 30 | ------------------------ 31 | 32 | .. autofunction:: sqlalchemy_utils.functions.non_indexed_foreign_keys 33 | -------------------------------------------------------------------------------- /sqlalchemy_utils/relationships/chained_join.py: -------------------------------------------------------------------------------- 1 | def chained_join(*relationships): 2 | """ 3 | Return a chained Join object for given relationships. 4 | """ 5 | property_ = relationships[0].property 6 | 7 | if property_.secondary is not None: 8 | from_ = property_.secondary.join( 9 | property_.mapper.class_.__table__, property_.secondaryjoin 10 | ) 11 | else: 12 | from_ = property_.mapper.class_.__table__ 13 | for relationship in relationships[1:]: 14 | prop = relationship.property 15 | if prop.secondary is not None: 16 | from_ = from_.join(prop.secondary, prop.primaryjoin) 17 | 18 | from_ = from_.join(prop.mapper.class_, prop.secondaryjoin) 19 | else: 20 | from_ = from_.join(prop.mapper.class_, prop.primaryjoin) 21 | return from_ 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | __pycache__ 3 | 4 | # C extensions 5 | *.so 6 | 7 | # Packages 8 | *.egg 9 | *.egg-info 10 | dist 11 | build 12 | eggs 13 | parts 14 | bin 15 | var 16 | sdist 17 | develop-eggs 18 | .installed.cfg 19 | .cache 20 | .eggs 21 | lib 22 | lib64 23 | docs/_build 24 | 25 | # Installer logs 26 | pip-log.txt 27 | 28 | # Unit test / coverage reports 29 | .coverage 30 | .tox 31 | coverage.xml 32 | nosetests.xml 33 | 34 | # Translations 35 | *.mo 36 | 37 | # Mr Developer 38 | .mr.developer.cfg 39 | .project 40 | .pydevproject 41 | 42 | # vim 43 | [._]*.s[a-w][a-z] 44 | [._]s[a-w][a-z] 45 | *.un~ 46 | Session.vim 47 | .netrwhist 48 | *~ 49 | 50 | # Sublime Text 51 | *.sublime-* 52 | 53 | # JetBrains 54 | .idea/ 55 | 56 | # Environments 57 | .env 58 | .venv 59 | env/ 60 | venv/ 61 | ENV/ 62 | env.bak/ 63 | venv.bak/ 64 | 65 | poetry.lock 66 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/bit.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | from sqlalchemy.dialects.postgresql import BIT 3 | 4 | 5 | class BitType(sa.types.TypeDecorator): 6 | """ 7 | BitType offers way of saving BITs into database. 8 | """ 9 | 10 | impl = sa.types.BINARY 11 | 12 | cache_ok = True 13 | 14 | def __init__(self, length=1, **kwargs): 15 | self.length = length 16 | sa.types.TypeDecorator.__init__(self, **kwargs) 17 | 18 | def load_dialect_impl(self, dialect): 19 | # Use the native BIT type for drivers that has it. 20 | if dialect.name == 'postgresql': 21 | return dialect.type_descriptor(BIT(self.length)) 22 | elif dialect.name == 'sqlite': 23 | return dialect.type_descriptor(sa.String(self.length)) 24 | else: 25 | return dialect.type_descriptor(type(self.impl)(self.length)) 26 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | SQLAlchemy-Utils 2 | ================ 3 | 4 | |Build| |Version Status| |Downloads| 5 | 6 | Various utility functions, new data types and helpers for SQLAlchemy. 7 | 8 | 9 | Resources 10 | --------- 11 | 12 | - `Documentation `_ 13 | - `Issue Tracker `_ 14 | - `Code `_ 15 | 16 | .. |Build| image:: https://github.com/kvesteri/sqlalchemy-utils/actions/workflows/test.yaml/badge.svg?branch=master 17 | :target: https://github.com/kvesteri/sqlalchemy-utils/actions/workflows/test.yaml 18 | :alt: Tests 19 | .. |Version Status| image:: https://img.shields.io/pypi/v/SQLAlchemy-Utils.svg 20 | :target: https://pypi.python.org/pypi/SQLAlchemy-Utils/ 21 | .. |Downloads| image:: https://img.shields.io/pypi/dm/SQLAlchemy-Utils.svg 22 | :target: https://pypi.python.org/pypi/SQLAlchemy-Utils/ 23 | -------------------------------------------------------------------------------- /tests/functions/test_json_sql.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import json_sql 5 | 6 | 7 | @pytest.mark.usefixtures('postgresql_dsn') 8 | class TestJSONSQL: 9 | 10 | @pytest.mark.parametrize( 11 | ('value', 'result'), 12 | ( 13 | (1, 1), 14 | (14.14, 14.14), 15 | ({'a': 2, 'b': 'c'}, {'a': 2, 'b': 'c'}), 16 | ( 17 | {'a': {'b': 'c'}}, 18 | {'a': {'b': 'c'}} 19 | ), 20 | ({}, {}), 21 | ([1, 2], [1, 2]), 22 | ([], []), 23 | ( 24 | [sa.select(sa.text('1')).label('alias')], 25 | [1] 26 | ) 27 | ) 28 | ) 29 | def test_compiled_scalars(self, connection, value, result): 30 | assert result == ( 31 | connection.execute(sa.select(json_sql(value))).fetchone()[0] 32 | ) 33 | -------------------------------------------------------------------------------- /tests/functions/test_jsonb_sql.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import jsonb_sql 5 | 6 | 7 | @pytest.mark.usefixtures('postgresql_dsn') 8 | class TestJSONBSQL: 9 | 10 | @pytest.mark.parametrize( 11 | ('value', 'result'), 12 | ( 13 | (1, 1), 14 | (14.14, 14.14), 15 | ({'a': 2, 'b': 'c'}, {'a': 2, 'b': 'c'}), 16 | ( 17 | {'a': {'b': 'c'}}, 18 | {'a': {'b': 'c'}} 19 | ), 20 | ({}, {}), 21 | ([1, 2], [1, 2]), 22 | ([], []), 23 | ( 24 | [sa.select(sa.text('1')).label('alias')], 25 | [1] 26 | ) 27 | ) 28 | ) 29 | def test_compiled_scalars(self, connection, value, result): 30 | assert result == ( 31 | connection.execute(sa.select(jsonb_sql(value))).fetchone()[0] 32 | ) 33 | -------------------------------------------------------------------------------- /tests/functions/test_table_name.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import table_name 5 | 6 | 7 | @pytest.fixture 8 | def Building(Base): 9 | class Building(Base): 10 | __tablename__ = 'building' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | name = sa.Column(sa.Unicode(255)) 13 | return Building 14 | 15 | 16 | @pytest.fixture 17 | def init_models(Base): 18 | pass 19 | 20 | 21 | class TestTableName: 22 | 23 | def test_class(self, Building): 24 | assert table_name(Building) == 'building' 25 | del Building.__tablename__ 26 | assert table_name(Building) == 'building' 27 | 28 | def test_attribute(self, Building): 29 | assert table_name(Building.id) == 'building' 30 | assert table_name(Building.name) == 'building' 31 | 32 | def test_target(self, Building): 33 | assert table_name(Building()) == 'building' 34 | -------------------------------------------------------------------------------- /docs/database_helpers.rst: -------------------------------------------------------------------------------- 1 | Database helpers 2 | ================ 3 | 4 | 5 | database_exists 6 | --------------- 7 | 8 | .. autofunction:: sqlalchemy_utils.functions.database_exists 9 | 10 | 11 | create_database 12 | --------------- 13 | 14 | .. autofunction:: sqlalchemy_utils.functions.create_database 15 | 16 | 17 | drop_database 18 | ------------- 19 | 20 | .. autofunction:: sqlalchemy_utils.functions.drop_database 21 | 22 | 23 | has_index 24 | --------- 25 | 26 | .. autofunction:: sqlalchemy_utils.functions.has_index 27 | 28 | 29 | has_unique_index 30 | ---------------- 31 | 32 | .. autofunction:: sqlalchemy_utils.functions.has_unique_index 33 | 34 | 35 | json_sql 36 | -------- 37 | 38 | .. autofunction:: sqlalchemy_utils.functions.json_sql 39 | 40 | 41 | render_expression 42 | ----------------- 43 | 44 | .. autofunction:: sqlalchemy_utils.functions.render_expression 45 | 46 | 47 | render_statement 48 | ---------------- 49 | 50 | .. autofunction:: sqlalchemy_utils.functions.render_statement 51 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/enriched_datetime/pendulum_date.py: -------------------------------------------------------------------------------- 1 | from ...exceptions import ImproperlyConfigured 2 | from .pendulum_datetime import PendulumDateTime 3 | 4 | pendulum = None 5 | try: 6 | import pendulum 7 | except ImportError: 8 | pass 9 | 10 | 11 | class PendulumDate(PendulumDateTime): 12 | def __init__(self): 13 | if not pendulum: 14 | raise ImproperlyConfigured( 15 | "'pendulum' package is required to use 'PendulumDate'" 16 | ) 17 | 18 | def _coerce(self, impl, value): 19 | if value: 20 | if not isinstance(value, pendulum.Date): 21 | value = super()._coerce(impl, value).date() 22 | return value 23 | 24 | def process_result_value(self, impl, value, dialect): 25 | if value: 26 | return pendulum.parse(value.isoformat()).date() 27 | return value 28 | 29 | def process_bind_param(self, impl, value, dialect): 30 | if value: 31 | return self._coerce(impl, value) 32 | return value 33 | -------------------------------------------------------------------------------- /sqlalchemy_utils/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .database import ( # noqa 2 | create_database, 3 | database_exists, 4 | drop_database, 5 | escape_like, 6 | has_index, 7 | has_unique_index, 8 | is_auto_assigned_date_column, 9 | json_sql, 10 | jsonb_sql, 11 | ) 12 | from .foreign_keys import ( # noqa 13 | dependent_objects, 14 | get_fk_constraint_for_columns, 15 | get_referencing_foreign_keys, 16 | group_foreign_keys, 17 | merge_references, 18 | non_indexed_foreign_keys, 19 | ) 20 | from .mock import create_mock_engine, mock_engine # noqa 21 | from .orm import ( # noqa 22 | cast_if, 23 | get_bind, 24 | get_class_by_table, 25 | get_column_key, 26 | get_columns, 27 | get_declarative_base, 28 | get_hybrid_properties, 29 | get_mapper, 30 | get_primary_keys, 31 | get_tables, 32 | get_type, 33 | getdotattr, 34 | has_changes, 35 | identity, 36 | is_loaded, 37 | naturally_equivalent, 38 | quote, 39 | table_name, 40 | ) 41 | from .render import render_expression, render_statement # noqa 42 | from .sort_query import make_order_by_deterministic # noqa 43 | -------------------------------------------------------------------------------- /tests/generic_relationship/test_column_aliases.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import generic_relationship 5 | 6 | from . import GenericRelationshipTestCase 7 | 8 | 9 | @pytest.fixture 10 | def Building(Base): 11 | class Building(Base): 12 | __tablename__ = 'building' 13 | id = sa.Column(sa.Integer, primary_key=True) 14 | return Building 15 | 16 | 17 | @pytest.fixture 18 | def User(Base): 19 | class User(Base): 20 | __tablename__ = 'user' 21 | id = sa.Column(sa.Integer, primary_key=True) 22 | return User 23 | 24 | 25 | @pytest.fixture 26 | def Event(Base): 27 | class Event(Base): 28 | __tablename__ = 'event' 29 | id = sa.Column(sa.Integer, primary_key=True) 30 | 31 | object_type = sa.Column(sa.Unicode(255), name="objectType") 32 | object_id = sa.Column(sa.Integer, nullable=False) 33 | 34 | object = generic_relationship(object_type, object_id) 35 | return Event 36 | 37 | 38 | @pytest.fixture 39 | def init_models(Building, User, Event): 40 | pass 41 | 42 | 43 | class TestGenericRelationship(GenericRelationshipTestCase): 44 | pass 45 | -------------------------------------------------------------------------------- /tests/types/test_ip_address.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.types import ip_address 5 | 6 | 7 | @pytest.fixture 8 | def Visitor(Base): 9 | class Visitor(Base): 10 | __tablename__ = 'document' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | ip_address = sa.Column(ip_address.IPAddressType) 13 | 14 | def __repr__(self): 15 | return 'Visitor(%r)' % self.id 16 | return Visitor 17 | 18 | 19 | @pytest.fixture 20 | def init_models(Visitor): 21 | pass 22 | 23 | 24 | @pytest.mark.skipif('ip_address.ip_address is None') 25 | class TestIPAddressType: 26 | def test_parameter_processing(self, session, Visitor): 27 | visitor = Visitor( 28 | ip_address='111.111.111.111' 29 | ) 30 | 31 | session.add(visitor) 32 | session.commit() 33 | 34 | visitor = session.query(Visitor).first() 35 | assert str(visitor.ip_address) == '111.111.111.111' 36 | 37 | def test_compilation(self, Visitor, session): 38 | query = sa.select(Visitor.ip_address) 39 | # the type should be cacheable and not throw exception 40 | session.execute(query) 41 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/enriched_datetime/arrow_datetime.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from datetime import datetime 3 | 4 | from ...exceptions import ImproperlyConfigured 5 | 6 | arrow = None 7 | try: 8 | import arrow 9 | except ImportError: 10 | pass 11 | 12 | 13 | class ArrowDateTime: 14 | def __init__(self): 15 | if not arrow: 16 | raise ImproperlyConfigured( 17 | "'arrow' package is required to use 'ArrowDateTime'" 18 | ) 19 | 20 | def _coerce(self, impl, value): 21 | if isinstance(value, str): 22 | value = arrow.get(value) 23 | elif isinstance(value, Iterable): 24 | value = arrow.get(*value) 25 | elif isinstance(value, datetime): 26 | value = arrow.get(value) 27 | return value 28 | 29 | def process_bind_param(self, impl, value, dialect): 30 | if value: 31 | utc_val = self._coerce(impl, value).to('UTC') 32 | return utc_val.datetime if impl.timezone else utc_val.naive 33 | return value 34 | 35 | def process_result_value(self, impl, value, dialect): 36 | if value: 37 | return arrow.get(value) 38 | return value 39 | -------------------------------------------------------------------------------- /tests/types/test_url.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.types import url 5 | 6 | 7 | @pytest.fixture 8 | def User(Base): 9 | class User(Base): 10 | __tablename__ = 'user' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | website = sa.Column(url.URLType) 13 | 14 | def __repr__(self): 15 | return 'User(%r)' % self.id 16 | return User 17 | 18 | 19 | @pytest.fixture 20 | def init_models(User): 21 | pass 22 | 23 | 24 | @pytest.mark.skipif('url.furl is None') 25 | class TestURLType: 26 | def test_color_parameter_processing(self, session, User): 27 | user = User( 28 | website=url.furl('www.example.com') 29 | ) 30 | 31 | session.add(user) 32 | session.commit() 33 | 34 | user = session.query(User).first() 35 | assert isinstance(user.website, url.furl) 36 | 37 | def test_scalar_attributes_get_coerced_to_objects(self, User): 38 | user = User(website='www.example.com') 39 | 40 | assert isinstance(user.website, url.furl) 41 | 42 | def test_compilation(self, User, session): 43 | query = sa.select(User.website) 44 | # the type should be cacheable and not throw exception 45 | session.execute(query) 46 | -------------------------------------------------------------------------------- /tests/functions/test_get_hybrid_properties.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | from sqlalchemy.ext.hybrid import hybrid_property 4 | 5 | from sqlalchemy_utils import get_hybrid_properties 6 | 7 | 8 | @pytest.fixture 9 | def Category(Base): 10 | class Category(Base): 11 | __tablename__ = 'category' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | name = sa.Column(sa.Unicode(255)) 14 | 15 | @hybrid_property 16 | def lowercase_name(self): 17 | return self.name.lower() 18 | 19 | @lowercase_name.expression 20 | def lowercase_name(cls): 21 | return sa.func.lower(cls.name) 22 | return Category 23 | 24 | 25 | class TestGetHybridProperties: 26 | 27 | def test_declarative_model(self, Category): 28 | assert ( 29 | list(get_hybrid_properties(Category).keys()) == 30 | ['lowercase_name'] 31 | ) 32 | 33 | def test_mapper(self, Category): 34 | assert ( 35 | list(get_hybrid_properties(sa.inspect(Category)).keys()) == 36 | ['lowercase_name'] 37 | ) 38 | 39 | def test_aliased_class(self, Category): 40 | props = get_hybrid_properties(sa.orm.aliased(Category)) 41 | assert list(props.keys()) == ['lowercase_name'] 42 | -------------------------------------------------------------------------------- /requirements/README.rst: -------------------------------------------------------------------------------- 1 | ``requirements/`` 2 | ################# 3 | 4 | This directory contains the files that manage dependencies for the project. 5 | 6 | Each subdirectory in this directory contains a ``pyproject.toml`` file 7 | with purpose-specific dependencies listed. 8 | 9 | 10 | How it's used 11 | ============= 12 | 13 | Tox is configured to use the exported ``requirements.txt`` files as needed. 14 | In addition, Read the Docs is configured to use ``docs/requirements.txt``. 15 | This helps ensure reproducible testing, linting, and documentation builds. 16 | 17 | 18 | How it's updated 19 | ================ 20 | 21 | A tox label, ``update``, ensures that dependencies can be easily updated, 22 | and that ``requirements.txt`` files are consistently re-exported. 23 | 24 | This can be invoked by running: 25 | 26 | .. code-block:: 27 | 28 | tox run -m update 29 | 30 | 31 | How to add dependencies 32 | ======================= 33 | 34 | New dependencies can be added to a given subdirectory's ``pyproject.toml`` 35 | by either manually modifying the file, or by running a command like: 36 | 37 | .. code-block:: 38 | 39 | poetry add --lock --directory "requirements/$DIR" $DEPENDENCY_NAME 40 | 41 | Either way, the dependencies must be re-exported: 42 | 43 | .. code-block:: 44 | 45 | tox run -m update 46 | -------------------------------------------------------------------------------- /sqlalchemy_utils/primitives/weekday.py: -------------------------------------------------------------------------------- 1 | from functools import total_ordering 2 | 3 | from .. import i18n 4 | from ..utils import str_coercible 5 | 6 | 7 | @str_coercible 8 | @total_ordering 9 | class WeekDay: 10 | NUM_WEEK_DAYS = 7 11 | 12 | def __init__(self, index): 13 | if not (0 <= index < self.NUM_WEEK_DAYS): 14 | raise ValueError('index must be between 0 and %d' % self.NUM_WEEK_DAYS) 15 | self.index = index 16 | 17 | def __eq__(self, other): 18 | if isinstance(other, WeekDay): 19 | return self.index == other.index 20 | else: 21 | return NotImplemented 22 | 23 | def __hash__(self): 24 | return hash(self.index) 25 | 26 | def __lt__(self, other): 27 | return self.position < other.position 28 | 29 | def __repr__(self): 30 | return f'{self.__class__.__name__}({self.index!r})' 31 | 32 | def __unicode__(self): 33 | return self.name 34 | 35 | def get_name(self, width='wide', context='format'): 36 | names = i18n.babel.dates.get_day_names(width, context, i18n.get_locale()) 37 | return names[self.index] 38 | 39 | @property 40 | def name(self): 41 | return self.get_name() 42 | 43 | @property 44 | def position(self): 45 | return (self.index - i18n.get_locale().first_week_day) % self.NUM_WEEK_DAYS 46 | -------------------------------------------------------------------------------- /tests/functions/test_get_column_key.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | 6 | from sqlalchemy_utils import get_column_key 7 | 8 | 9 | @pytest.fixture 10 | def Building(Base): 11 | class Building(Base): 12 | __tablename__ = 'building' 13 | id = sa.Column(sa.Integer, primary_key=True) 14 | name = sa.Column('_name', sa.Unicode(255)) 15 | return Building 16 | 17 | 18 | @pytest.fixture 19 | def Movie(Base): 20 | class Movie(Base): 21 | __tablename__ = 'movie' 22 | id = sa.Column(sa.Integer, primary_key=True) 23 | return Movie 24 | 25 | 26 | class TestGetColumnKey: 27 | 28 | def test_supports_aliases(self, Building): 29 | assert ( 30 | get_column_key(Building, Building.__table__.c.id) == 31 | 'id' 32 | ) 33 | assert ( 34 | get_column_key(Building, Building.__table__.c._name) == 35 | 'name' 36 | ) 37 | 38 | def test_supports_vague_matching_of_column_objects(self, Building): 39 | column = copy(Building.__table__.c._name) 40 | assert get_column_key(Building, column) == 'name' 41 | 42 | def test_throws_value_error_for_unknown_column(self, Building, Movie): 43 | with pytest.raises(sa.orm.exc.UnmappedColumnError): 44 | get_column_key(Building, Movie.__table__.c.id) 45 | -------------------------------------------------------------------------------- /tests/types/test_email.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import EmailType 5 | 6 | 7 | @pytest.fixture 8 | def User(Base): 9 | class User(Base): 10 | __tablename__ = 'user' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | email = sa.Column(EmailType) 13 | short_email = sa.Column(EmailType(length=70)) 14 | 15 | def __repr__(self): 16 | return 'User(%r)' % self.id 17 | return User 18 | 19 | 20 | class TestEmailType: 21 | def test_saves_email_as_lowercased(self, session, User): 22 | user = User(email='Someone@example.com') 23 | 24 | session.add(user) 25 | session.commit() 26 | 27 | user = session.query(User).first() 28 | assert user.email == 'someone@example.com' 29 | 30 | def test_literal_param(self, session, User): 31 | clause = User.email == 'Someone@example.com' 32 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) 33 | assert compiled == '"user".email = lower(\'Someone@example.com\')' 34 | 35 | def test_custom_length(self, session, User): 36 | assert User.short_email.type.impl.length == 70 37 | 38 | def test_compilation(self, User, session): 39 | query = sa.select(User.email) 40 | # the type should be cacheable and not throw exception 41 | session.execute(query) 42 | -------------------------------------------------------------------------------- /tests/functions/test_identity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.functions import identity 5 | 6 | 7 | class IdentityTestCase: 8 | 9 | @pytest.fixture 10 | def init_models(self, Building): 11 | pass 12 | 13 | def test_for_transient_class_without_id(self, Building): 14 | assert identity(Building()) == (None, ) 15 | 16 | def test_for_transient_class_with_id(self, session, Building): 17 | building = Building(name='Some building') 18 | session.add(building) 19 | session.flush() 20 | 21 | assert identity(building) == (building.id, ) 22 | 23 | def test_identity_for_class(self, Building): 24 | assert identity(Building) == (Building.id, ) 25 | 26 | 27 | class TestIdentity(IdentityTestCase): 28 | 29 | @pytest.fixture 30 | def Building(self, Base): 31 | class Building(Base): 32 | __tablename__ = 'building' 33 | id = sa.Column(sa.Integer, primary_key=True) 34 | name = sa.Column(sa.Unicode(255)) 35 | return Building 36 | 37 | 38 | class TestIdentityWithColumnAlias(IdentityTestCase): 39 | 40 | @pytest.fixture 41 | def Building(self, Base): 42 | class Building(Base): 43 | __tablename__ = 'building' 44 | id = sa.Column('_id', sa.Integer, primary_key=True) 45 | name = sa.Column(sa.Unicode(255)) 46 | return Building 47 | -------------------------------------------------------------------------------- /requirements/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.16 ; python_version >= "3.9" 2 | babel==2.17.0 ; python_version >= "3.9" 3 | certifi==2025.7.9 ; python_version >= "3.9" 4 | charset-normalizer==3.4.2 ; python_version >= "3.9" 5 | colorama==0.4.6 ; python_version >= "3.9" and sys_platform == "win32" 6 | docutils==0.21.2 ; python_version >= "3.9" 7 | idna==3.10 ; python_version >= "3.9" 8 | imagesize==1.4.1 ; python_version >= "3.9" 9 | importlib-metadata==8.7.0 ; python_version == "3.9" 10 | jinja2==3.1.6 ; python_version >= "3.9" 11 | markupsafe==3.0.2 ; python_version >= "3.9" 12 | packaging==25.0 ; python_version >= "3.9" 13 | pygments==2.19.2 ; python_version >= "3.9" 14 | requests==2.32.4 ; python_version >= "3.9" 15 | snowballstemmer==3.0.1 ; python_version >= "3.9" 16 | sphinx-rtd-theme==3.0.2 ; python_version >= "3.9" 17 | sphinx==7.4.7 ; python_version >= "3.9" 18 | sphinxcontrib-applehelp==2.0.0 ; python_version >= "3.9" 19 | sphinxcontrib-devhelp==2.0.0 ; python_version >= "3.9" 20 | sphinxcontrib-htmlhelp==2.1.0 ; python_version >= "3.9" 21 | sphinxcontrib-jquery==4.1 ; python_version >= "3.9" 22 | sphinxcontrib-jsmath==1.0.1 ; python_version >= "3.9" 23 | sphinxcontrib-qthelp==2.0.0 ; python_version >= "3.9" 24 | sphinxcontrib-serializinghtml==2.0.0 ; python_version >= "3.9" 25 | tomli==2.2.1 ; python_version >= "3.9" and python_version < "3.11" 26 | urllib3==2.6.0 ; python_version >= "3.9" 27 | zipp==3.23.0 ; python_version == "3.9" 28 | -------------------------------------------------------------------------------- /tests/generic_relationship/test_abstract_base_class.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | from sqlalchemy.ext.declarative import declared_attr 4 | 5 | from sqlalchemy_utils import generic_relationship 6 | 7 | from . import GenericRelationshipTestCase 8 | 9 | 10 | @pytest.fixture 11 | def Building(Base): 12 | class Building(Base): 13 | __tablename__ = 'building' 14 | id = sa.Column(sa.Integer, primary_key=True) 15 | return Building 16 | 17 | 18 | @pytest.fixture 19 | def User(Base): 20 | class User(Base): 21 | __tablename__ = 'user' 22 | id = sa.Column(sa.Integer, primary_key=True) 23 | return User 24 | 25 | 26 | @pytest.fixture 27 | def EventBase(Base): 28 | class EventBase(Base): 29 | __abstract__ = True 30 | 31 | object_type = sa.Column(sa.Unicode(255)) 32 | object_id = sa.Column(sa.Integer, nullable=False) 33 | 34 | @declared_attr 35 | def object(cls): 36 | return generic_relationship('object_type', 'object_id') 37 | return EventBase 38 | 39 | 40 | @pytest.fixture 41 | def Event(EventBase): 42 | class Event(EventBase): 43 | __tablename__ = 'event' 44 | id = sa.Column(sa.Integer, primary_key=True) 45 | return Event 46 | 47 | 48 | @pytest.fixture 49 | def init_models(Building, User, Event): 50 | pass 51 | 52 | 53 | class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase): 54 | pass 55 | -------------------------------------------------------------------------------- /tests/functions/test_cast_if.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | import sqlalchemy.orm 4 | 5 | from sqlalchemy_utils import cast_if 6 | 7 | 8 | @pytest.fixture(scope='class') 9 | def base(): 10 | return sqlalchemy.orm.declarative_base() 11 | 12 | 13 | @pytest.fixture(scope='class') 14 | def article_cls(base): 15 | class Article(base): 16 | __tablename__ = 'article' 17 | id = sa.Column(sa.Integer, primary_key=True) 18 | name = sa.Column(sa.String) 19 | name_synonym = sa.orm.synonym('name') 20 | 21 | return Article 22 | 23 | 24 | class TestCastIf: 25 | def test_column(self, article_cls): 26 | expr = article_cls.__table__.c.name 27 | assert cast_if(expr, sa.String) is expr 28 | 29 | def test_column_property(self, article_cls): 30 | expr = article_cls.name.property 31 | assert cast_if(expr, sa.String) is expr 32 | 33 | def test_instrumented_attribute(self, article_cls): 34 | expr = article_cls.name 35 | assert cast_if(expr, sa.String) is expr 36 | 37 | def test_synonym(self, article_cls): 38 | expr = article_cls.name_synonym 39 | assert cast_if(expr, sa.String) is expr 40 | 41 | def test_scalar_selectable(self, article_cls): 42 | expr = sa.select(article_cls.id).scalar_subquery() 43 | assert cast_if(expr, sa.Integer) is expr 44 | 45 | def test_scalar(self): 46 | assert cast_if('something', sa.String) == 'something' 47 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | This part of the documentation covers the installation of SQLAlchemy-Utils. 5 | 6 | Supported platforms 7 | ------------------- 8 | 9 | SQLAlchemy-Utils is currently tested against the following Python platforms: 10 | 11 | - CPython 3.9 12 | - CPython 3.10 13 | - CPython 3.11 14 | - CPython 3.12 15 | 16 | 17 | Installing an official release 18 | ------------------------------ 19 | 20 | You can install the most recent official SQLAlchemy-Utils version using 21 | pip_:: 22 | 23 | pip install sqlalchemy-utils 24 | 25 | .. _pip: https://pip.pypa.io/ 26 | 27 | Installing the development version 28 | ---------------------------------- 29 | 30 | To install the latest version of SQLAlchemy-Utils, you need first obtain a 31 | copy of the source. You can do that by cloning the git_ repository:: 32 | 33 | git clone git://github.com/kvesteri/sqlalchemy-utils.git 34 | 35 | Then you can install the source distribution using pip:: 36 | 37 | cd sqlalchemy-utils 38 | pip install -e . 39 | 40 | .. _git: https://git-scm.com/ 41 | 42 | Checking the installation 43 | ------------------------- 44 | 45 | To check that SQLAlchemy-Utils has been properly installed, type ``python`` 46 | from your shell. Then at the Python prompt, try to import SQLAlchemy-Utils, 47 | and check the installed version: 48 | 49 | .. parsed-literal:: 50 | 51 | >>> import sqlalchemy_utils 52 | >>> sqlalchemy_utils.__version__ 53 | |release| 54 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/email.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | 3 | from ..operators import CaseInsensitiveComparator 4 | 5 | 6 | class EmailType(sa.types.TypeDecorator): 7 | """ 8 | Provides a way for storing emails in a lower case. 9 | 10 | Example:: 11 | 12 | 13 | from sqlalchemy_utils import EmailType 14 | 15 | 16 | class User(Base): 17 | __tablename__ = 'user' 18 | id = sa.Column(sa.Integer, primary_key=True) 19 | name = sa.Column(sa.Unicode(255)) 20 | email = sa.Column(EmailType) 21 | 22 | 23 | user = User() 24 | user.email = 'John.Smith@foo.com' 25 | user.name = 'John Smith' 26 | session.add(user) 27 | session.commit() 28 | # Notice - email in filter() is lowercase. 29 | user = (session.query(User) 30 | .filter(User.email == 'john.smith@foo.com') 31 | .one()) 32 | assert user.name == 'John Smith' 33 | """ 34 | 35 | impl = sa.Unicode 36 | comparator_factory = CaseInsensitiveComparator 37 | cache_ok = True 38 | 39 | def __init__(self, length=255, *args, **kwargs): 40 | super().__init__(length=length, *args, **kwargs) 41 | 42 | def process_bind_param(self, value, dialect): 43 | if value is not None: 44 | return value.lower() 45 | return value 46 | 47 | @property 48 | def python_type(self): 49 | return self.impl.type.python_type 50 | -------------------------------------------------------------------------------- /tests/functions/test_has_changes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import has_changes 5 | 6 | 7 | @pytest.fixture 8 | def Article(Base): 9 | class Article(Base): 10 | __tablename__ = 'article_translation' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | title = sa.Column(sa.String(100)) 13 | return Article 14 | 15 | 16 | class TestHasChangesWithStringAttr: 17 | def test_without_changed_attr(self, Article): 18 | article = Article() 19 | assert not has_changes(article, 'title') 20 | 21 | def test_with_changed_attr(self, Article): 22 | article = Article(title='Some title') 23 | assert has_changes(article, 'title') 24 | 25 | 26 | class TestHasChangesWithMultipleAttrs: 27 | def test_without_changed_attr(self, Article): 28 | article = Article() 29 | assert not has_changes(article, ['title']) 30 | 31 | def test_with_changed_attr(self, Article): 32 | article = Article(title='Some title') 33 | assert has_changes(article, ['title', 'id']) 34 | 35 | 36 | class TestHasChangesWithExclude: 37 | def test_without_changed_attr(self, Article): 38 | article = Article() 39 | assert not has_changes(article, exclude=['id']) 40 | 41 | def test_with_changed_attr(self, Article): 42 | article = Article(title='Some title') 43 | assert has_changes(article, exclude=['id']) 44 | assert not has_changes(article, exclude=['title']) 45 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/enriched_datetime/pendulum_datetime.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from ...exceptions import ImproperlyConfigured 4 | 5 | pendulum = None 6 | try: 7 | import pendulum 8 | except ImportError: 9 | pass 10 | 11 | 12 | class PendulumDateTime: 13 | def __init__(self): 14 | if not pendulum: 15 | raise ImproperlyConfigured( 16 | "'pendulum' package is required to use 'PendulumDateTime'" 17 | ) 18 | 19 | def _coerce(self, impl, value): 20 | if value is not None: 21 | if isinstance(value, pendulum.DateTime): 22 | pass 23 | elif isinstance(value, (int, float)): 24 | value = pendulum.from_timestamp(value) 25 | elif isinstance(value, str) and value.isdigit(): 26 | value = pendulum.from_timestamp(int(value)) 27 | elif isinstance(value, datetime.datetime): 28 | value = pendulum.DateTime.instance(value) 29 | else: 30 | value = pendulum.parse(value) 31 | return value 32 | 33 | def process_bind_param(self, impl, value, dialect): 34 | if value: 35 | return self._coerce(impl, value).in_tz('UTC').naive() 36 | return value 37 | 38 | def process_result_value(self, impl, value, dialect): 39 | if value: 40 | return pendulum.DateTime.instance( 41 | value.replace(tzinfo=datetime.timezone.utc) 42 | ) 43 | return value 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012, Konsta Vesterinen 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | * The names of the contributors may not be used to endorse or promote products 16 | derived from this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY DIRECT, 22 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 23 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 26 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 27 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /tests/types/test_country.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import Country, CountryType, i18n # noqa 5 | 6 | 7 | @pytest.fixture 8 | def User(Base): 9 | class User(Base): 10 | __tablename__ = 'user' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | country = sa.Column(CountryType) 13 | 14 | def __repr__(self): 15 | return 'User(%r)' % self.id 16 | return User 17 | 18 | 19 | @pytest.fixture 20 | def init_models(User): 21 | pass 22 | 23 | 24 | @pytest.mark.skipif('i18n.babel is None') 25 | class TestCountryType: 26 | 27 | def test_parameter_processing(self, session, User): 28 | user = User( 29 | country=Country('FI') 30 | ) 31 | 32 | session.add(user) 33 | session.commit() 34 | 35 | user = session.query(User).first() 36 | assert user.country.name == 'Finland' 37 | 38 | def test_scalar_attributes_get_coerced_to_objects(self, User): 39 | user = User(country='FI') 40 | assert isinstance(user.country, Country) 41 | 42 | def test_literal_param(self, session, User): 43 | clause = User.country == 'FI' 44 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) 45 | assert compiled == '"user".country = \'FI\'' 46 | 47 | def test_compilation(self, User, session): 48 | query = sa.select(User.country) 49 | # the type should be cacheable and not throw exception 50 | session.execute(query) 51 | -------------------------------------------------------------------------------- /tests/types/test_ltree.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import Ltree, LtreeType 5 | 6 | 7 | @pytest.fixture 8 | def Section(Base): 9 | class Section(Base): 10 | __tablename__ = 'section' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | path = sa.Column(LtreeType) 13 | 14 | return Section 15 | 16 | 17 | @pytest.fixture 18 | def init_models(Section, connection): 19 | with connection.begin(): 20 | connection.execute(sa.text('CREATE EXTENSION IF NOT EXISTS ltree')) 21 | pass 22 | 23 | 24 | @pytest.mark.usefixtures('postgresql_dsn') 25 | class TestLTREE: 26 | def test_saves_path(self, session, Section): 27 | section = Section(path='1.1.2') 28 | 29 | session.add(section) 30 | session.commit() 31 | 32 | user = session.query(Section).first() 33 | assert user.path == '1.1.2' 34 | 35 | def test_scalar_attributes_get_coerced_to_objects(self, Section): 36 | section = Section(path='path.path') 37 | assert isinstance(section.path, Ltree) 38 | 39 | def test_literal_param(self, session, Section): 40 | clause = Section.path == 'path' 41 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) 42 | assert compiled == 'section.path = \'path\'' 43 | 44 | def test_compilation(self, Section, session): 45 | query = sa.select(Section.path) 46 | # the type should be cacheable and not throw exception 47 | session.execute(query) 48 | -------------------------------------------------------------------------------- /tests/functions/test_get_primary_keys.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import get_primary_keys 5 | 6 | try: 7 | from collections import OrderedDict 8 | except ImportError: 9 | from ordereddict import OrderedDict 10 | 11 | 12 | @pytest.fixture 13 | def Building(Base): 14 | class Building(Base): 15 | __tablename__ = 'building' 16 | id = sa.Column('_id', sa.Integer, primary_key=True) 17 | name = sa.Column('_name', sa.Unicode(255)) 18 | return Building 19 | 20 | 21 | class TestGetPrimaryKeys: 22 | 23 | def test_table(self, Building): 24 | assert get_primary_keys(Building.__table__) == OrderedDict({ 25 | '_id': Building.__table__.c._id 26 | }) 27 | 28 | def test_declarative_class(self, Building): 29 | assert get_primary_keys(Building) == OrderedDict({ 30 | 'id': Building.__table__.c._id 31 | }) 32 | 33 | def test_declarative_object(self, Building): 34 | assert get_primary_keys(Building()) == OrderedDict({ 35 | 'id': Building.__table__.c._id 36 | }) 37 | 38 | def test_class_alias(self, Building): 39 | alias = sa.orm.aliased(Building) 40 | assert get_primary_keys(alias) == OrderedDict({ 41 | 'id': Building.__table__.c._id 42 | }) 43 | 44 | def test_table_alias(self, Building): 45 | alias = sa.orm.aliased(Building.__table__) 46 | assert get_primary_keys(alias) == OrderedDict({ 47 | '_id': alias.c._id 48 | }) 49 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | 4 | def get_version() -> str: 5 | path = pathlib.Path(__file__).parent.parent / "sqlalchemy_utils/__init__.py" 6 | content = path.read_text() 7 | for line in content.splitlines(): 8 | if line.strip().startswith("__version__"): 9 | _, _, version_ = line.partition("=") 10 | return version_.strip("'\" ") 11 | 12 | 13 | # -- General configuration ----------------------------------------------------- 14 | 15 | # Add any Sphinx extension module names here, as strings. They can be extensions 16 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 17 | extensions = [ 18 | "sphinx.ext.autodoc", 19 | "sphinx.ext.doctest", 20 | "sphinx.ext.intersphinx", 21 | "sphinx.ext.todo", 22 | "sphinx.ext.coverage", 23 | "sphinx.ext.ifconfig", 24 | "sphinx.ext.viewcode", 25 | ] 26 | 27 | # Add any paths that contain templates here, relative to this directory. 28 | templates_path = ["_templates"] 29 | 30 | # The master toctree document. 31 | master_doc = "index" 32 | 33 | # General information about the project. 34 | project = "SQLAlchemy-Utils" 35 | copyright = "2013-2022, Konsta Vesterinen" 36 | 37 | version = get_version() 38 | release = version 39 | 40 | # List of patterns, relative to source directory, that match files and 41 | # directories to ignore when looking for source files. 42 | exclude_patterns = ["_build"] 43 | 44 | # The name of the Pygments (syntax highlighting) style to use. 45 | # pygments_style = "sphinx" 46 | 47 | html_theme = "sphinx_rtd_theme" 48 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/ip_address.py: -------------------------------------------------------------------------------- 1 | from ipaddress import ip_address 2 | 3 | from sqlalchemy import types 4 | 5 | from .scalar_coercible import ScalarCoercible 6 | 7 | 8 | class IPAddressType(ScalarCoercible, types.TypeDecorator): 9 | """ 10 | Changes IPAddress objects to a string representation on the way in and 11 | changes them back to IPAddress objects on the way out. 12 | 13 | :: 14 | 15 | 16 | from sqlalchemy_utils import IPAddressType 17 | 18 | 19 | class User(Base): 20 | __tablename__ = 'user' 21 | id = sa.Column(sa.Integer, autoincrement=True) 22 | name = sa.Column(sa.Unicode(255)) 23 | ip_address = sa.Column(IPAddressType) 24 | 25 | 26 | user = User() 27 | user.ip_address = '123.123.123.123' 28 | session.add(user) 29 | session.commit() 30 | 31 | user.ip_address # IPAddress object 32 | """ 33 | 34 | impl = types.Unicode(50) 35 | cache_ok = True 36 | 37 | def __init__(self, max_length=50, *args, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | self.impl = types.Unicode(max_length) 40 | 41 | def process_bind_param(self, value, dialect): 42 | return str(value) if value else None 43 | 44 | def process_result_value(self, value, dialect): 45 | return ip_address(value) if value else None 46 | 47 | def _coerce(self, value): 48 | return ip_address(value) if value else None 49 | 50 | @property 51 | def python_type(self): 52 | return self.impl.type.python_type 53 | -------------------------------------------------------------------------------- /tests/functions/test_get_type.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | import sqlalchemy.orm 4 | 5 | from sqlalchemy_utils import get_type 6 | 7 | 8 | @pytest.fixture 9 | def User(Base): 10 | class User(Base): 11 | __tablename__ = 'user' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | return User 14 | 15 | 16 | @pytest.fixture 17 | def Article(Base, User): 18 | class Article(Base): 19 | __tablename__ = 'article' 20 | id = sa.Column(sa.Integer, primary_key=True) 21 | 22 | author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) 23 | author = sa.orm.relationship(User) 24 | 25 | some_property = sa.orm.column_property( 26 | sa.func.coalesce(id, 1) 27 | ) 28 | return Article 29 | 30 | 31 | class TestGetType: 32 | 33 | def test_instrumented_attribute(self, Article): 34 | assert isinstance(get_type(Article.id), sa.Integer) 35 | 36 | def test_column_property(self, Article): 37 | assert isinstance(get_type(Article.id.property), sa.Integer) 38 | 39 | def test_column(self, Article): 40 | assert isinstance(get_type(Article.__table__.c.id), sa.Integer) 41 | 42 | def test_calculated_column_property(self, Article): 43 | assert isinstance(get_type(Article.some_property), sa.Integer) 44 | 45 | def test_relationship_property(self, Article, User): 46 | assert get_type(Article.author) == User 47 | 48 | def test_scalar_select(self, Article): 49 | query = sa.select(Article.id).scalar_subquery() 50 | assert isinstance(get_type(query), sa.Integer) 51 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/enriched_datetime/enriched_date_type.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from ..scalar_coercible import ScalarCoercible 4 | from .pendulum_date import PendulumDate 5 | 6 | 7 | class EnrichedDateType(types.TypeDecorator, ScalarCoercible): 8 | """ 9 | Supported for pendulum only. 10 | 11 | Example:: 12 | 13 | 14 | from sqlalchemy_utils import EnrichedDateType 15 | import pendulum 16 | 17 | 18 | class User(Base): 19 | __tablename__ = 'user' 20 | id = sa.Column(sa.Integer, primary_key=True) 21 | birthday = sa.Column(EnrichedDateType(type="pendulum")) 22 | 23 | 24 | user = User() 25 | user.birthday = pendulum.datetime(year=1995, month=7, day=11) 26 | session.add(user) 27 | session.commit() 28 | """ 29 | 30 | impl = types.Date 31 | cache_ok = True 32 | 33 | def __init__(self, date_processor=PendulumDate, *args, **kwargs): 34 | super().__init__(*args, **kwargs) 35 | self.date_object = date_processor() 36 | 37 | def _coerce(self, value): 38 | return self.date_object._coerce(self.impl, value) 39 | 40 | def process_bind_param(self, value, dialect): 41 | return self.date_object.process_bind_param(self.impl, value, dialect) 42 | 43 | def process_result_value(self, value, dialect): 44 | return self.date_object.process_result_value(self.impl, value, dialect) 45 | 46 | def process_literal_param(self, value, dialect): 47 | return value 48 | 49 | @property 50 | def python_type(self): 51 | return self.impl.type.python_type 52 | -------------------------------------------------------------------------------- /tests/test_case_insensitive_comparator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import EmailType 5 | 6 | 7 | @pytest.fixture 8 | def User(Base): 9 | class User(Base): 10 | __tablename__ = 'user' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | email = sa.Column(EmailType) 13 | 14 | def __repr__(self): 15 | return 'Building(%r)' % self.id 16 | return User 17 | 18 | 19 | @pytest.fixture 20 | def init_models(User): 21 | pass 22 | 23 | 24 | class TestCaseInsensitiveComparator: 25 | 26 | def test_supports_equals(self, session, User): 27 | query = ( 28 | session.query(User) 29 | .filter(User.email == 'email@example.com') 30 | ) 31 | 32 | assert 'user.email = lower(?)' in str(query) 33 | 34 | def test_supports_in_(self, session, User): 35 | query = ( 36 | session.query(User) 37 | .filter(User.email.in_(['email@example.com', 'a'])) 38 | ) 39 | assert ( 40 | 'user.email IN (lower(?), lower(?))' 41 | in str(query) 42 | ) 43 | 44 | def test_supports_notin_(self, session, User): 45 | query = ( 46 | session.query(User) 47 | .filter(User.email.notin_(['email@example.com', 'a'])) 48 | ) 49 | assert ( 50 | 'user.email NOT IN (' 51 | in str(query) 52 | ) 53 | 54 | def test_does_not_apply_lower_to_types_that_are_already_lowercased( 55 | self, 56 | User 57 | ): 58 | assert str(User.email == User.email) == ( 59 | '"user".email = "user".email' 60 | ) 61 | -------------------------------------------------------------------------------- /tests/aggregate/test_with_ondelete_cascade.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.aggregates import aggregated 5 | 6 | 7 | @pytest.fixture 8 | def Thread(Base): 9 | class Thread(Base): 10 | __tablename__ = 'thread' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | name = sa.Column(sa.Unicode(255)) 13 | 14 | @aggregated('comments', sa.Column(sa.Integer, default=0)) 15 | def comment_count(self): 16 | return sa.func.count('1') 17 | 18 | comments = sa.orm.relationship( 19 | 'Comment', 20 | passive_deletes=True, 21 | backref='thread' 22 | ) 23 | return Thread 24 | 25 | 26 | @pytest.fixture 27 | def Comment(Base): 28 | class Comment(Base): 29 | __tablename__ = 'comment' 30 | id = sa.Column(sa.Integer, primary_key=True) 31 | content = sa.Column(sa.Unicode(255)) 32 | thread_id = sa.Column( 33 | sa.Integer, 34 | sa.ForeignKey('thread.id', ondelete='CASCADE') 35 | ) 36 | return Comment 37 | 38 | 39 | @pytest.fixture 40 | def init_models(Thread, Comment): 41 | pass 42 | 43 | 44 | @pytest.mark.usefixtures('postgresql_dsn') 45 | class TestAggregateValueGenerationWithCascadeDelete: 46 | 47 | def test_something(self, session, Thread, Comment): 48 | thread = Thread() 49 | thread.name = 'some article name' 50 | session.add(thread) 51 | comment = Comment(content='Some content', thread=thread) 52 | session.add(comment) 53 | session.commit() 54 | session.expire_all() 55 | session.delete(thread) 56 | session.commit() 57 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/arrow.py: -------------------------------------------------------------------------------- 1 | from ..exceptions import ImproperlyConfigured 2 | from .enriched_datetime import ArrowDateTime 3 | from .enriched_datetime.enriched_datetime_type import EnrichedDateTimeType 4 | 5 | arrow = None 6 | try: 7 | import arrow 8 | except ImportError: 9 | pass 10 | 11 | 12 | class ArrowType(EnrichedDateTimeType): 13 | """ 14 | ArrowType provides way of saving Arrow_ objects into database. It 15 | automatically changes Arrow_ objects to datetime objects on the way in and 16 | datetime objects back to Arrow_ objects on the way out (when querying 17 | database). ArrowType needs Arrow_ library installed. 18 | 19 | .. _Arrow: https://github.com/arrow-py/arrow 20 | 21 | :: 22 | 23 | from datetime import datetime 24 | from sqlalchemy_utils import ArrowType 25 | import arrow 26 | 27 | 28 | class Article(Base): 29 | __tablename__ = 'article' 30 | id = sa.Column(sa.Integer, primary_key=True) 31 | name = sa.Column(sa.Unicode(255)) 32 | created_at = sa.Column(ArrowType) 33 | 34 | 35 | 36 | article = Article(created_at=arrow.utcnow()) 37 | 38 | 39 | As you may expect all the arrow goodies come available: 40 | 41 | :: 42 | 43 | 44 | article.created_at = article.created_at.replace(hours=-1) 45 | 46 | article.created_at.humanize() 47 | # 'an hour ago' 48 | 49 | """ 50 | 51 | cache_ok = True 52 | 53 | def __init__(self, *args, **kwargs): 54 | if not arrow: 55 | raise ImproperlyConfigured("'arrow' package is required to use 'ArrowType'") 56 | 57 | super().__init__(datetime_processor=ArrowDateTime, *args, **kwargs) 58 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/enriched_datetime/enriched_datetime_type.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from ..scalar_coercible import ScalarCoercible 4 | from .pendulum_datetime import PendulumDateTime 5 | 6 | 7 | class EnrichedDateTimeType(types.TypeDecorator, ScalarCoercible): 8 | """ 9 | Supported for arrow and pendulum. 10 | 11 | Example:: 12 | 13 | 14 | from sqlalchemy_utils import EnrichedDateTimeType 15 | import pendulum 16 | 17 | 18 | class User(Base): 19 | __tablename__ = 'user' 20 | id = sa.Column(sa.Integer, primary_key=True) 21 | created_at = sa.Column(EnrichedDateTimeType(type="pendulum")) 22 | # created_at = sa.Column(EnrichedDateTimeType(type="arrow")) 23 | 24 | 25 | user = User() 26 | user.created_at = pendulum.now() 27 | session.add(user) 28 | session.commit() 29 | """ 30 | 31 | impl = types.DateTime 32 | cache_ok = True 33 | 34 | def __init__(self, datetime_processor=PendulumDateTime, *args, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.dt_object = datetime_processor() 37 | 38 | def _coerce(self, value): 39 | return self.dt_object._coerce(self.impl, value) 40 | 41 | def process_bind_param(self, value, dialect): 42 | return self.dt_object.process_bind_param(self.impl, value, dialect) 43 | 44 | def process_result_value(self, value, dialect): 45 | return self.dt_object.process_result_value(self.impl, value, dialect) 46 | 47 | def process_literal_param(self, value, dialect): 48 | return value 49 | 50 | @property 51 | def python_type(self): 52 | return self.impl.type.python_type 53 | -------------------------------------------------------------------------------- /tests/types/test_currency.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import Currency, CurrencyType, i18n 5 | 6 | 7 | @pytest.fixture 8 | def set_get_locale(): 9 | i18n.get_locale = lambda: i18n.babel.Locale('en') 10 | 11 | 12 | @pytest.fixture 13 | def User(Base): 14 | class User(Base): 15 | __tablename__ = 'user' 16 | id = sa.Column(sa.Integer, primary_key=True) 17 | currency = sa.Column(CurrencyType) 18 | 19 | def __repr__(self): 20 | return 'User(%r)' % self.id 21 | 22 | return User 23 | 24 | 25 | @pytest.fixture 26 | def init_models(User): 27 | pass 28 | 29 | 30 | @pytest.mark.skipif('i18n.babel is None') 31 | class TestCurrencyType: 32 | def test_parameter_processing(self, session, User, set_get_locale): 33 | user = User( 34 | currency=Currency('USD') 35 | ) 36 | 37 | session.add(user) 38 | session.commit() 39 | 40 | user = session.query(User).first() 41 | assert user.currency.name == 'US Dollar' 42 | 43 | def test_scalar_attributes_get_coerced_to_objects( 44 | self, 45 | User, 46 | set_get_locale 47 | ): 48 | user = User(currency='USD') 49 | assert isinstance(user.currency, Currency) 50 | 51 | def test_literal_param(self, session, User): 52 | clause = User.currency == 'USD' 53 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) 54 | assert compiled == '"user".currency = \'USD\'' 55 | 56 | def test_compilation(self, User, session): 57 | query = sa.select(User.currency) 58 | # the type should be cacheable and not throw exception 59 | session.execute(query) 60 | -------------------------------------------------------------------------------- /sqlalchemy_utils/expressions.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | from sqlalchemy.dialects import postgresql 3 | from sqlalchemy.ext.compiler import compiles 4 | from sqlalchemy.sql.expression import ColumnElement, FunctionElement 5 | from sqlalchemy.sql.functions import GenericFunction 6 | 7 | from .functions.orm import quote 8 | 9 | 10 | class array_get(FunctionElement): 11 | name = 'array_get' 12 | 13 | 14 | @compiles(array_get) 15 | def compile_array_get(element, compiler, **kw): 16 | args = list(element.clauses) 17 | if len(args) != 2: 18 | raise Exception( 19 | "Function 'array_get' expects two arguments (%d given)." % len(args) 20 | ) 21 | 22 | if not hasattr(args[1], 'value') or not isinstance(args[1].value, int): 23 | raise Exception('Second argument should be an integer.') 24 | return f'({compiler.process(args[0])})[{sa.text(str(args[1].value + 1))}]' 25 | 26 | 27 | class row_to_json(GenericFunction): 28 | name = 'row_to_json' 29 | type = postgresql.JSON 30 | 31 | 32 | @compiles(row_to_json, 'postgresql') 33 | def compile_row_to_json(element, compiler, **kw): 34 | return f'{element.name}({compiler.process(element.clauses)})' 35 | 36 | 37 | class json_array_length(GenericFunction): 38 | name = 'json_array_length' 39 | type = sa.Integer 40 | 41 | 42 | @compiles(json_array_length, 'postgresql') 43 | def compile_json_array_length(element, compiler, **kw): 44 | return f'{element.name}({compiler.process(element.clauses)})' 45 | 46 | 47 | class Asterisk(ColumnElement): 48 | def __init__(self, selectable): 49 | self.selectable = selectable 50 | 51 | 52 | @compiles(Asterisk) 53 | def compile_asterisk(element, compiler, **kw): 54 | return '%s.*' % quote(compiler.dialect, element.selectable.name) 55 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/url.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from .scalar_coercible import ScalarCoercible 4 | 5 | furl = None 6 | try: 7 | from furl import furl 8 | except ImportError: 9 | pass 10 | 11 | 12 | class URLType(ScalarCoercible, types.TypeDecorator): 13 | """ 14 | URLType stores furl_ objects into database. 15 | 16 | .. _furl: https://github.com/gruns/furl 17 | 18 | :: 19 | 20 | from sqlalchemy_utils import URLType 21 | from furl import furl 22 | 23 | 24 | class User(Base): 25 | __tablename__ = 'user' 26 | 27 | id = sa.Column(sa.Integer, primary_key=True) 28 | website = sa.Column(URLType) 29 | 30 | 31 | user = User(website='www.example.com') 32 | 33 | # website is coerced to furl object, hence all nice furl operations 34 | # come available 35 | user.website.args['some_argument'] = '12' 36 | 37 | print user.website 38 | # www.example.com?some_argument=12 39 | """ 40 | 41 | impl = types.UnicodeText 42 | cache_ok = True 43 | 44 | def process_bind_param(self, value, dialect): 45 | if furl is not None and isinstance(value, furl): 46 | return str(value) 47 | 48 | if isinstance(value, str): 49 | return value 50 | 51 | def process_result_value(self, value, dialect): 52 | if furl is None: 53 | return value 54 | 55 | if value is not None: 56 | return furl(value) 57 | 58 | def _coerce(self, value): 59 | if furl is None: 60 | return value 61 | 62 | if value is not None and not isinstance(value, furl): 63 | return furl(value) 64 | return value 65 | 66 | @property 67 | def python_type(self): 68 | return self.impl.type.python_type 69 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/country.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from ..primitives import Country 4 | from .scalar_coercible import ScalarCoercible 5 | 6 | 7 | class CountryType(ScalarCoercible, types.TypeDecorator): 8 | """ 9 | Changes :class:`.Country` objects to a string representation on the way in 10 | and changes them back to :class:`.Country` objects on the way out. 11 | 12 | In order to use CountryType you need to install Babel_ first. 13 | 14 | .. _Babel: https://babel.pocoo.org/ 15 | 16 | :: 17 | 18 | 19 | from sqlalchemy_utils import CountryType, Country 20 | 21 | 22 | class User(Base): 23 | __tablename__ = 'user' 24 | id = sa.Column(sa.Integer, autoincrement=True) 25 | name = sa.Column(sa.Unicode(255)) 26 | country = sa.Column(CountryType) 27 | 28 | 29 | user = User() 30 | user.country = Country('FI') 31 | session.add(user) 32 | session.commit() 33 | 34 | user.country # Country('FI') 35 | user.country.name # Finland 36 | 37 | print user.country # Finland 38 | 39 | 40 | CountryType is scalar coercible:: 41 | 42 | 43 | user.country = 'US' 44 | user.country # Country('US') 45 | """ 46 | 47 | impl = types.String(2) 48 | python_type = Country 49 | cache_ok = True 50 | 51 | def process_bind_param(self, value, dialect): 52 | if isinstance(value, Country): 53 | return value.code 54 | 55 | if isinstance(value, str): 56 | return value 57 | 58 | def process_result_value(self, value, dialect): 59 | if value is not None: 60 | return Country(value) 61 | 62 | def _coerce(self, value): 63 | if value is not None and not isinstance(value, Country): 64 | return Country(value) 65 | return value 66 | -------------------------------------------------------------------------------- /sqlalchemy_utils/primitives/weekdays.py: -------------------------------------------------------------------------------- 1 | from ..utils import str_coercible 2 | from .weekday import WeekDay 3 | 4 | 5 | @str_coercible 6 | class WeekDays: 7 | def __init__(self, bit_string_or_week_days): 8 | if isinstance(bit_string_or_week_days, str): 9 | self._days = set() 10 | 11 | if len(bit_string_or_week_days) != WeekDay.NUM_WEEK_DAYS: 12 | raise ValueError( 13 | 'Bit string must be {} characters long.'.format( 14 | WeekDay.NUM_WEEK_DAYS 15 | ) 16 | ) 17 | 18 | for index, bit in enumerate(bit_string_or_week_days): 19 | if bit not in '01': 20 | raise ValueError('Bit string may only contain zeroes and ones.') 21 | if bit == '1': 22 | self._days.add(WeekDay(index)) 23 | elif isinstance(bit_string_or_week_days, WeekDays): 24 | self._days = bit_string_or_week_days._days 25 | else: 26 | self._days = set(bit_string_or_week_days) 27 | 28 | def __eq__(self, other): 29 | if isinstance(other, WeekDays): 30 | return self._days == other._days 31 | elif isinstance(other, str): 32 | return self.as_bit_string() == other 33 | else: 34 | return NotImplemented 35 | 36 | def __iter__(self): 37 | yield from sorted(self._days) 38 | 39 | def __contains__(self, value): 40 | return value in self._days 41 | 42 | def __repr__(self): 43 | return f'{self.__class__.__name__}({self.as_bit_string()!r})' 44 | 45 | def __unicode__(self): 46 | return ', '.join(str(day) for day in self) 47 | 48 | def as_bit_string(self): 49 | return ''.join( 50 | '1' if WeekDay(index) in self._days else '0' 51 | for index in range(WeekDay.NUM_WEEK_DAYS) 52 | ) 53 | -------------------------------------------------------------------------------- /tests/observes/test_o2o_o2o.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.observer import observes 5 | 6 | 7 | @pytest.fixture 8 | def Device(Base): 9 | class Device(Base): 10 | __tablename__ = 'device' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | name = sa.Column(sa.String) 13 | return Device 14 | 15 | 16 | @pytest.fixture 17 | def Order(Base): 18 | class Order(Base): 19 | __tablename__ = 'order' 20 | id = sa.Column(sa.Integer, primary_key=True) 21 | 22 | device_id = sa.Column( 23 | 'device', sa.ForeignKey('device.id'), nullable=False 24 | ) 25 | device = sa.orm.relationship('Device', backref='orders') 26 | return Order 27 | 28 | 29 | @pytest.fixture 30 | def SalesInvoice(Base): 31 | class SalesInvoice(Base): 32 | __tablename__ = 'sales_invoice' 33 | id = sa.Column(sa.Integer, primary_key=True) 34 | order_id = sa.Column( 35 | 'order', 36 | sa.ForeignKey('order.id'), 37 | nullable=False 38 | ) 39 | order = sa.orm.relationship( 40 | 'Order', 41 | backref=sa.orm.backref( 42 | 'invoice', 43 | uselist=False 44 | ) 45 | ) 46 | device_name = sa.Column(sa.String) 47 | 48 | @observes('order.device') 49 | def process_device(self, device): 50 | self.device_name = device.name 51 | 52 | return SalesInvoice 53 | 54 | 55 | @pytest.fixture 56 | def init_models(Device, Order, SalesInvoice): 57 | pass 58 | 59 | 60 | @pytest.mark.usefixtures('postgresql_dsn') 61 | class TestObservesForOneToManyToOneToMany: 62 | 63 | def test_observable_root_obj_is_none(self, session, Device, Order): 64 | order = Order(device=Device(name='Something')) 65 | session.add(order) 66 | session.flush() 67 | -------------------------------------------------------------------------------- /tests/test_instant_defaults_listener.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | 6 | from sqlalchemy_utils.listeners import force_instant_defaults 7 | 8 | force_instant_defaults() 9 | 10 | 11 | @pytest.fixture 12 | def Article(Base): 13 | class Article(Base): 14 | __tablename__ = 'article' 15 | id = sa.Column(sa.Integer, primary_key=True) 16 | name = sa.Column(sa.Unicode(255), default='Some article') 17 | created_at = sa.Column(sa.DateTime, default=datetime.now) 18 | _byline = sa.Column(sa.Unicode(255), default='Default byline') 19 | 20 | @property 21 | def byline(self): 22 | return self._byline 23 | 24 | @byline.setter 25 | def byline(self, value): 26 | self._byline = value 27 | 28 | return Article 29 | 30 | 31 | @pytest.fixture 32 | def Document(Base): 33 | class Document(Base): 34 | __tablename__ = 'document' 35 | id = sa.Column( 36 | sa.Integer, 37 | sa.Sequence('document_id_seq'), 38 | primary_key=True 39 | ) 40 | title = sa.Column(sa.Unicode(255), default='Untitled') 41 | 42 | return Document 43 | 44 | 45 | class TestInstantDefaultListener: 46 | def test_assigns_defaults_on_object_construction(self, Article): 47 | article = Article() 48 | assert article.name == 'Some article' 49 | 50 | def test_callables_as_defaults(self, Article): 51 | article = Article() 52 | assert isinstance(article.created_at, datetime) 53 | 54 | def test_override_default_with_setter_function(self, Article): 55 | article = Article(byline='provided byline') 56 | assert article.byline == 'provided byline' 57 | 58 | def test_handles_sequence_defaults(self, Document): 59 | document = Document() 60 | assert document.title == 'Untitled' 61 | assert document.id is None 62 | -------------------------------------------------------------------------------- /tests/functions/test_get_columns.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import get_columns 5 | 6 | 7 | @pytest.fixture 8 | def Building(Base): 9 | class Building(Base): 10 | __tablename__ = 'building' 11 | id = sa.Column('_id', sa.Integer, primary_key=True) 12 | name = sa.Column('_name', sa.Unicode(255)) 13 | return Building 14 | 15 | 16 | @pytest.fixture 17 | def columns(): 18 | return ['_id', '_name'] 19 | 20 | 21 | class TestGetColumns: 22 | def test_table(self, Building, columns): 23 | assert [c.name for c in get_columns(Building.__table__)] == columns 24 | 25 | def test_instrumented_attribute(self, Building): 26 | assert get_columns(Building.id) == [Building.__table__.c._id] 27 | 28 | def test_column_property(self, Building): 29 | assert get_columns(Building.id.property) == [ 30 | Building.__table__.c._id 31 | ] 32 | 33 | def test_column(self, Building): 34 | assert get_columns(Building.__table__.c._id) == [ 35 | Building.__table__.c._id 36 | ] 37 | 38 | def test_declarative_class(self, Building, columns): 39 | assert [c.name for c in get_columns(Building).values()] == columns 40 | 41 | def test_declarative_object(self, Building, columns): 42 | assert [c.name for c in get_columns(Building()).values()] == columns 43 | 44 | def test_mapper(self, Building, columns): 45 | assert [ 46 | c.name for c in get_columns(Building.__mapper__).values() 47 | ] == columns 48 | 49 | def test_class_alias(self, Building, columns): 50 | assert [ 51 | c.name for c in get_columns(sa.orm.aliased(Building)).values() 52 | ] == columns 53 | 54 | def test_table_alias(self, Building, columns): 55 | alias = sa.orm.aliased(Building.__table__) 56 | assert [c.name for c in get_columns(alias).values()] == columns 57 | -------------------------------------------------------------------------------- /tests/observes/test_dynamic_relationship.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | from sqlalchemy.orm import dynamic_loader 4 | 5 | from sqlalchemy_utils.observer import observes 6 | 7 | 8 | @pytest.fixture 9 | def Director(Base): 10 | class Director(Base): 11 | __tablename__ = 'director' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | name = sa.Column(sa.String) 14 | movies = dynamic_loader('Movie', back_populates='director') 15 | 16 | return Director 17 | 18 | 19 | @pytest.fixture 20 | def Movie(Base, Director): 21 | class Movie(Base): 22 | __tablename__ = 'movie' 23 | id = sa.Column(sa.Integer, primary_key=True) 24 | name = sa.Column(sa.String) 25 | director_id = sa.Column(sa.Integer, sa.ForeignKey(Director.id)) 26 | director = sa.orm.relationship(Director, back_populates='movies') 27 | director_name = sa.Column(sa.String) 28 | 29 | @observes('director') 30 | def director_observer(self, director): 31 | self.director_name = director.name 32 | 33 | return Movie 34 | 35 | 36 | @pytest.fixture 37 | def init_models(Director, Movie): 38 | pass 39 | 40 | 41 | @pytest.mark.usefixtures('postgresql_dsn') 42 | class TestObservesForDynamicRelationship: 43 | def test_add_observed_object(self, session, Director, Movie): 44 | steven = Director(name='Steven Spielberg') 45 | session.add(steven) 46 | jaws = Movie(name='Jaws', director=steven) 47 | session.add(jaws) 48 | session.commit() 49 | assert jaws.director_name == 'Steven Spielberg' 50 | 51 | def test_add_observed_object_from_backref(self, session, Director, Movie): 52 | jaws = Movie(name='Jaws') 53 | steven = Director(name='Steven Spielberg', movies=[jaws]) 54 | session.add(steven) 55 | session.add(jaws) 56 | session.commit() 57 | assert jaws.director_name == 'Steven Spielberg' 58 | -------------------------------------------------------------------------------- /tests/types/encrypted/test_padding.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sqlalchemy_utils.types.encrypted.padding import ( 4 | InvalidPaddingError, 5 | PKCS5Padding 6 | ) 7 | 8 | 9 | class TestPkcs5Padding: 10 | def setup_method(self): 11 | self.BLOCK_SIZE = 8 12 | self.padder = PKCS5Padding(self.BLOCK_SIZE) 13 | 14 | def test_various_lengths_roundtrip(self): 15 | for number in range(0, 3 * self.BLOCK_SIZE): 16 | val = b'*' * number 17 | padded = self.padder.pad(val) 18 | unpadded = self.padder.unpad(padded) 19 | assert val == unpadded, 'Round trip error for length %d' % number 20 | 21 | def test_invalid_unpad(self): 22 | with pytest.raises(InvalidPaddingError): 23 | self.padder.unpad(None) 24 | with pytest.raises(InvalidPaddingError): 25 | self.padder.unpad(b'') 26 | with pytest.raises(InvalidPaddingError): 27 | self.padder.unpad(b'\01') 28 | with pytest.raises(InvalidPaddingError): 29 | self.padder.unpad((b'*' * (self.BLOCK_SIZE - 1)) + b'\00') 30 | with pytest.raises(InvalidPaddingError): 31 | self.padder.unpad((b'*' * self.BLOCK_SIZE) + b'\01') 32 | 33 | def test_pad_longer_than_block(self): 34 | with pytest.raises(InvalidPaddingError): 35 | self.padder.unpad( 36 | 'x' * (self.BLOCK_SIZE - 1) + 37 | chr(self.BLOCK_SIZE + 1) * (self.BLOCK_SIZE + 1) 38 | ) 39 | 40 | def test_incorrect_padding(self): 41 | # Hard-coded for blocksize of 8 42 | assert self.padder.unpad(b'1234\04\04\04\04') == b'1234' 43 | with pytest.raises(InvalidPaddingError): 44 | self.padder.unpad(b'1234\02\04\04\04') 45 | with pytest.raises(InvalidPaddingError): 46 | self.padder.unpad(b'1234\04\02\04\04') 47 | with pytest.raises(InvalidPaddingError): 48 | self.padder.unpad(b'1234\04\04\02\04') 49 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # SQLAlchemy-Utils Release Process 2 | 3 | This document outlines the process for creating a new release of SQLAlchemy-Utils. 4 | 5 | ## Prerequisites 6 | 7 | - Access to the GitHub repository with push permissions 8 | - PyPI account with API token for the project 9 | - `uv` package manager installed 10 | - `gh` CLI tool installed (optional, for GitHub releases) 11 | 12 | ## Release Process 13 | 14 | ### 1. Prepare the Release 15 | 16 | 1. **Update CHANGES.rst** 17 | - Move items from "Unreleased changes" section to a new version section 18 | - Follow the existing format: `X.Y.Z (YYYY-MM-DD)` 19 | - Add a new "Unreleased changes" section 20 | 21 | 2. **Update version in pyproject.toml** 22 | ```bash 23 | # Edit pyproject.toml and update the version field 24 | version = "X.Y.Z" 25 | ``` 26 | 27 | 3. **Commit the changes** 28 | ```bash 29 | git add CHANGES.rst pyproject.toml 30 | git commit -m "Bump version to X.Y.Z" 31 | git push origin master 32 | ``` 33 | 34 | ### 2. Create Git Tag 35 | 36 | ```bash 37 | git tag -a X.Y.Z -m "Release version X.Y.Z" 38 | git push origin X.Y.Z 39 | ``` 40 | 41 | ### 3. Build and Publish to PyPI 42 | 43 | 1. **Build the package** 44 | ```bash 45 | uv build 46 | ``` 47 | 48 | 2. **Publish to PyPI** 49 | ```bash 50 | uv publish --token pypi-your-api-token-here 51 | ``` 52 | 53 | Alternatively, set the token as an environment variable: 54 | ```bash 55 | export UV_PUBLISH_TOKEN="pypi-your-api-token-here" 56 | uv publish 57 | ``` 58 | 59 | ### 4. Create GitHub Release 60 | 61 | ```bash 62 | gh release create X.Y.Z --title "X.Y.Z" --notes-from-tag 63 | ``` 64 | 65 | Or create the release manually on GitHub using the tag. 66 | 67 | ## Version Numbering 68 | 69 | This project follows [Semantic Versioning](https://semver.org/): 70 | - **MAJOR** version for incompatible API changes 71 | - **MINOR** version for backwards-compatible functionality additions 72 | - **PATCH** version for backwards-compatible bug fixes 73 | -------------------------------------------------------------------------------- /tests/aggregate/test_search_vectors.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import aggregated, TSVectorType 5 | 6 | 7 | def tsvector_reduce_concat(vectors): 8 | return sa.sql.expression.cast( 9 | sa.func.coalesce( 10 | sa.func.array_to_string(sa.func.array_agg(vectors), ' ') 11 | ), 12 | TSVectorType 13 | ) 14 | 15 | 16 | @pytest.fixture 17 | def Product(Base): 18 | class Product(Base): 19 | __tablename__ = 'product' 20 | id = sa.Column(sa.Integer, primary_key=True) 21 | name = sa.Column(sa.Unicode(255)) 22 | price = sa.Column(sa.Numeric) 23 | 24 | catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) 25 | return Product 26 | 27 | 28 | @pytest.fixture 29 | def Catalog(Base, Product): 30 | class Catalog(Base): 31 | __tablename__ = 'catalog' 32 | id = sa.Column(sa.Integer, primary_key=True) 33 | name = sa.Column(sa.Unicode(255)) 34 | 35 | @aggregated('products', sa.Column(TSVectorType)) 36 | def product_search_vector(self): 37 | return tsvector_reduce_concat( 38 | sa.func.to_tsvector(Product.name) 39 | ) 40 | 41 | products = sa.orm.relationship('Product', backref='catalog') 42 | return Catalog 43 | 44 | 45 | @pytest.fixture 46 | def init_models(Product, Catalog): 47 | pass 48 | 49 | 50 | @pytest.mark.usefixtures('postgresql_dsn') 51 | class TestSearchVectorAggregates: 52 | 53 | def test_assigns_aggregates_on_insert(self, session, Product, Catalog): 54 | catalog = Catalog( 55 | name='Some catalog' 56 | ) 57 | session.add(catalog) 58 | session.commit() 59 | product = Product( 60 | name='Product XYZ', 61 | catalog=catalog 62 | ) 63 | session.add(product) 64 | session.commit() 65 | session.refresh(catalog) 66 | assert catalog.product_search_vector == "'product':1 'xyz':2" 67 | -------------------------------------------------------------------------------- /tests/types/test_weekdays.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import i18n 5 | from sqlalchemy_utils.primitives import WeekDays 6 | from sqlalchemy_utils.types import WeekDaysType 7 | 8 | 9 | @pytest.fixture 10 | def Schedule(Base): 11 | class Schedule(Base): 12 | __tablename__ = 'schedule' 13 | id = sa.Column(sa.Integer, primary_key=True) 14 | working_days = sa.Column(WeekDaysType) 15 | 16 | def __repr__(self): 17 | return 'Schedule(%r)' % self.id 18 | return Schedule 19 | 20 | 21 | @pytest.fixture 22 | def init_models(Schedule): 23 | pass 24 | 25 | 26 | @pytest.fixture 27 | def set_get_locale(): 28 | i18n.get_locale = lambda: i18n.babel.Locale('en') 29 | 30 | 31 | @pytest.mark.usefixtures('set_get_locale') 32 | @pytest.mark.skipif('i18n.babel is None') 33 | class WeekDaysTypeTestCase: 34 | def test_color_parameter_processing(self, session, Schedule): 35 | schedule = Schedule( 36 | working_days=b'0001111' 37 | ) 38 | session.add(schedule) 39 | session.commit() 40 | 41 | schedule = session.query(Schedule).first() 42 | assert isinstance(schedule.working_days, WeekDays) 43 | 44 | def test_scalar_attributes_get_coerced_to_objects(self, Schedule): 45 | schedule = Schedule(working_days=b'1010101') 46 | 47 | assert isinstance(schedule.working_days, WeekDays) 48 | 49 | def test_compilation(self, Schedule, session): 50 | query = sa.select(Schedule.working_days) 51 | # the type should be cacheable and not throw exception 52 | session.execute(query) 53 | 54 | 55 | @pytest.mark.usefixtures('sqlite_memory_dsn') 56 | class TestWeekDaysTypeOnSQLite(WeekDaysTypeTestCase): 57 | pass 58 | 59 | 60 | @pytest.mark.usefixtures('postgresql_dsn') 61 | class TestWeekDaysTypeOnPostgres(WeekDaysTypeTestCase): 62 | pass 63 | 64 | 65 | @pytest.mark.usefixtures('mysql_dsn') 66 | class TestWeekDaysTypeOnMySQL(WeekDaysTypeTestCase): 67 | pass 68 | -------------------------------------------------------------------------------- /tests/primitives/test_currency.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sqlalchemy_utils import Currency, i18n 4 | 5 | 6 | @pytest.fixture 7 | def set_get_locale(): 8 | i18n.get_locale = lambda: i18n.babel.Locale('en') 9 | 10 | 11 | @pytest.mark.skipif('i18n.babel is None') 12 | @pytest.mark.usefixtures('set_get_locale') 13 | class TestCurrency: 14 | 15 | def test_init(self): 16 | assert Currency('USD') == Currency(Currency('USD')) 17 | 18 | def test_hashability(self): 19 | assert len({Currency('USD'), Currency('USD')}) == 1 20 | 21 | def test_invalid_currency_code(self): 22 | with pytest.raises(ValueError): 23 | Currency('Unknown code') 24 | 25 | def test_invalid_currency_code_type(self): 26 | with pytest.raises(TypeError): 27 | Currency(None) 28 | 29 | @pytest.mark.parametrize( 30 | ('code', 'name'), 31 | ( 32 | ('USD', 'US Dollar'), 33 | ('EUR', 'Euro') 34 | ) 35 | ) 36 | def test_name_property(self, code, name): 37 | assert Currency(code).name == name 38 | 39 | @pytest.mark.parametrize( 40 | ('code', 'symbol'), 41 | ( 42 | ('USD', '$'), 43 | ('EUR', '€') 44 | ) 45 | ) 46 | def test_symbol_property(self, code, symbol): 47 | assert Currency(code).symbol == symbol 48 | 49 | def test_equality_operator(self): 50 | assert Currency('USD') == 'USD' 51 | assert 'USD' == Currency('USD') 52 | assert Currency('USD') == Currency('USD') 53 | 54 | def test_non_equality_operator(self): 55 | assert Currency('USD') != 'EUR' 56 | assert not (Currency('USD') != 'USD') 57 | 58 | def test_unicode(self): 59 | currency = Currency('USD') 60 | assert str(currency) == 'USD' 61 | 62 | def test_str(self): 63 | currency = Currency('USD') 64 | assert str(currency) == 'USD' 65 | 66 | def test_representation(self): 67 | currency = Currency('USD') 68 | assert repr(currency) == "Currency('USD')" 69 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/locale.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from ..exceptions import ImproperlyConfigured 4 | from .scalar_coercible import ScalarCoercible 5 | 6 | babel = None 7 | try: 8 | import babel 9 | except ImportError: 10 | pass 11 | 12 | 13 | class LocaleType(ScalarCoercible, types.TypeDecorator): 14 | """ 15 | LocaleType saves Babel_ Locale objects into database. The Locale objects 16 | are converted to string on the way in and back to object on the way out. 17 | 18 | In order to use LocaleType you need to install Babel_ first. 19 | 20 | .. _Babel: https://babel.pocoo.org/ 21 | 22 | :: 23 | 24 | 25 | from sqlalchemy_utils import LocaleType 26 | from babel import Locale 27 | 28 | 29 | class User(Base): 30 | __tablename__ = 'user' 31 | id = sa.Column(sa.Integer, autoincrement=True) 32 | name = sa.Column(sa.Unicode(50)) 33 | locale = sa.Column(LocaleType) 34 | 35 | 36 | user = User() 37 | user.locale = Locale('en_US') 38 | session.add(user) 39 | session.commit() 40 | 41 | 42 | Like many other types this type also supports scalar coercion: 43 | 44 | :: 45 | 46 | 47 | user.locale = 'de_DE' 48 | user.locale # Locale('de', territory='DE') 49 | 50 | """ 51 | 52 | impl = types.Unicode(10) 53 | 54 | cache_ok = True 55 | 56 | def __init__(self): 57 | if babel is None: 58 | raise ImproperlyConfigured('Babel packaged is required with LocaleType.') 59 | 60 | def process_bind_param(self, value, dialect): 61 | if isinstance(value, babel.Locale): 62 | return str(value) 63 | 64 | if isinstance(value, str): 65 | return value 66 | 67 | def process_result_value(self, value, dialect): 68 | if value is not None: 69 | return babel.Locale.parse(value) 70 | 71 | def _coerce(self, value): 72 | if value is not None and not isinstance(value, babel.Locale): 73 | return babel.Locale.parse(value) 74 | return value 75 | -------------------------------------------------------------------------------- /tests/types/test_uuid.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | from sqlalchemy.dialects import postgresql 6 | 7 | from sqlalchemy_utils import UUIDType 8 | 9 | 10 | @pytest.fixture 11 | def User(Base): 12 | class User(Base): 13 | __tablename__ = 'user' 14 | id = sa.Column(UUIDType, default=uuid.uuid4, primary_key=True) 15 | 16 | def __repr__(self): 17 | return 'User(%r)' % self.id 18 | return User 19 | 20 | 21 | @pytest.fixture 22 | def init_models(User): 23 | pass 24 | 25 | 26 | class TestUUIDType: 27 | def test_repr(self): 28 | plain = UUIDType() 29 | assert repr(plain) == 'UUIDType()' 30 | 31 | text = UUIDType(binary=False) 32 | assert repr(text) == 'UUIDType(binary=False)' 33 | 34 | not_native = UUIDType(native=False) 35 | assert repr(not_native) == 'UUIDType(native=False)' 36 | 37 | def test_commit(self, session, User): 38 | obj = User() 39 | obj.id = uuid.uuid4().hex 40 | 41 | session.add(obj) 42 | session.commit() 43 | 44 | u = session.query(User).one() 45 | 46 | assert u.id == obj.id 47 | 48 | def test_coerce(self, User): 49 | obj = User() 50 | obj.id = identifier = uuid.uuid4().hex 51 | 52 | assert isinstance(obj.id, uuid.UUID) 53 | assert obj.id.hex == identifier 54 | 55 | obj.id = identifier = uuid.uuid4().bytes 56 | 57 | assert isinstance(obj.id, uuid.UUID) 58 | assert obj.id.bytes == identifier 59 | 60 | def test_compilation(self, User, session): 61 | query = sa.select(User.id) 62 | 63 | # the type should be cacheable and not throw exception 64 | session.execute(query) 65 | 66 | def test_literal_bind(self, User): 67 | expr = (User.id == 'b4e794d6-5750-4844-958c-fa382649719d').compile( 68 | dialect=postgresql.dialect(), 69 | compile_kwargs={'literal_binds': True} 70 | ) 71 | assert str(expr) == ( 72 | '''"user".id = \'b4e794d6-5750-4844-958c-fa382649719d\'''' 73 | ) 74 | -------------------------------------------------------------------------------- /tests/test_expressions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | from sqlalchemy.dialects import postgresql 4 | from sqlalchemy.orm import declarative_base 5 | 6 | from sqlalchemy_utils import Asterisk, row_to_json 7 | 8 | 9 | @pytest.fixture 10 | def assert_startswith(session): 11 | def assert_startswith(query, query_part): 12 | assert str( 13 | query.compile(dialect=postgresql.dialect()) 14 | ).startswith(query_part) 15 | # Check that query executes properly 16 | session.execute(query) 17 | return assert_startswith 18 | 19 | 20 | @pytest.fixture 21 | def Article(Base): 22 | class Article(Base): 23 | __tablename__ = 'article' 24 | id = sa.Column(sa.Integer, primary_key=True) 25 | name = sa.Column(sa.Unicode(255)) 26 | content = sa.Column(sa.UnicodeText) 27 | return Article 28 | 29 | 30 | class TestAsterisk: 31 | def test_with_table_object(self): 32 | Base = declarative_base() 33 | 34 | class Article(Base): 35 | __tablename__ = 'article' 36 | id = sa.Column(sa.Integer, primary_key=True) 37 | 38 | assert str(Asterisk(Article.__table__)) == 'article.*' 39 | 40 | def test_with_quoted_identifier(self): 41 | Base = declarative_base() 42 | 43 | class User(Base): 44 | __tablename__ = 'user' 45 | id = sa.Column(sa.Integer, primary_key=True) 46 | 47 | assert str(Asterisk(User.__table__).compile( 48 | dialect=postgresql.dialect() 49 | )) == '"user".*' 50 | 51 | 52 | class TestRowToJson: 53 | def test_compiler_with_default_dialect(self): 54 | assert str(row_to_json(sa.text('article.*'))) == ( 55 | 'row_to_json(article.*)' 56 | ) 57 | 58 | def test_compiler_with_postgresql(self): 59 | assert str(row_to_json(sa.text('article.*')).compile( 60 | dialect=postgresql.dialect() 61 | )) == 'row_to_json(article.*)' 62 | 63 | def test_type(self): 64 | assert isinstance( 65 | sa.func.row_to_json(sa.text('article.*')).type, 66 | postgresql.JSON 67 | ) 68 | -------------------------------------------------------------------------------- /tests/types/test_locale.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.types import locale 5 | 6 | 7 | @pytest.fixture 8 | def User(Base): 9 | class User(Base): 10 | __tablename__ = 'user' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | locale = sa.Column(locale.LocaleType) 13 | 14 | def __repr__(self): 15 | return 'User(%r)' % self.id 16 | return User 17 | 18 | 19 | @pytest.fixture 20 | def init_models(User): 21 | pass 22 | 23 | 24 | @pytest.mark.skipif('locale.babel is None') 25 | class TestLocaleType: 26 | def test_parameter_processing(self, session, User): 27 | user = User( 28 | locale=locale.babel.Locale('fi') 29 | ) 30 | 31 | session.add(user) 32 | session.commit() 33 | 34 | user = session.query(User).first() 35 | 36 | def test_territory_parsing(self, session, User): 37 | ko_kr = locale.babel.Locale('ko', territory='KR') 38 | user = User(locale=ko_kr) 39 | session.add(user) 40 | session.commit() 41 | 42 | assert session.query(User.locale).first()[0] == ko_kr 43 | 44 | def test_coerce_territory_parsing(self, User): 45 | user = User() 46 | user.locale = 'ko_KR' 47 | assert user.locale == locale.babel.Locale('ko', territory='KR') 48 | 49 | def test_scalar_attributes_get_coerced_to_objects(self, User): 50 | user = User(locale='en_US') 51 | 52 | assert isinstance(user.locale, locale.babel.Locale) 53 | 54 | def test_unknown_locale_throws_exception(self, User): 55 | with pytest.raises(locale.babel.UnknownLocaleError): 56 | User(locale='unknown') 57 | 58 | def test_literal_param(self, session, User): 59 | clause = User.locale == 'en_US' 60 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) 61 | assert compiled == '"user".locale = \'en_US\'' 62 | 63 | def test_compilation(self, User, session): 64 | query = sa.select(User.locale) 65 | 66 | # the type should be cacheable and not throw exception 67 | session.execute(query) 68 | -------------------------------------------------------------------------------- /sqlalchemy_utils/operators.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | 3 | 4 | def inspect_type(mixed): 5 | if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): 6 | return mixed.property.columns[0].type 7 | elif isinstance(mixed, sa.orm.ColumnProperty): 8 | return mixed.columns[0].type 9 | elif isinstance(mixed, sa.Column): 10 | return mixed.type 11 | 12 | 13 | def is_case_insensitive(mixed): 14 | try: 15 | return isinstance(inspect_type(mixed).comparator, CaseInsensitiveComparator) 16 | except AttributeError: 17 | try: 18 | return issubclass( 19 | inspect_type(mixed).comparator_factory, CaseInsensitiveComparator 20 | ) 21 | except AttributeError: 22 | return False 23 | 24 | 25 | class CaseInsensitiveComparator(sa.Unicode.Comparator): 26 | @classmethod 27 | def lowercase_arg(cls, func): 28 | def operation(self, other, **kwargs): 29 | operator = getattr(sa.Unicode.Comparator, func) 30 | if other is None: 31 | return operator(self, other, **kwargs) 32 | if not is_case_insensitive(other): 33 | other = sa.func.lower(other) 34 | return operator(self, other, **kwargs) 35 | 36 | return operation 37 | 38 | def in_(self, other): 39 | if isinstance(other, list) or isinstance(other, tuple): 40 | other = map(sa.func.lower, other) 41 | return sa.Unicode.Comparator.in_(self, other) 42 | 43 | def notin_(self, other): 44 | if isinstance(other, list) or isinstance(other, tuple): 45 | other = map(sa.func.lower, other) 46 | return sa.Unicode.Comparator.notin_(self, other) 47 | 48 | 49 | string_operator_funcs = [ 50 | '__eq__', 51 | '__ne__', 52 | '__lt__', 53 | '__le__', 54 | '__gt__', 55 | '__ge__', 56 | 'concat', 57 | 'contains', 58 | 'ilike', 59 | 'like', 60 | 'notlike', 61 | 'notilike', 62 | 'startswith', 63 | 'endswith', 64 | ] 65 | 66 | for func in string_operator_funcs: 67 | setattr( 68 | CaseInsensitiveComparator, func, CaseInsensitiveComparator.lowercase_arg(func) 69 | ) 70 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList 4 | 5 | from .arrow import ArrowType # noqa 6 | from .choice import Choice, ChoiceType # noqa 7 | from .color import ColorType # noqa 8 | from .country import CountryType # noqa 9 | from .currency import CurrencyType # noqa 10 | from .email import EmailType # noqa 11 | from .encrypted.encrypted_type import ( # noqa 12 | EncryptedType, 13 | StringEncryptedType, 14 | ) 15 | from .enriched_datetime.enriched_date_type import EnrichedDateType # noqa 16 | from .ip_address import IPAddressType # noqa 17 | from .json import JSONType # noqa 18 | from .locale import LocaleType # noqa 19 | from .ltree import LtreeType # noqa 20 | from .password import Password, PasswordType # noqa 21 | from .pg_composite import ( # noqa 22 | CompositeType, 23 | register_composites, 24 | remove_composite_listeners, 25 | ) 26 | from .phone_number import ( # noqa 27 | PhoneNumber, 28 | PhoneNumberParseException, 29 | PhoneNumberType, 30 | ) 31 | from .range import ( # noqa 32 | DateRangeType, 33 | DateTimeRangeType, 34 | Int8RangeType, 35 | IntRangeType, 36 | NumericRangeType, 37 | ) 38 | from .scalar_list import ScalarListException, ScalarListType # noqa 39 | from .timezone import TimezoneType # noqa 40 | from .ts_vector import TSVectorType # noqa 41 | from .url import URLType # noqa 42 | from .uuid import UUIDType # noqa 43 | from .weekdays import WeekDaysType # noqa 44 | 45 | from .enriched_datetime.enriched_datetime_type import EnrichedDateTimeType # noqa isort:skip 46 | 47 | 48 | class InstrumentedList(_InstrumentedList): 49 | """Enhanced version of SQLAlchemy InstrumentedList. Provides some 50 | additional functionality.""" 51 | 52 | def any(self, attr): 53 | return any(getattr(item, attr) for item in self) 54 | 55 | def all(self, attr): 56 | return all(getattr(item, attr) for item in self) 57 | 58 | 59 | def instrumented_list(f): 60 | @wraps(f) 61 | def wrapper(*args, **kwargs): 62 | return InstrumentedList([item for item in f(*args, **kwargs)]) 63 | 64 | return wrapper 65 | -------------------------------------------------------------------------------- /tests/types/test_enriched_date_pendulum.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | 6 | from sqlalchemy_utils.types.enriched_datetime import ( 7 | enriched_date_type, 8 | pendulum_date 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def User(Base): 14 | class User(Base): 15 | __tablename__ = 'users' 16 | id = sa.Column(sa.Integer, primary_key=True) 17 | birthday = sa.Column( 18 | enriched_date_type.EnrichedDateType( 19 | date_processor=pendulum_date.PendulumDate 20 | ) 21 | ) 22 | return User 23 | 24 | 25 | @pytest.fixture 26 | def init_models(User): 27 | pass 28 | 29 | 30 | @pytest.mark.skipif('pendulum_date.pendulum is None') 31 | class TestPendulumDateType: 32 | def test_parameter_processing(self, session, User): 33 | user = User( 34 | birthday=pendulum_date.pendulum.date(1995, 7, 11) 35 | ) 36 | 37 | session.add(user) 38 | session.commit() 39 | 40 | user = session.query(User).first() 41 | assert isinstance(user.birthday, date) 42 | 43 | def test_int_coercion(self, User): 44 | user = User( 45 | birthday=1367900664 46 | ) 47 | assert user.birthday.year == 2013 48 | 49 | def test_string_coercion(self, User): 50 | user = User( 51 | birthday='1367900664' 52 | ) 53 | assert user.birthday.year == 2013 54 | 55 | def test_utc(self, session, User): 56 | time = pendulum_date.pendulum.now("UTC") 57 | user = User(birthday=time) 58 | session.add(user) 59 | assert user.birthday == time 60 | session.commit() 61 | assert user.birthday == time.date() 62 | 63 | def test_literal_param(self, session, User): 64 | clause = User.birthday > date(2015, 1, 1) 65 | compiled = str(clause.compile(compile_kwargs={"literal_binds": True})) 66 | assert compiled == "users.birthday > '2015-01-01'" 67 | 68 | def test_compilation(self, User, session): 69 | query = sa.select(User.birthday) 70 | # the type should be cacheable and not throw exception 71 | session.execute(query) 72 | -------------------------------------------------------------------------------- /tests/types/test_color.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | from flexmock import flexmock 4 | 5 | from sqlalchemy_utils import ColorType, types # noqa 6 | 7 | 8 | @pytest.fixture 9 | def Document(Base): 10 | class Document(Base): 11 | __tablename__ = 'document' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | bg_color = sa.Column(ColorType) 14 | 15 | def __repr__(self): 16 | return 'Document(%r)' % self.id 17 | return Document 18 | 19 | 20 | @pytest.fixture 21 | def init_models(Document): 22 | pass 23 | 24 | 25 | @pytest.mark.skipif('types.color.python_colour_type is None') 26 | class TestColorType: 27 | def test_string_parameter_processing(self, session, Document): 28 | from colour import Color 29 | 30 | flexmock(ColorType).should_receive('_coerce').and_return( 31 | 'white' 32 | ) 33 | document = Document( 34 | bg_color='white' 35 | ) 36 | 37 | session.add(document) 38 | session.commit() 39 | 40 | document = session.query(Document).first() 41 | assert document.bg_color.hex == Color('white').hex 42 | 43 | def test_color_parameter_processing(self, session, Document): 44 | from colour import Color 45 | 46 | document = Document(bg_color=Color('white')) 47 | session.add(document) 48 | session.commit() 49 | 50 | document = session.query(Document).first() 51 | assert document.bg_color.hex == Color('white').hex 52 | 53 | def test_scalar_attributes_get_coerced_to_objects(self, Document): 54 | from colour import Color 55 | 56 | document = Document(bg_color='white') 57 | assert isinstance(document.bg_color, Color) 58 | 59 | def test_literal_param(self, session, Document): 60 | clause = Document.bg_color == 'white' 61 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) 62 | assert compiled == "document.bg_color = 'white'" 63 | 64 | def test_compilation(self, Document, session): 65 | query = sa.select(Document.bg_color) 66 | # the type should be cacheable and not throw exception 67 | session.execute(query) 68 | -------------------------------------------------------------------------------- /tests/aggregate/test_with_column_alias.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.aggregates import aggregated 5 | 6 | 7 | @pytest.fixture 8 | def Thread(Base): 9 | class Thread(Base): 10 | __tablename__ = 'thread' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | 13 | @aggregated( 14 | 'comments', 15 | sa.Column('_comment_count', sa.Integer, default=0) 16 | ) 17 | def comment_count(self): 18 | return sa.func.count('1') 19 | 20 | comments = sa.orm.relationship('Comment', backref='thread') 21 | return Thread 22 | 23 | 24 | @pytest.fixture 25 | def Comment(Base): 26 | class Comment(Base): 27 | __tablename__ = 'comment' 28 | id = sa.Column(sa.Integer, primary_key=True) 29 | thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) 30 | return Comment 31 | 32 | 33 | @pytest.fixture 34 | def init_models(Thread, Comment): 35 | pass 36 | 37 | 38 | class TestAggregatedWithColumnAlias: 39 | 40 | def test_assigns_aggregates_on_insert(self, session, Thread, Comment): 41 | thread = Thread() 42 | session.add(thread) 43 | comment = Comment(thread=thread) 44 | session.add(comment) 45 | session.commit() 46 | session.refresh(thread) 47 | assert thread.comment_count == 1 48 | 49 | def test_assigns_aggregates_on_separate_insert( 50 | self, 51 | session, 52 | Thread, 53 | Comment 54 | ): 55 | thread = Thread() 56 | session.add(thread) 57 | session.commit() 58 | comment = Comment(thread=thread) 59 | session.add(comment) 60 | session.commit() 61 | session.refresh(thread) 62 | assert thread.comment_count == 1 63 | 64 | def test_assigns_aggregates_on_delete(self, session, Thread, Comment): 65 | thread = Thread() 66 | session.add(thread) 67 | session.commit() 68 | comment = Comment(thread=thread) 69 | session.add(comment) 70 | session.commit() 71 | session.delete(comment) 72 | session.commit() 73 | session.refresh(thread) 74 | assert thread.comment_count == 0 75 | -------------------------------------------------------------------------------- /tests/functions/test_non_indexed_foreign_keys.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | 6 | from sqlalchemy_utils.functions import non_indexed_foreign_keys 7 | 8 | 9 | class TestFindNonIndexedForeignKeys: 10 | 11 | @pytest.fixture 12 | def User(self, Base): 13 | class User(Base): 14 | __tablename__ = 'user' 15 | id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) 16 | name = sa.Column(sa.Unicode(255)) 17 | return User 18 | 19 | @pytest.fixture 20 | def Category(self, Base): 21 | class Category(Base): 22 | __tablename__ = 'category' 23 | id = sa.Column(sa.Integer, primary_key=True) 24 | name = sa.Column(sa.Unicode(255)) 25 | return Category 26 | 27 | @pytest.fixture 28 | def Article(self, Base, User, Category): 29 | class Article(Base): 30 | __tablename__ = 'article' 31 | id = sa.Column(sa.Integer, primary_key=True) 32 | name = sa.Column(sa.Unicode(255)) 33 | author_id = sa.Column( 34 | sa.Integer, sa.ForeignKey(User.id), index=True 35 | ) 36 | category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) 37 | 38 | category = sa.orm.relationship( 39 | Category, 40 | primaryjoin=category_id == Category.id, 41 | backref=sa.orm.backref( 42 | 'articles', 43 | ) 44 | ) 45 | return Article 46 | 47 | @pytest.fixture 48 | def init_models(self, User, Category, Article): 49 | pass 50 | 51 | def test_finds_all_non_indexed_fks(self, session, Base, engine): 52 | fks = non_indexed_foreign_keys(Base.metadata, engine) 53 | assert ( 54 | 'article' in 55 | fks 56 | ) 57 | column_names = list(chain( 58 | *( 59 | names for names in ( 60 | fk.columns.keys() 61 | for fk in fks['article'] 62 | ) 63 | ) 64 | )) 65 | assert 'category_id' in column_names 66 | assert 'author_id' not in column_names 67 | -------------------------------------------------------------------------------- /docs/orm_helpers.rst: -------------------------------------------------------------------------------- 1 | ORM helpers 2 | =========== 3 | 4 | 5 | cast_if 6 | ------- 7 | 8 | .. autofunction:: sqlalchemy_utils.functions.cast_if 9 | 10 | 11 | escape_like 12 | ----------- 13 | 14 | .. autofunction:: sqlalchemy_utils.functions.escape_like 15 | 16 | 17 | get_bind 18 | -------- 19 | 20 | .. autofunction:: sqlalchemy_utils.functions.get_bind 21 | 22 | 23 | get_class_by_table 24 | ------------------ 25 | 26 | .. autofunction:: sqlalchemy_utils.functions.get_class_by_table 27 | 28 | 29 | get_column_key 30 | -------------- 31 | 32 | .. autofunction:: sqlalchemy_utils.functions.get_column_key 33 | 34 | 35 | get_columns 36 | ----------- 37 | 38 | .. autofunction:: sqlalchemy_utils.functions.get_columns 39 | 40 | 41 | get_declarative_base 42 | -------------------- 43 | 44 | .. autofunction:: sqlalchemy_utils.functions.get_declarative_base 45 | 46 | 47 | get_hybrid_properties 48 | --------------------- 49 | 50 | .. autofunction:: sqlalchemy_utils.functions.get_hybrid_properties 51 | 52 | 53 | get_mapper 54 | ---------- 55 | 56 | .. autofunction:: sqlalchemy_utils.functions.get_mapper 57 | 58 | 59 | get_primary_keys 60 | ---------------- 61 | 62 | .. autofunction:: sqlalchemy_utils.functions.get_primary_keys 63 | 64 | 65 | get_tables 66 | ---------- 67 | 68 | .. autofunction:: sqlalchemy_utils.functions.get_tables 69 | 70 | 71 | get_type 72 | -------- 73 | 74 | .. autofunction:: sqlalchemy_utils.functions.get_type 75 | 76 | 77 | has_changes 78 | ----------- 79 | 80 | .. autofunction:: sqlalchemy_utils.functions.has_changes 81 | 82 | 83 | identity 84 | -------- 85 | 86 | .. autofunction:: sqlalchemy_utils.functions.identity 87 | 88 | 89 | is_loaded 90 | --------- 91 | 92 | .. autofunction:: sqlalchemy_utils.functions.is_loaded 93 | 94 | 95 | make_order_by_deterministic 96 | --------------------------- 97 | 98 | .. autofunction:: sqlalchemy_utils.functions.make_order_by_deterministic 99 | 100 | 101 | naturally_equivalent 102 | -------------------- 103 | 104 | .. autofunction:: sqlalchemy_utils.functions.naturally_equivalent 105 | 106 | 107 | quote 108 | ----- 109 | 110 | .. autofunction:: sqlalchemy_utils.functions.quote 111 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/currency.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from .. import i18n, ImproperlyConfigured 4 | from ..primitives import Currency 5 | from .scalar_coercible import ScalarCoercible 6 | 7 | 8 | class CurrencyType(ScalarCoercible, types.TypeDecorator): 9 | """ 10 | Changes :class:`.Currency` objects to a string representation on the way in 11 | and changes them back to :class:`.Currency` objects on the way out. 12 | 13 | In order to use CurrencyType you need to install Babel_ first. 14 | 15 | .. _Babel: https://babel.pocoo.org/ 16 | 17 | :: 18 | 19 | 20 | from sqlalchemy_utils import CurrencyType, Currency 21 | 22 | 23 | class User(Base): 24 | __tablename__ = 'user' 25 | id = sa.Column(sa.Integer, autoincrement=True) 26 | name = sa.Column(sa.Unicode(255)) 27 | currency = sa.Column(CurrencyType) 28 | 29 | 30 | user = User() 31 | user.currency = Currency('USD') 32 | session.add(user) 33 | session.commit() 34 | 35 | user.currency # Currency('USD') 36 | user.currency.name # US Dollar 37 | 38 | str(user.currency) # US Dollar 39 | user.currency.symbol # $ 40 | 41 | 42 | 43 | CurrencyType is scalar coercible:: 44 | 45 | 46 | user.currency = 'US' 47 | user.currency # Currency('US') 48 | """ 49 | 50 | impl = types.String(3) 51 | python_type = Currency 52 | cache_ok = True 53 | 54 | def __init__(self, *args, **kwargs): 55 | if i18n.babel is None: 56 | raise ImproperlyConfigured( 57 | "'babel' package is required in order to use CurrencyType." 58 | ) 59 | 60 | super().__init__(*args, **kwargs) 61 | 62 | def process_bind_param(self, value, dialect): 63 | if isinstance(value, Currency): 64 | return value.code 65 | elif isinstance(value, str): 66 | return value 67 | 68 | def process_result_value(self, value, dialect): 69 | if value is not None: 70 | return Currency(value) 71 | 72 | def _coerce(self, value): 73 | if value is not None and not isinstance(value, Currency): 74 | return Currency(value) 75 | return value 76 | -------------------------------------------------------------------------------- /tests/aggregate/test_o2m_o2m.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | 6 | from sqlalchemy_utils.aggregates import aggregated 7 | 8 | 9 | @pytest.fixture 10 | def Catalog(Base): 11 | class Catalog(Base): 12 | __tablename__ = 'catalog' 13 | id = sa.Column(sa.Integer, primary_key=True) 14 | name = sa.Column(sa.Unicode(255)) 15 | 16 | @aggregated( 17 | 'categories.products', 18 | sa.Column(sa.Integer, default=0) 19 | ) 20 | def product_count(self): 21 | return sa.func.count('1') 22 | 23 | categories = sa.orm.relationship('Category', backref='catalog') 24 | return Catalog 25 | 26 | 27 | @pytest.fixture 28 | def Category(Base): 29 | class Category(Base): 30 | __tablename__ = 'category' 31 | id = sa.Column(sa.Integer, primary_key=True) 32 | name = sa.Column(sa.Unicode(255)) 33 | 34 | catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) 35 | 36 | products = sa.orm.relationship('Product', backref='category') 37 | return Category 38 | 39 | 40 | @pytest.fixture 41 | def Product(Base): 42 | class Product(Base): 43 | __tablename__ = 'product' 44 | id = sa.Column(sa.Integer, primary_key=True) 45 | name = sa.Column(sa.Unicode(255)) 46 | price = sa.Column(sa.Numeric) 47 | 48 | category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) 49 | return Product 50 | 51 | 52 | @pytest.fixture 53 | def init_models(Catalog, Category, Product): 54 | pass 55 | 56 | 57 | @pytest.mark.usefixtures('postgresql_dsn') 58 | class TestAggregateOneToManyAndOneToMany: 59 | 60 | def test_assigns_aggregates(self, session, Category, Catalog, Product): 61 | category = Category(name='Some category') 62 | catalog = Catalog( 63 | categories=[category] 64 | ) 65 | catalog.name = 'Some catalog' 66 | session.add(catalog) 67 | session.commit() 68 | product = Product( 69 | name='Some product', 70 | price=Decimal('1000'), 71 | category=category 72 | ) 73 | session.add(product) 74 | session.commit() 75 | session.refresh(catalog) 76 | assert catalog.product_count == 1 77 | -------------------------------------------------------------------------------- /tests/generic_relationship/test_hybrid_properties.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | from sqlalchemy.ext.hybrid import hybrid_property 4 | 5 | from sqlalchemy_utils import generic_relationship 6 | 7 | 8 | @pytest.fixture 9 | def User(Base): 10 | class User(Base): 11 | __tablename__ = 'user' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | return User 14 | 15 | 16 | @pytest.fixture 17 | def UserHistory(Base): 18 | class UserHistory(Base): 19 | __tablename__ = 'user_history' 20 | id = sa.Column(sa.Integer, primary_key=True) 21 | 22 | transaction_id = sa.Column(sa.Integer, primary_key=True) 23 | return UserHistory 24 | 25 | 26 | @pytest.fixture 27 | def Event(Base): 28 | class Event(Base): 29 | __tablename__ = 'event' 30 | id = sa.Column(sa.Integer, primary_key=True) 31 | 32 | transaction_id = sa.Column(sa.Integer) 33 | 34 | object_type = sa.Column(sa.Unicode(255)) 35 | object_id = sa.Column(sa.Integer, nullable=False) 36 | 37 | object = generic_relationship( 38 | object_type, object_id 39 | ) 40 | 41 | @hybrid_property 42 | def object_version_type(self): 43 | return self.object_type + 'History' 44 | 45 | @object_version_type.expression 46 | def object_version_type(cls): 47 | return sa.func.concat(cls.object_type, 'History') 48 | 49 | object_version = generic_relationship( 50 | object_version_type, (object_id, transaction_id) 51 | ) 52 | return Event 53 | 54 | 55 | @pytest.fixture 56 | def init_models(User, UserHistory, Event): 57 | pass 58 | 59 | 60 | class TestGenericRelationship: 61 | 62 | def test_set_manual_and_get(self, session, User, UserHistory, Event): 63 | user = User(id=1) 64 | history = UserHistory(id=1, transaction_id=1) 65 | session.add(user) 66 | session.add(history) 67 | session.commit() 68 | 69 | event = Event(transaction_id=1) 70 | event.object_id = user.id 71 | event.object_type = type(user).__name__ 72 | assert event.object is None 73 | 74 | session.add(event) 75 | session.commit() 76 | 77 | assert event.object == user 78 | assert event.object_version == history 79 | -------------------------------------------------------------------------------- /tests/generic_relationship/test_composite_keys.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import generic_relationship 5 | 6 | from ..generic_relationship import GenericRelationshipTestCase 7 | 8 | 9 | @pytest.fixture 10 | def incrementor(): 11 | class Incrementor: 12 | value = 1 13 | return Incrementor() 14 | 15 | 16 | @pytest.fixture 17 | def Building(Base, incrementor): 18 | class Building(Base): 19 | __tablename__ = 'building' 20 | id = sa.Column(sa.Integer, primary_key=True) 21 | code = sa.Column(sa.Integer, primary_key=True) 22 | 23 | def __init__(self): 24 | incrementor.value += 1 25 | self.id = incrementor.value 26 | self.code = incrementor.value 27 | return Building 28 | 29 | 30 | @pytest.fixture 31 | def User(Base, incrementor): 32 | class User(Base): 33 | __tablename__ = 'user' 34 | id = sa.Column(sa.Integer, primary_key=True) 35 | code = sa.Column(sa.Integer, primary_key=True) 36 | 37 | def __init__(self): 38 | incrementor.value += 1 39 | self.id = incrementor.value 40 | self.code = incrementor.value 41 | return User 42 | 43 | 44 | @pytest.fixture 45 | def Event(Base): 46 | class Event(Base): 47 | __tablename__ = 'event' 48 | id = sa.Column(sa.Integer, primary_key=True) 49 | 50 | object_type = sa.Column(sa.Unicode(255)) 51 | object_id = sa.Column(sa.Integer, nullable=False) 52 | object_code = sa.Column(sa.Integer, nullable=False) 53 | 54 | object = generic_relationship( 55 | object_type, (object_id, object_code) 56 | ) 57 | return Event 58 | 59 | 60 | @pytest.fixture 61 | def init_models(Building, User, Event): 62 | pass 63 | 64 | 65 | class TestGenericRelationship(GenericRelationshipTestCase): 66 | 67 | def test_set_manual_and_get(self, session, Event, User): 68 | user = User() 69 | 70 | session.add(user) 71 | session.commit() 72 | 73 | event = Event() 74 | event.object_id = user.id 75 | event.object_type = type(user).__name__ 76 | event.object_code = user.code 77 | 78 | assert event.object is None 79 | 80 | session.add(event) 81 | session.commit() 82 | 83 | assert event.object == user 84 | -------------------------------------------------------------------------------- /tests/aggregate/test_m2m.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.aggregates import aggregated 5 | 6 | 7 | @pytest.fixture 8 | def User(Base): 9 | user_group = sa.Table( 10 | 'user_group', 11 | Base.metadata, 12 | sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), 13 | sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) 14 | ) 15 | 16 | class User(Base): 17 | __tablename__ = 'user' 18 | id = sa.Column(sa.Integer, primary_key=True) 19 | name = sa.Column(sa.Unicode(255)) 20 | 21 | @aggregated('groups', sa.Column(sa.Integer, default=0)) 22 | def group_count(self): 23 | return sa.func.count('1') 24 | 25 | groups = sa.orm.relationship( 26 | 'Group', 27 | backref='users', 28 | secondary=user_group 29 | ) 30 | return User 31 | 32 | 33 | @pytest.fixture 34 | def Group(Base): 35 | class Group(Base): 36 | __tablename__ = 'group' 37 | id = sa.Column(sa.Integer, primary_key=True) 38 | name = sa.Column(sa.Unicode(255)) 39 | return Group 40 | 41 | 42 | @pytest.fixture 43 | def init_models(User, Group): 44 | pass 45 | 46 | 47 | @pytest.mark.skip 48 | @pytest.mark.usefixtures('postgresql_dsn') 49 | class TestAggregatesWithManyToManyRelationships: 50 | 51 | def test_assigns_aggregates_on_insert(self, session, User, Group): 52 | user = User( 53 | name='John Matrix' 54 | ) 55 | session.add(user) 56 | session.commit() 57 | group = Group( 58 | name='Some group', 59 | users=[user] 60 | ) 61 | session.add(group) 62 | session.commit() 63 | session.refresh(user) 64 | assert user.group_count == 1 65 | 66 | def test_updates_aggregates_on_delete(self, session, User, Group): 67 | user = User( 68 | name='John Matrix' 69 | ) 70 | session.add(user) 71 | session.commit() 72 | group = Group( 73 | name='Some group', 74 | users=[user] 75 | ) 76 | session.add(group) 77 | session.commit() 78 | session.refresh(user) 79 | user.groups = [] 80 | session.commit() 81 | session.refresh(user) 82 | assert user.group_count == 0 83 | -------------------------------------------------------------------------------- /sqlalchemy_utils/functions/render.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import io 3 | 4 | import sqlalchemy as sa 5 | 6 | from .mock import create_mock_engine 7 | from .orm import _get_query_compile_state 8 | 9 | 10 | def render_expression(expression, bind, stream=None): 11 | """Generate a SQL expression from the passed python expression. 12 | 13 | Only the global variable, `engine`, is available for use in the 14 | expression. Additional local variables may be passed in the context 15 | parameter. 16 | 17 | Note this function is meant for convenience and protected usage. Do NOT 18 | blindly pass user input to this function as it uses exec. 19 | 20 | :param bind: A SQLAlchemy engine or bind URL. 21 | :param stream: Render all DDL operations to the stream. 22 | """ 23 | 24 | # Create a stream if not present. 25 | 26 | if stream is None: 27 | stream = io.StringIO() 28 | 29 | engine = create_mock_engine(bind, stream) 30 | 31 | # Navigate the stack and find the calling frame that allows the 32 | # expression to execuate. 33 | 34 | for frame in inspect.stack()[1:]: 35 | try: 36 | frame = frame[0] 37 | local = dict(frame.f_locals) 38 | local['engine'] = engine 39 | exec(expression, frame.f_globals, local) 40 | break 41 | except Exception: 42 | pass 43 | else: 44 | raise ValueError('Not a valid python expression', engine) 45 | 46 | return stream 47 | 48 | 49 | def render_statement(statement, bind=None): 50 | """ 51 | Generate an SQL expression string with bound parameters rendered inline 52 | for the given SQLAlchemy statement. 53 | 54 | :param statement: SQLAlchemy Query object. 55 | :param bind: 56 | Optional SQLAlchemy bind, if None uses the bind of the given query 57 | object. 58 | """ 59 | 60 | if isinstance(statement, sa.orm.query.Query): 61 | if bind is None: 62 | bind = statement.session.get_bind( 63 | _get_query_compile_state(statement)._mapper_zero() 64 | ) 65 | 66 | statement = statement.statement 67 | 68 | elif bind is None: 69 | bind = statement.bind 70 | 71 | stream = io.StringIO() 72 | engine = create_mock_engine(bind.engine, stream=stream) 73 | engine.execute(statement) 74 | 75 | return stream.getvalue() 76 | -------------------------------------------------------------------------------- /tests/aggregate/test_custom_select_expressions.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | 6 | from sqlalchemy_utils.aggregates import aggregated 7 | 8 | 9 | @pytest.fixture 10 | def Product(Base): 11 | class Product(Base): 12 | __tablename__ = 'product' 13 | id = sa.Column(sa.Integer, primary_key=True) 14 | name = sa.Column(sa.Unicode(255)) 15 | price = sa.Column(sa.Numeric) 16 | 17 | catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) 18 | return Product 19 | 20 | 21 | @pytest.fixture 22 | def Catalog(Base, Product): 23 | class Catalog(Base): 24 | __tablename__ = 'catalog' 25 | id = sa.Column(sa.Integer, primary_key=True) 26 | name = sa.Column(sa.Unicode(255)) 27 | 28 | @aggregated('products', sa.Column(sa.Numeric, default=0)) 29 | def net_worth(self): 30 | return sa.func.sum(Product.price) 31 | 32 | products = sa.orm.relationship('Product', backref='catalog') 33 | return Catalog 34 | 35 | 36 | @pytest.fixture 37 | def init_models(Product, Catalog): 38 | pass 39 | 40 | 41 | @pytest.mark.usefixtures('postgresql_dsn') 42 | class TestLazyEvaluatedSelectExpressionsForAggregates: 43 | 44 | def test_assigns_aggregates_on_insert(self, session, Product, Catalog): 45 | catalog = Catalog( 46 | name='Some catalog' 47 | ) 48 | session.add(catalog) 49 | session.commit() 50 | product = Product( 51 | name='Some product', 52 | price=Decimal('1000'), 53 | catalog=catalog 54 | ) 55 | session.add(product) 56 | session.commit() 57 | session.refresh(catalog) 58 | assert catalog.net_worth == Decimal('1000') 59 | 60 | def test_assigns_aggregates_on_update(self, session, Product, Catalog): 61 | catalog = Catalog( 62 | name='Some catalog' 63 | ) 64 | session.add(catalog) 65 | session.commit() 66 | product = Product( 67 | name='Some product', 68 | price=Decimal('1000'), 69 | catalog=catalog 70 | ) 71 | session.add(product) 72 | session.commit() 73 | product.price = Decimal('500') 74 | session.commit() 75 | session.refresh(catalog) 76 | assert catalog.net_worth == Decimal('500') 77 | -------------------------------------------------------------------------------- /tests/types/test_tsvector.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | from sqlalchemy.dialects.postgresql import TSVECTOR 4 | 5 | from sqlalchemy_utils import TSVectorType 6 | 7 | 8 | @pytest.fixture 9 | def User(Base): 10 | class User(Base): 11 | __tablename__ = 'user' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | name = sa.Column(sa.Unicode(255)) 14 | search_index = sa.Column( 15 | TSVectorType(name, regconfig='pg_catalog.finnish') 16 | ) 17 | 18 | def __repr__(self): 19 | return 'User(%r)' % self.id 20 | return User 21 | 22 | 23 | @pytest.fixture 24 | def init_models(User): 25 | pass 26 | 27 | 28 | @pytest.mark.usefixtures('postgresql_dsn') 29 | class TestTSVector: 30 | def test_generates_table(self, User): 31 | assert 'search_index' in User.__table__.c 32 | 33 | @pytest.mark.usefixtures('session') 34 | def test_type_reflection(self, engine): 35 | reflected_metadata = sa.schema.MetaData() 36 | table = sa.schema.Table( 37 | 'user', 38 | reflected_metadata, 39 | autoload_with=engine, 40 | ) 41 | assert isinstance(table.c['search_index'].type, TSVECTOR) 42 | 43 | def test_catalog_and_columns_as_args(self): 44 | type_ = TSVectorType('name', 'age', regconfig='pg_catalog.simple') 45 | assert type_.columns == ('name', 'age') 46 | assert type_.options['regconfig'] == 'pg_catalog.simple' 47 | 48 | def test_catalog_passed_to_match(self, connection, User): 49 | expr = User.search_index.match('something') 50 | assert 'pg_catalog.finnish' in str(expr.compile(connection)) 51 | 52 | def test_concat(self, User): 53 | assert str(User.search_index | User.search_index) == ( 54 | '"user".search_index || "user".search_index' 55 | ) 56 | 57 | def test_match_concatenation(self, session, User): 58 | concat = User.search_index | User.search_index 59 | bind = session.bind 60 | assert 'pg_catalog.finnish' in str( 61 | concat.match('something').compile(bind) 62 | ) 63 | 64 | def test_match_with_catalog(self, connection, User): 65 | expr = User.search_index.match( 66 | 'something', 67 | postgresql_regconfig='pg_catalog.simple' 68 | ) 69 | assert 'pg_catalog.simple' in str(expr.compile(connection)) 70 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/color.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from ..exceptions import ImproperlyConfigured 4 | from .scalar_coercible import ScalarCoercible 5 | 6 | colour = None 7 | try: 8 | import colour 9 | 10 | python_colour_type = colour.Color 11 | except (ImportError, AttributeError): 12 | python_colour_type = None 13 | 14 | 15 | class ColorType(ScalarCoercible, types.TypeDecorator): 16 | """ 17 | ColorType provides a way for saving Color (from colour_ package) objects 18 | into database. ColorType saves Color objects as strings on the way in and 19 | converts them back to objects when querying the database. 20 | 21 | :: 22 | 23 | 24 | from colour import Color 25 | from sqlalchemy_utils import ColorType 26 | 27 | 28 | class Document(Base): 29 | __tablename__ = 'document' 30 | id = sa.Column(sa.Integer, autoincrement=True) 31 | name = sa.Column(sa.Unicode(50)) 32 | background_color = sa.Column(ColorType) 33 | 34 | 35 | document = Document() 36 | document.background_color = Color('#F5F5F5') 37 | session.commit() 38 | 39 | 40 | Querying the database returns Color objects: 41 | 42 | :: 43 | 44 | document = session.query(Document).first() 45 | 46 | document.background_color.hex 47 | # '#f5f5f5' 48 | 49 | 50 | .. _colour: https://github.com/vaab/colour 51 | """ 52 | 53 | STORE_FORMAT = 'hex' 54 | impl = types.Unicode(20) 55 | python_type = python_colour_type 56 | cache_ok = True 57 | 58 | def __init__(self, max_length=20, *args, **kwargs): 59 | # Fail if colour is not found. 60 | if colour is None: 61 | raise ImproperlyConfigured( 62 | "'colour' package is required to use 'ColorType'" 63 | ) 64 | 65 | super().__init__(*args, **kwargs) 66 | self.impl = types.Unicode(max_length) 67 | 68 | def process_bind_param(self, value, dialect): 69 | if value and isinstance(value, colour.Color): 70 | return str(getattr(value, self.STORE_FORMAT)) 71 | return value 72 | 73 | def process_result_value(self, value, dialect): 74 | if value: 75 | return colour.Color(value) 76 | return value 77 | 78 | def _coerce(self, value): 79 | if value is not None and not isinstance(value, colour.Color): 80 | return colour.Color(value) 81 | return value 82 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/json.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import sqlalchemy as sa 4 | from sqlalchemy.dialects.postgresql.base import ischema_names 5 | 6 | try: 7 | from sqlalchemy.dialects.postgresql import JSON 8 | 9 | has_postgres_json = True 10 | except ImportError: 11 | 12 | class PostgresJSONType(sa.types.UserDefinedType): 13 | """ 14 | Text search vector type for postgresql. 15 | """ 16 | 17 | def get_col_spec(self): 18 | return 'json' 19 | 20 | ischema_names['json'] = PostgresJSONType 21 | has_postgres_json = False 22 | 23 | 24 | class JSONType(sa.types.TypeDecorator): 25 | """ 26 | JSONType offers way of saving JSON data structures to database. On 27 | PostgreSQL the underlying implementation of this data type is 'json' while 28 | on other databases its simply 'text'. 29 | 30 | :: 31 | 32 | 33 | from sqlalchemy_utils import JSONType 34 | 35 | 36 | class Product(Base): 37 | __tablename__ = 'product' 38 | id = sa.Column(sa.Integer, autoincrement=True) 39 | name = sa.Column(sa.Unicode(50)) 40 | details = sa.Column(JSONType) 41 | 42 | 43 | product = Product() 44 | product.details = { 45 | 'color': 'red', 46 | 'type': 'car', 47 | 'max-speed': '400 mph' 48 | } 49 | session.commit() 50 | """ 51 | 52 | impl = sa.UnicodeText 53 | hashable = False 54 | cache_ok = True 55 | 56 | def __init__(self, *args, **kwargs): 57 | super().__init__(*args, **kwargs) 58 | 59 | def load_dialect_impl(self, dialect): 60 | if dialect.name == 'postgresql': 61 | # Use the native JSON type. 62 | if has_postgres_json: 63 | return dialect.type_descriptor(JSON()) 64 | else: 65 | return dialect.type_descriptor(PostgresJSONType()) 66 | else: 67 | return dialect.type_descriptor(self.impl) 68 | 69 | def process_bind_param(self, value, dialect): 70 | if dialect.name == 'postgresql' and has_postgres_json: 71 | return value 72 | if value is not None: 73 | value = json.dumps(value) 74 | return value 75 | 76 | def process_result_value(self, value, dialect): 77 | if dialect.name == 'postgresql': 78 | return value 79 | if value is not None: 80 | value = json.loads(value) 81 | return value 82 | -------------------------------------------------------------------------------- /tests/aggregate/test_backrefs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.aggregates import aggregated 5 | 6 | 7 | @pytest.fixture 8 | def Thread(Base): 9 | class Thread(Base): 10 | __tablename__ = 'thread' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | name = sa.Column(sa.Unicode(255)) 13 | 14 | @aggregated('comments', sa.Column(sa.Integer, default=0)) 15 | def comment_count(self): 16 | return sa.func.count('1') 17 | return Thread 18 | 19 | 20 | @pytest.fixture 21 | def Comment(Base, Thread): 22 | class Comment(Base): 23 | __tablename__ = 'comment' 24 | id = sa.Column(sa.Integer, primary_key=True) 25 | content = sa.Column(sa.Unicode(255)) 26 | thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) 27 | 28 | thread = sa.orm.relationship(Thread, backref='comments') 29 | return Comment 30 | 31 | 32 | @pytest.fixture 33 | def init_models(Thread, Comment): 34 | pass 35 | 36 | 37 | class TestAggregateValueGenerationWithBackrefs: 38 | 39 | def test_assigns_aggregates_on_insert(self, session, Thread, Comment): 40 | thread = Thread() 41 | thread.name = 'some article name' 42 | session.add(thread) 43 | comment = Comment(content='Some content', thread=thread) 44 | session.add(comment) 45 | session.commit() 46 | session.refresh(thread) 47 | assert thread.comment_count == 1 48 | 49 | def test_assigns_aggregates_on_separate_insert( 50 | self, 51 | session, 52 | Thread, 53 | Comment 54 | ): 55 | thread = Thread() 56 | thread.name = 'some article name' 57 | session.add(thread) 58 | session.commit() 59 | comment = Comment(content='Some content', thread=thread) 60 | session.add(comment) 61 | session.commit() 62 | session.refresh(thread) 63 | assert thread.comment_count == 1 64 | 65 | def test_assigns_aggregates_on_delete(self, session, Thread, Comment): 66 | thread = Thread() 67 | thread.name = 'some article name' 68 | session.add(thread) 69 | session.commit() 70 | comment = Comment(content='Some content', thread=thread) 71 | session.add(comment) 72 | session.commit() 73 | session.delete(comment) 74 | session.commit() 75 | session.refresh(thread) 76 | assert thread.comment_count == 0 77 | -------------------------------------------------------------------------------- /tests/aggregate/test_simple_paths.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.aggregates import aggregated 5 | 6 | 7 | @pytest.fixture 8 | def Thread(Base): 9 | class Thread(Base): 10 | __tablename__ = 'thread' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | name = sa.Column(sa.Unicode(255)) 13 | 14 | @aggregated('comments', sa.Column(sa.Integer, default=0)) 15 | def comment_count(self): 16 | return sa.func.count('1') 17 | 18 | comments = sa.orm.relationship('Comment', backref='thread') 19 | return Thread 20 | 21 | 22 | @pytest.fixture 23 | def Comment(Base): 24 | class Comment(Base): 25 | __tablename__ = 'comment' 26 | id = sa.Column(sa.Integer, primary_key=True) 27 | content = sa.Column(sa.Unicode(255)) 28 | thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) 29 | return Comment 30 | 31 | 32 | @pytest.fixture 33 | def init_models(Thread, Comment): 34 | pass 35 | 36 | 37 | class TestAggregateValueGenerationForSimpleModelPaths: 38 | 39 | def test_assigns_aggregates_on_insert(self, session, Thread, Comment): 40 | thread = Thread() 41 | thread.name = 'some article name' 42 | session.add(thread) 43 | comment = Comment(content='Some content', thread=thread) 44 | session.add(comment) 45 | session.commit() 46 | session.refresh(thread) 47 | assert thread.comment_count == 1 48 | 49 | def test_assigns_aggregates_on_separate_insert( 50 | self, 51 | session, 52 | Thread, 53 | Comment 54 | ): 55 | thread = Thread() 56 | thread.name = 'some article name' 57 | session.add(thread) 58 | session.commit() 59 | comment = Comment(content='Some content', thread=thread) 60 | session.add(comment) 61 | session.commit() 62 | session.refresh(thread) 63 | assert thread.comment_count == 1 64 | 65 | def test_assigns_aggregates_on_delete(self, session, Thread, Comment): 66 | thread = Thread() 67 | thread.name = 'some article name' 68 | session.add(thread) 69 | session.commit() 70 | comment = Comment(content='Some content', thread=thread) 71 | session.add(comment) 72 | session.commit() 73 | session.delete(comment) 74 | session.commit() 75 | session.refresh(thread) 76 | assert thread.comment_count == 0 77 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/weekdays.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from .. import i18n 4 | from ..exceptions import ImproperlyConfigured 5 | from ..primitives import WeekDay, WeekDays 6 | from .bit import BitType 7 | from .scalar_coercible import ScalarCoercible 8 | 9 | 10 | class WeekDaysType(types.TypeDecorator, ScalarCoercible): 11 | """ 12 | WeekDaysType offers way of saving WeekDays objects into database. The 13 | WeekDays objects are converted to bit strings on the way in and back to 14 | WeekDays objects on the way out. 15 | 16 | In order to use WeekDaysType you need to install Babel_ first. 17 | 18 | .. _Babel: https://babel.pocoo.org/ 19 | 20 | :: 21 | 22 | 23 | from sqlalchemy_utils import WeekDaysType, WeekDays 24 | from babel import Locale 25 | 26 | 27 | class Schedule(Base): 28 | __tablename__ = 'schedule' 29 | id = sa.Column(sa.Integer, autoincrement=True) 30 | working_days = sa.Column(WeekDaysType) 31 | 32 | 33 | schedule = Schedule() 34 | schedule.working_days = WeekDays('0001111') 35 | session.add(schedule) 36 | session.commit() 37 | 38 | print schedule.working_days # Thursday, Friday, Saturday, Sunday 39 | 40 | 41 | WeekDaysType also supports scalar coercion: 42 | 43 | :: 44 | 45 | 46 | schedule.working_days = '1110000' 47 | schedule.working_days # WeekDays object 48 | 49 | """ 50 | 51 | impl = BitType(WeekDay.NUM_WEEK_DAYS) 52 | 53 | cache_ok = True 54 | 55 | def __init__(self, *args, **kwargs): 56 | if i18n.babel is None: 57 | raise ImproperlyConfigured( 58 | "'babel' package is required to use 'WeekDaysType'" 59 | ) 60 | 61 | super().__init__(*args, **kwargs) 62 | 63 | @property 64 | def comparator_factory(self): 65 | return self.impl.comparator_factory 66 | 67 | def process_bind_param(self, value, dialect): 68 | if isinstance(value, WeekDays): 69 | value = value.as_bit_string() 70 | 71 | if dialect.name == 'mysql': 72 | return bytes(value, 'utf8') 73 | return value 74 | 75 | def process_result_value(self, value, dialect): 76 | if value is not None: 77 | return WeekDays(value) 78 | 79 | def _coerce(self, value): 80 | if value is not None and not isinstance(value, WeekDays): 81 | return WeekDays(value) 82 | return value 83 | -------------------------------------------------------------------------------- /sqlalchemy_utils/functions/sort_query.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | 3 | from .database import has_unique_index 4 | from .orm import _get_query_compile_state, get_tables 5 | 6 | 7 | def make_order_by_deterministic(query): 8 | """ 9 | Make query order by deterministic (if it isn't already). Order by is 10 | considered deterministic if it contains column that is unique index ( 11 | either it is a primary key or has a unique index). Many times it is design 12 | flaw to order by queries in nondeterministic manner. 13 | 14 | Consider a User model with three fields: id (primary key), favorite color 15 | and email (unique).:: 16 | 17 | 18 | from sqlalchemy_utils import make_order_by_deterministic 19 | 20 | 21 | query = session.query(User).order_by(User.favorite_color) 22 | 23 | query = make_order_by_deterministic(query) 24 | print query # 'SELECT ... ORDER BY "user".favorite_color, "user".id' 25 | 26 | 27 | query = session.query(User).order_by(User.email) 28 | 29 | query = make_order_by_deterministic(query) 30 | print query # 'SELECT ... ORDER BY "user".email' 31 | 32 | 33 | query = session.query(User).order_by(User.id) 34 | 35 | query = make_order_by_deterministic(query) 36 | print query # 'SELECT ... ORDER BY "user".id' 37 | 38 | 39 | .. versionadded: 0.27.1 40 | """ 41 | order_by_func = sa.asc 42 | 43 | try: 44 | order_by_clauses = query._order_by_clauses 45 | except AttributeError: # SQLAlchemy <1.4 46 | order_by_clauses = query._order_by 47 | if not order_by_clauses: 48 | column = None 49 | else: 50 | order_by = order_by_clauses[0] 51 | if isinstance(order_by, sa.sql.elements._label_reference): 52 | order_by = order_by.element 53 | if isinstance(order_by, sa.sql.expression.UnaryExpression): 54 | if order_by.modifier == sa.sql.operators.desc_op: 55 | order_by_func = sa.desc 56 | else: 57 | order_by_func = sa.asc 58 | column = list(order_by.get_children())[0] 59 | else: 60 | column = order_by 61 | 62 | # Skip queries that are ordered by an already deterministic column 63 | if isinstance(column, sa.Column): 64 | try: 65 | if has_unique_index(column): 66 | return query 67 | except TypeError: 68 | pass 69 | 70 | base_table = get_tables(_get_query_compile_state(query)._entities[0])[0] 71 | query = query.order_by(*(order_by_func(c) for c in base_table.c if c.primary_key)) 72 | return query 73 | -------------------------------------------------------------------------------- /tests/functions/test_render.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.functions import ( 5 | mock_engine, 6 | render_expression, 7 | render_statement 8 | ) 9 | 10 | 11 | class TestRender: 12 | 13 | @pytest.fixture 14 | def User(self, Base): 15 | class User(Base): 16 | __tablename__ = 'user' 17 | id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) 18 | name = sa.Column(sa.Unicode(255)) 19 | return User 20 | 21 | @pytest.fixture 22 | def init_models(self, User): 23 | pass 24 | 25 | def test_render_orm_query(self, session, User): 26 | query = session.query(User).filter_by(id=3) 27 | text = render_statement(query) 28 | 29 | assert 'SELECT user.id, user.name' in text 30 | assert 'FROM user' in text 31 | assert 'WHERE user.id = 3' in text 32 | 33 | def test_render_statement(self, session, User): 34 | statement = User.__table__.select().where(User.id == 3) 35 | text = render_statement(statement, bind=session.bind) 36 | 37 | assert 'SELECT user.id, user.name' in text 38 | assert 'FROM user' in text 39 | assert 'WHERE user.id = 3' in text 40 | 41 | def test_render_statement_without_mapper(self, session): 42 | statement = sa.select(sa.text('1')) 43 | text = render_statement(statement, bind=session.bind) 44 | 45 | assert 'SELECT 1' in text 46 | 47 | def test_render_ddl(self, engine, User): 48 | expression = 'User.__table__.create(engine)' 49 | stream = render_expression(expression, engine) 50 | 51 | text = stream.getvalue() 52 | 53 | assert 'CREATE TABLE user' in text 54 | assert 'PRIMARY KEY' in text 55 | 56 | def test_render_mock_ddl(self, engine, User): 57 | # TODO: mock_engine doesn't seem to work with locally scoped variables. 58 | self.engine = engine 59 | with mock_engine('self.engine') as stream: 60 | User.__table__.create(self.engine) 61 | 62 | text = stream.getvalue() 63 | 64 | assert 'CREATE TABLE user' in text 65 | assert 'PRIMARY KEY' in text 66 | 67 | def test_render_statement_hangul(self, User, session): 68 | statement = User.__table__.select().where( 69 | (User.id == 3) 70 | & (User.name == "한글") 71 | ) 72 | text = render_statement(statement, bind=session.bind) 73 | 74 | assert "SELECT user.id, user.name" in text 75 | assert "FROM user" in text 76 | assert "user.id = 3" in text 77 | assert "user.name = '한글'" in text 78 | -------------------------------------------------------------------------------- /sqlalchemy_utils/proxy_dict.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | 3 | 4 | class ProxyDict: 5 | def __init__(self, parent, collection_name, mapping_attr): 6 | self.parent = parent 7 | self.collection_name = collection_name 8 | self.child_class = mapping_attr.class_ 9 | self.key_name = mapping_attr.key 10 | self.cache = {} 11 | 12 | @property 13 | def collection(self): 14 | return getattr(self.parent, self.collection_name) 15 | 16 | def keys(self): 17 | descriptor = getattr(self.child_class, self.key_name) 18 | return [x[0] for x in self.collection.values(descriptor)] 19 | 20 | def __contains__(self, key): 21 | if key in self.cache: 22 | return self.cache[key] is not None 23 | return self.fetch(key) is not None 24 | 25 | def has_key(self, key): 26 | return self.__contains__(key) 27 | 28 | def fetch(self, key): 29 | session = sa.orm.object_session(self.parent) 30 | if session and sa.orm.util.has_identity(self.parent): 31 | obj = self.collection.filter_by(**{self.key_name: key}).first() 32 | self.cache[key] = obj 33 | return obj 34 | 35 | def create_new_instance(self, key): 36 | value = self.child_class(**{self.key_name: key}) 37 | self.collection.append(value) 38 | self.cache[key] = value 39 | return value 40 | 41 | def __getitem__(self, key): 42 | if key in self.cache: 43 | if self.cache[key] is not None: 44 | return self.cache[key] 45 | else: 46 | value = self.fetch(key) 47 | if value: 48 | return value 49 | 50 | return self.create_new_instance(key) 51 | 52 | def __setitem__(self, key, value): 53 | try: 54 | existing = self[key] 55 | self.collection.remove(existing) 56 | except KeyError: 57 | pass 58 | self.collection.append(value) 59 | self.cache[key] = value 60 | 61 | 62 | def proxy_dict(parent, collection_name, mapping_attr): 63 | try: 64 | parent._proxy_dicts 65 | except AttributeError: 66 | parent._proxy_dicts = {} 67 | 68 | try: 69 | return parent._proxy_dicts[collection_name] 70 | except KeyError: 71 | parent._proxy_dicts[collection_name] = ProxyDict( 72 | parent, collection_name, mapping_attr 73 | ) 74 | return parent._proxy_dicts[collection_name] 75 | 76 | 77 | def expire_proxy_dicts(target, context): 78 | if hasattr(target, '_proxy_dicts'): 79 | target._proxy_dicts = {} 80 | 81 | 82 | sa.event.listen(sa.orm.Mapper, 'expire', expire_proxy_dicts) 83 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: [push, pull_request, workflow_dispatch] 3 | jobs: 4 | test: 5 | strategy: 6 | fail-fast: false 7 | matrix: 8 | os: ['ubuntu-latest'] 9 | python-version: 10 | - '3.9' 11 | - '3.10' 12 | - '3.11' 13 | - '3.12' 14 | - '3.13' 15 | tox_env: ['sqlalchemy14', 'sqlalchemy2'] 16 | runs-on: ${{ matrix.os }} 17 | services: 18 | postgres: 19 | image: postgres:latest 20 | env: 21 | POSTGRES_USER: postgres 22 | POSTGRES_PASSWORD: postgres 23 | POSTGRES_DB: sqlalchemy_utils_test 24 | # Set health checks to wait until PostgreSQL has started 25 | options: >- 26 | --health-cmd pg_isready 27 | --health-interval 10s 28 | --health-timeout 5s 29 | --health-retries 5 30 | ports: 31 | - 5432:5432 32 | mysql: 33 | image: mysql:latest 34 | env: 35 | MYSQL_ALLOW_EMPTY_PASSWORD: yes 36 | MYSQL_DATABASE: sqlalchemy_utils_test 37 | ports: 38 | - 3306:3306 39 | mssql: 40 | image: mcr.microsoft.com/mssql/server:2017-latest 41 | env: 42 | ACCEPT_EULA: Y 43 | SA_PASSWORD: Strong_Passw0rd 44 | # Set health checks to wait until SQL Server has started 45 | options: >- 46 | --health-cmd "/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P ${SA_PASSWORD} -Q 'SELECT 1;' -b" 47 | --health-interval 10s 48 | --health-timeout 5s 49 | --health-retries 5 50 | ports: 51 | - 1433:1433 52 | steps: 53 | - uses: actions/checkout@v6 54 | - name: Set up Python 55 | uses: actions/setup-python@v6 56 | with: 57 | python-version: ${{ matrix.python-version }} 58 | - name: Install MS SQL stuff 59 | run: bash .ci/install_mssql.sh 60 | - name: Add hstore extension to the sqlalchemy_utils_test database 61 | env: 62 | PGHOST: localhost 63 | PGPASSWORD: postgres 64 | PGPORT: 5432 65 | run: psql -U postgres -d sqlalchemy_utils_test -c 'CREATE EXTENSION hstore;' 66 | - name: Install tox 67 | run: pip install tox 68 | - name: Run tests 69 | env: 70 | SQLALCHEMY_UTILS_TEST_POSTGRESQL_PASSWORD: postgres 71 | run: tox -e ${{ matrix.tox_env }} 72 | 73 | docs: 74 | runs-on: 'ubuntu-latest' 75 | steps: 76 | - uses: actions/checkout@v6 77 | - name: Set up Python 78 | uses: actions/setup-python@v6 79 | with: 80 | python-version: '3.13' 81 | - name: Install tox 82 | run: pip install tox 83 | - name: Build documentation 84 | run: tox -e docs 85 | -------------------------------------------------------------------------------- /tests/functions/test_get_tables.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import get_tables 5 | from sqlalchemy_utils.functions.orm import _get_query_compile_state 6 | 7 | 8 | @pytest.fixture 9 | def TextItem(Base): 10 | class TextItem(Base): 11 | __tablename__ = 'text_item' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | name = sa.Column(sa.Unicode(255)) 14 | type = sa.Column(sa.Unicode(255)) 15 | 16 | __mapper_args__ = { 17 | 'polymorphic_on': type, 18 | 'with_polymorphic': '*' 19 | } 20 | return TextItem 21 | 22 | 23 | @pytest.fixture 24 | def Article(TextItem): 25 | class Article(TextItem): 26 | __tablename__ = 'article' 27 | id = sa.Column( 28 | sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True 29 | ) 30 | __mapper_args__ = { 31 | 'polymorphic_identity': 'article' 32 | } 33 | return Article 34 | 35 | 36 | @pytest.fixture 37 | def init_models(TextItem, Article): 38 | pass 39 | 40 | 41 | class TestGetTables: 42 | 43 | def test_child_class_using_join_table_inheritance(self, TextItem, Article): 44 | assert get_tables(Article) == [ 45 | TextItem.__table__, 46 | Article.__table__ 47 | ] 48 | 49 | def test_entity_using_with_polymorphic(self, TextItem, Article): 50 | assert get_tables(TextItem) == [ 51 | TextItem.__table__, 52 | Article.__table__ 53 | ] 54 | 55 | def test_instrumented_attribute(self, TextItem): 56 | assert get_tables(TextItem.name) == [ 57 | TextItem.__table__, 58 | ] 59 | 60 | def test_polymorphic_instrumented_attribute(self, TextItem, Article): 61 | assert get_tables(Article.id) == [ 62 | TextItem.__table__, 63 | Article.__table__ 64 | ] 65 | 66 | def test_column(self, Article): 67 | assert get_tables(Article.__table__.c.id) == [ 68 | Article.__table__ 69 | ] 70 | 71 | def test_mapper_entity_with_class(self, session, TextItem, Article): 72 | query = session.query(Article) 73 | entity = _get_query_compile_state(query)._entities[0] 74 | assert get_tables(entity) == [TextItem.__table__, Article.__table__] 75 | 76 | def test_mapper_entity_with_mapper(self, session, TextItem, Article): 77 | query = session.query(sa.inspect(Article)) 78 | entity = _get_query_compile_state(query)._entities[0] 79 | assert get_tables(entity) == [TextItem.__table__, Article.__table__] 80 | 81 | def test_column_entity(self, session, TextItem, Article): 82 | query = session.query(Article.id) 83 | entity = _get_query_compile_state(query)._entities[0] 84 | assert get_tables(entity) == [TextItem.__table__, Article.__table__] 85 | -------------------------------------------------------------------------------- /tests/aggregate/test_o2m_m2m.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | import sqlalchemy.orm 4 | 5 | from sqlalchemy_utils import aggregated 6 | 7 | 8 | @pytest.fixture 9 | def Category(Base): 10 | class Category(Base): 11 | __tablename__ = 'category' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | name = sa.Column(sa.Unicode(255)) 14 | return Category 15 | 16 | 17 | @pytest.fixture 18 | def Catalog(Base, Category): 19 | class Catalog(Base): 20 | __tablename__ = 'catalog' 21 | id = sa.Column(sa.Integer, primary_key=True) 22 | name = sa.Column(sa.Unicode(255)) 23 | 24 | @aggregated( 25 | 'products.categories', 26 | sa.Column(sa.Integer, default=0) 27 | ) 28 | def category_count(self): 29 | return sa.func.count(sa.distinct(Category.id)) 30 | return Catalog 31 | 32 | 33 | @pytest.fixture 34 | def Product(Base, Catalog, Category): 35 | product_categories = sa.Table( 36 | 'category_product', 37 | Base.metadata, 38 | sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), 39 | sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) 40 | ) 41 | 42 | class Product(Base): 43 | __tablename__ = 'product' 44 | id = sa.Column(sa.Integer, primary_key=True) 45 | name = sa.Column(sa.Unicode(255)) 46 | price = sa.Column(sa.Numeric) 47 | 48 | catalog_id = sa.Column( 49 | sa.Integer, sa.ForeignKey('catalog.id') 50 | ) 51 | 52 | catalog = sa.orm.relationship( 53 | Catalog, 54 | backref='products' 55 | ) 56 | 57 | categories = sa.orm.relationship( 58 | Category, 59 | backref='products', 60 | secondary=product_categories 61 | ) 62 | return Product 63 | 64 | 65 | @pytest.fixture 66 | def init_models(Category, Catalog, Product): 67 | pass 68 | 69 | 70 | @pytest.mark.usefixtures('postgresql_dsn') 71 | class TestAggregateOneToManyAndManyToMany: 72 | 73 | def test_insert(self, session, Category, Catalog, Product): 74 | category = Category() 75 | session.add(category) 76 | products = [ 77 | Product(categories=[category]), 78 | Product(categories=[category]) 79 | ] 80 | [session.add(product) for product in products] 81 | catalog = Catalog(products=products) 82 | session.add(catalog) 83 | products2 = [ 84 | Product(categories=[category]), 85 | Product(categories=[category]) 86 | ] 87 | [session.add(product) for product in products2] 88 | catalog2 = Catalog(products=products2) 89 | session.add(catalog2) 90 | session.commit() 91 | assert catalog.category_count == 1 92 | assert catalog2.category_count == 1 93 | -------------------------------------------------------------------------------- /sqlalchemy_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregates import aggregated # noqa 2 | from .asserts import ( # noqa 3 | assert_max_length, 4 | assert_max_value, 5 | assert_min_value, 6 | assert_non_nullable, 7 | assert_nullable, 8 | ) 9 | from .exceptions import ImproperlyConfigured # noqa 10 | from .expressions import Asterisk, row_to_json # noqa 11 | from .functions import ( # noqa 12 | cast_if, 13 | create_database, 14 | create_mock_engine, 15 | database_exists, 16 | dependent_objects, 17 | drop_database, 18 | escape_like, 19 | get_bind, 20 | get_class_by_table, 21 | get_column_key, 22 | get_columns, 23 | get_declarative_base, 24 | get_fk_constraint_for_columns, 25 | get_hybrid_properties, 26 | get_mapper, 27 | get_primary_keys, 28 | get_referencing_foreign_keys, 29 | get_tables, 30 | get_type, 31 | group_foreign_keys, 32 | has_changes, 33 | has_index, 34 | has_unique_index, 35 | identity, 36 | is_loaded, 37 | json_sql, 38 | jsonb_sql, 39 | merge_references, 40 | mock_engine, 41 | naturally_equivalent, 42 | render_expression, 43 | render_statement, 44 | table_name, 45 | ) 46 | from .generic import generic_relationship # noqa 47 | from .i18n import TranslationHybrid # noqa 48 | from .listeners import ( # noqa 49 | auto_delete_orphans, 50 | coercion_listener, 51 | force_auto_coercion, 52 | force_instant_defaults, 53 | ) 54 | from .models import generic_repr, Timestamp # noqa 55 | from .observer import observes # noqa 56 | from .primitives import Country, Currency, Ltree, WeekDay, WeekDays # noqa 57 | from .proxy_dict import proxy_dict, ProxyDict # noqa 58 | from .query_chain import QueryChain # noqa 59 | from .types import ( # noqa 60 | ArrowType, 61 | Choice, 62 | ChoiceType, 63 | ColorType, 64 | CompositeType, 65 | CountryType, 66 | CurrencyType, 67 | DateRangeType, 68 | DateTimeRangeType, 69 | EmailType, 70 | EncryptedType, 71 | EnrichedDateTimeType, 72 | EnrichedDateType, 73 | instrumented_list, 74 | InstrumentedList, 75 | Int8RangeType, 76 | IntRangeType, 77 | IPAddressType, 78 | JSONType, 79 | LocaleType, 80 | LtreeType, 81 | NumericRangeType, 82 | Password, 83 | PasswordType, 84 | PhoneNumber, 85 | PhoneNumberParseException, 86 | PhoneNumberType, 87 | register_composites, 88 | remove_composite_listeners, 89 | ScalarListException, 90 | ScalarListType, 91 | StringEncryptedType, 92 | TimezoneType, 93 | TSVectorType, 94 | URLType, 95 | UUIDType, 96 | WeekDaysType, 97 | ) 98 | from .view import ( # noqa 99 | create_materialized_view, 100 | create_view, 101 | refresh_materialized_view, 102 | ) 103 | 104 | __version__ = '0.42.0' 105 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = 3 | py{3.9, 3.10, 3.11, 3.12, 3.13}-sqlalchemy{14, 2} 4 | ruff 5 | docs 6 | labels = 7 | update = update-{requirements, pre-commit} 8 | 9 | [testenv] 10 | commands = 11 | py.test sqlalchemy_utils --cov=sqlalchemy_utils --cov-report=xml tests 12 | deps = 13 | .[test_all] 14 | pytest-cov 15 | sqlalchemy14: SQLAlchemy>=1.4,<1.5 16 | sqlalchemy2: SQLAlchemy>=2 17 | ; It's sometimes necessary to test against specific sqlalchemy versions 18 | ; to verify bug or feature behavior before or after a specific version. 19 | ; 20 | ; You can choose the sqlalchemy version using this command syntax: 21 | ; $ tox -e py39-sqlalchemy1_4_xx 22 | ; 23 | ; sqlalchemy 1.4.19 through 1.4.23 have JSON-related TypeErrors. See #543. 24 | sqlalchemy1_4_19: SQLAlchemy==1.4.19 25 | ; sqlalchemy <1.4.28 threw a DeprecationWarning when copying a URL. See #573. 26 | sqlalchemy1_4_27: SQLAlchemy==1.4.27 27 | ; sqlalchemy 1.4.30 introduced UUID literal quoting. See #580. 28 | sqlalchemy1_4_29: SQLAlchemy==1.4.29 29 | setenv = 30 | SQLALCHEMY_WARN_20 = true 31 | passenv = 32 | SQLALCHEMY_UTILS_TEST_* 33 | recreate = True 34 | 35 | [testenv:ruff] 36 | skip_install = True 37 | recreate = False 38 | deps = pre-commit 39 | commands = 40 | pre-commit run --all ruff-check 41 | pre-commit run --all ruff-format 42 | 43 | [testenv:docs] 44 | base_python = py3.13 45 | recreate = False 46 | deps = -r requirements/docs/requirements.txt 47 | commands = sphinx-build docs/ build/docs 48 | 49 | [pytest] 50 | filterwarnings = 51 | error 52 | ; Ignore ResourceWarnings caused by unclosed sockets in pg8000. 53 | ; These are caught and re-raised as pytest.PytestUnraisableExceptionWarnings. 54 | ignore:Exception ignored:pytest.PytestUnraisableExceptionWarning 55 | ; Ignore DeprecationWarnings caused by pg8000. 56 | ignore:distutils Version classes are deprecated.:DeprecationWarning 57 | ; Ignore warnings about passlib's use of the crypt module (Python 3.11). 58 | ignore:'crypt' is deprecated and slated for removal:DeprecationWarning 59 | ; Ignore Python 3.12 UTC deprecation warnings caused by pendulum. 60 | ignore:datetime.datetime.utcfromtimestamp:DeprecationWarning 61 | 62 | 63 | [testenv:update-{requirements, pre-commit}] 64 | base_python = py3.13 65 | recreate = true 66 | description = Update tool dependency versions 67 | skip_install = true 68 | deps = 69 | requirements: poetry 70 | requirements: poetry-plugin-export 71 | pre-commit: pre-commit 72 | commands = 73 | # Update requirements files 74 | requirements: poetry update --directory="requirements/docs" --lock 75 | requirements: poetry export --directory="requirements/docs" --output="requirements.txt" --without-hashes 76 | 77 | # Update pre-commit hook versions 78 | pre-commit: pre-commit autoupdate 79 | 80 | # Run pre-commit immediately, but ignore its exit code 81 | pre-commit: - pre-commit run -a 82 | -------------------------------------------------------------------------------- /tests/types/test_arrow.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | from dateutil import tz 6 | 7 | from sqlalchemy_utils.types import arrow 8 | 9 | 10 | @pytest.fixture 11 | def Article(Base): 12 | class Article(Base): 13 | __tablename__ = 'article' 14 | id = sa.Column(sa.Integer, primary_key=True) 15 | created_at = sa.Column(arrow.ArrowType) 16 | published_at = sa.Column(arrow.ArrowType(timezone=True)) 17 | published_at_dt = sa.Column(sa.DateTime(timezone=True)) 18 | return Article 19 | 20 | 21 | @pytest.fixture 22 | def init_models(Article): 23 | pass 24 | 25 | 26 | @pytest.mark.skipif('arrow.arrow is None') 27 | class TestArrowDateTimeType: 28 | def test_parameter_processing(self, session, Article): 29 | article = Article( 30 | created_at=arrow.arrow.get(datetime(2000, 11, 1)) 31 | ) 32 | 33 | session.add(article) 34 | session.commit() 35 | 36 | article = session.query(Article).first() 37 | assert article.created_at.datetime 38 | 39 | def test_string_coercion(self, Article): 40 | article = Article( 41 | created_at='2013-01-01' 42 | ) 43 | assert article.created_at.year == 2013 44 | 45 | def test_utc(self, session, Article): 46 | time = arrow.arrow.utcnow() 47 | article = Article(created_at=time) 48 | session.add(article) 49 | assert article.created_at == time 50 | session.commit() 51 | assert article.created_at == time 52 | 53 | def test_other_tz(self, session, Article): 54 | time = arrow.arrow.utcnow() 55 | local = time.to('US/Pacific') 56 | article = Article(created_at=local) 57 | session.add(article) 58 | assert article.created_at == time == local 59 | session.commit() 60 | assert article.created_at == time 61 | 62 | def test_literal_param(self, session, Article): 63 | clause = Article.created_at > datetime(2015, 1, 1) 64 | compiled = str(clause.compile(compile_kwargs={"literal_binds": True})) 65 | assert compiled == "article.created_at > '2015-01-01 00:00:00'" 66 | 67 | @pytest.mark.usefixtures('postgresql_dsn') 68 | def test_timezone(self, session, Article): 69 | timezone = tz.gettz('Europe/Stockholm') 70 | dt = arrow.arrow.get(datetime(2015, 1, 1, 15, 30, 45), timezone) 71 | article = Article(published_at=dt, published_at_dt=dt.datetime) 72 | 73 | session.add(article) 74 | session.commit() 75 | session.expunge_all() 76 | 77 | item = session.query(Article).one() 78 | assert item.published_at.datetime == item.published_at_dt 79 | assert item.published_at.to(timezone) == dt 80 | 81 | def test_compilation(self, Article, session): 82 | query = sa.select(Article.created_at) 83 | # the type should be cacheable and not throw exception 84 | session.execute(query) 85 | -------------------------------------------------------------------------------- /tests/types/test_scalar_list.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | import sqlalchemy.exc 4 | 5 | from sqlalchemy_utils import ScalarListType 6 | 7 | 8 | class TestScalarIntegerList: 9 | 10 | @pytest.fixture 11 | def User(self, Base): 12 | class User(Base): 13 | __tablename__ = 'user' 14 | id = sa.Column(sa.Integer, primary_key=True) 15 | some_list = sa.Column(ScalarListType(int)) 16 | 17 | def __repr__(self): 18 | return 'User(%r)' % self.id 19 | 20 | return User 21 | 22 | @pytest.fixture 23 | def init_models(self, User): 24 | pass 25 | 26 | def test_save_integer_list(self, session, User): 27 | user = User( 28 | some_list=[1, 2, 3, 4] 29 | ) 30 | 31 | session.add(user) 32 | session.commit() 33 | 34 | user = session.query(User).first() 35 | assert user.some_list == [1, 2, 3, 4] 36 | 37 | 38 | class TestScalarUnicodeList: 39 | 40 | @pytest.fixture 41 | def User(self, Base): 42 | class User(Base): 43 | __tablename__ = 'user' 44 | id = sa.Column(sa.Integer, primary_key=True) 45 | some_list = sa.Column(ScalarListType(str)) 46 | 47 | def __repr__(self): 48 | return 'User(%r)' % self.id 49 | 50 | return User 51 | 52 | @pytest.fixture 53 | def init_models(self, User): 54 | pass 55 | 56 | def test_throws_exception_if_using_separator_in_list_values( 57 | self, 58 | session, 59 | User 60 | ): 61 | user = User( 62 | some_list=[','] 63 | ) 64 | 65 | session.add(user) 66 | with pytest.raises(sa.exc.StatementError) as db_err: 67 | session.commit() 68 | assert ( 69 | "List values can't contain string ',' (its being used as " 70 | "separator. If you wish for scalar list values to contain " 71 | "these strings, use a different separator string.)" 72 | ) in str(db_err.value) 73 | 74 | def test_save_unicode_list(self, session, User): 75 | user = User( 76 | some_list=['1', '2', '3', '4'] 77 | ) 78 | 79 | session.add(user) 80 | session.commit() 81 | 82 | user = session.query(User).first() 83 | assert user.some_list == ['1', '2', '3', '4'] 84 | 85 | def test_save_and_retrieve_empty_list(self, session, User): 86 | user = User( 87 | some_list=[] 88 | ) 89 | 90 | session.add(user) 91 | session.commit() 92 | 93 | user = session.query(User).first() 94 | assert user.some_list == [] 95 | 96 | def test_compilation(self, User, session): 97 | query = sa.select(User.some_list) 98 | # the type should be cacheable and not throw exception 99 | session.execute(query) 100 | -------------------------------------------------------------------------------- /tests/types/test_enriched_datetime_pendulum.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | 6 | from sqlalchemy_utils.types.enriched_datetime import ( 7 | enriched_datetime_type, 8 | pendulum_datetime 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def User(Base): 14 | class User(Base): 15 | __tablename__ = 'users' 16 | id = sa.Column(sa.Integer, primary_key=True) 17 | created_at = sa.Column( 18 | enriched_datetime_type.EnrichedDateTimeType( 19 | datetime_processor=pendulum_datetime.PendulumDateTime, 20 | )) 21 | return User 22 | 23 | 24 | @pytest.fixture 25 | def init_models(User): 26 | pass 27 | 28 | 29 | @pytest.mark.skipif('pendulum_datetime.pendulum is None') 30 | class TestPendulumDateTimeType: 31 | 32 | def test_parameter_processing(self, session, User): 33 | user = User( 34 | created_at=pendulum_datetime.pendulum.datetime(1995, 7, 11) 35 | ) 36 | 37 | session.add(user) 38 | session.commit() 39 | 40 | user = session.query(User).first() 41 | assert isinstance(user.created_at, datetime) 42 | 43 | def test_int_coercion(self, User): 44 | user = User( 45 | created_at=1367900664 46 | ) 47 | assert user.created_at.year == 2013 48 | 49 | def test_float_coercion(self, User): 50 | user = User( 51 | created_at=1367900664.0 52 | ) 53 | assert user.created_at.year == 2013 54 | 55 | def test_string_coercion(self, User): 56 | user = User( 57 | created_at='1367900664' 58 | ) 59 | assert user.created_at.year == 2013 60 | 61 | def test_utc(self, session, User): 62 | time = pendulum_datetime.pendulum.now("UTC") 63 | user = User(created_at=time) 64 | session.add(user) 65 | assert user.created_at == time 66 | session.commit() 67 | assert user.created_at == time 68 | 69 | @pytest.mark.usefixtures('postgresql_dsn') 70 | def test_utc_postgres(self, session, User): 71 | time = pendulum_datetime.pendulum.now("UTC") 72 | user = User(created_at=time) 73 | session.add(user) 74 | assert user.created_at == time 75 | session.commit() 76 | assert user.created_at == time 77 | 78 | def test_other_tz(self, session, User): 79 | time = pendulum_datetime.pendulum.now("UTC") 80 | local = time.in_tz('Asia/Tokyo') 81 | user = User(created_at=local) 82 | session.add(user) 83 | assert user.created_at == time == local 84 | session.commit() 85 | assert user.created_at == time 86 | 87 | def test_literal_param(self, session, User): 88 | clause = User.created_at > datetime(2015, 1, 1) 89 | compiled = str(clause.compile(compile_kwargs={"literal_binds": True})) 90 | assert compiled == "users.created_at > '2015-01-01 00:00:00'" 91 | -------------------------------------------------------------------------------- /tests/aggregate/test_m2m_m2m.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | import sqlalchemy.orm 4 | 5 | from sqlalchemy_utils import aggregated 6 | 7 | 8 | @pytest.fixture 9 | def Category(Base): 10 | class Category(Base): 11 | __tablename__ = 'category' 12 | id = sa.Column(sa.Integer, primary_key=True) 13 | name = sa.Column(sa.Unicode(255)) 14 | return Category 15 | 16 | 17 | @pytest.fixture 18 | def Catalog(Base, Category): 19 | class Catalog(Base): 20 | __tablename__ = 'catalog' 21 | id = sa.Column(sa.Integer, primary_key=True) 22 | name = sa.Column(sa.Unicode(255)) 23 | 24 | @aggregated( 25 | 'products.categories', 26 | sa.Column(sa.Integer, default=0) 27 | ) 28 | def category_count(self): 29 | return sa.func.count(sa.distinct(Category.id)) 30 | return Catalog 31 | 32 | 33 | @pytest.fixture 34 | def Product(Base, Catalog, Category): 35 | catalog_products = sa.Table( 36 | 'catalog_product', 37 | Base.metadata, 38 | sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), 39 | sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) 40 | ) 41 | 42 | product_categories = sa.Table( 43 | 'category_product', 44 | Base.metadata, 45 | sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), 46 | sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) 47 | ) 48 | 49 | class Product(Base): 50 | __tablename__ = 'product' 51 | id = sa.Column(sa.Integer, primary_key=True) 52 | name = sa.Column(sa.Unicode(255)) 53 | price = sa.Column(sa.Numeric) 54 | 55 | catalog_id = sa.Column( 56 | sa.Integer, sa.ForeignKey('catalog.id') 57 | ) 58 | 59 | catalogs = sa.orm.relationship( 60 | Catalog, 61 | backref='products', 62 | secondary=catalog_products 63 | ) 64 | 65 | categories = sa.orm.relationship( 66 | Category, 67 | backref='products', 68 | secondary=product_categories 69 | ) 70 | return Product 71 | 72 | 73 | @pytest.fixture 74 | def init_models(Category, Catalog, Product): 75 | pass 76 | 77 | 78 | @pytest.mark.usefixtures('postgresql_dsn') 79 | class TestAggregateManyToManyAndManyToMany: 80 | 81 | def test_insert(self, session, Product, Category, Catalog): 82 | category = Category() 83 | session.add(category) 84 | products = [ 85 | Product(categories=[category]), 86 | Product(categories=[category]) 87 | ] 88 | [session.add(product) for product in products] 89 | catalog = Catalog(products=products) 90 | session.add(catalog) 91 | catalog2 = Catalog(products=products) 92 | session.add(catalog2) 93 | session.commit() 94 | assert catalog.category_count == 1 95 | assert catalog2.category_count == 1 96 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/scalar_list.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | from sqlalchemy import types 3 | 4 | 5 | class ScalarListException(Exception): 6 | pass 7 | 8 | 9 | class ScalarListType(types.TypeDecorator): 10 | """ 11 | ScalarListType type provides convenient way for saving multiple scalar 12 | values in one column. ScalarListType works like list on python side and 13 | saves the result as comma-separated list in the database (custom separators 14 | can also be used). 15 | 16 | Example :: 17 | 18 | 19 | from sqlalchemy_utils import ScalarListType 20 | 21 | 22 | class User(Base): 23 | __tablename__ = 'user' 24 | id = sa.Column(sa.Integer, autoincrement=True) 25 | hobbies = sa.Column(ScalarListType()) 26 | 27 | 28 | user = User() 29 | user.hobbies = ['football', 'ice_hockey'] 30 | session.commit() 31 | 32 | 33 | You can easily set up integer lists too: 34 | 35 | :: 36 | 37 | 38 | from sqlalchemy_utils import ScalarListType 39 | 40 | 41 | class Player(Base): 42 | __tablename__ = 'player' 43 | id = sa.Column(sa.Integer, autoincrement=True) 44 | points = sa.Column(ScalarListType(int)) 45 | 46 | 47 | player = Player() 48 | player.points = [11, 12, 8, 80] 49 | session.commit() 50 | 51 | 52 | ScalarListType is always stored as text. To use an array field on 53 | PostgreSQL database use variant construct:: 54 | 55 | from sqlalchemy_utils import ScalarListType 56 | 57 | 58 | class Player(Base): 59 | __tablename__ = 'player' 60 | id = sa.Column(sa.Integer, autoincrement=True) 61 | points = sa.Column( 62 | ARRAY(Integer).with_variant(ScalarListType(int), 'sqlite') 63 | ) 64 | 65 | 66 | """ 67 | 68 | impl = sa.UnicodeText() 69 | 70 | cache_ok = True 71 | 72 | def __init__(self, coerce_func=str, separator=','): 73 | self.separator = str(separator) 74 | self.coerce_func = coerce_func 75 | 76 | def process_bind_param(self, value, dialect): 77 | # Convert list of values to unicode separator-separated list 78 | # Example: [1, 2, 3, 4] -> '1, 2, 3, 4' 79 | if value is not None: 80 | if any(self.separator in str(item) for item in value): 81 | raise ScalarListException( 82 | "List values can't contain string '%s' (its being used as " 83 | 'separator. If you wish for scalar list values to contain ' 84 | 'these strings, use a different separator string.)' % self.separator 85 | ) 86 | return self.separator.join(map(str, value)) 87 | 88 | def process_result_value(self, value, dialect): 89 | if value is not None: 90 | if value == '': 91 | return [] 92 | # coerce each value 93 | return list(map(self.coerce_func, value.split(self.separator))) 94 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | import pytest 4 | import sqlalchemy as sa 5 | import sqlalchemy.orm 6 | 7 | from sqlalchemy_utils import generic_repr, Timestamp 8 | 9 | 10 | class TestTimestamp: 11 | @pytest.fixture 12 | def Article(self, Base): 13 | class Article(Base, Timestamp): 14 | __tablename__ = 'article' 15 | id = sa.Column(sa.Integer, primary_key=True) 16 | name = sa.Column(sa.Unicode(255), default='Some article') 17 | return Article 18 | 19 | def test_created(self, session, Article): 20 | then = datetime.now(tz=timezone.utc).replace(tzinfo=None) 21 | article = Article() 22 | 23 | session.add(article) 24 | session.commit() 25 | 26 | assert article.created >= then 27 | assert article.created <= datetime.now(tz=timezone.utc).replace(tzinfo=None) 28 | 29 | def test_updated(self, session, Article): 30 | article = Article() 31 | 32 | session.add(article) 33 | session.commit() 34 | 35 | then = datetime.now(tz=timezone.utc).replace(tzinfo=None) 36 | article.name = "Something" 37 | 38 | session.commit() 39 | 40 | assert article.updated >= then 41 | assert article.updated <= datetime.now(tz=timezone.utc).replace(tzinfo=None) 42 | 43 | 44 | class TestGenericRepr: 45 | @pytest.fixture 46 | def Article(self, Base): 47 | class Article(Base): 48 | __tablename__ = 'article' 49 | id = sa.Column(sa.Integer, primary_key=True) 50 | name = sa.Column(sa.Unicode(255), default='Some article') 51 | return Article 52 | 53 | def test_repr(self, Article): 54 | """Representation of a basic model.""" 55 | Article = generic_repr(Article) 56 | article = Article(id=1, name='Foo') 57 | expected_repr = "Article(id=1, name='Foo')" 58 | actual_repr = repr(article) 59 | 60 | assert actual_repr == expected_repr 61 | 62 | def test_repr_partial(self, Article): 63 | """Representation of a basic model with selected fields.""" 64 | Article = generic_repr('id')(Article) 65 | article = Article(id=1, name='Foo') 66 | expected_repr = 'Article(id=1)' 67 | actual_repr = repr(article) 68 | 69 | assert actual_repr == expected_repr 70 | 71 | def test_not_loaded(self, session, Article): 72 | """:py:func:`~sqlalchemy_utils.models.generic_repr` doesn't force 73 | execution of additional queries if some fields are not loaded and 74 | instead represents them as "". 75 | """ 76 | Article = generic_repr(Article) 77 | article = Article(name='Foo') 78 | session.add(article) 79 | session.commit() 80 | 81 | article = session.query(Article).options(sa.orm.defer(Article.name)).one() 82 | actual_repr = repr(article) 83 | 84 | expected_repr = f'Article(id={article.id}, name=)' 85 | assert actual_repr == expected_repr 86 | -------------------------------------------------------------------------------- /tests/generic_relationship/__init__.py: -------------------------------------------------------------------------------- 1 | class GenericRelationshipTestCase: 2 | def test_set_as_none(self, Event): 3 | event = Event() 4 | event.object = None 5 | assert event.object is None 6 | 7 | def test_set_manual_and_get(self, session, User, Event): 8 | user = User() 9 | 10 | session.add(user) 11 | session.commit() 12 | 13 | event = Event() 14 | event.object_id = user.id 15 | event.object_type = str(type(user).__name__) 16 | 17 | assert event.object is None 18 | 19 | session.add(event) 20 | session.commit() 21 | 22 | assert event.object == user 23 | 24 | def test_set_and_get(self, session, User, Event): 25 | user = User() 26 | 27 | session.add(user) 28 | session.commit() 29 | 30 | event = Event(object=user) 31 | 32 | assert event.object_id == user.id 33 | assert event.object_type == type(user).__name__ 34 | 35 | session.add(event) 36 | session.commit() 37 | 38 | assert event.object == user 39 | 40 | def test_compare_instance(self, session, User, Event): 41 | user1 = User() 42 | user2 = User() 43 | 44 | session.add_all([user1, user2]) 45 | session.commit() 46 | 47 | event = Event(object=user1) 48 | 49 | session.add(event) 50 | session.commit() 51 | 52 | assert event.object == user1 53 | assert event.object != user2 54 | 55 | def test_compare_query(self, session, User, Event): 56 | user1 = User() 57 | user2 = User() 58 | 59 | session.add_all([user1, user2]) 60 | session.commit() 61 | 62 | event = Event(object=user1) 63 | 64 | session.add(event) 65 | session.commit() 66 | 67 | q = session.query(Event) 68 | assert q.filter_by(object=user1).first() is not None 69 | assert q.filter_by(object=user2).first() is None 70 | assert q.filter(Event.object == user2).first() is None 71 | 72 | def test_compare_not_query(self, session, User, Event): 73 | user1 = User() 74 | user2 = User() 75 | 76 | session.add_all([user1, user2]) 77 | session.commit() 78 | 79 | event = Event(object=user1) 80 | 81 | session.add(event) 82 | session.commit() 83 | 84 | q = session.query(Event) 85 | assert q.filter(Event.object != user2).first() is not None 86 | 87 | def test_compare_type(self, session, User, Event): 88 | user1 = User() 89 | user2 = User() 90 | 91 | session.add_all([user1, user2]) 92 | session.commit() 93 | 94 | event1 = Event(object=user1) 95 | event2 = Event(object=user2) 96 | 97 | session.add_all([event1, event2]) 98 | session.commit() 99 | 100 | statement = Event.object.is_type(User) 101 | q = session.query(Event).filter(statement) 102 | assert q.first() is not None 103 | -------------------------------------------------------------------------------- /tests/observes/test_o2o_o2o_o2o.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.observer import observes 5 | 6 | 7 | @pytest.fixture 8 | def Catalog(Base): 9 | class Catalog(Base): 10 | __tablename__ = 'catalog' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | product_price = sa.Column(sa.Integer) 13 | 14 | @observes('category.sub_category.product') 15 | def product_observer(self, product): 16 | self.product_price = product.price if product else None 17 | 18 | category = sa.orm.relationship( 19 | 'Category', 20 | uselist=False, 21 | backref='catalog' 22 | ) 23 | return Catalog 24 | 25 | 26 | @pytest.fixture 27 | def Category(Base): 28 | class Category(Base): 29 | __tablename__ = 'category' 30 | id = sa.Column(sa.Integer, primary_key=True) 31 | catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) 32 | 33 | sub_category = sa.orm.relationship( 34 | 'SubCategory', 35 | uselist=False, 36 | backref='category' 37 | ) 38 | return Category 39 | 40 | 41 | @pytest.fixture 42 | def SubCategory(Base): 43 | class SubCategory(Base): 44 | __tablename__ = 'sub_category' 45 | id = sa.Column(sa.Integer, primary_key=True) 46 | category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) 47 | product = sa.orm.relationship( 48 | 'Product', 49 | uselist=False, 50 | backref='sub_category' 51 | ) 52 | return SubCategory 53 | 54 | 55 | @pytest.fixture 56 | def Product(Base): 57 | class Product(Base): 58 | __tablename__ = 'product' 59 | id = sa.Column(sa.Integer, primary_key=True) 60 | price = sa.Column(sa.Integer) 61 | 62 | sub_category_id = sa.Column( 63 | sa.Integer, sa.ForeignKey('sub_category.id') 64 | ) 65 | return Product 66 | 67 | 68 | @pytest.fixture 69 | def init_models(Catalog, Category, SubCategory, Product): 70 | pass 71 | 72 | 73 | @pytest.fixture 74 | def catalog(session, Catalog, Category, SubCategory, Product): 75 | sub_category = SubCategory(product=Product(price=123)) 76 | category = Category(sub_category=sub_category) 77 | catalog = Catalog(category=category) 78 | session.add(catalog) 79 | session.flush() 80 | return catalog 81 | 82 | 83 | @pytest.mark.usefixtures('postgresql_dsn') 84 | class TestObservesForOneToOneToOneToOne: 85 | 86 | def test_simple_insert(self, catalog): 87 | assert catalog.product_price == 123 88 | 89 | def test_replace_leaf_object(self, catalog, session, Product): 90 | product = Product(price=44) 91 | catalog.category.sub_category.product = product 92 | session.flush() 93 | assert catalog.product_price == 44 94 | 95 | def test_delete_leaf_object(self, catalog, session): 96 | session.delete(catalog.category.sub_category.product) 97 | session.flush() 98 | assert catalog.product_price is None 99 | -------------------------------------------------------------------------------- /tests/aggregate/test_multiple_aggregates_per_class.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils.aggregates import aggregated 5 | 6 | 7 | @pytest.fixture 8 | def Comment(Base): 9 | class Comment(Base): 10 | __tablename__ = 'comment' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | content = sa.Column(sa.Unicode(255)) 13 | thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) 14 | return Comment 15 | 16 | 17 | @pytest.fixture 18 | def Thread(Base, Comment): 19 | class Thread(Base): 20 | __tablename__ = 'thread' 21 | id = sa.Column(sa.Integer, primary_key=True) 22 | name = sa.Column(sa.Unicode(255)) 23 | 24 | @aggregated( 25 | 'comments', 26 | sa.Column(sa.Integer, default=0) 27 | ) 28 | def comment_count(self): 29 | return sa.func.count('1') 30 | 31 | @aggregated('comments', sa.Column(sa.Integer)) 32 | def last_comment_id(self): 33 | return sa.func.max(Comment.id) 34 | 35 | comments = sa.orm.relationship( 36 | 'Comment', 37 | backref='thread' 38 | ) 39 | 40 | Thread.last_comment = sa.orm.relationship( 41 | 'Comment', 42 | primaryjoin='Thread.last_comment_id == Comment.id', 43 | foreign_keys=[Thread.last_comment_id], 44 | viewonly=True 45 | ) 46 | return Thread 47 | 48 | 49 | @pytest.fixture 50 | def init_models(Comment, Thread): 51 | pass 52 | 53 | 54 | class TestAggregateValueGenerationForSimpleModelPaths: 55 | 56 | def test_assigns_aggregates_on_insert(self, session, Thread, Comment): 57 | thread = Thread() 58 | thread.name = 'some article name' 59 | session.add(thread) 60 | comment = Comment(content='Some content', thread=thread) 61 | session.add(comment) 62 | session.commit() 63 | session.refresh(thread) 64 | assert thread.comment_count == 1 65 | assert thread.last_comment_id == comment.id 66 | 67 | def test_assigns_aggregates_on_separate_insert( 68 | self, 69 | session, 70 | Thread, 71 | Comment 72 | ): 73 | thread = Thread() 74 | thread.name = 'some article name' 75 | session.add(thread) 76 | session.commit() 77 | comment = Comment(content='Some content', thread=thread) 78 | session.add(comment) 79 | session.commit() 80 | session.refresh(thread) 81 | assert thread.comment_count == 1 82 | assert thread.last_comment_id == 1 83 | 84 | def test_assigns_aggregates_on_delete(self, session, Thread, Comment): 85 | thread = Thread() 86 | thread.name = 'some article name' 87 | session.add(thread) 88 | session.commit() 89 | comment = Comment(content='Some content', thread=thread) 90 | session.add(comment) 91 | session.commit() 92 | session.delete(comment) 93 | session.commit() 94 | session.refresh(thread) 95 | assert thread.comment_count == 0 96 | assert thread.last_comment_id is None 97 | -------------------------------------------------------------------------------- /tests/primitives/test_country.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import pytest 4 | 5 | from sqlalchemy_utils import Country, i18n 6 | 7 | 8 | @pytest.fixture 9 | def set_get_locale(): 10 | i18n.get_locale = lambda: i18n.babel.Locale('en') 11 | 12 | 13 | @pytest.mark.skipif('i18n.babel is None') 14 | @pytest.mark.usefixtures('set_get_locale') 15 | class TestCountry: 16 | 17 | def test_init(self): 18 | assert Country('FI') == Country(Country('FI')) 19 | 20 | def test_constructor_with_wrong_type(self): 21 | with pytest.raises(TypeError) as e: 22 | Country(None) 23 | assert str(e.value) == ( 24 | "Country() argument must be a string or a country, not 'NoneType'" 25 | ) 26 | 27 | def test_constructor_with_invalid_code(self): 28 | with pytest.raises(ValueError) as e: 29 | Country('SomeUnknownCode') 30 | assert str(e.value) == ( 31 | 'Could not convert string to country code: SomeUnknownCode' 32 | ) 33 | 34 | @pytest.mark.parametrize( 35 | 'code', 36 | ( 37 | 'FI', 38 | 'US', 39 | ) 40 | ) 41 | def test_validate_with_valid_codes(self, code): 42 | Country.validate(code) 43 | 44 | def test_validate_with_invalid_code(self): 45 | with pytest.raises(ValueError) as e: 46 | Country.validate('SomeUnknownCode') 47 | assert str(e.value) == ( 48 | 'Could not convert string to country code: SomeUnknownCode' 49 | ) 50 | 51 | def test_equality_operator(self): 52 | assert Country('FI') == 'FI' 53 | assert 'FI' == Country('FI') 54 | assert Country('FI') == Country('FI') 55 | 56 | def test_non_equality_operator(self): 57 | assert Country('FI') != 'sv' 58 | assert not (Country('FI') != 'FI') 59 | 60 | @pytest.mark.parametrize( 61 | 'op, code_left, code_right, is_', 62 | [ 63 | (operator.lt, 'ES', 'FI', True), 64 | (operator.lt, 'FI', 'ES', False), 65 | (operator.lt, 'ES', 'ES', False), 66 | 67 | (operator.le, 'ES', 'FI', True), 68 | (operator.le, 'FI', 'ES', False), 69 | (operator.le, 'ES', 'ES', True), 70 | 71 | (operator.ge, 'ES', 'FI', False), 72 | (operator.ge, 'FI', 'ES', True), 73 | (operator.ge, 'ES', 'ES', True), 74 | 75 | (operator.gt, 'ES', 'FI', False), 76 | (operator.gt, 'FI', 'ES', True), 77 | (operator.gt, 'ES', 'ES', False), 78 | ] 79 | ) 80 | def test_ordering(self, op, code_left, code_right, is_): 81 | country_left = Country(code_left) 82 | country_right = Country(code_right) 83 | assert op(country_left, country_right) is is_ 84 | assert op(country_left, code_right) is is_ 85 | assert op(code_left, country_right) is is_ 86 | 87 | def test_hash(self): 88 | assert hash(Country('FI')) == hash('FI') 89 | 90 | def test_repr(self): 91 | assert repr(Country('FI')) == "Country('FI')" 92 | 93 | def test_str(self): 94 | country = Country('FI') 95 | assert str(country) == 'Finland' 96 | -------------------------------------------------------------------------------- /tests/test_query_chain.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | 4 | from sqlalchemy_utils import QueryChain 5 | 6 | 7 | @pytest.fixture 8 | def User(Base): 9 | class User(Base): 10 | __tablename__ = 'user' 11 | id = sa.Column(sa.Integer, primary_key=True) 12 | return User 13 | 14 | 15 | @pytest.fixture 16 | def Article(Base): 17 | class Article(Base): 18 | __tablename__ = 'article' 19 | id = sa.Column(sa.Integer, primary_key=True) 20 | return Article 21 | 22 | 23 | @pytest.fixture 24 | def BlogPost(Base): 25 | class BlogPost(Base): 26 | __tablename__ = 'blog_post' 27 | id = sa.Column(sa.Integer, primary_key=True) 28 | return BlogPost 29 | 30 | 31 | @pytest.fixture 32 | def init_models(User, Article, BlogPost): 33 | pass 34 | 35 | 36 | @pytest.fixture 37 | def users(session, User): 38 | users = [User(), User()] 39 | session.add_all(users) 40 | session.commit() 41 | return users 42 | 43 | 44 | @pytest.fixture 45 | def articles(session, Article): 46 | articles = [Article(), Article(), Article(), Article()] 47 | session.add_all(articles) 48 | session.commit() 49 | return articles 50 | 51 | 52 | @pytest.fixture 53 | def posts(session, BlogPost): 54 | posts = [BlogPost(), BlogPost(), BlogPost()] 55 | session.add_all(posts) 56 | session.commit() 57 | return posts 58 | 59 | 60 | @pytest.fixture 61 | def chain(session, users, articles, posts, User, Article, BlogPost): 62 | return QueryChain( 63 | [ 64 | session.query(User).order_by('id'), 65 | session.query(Article).order_by('id'), 66 | session.query(BlogPost).order_by('id') 67 | ] 68 | ) 69 | 70 | 71 | class TestQueryChain: 72 | 73 | def test_iter(self, chain): 74 | assert len(list(chain)) == 9 75 | 76 | def test_iter_with_limit(self, chain, users, articles): 77 | c = chain.limit(4) 78 | objects = list(c) 79 | assert users == objects[0:2] 80 | assert articles[0:2] == objects[2:] 81 | 82 | def test_iter_with_offset(self, chain, articles, posts): 83 | c = chain.offset(3) 84 | objects = list(c) 85 | assert articles[1:] + posts == objects 86 | 87 | def test_iter_with_limit_and_offset(self, chain, articles, posts): 88 | c = chain.offset(3).limit(4) 89 | objects = list(c) 90 | assert articles[1:] + posts[0:1] == objects 91 | 92 | def test_iter_with_offset_spanning_multiple_queries(self, chain, posts): 93 | c = chain.offset(7) 94 | objects = list(c) 95 | assert posts[1:] == objects 96 | 97 | def test_repr(self, chain): 98 | assert repr(chain) == '' % id(chain) 99 | 100 | def test_getitem_with_slice(self, chain): 101 | c = chain[1:] 102 | assert c._offset == 1 103 | assert c._limit is None 104 | 105 | def test_getitem_with_single_key(self, chain, articles): 106 | article = chain[2] 107 | assert article == articles[0] 108 | 109 | def test_count(self, chain): 110 | assert chain.count() == 9 111 | -------------------------------------------------------------------------------- /sqlalchemy_utils/types/timezone.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import types 2 | 3 | from ..exceptions import ImproperlyConfigured 4 | from .scalar_coercible import ScalarCoercible 5 | 6 | 7 | class TimezoneType(ScalarCoercible, types.TypeDecorator): 8 | """ 9 | TimezoneType provides a way for saving timezones objects into database. 10 | TimezoneType saves timezone objects as strings on the way in and converts 11 | them back to objects when querying the database. 12 | 13 | :: 14 | 15 | from sqlalchemy_utils import TimezoneType 16 | 17 | class User(Base): 18 | __tablename__ = 'user' 19 | 20 | # Pass backend='pytz' to change it to use pytz. Other values: 21 | # 'dateutil' (default), and 'zoneinfo'. 22 | timezone = sa.Column(TimezoneType(backend='pytz')) 23 | 24 | :param backend: 25 | Whether to use 'dateutil', 'pytz' or 'zoneinfo' for timezones. 26 | 'zoneinfo' uses the standard library module. 27 | 28 | """ 29 | 30 | impl = types.Unicode(50) 31 | python_type = None 32 | cache_ok = True 33 | 34 | def __init__(self, backend='dateutil'): 35 | self.backend = backend 36 | if backend == 'dateutil': 37 | try: 38 | from dateutil.tz import tzfile 39 | from dateutil.zoneinfo import get_zonefile_instance 40 | 41 | self.python_type = tzfile 42 | self._to = get_zonefile_instance().zones.get 43 | self._from = lambda x: str(x._filename) 44 | 45 | except ImportError: 46 | raise ImproperlyConfigured( 47 | "'python-dateutil' is required to use the " 48 | "'dateutil' backend for 'TimezoneType'" 49 | ) 50 | 51 | elif backend == 'pytz': 52 | try: 53 | from pytz import timezone 54 | from pytz.tzinfo import BaseTzInfo 55 | 56 | self.python_type = BaseTzInfo 57 | self._to = timezone 58 | self._from = str 59 | 60 | except ImportError: 61 | raise ImproperlyConfigured( 62 | "'pytz' is required to use the 'pytz' backend for 'TimezoneType'" 63 | ) 64 | 65 | elif backend == 'zoneinfo': 66 | import zoneinfo 67 | 68 | self.python_type = zoneinfo.ZoneInfo 69 | self._to = zoneinfo.ZoneInfo 70 | self._from = str 71 | 72 | else: 73 | raise ImproperlyConfigured( 74 | "'pytz', 'dateutil' or 'zoneinfo' are the backends " 75 | "supported for 'TimezoneType'" 76 | ) 77 | 78 | def _coerce(self, value): 79 | if value is not None and not isinstance(value, self.python_type): 80 | obj = self._to(value) 81 | if obj is None: 82 | raise ValueError("unknown time zone '%s'" % value) 83 | return obj 84 | return value 85 | 86 | def process_bind_param(self, value, dialect): 87 | return self._from(self._coerce(value)) if value else None 88 | 89 | def process_result_value(self, value, dialect): 90 | return self._to(value) if value else None 91 | -------------------------------------------------------------------------------- /sqlalchemy_utils/primitives/country.py: -------------------------------------------------------------------------------- 1 | from functools import total_ordering 2 | 3 | from .. import i18n 4 | from ..utils import str_coercible 5 | 6 | 7 | @total_ordering 8 | @str_coercible 9 | class Country: 10 | """ 11 | Country class wraps a 2 to 3 letter country code. It provides various 12 | convenience properties and methods. 13 | 14 | :: 15 | 16 | from babel import Locale 17 | from sqlalchemy_utils import Country, i18n 18 | 19 | 20 | # First lets add a locale getter for testing purposes 21 | i18n.get_locale = lambda: Locale('en') 22 | 23 | 24 | Country('FI').name # Finland 25 | Country('FI').code # FI 26 | 27 | Country(Country('FI')).code # 'FI' 28 | 29 | Country always validates the given code if you use at least the optional 30 | dependency list 'babel', otherwise no validation are performed. 31 | 32 | :: 33 | 34 | Country(None) # raises TypeError 35 | 36 | Country('UnknownCode') # raises ValueError 37 | 38 | 39 | Country supports equality operators. 40 | 41 | :: 42 | 43 | Country('FI') == Country('FI') 44 | Country('FI') != Country('US') 45 | 46 | 47 | Country objects are hashable. 48 | 49 | 50 | :: 51 | 52 | assert hash(Country('FI')) == hash('FI') 53 | 54 | """ 55 | 56 | def __init__(self, code_or_country): 57 | if isinstance(code_or_country, Country): 58 | self.code = code_or_country.code 59 | elif isinstance(code_or_country, str): 60 | self.validate(code_or_country) 61 | self.code = code_or_country 62 | else: 63 | raise TypeError( 64 | "Country() argument must be a string or a country, not '{}'".format( 65 | type(code_or_country).__name__ 66 | ) 67 | ) 68 | 69 | @property 70 | def name(self): 71 | return i18n.get_locale().territories[self.code] 72 | 73 | @classmethod 74 | def validate(self, code): 75 | try: 76 | i18n.babel.Locale('en').territories[code] 77 | except KeyError: 78 | raise ValueError(f'Could not convert string to country code: {code}') 79 | except AttributeError: 80 | # As babel is optional, we may raise an AttributeError accessing it 81 | pass 82 | 83 | def __eq__(self, other): 84 | if isinstance(other, Country): 85 | return self.code == other.code 86 | elif isinstance(other, str): 87 | return self.code == other 88 | else: 89 | return NotImplemented 90 | 91 | def __hash__(self): 92 | return hash(self.code) 93 | 94 | def __ne__(self, other): 95 | return not (self == other) 96 | 97 | def __lt__(self, other): 98 | if isinstance(other, Country): 99 | return self.code < other.code 100 | elif isinstance(other, str): 101 | return self.code < other 102 | return NotImplemented 103 | 104 | def __repr__(self): 105 | return f'{self.__class__.__name__}({self.code!r})' 106 | 107 | def __unicode__(self): 108 | return self.name 109 | -------------------------------------------------------------------------------- /sqlalchemy_utils/primitives/currency.py: -------------------------------------------------------------------------------- 1 | from .. import i18n, ImproperlyConfigured 2 | from ..utils import str_coercible 3 | 4 | 5 | @str_coercible 6 | class Currency: 7 | """ 8 | Currency class wraps a 3-letter currency code. It provides various 9 | convenience properties and methods. 10 | 11 | :: 12 | 13 | from babel import Locale 14 | from sqlalchemy_utils import Currency, i18n 15 | 16 | 17 | # First lets add a locale getter for testing purposes 18 | i18n.get_locale = lambda: Locale('en') 19 | 20 | 21 | Currency('USD').name # US Dollar 22 | Currency('USD').symbol # $ 23 | 24 | Currency(Currency('USD')).code # 'USD' 25 | 26 | Currency always validates the given code if you use at least the optional 27 | dependency list 'babel', otherwise no validation are performed. 28 | 29 | :: 30 | 31 | Currency(None) # raises TypeError 32 | 33 | Currency('UnknownCode') # raises ValueError 34 | 35 | 36 | Currency supports equality operators. 37 | 38 | :: 39 | 40 | Currency('USD') == Currency('USD') 41 | Currency('USD') != Currency('EUR') 42 | 43 | 44 | Currencies are hashable. 45 | 46 | 47 | :: 48 | 49 | len(set([Currency('USD'), Currency('USD')])) # 1 50 | 51 | 52 | """ 53 | 54 | def __init__(self, code): 55 | if i18n.babel is None: 56 | raise ImproperlyConfigured( 57 | "'babel' package is required in order to use Currency class." 58 | ) 59 | if isinstance(code, Currency): 60 | self.code = code 61 | elif isinstance(code, str): 62 | self.validate(code) 63 | self.code = code 64 | else: 65 | raise TypeError( 66 | 'First argument given to Currency constructor should be ' 67 | 'either an instance of Currency or valid three letter ' 68 | 'currency code.' 69 | ) 70 | 71 | @classmethod 72 | def validate(self, code): 73 | try: 74 | i18n.babel.Locale('en').currencies[code] 75 | except KeyError: 76 | raise ValueError(f"'{code}' is not valid currency code.") 77 | except AttributeError: 78 | # As babel is optional, we may raise an AttributeError accessing it 79 | pass 80 | 81 | @property 82 | def symbol(self): 83 | return i18n.babel.numbers.get_currency_symbol(self.code, i18n.get_locale()) 84 | 85 | @property 86 | def name(self): 87 | return i18n.get_locale().currencies[self.code] 88 | 89 | def __eq__(self, other): 90 | if isinstance(other, Currency): 91 | return self.code == other.code 92 | elif isinstance(other, str): 93 | return self.code == other 94 | else: 95 | return NotImplemented 96 | 97 | def __ne__(self, other): 98 | return not (self == other) 99 | 100 | def __hash__(self): 101 | return hash(self.code) 102 | 103 | def __repr__(self): 104 | return f'{self.__class__.__name__}({self.code!r})' 105 | 106 | def __unicode__(self): 107 | return self.code 108 | -------------------------------------------------------------------------------- /tests/test_auto_delete_orphans.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlalchemy as sa 3 | from sqlalchemy.orm import backref 4 | 5 | from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured 6 | 7 | 8 | @pytest.fixture 9 | def tagging_tbl(Base): 10 | return sa.Table( 11 | 'tagging', 12 | Base.metadata, 13 | sa.Column( 14 | 'tag_id', 15 | sa.Integer, 16 | sa.ForeignKey('tag.id', ondelete='cascade'), 17 | primary_key=True 18 | ), 19 | sa.Column( 20 | 'entry_id', 21 | sa.Integer, 22 | sa.ForeignKey('entry.id', ondelete='cascade'), 23 | primary_key=True 24 | ) 25 | ) 26 | 27 | 28 | @pytest.fixture 29 | def Tag(Base): 30 | class Tag(Base): 31 | __tablename__ = 'tag' 32 | id = sa.Column(sa.Integer, primary_key=True) 33 | name = sa.Column(sa.String(100), unique=True, nullable=False) 34 | 35 | def __init__(self, name=None): 36 | self.name = name 37 | return Tag 38 | 39 | 40 | @pytest.fixture( 41 | params=['entries', backref('entries', lazy='select')], 42 | ids=['backref_string', 'backref_with_keywords'] 43 | ) 44 | def Entry(Base, Tag, tagging_tbl, request): 45 | class Entry(Base): 46 | __tablename__ = 'entry' 47 | 48 | id = sa.Column(sa.Integer, primary_key=True) 49 | 50 | tags = sa.orm.relationship( 51 | Tag, 52 | secondary=tagging_tbl, 53 | backref=request.param 54 | ) 55 | auto_delete_orphans(Entry.tags) 56 | return Entry 57 | 58 | 59 | @pytest.fixture 60 | def EntryWithoutTagsBackref(Base, Tag, tagging_tbl): 61 | class EntryWithoutTagsBackref(Base): 62 | __tablename__ = 'entry' 63 | 64 | id = sa.Column(sa.Integer, primary_key=True) 65 | 66 | tags = sa.orm.relationship( 67 | Tag, 68 | secondary=tagging_tbl 69 | ) 70 | return EntryWithoutTagsBackref 71 | 72 | 73 | class TestAutoDeleteOrphans: 74 | 75 | @pytest.fixture 76 | def init_models(self, Entry, Tag): 77 | pass 78 | 79 | def test_orphan_deletion(self, session, Entry, Tag): 80 | r1 = Entry() 81 | r2 = Entry() 82 | r3 = Entry() 83 | t1, t2, t3, t4 = ( 84 | Tag('t1'), 85 | Tag('t2'), 86 | Tag('t3'), 87 | Tag('t4') 88 | ) 89 | 90 | r1.tags.extend([t1, t2]) 91 | r2.tags.extend([t2, t3]) 92 | r3.tags.extend([t4]) 93 | session.add_all([r1, r2, r3]) 94 | 95 | assert session.query(Tag).count() == 4 96 | r2.tags.remove(t2) 97 | assert session.query(Tag).count() == 4 98 | r1.tags.remove(t2) 99 | assert session.query(Tag).count() == 3 100 | r1.tags.remove(t1) 101 | assert session.query(Tag).count() == 2 102 | 103 | 104 | class TestAutoDeleteOrphansWithoutBackref: 105 | 106 | @pytest.fixture 107 | def init_models(self, EntryWithoutTagsBackref, Tag): 108 | pass 109 | 110 | def test_orphan_deletion(self, EntryWithoutTagsBackref): 111 | with pytest.raises(ImproperlyConfigured): 112 | auto_delete_orphans(EntryWithoutTagsBackref.tags) 113 | --------------------------------------------------------------------------------