├── tests
├── types
│ ├── __init__.py
│ ├── test_ip_address.py
│ ├── test_url.py
│ ├── test_email.py
│ ├── test_country.py
│ ├── test_ltree.py
│ ├── test_currency.py
│ ├── encrypted
│ │ └── test_padding.py
│ ├── test_weekdays.py
│ ├── test_uuid.py
│ ├── test_locale.py
│ ├── test_enriched_date_pendulum.py
│ ├── test_color.py
│ ├── test_tsvector.py
│ ├── test_arrow.py
│ ├── test_scalar_list.py
│ └── test_enriched_datetime_pendulum.py
├── aggregate
│ ├── __init__.py
│ ├── test_with_ondelete_cascade.py
│ ├── test_search_vectors.py
│ ├── test_with_column_alias.py
│ ├── test_o2m_o2m.py
│ ├── test_m2m.py
│ ├── test_custom_select_expressions.py
│ ├── test_backrefs.py
│ ├── test_simple_paths.py
│ ├── test_o2m_m2m.py
│ ├── test_m2m_m2m.py
│ └── test_multiple_aggregates_per_class.py
├── functions
│ ├── __init__.py
│ ├── test_escape_like.py
│ ├── test_naturally_equivalent.py
│ ├── test_get_bind.py
│ ├── test_is_loaded.py
│ ├── test_quote.py
│ ├── test_json_sql.py
│ ├── test_jsonb_sql.py
│ ├── test_table_name.py
│ ├── test_get_hybrid_properties.py
│ ├── test_get_column_key.py
│ ├── test_identity.py
│ ├── test_cast_if.py
│ ├── test_has_changes.py
│ ├── test_get_primary_keys.py
│ ├── test_get_type.py
│ ├── test_get_columns.py
│ ├── test_non_indexed_foreign_keys.py
│ ├── test_render.py
│ └── test_get_tables.py
├── observes
│ ├── __init__.py
│ ├── test_o2o_o2o.py
│ ├── test_dynamic_relationship.py
│ └── test_o2o_o2o_o2o.py
├── primitives
│ ├── __init__.py
│ ├── test_currency.py
│ └── test_country.py
├── relationships
│ └── __init__.py
├── __init__.py
├── test_instrumented_list.py
├── generic_relationship
│ ├── test_column_aliases.py
│ ├── test_abstract_base_class.py
│ ├── test_hybrid_properties.py
│ ├── test_composite_keys.py
│ └── __init__.py
├── test_case_insensitive_comparator.py
├── test_instant_defaults_listener.py
├── test_expressions.py
├── test_models.py
├── test_query_chain.py
└── test_auto_delete_orphans.py
├── docs
├── license.rst
├── observers.rst
├── aggregates.rst
├── utility_classes.rst
├── models.rst
├── listeners.rst
├── view.rst
├── index.rst
├── testing.rst
├── range_data_types.rst
├── foreign_key_helpers.rst
├── database_helpers.rst
├── installation.rst
├── conf.py
└── orm_helpers.rst
├── sqlalchemy_utils
├── types
│ ├── encrypted
│ │ └── __init__.py
│ ├── scalar_coercible.py
│ ├── enriched_datetime
│ │ ├── __init__.py
│ │ ├── pendulum_date.py
│ │ ├── arrow_datetime.py
│ │ ├── pendulum_datetime.py
│ │ ├── enriched_date_type.py
│ │ └── enriched_datetime_type.py
│ ├── bit.py
│ ├── email.py
│ ├── ip_address.py
│ ├── arrow.py
│ ├── url.py
│ ├── country.py
│ ├── locale.py
│ ├── __init__.py
│ ├── currency.py
│ ├── color.py
│ ├── json.py
│ ├── weekdays.py
│ ├── scalar_list.py
│ └── timezone.py
├── primitives
│ ├── __init__.py
│ ├── weekday.py
│ ├── weekdays.py
│ ├── country.py
│ └── currency.py
├── exceptions.py
├── compat.py
├── utils.py
├── relationships
│ └── chained_join.py
├── functions
│ ├── __init__.py
│ ├── render.py
│ └── sort_query.py
├── expressions.py
├── operators.py
├── proxy_dict.py
└── __init__.py
├── requirements
├── docs
│ ├── pyproject.toml
│ └── requirements.txt
└── README.rst
├── .flake8
├── MANIFEST.in
├── .github
├── dependabot.yaml
└── workflows
│ ├── lint.yaml
│ └── test.yaml
├── .readthedocs.yaml
├── .editorconfig
├── .ci
└── install_mssql.sh
├── ROADMAP.rst
├── .pre-commit-config.yaml
├── .gitignore
├── README.rst
├── LICENSE
├── RELEASE.md
└── tox.ini
/tests/types/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/aggregate/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/functions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/observes/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/primitives/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/relationships/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/license.rst:
--------------------------------------------------------------------------------
1 | License
2 | =======
3 |
4 | .. include:: ../LICENSE
5 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/encrypted/__init__.py:
--------------------------------------------------------------------------------
1 | # Module for encrypted type
2 |
--------------------------------------------------------------------------------
/docs/observers.rst:
--------------------------------------------------------------------------------
1 | Observers
2 | =========
3 |
4 | .. automodule:: sqlalchemy_utils.observer
5 |
6 | .. autofunction:: observes
7 |
--------------------------------------------------------------------------------
/docs/aggregates.rst:
--------------------------------------------------------------------------------
1 | Aggregated attributes
2 | =====================
3 |
4 | .. automodule:: sqlalchemy_utils.aggregates
5 |
6 | .. autofunction:: aggregated
7 |
--------------------------------------------------------------------------------
/requirements/docs/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | package-mode = false
3 |
4 | [tool.poetry.dependencies]
5 | python = ">=3.9"
6 | sphinx = "*"
7 | sphinx-rtd-theme = "*"
8 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | def assert_contains(clause, query):
2 | # Test that query executes
3 | query.all()
4 | assert clause in str(query), f"[{clause}] is not in [{query}]"
5 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | extend-ignore = E203
4 | exclude =
5 | .eggs,
6 | .git,
7 | .tox,
8 | .venv,
9 | build,
10 | docs/conf.py
11 |
--------------------------------------------------------------------------------
/tests/functions/test_escape_like.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy_utils import escape_like
2 |
3 |
4 | class TestEscapeLike:
5 | def test_escapes_wildcards(self):
6 | assert escape_like('_*%') == '*_***%'
7 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include CHANGES.rst LICENSE README.rst
2 | recursive-include tests *
3 | recursive-exclude tests *.pyc
4 | recursive-include docs *
5 | recursive-exclude docs *.pyc
6 | prune docs/_build
7 | exclude docs/_themes/.git
8 |
--------------------------------------------------------------------------------
/docs/utility_classes.rst:
--------------------------------------------------------------------------------
1 | Utility classes
2 | ===============
3 |
4 | QueryChain
5 | ----------
6 |
7 | .. automodule:: sqlalchemy_utils.query_chain
8 |
9 | API
10 | ---
11 |
12 | .. autoclass:: QueryChain
13 | :members:
14 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/primitives/__init__.py:
--------------------------------------------------------------------------------
1 | from .country import Country # noqa
2 | from .currency import Currency # noqa
3 | from .ltree import Ltree # noqa
4 | from .weekday import WeekDay # noqa
5 | from .weekdays import WeekDays # noqa
6 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/scalar_coercible.py:
--------------------------------------------------------------------------------
1 | class ScalarCoercible:
2 | def _coerce(self, value):
3 | raise NotImplementedError
4 |
5 | def coercion_listener(self, target, value, oldvalue, initiator):
6 | return self._coerce(value)
7 |
--------------------------------------------------------------------------------
/.github/dependabot.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "github-actions"
4 | directory: "/"
5 | schedule:
6 | interval: "monthly"
7 | groups:
8 | github-actions:
9 | patterns:
10 | - "*"
11 |
--------------------------------------------------------------------------------
/docs/models.rst:
--------------------------------------------------------------------------------
1 | Model mixins
2 | ============
3 |
4 |
5 | Timestamp
6 | ---------
7 |
8 | .. autoclass:: sqlalchemy_utils.models.Timestamp
9 |
10 |
11 | generic_repr
12 | ------------
13 |
14 | .. autofunction:: sqlalchemy_utils.models.generic_repr
15 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/enriched_datetime/__init__.py:
--------------------------------------------------------------------------------
1 | # Module for enriched date, datetime type
2 | from .arrow_datetime import ArrowDateTime # noqa
3 | from .pendulum_date import PendulumDate # noqa
4 | from .pendulum_datetime import PendulumDateTime # noqa
5 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: "ubuntu-24.04"
5 | tools:
6 | python: "3.13"
7 |
8 | python:
9 | install:
10 | - requirements: "requirements/docs/requirements.txt"
11 |
12 | sphinx:
13 | configuration: "docs/conf.py"
14 | fail_on_warning: false
15 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/exceptions.py:
--------------------------------------------------------------------------------
1 | """
2 | Global SQLAlchemy-Utils exception classes.
3 | """
4 |
5 |
6 | class ImproperlyConfigured(Exception):
7 | """
8 | SQLAlchemy-Utils is improperly configured; normally due to usage of
9 | a utility that depends on a missing library.
10 | """
11 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # EditorConfig helps developers define and maintain consistent
2 | # coding styles between different editors and IDEs
3 | # editorconfig.org
4 |
5 | root = true
6 |
7 |
8 | [*]
9 | indent_style = space
10 | end_of_line = lf
11 | charset = utf-8
12 | trim_trailing_whitespace = true
13 | insert_final_newline = true
14 | indent_size = 4
15 |
16 | [*.yaml]
17 | indent_size = 2
18 |
--------------------------------------------------------------------------------
/.ci/install_mssql.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ ! -f /etc/apt/sources.list.d/microsoft-prod.list ]; then
4 | curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add -
5 | sudo sh -c "curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -r -s)/prod.list > /etc/apt/sources.list.d/mssql-release.list"
6 | fi
7 |
8 | sudo apt-get update
9 | sudo ACCEPT_EULA=Y apt-get -y install msodbcsql18 unixodbc
10 |
--------------------------------------------------------------------------------
/docs/listeners.rst:
--------------------------------------------------------------------------------
1 | Listeners
2 | =========
3 |
4 |
5 | .. module:: sqlalchemy_utils.listeners
6 |
7 | Automatic data coercion
8 | -----------------------
9 |
10 | .. autofunction:: force_auto_coercion
11 |
12 |
13 | Instant defaults
14 | ----------------
15 |
16 | .. autofunction:: force_instant_defaults
17 |
18 |
19 | Many-to-many orphan deletion
20 | ----------------------------
21 |
22 | .. autofunction:: auto_delete_orphans
23 |
--------------------------------------------------------------------------------
/docs/view.rst:
--------------------------------------------------------------------------------
1 | View utilities
2 | ==============
3 |
4 |
5 | create_view
6 | -----------
7 |
8 | .. autofunction:: sqlalchemy_utils.view.create_view
9 |
10 |
11 | create_materialized_view
12 | ------------------------
13 |
14 | .. autofunction:: sqlalchemy_utils.view.create_materialized_view
15 |
16 |
17 | refresh_materialized_view
18 | -------------------------
19 |
20 | .. autofunction:: sqlalchemy_utils.view.refresh_materialized_view
21 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/compat.py:
--------------------------------------------------------------------------------
1 | import re
2 | from importlib.metadata import metadata
3 |
4 |
5 | def get_sqlalchemy_version(version=metadata('sqlalchemy')['Version']):
6 | """Extract the sqlalchemy version as a tuple of integers."""
7 |
8 | match = re.search(r'^(\d+)(?:\.(\d+)(?:\.(\d+))?)?', version)
9 | try:
10 | return tuple(int(v) for v in match.groups() if v is not None)
11 | except AttributeError:
12 | return ()
13 |
--------------------------------------------------------------------------------
/ROADMAP.rst:
--------------------------------------------------------------------------------
1 | * Add efficient pagination support
2 | https://www.depesz.com/2011/05/20/pagination-with-fixed-order/
3 | https://stackoverflow.com/questions/6618366/improving-offset-performance-in-postgresql
4 | * Generic file model
5 | https://github.com/jpvanhal/silo
6 | * Query to Postgres JSON converter
7 | https://hashrocket.com/blog/posts/faster-json-generation-with-postgresql
8 | * Postgres Cube datatype:
9 | https://www.postgresql.org/docs/current/cube.html
10 |
--------------------------------------------------------------------------------
/tests/functions/test_naturally_equivalent.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy_utils.functions import naturally_equivalent
2 |
3 |
4 | class TestNaturallyEquivalent:
5 | def test_returns_true_when_properties_match(self, User):
6 | assert naturally_equivalent(
7 | User(name='someone'), User(name='someone')
8 | )
9 |
10 | def test_skips_primary_keys(self, User):
11 | assert naturally_equivalent(
12 | User(id=1, name='someone'), User(id=2, name='someone')
13 | )
14 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yaml:
--------------------------------------------------------------------------------
1 | name: Python linting
2 | on: [push, pull_request]
3 | jobs:
4 | test:
5 | name: Lint
6 | runs-on: ubuntu-latest
7 | strategy:
8 | matrix:
9 | python-version: ['3.13']
10 | steps:
11 | - uses: actions/checkout@v6
12 | - uses: actions/setup-python@v6
13 | with:
14 | python-version: ${{ matrix.python-version }}
15 | - name: Install tox
16 | run: pip install tox
17 | - name: Run linting
18 | run: tox -e ruff
19 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | SQLAlchemy-Utils
2 | ================
3 |
4 |
5 | SQLAlchemy-Utils provides custom data types and various utility functions for SQLAlchemy.
6 |
7 |
8 | .. toctree::
9 | :maxdepth: 2
10 |
11 | installation
12 | listeners
13 | data_types
14 | range_data_types
15 | aggregates
16 | observers
17 | internationalization
18 | generic_relationship
19 | database_helpers
20 | foreign_key_helpers
21 | orm_helpers
22 | utility_classes
23 | models
24 | view
25 | testing
26 | license
27 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/utils.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 |
3 |
4 | def str_coercible(cls):
5 | def __str__(self):
6 | return self.__unicode__()
7 |
8 | cls.__str__ = __str__
9 | return cls
10 |
11 |
12 | def is_sequence(value):
13 | return isinstance(value, Iterable) and not isinstance(value, str)
14 |
15 |
16 | def starts_with(iterable, prefix):
17 | """
18 | Returns whether or not given iterable starts with given prefix.
19 | """
20 | return list(iterable)[0 : len(prefix)] == list(prefix)
21 |
--------------------------------------------------------------------------------
/docs/testing.rst:
--------------------------------------------------------------------------------
1 | Testing
2 | =======
3 |
4 | .. automodule:: sqlalchemy_utils.asserts
5 |
6 |
7 | assert_min_value
8 | ----------------
9 |
10 | .. autofunction:: assert_min_value
11 |
12 | assert_max_length
13 | -----------------
14 |
15 | .. autofunction:: assert_max_length
16 |
17 | assert_max_value
18 | ----------------
19 |
20 | .. autofunction:: assert_max_value
21 |
22 | assert_nullable
23 | ---------------
24 |
25 | .. autofunction:: assert_nullable
26 |
27 | assert_non_nullable
28 | -------------------
29 |
30 | .. autofunction:: assert_non_nullable
31 |
--------------------------------------------------------------------------------
/docs/range_data_types.rst:
--------------------------------------------------------------------------------
1 | Range data types
2 | ================
3 |
4 | .. automodule:: sqlalchemy_utils.types.range
5 |
6 |
7 | DateRangeType
8 | -------------
9 |
10 | .. autoclass:: DateRangeType
11 |
12 |
13 | DateTimeRangeType
14 | -----------------
15 |
16 | .. autoclass:: DateTimeRangeType
17 |
18 |
19 | IntRangeType
20 | ------------
21 |
22 | .. autoclass:: IntRangeType
23 |
24 |
25 | NumericRangeType
26 | ----------------
27 |
28 | .. autoclass:: NumericRangeType
29 |
30 |
31 | RangeComparator
32 | ---------------
33 |
34 | .. autoclass:: RangeComparator
35 | :members:
36 |
--------------------------------------------------------------------------------
/tests/test_instrumented_list.py:
--------------------------------------------------------------------------------
1 | class TestInstrumentedList:
2 | def test_any_returns_true_if_member_has_attr_defined(
3 | self,
4 | Category,
5 | Article
6 | ):
7 | category = Category()
8 | category.articles.append(Article())
9 | category.articles.append(Article(name='some name'))
10 | assert category.articles.any('name')
11 |
12 | def test_any_returns_false_if_no_member_has_attr_defined(
13 | self,
14 | Category,
15 | Article
16 | ):
17 | category = Category()
18 | category.articles.append(Article())
19 | assert not category.articles.any('name')
20 |
--------------------------------------------------------------------------------
/tests/functions/test_get_bind.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from sqlalchemy_utils import get_bind
4 |
5 |
6 | class TestGetBind:
7 | def test_with_session(self, session, connection):
8 | assert get_bind(session) == connection
9 |
10 | def test_with_connection(self, session, connection):
11 | assert get_bind(connection) == connection
12 |
13 | def test_with_model_object(self, session, connection, Article):
14 | article = Article()
15 | session.add(article)
16 | assert get_bind(article) == connection
17 |
18 | def test_with_unknown_type(self):
19 | with pytest.raises(TypeError):
20 | get_bind(None)
21 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Run `pre-commit autoupdate` to update the hook versions.
2 |
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: check-yaml
8 | - id: end-of-file-fixer
9 | - id: trailing-whitespace
10 |
11 | - repo: https://github.com/astral-sh/ruff-pre-commit
12 | rev: v0.12.2
13 | hooks:
14 | - id: ruff-check
15 | args: [--fix]
16 | - id: ruff-format
17 |
18 | - repo: https://github.com/asottile/pyupgrade
19 | rev: v3.20.0
20 | hooks:
21 | - id: pyupgrade
22 | name: "Enforce Python 3.9+ idioms"
23 | args: ["--py39-plus"]
24 |
--------------------------------------------------------------------------------
/tests/functions/test_is_loaded.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import is_loaded
5 |
6 |
7 | @pytest.fixture
8 | def Article(Base):
9 | class Article(Base):
10 | __tablename__ = 'article_translation'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | title = sa.orm.deferred(sa.Column(sa.String(100)))
13 | return Article
14 |
15 |
16 | class TestIsLoaded:
17 |
18 | def test_loaded_property(self, Article):
19 | article = Article(id=1)
20 | assert is_loaded(article, 'id')
21 |
22 | def test_unloaded_property(self, Article):
23 | article = Article(id=4)
24 | assert not is_loaded(article, 'title')
25 |
--------------------------------------------------------------------------------
/tests/functions/test_quote.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy.dialects import postgresql
2 |
3 | from sqlalchemy_utils.functions import quote
4 |
5 |
6 | class TestQuote:
7 | def test_quote_with_preserved_keyword(self, connection, session):
8 | assert quote(connection, 'order') == '"order"'
9 | assert quote(session, 'order') == '"order"'
10 | assert quote(postgresql.dialect(), 'order') == '"order"'
11 |
12 | def test_quote_with_non_preserved_keyword(
13 | self,
14 | connection,
15 | session
16 | ):
17 | assert quote(connection, 'some_order') == 'some_order'
18 | assert quote(session, 'some_order') == 'some_order'
19 | assert quote(postgresql.dialect(), 'some_order') == 'some_order'
20 |
--------------------------------------------------------------------------------
/docs/foreign_key_helpers.rst:
--------------------------------------------------------------------------------
1 | Foreign key helpers
2 | ===================
3 |
4 |
5 | dependent_objects
6 | -----------------
7 |
8 | .. autofunction:: sqlalchemy_utils.functions.dependent_objects
9 |
10 |
11 | get_referencing_foreign_keys
12 | ----------------------------
13 |
14 | .. autofunction:: sqlalchemy_utils.functions.get_referencing_foreign_keys
15 |
16 |
17 | group_foreign_keys
18 | ------------------
19 |
20 | .. autofunction:: sqlalchemy_utils.functions.group_foreign_keys
21 |
22 |
23 | merge_references
24 | ----------------
25 |
26 | .. autofunction:: sqlalchemy_utils.functions.merge_references
27 |
28 |
29 | non_indexed_foreign_keys
30 | ------------------------
31 |
32 | .. autofunction:: sqlalchemy_utils.functions.non_indexed_foreign_keys
33 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/relationships/chained_join.py:
--------------------------------------------------------------------------------
1 | def chained_join(*relationships):
2 | """
3 | Return a chained Join object for given relationships.
4 | """
5 | property_ = relationships[0].property
6 |
7 | if property_.secondary is not None:
8 | from_ = property_.secondary.join(
9 | property_.mapper.class_.__table__, property_.secondaryjoin
10 | )
11 | else:
12 | from_ = property_.mapper.class_.__table__
13 | for relationship in relationships[1:]:
14 | prop = relationship.property
15 | if prop.secondary is not None:
16 | from_ = from_.join(prop.secondary, prop.primaryjoin)
17 |
18 | from_ = from_.join(prop.mapper.class_, prop.secondaryjoin)
19 | else:
20 | from_ = from_.join(prop.mapper.class_, prop.primaryjoin)
21 | return from_
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[cod]
2 | __pycache__
3 |
4 | # C extensions
5 | *.so
6 |
7 | # Packages
8 | *.egg
9 | *.egg-info
10 | dist
11 | build
12 | eggs
13 | parts
14 | bin
15 | var
16 | sdist
17 | develop-eggs
18 | .installed.cfg
19 | .cache
20 | .eggs
21 | lib
22 | lib64
23 | docs/_build
24 |
25 | # Installer logs
26 | pip-log.txt
27 |
28 | # Unit test / coverage reports
29 | .coverage
30 | .tox
31 | coverage.xml
32 | nosetests.xml
33 |
34 | # Translations
35 | *.mo
36 |
37 | # Mr Developer
38 | .mr.developer.cfg
39 | .project
40 | .pydevproject
41 |
42 | # vim
43 | [._]*.s[a-w][a-z]
44 | [._]s[a-w][a-z]
45 | *.un~
46 | Session.vim
47 | .netrwhist
48 | *~
49 |
50 | # Sublime Text
51 | *.sublime-*
52 |
53 | # JetBrains
54 | .idea/
55 |
56 | # Environments
57 | .env
58 | .venv
59 | env/
60 | venv/
61 | ENV/
62 | env.bak/
63 | venv.bak/
64 |
65 | poetry.lock
66 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/bit.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy as sa
2 | from sqlalchemy.dialects.postgresql import BIT
3 |
4 |
5 | class BitType(sa.types.TypeDecorator):
6 | """
7 | BitType offers way of saving BITs into database.
8 | """
9 |
10 | impl = sa.types.BINARY
11 |
12 | cache_ok = True
13 |
14 | def __init__(self, length=1, **kwargs):
15 | self.length = length
16 | sa.types.TypeDecorator.__init__(self, **kwargs)
17 |
18 | def load_dialect_impl(self, dialect):
19 | # Use the native BIT type for drivers that has it.
20 | if dialect.name == 'postgresql':
21 | return dialect.type_descriptor(BIT(self.length))
22 | elif dialect.name == 'sqlite':
23 | return dialect.type_descriptor(sa.String(self.length))
24 | else:
25 | return dialect.type_descriptor(type(self.impl)(self.length))
26 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | SQLAlchemy-Utils
2 | ================
3 |
4 | |Build| |Version Status| |Downloads|
5 |
6 | Various utility functions, new data types and helpers for SQLAlchemy.
7 |
8 |
9 | Resources
10 | ---------
11 |
12 | - `Documentation `_
13 | - `Issue Tracker `_
14 | - `Code `_
15 |
16 | .. |Build| image:: https://github.com/kvesteri/sqlalchemy-utils/actions/workflows/test.yaml/badge.svg?branch=master
17 | :target: https://github.com/kvesteri/sqlalchemy-utils/actions/workflows/test.yaml
18 | :alt: Tests
19 | .. |Version Status| image:: https://img.shields.io/pypi/v/SQLAlchemy-Utils.svg
20 | :target: https://pypi.python.org/pypi/SQLAlchemy-Utils/
21 | .. |Downloads| image:: https://img.shields.io/pypi/dm/SQLAlchemy-Utils.svg
22 | :target: https://pypi.python.org/pypi/SQLAlchemy-Utils/
23 |
--------------------------------------------------------------------------------
/tests/functions/test_json_sql.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import json_sql
5 |
6 |
7 | @pytest.mark.usefixtures('postgresql_dsn')
8 | class TestJSONSQL:
9 |
10 | @pytest.mark.parametrize(
11 | ('value', 'result'),
12 | (
13 | (1, 1),
14 | (14.14, 14.14),
15 | ({'a': 2, 'b': 'c'}, {'a': 2, 'b': 'c'}),
16 | (
17 | {'a': {'b': 'c'}},
18 | {'a': {'b': 'c'}}
19 | ),
20 | ({}, {}),
21 | ([1, 2], [1, 2]),
22 | ([], []),
23 | (
24 | [sa.select(sa.text('1')).label('alias')],
25 | [1]
26 | )
27 | )
28 | )
29 | def test_compiled_scalars(self, connection, value, result):
30 | assert result == (
31 | connection.execute(sa.select(json_sql(value))).fetchone()[0]
32 | )
33 |
--------------------------------------------------------------------------------
/tests/functions/test_jsonb_sql.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import jsonb_sql
5 |
6 |
7 | @pytest.mark.usefixtures('postgresql_dsn')
8 | class TestJSONBSQL:
9 |
10 | @pytest.mark.parametrize(
11 | ('value', 'result'),
12 | (
13 | (1, 1),
14 | (14.14, 14.14),
15 | ({'a': 2, 'b': 'c'}, {'a': 2, 'b': 'c'}),
16 | (
17 | {'a': {'b': 'c'}},
18 | {'a': {'b': 'c'}}
19 | ),
20 | ({}, {}),
21 | ([1, 2], [1, 2]),
22 | ([], []),
23 | (
24 | [sa.select(sa.text('1')).label('alias')],
25 | [1]
26 | )
27 | )
28 | )
29 | def test_compiled_scalars(self, connection, value, result):
30 | assert result == (
31 | connection.execute(sa.select(jsonb_sql(value))).fetchone()[0]
32 | )
33 |
--------------------------------------------------------------------------------
/tests/functions/test_table_name.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import table_name
5 |
6 |
7 | @pytest.fixture
8 | def Building(Base):
9 | class Building(Base):
10 | __tablename__ = 'building'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | name = sa.Column(sa.Unicode(255))
13 | return Building
14 |
15 |
16 | @pytest.fixture
17 | def init_models(Base):
18 | pass
19 |
20 |
21 | class TestTableName:
22 |
23 | def test_class(self, Building):
24 | assert table_name(Building) == 'building'
25 | del Building.__tablename__
26 | assert table_name(Building) == 'building'
27 |
28 | def test_attribute(self, Building):
29 | assert table_name(Building.id) == 'building'
30 | assert table_name(Building.name) == 'building'
31 |
32 | def test_target(self, Building):
33 | assert table_name(Building()) == 'building'
34 |
--------------------------------------------------------------------------------
/docs/database_helpers.rst:
--------------------------------------------------------------------------------
1 | Database helpers
2 | ================
3 |
4 |
5 | database_exists
6 | ---------------
7 |
8 | .. autofunction:: sqlalchemy_utils.functions.database_exists
9 |
10 |
11 | create_database
12 | ---------------
13 |
14 | .. autofunction:: sqlalchemy_utils.functions.create_database
15 |
16 |
17 | drop_database
18 | -------------
19 |
20 | .. autofunction:: sqlalchemy_utils.functions.drop_database
21 |
22 |
23 | has_index
24 | ---------
25 |
26 | .. autofunction:: sqlalchemy_utils.functions.has_index
27 |
28 |
29 | has_unique_index
30 | ----------------
31 |
32 | .. autofunction:: sqlalchemy_utils.functions.has_unique_index
33 |
34 |
35 | json_sql
36 | --------
37 |
38 | .. autofunction:: sqlalchemy_utils.functions.json_sql
39 |
40 |
41 | render_expression
42 | -----------------
43 |
44 | .. autofunction:: sqlalchemy_utils.functions.render_expression
45 |
46 |
47 | render_statement
48 | ----------------
49 |
50 | .. autofunction:: sqlalchemy_utils.functions.render_statement
51 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/enriched_datetime/pendulum_date.py:
--------------------------------------------------------------------------------
1 | from ...exceptions import ImproperlyConfigured
2 | from .pendulum_datetime import PendulumDateTime
3 |
4 | pendulum = None
5 | try:
6 | import pendulum
7 | except ImportError:
8 | pass
9 |
10 |
11 | class PendulumDate(PendulumDateTime):
12 | def __init__(self):
13 | if not pendulum:
14 | raise ImproperlyConfigured(
15 | "'pendulum' package is required to use 'PendulumDate'"
16 | )
17 |
18 | def _coerce(self, impl, value):
19 | if value:
20 | if not isinstance(value, pendulum.Date):
21 | value = super()._coerce(impl, value).date()
22 | return value
23 |
24 | def process_result_value(self, impl, value, dialect):
25 | if value:
26 | return pendulum.parse(value.isoformat()).date()
27 | return value
28 |
29 | def process_bind_param(self, impl, value, dialect):
30 | if value:
31 | return self._coerce(impl, value)
32 | return value
33 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .database import ( # noqa
2 | create_database,
3 | database_exists,
4 | drop_database,
5 | escape_like,
6 | has_index,
7 | has_unique_index,
8 | is_auto_assigned_date_column,
9 | json_sql,
10 | jsonb_sql,
11 | )
12 | from .foreign_keys import ( # noqa
13 | dependent_objects,
14 | get_fk_constraint_for_columns,
15 | get_referencing_foreign_keys,
16 | group_foreign_keys,
17 | merge_references,
18 | non_indexed_foreign_keys,
19 | )
20 | from .mock import create_mock_engine, mock_engine # noqa
21 | from .orm import ( # noqa
22 | cast_if,
23 | get_bind,
24 | get_class_by_table,
25 | get_column_key,
26 | get_columns,
27 | get_declarative_base,
28 | get_hybrid_properties,
29 | get_mapper,
30 | get_primary_keys,
31 | get_tables,
32 | get_type,
33 | getdotattr,
34 | has_changes,
35 | identity,
36 | is_loaded,
37 | naturally_equivalent,
38 | quote,
39 | table_name,
40 | )
41 | from .render import render_expression, render_statement # noqa
42 | from .sort_query import make_order_by_deterministic # noqa
43 |
--------------------------------------------------------------------------------
/tests/generic_relationship/test_column_aliases.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import generic_relationship
5 |
6 | from . import GenericRelationshipTestCase
7 |
8 |
9 | @pytest.fixture
10 | def Building(Base):
11 | class Building(Base):
12 | __tablename__ = 'building'
13 | id = sa.Column(sa.Integer, primary_key=True)
14 | return Building
15 |
16 |
17 | @pytest.fixture
18 | def User(Base):
19 | class User(Base):
20 | __tablename__ = 'user'
21 | id = sa.Column(sa.Integer, primary_key=True)
22 | return User
23 |
24 |
25 | @pytest.fixture
26 | def Event(Base):
27 | class Event(Base):
28 | __tablename__ = 'event'
29 | id = sa.Column(sa.Integer, primary_key=True)
30 |
31 | object_type = sa.Column(sa.Unicode(255), name="objectType")
32 | object_id = sa.Column(sa.Integer, nullable=False)
33 |
34 | object = generic_relationship(object_type, object_id)
35 | return Event
36 |
37 |
38 | @pytest.fixture
39 | def init_models(Building, User, Event):
40 | pass
41 |
42 |
43 | class TestGenericRelationship(GenericRelationshipTestCase):
44 | pass
45 |
--------------------------------------------------------------------------------
/tests/types/test_ip_address.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.types import ip_address
5 |
6 |
7 | @pytest.fixture
8 | def Visitor(Base):
9 | class Visitor(Base):
10 | __tablename__ = 'document'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | ip_address = sa.Column(ip_address.IPAddressType)
13 |
14 | def __repr__(self):
15 | return 'Visitor(%r)' % self.id
16 | return Visitor
17 |
18 |
19 | @pytest.fixture
20 | def init_models(Visitor):
21 | pass
22 |
23 |
24 | @pytest.mark.skipif('ip_address.ip_address is None')
25 | class TestIPAddressType:
26 | def test_parameter_processing(self, session, Visitor):
27 | visitor = Visitor(
28 | ip_address='111.111.111.111'
29 | )
30 |
31 | session.add(visitor)
32 | session.commit()
33 |
34 | visitor = session.query(Visitor).first()
35 | assert str(visitor.ip_address) == '111.111.111.111'
36 |
37 | def test_compilation(self, Visitor, session):
38 | query = sa.select(Visitor.ip_address)
39 | # the type should be cacheable and not throw exception
40 | session.execute(query)
41 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/enriched_datetime/arrow_datetime.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 | from datetime import datetime
3 |
4 | from ...exceptions import ImproperlyConfigured
5 |
6 | arrow = None
7 | try:
8 | import arrow
9 | except ImportError:
10 | pass
11 |
12 |
13 | class ArrowDateTime:
14 | def __init__(self):
15 | if not arrow:
16 | raise ImproperlyConfigured(
17 | "'arrow' package is required to use 'ArrowDateTime'"
18 | )
19 |
20 | def _coerce(self, impl, value):
21 | if isinstance(value, str):
22 | value = arrow.get(value)
23 | elif isinstance(value, Iterable):
24 | value = arrow.get(*value)
25 | elif isinstance(value, datetime):
26 | value = arrow.get(value)
27 | return value
28 |
29 | def process_bind_param(self, impl, value, dialect):
30 | if value:
31 | utc_val = self._coerce(impl, value).to('UTC')
32 | return utc_val.datetime if impl.timezone else utc_val.naive
33 | return value
34 |
35 | def process_result_value(self, impl, value, dialect):
36 | if value:
37 | return arrow.get(value)
38 | return value
39 |
--------------------------------------------------------------------------------
/tests/types/test_url.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.types import url
5 |
6 |
7 | @pytest.fixture
8 | def User(Base):
9 | class User(Base):
10 | __tablename__ = 'user'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | website = sa.Column(url.URLType)
13 |
14 | def __repr__(self):
15 | return 'User(%r)' % self.id
16 | return User
17 |
18 |
19 | @pytest.fixture
20 | def init_models(User):
21 | pass
22 |
23 |
24 | @pytest.mark.skipif('url.furl is None')
25 | class TestURLType:
26 | def test_color_parameter_processing(self, session, User):
27 | user = User(
28 | website=url.furl('www.example.com')
29 | )
30 |
31 | session.add(user)
32 | session.commit()
33 |
34 | user = session.query(User).first()
35 | assert isinstance(user.website, url.furl)
36 |
37 | def test_scalar_attributes_get_coerced_to_objects(self, User):
38 | user = User(website='www.example.com')
39 |
40 | assert isinstance(user.website, url.furl)
41 |
42 | def test_compilation(self, User, session):
43 | query = sa.select(User.website)
44 | # the type should be cacheable and not throw exception
45 | session.execute(query)
46 |
--------------------------------------------------------------------------------
/tests/functions/test_get_hybrid_properties.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | from sqlalchemy.ext.hybrid import hybrid_property
4 |
5 | from sqlalchemy_utils import get_hybrid_properties
6 |
7 |
8 | @pytest.fixture
9 | def Category(Base):
10 | class Category(Base):
11 | __tablename__ = 'category'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | name = sa.Column(sa.Unicode(255))
14 |
15 | @hybrid_property
16 | def lowercase_name(self):
17 | return self.name.lower()
18 |
19 | @lowercase_name.expression
20 | def lowercase_name(cls):
21 | return sa.func.lower(cls.name)
22 | return Category
23 |
24 |
25 | class TestGetHybridProperties:
26 |
27 | def test_declarative_model(self, Category):
28 | assert (
29 | list(get_hybrid_properties(Category).keys()) ==
30 | ['lowercase_name']
31 | )
32 |
33 | def test_mapper(self, Category):
34 | assert (
35 | list(get_hybrid_properties(sa.inspect(Category)).keys()) ==
36 | ['lowercase_name']
37 | )
38 |
39 | def test_aliased_class(self, Category):
40 | props = get_hybrid_properties(sa.orm.aliased(Category))
41 | assert list(props.keys()) == ['lowercase_name']
42 |
--------------------------------------------------------------------------------
/requirements/README.rst:
--------------------------------------------------------------------------------
1 | ``requirements/``
2 | #################
3 |
4 | This directory contains the files that manage dependencies for the project.
5 |
6 | Each subdirectory in this directory contains a ``pyproject.toml`` file
7 | with purpose-specific dependencies listed.
8 |
9 |
10 | How it's used
11 | =============
12 |
13 | Tox is configured to use the exported ``requirements.txt`` files as needed.
14 | In addition, Read the Docs is configured to use ``docs/requirements.txt``.
15 | This helps ensure reproducible testing, linting, and documentation builds.
16 |
17 |
18 | How it's updated
19 | ================
20 |
21 | A tox label, ``update``, ensures that dependencies can be easily updated,
22 | and that ``requirements.txt`` files are consistently re-exported.
23 |
24 | This can be invoked by running:
25 |
26 | .. code-block::
27 |
28 | tox run -m update
29 |
30 |
31 | How to add dependencies
32 | =======================
33 |
34 | New dependencies can be added to a given subdirectory's ``pyproject.toml``
35 | by either manually modifying the file, or by running a command like:
36 |
37 | .. code-block::
38 |
39 | poetry add --lock --directory "requirements/$DIR" $DEPENDENCY_NAME
40 |
41 | Either way, the dependencies must be re-exported:
42 |
43 | .. code-block::
44 |
45 | tox run -m update
46 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/primitives/weekday.py:
--------------------------------------------------------------------------------
1 | from functools import total_ordering
2 |
3 | from .. import i18n
4 | from ..utils import str_coercible
5 |
6 |
7 | @str_coercible
8 | @total_ordering
9 | class WeekDay:
10 | NUM_WEEK_DAYS = 7
11 |
12 | def __init__(self, index):
13 | if not (0 <= index < self.NUM_WEEK_DAYS):
14 | raise ValueError('index must be between 0 and %d' % self.NUM_WEEK_DAYS)
15 | self.index = index
16 |
17 | def __eq__(self, other):
18 | if isinstance(other, WeekDay):
19 | return self.index == other.index
20 | else:
21 | return NotImplemented
22 |
23 | def __hash__(self):
24 | return hash(self.index)
25 |
26 | def __lt__(self, other):
27 | return self.position < other.position
28 |
29 | def __repr__(self):
30 | return f'{self.__class__.__name__}({self.index!r})'
31 |
32 | def __unicode__(self):
33 | return self.name
34 |
35 | def get_name(self, width='wide', context='format'):
36 | names = i18n.babel.dates.get_day_names(width, context, i18n.get_locale())
37 | return names[self.index]
38 |
39 | @property
40 | def name(self):
41 | return self.get_name()
42 |
43 | @property
44 | def position(self):
45 | return (self.index - i18n.get_locale().first_week_day) % self.NUM_WEEK_DAYS
46 |
--------------------------------------------------------------------------------
/tests/functions/test_get_column_key.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 |
6 | from sqlalchemy_utils import get_column_key
7 |
8 |
9 | @pytest.fixture
10 | def Building(Base):
11 | class Building(Base):
12 | __tablename__ = 'building'
13 | id = sa.Column(sa.Integer, primary_key=True)
14 | name = sa.Column('_name', sa.Unicode(255))
15 | return Building
16 |
17 |
18 | @pytest.fixture
19 | def Movie(Base):
20 | class Movie(Base):
21 | __tablename__ = 'movie'
22 | id = sa.Column(sa.Integer, primary_key=True)
23 | return Movie
24 |
25 |
26 | class TestGetColumnKey:
27 |
28 | def test_supports_aliases(self, Building):
29 | assert (
30 | get_column_key(Building, Building.__table__.c.id) ==
31 | 'id'
32 | )
33 | assert (
34 | get_column_key(Building, Building.__table__.c._name) ==
35 | 'name'
36 | )
37 |
38 | def test_supports_vague_matching_of_column_objects(self, Building):
39 | column = copy(Building.__table__.c._name)
40 | assert get_column_key(Building, column) == 'name'
41 |
42 | def test_throws_value_error_for_unknown_column(self, Building, Movie):
43 | with pytest.raises(sa.orm.exc.UnmappedColumnError):
44 | get_column_key(Building, Movie.__table__.c.id)
45 |
--------------------------------------------------------------------------------
/tests/types/test_email.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import EmailType
5 |
6 |
7 | @pytest.fixture
8 | def User(Base):
9 | class User(Base):
10 | __tablename__ = 'user'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | email = sa.Column(EmailType)
13 | short_email = sa.Column(EmailType(length=70))
14 |
15 | def __repr__(self):
16 | return 'User(%r)' % self.id
17 | return User
18 |
19 |
20 | class TestEmailType:
21 | def test_saves_email_as_lowercased(self, session, User):
22 | user = User(email='Someone@example.com')
23 |
24 | session.add(user)
25 | session.commit()
26 |
27 | user = session.query(User).first()
28 | assert user.email == 'someone@example.com'
29 |
30 | def test_literal_param(self, session, User):
31 | clause = User.email == 'Someone@example.com'
32 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
33 | assert compiled == '"user".email = lower(\'Someone@example.com\')'
34 |
35 | def test_custom_length(self, session, User):
36 | assert User.short_email.type.impl.length == 70
37 |
38 | def test_compilation(self, User, session):
39 | query = sa.select(User.email)
40 | # the type should be cacheable and not throw exception
41 | session.execute(query)
42 |
--------------------------------------------------------------------------------
/tests/functions/test_identity.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.functions import identity
5 |
6 |
7 | class IdentityTestCase:
8 |
9 | @pytest.fixture
10 | def init_models(self, Building):
11 | pass
12 |
13 | def test_for_transient_class_without_id(self, Building):
14 | assert identity(Building()) == (None, )
15 |
16 | def test_for_transient_class_with_id(self, session, Building):
17 | building = Building(name='Some building')
18 | session.add(building)
19 | session.flush()
20 |
21 | assert identity(building) == (building.id, )
22 |
23 | def test_identity_for_class(self, Building):
24 | assert identity(Building) == (Building.id, )
25 |
26 |
27 | class TestIdentity(IdentityTestCase):
28 |
29 | @pytest.fixture
30 | def Building(self, Base):
31 | class Building(Base):
32 | __tablename__ = 'building'
33 | id = sa.Column(sa.Integer, primary_key=True)
34 | name = sa.Column(sa.Unicode(255))
35 | return Building
36 |
37 |
38 | class TestIdentityWithColumnAlias(IdentityTestCase):
39 |
40 | @pytest.fixture
41 | def Building(self, Base):
42 | class Building(Base):
43 | __tablename__ = 'building'
44 | id = sa.Column('_id', sa.Integer, primary_key=True)
45 | name = sa.Column(sa.Unicode(255))
46 | return Building
47 |
--------------------------------------------------------------------------------
/requirements/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | alabaster==0.7.16 ; python_version >= "3.9"
2 | babel==2.17.0 ; python_version >= "3.9"
3 | certifi==2025.7.9 ; python_version >= "3.9"
4 | charset-normalizer==3.4.2 ; python_version >= "3.9"
5 | colorama==0.4.6 ; python_version >= "3.9" and sys_platform == "win32"
6 | docutils==0.21.2 ; python_version >= "3.9"
7 | idna==3.10 ; python_version >= "3.9"
8 | imagesize==1.4.1 ; python_version >= "3.9"
9 | importlib-metadata==8.7.0 ; python_version == "3.9"
10 | jinja2==3.1.6 ; python_version >= "3.9"
11 | markupsafe==3.0.2 ; python_version >= "3.9"
12 | packaging==25.0 ; python_version >= "3.9"
13 | pygments==2.19.2 ; python_version >= "3.9"
14 | requests==2.32.4 ; python_version >= "3.9"
15 | snowballstemmer==3.0.1 ; python_version >= "3.9"
16 | sphinx-rtd-theme==3.0.2 ; python_version >= "3.9"
17 | sphinx==7.4.7 ; python_version >= "3.9"
18 | sphinxcontrib-applehelp==2.0.0 ; python_version >= "3.9"
19 | sphinxcontrib-devhelp==2.0.0 ; python_version >= "3.9"
20 | sphinxcontrib-htmlhelp==2.1.0 ; python_version >= "3.9"
21 | sphinxcontrib-jquery==4.1 ; python_version >= "3.9"
22 | sphinxcontrib-jsmath==1.0.1 ; python_version >= "3.9"
23 | sphinxcontrib-qthelp==2.0.0 ; python_version >= "3.9"
24 | sphinxcontrib-serializinghtml==2.0.0 ; python_version >= "3.9"
25 | tomli==2.2.1 ; python_version >= "3.9" and python_version < "3.11"
26 | urllib3==2.6.0 ; python_version >= "3.9"
27 | zipp==3.23.0 ; python_version == "3.9"
28 |
--------------------------------------------------------------------------------
/tests/generic_relationship/test_abstract_base_class.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | from sqlalchemy.ext.declarative import declared_attr
4 |
5 | from sqlalchemy_utils import generic_relationship
6 |
7 | from . import GenericRelationshipTestCase
8 |
9 |
10 | @pytest.fixture
11 | def Building(Base):
12 | class Building(Base):
13 | __tablename__ = 'building'
14 | id = sa.Column(sa.Integer, primary_key=True)
15 | return Building
16 |
17 |
18 | @pytest.fixture
19 | def User(Base):
20 | class User(Base):
21 | __tablename__ = 'user'
22 | id = sa.Column(sa.Integer, primary_key=True)
23 | return User
24 |
25 |
26 | @pytest.fixture
27 | def EventBase(Base):
28 | class EventBase(Base):
29 | __abstract__ = True
30 |
31 | object_type = sa.Column(sa.Unicode(255))
32 | object_id = sa.Column(sa.Integer, nullable=False)
33 |
34 | @declared_attr
35 | def object(cls):
36 | return generic_relationship('object_type', 'object_id')
37 | return EventBase
38 |
39 |
40 | @pytest.fixture
41 | def Event(EventBase):
42 | class Event(EventBase):
43 | __tablename__ = 'event'
44 | id = sa.Column(sa.Integer, primary_key=True)
45 | return Event
46 |
47 |
48 | @pytest.fixture
49 | def init_models(Building, User, Event):
50 | pass
51 |
52 |
53 | class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
54 | pass
55 |
--------------------------------------------------------------------------------
/tests/functions/test_cast_if.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | import sqlalchemy.orm
4 |
5 | from sqlalchemy_utils import cast_if
6 |
7 |
8 | @pytest.fixture(scope='class')
9 | def base():
10 | return sqlalchemy.orm.declarative_base()
11 |
12 |
13 | @pytest.fixture(scope='class')
14 | def article_cls(base):
15 | class Article(base):
16 | __tablename__ = 'article'
17 | id = sa.Column(sa.Integer, primary_key=True)
18 | name = sa.Column(sa.String)
19 | name_synonym = sa.orm.synonym('name')
20 |
21 | return Article
22 |
23 |
24 | class TestCastIf:
25 | def test_column(self, article_cls):
26 | expr = article_cls.__table__.c.name
27 | assert cast_if(expr, sa.String) is expr
28 |
29 | def test_column_property(self, article_cls):
30 | expr = article_cls.name.property
31 | assert cast_if(expr, sa.String) is expr
32 |
33 | def test_instrumented_attribute(self, article_cls):
34 | expr = article_cls.name
35 | assert cast_if(expr, sa.String) is expr
36 |
37 | def test_synonym(self, article_cls):
38 | expr = article_cls.name_synonym
39 | assert cast_if(expr, sa.String) is expr
40 |
41 | def test_scalar_selectable(self, article_cls):
42 | expr = sa.select(article_cls.id).scalar_subquery()
43 | assert cast_if(expr, sa.Integer) is expr
44 |
45 | def test_scalar(self):
46 | assert cast_if('something', sa.String) == 'something'
47 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | This part of the documentation covers the installation of SQLAlchemy-Utils.
5 |
6 | Supported platforms
7 | -------------------
8 |
9 | SQLAlchemy-Utils is currently tested against the following Python platforms:
10 |
11 | - CPython 3.9
12 | - CPython 3.10
13 | - CPython 3.11
14 | - CPython 3.12
15 |
16 |
17 | Installing an official release
18 | ------------------------------
19 |
20 | You can install the most recent official SQLAlchemy-Utils version using
21 | pip_::
22 |
23 | pip install sqlalchemy-utils
24 |
25 | .. _pip: https://pip.pypa.io/
26 |
27 | Installing the development version
28 | ----------------------------------
29 |
30 | To install the latest version of SQLAlchemy-Utils, you need first obtain a
31 | copy of the source. You can do that by cloning the git_ repository::
32 |
33 | git clone git://github.com/kvesteri/sqlalchemy-utils.git
34 |
35 | Then you can install the source distribution using pip::
36 |
37 | cd sqlalchemy-utils
38 | pip install -e .
39 |
40 | .. _git: https://git-scm.com/
41 |
42 | Checking the installation
43 | -------------------------
44 |
45 | To check that SQLAlchemy-Utils has been properly installed, type ``python``
46 | from your shell. Then at the Python prompt, try to import SQLAlchemy-Utils,
47 | and check the installed version:
48 |
49 | .. parsed-literal::
50 |
51 | >>> import sqlalchemy_utils
52 | >>> sqlalchemy_utils.__version__
53 | |release|
54 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/email.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy as sa
2 |
3 | from ..operators import CaseInsensitiveComparator
4 |
5 |
6 | class EmailType(sa.types.TypeDecorator):
7 | """
8 | Provides a way for storing emails in a lower case.
9 |
10 | Example::
11 |
12 |
13 | from sqlalchemy_utils import EmailType
14 |
15 |
16 | class User(Base):
17 | __tablename__ = 'user'
18 | id = sa.Column(sa.Integer, primary_key=True)
19 | name = sa.Column(sa.Unicode(255))
20 | email = sa.Column(EmailType)
21 |
22 |
23 | user = User()
24 | user.email = 'John.Smith@foo.com'
25 | user.name = 'John Smith'
26 | session.add(user)
27 | session.commit()
28 | # Notice - email in filter() is lowercase.
29 | user = (session.query(User)
30 | .filter(User.email == 'john.smith@foo.com')
31 | .one())
32 | assert user.name == 'John Smith'
33 | """
34 |
35 | impl = sa.Unicode
36 | comparator_factory = CaseInsensitiveComparator
37 | cache_ok = True
38 |
39 | def __init__(self, length=255, *args, **kwargs):
40 | super().__init__(length=length, *args, **kwargs)
41 |
42 | def process_bind_param(self, value, dialect):
43 | if value is not None:
44 | return value.lower()
45 | return value
46 |
47 | @property
48 | def python_type(self):
49 | return self.impl.type.python_type
50 |
--------------------------------------------------------------------------------
/tests/functions/test_has_changes.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import has_changes
5 |
6 |
7 | @pytest.fixture
8 | def Article(Base):
9 | class Article(Base):
10 | __tablename__ = 'article_translation'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | title = sa.Column(sa.String(100))
13 | return Article
14 |
15 |
16 | class TestHasChangesWithStringAttr:
17 | def test_without_changed_attr(self, Article):
18 | article = Article()
19 | assert not has_changes(article, 'title')
20 |
21 | def test_with_changed_attr(self, Article):
22 | article = Article(title='Some title')
23 | assert has_changes(article, 'title')
24 |
25 |
26 | class TestHasChangesWithMultipleAttrs:
27 | def test_without_changed_attr(self, Article):
28 | article = Article()
29 | assert not has_changes(article, ['title'])
30 |
31 | def test_with_changed_attr(self, Article):
32 | article = Article(title='Some title')
33 | assert has_changes(article, ['title', 'id'])
34 |
35 |
36 | class TestHasChangesWithExclude:
37 | def test_without_changed_attr(self, Article):
38 | article = Article()
39 | assert not has_changes(article, exclude=['id'])
40 |
41 | def test_with_changed_attr(self, Article):
42 | article = Article(title='Some title')
43 | assert has_changes(article, exclude=['id'])
44 | assert not has_changes(article, exclude=['title'])
45 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/enriched_datetime/pendulum_datetime.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from ...exceptions import ImproperlyConfigured
4 |
5 | pendulum = None
6 | try:
7 | import pendulum
8 | except ImportError:
9 | pass
10 |
11 |
12 | class PendulumDateTime:
13 | def __init__(self):
14 | if not pendulum:
15 | raise ImproperlyConfigured(
16 | "'pendulum' package is required to use 'PendulumDateTime'"
17 | )
18 |
19 | def _coerce(self, impl, value):
20 | if value is not None:
21 | if isinstance(value, pendulum.DateTime):
22 | pass
23 | elif isinstance(value, (int, float)):
24 | value = pendulum.from_timestamp(value)
25 | elif isinstance(value, str) and value.isdigit():
26 | value = pendulum.from_timestamp(int(value))
27 | elif isinstance(value, datetime.datetime):
28 | value = pendulum.DateTime.instance(value)
29 | else:
30 | value = pendulum.parse(value)
31 | return value
32 |
33 | def process_bind_param(self, impl, value, dialect):
34 | if value:
35 | return self._coerce(impl, value).in_tz('UTC').naive()
36 | return value
37 |
38 | def process_result_value(self, impl, value, dialect):
39 | if value:
40 | return pendulum.DateTime.instance(
41 | value.replace(tzinfo=datetime.timezone.utc)
42 | )
43 | return value
44 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2012, Konsta Vesterinen
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | * The names of the contributors may not be used to endorse or promote products
16 | derived from this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY DIRECT,
22 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
23 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
26 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
27 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/tests/types/test_country.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import Country, CountryType, i18n # noqa
5 |
6 |
7 | @pytest.fixture
8 | def User(Base):
9 | class User(Base):
10 | __tablename__ = 'user'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | country = sa.Column(CountryType)
13 |
14 | def __repr__(self):
15 | return 'User(%r)' % self.id
16 | return User
17 |
18 |
19 | @pytest.fixture
20 | def init_models(User):
21 | pass
22 |
23 |
24 | @pytest.mark.skipif('i18n.babel is None')
25 | class TestCountryType:
26 |
27 | def test_parameter_processing(self, session, User):
28 | user = User(
29 | country=Country('FI')
30 | )
31 |
32 | session.add(user)
33 | session.commit()
34 |
35 | user = session.query(User).first()
36 | assert user.country.name == 'Finland'
37 |
38 | def test_scalar_attributes_get_coerced_to_objects(self, User):
39 | user = User(country='FI')
40 | assert isinstance(user.country, Country)
41 |
42 | def test_literal_param(self, session, User):
43 | clause = User.country == 'FI'
44 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
45 | assert compiled == '"user".country = \'FI\''
46 |
47 | def test_compilation(self, User, session):
48 | query = sa.select(User.country)
49 | # the type should be cacheable and not throw exception
50 | session.execute(query)
51 |
--------------------------------------------------------------------------------
/tests/types/test_ltree.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import Ltree, LtreeType
5 |
6 |
7 | @pytest.fixture
8 | def Section(Base):
9 | class Section(Base):
10 | __tablename__ = 'section'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | path = sa.Column(LtreeType)
13 |
14 | return Section
15 |
16 |
17 | @pytest.fixture
18 | def init_models(Section, connection):
19 | with connection.begin():
20 | connection.execute(sa.text('CREATE EXTENSION IF NOT EXISTS ltree'))
21 | pass
22 |
23 |
24 | @pytest.mark.usefixtures('postgresql_dsn')
25 | class TestLTREE:
26 | def test_saves_path(self, session, Section):
27 | section = Section(path='1.1.2')
28 |
29 | session.add(section)
30 | session.commit()
31 |
32 | user = session.query(Section).first()
33 | assert user.path == '1.1.2'
34 |
35 | def test_scalar_attributes_get_coerced_to_objects(self, Section):
36 | section = Section(path='path.path')
37 | assert isinstance(section.path, Ltree)
38 |
39 | def test_literal_param(self, session, Section):
40 | clause = Section.path == 'path'
41 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
42 | assert compiled == 'section.path = \'path\''
43 |
44 | def test_compilation(self, Section, session):
45 | query = sa.select(Section.path)
46 | # the type should be cacheable and not throw exception
47 | session.execute(query)
48 |
--------------------------------------------------------------------------------
/tests/functions/test_get_primary_keys.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import get_primary_keys
5 |
6 | try:
7 | from collections import OrderedDict
8 | except ImportError:
9 | from ordereddict import OrderedDict
10 |
11 |
12 | @pytest.fixture
13 | def Building(Base):
14 | class Building(Base):
15 | __tablename__ = 'building'
16 | id = sa.Column('_id', sa.Integer, primary_key=True)
17 | name = sa.Column('_name', sa.Unicode(255))
18 | return Building
19 |
20 |
21 | class TestGetPrimaryKeys:
22 |
23 | def test_table(self, Building):
24 | assert get_primary_keys(Building.__table__) == OrderedDict({
25 | '_id': Building.__table__.c._id
26 | })
27 |
28 | def test_declarative_class(self, Building):
29 | assert get_primary_keys(Building) == OrderedDict({
30 | 'id': Building.__table__.c._id
31 | })
32 |
33 | def test_declarative_object(self, Building):
34 | assert get_primary_keys(Building()) == OrderedDict({
35 | 'id': Building.__table__.c._id
36 | })
37 |
38 | def test_class_alias(self, Building):
39 | alias = sa.orm.aliased(Building)
40 | assert get_primary_keys(alias) == OrderedDict({
41 | 'id': Building.__table__.c._id
42 | })
43 |
44 | def test_table_alias(self, Building):
45 | alias = sa.orm.aliased(Building.__table__)
46 | assert get_primary_keys(alias) == OrderedDict({
47 | '_id': alias.c._id
48 | })
49 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 |
4 | def get_version() -> str:
5 | path = pathlib.Path(__file__).parent.parent / "sqlalchemy_utils/__init__.py"
6 | content = path.read_text()
7 | for line in content.splitlines():
8 | if line.strip().startswith("__version__"):
9 | _, _, version_ = line.partition("=")
10 | return version_.strip("'\" ")
11 |
12 |
13 | # -- General configuration -----------------------------------------------------
14 |
15 | # Add any Sphinx extension module names here, as strings. They can be extensions
16 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
17 | extensions = [
18 | "sphinx.ext.autodoc",
19 | "sphinx.ext.doctest",
20 | "sphinx.ext.intersphinx",
21 | "sphinx.ext.todo",
22 | "sphinx.ext.coverage",
23 | "sphinx.ext.ifconfig",
24 | "sphinx.ext.viewcode",
25 | ]
26 |
27 | # Add any paths that contain templates here, relative to this directory.
28 | templates_path = ["_templates"]
29 |
30 | # The master toctree document.
31 | master_doc = "index"
32 |
33 | # General information about the project.
34 | project = "SQLAlchemy-Utils"
35 | copyright = "2013-2022, Konsta Vesterinen"
36 |
37 | version = get_version()
38 | release = version
39 |
40 | # List of patterns, relative to source directory, that match files and
41 | # directories to ignore when looking for source files.
42 | exclude_patterns = ["_build"]
43 |
44 | # The name of the Pygments (syntax highlighting) style to use.
45 | # pygments_style = "sphinx"
46 |
47 | html_theme = "sphinx_rtd_theme"
48 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/ip_address.py:
--------------------------------------------------------------------------------
1 | from ipaddress import ip_address
2 |
3 | from sqlalchemy import types
4 |
5 | from .scalar_coercible import ScalarCoercible
6 |
7 |
8 | class IPAddressType(ScalarCoercible, types.TypeDecorator):
9 | """
10 | Changes IPAddress objects to a string representation on the way in and
11 | changes them back to IPAddress objects on the way out.
12 |
13 | ::
14 |
15 |
16 | from sqlalchemy_utils import IPAddressType
17 |
18 |
19 | class User(Base):
20 | __tablename__ = 'user'
21 | id = sa.Column(sa.Integer, autoincrement=True)
22 | name = sa.Column(sa.Unicode(255))
23 | ip_address = sa.Column(IPAddressType)
24 |
25 |
26 | user = User()
27 | user.ip_address = '123.123.123.123'
28 | session.add(user)
29 | session.commit()
30 |
31 | user.ip_address # IPAddress object
32 | """
33 |
34 | impl = types.Unicode(50)
35 | cache_ok = True
36 |
37 | def __init__(self, max_length=50, *args, **kwargs):
38 | super().__init__(*args, **kwargs)
39 | self.impl = types.Unicode(max_length)
40 |
41 | def process_bind_param(self, value, dialect):
42 | return str(value) if value else None
43 |
44 | def process_result_value(self, value, dialect):
45 | return ip_address(value) if value else None
46 |
47 | def _coerce(self, value):
48 | return ip_address(value) if value else None
49 |
50 | @property
51 | def python_type(self):
52 | return self.impl.type.python_type
53 |
--------------------------------------------------------------------------------
/tests/functions/test_get_type.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | import sqlalchemy.orm
4 |
5 | from sqlalchemy_utils import get_type
6 |
7 |
8 | @pytest.fixture
9 | def User(Base):
10 | class User(Base):
11 | __tablename__ = 'user'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | return User
14 |
15 |
16 | @pytest.fixture
17 | def Article(Base, User):
18 | class Article(Base):
19 | __tablename__ = 'article'
20 | id = sa.Column(sa.Integer, primary_key=True)
21 |
22 | author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
23 | author = sa.orm.relationship(User)
24 |
25 | some_property = sa.orm.column_property(
26 | sa.func.coalesce(id, 1)
27 | )
28 | return Article
29 |
30 |
31 | class TestGetType:
32 |
33 | def test_instrumented_attribute(self, Article):
34 | assert isinstance(get_type(Article.id), sa.Integer)
35 |
36 | def test_column_property(self, Article):
37 | assert isinstance(get_type(Article.id.property), sa.Integer)
38 |
39 | def test_column(self, Article):
40 | assert isinstance(get_type(Article.__table__.c.id), sa.Integer)
41 |
42 | def test_calculated_column_property(self, Article):
43 | assert isinstance(get_type(Article.some_property), sa.Integer)
44 |
45 | def test_relationship_property(self, Article, User):
46 | assert get_type(Article.author) == User
47 |
48 | def test_scalar_select(self, Article):
49 | query = sa.select(Article.id).scalar_subquery()
50 | assert isinstance(get_type(query), sa.Integer)
51 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/enriched_datetime/enriched_date_type.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from ..scalar_coercible import ScalarCoercible
4 | from .pendulum_date import PendulumDate
5 |
6 |
7 | class EnrichedDateType(types.TypeDecorator, ScalarCoercible):
8 | """
9 | Supported for pendulum only.
10 |
11 | Example::
12 |
13 |
14 | from sqlalchemy_utils import EnrichedDateType
15 | import pendulum
16 |
17 |
18 | class User(Base):
19 | __tablename__ = 'user'
20 | id = sa.Column(sa.Integer, primary_key=True)
21 | birthday = sa.Column(EnrichedDateType(type="pendulum"))
22 |
23 |
24 | user = User()
25 | user.birthday = pendulum.datetime(year=1995, month=7, day=11)
26 | session.add(user)
27 | session.commit()
28 | """
29 |
30 | impl = types.Date
31 | cache_ok = True
32 |
33 | def __init__(self, date_processor=PendulumDate, *args, **kwargs):
34 | super().__init__(*args, **kwargs)
35 | self.date_object = date_processor()
36 |
37 | def _coerce(self, value):
38 | return self.date_object._coerce(self.impl, value)
39 |
40 | def process_bind_param(self, value, dialect):
41 | return self.date_object.process_bind_param(self.impl, value, dialect)
42 |
43 | def process_result_value(self, value, dialect):
44 | return self.date_object.process_result_value(self.impl, value, dialect)
45 |
46 | def process_literal_param(self, value, dialect):
47 | return value
48 |
49 | @property
50 | def python_type(self):
51 | return self.impl.type.python_type
52 |
--------------------------------------------------------------------------------
/tests/test_case_insensitive_comparator.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import EmailType
5 |
6 |
7 | @pytest.fixture
8 | def User(Base):
9 | class User(Base):
10 | __tablename__ = 'user'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | email = sa.Column(EmailType)
13 |
14 | def __repr__(self):
15 | return 'Building(%r)' % self.id
16 | return User
17 |
18 |
19 | @pytest.fixture
20 | def init_models(User):
21 | pass
22 |
23 |
24 | class TestCaseInsensitiveComparator:
25 |
26 | def test_supports_equals(self, session, User):
27 | query = (
28 | session.query(User)
29 | .filter(User.email == 'email@example.com')
30 | )
31 |
32 | assert 'user.email = lower(?)' in str(query)
33 |
34 | def test_supports_in_(self, session, User):
35 | query = (
36 | session.query(User)
37 | .filter(User.email.in_(['email@example.com', 'a']))
38 | )
39 | assert (
40 | 'user.email IN (lower(?), lower(?))'
41 | in str(query)
42 | )
43 |
44 | def test_supports_notin_(self, session, User):
45 | query = (
46 | session.query(User)
47 | .filter(User.email.notin_(['email@example.com', 'a']))
48 | )
49 | assert (
50 | 'user.email NOT IN ('
51 | in str(query)
52 | )
53 |
54 | def test_does_not_apply_lower_to_types_that_are_already_lowercased(
55 | self,
56 | User
57 | ):
58 | assert str(User.email == User.email) == (
59 | '"user".email = "user".email'
60 | )
61 |
--------------------------------------------------------------------------------
/tests/aggregate/test_with_ondelete_cascade.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.aggregates import aggregated
5 |
6 |
7 | @pytest.fixture
8 | def Thread(Base):
9 | class Thread(Base):
10 | __tablename__ = 'thread'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | name = sa.Column(sa.Unicode(255))
13 |
14 | @aggregated('comments', sa.Column(sa.Integer, default=0))
15 | def comment_count(self):
16 | return sa.func.count('1')
17 |
18 | comments = sa.orm.relationship(
19 | 'Comment',
20 | passive_deletes=True,
21 | backref='thread'
22 | )
23 | return Thread
24 |
25 |
26 | @pytest.fixture
27 | def Comment(Base):
28 | class Comment(Base):
29 | __tablename__ = 'comment'
30 | id = sa.Column(sa.Integer, primary_key=True)
31 | content = sa.Column(sa.Unicode(255))
32 | thread_id = sa.Column(
33 | sa.Integer,
34 | sa.ForeignKey('thread.id', ondelete='CASCADE')
35 | )
36 | return Comment
37 |
38 |
39 | @pytest.fixture
40 | def init_models(Thread, Comment):
41 | pass
42 |
43 |
44 | @pytest.mark.usefixtures('postgresql_dsn')
45 | class TestAggregateValueGenerationWithCascadeDelete:
46 |
47 | def test_something(self, session, Thread, Comment):
48 | thread = Thread()
49 | thread.name = 'some article name'
50 | session.add(thread)
51 | comment = Comment(content='Some content', thread=thread)
52 | session.add(comment)
53 | session.commit()
54 | session.expire_all()
55 | session.delete(thread)
56 | session.commit()
57 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/arrow.py:
--------------------------------------------------------------------------------
1 | from ..exceptions import ImproperlyConfigured
2 | from .enriched_datetime import ArrowDateTime
3 | from .enriched_datetime.enriched_datetime_type import EnrichedDateTimeType
4 |
5 | arrow = None
6 | try:
7 | import arrow
8 | except ImportError:
9 | pass
10 |
11 |
12 | class ArrowType(EnrichedDateTimeType):
13 | """
14 | ArrowType provides way of saving Arrow_ objects into database. It
15 | automatically changes Arrow_ objects to datetime objects on the way in and
16 | datetime objects back to Arrow_ objects on the way out (when querying
17 | database). ArrowType needs Arrow_ library installed.
18 |
19 | .. _Arrow: https://github.com/arrow-py/arrow
20 |
21 | ::
22 |
23 | from datetime import datetime
24 | from sqlalchemy_utils import ArrowType
25 | import arrow
26 |
27 |
28 | class Article(Base):
29 | __tablename__ = 'article'
30 | id = sa.Column(sa.Integer, primary_key=True)
31 | name = sa.Column(sa.Unicode(255))
32 | created_at = sa.Column(ArrowType)
33 |
34 |
35 |
36 | article = Article(created_at=arrow.utcnow())
37 |
38 |
39 | As you may expect all the arrow goodies come available:
40 |
41 | ::
42 |
43 |
44 | article.created_at = article.created_at.replace(hours=-1)
45 |
46 | article.created_at.humanize()
47 | # 'an hour ago'
48 |
49 | """
50 |
51 | cache_ok = True
52 |
53 | def __init__(self, *args, **kwargs):
54 | if not arrow:
55 | raise ImproperlyConfigured("'arrow' package is required to use 'ArrowType'")
56 |
57 | super().__init__(datetime_processor=ArrowDateTime, *args, **kwargs)
58 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/enriched_datetime/enriched_datetime_type.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from ..scalar_coercible import ScalarCoercible
4 | from .pendulum_datetime import PendulumDateTime
5 |
6 |
7 | class EnrichedDateTimeType(types.TypeDecorator, ScalarCoercible):
8 | """
9 | Supported for arrow and pendulum.
10 |
11 | Example::
12 |
13 |
14 | from sqlalchemy_utils import EnrichedDateTimeType
15 | import pendulum
16 |
17 |
18 | class User(Base):
19 | __tablename__ = 'user'
20 | id = sa.Column(sa.Integer, primary_key=True)
21 | created_at = sa.Column(EnrichedDateTimeType(type="pendulum"))
22 | # created_at = sa.Column(EnrichedDateTimeType(type="arrow"))
23 |
24 |
25 | user = User()
26 | user.created_at = pendulum.now()
27 | session.add(user)
28 | session.commit()
29 | """
30 |
31 | impl = types.DateTime
32 | cache_ok = True
33 |
34 | def __init__(self, datetime_processor=PendulumDateTime, *args, **kwargs):
35 | super().__init__(*args, **kwargs)
36 | self.dt_object = datetime_processor()
37 |
38 | def _coerce(self, value):
39 | return self.dt_object._coerce(self.impl, value)
40 |
41 | def process_bind_param(self, value, dialect):
42 | return self.dt_object.process_bind_param(self.impl, value, dialect)
43 |
44 | def process_result_value(self, value, dialect):
45 | return self.dt_object.process_result_value(self.impl, value, dialect)
46 |
47 | def process_literal_param(self, value, dialect):
48 | return value
49 |
50 | @property
51 | def python_type(self):
52 | return self.impl.type.python_type
53 |
--------------------------------------------------------------------------------
/tests/types/test_currency.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import Currency, CurrencyType, i18n
5 |
6 |
7 | @pytest.fixture
8 | def set_get_locale():
9 | i18n.get_locale = lambda: i18n.babel.Locale('en')
10 |
11 |
12 | @pytest.fixture
13 | def User(Base):
14 | class User(Base):
15 | __tablename__ = 'user'
16 | id = sa.Column(sa.Integer, primary_key=True)
17 | currency = sa.Column(CurrencyType)
18 |
19 | def __repr__(self):
20 | return 'User(%r)' % self.id
21 |
22 | return User
23 |
24 |
25 | @pytest.fixture
26 | def init_models(User):
27 | pass
28 |
29 |
30 | @pytest.mark.skipif('i18n.babel is None')
31 | class TestCurrencyType:
32 | def test_parameter_processing(self, session, User, set_get_locale):
33 | user = User(
34 | currency=Currency('USD')
35 | )
36 |
37 | session.add(user)
38 | session.commit()
39 |
40 | user = session.query(User).first()
41 | assert user.currency.name == 'US Dollar'
42 |
43 | def test_scalar_attributes_get_coerced_to_objects(
44 | self,
45 | User,
46 | set_get_locale
47 | ):
48 | user = User(currency='USD')
49 | assert isinstance(user.currency, Currency)
50 |
51 | def test_literal_param(self, session, User):
52 | clause = User.currency == 'USD'
53 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
54 | assert compiled == '"user".currency = \'USD\''
55 |
56 | def test_compilation(self, User, session):
57 | query = sa.select(User.currency)
58 | # the type should be cacheable and not throw exception
59 | session.execute(query)
60 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/expressions.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy as sa
2 | from sqlalchemy.dialects import postgresql
3 | from sqlalchemy.ext.compiler import compiles
4 | from sqlalchemy.sql.expression import ColumnElement, FunctionElement
5 | from sqlalchemy.sql.functions import GenericFunction
6 |
7 | from .functions.orm import quote
8 |
9 |
10 | class array_get(FunctionElement):
11 | name = 'array_get'
12 |
13 |
14 | @compiles(array_get)
15 | def compile_array_get(element, compiler, **kw):
16 | args = list(element.clauses)
17 | if len(args) != 2:
18 | raise Exception(
19 | "Function 'array_get' expects two arguments (%d given)." % len(args)
20 | )
21 |
22 | if not hasattr(args[1], 'value') or not isinstance(args[1].value, int):
23 | raise Exception('Second argument should be an integer.')
24 | return f'({compiler.process(args[0])})[{sa.text(str(args[1].value + 1))}]'
25 |
26 |
27 | class row_to_json(GenericFunction):
28 | name = 'row_to_json'
29 | type = postgresql.JSON
30 |
31 |
32 | @compiles(row_to_json, 'postgresql')
33 | def compile_row_to_json(element, compiler, **kw):
34 | return f'{element.name}({compiler.process(element.clauses)})'
35 |
36 |
37 | class json_array_length(GenericFunction):
38 | name = 'json_array_length'
39 | type = sa.Integer
40 |
41 |
42 | @compiles(json_array_length, 'postgresql')
43 | def compile_json_array_length(element, compiler, **kw):
44 | return f'{element.name}({compiler.process(element.clauses)})'
45 |
46 |
47 | class Asterisk(ColumnElement):
48 | def __init__(self, selectable):
49 | self.selectable = selectable
50 |
51 |
52 | @compiles(Asterisk)
53 | def compile_asterisk(element, compiler, **kw):
54 | return '%s.*' % quote(compiler.dialect, element.selectable.name)
55 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/url.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from .scalar_coercible import ScalarCoercible
4 |
5 | furl = None
6 | try:
7 | from furl import furl
8 | except ImportError:
9 | pass
10 |
11 |
12 | class URLType(ScalarCoercible, types.TypeDecorator):
13 | """
14 | URLType stores furl_ objects into database.
15 |
16 | .. _furl: https://github.com/gruns/furl
17 |
18 | ::
19 |
20 | from sqlalchemy_utils import URLType
21 | from furl import furl
22 |
23 |
24 | class User(Base):
25 | __tablename__ = 'user'
26 |
27 | id = sa.Column(sa.Integer, primary_key=True)
28 | website = sa.Column(URLType)
29 |
30 |
31 | user = User(website='www.example.com')
32 |
33 | # website is coerced to furl object, hence all nice furl operations
34 | # come available
35 | user.website.args['some_argument'] = '12'
36 |
37 | print user.website
38 | # www.example.com?some_argument=12
39 | """
40 |
41 | impl = types.UnicodeText
42 | cache_ok = True
43 |
44 | def process_bind_param(self, value, dialect):
45 | if furl is not None and isinstance(value, furl):
46 | return str(value)
47 |
48 | if isinstance(value, str):
49 | return value
50 |
51 | def process_result_value(self, value, dialect):
52 | if furl is None:
53 | return value
54 |
55 | if value is not None:
56 | return furl(value)
57 |
58 | def _coerce(self, value):
59 | if furl is None:
60 | return value
61 |
62 | if value is not None and not isinstance(value, furl):
63 | return furl(value)
64 | return value
65 |
66 | @property
67 | def python_type(self):
68 | return self.impl.type.python_type
69 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/country.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from ..primitives import Country
4 | from .scalar_coercible import ScalarCoercible
5 |
6 |
7 | class CountryType(ScalarCoercible, types.TypeDecorator):
8 | """
9 | Changes :class:`.Country` objects to a string representation on the way in
10 | and changes them back to :class:`.Country` objects on the way out.
11 |
12 | In order to use CountryType you need to install Babel_ first.
13 |
14 | .. _Babel: https://babel.pocoo.org/
15 |
16 | ::
17 |
18 |
19 | from sqlalchemy_utils import CountryType, Country
20 |
21 |
22 | class User(Base):
23 | __tablename__ = 'user'
24 | id = sa.Column(sa.Integer, autoincrement=True)
25 | name = sa.Column(sa.Unicode(255))
26 | country = sa.Column(CountryType)
27 |
28 |
29 | user = User()
30 | user.country = Country('FI')
31 | session.add(user)
32 | session.commit()
33 |
34 | user.country # Country('FI')
35 | user.country.name # Finland
36 |
37 | print user.country # Finland
38 |
39 |
40 | CountryType is scalar coercible::
41 |
42 |
43 | user.country = 'US'
44 | user.country # Country('US')
45 | """
46 |
47 | impl = types.String(2)
48 | python_type = Country
49 | cache_ok = True
50 |
51 | def process_bind_param(self, value, dialect):
52 | if isinstance(value, Country):
53 | return value.code
54 |
55 | if isinstance(value, str):
56 | return value
57 |
58 | def process_result_value(self, value, dialect):
59 | if value is not None:
60 | return Country(value)
61 |
62 | def _coerce(self, value):
63 | if value is not None and not isinstance(value, Country):
64 | return Country(value)
65 | return value
66 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/primitives/weekdays.py:
--------------------------------------------------------------------------------
1 | from ..utils import str_coercible
2 | from .weekday import WeekDay
3 |
4 |
5 | @str_coercible
6 | class WeekDays:
7 | def __init__(self, bit_string_or_week_days):
8 | if isinstance(bit_string_or_week_days, str):
9 | self._days = set()
10 |
11 | if len(bit_string_or_week_days) != WeekDay.NUM_WEEK_DAYS:
12 | raise ValueError(
13 | 'Bit string must be {} characters long.'.format(
14 | WeekDay.NUM_WEEK_DAYS
15 | )
16 | )
17 |
18 | for index, bit in enumerate(bit_string_or_week_days):
19 | if bit not in '01':
20 | raise ValueError('Bit string may only contain zeroes and ones.')
21 | if bit == '1':
22 | self._days.add(WeekDay(index))
23 | elif isinstance(bit_string_or_week_days, WeekDays):
24 | self._days = bit_string_or_week_days._days
25 | else:
26 | self._days = set(bit_string_or_week_days)
27 |
28 | def __eq__(self, other):
29 | if isinstance(other, WeekDays):
30 | return self._days == other._days
31 | elif isinstance(other, str):
32 | return self.as_bit_string() == other
33 | else:
34 | return NotImplemented
35 |
36 | def __iter__(self):
37 | yield from sorted(self._days)
38 |
39 | def __contains__(self, value):
40 | return value in self._days
41 |
42 | def __repr__(self):
43 | return f'{self.__class__.__name__}({self.as_bit_string()!r})'
44 |
45 | def __unicode__(self):
46 | return ', '.join(str(day) for day in self)
47 |
48 | def as_bit_string(self):
49 | return ''.join(
50 | '1' if WeekDay(index) in self._days else '0'
51 | for index in range(WeekDay.NUM_WEEK_DAYS)
52 | )
53 |
--------------------------------------------------------------------------------
/tests/observes/test_o2o_o2o.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.observer import observes
5 |
6 |
7 | @pytest.fixture
8 | def Device(Base):
9 | class Device(Base):
10 | __tablename__ = 'device'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | name = sa.Column(sa.String)
13 | return Device
14 |
15 |
16 | @pytest.fixture
17 | def Order(Base):
18 | class Order(Base):
19 | __tablename__ = 'order'
20 | id = sa.Column(sa.Integer, primary_key=True)
21 |
22 | device_id = sa.Column(
23 | 'device', sa.ForeignKey('device.id'), nullable=False
24 | )
25 | device = sa.orm.relationship('Device', backref='orders')
26 | return Order
27 |
28 |
29 | @pytest.fixture
30 | def SalesInvoice(Base):
31 | class SalesInvoice(Base):
32 | __tablename__ = 'sales_invoice'
33 | id = sa.Column(sa.Integer, primary_key=True)
34 | order_id = sa.Column(
35 | 'order',
36 | sa.ForeignKey('order.id'),
37 | nullable=False
38 | )
39 | order = sa.orm.relationship(
40 | 'Order',
41 | backref=sa.orm.backref(
42 | 'invoice',
43 | uselist=False
44 | )
45 | )
46 | device_name = sa.Column(sa.String)
47 |
48 | @observes('order.device')
49 | def process_device(self, device):
50 | self.device_name = device.name
51 |
52 | return SalesInvoice
53 |
54 |
55 | @pytest.fixture
56 | def init_models(Device, Order, SalesInvoice):
57 | pass
58 |
59 |
60 | @pytest.mark.usefixtures('postgresql_dsn')
61 | class TestObservesForOneToManyToOneToMany:
62 |
63 | def test_observable_root_obj_is_none(self, session, Device, Order):
64 | order = Order(device=Device(name='Something'))
65 | session.add(order)
66 | session.flush()
67 |
--------------------------------------------------------------------------------
/tests/test_instant_defaults_listener.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 |
6 | from sqlalchemy_utils.listeners import force_instant_defaults
7 |
8 | force_instant_defaults()
9 |
10 |
11 | @pytest.fixture
12 | def Article(Base):
13 | class Article(Base):
14 | __tablename__ = 'article'
15 | id = sa.Column(sa.Integer, primary_key=True)
16 | name = sa.Column(sa.Unicode(255), default='Some article')
17 | created_at = sa.Column(sa.DateTime, default=datetime.now)
18 | _byline = sa.Column(sa.Unicode(255), default='Default byline')
19 |
20 | @property
21 | def byline(self):
22 | return self._byline
23 |
24 | @byline.setter
25 | def byline(self, value):
26 | self._byline = value
27 |
28 | return Article
29 |
30 |
31 | @pytest.fixture
32 | def Document(Base):
33 | class Document(Base):
34 | __tablename__ = 'document'
35 | id = sa.Column(
36 | sa.Integer,
37 | sa.Sequence('document_id_seq'),
38 | primary_key=True
39 | )
40 | title = sa.Column(sa.Unicode(255), default='Untitled')
41 |
42 | return Document
43 |
44 |
45 | class TestInstantDefaultListener:
46 | def test_assigns_defaults_on_object_construction(self, Article):
47 | article = Article()
48 | assert article.name == 'Some article'
49 |
50 | def test_callables_as_defaults(self, Article):
51 | article = Article()
52 | assert isinstance(article.created_at, datetime)
53 |
54 | def test_override_default_with_setter_function(self, Article):
55 | article = Article(byline='provided byline')
56 | assert article.byline == 'provided byline'
57 |
58 | def test_handles_sequence_defaults(self, Document):
59 | document = Document()
60 | assert document.title == 'Untitled'
61 | assert document.id is None
62 |
--------------------------------------------------------------------------------
/tests/functions/test_get_columns.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import get_columns
5 |
6 |
7 | @pytest.fixture
8 | def Building(Base):
9 | class Building(Base):
10 | __tablename__ = 'building'
11 | id = sa.Column('_id', sa.Integer, primary_key=True)
12 | name = sa.Column('_name', sa.Unicode(255))
13 | return Building
14 |
15 |
16 | @pytest.fixture
17 | def columns():
18 | return ['_id', '_name']
19 |
20 |
21 | class TestGetColumns:
22 | def test_table(self, Building, columns):
23 | assert [c.name for c in get_columns(Building.__table__)] == columns
24 |
25 | def test_instrumented_attribute(self, Building):
26 | assert get_columns(Building.id) == [Building.__table__.c._id]
27 |
28 | def test_column_property(self, Building):
29 | assert get_columns(Building.id.property) == [
30 | Building.__table__.c._id
31 | ]
32 |
33 | def test_column(self, Building):
34 | assert get_columns(Building.__table__.c._id) == [
35 | Building.__table__.c._id
36 | ]
37 |
38 | def test_declarative_class(self, Building, columns):
39 | assert [c.name for c in get_columns(Building).values()] == columns
40 |
41 | def test_declarative_object(self, Building, columns):
42 | assert [c.name for c in get_columns(Building()).values()] == columns
43 |
44 | def test_mapper(self, Building, columns):
45 | assert [
46 | c.name for c in get_columns(Building.__mapper__).values()
47 | ] == columns
48 |
49 | def test_class_alias(self, Building, columns):
50 | assert [
51 | c.name for c in get_columns(sa.orm.aliased(Building)).values()
52 | ] == columns
53 |
54 | def test_table_alias(self, Building, columns):
55 | alias = sa.orm.aliased(Building.__table__)
56 | assert [c.name for c in get_columns(alias).values()] == columns
57 |
--------------------------------------------------------------------------------
/tests/observes/test_dynamic_relationship.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | from sqlalchemy.orm import dynamic_loader
4 |
5 | from sqlalchemy_utils.observer import observes
6 |
7 |
8 | @pytest.fixture
9 | def Director(Base):
10 | class Director(Base):
11 | __tablename__ = 'director'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | name = sa.Column(sa.String)
14 | movies = dynamic_loader('Movie', back_populates='director')
15 |
16 | return Director
17 |
18 |
19 | @pytest.fixture
20 | def Movie(Base, Director):
21 | class Movie(Base):
22 | __tablename__ = 'movie'
23 | id = sa.Column(sa.Integer, primary_key=True)
24 | name = sa.Column(sa.String)
25 | director_id = sa.Column(sa.Integer, sa.ForeignKey(Director.id))
26 | director = sa.orm.relationship(Director, back_populates='movies')
27 | director_name = sa.Column(sa.String)
28 |
29 | @observes('director')
30 | def director_observer(self, director):
31 | self.director_name = director.name
32 |
33 | return Movie
34 |
35 |
36 | @pytest.fixture
37 | def init_models(Director, Movie):
38 | pass
39 |
40 |
41 | @pytest.mark.usefixtures('postgresql_dsn')
42 | class TestObservesForDynamicRelationship:
43 | def test_add_observed_object(self, session, Director, Movie):
44 | steven = Director(name='Steven Spielberg')
45 | session.add(steven)
46 | jaws = Movie(name='Jaws', director=steven)
47 | session.add(jaws)
48 | session.commit()
49 | assert jaws.director_name == 'Steven Spielberg'
50 |
51 | def test_add_observed_object_from_backref(self, session, Director, Movie):
52 | jaws = Movie(name='Jaws')
53 | steven = Director(name='Steven Spielberg', movies=[jaws])
54 | session.add(steven)
55 | session.add(jaws)
56 | session.commit()
57 | assert jaws.director_name == 'Steven Spielberg'
58 |
--------------------------------------------------------------------------------
/tests/types/encrypted/test_padding.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from sqlalchemy_utils.types.encrypted.padding import (
4 | InvalidPaddingError,
5 | PKCS5Padding
6 | )
7 |
8 |
9 | class TestPkcs5Padding:
10 | def setup_method(self):
11 | self.BLOCK_SIZE = 8
12 | self.padder = PKCS5Padding(self.BLOCK_SIZE)
13 |
14 | def test_various_lengths_roundtrip(self):
15 | for number in range(0, 3 * self.BLOCK_SIZE):
16 | val = b'*' * number
17 | padded = self.padder.pad(val)
18 | unpadded = self.padder.unpad(padded)
19 | assert val == unpadded, 'Round trip error for length %d' % number
20 |
21 | def test_invalid_unpad(self):
22 | with pytest.raises(InvalidPaddingError):
23 | self.padder.unpad(None)
24 | with pytest.raises(InvalidPaddingError):
25 | self.padder.unpad(b'')
26 | with pytest.raises(InvalidPaddingError):
27 | self.padder.unpad(b'\01')
28 | with pytest.raises(InvalidPaddingError):
29 | self.padder.unpad((b'*' * (self.BLOCK_SIZE - 1)) + b'\00')
30 | with pytest.raises(InvalidPaddingError):
31 | self.padder.unpad((b'*' * self.BLOCK_SIZE) + b'\01')
32 |
33 | def test_pad_longer_than_block(self):
34 | with pytest.raises(InvalidPaddingError):
35 | self.padder.unpad(
36 | 'x' * (self.BLOCK_SIZE - 1) +
37 | chr(self.BLOCK_SIZE + 1) * (self.BLOCK_SIZE + 1)
38 | )
39 |
40 | def test_incorrect_padding(self):
41 | # Hard-coded for blocksize of 8
42 | assert self.padder.unpad(b'1234\04\04\04\04') == b'1234'
43 | with pytest.raises(InvalidPaddingError):
44 | self.padder.unpad(b'1234\02\04\04\04')
45 | with pytest.raises(InvalidPaddingError):
46 | self.padder.unpad(b'1234\04\02\04\04')
47 | with pytest.raises(InvalidPaddingError):
48 | self.padder.unpad(b'1234\04\04\02\04')
49 |
--------------------------------------------------------------------------------
/RELEASE.md:
--------------------------------------------------------------------------------
1 | # SQLAlchemy-Utils Release Process
2 |
3 | This document outlines the process for creating a new release of SQLAlchemy-Utils.
4 |
5 | ## Prerequisites
6 |
7 | - Access to the GitHub repository with push permissions
8 | - PyPI account with API token for the project
9 | - `uv` package manager installed
10 | - `gh` CLI tool installed (optional, for GitHub releases)
11 |
12 | ## Release Process
13 |
14 | ### 1. Prepare the Release
15 |
16 | 1. **Update CHANGES.rst**
17 | - Move items from "Unreleased changes" section to a new version section
18 | - Follow the existing format: `X.Y.Z (YYYY-MM-DD)`
19 | - Add a new "Unreleased changes" section
20 |
21 | 2. **Update version in pyproject.toml**
22 | ```bash
23 | # Edit pyproject.toml and update the version field
24 | version = "X.Y.Z"
25 | ```
26 |
27 | 3. **Commit the changes**
28 | ```bash
29 | git add CHANGES.rst pyproject.toml
30 | git commit -m "Bump version to X.Y.Z"
31 | git push origin master
32 | ```
33 |
34 | ### 2. Create Git Tag
35 |
36 | ```bash
37 | git tag -a X.Y.Z -m "Release version X.Y.Z"
38 | git push origin X.Y.Z
39 | ```
40 |
41 | ### 3. Build and Publish to PyPI
42 |
43 | 1. **Build the package**
44 | ```bash
45 | uv build
46 | ```
47 |
48 | 2. **Publish to PyPI**
49 | ```bash
50 | uv publish --token pypi-your-api-token-here
51 | ```
52 |
53 | Alternatively, set the token as an environment variable:
54 | ```bash
55 | export UV_PUBLISH_TOKEN="pypi-your-api-token-here"
56 | uv publish
57 | ```
58 |
59 | ### 4. Create GitHub Release
60 |
61 | ```bash
62 | gh release create X.Y.Z --title "X.Y.Z" --notes-from-tag
63 | ```
64 |
65 | Or create the release manually on GitHub using the tag.
66 |
67 | ## Version Numbering
68 |
69 | This project follows [Semantic Versioning](https://semver.org/):
70 | - **MAJOR** version for incompatible API changes
71 | - **MINOR** version for backwards-compatible functionality additions
72 | - **PATCH** version for backwards-compatible bug fixes
73 |
--------------------------------------------------------------------------------
/tests/aggregate/test_search_vectors.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import aggregated, TSVectorType
5 |
6 |
7 | def tsvector_reduce_concat(vectors):
8 | return sa.sql.expression.cast(
9 | sa.func.coalesce(
10 | sa.func.array_to_string(sa.func.array_agg(vectors), ' ')
11 | ),
12 | TSVectorType
13 | )
14 |
15 |
16 | @pytest.fixture
17 | def Product(Base):
18 | class Product(Base):
19 | __tablename__ = 'product'
20 | id = sa.Column(sa.Integer, primary_key=True)
21 | name = sa.Column(sa.Unicode(255))
22 | price = sa.Column(sa.Numeric)
23 |
24 | catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
25 | return Product
26 |
27 |
28 | @pytest.fixture
29 | def Catalog(Base, Product):
30 | class Catalog(Base):
31 | __tablename__ = 'catalog'
32 | id = sa.Column(sa.Integer, primary_key=True)
33 | name = sa.Column(sa.Unicode(255))
34 |
35 | @aggregated('products', sa.Column(TSVectorType))
36 | def product_search_vector(self):
37 | return tsvector_reduce_concat(
38 | sa.func.to_tsvector(Product.name)
39 | )
40 |
41 | products = sa.orm.relationship('Product', backref='catalog')
42 | return Catalog
43 |
44 |
45 | @pytest.fixture
46 | def init_models(Product, Catalog):
47 | pass
48 |
49 |
50 | @pytest.mark.usefixtures('postgresql_dsn')
51 | class TestSearchVectorAggregates:
52 |
53 | def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
54 | catalog = Catalog(
55 | name='Some catalog'
56 | )
57 | session.add(catalog)
58 | session.commit()
59 | product = Product(
60 | name='Product XYZ',
61 | catalog=catalog
62 | )
63 | session.add(product)
64 | session.commit()
65 | session.refresh(catalog)
66 | assert catalog.product_search_vector == "'product':1 'xyz':2"
67 |
--------------------------------------------------------------------------------
/tests/types/test_weekdays.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import i18n
5 | from sqlalchemy_utils.primitives import WeekDays
6 | from sqlalchemy_utils.types import WeekDaysType
7 |
8 |
9 | @pytest.fixture
10 | def Schedule(Base):
11 | class Schedule(Base):
12 | __tablename__ = 'schedule'
13 | id = sa.Column(sa.Integer, primary_key=True)
14 | working_days = sa.Column(WeekDaysType)
15 |
16 | def __repr__(self):
17 | return 'Schedule(%r)' % self.id
18 | return Schedule
19 |
20 |
21 | @pytest.fixture
22 | def init_models(Schedule):
23 | pass
24 |
25 |
26 | @pytest.fixture
27 | def set_get_locale():
28 | i18n.get_locale = lambda: i18n.babel.Locale('en')
29 |
30 |
31 | @pytest.mark.usefixtures('set_get_locale')
32 | @pytest.mark.skipif('i18n.babel is None')
33 | class WeekDaysTypeTestCase:
34 | def test_color_parameter_processing(self, session, Schedule):
35 | schedule = Schedule(
36 | working_days=b'0001111'
37 | )
38 | session.add(schedule)
39 | session.commit()
40 |
41 | schedule = session.query(Schedule).first()
42 | assert isinstance(schedule.working_days, WeekDays)
43 |
44 | def test_scalar_attributes_get_coerced_to_objects(self, Schedule):
45 | schedule = Schedule(working_days=b'1010101')
46 |
47 | assert isinstance(schedule.working_days, WeekDays)
48 |
49 | def test_compilation(self, Schedule, session):
50 | query = sa.select(Schedule.working_days)
51 | # the type should be cacheable and not throw exception
52 | session.execute(query)
53 |
54 |
55 | @pytest.mark.usefixtures('sqlite_memory_dsn')
56 | class TestWeekDaysTypeOnSQLite(WeekDaysTypeTestCase):
57 | pass
58 |
59 |
60 | @pytest.mark.usefixtures('postgresql_dsn')
61 | class TestWeekDaysTypeOnPostgres(WeekDaysTypeTestCase):
62 | pass
63 |
64 |
65 | @pytest.mark.usefixtures('mysql_dsn')
66 | class TestWeekDaysTypeOnMySQL(WeekDaysTypeTestCase):
67 | pass
68 |
--------------------------------------------------------------------------------
/tests/primitives/test_currency.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from sqlalchemy_utils import Currency, i18n
4 |
5 |
6 | @pytest.fixture
7 | def set_get_locale():
8 | i18n.get_locale = lambda: i18n.babel.Locale('en')
9 |
10 |
11 | @pytest.mark.skipif('i18n.babel is None')
12 | @pytest.mark.usefixtures('set_get_locale')
13 | class TestCurrency:
14 |
15 | def test_init(self):
16 | assert Currency('USD') == Currency(Currency('USD'))
17 |
18 | def test_hashability(self):
19 | assert len({Currency('USD'), Currency('USD')}) == 1
20 |
21 | def test_invalid_currency_code(self):
22 | with pytest.raises(ValueError):
23 | Currency('Unknown code')
24 |
25 | def test_invalid_currency_code_type(self):
26 | with pytest.raises(TypeError):
27 | Currency(None)
28 |
29 | @pytest.mark.parametrize(
30 | ('code', 'name'),
31 | (
32 | ('USD', 'US Dollar'),
33 | ('EUR', 'Euro')
34 | )
35 | )
36 | def test_name_property(self, code, name):
37 | assert Currency(code).name == name
38 |
39 | @pytest.mark.parametrize(
40 | ('code', 'symbol'),
41 | (
42 | ('USD', '$'),
43 | ('EUR', '€')
44 | )
45 | )
46 | def test_symbol_property(self, code, symbol):
47 | assert Currency(code).symbol == symbol
48 |
49 | def test_equality_operator(self):
50 | assert Currency('USD') == 'USD'
51 | assert 'USD' == Currency('USD')
52 | assert Currency('USD') == Currency('USD')
53 |
54 | def test_non_equality_operator(self):
55 | assert Currency('USD') != 'EUR'
56 | assert not (Currency('USD') != 'USD')
57 |
58 | def test_unicode(self):
59 | currency = Currency('USD')
60 | assert str(currency) == 'USD'
61 |
62 | def test_str(self):
63 | currency = Currency('USD')
64 | assert str(currency) == 'USD'
65 |
66 | def test_representation(self):
67 | currency = Currency('USD')
68 | assert repr(currency) == "Currency('USD')"
69 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/locale.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from ..exceptions import ImproperlyConfigured
4 | from .scalar_coercible import ScalarCoercible
5 |
6 | babel = None
7 | try:
8 | import babel
9 | except ImportError:
10 | pass
11 |
12 |
13 | class LocaleType(ScalarCoercible, types.TypeDecorator):
14 | """
15 | LocaleType saves Babel_ Locale objects into database. The Locale objects
16 | are converted to string on the way in and back to object on the way out.
17 |
18 | In order to use LocaleType you need to install Babel_ first.
19 |
20 | .. _Babel: https://babel.pocoo.org/
21 |
22 | ::
23 |
24 |
25 | from sqlalchemy_utils import LocaleType
26 | from babel import Locale
27 |
28 |
29 | class User(Base):
30 | __tablename__ = 'user'
31 | id = sa.Column(sa.Integer, autoincrement=True)
32 | name = sa.Column(sa.Unicode(50))
33 | locale = sa.Column(LocaleType)
34 |
35 |
36 | user = User()
37 | user.locale = Locale('en_US')
38 | session.add(user)
39 | session.commit()
40 |
41 |
42 | Like many other types this type also supports scalar coercion:
43 |
44 | ::
45 |
46 |
47 | user.locale = 'de_DE'
48 | user.locale # Locale('de', territory='DE')
49 |
50 | """
51 |
52 | impl = types.Unicode(10)
53 |
54 | cache_ok = True
55 |
56 | def __init__(self):
57 | if babel is None:
58 | raise ImproperlyConfigured('Babel packaged is required with LocaleType.')
59 |
60 | def process_bind_param(self, value, dialect):
61 | if isinstance(value, babel.Locale):
62 | return str(value)
63 |
64 | if isinstance(value, str):
65 | return value
66 |
67 | def process_result_value(self, value, dialect):
68 | if value is not None:
69 | return babel.Locale.parse(value)
70 |
71 | def _coerce(self, value):
72 | if value is not None and not isinstance(value, babel.Locale):
73 | return babel.Locale.parse(value)
74 | return value
75 |
--------------------------------------------------------------------------------
/tests/types/test_uuid.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 | from sqlalchemy.dialects import postgresql
6 |
7 | from sqlalchemy_utils import UUIDType
8 |
9 |
10 | @pytest.fixture
11 | def User(Base):
12 | class User(Base):
13 | __tablename__ = 'user'
14 | id = sa.Column(UUIDType, default=uuid.uuid4, primary_key=True)
15 |
16 | def __repr__(self):
17 | return 'User(%r)' % self.id
18 | return User
19 |
20 |
21 | @pytest.fixture
22 | def init_models(User):
23 | pass
24 |
25 |
26 | class TestUUIDType:
27 | def test_repr(self):
28 | plain = UUIDType()
29 | assert repr(plain) == 'UUIDType()'
30 |
31 | text = UUIDType(binary=False)
32 | assert repr(text) == 'UUIDType(binary=False)'
33 |
34 | not_native = UUIDType(native=False)
35 | assert repr(not_native) == 'UUIDType(native=False)'
36 |
37 | def test_commit(self, session, User):
38 | obj = User()
39 | obj.id = uuid.uuid4().hex
40 |
41 | session.add(obj)
42 | session.commit()
43 |
44 | u = session.query(User).one()
45 |
46 | assert u.id == obj.id
47 |
48 | def test_coerce(self, User):
49 | obj = User()
50 | obj.id = identifier = uuid.uuid4().hex
51 |
52 | assert isinstance(obj.id, uuid.UUID)
53 | assert obj.id.hex == identifier
54 |
55 | obj.id = identifier = uuid.uuid4().bytes
56 |
57 | assert isinstance(obj.id, uuid.UUID)
58 | assert obj.id.bytes == identifier
59 |
60 | def test_compilation(self, User, session):
61 | query = sa.select(User.id)
62 |
63 | # the type should be cacheable and not throw exception
64 | session.execute(query)
65 |
66 | def test_literal_bind(self, User):
67 | expr = (User.id == 'b4e794d6-5750-4844-958c-fa382649719d').compile(
68 | dialect=postgresql.dialect(),
69 | compile_kwargs={'literal_binds': True}
70 | )
71 | assert str(expr) == (
72 | '''"user".id = \'b4e794d6-5750-4844-958c-fa382649719d\''''
73 | )
74 |
--------------------------------------------------------------------------------
/tests/test_expressions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | from sqlalchemy.dialects import postgresql
4 | from sqlalchemy.orm import declarative_base
5 |
6 | from sqlalchemy_utils import Asterisk, row_to_json
7 |
8 |
9 | @pytest.fixture
10 | def assert_startswith(session):
11 | def assert_startswith(query, query_part):
12 | assert str(
13 | query.compile(dialect=postgresql.dialect())
14 | ).startswith(query_part)
15 | # Check that query executes properly
16 | session.execute(query)
17 | return assert_startswith
18 |
19 |
20 | @pytest.fixture
21 | def Article(Base):
22 | class Article(Base):
23 | __tablename__ = 'article'
24 | id = sa.Column(sa.Integer, primary_key=True)
25 | name = sa.Column(sa.Unicode(255))
26 | content = sa.Column(sa.UnicodeText)
27 | return Article
28 |
29 |
30 | class TestAsterisk:
31 | def test_with_table_object(self):
32 | Base = declarative_base()
33 |
34 | class Article(Base):
35 | __tablename__ = 'article'
36 | id = sa.Column(sa.Integer, primary_key=True)
37 |
38 | assert str(Asterisk(Article.__table__)) == 'article.*'
39 |
40 | def test_with_quoted_identifier(self):
41 | Base = declarative_base()
42 |
43 | class User(Base):
44 | __tablename__ = 'user'
45 | id = sa.Column(sa.Integer, primary_key=True)
46 |
47 | assert str(Asterisk(User.__table__).compile(
48 | dialect=postgresql.dialect()
49 | )) == '"user".*'
50 |
51 |
52 | class TestRowToJson:
53 | def test_compiler_with_default_dialect(self):
54 | assert str(row_to_json(sa.text('article.*'))) == (
55 | 'row_to_json(article.*)'
56 | )
57 |
58 | def test_compiler_with_postgresql(self):
59 | assert str(row_to_json(sa.text('article.*')).compile(
60 | dialect=postgresql.dialect()
61 | )) == 'row_to_json(article.*)'
62 |
63 | def test_type(self):
64 | assert isinstance(
65 | sa.func.row_to_json(sa.text('article.*')).type,
66 | postgresql.JSON
67 | )
68 |
--------------------------------------------------------------------------------
/tests/types/test_locale.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.types import locale
5 |
6 |
7 | @pytest.fixture
8 | def User(Base):
9 | class User(Base):
10 | __tablename__ = 'user'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | locale = sa.Column(locale.LocaleType)
13 |
14 | def __repr__(self):
15 | return 'User(%r)' % self.id
16 | return User
17 |
18 |
19 | @pytest.fixture
20 | def init_models(User):
21 | pass
22 |
23 |
24 | @pytest.mark.skipif('locale.babel is None')
25 | class TestLocaleType:
26 | def test_parameter_processing(self, session, User):
27 | user = User(
28 | locale=locale.babel.Locale('fi')
29 | )
30 |
31 | session.add(user)
32 | session.commit()
33 |
34 | user = session.query(User).first()
35 |
36 | def test_territory_parsing(self, session, User):
37 | ko_kr = locale.babel.Locale('ko', territory='KR')
38 | user = User(locale=ko_kr)
39 | session.add(user)
40 | session.commit()
41 |
42 | assert session.query(User.locale).first()[0] == ko_kr
43 |
44 | def test_coerce_territory_parsing(self, User):
45 | user = User()
46 | user.locale = 'ko_KR'
47 | assert user.locale == locale.babel.Locale('ko', territory='KR')
48 |
49 | def test_scalar_attributes_get_coerced_to_objects(self, User):
50 | user = User(locale='en_US')
51 |
52 | assert isinstance(user.locale, locale.babel.Locale)
53 |
54 | def test_unknown_locale_throws_exception(self, User):
55 | with pytest.raises(locale.babel.UnknownLocaleError):
56 | User(locale='unknown')
57 |
58 | def test_literal_param(self, session, User):
59 | clause = User.locale == 'en_US'
60 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
61 | assert compiled == '"user".locale = \'en_US\''
62 |
63 | def test_compilation(self, User, session):
64 | query = sa.select(User.locale)
65 |
66 | # the type should be cacheable and not throw exception
67 | session.execute(query)
68 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/operators.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy as sa
2 |
3 |
4 | def inspect_type(mixed):
5 | if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
6 | return mixed.property.columns[0].type
7 | elif isinstance(mixed, sa.orm.ColumnProperty):
8 | return mixed.columns[0].type
9 | elif isinstance(mixed, sa.Column):
10 | return mixed.type
11 |
12 |
13 | def is_case_insensitive(mixed):
14 | try:
15 | return isinstance(inspect_type(mixed).comparator, CaseInsensitiveComparator)
16 | except AttributeError:
17 | try:
18 | return issubclass(
19 | inspect_type(mixed).comparator_factory, CaseInsensitiveComparator
20 | )
21 | except AttributeError:
22 | return False
23 |
24 |
25 | class CaseInsensitiveComparator(sa.Unicode.Comparator):
26 | @classmethod
27 | def lowercase_arg(cls, func):
28 | def operation(self, other, **kwargs):
29 | operator = getattr(sa.Unicode.Comparator, func)
30 | if other is None:
31 | return operator(self, other, **kwargs)
32 | if not is_case_insensitive(other):
33 | other = sa.func.lower(other)
34 | return operator(self, other, **kwargs)
35 |
36 | return operation
37 |
38 | def in_(self, other):
39 | if isinstance(other, list) or isinstance(other, tuple):
40 | other = map(sa.func.lower, other)
41 | return sa.Unicode.Comparator.in_(self, other)
42 |
43 | def notin_(self, other):
44 | if isinstance(other, list) or isinstance(other, tuple):
45 | other = map(sa.func.lower, other)
46 | return sa.Unicode.Comparator.notin_(self, other)
47 |
48 |
49 | string_operator_funcs = [
50 | '__eq__',
51 | '__ne__',
52 | '__lt__',
53 | '__le__',
54 | '__gt__',
55 | '__ge__',
56 | 'concat',
57 | 'contains',
58 | 'ilike',
59 | 'like',
60 | 'notlike',
61 | 'notilike',
62 | 'startswith',
63 | 'endswith',
64 | ]
65 |
66 | for func in string_operator_funcs:
67 | setattr(
68 | CaseInsensitiveComparator, func, CaseInsensitiveComparator.lowercase_arg(func)
69 | )
70 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/__init__.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 |
3 | from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
4 |
5 | from .arrow import ArrowType # noqa
6 | from .choice import Choice, ChoiceType # noqa
7 | from .color import ColorType # noqa
8 | from .country import CountryType # noqa
9 | from .currency import CurrencyType # noqa
10 | from .email import EmailType # noqa
11 | from .encrypted.encrypted_type import ( # noqa
12 | EncryptedType,
13 | StringEncryptedType,
14 | )
15 | from .enriched_datetime.enriched_date_type import EnrichedDateType # noqa
16 | from .ip_address import IPAddressType # noqa
17 | from .json import JSONType # noqa
18 | from .locale import LocaleType # noqa
19 | from .ltree import LtreeType # noqa
20 | from .password import Password, PasswordType # noqa
21 | from .pg_composite import ( # noqa
22 | CompositeType,
23 | register_composites,
24 | remove_composite_listeners,
25 | )
26 | from .phone_number import ( # noqa
27 | PhoneNumber,
28 | PhoneNumberParseException,
29 | PhoneNumberType,
30 | )
31 | from .range import ( # noqa
32 | DateRangeType,
33 | DateTimeRangeType,
34 | Int8RangeType,
35 | IntRangeType,
36 | NumericRangeType,
37 | )
38 | from .scalar_list import ScalarListException, ScalarListType # noqa
39 | from .timezone import TimezoneType # noqa
40 | from .ts_vector import TSVectorType # noqa
41 | from .url import URLType # noqa
42 | from .uuid import UUIDType # noqa
43 | from .weekdays import WeekDaysType # noqa
44 |
45 | from .enriched_datetime.enriched_datetime_type import EnrichedDateTimeType # noqa isort:skip
46 |
47 |
48 | class InstrumentedList(_InstrumentedList):
49 | """Enhanced version of SQLAlchemy InstrumentedList. Provides some
50 | additional functionality."""
51 |
52 | def any(self, attr):
53 | return any(getattr(item, attr) for item in self)
54 |
55 | def all(self, attr):
56 | return all(getattr(item, attr) for item in self)
57 |
58 |
59 | def instrumented_list(f):
60 | @wraps(f)
61 | def wrapper(*args, **kwargs):
62 | return InstrumentedList([item for item in f(*args, **kwargs)])
63 |
64 | return wrapper
65 |
--------------------------------------------------------------------------------
/tests/types/test_enriched_date_pendulum.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 |
6 | from sqlalchemy_utils.types.enriched_datetime import (
7 | enriched_date_type,
8 | pendulum_date
9 | )
10 |
11 |
12 | @pytest.fixture
13 | def User(Base):
14 | class User(Base):
15 | __tablename__ = 'users'
16 | id = sa.Column(sa.Integer, primary_key=True)
17 | birthday = sa.Column(
18 | enriched_date_type.EnrichedDateType(
19 | date_processor=pendulum_date.PendulumDate
20 | )
21 | )
22 | return User
23 |
24 |
25 | @pytest.fixture
26 | def init_models(User):
27 | pass
28 |
29 |
30 | @pytest.mark.skipif('pendulum_date.pendulum is None')
31 | class TestPendulumDateType:
32 | def test_parameter_processing(self, session, User):
33 | user = User(
34 | birthday=pendulum_date.pendulum.date(1995, 7, 11)
35 | )
36 |
37 | session.add(user)
38 | session.commit()
39 |
40 | user = session.query(User).first()
41 | assert isinstance(user.birthday, date)
42 |
43 | def test_int_coercion(self, User):
44 | user = User(
45 | birthday=1367900664
46 | )
47 | assert user.birthday.year == 2013
48 |
49 | def test_string_coercion(self, User):
50 | user = User(
51 | birthday='1367900664'
52 | )
53 | assert user.birthday.year == 2013
54 |
55 | def test_utc(self, session, User):
56 | time = pendulum_date.pendulum.now("UTC")
57 | user = User(birthday=time)
58 | session.add(user)
59 | assert user.birthday == time
60 | session.commit()
61 | assert user.birthday == time.date()
62 |
63 | def test_literal_param(self, session, User):
64 | clause = User.birthday > date(2015, 1, 1)
65 | compiled = str(clause.compile(compile_kwargs={"literal_binds": True}))
66 | assert compiled == "users.birthday > '2015-01-01'"
67 |
68 | def test_compilation(self, User, session):
69 | query = sa.select(User.birthday)
70 | # the type should be cacheable and not throw exception
71 | session.execute(query)
72 |
--------------------------------------------------------------------------------
/tests/types/test_color.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | from flexmock import flexmock
4 |
5 | from sqlalchemy_utils import ColorType, types # noqa
6 |
7 |
8 | @pytest.fixture
9 | def Document(Base):
10 | class Document(Base):
11 | __tablename__ = 'document'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | bg_color = sa.Column(ColorType)
14 |
15 | def __repr__(self):
16 | return 'Document(%r)' % self.id
17 | return Document
18 |
19 |
20 | @pytest.fixture
21 | def init_models(Document):
22 | pass
23 |
24 |
25 | @pytest.mark.skipif('types.color.python_colour_type is None')
26 | class TestColorType:
27 | def test_string_parameter_processing(self, session, Document):
28 | from colour import Color
29 |
30 | flexmock(ColorType).should_receive('_coerce').and_return(
31 | 'white'
32 | )
33 | document = Document(
34 | bg_color='white'
35 | )
36 |
37 | session.add(document)
38 | session.commit()
39 |
40 | document = session.query(Document).first()
41 | assert document.bg_color.hex == Color('white').hex
42 |
43 | def test_color_parameter_processing(self, session, Document):
44 | from colour import Color
45 |
46 | document = Document(bg_color=Color('white'))
47 | session.add(document)
48 | session.commit()
49 |
50 | document = session.query(Document).first()
51 | assert document.bg_color.hex == Color('white').hex
52 |
53 | def test_scalar_attributes_get_coerced_to_objects(self, Document):
54 | from colour import Color
55 |
56 | document = Document(bg_color='white')
57 | assert isinstance(document.bg_color, Color)
58 |
59 | def test_literal_param(self, session, Document):
60 | clause = Document.bg_color == 'white'
61 | compiled = str(clause.compile(compile_kwargs={'literal_binds': True}))
62 | assert compiled == "document.bg_color = 'white'"
63 |
64 | def test_compilation(self, Document, session):
65 | query = sa.select(Document.bg_color)
66 | # the type should be cacheable and not throw exception
67 | session.execute(query)
68 |
--------------------------------------------------------------------------------
/tests/aggregate/test_with_column_alias.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.aggregates import aggregated
5 |
6 |
7 | @pytest.fixture
8 | def Thread(Base):
9 | class Thread(Base):
10 | __tablename__ = 'thread'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 |
13 | @aggregated(
14 | 'comments',
15 | sa.Column('_comment_count', sa.Integer, default=0)
16 | )
17 | def comment_count(self):
18 | return sa.func.count('1')
19 |
20 | comments = sa.orm.relationship('Comment', backref='thread')
21 | return Thread
22 |
23 |
24 | @pytest.fixture
25 | def Comment(Base):
26 | class Comment(Base):
27 | __tablename__ = 'comment'
28 | id = sa.Column(sa.Integer, primary_key=True)
29 | thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
30 | return Comment
31 |
32 |
33 | @pytest.fixture
34 | def init_models(Thread, Comment):
35 | pass
36 |
37 |
38 | class TestAggregatedWithColumnAlias:
39 |
40 | def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
41 | thread = Thread()
42 | session.add(thread)
43 | comment = Comment(thread=thread)
44 | session.add(comment)
45 | session.commit()
46 | session.refresh(thread)
47 | assert thread.comment_count == 1
48 |
49 | def test_assigns_aggregates_on_separate_insert(
50 | self,
51 | session,
52 | Thread,
53 | Comment
54 | ):
55 | thread = Thread()
56 | session.add(thread)
57 | session.commit()
58 | comment = Comment(thread=thread)
59 | session.add(comment)
60 | session.commit()
61 | session.refresh(thread)
62 | assert thread.comment_count == 1
63 |
64 | def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
65 | thread = Thread()
66 | session.add(thread)
67 | session.commit()
68 | comment = Comment(thread=thread)
69 | session.add(comment)
70 | session.commit()
71 | session.delete(comment)
72 | session.commit()
73 | session.refresh(thread)
74 | assert thread.comment_count == 0
75 |
--------------------------------------------------------------------------------
/tests/functions/test_non_indexed_foreign_keys.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 |
6 | from sqlalchemy_utils.functions import non_indexed_foreign_keys
7 |
8 |
9 | class TestFindNonIndexedForeignKeys:
10 |
11 | @pytest.fixture
12 | def User(self, Base):
13 | class User(Base):
14 | __tablename__ = 'user'
15 | id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
16 | name = sa.Column(sa.Unicode(255))
17 | return User
18 |
19 | @pytest.fixture
20 | def Category(self, Base):
21 | class Category(Base):
22 | __tablename__ = 'category'
23 | id = sa.Column(sa.Integer, primary_key=True)
24 | name = sa.Column(sa.Unicode(255))
25 | return Category
26 |
27 | @pytest.fixture
28 | def Article(self, Base, User, Category):
29 | class Article(Base):
30 | __tablename__ = 'article'
31 | id = sa.Column(sa.Integer, primary_key=True)
32 | name = sa.Column(sa.Unicode(255))
33 | author_id = sa.Column(
34 | sa.Integer, sa.ForeignKey(User.id), index=True
35 | )
36 | category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
37 |
38 | category = sa.orm.relationship(
39 | Category,
40 | primaryjoin=category_id == Category.id,
41 | backref=sa.orm.backref(
42 | 'articles',
43 | )
44 | )
45 | return Article
46 |
47 | @pytest.fixture
48 | def init_models(self, User, Category, Article):
49 | pass
50 |
51 | def test_finds_all_non_indexed_fks(self, session, Base, engine):
52 | fks = non_indexed_foreign_keys(Base.metadata, engine)
53 | assert (
54 | 'article' in
55 | fks
56 | )
57 | column_names = list(chain(
58 | *(
59 | names for names in (
60 | fk.columns.keys()
61 | for fk in fks['article']
62 | )
63 | )
64 | ))
65 | assert 'category_id' in column_names
66 | assert 'author_id' not in column_names
67 |
--------------------------------------------------------------------------------
/docs/orm_helpers.rst:
--------------------------------------------------------------------------------
1 | ORM helpers
2 | ===========
3 |
4 |
5 | cast_if
6 | -------
7 |
8 | .. autofunction:: sqlalchemy_utils.functions.cast_if
9 |
10 |
11 | escape_like
12 | -----------
13 |
14 | .. autofunction:: sqlalchemy_utils.functions.escape_like
15 |
16 |
17 | get_bind
18 | --------
19 |
20 | .. autofunction:: sqlalchemy_utils.functions.get_bind
21 |
22 |
23 | get_class_by_table
24 | ------------------
25 |
26 | .. autofunction:: sqlalchemy_utils.functions.get_class_by_table
27 |
28 |
29 | get_column_key
30 | --------------
31 |
32 | .. autofunction:: sqlalchemy_utils.functions.get_column_key
33 |
34 |
35 | get_columns
36 | -----------
37 |
38 | .. autofunction:: sqlalchemy_utils.functions.get_columns
39 |
40 |
41 | get_declarative_base
42 | --------------------
43 |
44 | .. autofunction:: sqlalchemy_utils.functions.get_declarative_base
45 |
46 |
47 | get_hybrid_properties
48 | ---------------------
49 |
50 | .. autofunction:: sqlalchemy_utils.functions.get_hybrid_properties
51 |
52 |
53 | get_mapper
54 | ----------
55 |
56 | .. autofunction:: sqlalchemy_utils.functions.get_mapper
57 |
58 |
59 | get_primary_keys
60 | ----------------
61 |
62 | .. autofunction:: sqlalchemy_utils.functions.get_primary_keys
63 |
64 |
65 | get_tables
66 | ----------
67 |
68 | .. autofunction:: sqlalchemy_utils.functions.get_tables
69 |
70 |
71 | get_type
72 | --------
73 |
74 | .. autofunction:: sqlalchemy_utils.functions.get_type
75 |
76 |
77 | has_changes
78 | -----------
79 |
80 | .. autofunction:: sqlalchemy_utils.functions.has_changes
81 |
82 |
83 | identity
84 | --------
85 |
86 | .. autofunction:: sqlalchemy_utils.functions.identity
87 |
88 |
89 | is_loaded
90 | ---------
91 |
92 | .. autofunction:: sqlalchemy_utils.functions.is_loaded
93 |
94 |
95 | make_order_by_deterministic
96 | ---------------------------
97 |
98 | .. autofunction:: sqlalchemy_utils.functions.make_order_by_deterministic
99 |
100 |
101 | naturally_equivalent
102 | --------------------
103 |
104 | .. autofunction:: sqlalchemy_utils.functions.naturally_equivalent
105 |
106 |
107 | quote
108 | -----
109 |
110 | .. autofunction:: sqlalchemy_utils.functions.quote
111 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/currency.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from .. import i18n, ImproperlyConfigured
4 | from ..primitives import Currency
5 | from .scalar_coercible import ScalarCoercible
6 |
7 |
8 | class CurrencyType(ScalarCoercible, types.TypeDecorator):
9 | """
10 | Changes :class:`.Currency` objects to a string representation on the way in
11 | and changes them back to :class:`.Currency` objects on the way out.
12 |
13 | In order to use CurrencyType you need to install Babel_ first.
14 |
15 | .. _Babel: https://babel.pocoo.org/
16 |
17 | ::
18 |
19 |
20 | from sqlalchemy_utils import CurrencyType, Currency
21 |
22 |
23 | class User(Base):
24 | __tablename__ = 'user'
25 | id = sa.Column(sa.Integer, autoincrement=True)
26 | name = sa.Column(sa.Unicode(255))
27 | currency = sa.Column(CurrencyType)
28 |
29 |
30 | user = User()
31 | user.currency = Currency('USD')
32 | session.add(user)
33 | session.commit()
34 |
35 | user.currency # Currency('USD')
36 | user.currency.name # US Dollar
37 |
38 | str(user.currency) # US Dollar
39 | user.currency.symbol # $
40 |
41 |
42 |
43 | CurrencyType is scalar coercible::
44 |
45 |
46 | user.currency = 'US'
47 | user.currency # Currency('US')
48 | """
49 |
50 | impl = types.String(3)
51 | python_type = Currency
52 | cache_ok = True
53 |
54 | def __init__(self, *args, **kwargs):
55 | if i18n.babel is None:
56 | raise ImproperlyConfigured(
57 | "'babel' package is required in order to use CurrencyType."
58 | )
59 |
60 | super().__init__(*args, **kwargs)
61 |
62 | def process_bind_param(self, value, dialect):
63 | if isinstance(value, Currency):
64 | return value.code
65 | elif isinstance(value, str):
66 | return value
67 |
68 | def process_result_value(self, value, dialect):
69 | if value is not None:
70 | return Currency(value)
71 |
72 | def _coerce(self, value):
73 | if value is not None and not isinstance(value, Currency):
74 | return Currency(value)
75 | return value
76 |
--------------------------------------------------------------------------------
/tests/aggregate/test_o2m_o2m.py:
--------------------------------------------------------------------------------
1 | from decimal import Decimal
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 |
6 | from sqlalchemy_utils.aggregates import aggregated
7 |
8 |
9 | @pytest.fixture
10 | def Catalog(Base):
11 | class Catalog(Base):
12 | __tablename__ = 'catalog'
13 | id = sa.Column(sa.Integer, primary_key=True)
14 | name = sa.Column(sa.Unicode(255))
15 |
16 | @aggregated(
17 | 'categories.products',
18 | sa.Column(sa.Integer, default=0)
19 | )
20 | def product_count(self):
21 | return sa.func.count('1')
22 |
23 | categories = sa.orm.relationship('Category', backref='catalog')
24 | return Catalog
25 |
26 |
27 | @pytest.fixture
28 | def Category(Base):
29 | class Category(Base):
30 | __tablename__ = 'category'
31 | id = sa.Column(sa.Integer, primary_key=True)
32 | name = sa.Column(sa.Unicode(255))
33 |
34 | catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
35 |
36 | products = sa.orm.relationship('Product', backref='category')
37 | return Category
38 |
39 |
40 | @pytest.fixture
41 | def Product(Base):
42 | class Product(Base):
43 | __tablename__ = 'product'
44 | id = sa.Column(sa.Integer, primary_key=True)
45 | name = sa.Column(sa.Unicode(255))
46 | price = sa.Column(sa.Numeric)
47 |
48 | category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
49 | return Product
50 |
51 |
52 | @pytest.fixture
53 | def init_models(Catalog, Category, Product):
54 | pass
55 |
56 |
57 | @pytest.mark.usefixtures('postgresql_dsn')
58 | class TestAggregateOneToManyAndOneToMany:
59 |
60 | def test_assigns_aggregates(self, session, Category, Catalog, Product):
61 | category = Category(name='Some category')
62 | catalog = Catalog(
63 | categories=[category]
64 | )
65 | catalog.name = 'Some catalog'
66 | session.add(catalog)
67 | session.commit()
68 | product = Product(
69 | name='Some product',
70 | price=Decimal('1000'),
71 | category=category
72 | )
73 | session.add(product)
74 | session.commit()
75 | session.refresh(catalog)
76 | assert catalog.product_count == 1
77 |
--------------------------------------------------------------------------------
/tests/generic_relationship/test_hybrid_properties.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | from sqlalchemy.ext.hybrid import hybrid_property
4 |
5 | from sqlalchemy_utils import generic_relationship
6 |
7 |
8 | @pytest.fixture
9 | def User(Base):
10 | class User(Base):
11 | __tablename__ = 'user'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | return User
14 |
15 |
16 | @pytest.fixture
17 | def UserHistory(Base):
18 | class UserHistory(Base):
19 | __tablename__ = 'user_history'
20 | id = sa.Column(sa.Integer, primary_key=True)
21 |
22 | transaction_id = sa.Column(sa.Integer, primary_key=True)
23 | return UserHistory
24 |
25 |
26 | @pytest.fixture
27 | def Event(Base):
28 | class Event(Base):
29 | __tablename__ = 'event'
30 | id = sa.Column(sa.Integer, primary_key=True)
31 |
32 | transaction_id = sa.Column(sa.Integer)
33 |
34 | object_type = sa.Column(sa.Unicode(255))
35 | object_id = sa.Column(sa.Integer, nullable=False)
36 |
37 | object = generic_relationship(
38 | object_type, object_id
39 | )
40 |
41 | @hybrid_property
42 | def object_version_type(self):
43 | return self.object_type + 'History'
44 |
45 | @object_version_type.expression
46 | def object_version_type(cls):
47 | return sa.func.concat(cls.object_type, 'History')
48 |
49 | object_version = generic_relationship(
50 | object_version_type, (object_id, transaction_id)
51 | )
52 | return Event
53 |
54 |
55 | @pytest.fixture
56 | def init_models(User, UserHistory, Event):
57 | pass
58 |
59 |
60 | class TestGenericRelationship:
61 |
62 | def test_set_manual_and_get(self, session, User, UserHistory, Event):
63 | user = User(id=1)
64 | history = UserHistory(id=1, transaction_id=1)
65 | session.add(user)
66 | session.add(history)
67 | session.commit()
68 |
69 | event = Event(transaction_id=1)
70 | event.object_id = user.id
71 | event.object_type = type(user).__name__
72 | assert event.object is None
73 |
74 | session.add(event)
75 | session.commit()
76 |
77 | assert event.object == user
78 | assert event.object_version == history
79 |
--------------------------------------------------------------------------------
/tests/generic_relationship/test_composite_keys.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import generic_relationship
5 |
6 | from ..generic_relationship import GenericRelationshipTestCase
7 |
8 |
9 | @pytest.fixture
10 | def incrementor():
11 | class Incrementor:
12 | value = 1
13 | return Incrementor()
14 |
15 |
16 | @pytest.fixture
17 | def Building(Base, incrementor):
18 | class Building(Base):
19 | __tablename__ = 'building'
20 | id = sa.Column(sa.Integer, primary_key=True)
21 | code = sa.Column(sa.Integer, primary_key=True)
22 |
23 | def __init__(self):
24 | incrementor.value += 1
25 | self.id = incrementor.value
26 | self.code = incrementor.value
27 | return Building
28 |
29 |
30 | @pytest.fixture
31 | def User(Base, incrementor):
32 | class User(Base):
33 | __tablename__ = 'user'
34 | id = sa.Column(sa.Integer, primary_key=True)
35 | code = sa.Column(sa.Integer, primary_key=True)
36 |
37 | def __init__(self):
38 | incrementor.value += 1
39 | self.id = incrementor.value
40 | self.code = incrementor.value
41 | return User
42 |
43 |
44 | @pytest.fixture
45 | def Event(Base):
46 | class Event(Base):
47 | __tablename__ = 'event'
48 | id = sa.Column(sa.Integer, primary_key=True)
49 |
50 | object_type = sa.Column(sa.Unicode(255))
51 | object_id = sa.Column(sa.Integer, nullable=False)
52 | object_code = sa.Column(sa.Integer, nullable=False)
53 |
54 | object = generic_relationship(
55 | object_type, (object_id, object_code)
56 | )
57 | return Event
58 |
59 |
60 | @pytest.fixture
61 | def init_models(Building, User, Event):
62 | pass
63 |
64 |
65 | class TestGenericRelationship(GenericRelationshipTestCase):
66 |
67 | def test_set_manual_and_get(self, session, Event, User):
68 | user = User()
69 |
70 | session.add(user)
71 | session.commit()
72 |
73 | event = Event()
74 | event.object_id = user.id
75 | event.object_type = type(user).__name__
76 | event.object_code = user.code
77 |
78 | assert event.object is None
79 |
80 | session.add(event)
81 | session.commit()
82 |
83 | assert event.object == user
84 |
--------------------------------------------------------------------------------
/tests/aggregate/test_m2m.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.aggregates import aggregated
5 |
6 |
7 | @pytest.fixture
8 | def User(Base):
9 | user_group = sa.Table(
10 | 'user_group',
11 | Base.metadata,
12 | sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
13 | sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
14 | )
15 |
16 | class User(Base):
17 | __tablename__ = 'user'
18 | id = sa.Column(sa.Integer, primary_key=True)
19 | name = sa.Column(sa.Unicode(255))
20 |
21 | @aggregated('groups', sa.Column(sa.Integer, default=0))
22 | def group_count(self):
23 | return sa.func.count('1')
24 |
25 | groups = sa.orm.relationship(
26 | 'Group',
27 | backref='users',
28 | secondary=user_group
29 | )
30 | return User
31 |
32 |
33 | @pytest.fixture
34 | def Group(Base):
35 | class Group(Base):
36 | __tablename__ = 'group'
37 | id = sa.Column(sa.Integer, primary_key=True)
38 | name = sa.Column(sa.Unicode(255))
39 | return Group
40 |
41 |
42 | @pytest.fixture
43 | def init_models(User, Group):
44 | pass
45 |
46 |
47 | @pytest.mark.skip
48 | @pytest.mark.usefixtures('postgresql_dsn')
49 | class TestAggregatesWithManyToManyRelationships:
50 |
51 | def test_assigns_aggregates_on_insert(self, session, User, Group):
52 | user = User(
53 | name='John Matrix'
54 | )
55 | session.add(user)
56 | session.commit()
57 | group = Group(
58 | name='Some group',
59 | users=[user]
60 | )
61 | session.add(group)
62 | session.commit()
63 | session.refresh(user)
64 | assert user.group_count == 1
65 |
66 | def test_updates_aggregates_on_delete(self, session, User, Group):
67 | user = User(
68 | name='John Matrix'
69 | )
70 | session.add(user)
71 | session.commit()
72 | group = Group(
73 | name='Some group',
74 | users=[user]
75 | )
76 | session.add(group)
77 | session.commit()
78 | session.refresh(user)
79 | user.groups = []
80 | session.commit()
81 | session.refresh(user)
82 | assert user.group_count == 0
83 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/functions/render.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import io
3 |
4 | import sqlalchemy as sa
5 |
6 | from .mock import create_mock_engine
7 | from .orm import _get_query_compile_state
8 |
9 |
10 | def render_expression(expression, bind, stream=None):
11 | """Generate a SQL expression from the passed python expression.
12 |
13 | Only the global variable, `engine`, is available for use in the
14 | expression. Additional local variables may be passed in the context
15 | parameter.
16 |
17 | Note this function is meant for convenience and protected usage. Do NOT
18 | blindly pass user input to this function as it uses exec.
19 |
20 | :param bind: A SQLAlchemy engine or bind URL.
21 | :param stream: Render all DDL operations to the stream.
22 | """
23 |
24 | # Create a stream if not present.
25 |
26 | if stream is None:
27 | stream = io.StringIO()
28 |
29 | engine = create_mock_engine(bind, stream)
30 |
31 | # Navigate the stack and find the calling frame that allows the
32 | # expression to execuate.
33 |
34 | for frame in inspect.stack()[1:]:
35 | try:
36 | frame = frame[0]
37 | local = dict(frame.f_locals)
38 | local['engine'] = engine
39 | exec(expression, frame.f_globals, local)
40 | break
41 | except Exception:
42 | pass
43 | else:
44 | raise ValueError('Not a valid python expression', engine)
45 |
46 | return stream
47 |
48 |
49 | def render_statement(statement, bind=None):
50 | """
51 | Generate an SQL expression string with bound parameters rendered inline
52 | for the given SQLAlchemy statement.
53 |
54 | :param statement: SQLAlchemy Query object.
55 | :param bind:
56 | Optional SQLAlchemy bind, if None uses the bind of the given query
57 | object.
58 | """
59 |
60 | if isinstance(statement, sa.orm.query.Query):
61 | if bind is None:
62 | bind = statement.session.get_bind(
63 | _get_query_compile_state(statement)._mapper_zero()
64 | )
65 |
66 | statement = statement.statement
67 |
68 | elif bind is None:
69 | bind = statement.bind
70 |
71 | stream = io.StringIO()
72 | engine = create_mock_engine(bind.engine, stream=stream)
73 | engine.execute(statement)
74 |
75 | return stream.getvalue()
76 |
--------------------------------------------------------------------------------
/tests/aggregate/test_custom_select_expressions.py:
--------------------------------------------------------------------------------
1 | from decimal import Decimal
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 |
6 | from sqlalchemy_utils.aggregates import aggregated
7 |
8 |
9 | @pytest.fixture
10 | def Product(Base):
11 | class Product(Base):
12 | __tablename__ = 'product'
13 | id = sa.Column(sa.Integer, primary_key=True)
14 | name = sa.Column(sa.Unicode(255))
15 | price = sa.Column(sa.Numeric)
16 |
17 | catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
18 | return Product
19 |
20 |
21 | @pytest.fixture
22 | def Catalog(Base, Product):
23 | class Catalog(Base):
24 | __tablename__ = 'catalog'
25 | id = sa.Column(sa.Integer, primary_key=True)
26 | name = sa.Column(sa.Unicode(255))
27 |
28 | @aggregated('products', sa.Column(sa.Numeric, default=0))
29 | def net_worth(self):
30 | return sa.func.sum(Product.price)
31 |
32 | products = sa.orm.relationship('Product', backref='catalog')
33 | return Catalog
34 |
35 |
36 | @pytest.fixture
37 | def init_models(Product, Catalog):
38 | pass
39 |
40 |
41 | @pytest.mark.usefixtures('postgresql_dsn')
42 | class TestLazyEvaluatedSelectExpressionsForAggregates:
43 |
44 | def test_assigns_aggregates_on_insert(self, session, Product, Catalog):
45 | catalog = Catalog(
46 | name='Some catalog'
47 | )
48 | session.add(catalog)
49 | session.commit()
50 | product = Product(
51 | name='Some product',
52 | price=Decimal('1000'),
53 | catalog=catalog
54 | )
55 | session.add(product)
56 | session.commit()
57 | session.refresh(catalog)
58 | assert catalog.net_worth == Decimal('1000')
59 |
60 | def test_assigns_aggregates_on_update(self, session, Product, Catalog):
61 | catalog = Catalog(
62 | name='Some catalog'
63 | )
64 | session.add(catalog)
65 | session.commit()
66 | product = Product(
67 | name='Some product',
68 | price=Decimal('1000'),
69 | catalog=catalog
70 | )
71 | session.add(product)
72 | session.commit()
73 | product.price = Decimal('500')
74 | session.commit()
75 | session.refresh(catalog)
76 | assert catalog.net_worth == Decimal('500')
77 |
--------------------------------------------------------------------------------
/tests/types/test_tsvector.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | from sqlalchemy.dialects.postgresql import TSVECTOR
4 |
5 | from sqlalchemy_utils import TSVectorType
6 |
7 |
8 | @pytest.fixture
9 | def User(Base):
10 | class User(Base):
11 | __tablename__ = 'user'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | name = sa.Column(sa.Unicode(255))
14 | search_index = sa.Column(
15 | TSVectorType(name, regconfig='pg_catalog.finnish')
16 | )
17 |
18 | def __repr__(self):
19 | return 'User(%r)' % self.id
20 | return User
21 |
22 |
23 | @pytest.fixture
24 | def init_models(User):
25 | pass
26 |
27 |
28 | @pytest.mark.usefixtures('postgresql_dsn')
29 | class TestTSVector:
30 | def test_generates_table(self, User):
31 | assert 'search_index' in User.__table__.c
32 |
33 | @pytest.mark.usefixtures('session')
34 | def test_type_reflection(self, engine):
35 | reflected_metadata = sa.schema.MetaData()
36 | table = sa.schema.Table(
37 | 'user',
38 | reflected_metadata,
39 | autoload_with=engine,
40 | )
41 | assert isinstance(table.c['search_index'].type, TSVECTOR)
42 |
43 | def test_catalog_and_columns_as_args(self):
44 | type_ = TSVectorType('name', 'age', regconfig='pg_catalog.simple')
45 | assert type_.columns == ('name', 'age')
46 | assert type_.options['regconfig'] == 'pg_catalog.simple'
47 |
48 | def test_catalog_passed_to_match(self, connection, User):
49 | expr = User.search_index.match('something')
50 | assert 'pg_catalog.finnish' in str(expr.compile(connection))
51 |
52 | def test_concat(self, User):
53 | assert str(User.search_index | User.search_index) == (
54 | '"user".search_index || "user".search_index'
55 | )
56 |
57 | def test_match_concatenation(self, session, User):
58 | concat = User.search_index | User.search_index
59 | bind = session.bind
60 | assert 'pg_catalog.finnish' in str(
61 | concat.match('something').compile(bind)
62 | )
63 |
64 | def test_match_with_catalog(self, connection, User):
65 | expr = User.search_index.match(
66 | 'something',
67 | postgresql_regconfig='pg_catalog.simple'
68 | )
69 | assert 'pg_catalog.simple' in str(expr.compile(connection))
70 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/color.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from ..exceptions import ImproperlyConfigured
4 | from .scalar_coercible import ScalarCoercible
5 |
6 | colour = None
7 | try:
8 | import colour
9 |
10 | python_colour_type = colour.Color
11 | except (ImportError, AttributeError):
12 | python_colour_type = None
13 |
14 |
15 | class ColorType(ScalarCoercible, types.TypeDecorator):
16 | """
17 | ColorType provides a way for saving Color (from colour_ package) objects
18 | into database. ColorType saves Color objects as strings on the way in and
19 | converts them back to objects when querying the database.
20 |
21 | ::
22 |
23 |
24 | from colour import Color
25 | from sqlalchemy_utils import ColorType
26 |
27 |
28 | class Document(Base):
29 | __tablename__ = 'document'
30 | id = sa.Column(sa.Integer, autoincrement=True)
31 | name = sa.Column(sa.Unicode(50))
32 | background_color = sa.Column(ColorType)
33 |
34 |
35 | document = Document()
36 | document.background_color = Color('#F5F5F5')
37 | session.commit()
38 |
39 |
40 | Querying the database returns Color objects:
41 |
42 | ::
43 |
44 | document = session.query(Document).first()
45 |
46 | document.background_color.hex
47 | # '#f5f5f5'
48 |
49 |
50 | .. _colour: https://github.com/vaab/colour
51 | """
52 |
53 | STORE_FORMAT = 'hex'
54 | impl = types.Unicode(20)
55 | python_type = python_colour_type
56 | cache_ok = True
57 |
58 | def __init__(self, max_length=20, *args, **kwargs):
59 | # Fail if colour is not found.
60 | if colour is None:
61 | raise ImproperlyConfigured(
62 | "'colour' package is required to use 'ColorType'"
63 | )
64 |
65 | super().__init__(*args, **kwargs)
66 | self.impl = types.Unicode(max_length)
67 |
68 | def process_bind_param(self, value, dialect):
69 | if value and isinstance(value, colour.Color):
70 | return str(getattr(value, self.STORE_FORMAT))
71 | return value
72 |
73 | def process_result_value(self, value, dialect):
74 | if value:
75 | return colour.Color(value)
76 | return value
77 |
78 | def _coerce(self, value):
79 | if value is not None and not isinstance(value, colour.Color):
80 | return colour.Color(value)
81 | return value
82 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/json.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import sqlalchemy as sa
4 | from sqlalchemy.dialects.postgresql.base import ischema_names
5 |
6 | try:
7 | from sqlalchemy.dialects.postgresql import JSON
8 |
9 | has_postgres_json = True
10 | except ImportError:
11 |
12 | class PostgresJSONType(sa.types.UserDefinedType):
13 | """
14 | Text search vector type for postgresql.
15 | """
16 |
17 | def get_col_spec(self):
18 | return 'json'
19 |
20 | ischema_names['json'] = PostgresJSONType
21 | has_postgres_json = False
22 |
23 |
24 | class JSONType(sa.types.TypeDecorator):
25 | """
26 | JSONType offers way of saving JSON data structures to database. On
27 | PostgreSQL the underlying implementation of this data type is 'json' while
28 | on other databases its simply 'text'.
29 |
30 | ::
31 |
32 |
33 | from sqlalchemy_utils import JSONType
34 |
35 |
36 | class Product(Base):
37 | __tablename__ = 'product'
38 | id = sa.Column(sa.Integer, autoincrement=True)
39 | name = sa.Column(sa.Unicode(50))
40 | details = sa.Column(JSONType)
41 |
42 |
43 | product = Product()
44 | product.details = {
45 | 'color': 'red',
46 | 'type': 'car',
47 | 'max-speed': '400 mph'
48 | }
49 | session.commit()
50 | """
51 |
52 | impl = sa.UnicodeText
53 | hashable = False
54 | cache_ok = True
55 |
56 | def __init__(self, *args, **kwargs):
57 | super().__init__(*args, **kwargs)
58 |
59 | def load_dialect_impl(self, dialect):
60 | if dialect.name == 'postgresql':
61 | # Use the native JSON type.
62 | if has_postgres_json:
63 | return dialect.type_descriptor(JSON())
64 | else:
65 | return dialect.type_descriptor(PostgresJSONType())
66 | else:
67 | return dialect.type_descriptor(self.impl)
68 |
69 | def process_bind_param(self, value, dialect):
70 | if dialect.name == 'postgresql' and has_postgres_json:
71 | return value
72 | if value is not None:
73 | value = json.dumps(value)
74 | return value
75 |
76 | def process_result_value(self, value, dialect):
77 | if dialect.name == 'postgresql':
78 | return value
79 | if value is not None:
80 | value = json.loads(value)
81 | return value
82 |
--------------------------------------------------------------------------------
/tests/aggregate/test_backrefs.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.aggregates import aggregated
5 |
6 |
7 | @pytest.fixture
8 | def Thread(Base):
9 | class Thread(Base):
10 | __tablename__ = 'thread'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | name = sa.Column(sa.Unicode(255))
13 |
14 | @aggregated('comments', sa.Column(sa.Integer, default=0))
15 | def comment_count(self):
16 | return sa.func.count('1')
17 | return Thread
18 |
19 |
20 | @pytest.fixture
21 | def Comment(Base, Thread):
22 | class Comment(Base):
23 | __tablename__ = 'comment'
24 | id = sa.Column(sa.Integer, primary_key=True)
25 | content = sa.Column(sa.Unicode(255))
26 | thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
27 |
28 | thread = sa.orm.relationship(Thread, backref='comments')
29 | return Comment
30 |
31 |
32 | @pytest.fixture
33 | def init_models(Thread, Comment):
34 | pass
35 |
36 |
37 | class TestAggregateValueGenerationWithBackrefs:
38 |
39 | def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
40 | thread = Thread()
41 | thread.name = 'some article name'
42 | session.add(thread)
43 | comment = Comment(content='Some content', thread=thread)
44 | session.add(comment)
45 | session.commit()
46 | session.refresh(thread)
47 | assert thread.comment_count == 1
48 |
49 | def test_assigns_aggregates_on_separate_insert(
50 | self,
51 | session,
52 | Thread,
53 | Comment
54 | ):
55 | thread = Thread()
56 | thread.name = 'some article name'
57 | session.add(thread)
58 | session.commit()
59 | comment = Comment(content='Some content', thread=thread)
60 | session.add(comment)
61 | session.commit()
62 | session.refresh(thread)
63 | assert thread.comment_count == 1
64 |
65 | def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
66 | thread = Thread()
67 | thread.name = 'some article name'
68 | session.add(thread)
69 | session.commit()
70 | comment = Comment(content='Some content', thread=thread)
71 | session.add(comment)
72 | session.commit()
73 | session.delete(comment)
74 | session.commit()
75 | session.refresh(thread)
76 | assert thread.comment_count == 0
77 |
--------------------------------------------------------------------------------
/tests/aggregate/test_simple_paths.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.aggregates import aggregated
5 |
6 |
7 | @pytest.fixture
8 | def Thread(Base):
9 | class Thread(Base):
10 | __tablename__ = 'thread'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | name = sa.Column(sa.Unicode(255))
13 |
14 | @aggregated('comments', sa.Column(sa.Integer, default=0))
15 | def comment_count(self):
16 | return sa.func.count('1')
17 |
18 | comments = sa.orm.relationship('Comment', backref='thread')
19 | return Thread
20 |
21 |
22 | @pytest.fixture
23 | def Comment(Base):
24 | class Comment(Base):
25 | __tablename__ = 'comment'
26 | id = sa.Column(sa.Integer, primary_key=True)
27 | content = sa.Column(sa.Unicode(255))
28 | thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
29 | return Comment
30 |
31 |
32 | @pytest.fixture
33 | def init_models(Thread, Comment):
34 | pass
35 |
36 |
37 | class TestAggregateValueGenerationForSimpleModelPaths:
38 |
39 | def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
40 | thread = Thread()
41 | thread.name = 'some article name'
42 | session.add(thread)
43 | comment = Comment(content='Some content', thread=thread)
44 | session.add(comment)
45 | session.commit()
46 | session.refresh(thread)
47 | assert thread.comment_count == 1
48 |
49 | def test_assigns_aggregates_on_separate_insert(
50 | self,
51 | session,
52 | Thread,
53 | Comment
54 | ):
55 | thread = Thread()
56 | thread.name = 'some article name'
57 | session.add(thread)
58 | session.commit()
59 | comment = Comment(content='Some content', thread=thread)
60 | session.add(comment)
61 | session.commit()
62 | session.refresh(thread)
63 | assert thread.comment_count == 1
64 |
65 | def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
66 | thread = Thread()
67 | thread.name = 'some article name'
68 | session.add(thread)
69 | session.commit()
70 | comment = Comment(content='Some content', thread=thread)
71 | session.add(comment)
72 | session.commit()
73 | session.delete(comment)
74 | session.commit()
75 | session.refresh(thread)
76 | assert thread.comment_count == 0
77 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/weekdays.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from .. import i18n
4 | from ..exceptions import ImproperlyConfigured
5 | from ..primitives import WeekDay, WeekDays
6 | from .bit import BitType
7 | from .scalar_coercible import ScalarCoercible
8 |
9 |
10 | class WeekDaysType(types.TypeDecorator, ScalarCoercible):
11 | """
12 | WeekDaysType offers way of saving WeekDays objects into database. The
13 | WeekDays objects are converted to bit strings on the way in and back to
14 | WeekDays objects on the way out.
15 |
16 | In order to use WeekDaysType you need to install Babel_ first.
17 |
18 | .. _Babel: https://babel.pocoo.org/
19 |
20 | ::
21 |
22 |
23 | from sqlalchemy_utils import WeekDaysType, WeekDays
24 | from babel import Locale
25 |
26 |
27 | class Schedule(Base):
28 | __tablename__ = 'schedule'
29 | id = sa.Column(sa.Integer, autoincrement=True)
30 | working_days = sa.Column(WeekDaysType)
31 |
32 |
33 | schedule = Schedule()
34 | schedule.working_days = WeekDays('0001111')
35 | session.add(schedule)
36 | session.commit()
37 |
38 | print schedule.working_days # Thursday, Friday, Saturday, Sunday
39 |
40 |
41 | WeekDaysType also supports scalar coercion:
42 |
43 | ::
44 |
45 |
46 | schedule.working_days = '1110000'
47 | schedule.working_days # WeekDays object
48 |
49 | """
50 |
51 | impl = BitType(WeekDay.NUM_WEEK_DAYS)
52 |
53 | cache_ok = True
54 |
55 | def __init__(self, *args, **kwargs):
56 | if i18n.babel is None:
57 | raise ImproperlyConfigured(
58 | "'babel' package is required to use 'WeekDaysType'"
59 | )
60 |
61 | super().__init__(*args, **kwargs)
62 |
63 | @property
64 | def comparator_factory(self):
65 | return self.impl.comparator_factory
66 |
67 | def process_bind_param(self, value, dialect):
68 | if isinstance(value, WeekDays):
69 | value = value.as_bit_string()
70 |
71 | if dialect.name == 'mysql':
72 | return bytes(value, 'utf8')
73 | return value
74 |
75 | def process_result_value(self, value, dialect):
76 | if value is not None:
77 | return WeekDays(value)
78 |
79 | def _coerce(self, value):
80 | if value is not None and not isinstance(value, WeekDays):
81 | return WeekDays(value)
82 | return value
83 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/functions/sort_query.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy as sa
2 |
3 | from .database import has_unique_index
4 | from .orm import _get_query_compile_state, get_tables
5 |
6 |
7 | def make_order_by_deterministic(query):
8 | """
9 | Make query order by deterministic (if it isn't already). Order by is
10 | considered deterministic if it contains column that is unique index (
11 | either it is a primary key or has a unique index). Many times it is design
12 | flaw to order by queries in nondeterministic manner.
13 |
14 | Consider a User model with three fields: id (primary key), favorite color
15 | and email (unique).::
16 |
17 |
18 | from sqlalchemy_utils import make_order_by_deterministic
19 |
20 |
21 | query = session.query(User).order_by(User.favorite_color)
22 |
23 | query = make_order_by_deterministic(query)
24 | print query # 'SELECT ... ORDER BY "user".favorite_color, "user".id'
25 |
26 |
27 | query = session.query(User).order_by(User.email)
28 |
29 | query = make_order_by_deterministic(query)
30 | print query # 'SELECT ... ORDER BY "user".email'
31 |
32 |
33 | query = session.query(User).order_by(User.id)
34 |
35 | query = make_order_by_deterministic(query)
36 | print query # 'SELECT ... ORDER BY "user".id'
37 |
38 |
39 | .. versionadded: 0.27.1
40 | """
41 | order_by_func = sa.asc
42 |
43 | try:
44 | order_by_clauses = query._order_by_clauses
45 | except AttributeError: # SQLAlchemy <1.4
46 | order_by_clauses = query._order_by
47 | if not order_by_clauses:
48 | column = None
49 | else:
50 | order_by = order_by_clauses[0]
51 | if isinstance(order_by, sa.sql.elements._label_reference):
52 | order_by = order_by.element
53 | if isinstance(order_by, sa.sql.expression.UnaryExpression):
54 | if order_by.modifier == sa.sql.operators.desc_op:
55 | order_by_func = sa.desc
56 | else:
57 | order_by_func = sa.asc
58 | column = list(order_by.get_children())[0]
59 | else:
60 | column = order_by
61 |
62 | # Skip queries that are ordered by an already deterministic column
63 | if isinstance(column, sa.Column):
64 | try:
65 | if has_unique_index(column):
66 | return query
67 | except TypeError:
68 | pass
69 |
70 | base_table = get_tables(_get_query_compile_state(query)._entities[0])[0]
71 | query = query.order_by(*(order_by_func(c) for c in base_table.c if c.primary_key))
72 | return query
73 |
--------------------------------------------------------------------------------
/tests/functions/test_render.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.functions import (
5 | mock_engine,
6 | render_expression,
7 | render_statement
8 | )
9 |
10 |
11 | class TestRender:
12 |
13 | @pytest.fixture
14 | def User(self, Base):
15 | class User(Base):
16 | __tablename__ = 'user'
17 | id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
18 | name = sa.Column(sa.Unicode(255))
19 | return User
20 |
21 | @pytest.fixture
22 | def init_models(self, User):
23 | pass
24 |
25 | def test_render_orm_query(self, session, User):
26 | query = session.query(User).filter_by(id=3)
27 | text = render_statement(query)
28 |
29 | assert 'SELECT user.id, user.name' in text
30 | assert 'FROM user' in text
31 | assert 'WHERE user.id = 3' in text
32 |
33 | def test_render_statement(self, session, User):
34 | statement = User.__table__.select().where(User.id == 3)
35 | text = render_statement(statement, bind=session.bind)
36 |
37 | assert 'SELECT user.id, user.name' in text
38 | assert 'FROM user' in text
39 | assert 'WHERE user.id = 3' in text
40 |
41 | def test_render_statement_without_mapper(self, session):
42 | statement = sa.select(sa.text('1'))
43 | text = render_statement(statement, bind=session.bind)
44 |
45 | assert 'SELECT 1' in text
46 |
47 | def test_render_ddl(self, engine, User):
48 | expression = 'User.__table__.create(engine)'
49 | stream = render_expression(expression, engine)
50 |
51 | text = stream.getvalue()
52 |
53 | assert 'CREATE TABLE user' in text
54 | assert 'PRIMARY KEY' in text
55 |
56 | def test_render_mock_ddl(self, engine, User):
57 | # TODO: mock_engine doesn't seem to work with locally scoped variables.
58 | self.engine = engine
59 | with mock_engine('self.engine') as stream:
60 | User.__table__.create(self.engine)
61 |
62 | text = stream.getvalue()
63 |
64 | assert 'CREATE TABLE user' in text
65 | assert 'PRIMARY KEY' in text
66 |
67 | def test_render_statement_hangul(self, User, session):
68 | statement = User.__table__.select().where(
69 | (User.id == 3)
70 | & (User.name == "한글")
71 | )
72 | text = render_statement(statement, bind=session.bind)
73 |
74 | assert "SELECT user.id, user.name" in text
75 | assert "FROM user" in text
76 | assert "user.id = 3" in text
77 | assert "user.name = '한글'" in text
78 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/proxy_dict.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy as sa
2 |
3 |
4 | class ProxyDict:
5 | def __init__(self, parent, collection_name, mapping_attr):
6 | self.parent = parent
7 | self.collection_name = collection_name
8 | self.child_class = mapping_attr.class_
9 | self.key_name = mapping_attr.key
10 | self.cache = {}
11 |
12 | @property
13 | def collection(self):
14 | return getattr(self.parent, self.collection_name)
15 |
16 | def keys(self):
17 | descriptor = getattr(self.child_class, self.key_name)
18 | return [x[0] for x in self.collection.values(descriptor)]
19 |
20 | def __contains__(self, key):
21 | if key in self.cache:
22 | return self.cache[key] is not None
23 | return self.fetch(key) is not None
24 |
25 | def has_key(self, key):
26 | return self.__contains__(key)
27 |
28 | def fetch(self, key):
29 | session = sa.orm.object_session(self.parent)
30 | if session and sa.orm.util.has_identity(self.parent):
31 | obj = self.collection.filter_by(**{self.key_name: key}).first()
32 | self.cache[key] = obj
33 | return obj
34 |
35 | def create_new_instance(self, key):
36 | value = self.child_class(**{self.key_name: key})
37 | self.collection.append(value)
38 | self.cache[key] = value
39 | return value
40 |
41 | def __getitem__(self, key):
42 | if key in self.cache:
43 | if self.cache[key] is not None:
44 | return self.cache[key]
45 | else:
46 | value = self.fetch(key)
47 | if value:
48 | return value
49 |
50 | return self.create_new_instance(key)
51 |
52 | def __setitem__(self, key, value):
53 | try:
54 | existing = self[key]
55 | self.collection.remove(existing)
56 | except KeyError:
57 | pass
58 | self.collection.append(value)
59 | self.cache[key] = value
60 |
61 |
62 | def proxy_dict(parent, collection_name, mapping_attr):
63 | try:
64 | parent._proxy_dicts
65 | except AttributeError:
66 | parent._proxy_dicts = {}
67 |
68 | try:
69 | return parent._proxy_dicts[collection_name]
70 | except KeyError:
71 | parent._proxy_dicts[collection_name] = ProxyDict(
72 | parent, collection_name, mapping_attr
73 | )
74 | return parent._proxy_dicts[collection_name]
75 |
76 |
77 | def expire_proxy_dicts(target, context):
78 | if hasattr(target, '_proxy_dicts'):
79 | target._proxy_dicts = {}
80 |
81 |
82 | sa.event.listen(sa.orm.Mapper, 'expire', expire_proxy_dicts)
83 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Tests
2 | on: [push, pull_request, workflow_dispatch]
3 | jobs:
4 | test:
5 | strategy:
6 | fail-fast: false
7 | matrix:
8 | os: ['ubuntu-latest']
9 | python-version:
10 | - '3.9'
11 | - '3.10'
12 | - '3.11'
13 | - '3.12'
14 | - '3.13'
15 | tox_env: ['sqlalchemy14', 'sqlalchemy2']
16 | runs-on: ${{ matrix.os }}
17 | services:
18 | postgres:
19 | image: postgres:latest
20 | env:
21 | POSTGRES_USER: postgres
22 | POSTGRES_PASSWORD: postgres
23 | POSTGRES_DB: sqlalchemy_utils_test
24 | # Set health checks to wait until PostgreSQL has started
25 | options: >-
26 | --health-cmd pg_isready
27 | --health-interval 10s
28 | --health-timeout 5s
29 | --health-retries 5
30 | ports:
31 | - 5432:5432
32 | mysql:
33 | image: mysql:latest
34 | env:
35 | MYSQL_ALLOW_EMPTY_PASSWORD: yes
36 | MYSQL_DATABASE: sqlalchemy_utils_test
37 | ports:
38 | - 3306:3306
39 | mssql:
40 | image: mcr.microsoft.com/mssql/server:2017-latest
41 | env:
42 | ACCEPT_EULA: Y
43 | SA_PASSWORD: Strong_Passw0rd
44 | # Set health checks to wait until SQL Server has started
45 | options: >-
46 | --health-cmd "/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P ${SA_PASSWORD} -Q 'SELECT 1;' -b"
47 | --health-interval 10s
48 | --health-timeout 5s
49 | --health-retries 5
50 | ports:
51 | - 1433:1433
52 | steps:
53 | - uses: actions/checkout@v6
54 | - name: Set up Python
55 | uses: actions/setup-python@v6
56 | with:
57 | python-version: ${{ matrix.python-version }}
58 | - name: Install MS SQL stuff
59 | run: bash .ci/install_mssql.sh
60 | - name: Add hstore extension to the sqlalchemy_utils_test database
61 | env:
62 | PGHOST: localhost
63 | PGPASSWORD: postgres
64 | PGPORT: 5432
65 | run: psql -U postgres -d sqlalchemy_utils_test -c 'CREATE EXTENSION hstore;'
66 | - name: Install tox
67 | run: pip install tox
68 | - name: Run tests
69 | env:
70 | SQLALCHEMY_UTILS_TEST_POSTGRESQL_PASSWORD: postgres
71 | run: tox -e ${{ matrix.tox_env }}
72 |
73 | docs:
74 | runs-on: 'ubuntu-latest'
75 | steps:
76 | - uses: actions/checkout@v6
77 | - name: Set up Python
78 | uses: actions/setup-python@v6
79 | with:
80 | python-version: '3.13'
81 | - name: Install tox
82 | run: pip install tox
83 | - name: Build documentation
84 | run: tox -e docs
85 |
--------------------------------------------------------------------------------
/tests/functions/test_get_tables.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import get_tables
5 | from sqlalchemy_utils.functions.orm import _get_query_compile_state
6 |
7 |
8 | @pytest.fixture
9 | def TextItem(Base):
10 | class TextItem(Base):
11 | __tablename__ = 'text_item'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | name = sa.Column(sa.Unicode(255))
14 | type = sa.Column(sa.Unicode(255))
15 |
16 | __mapper_args__ = {
17 | 'polymorphic_on': type,
18 | 'with_polymorphic': '*'
19 | }
20 | return TextItem
21 |
22 |
23 | @pytest.fixture
24 | def Article(TextItem):
25 | class Article(TextItem):
26 | __tablename__ = 'article'
27 | id = sa.Column(
28 | sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
29 | )
30 | __mapper_args__ = {
31 | 'polymorphic_identity': 'article'
32 | }
33 | return Article
34 |
35 |
36 | @pytest.fixture
37 | def init_models(TextItem, Article):
38 | pass
39 |
40 |
41 | class TestGetTables:
42 |
43 | def test_child_class_using_join_table_inheritance(self, TextItem, Article):
44 | assert get_tables(Article) == [
45 | TextItem.__table__,
46 | Article.__table__
47 | ]
48 |
49 | def test_entity_using_with_polymorphic(self, TextItem, Article):
50 | assert get_tables(TextItem) == [
51 | TextItem.__table__,
52 | Article.__table__
53 | ]
54 |
55 | def test_instrumented_attribute(self, TextItem):
56 | assert get_tables(TextItem.name) == [
57 | TextItem.__table__,
58 | ]
59 |
60 | def test_polymorphic_instrumented_attribute(self, TextItem, Article):
61 | assert get_tables(Article.id) == [
62 | TextItem.__table__,
63 | Article.__table__
64 | ]
65 |
66 | def test_column(self, Article):
67 | assert get_tables(Article.__table__.c.id) == [
68 | Article.__table__
69 | ]
70 |
71 | def test_mapper_entity_with_class(self, session, TextItem, Article):
72 | query = session.query(Article)
73 | entity = _get_query_compile_state(query)._entities[0]
74 | assert get_tables(entity) == [TextItem.__table__, Article.__table__]
75 |
76 | def test_mapper_entity_with_mapper(self, session, TextItem, Article):
77 | query = session.query(sa.inspect(Article))
78 | entity = _get_query_compile_state(query)._entities[0]
79 | assert get_tables(entity) == [TextItem.__table__, Article.__table__]
80 |
81 | def test_column_entity(self, session, TextItem, Article):
82 | query = session.query(Article.id)
83 | entity = _get_query_compile_state(query)._entities[0]
84 | assert get_tables(entity) == [TextItem.__table__, Article.__table__]
85 |
--------------------------------------------------------------------------------
/tests/aggregate/test_o2m_m2m.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | import sqlalchemy.orm
4 |
5 | from sqlalchemy_utils import aggregated
6 |
7 |
8 | @pytest.fixture
9 | def Category(Base):
10 | class Category(Base):
11 | __tablename__ = 'category'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | name = sa.Column(sa.Unicode(255))
14 | return Category
15 |
16 |
17 | @pytest.fixture
18 | def Catalog(Base, Category):
19 | class Catalog(Base):
20 | __tablename__ = 'catalog'
21 | id = sa.Column(sa.Integer, primary_key=True)
22 | name = sa.Column(sa.Unicode(255))
23 |
24 | @aggregated(
25 | 'products.categories',
26 | sa.Column(sa.Integer, default=0)
27 | )
28 | def category_count(self):
29 | return sa.func.count(sa.distinct(Category.id))
30 | return Catalog
31 |
32 |
33 | @pytest.fixture
34 | def Product(Base, Catalog, Category):
35 | product_categories = sa.Table(
36 | 'category_product',
37 | Base.metadata,
38 | sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
39 | sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
40 | )
41 |
42 | class Product(Base):
43 | __tablename__ = 'product'
44 | id = sa.Column(sa.Integer, primary_key=True)
45 | name = sa.Column(sa.Unicode(255))
46 | price = sa.Column(sa.Numeric)
47 |
48 | catalog_id = sa.Column(
49 | sa.Integer, sa.ForeignKey('catalog.id')
50 | )
51 |
52 | catalog = sa.orm.relationship(
53 | Catalog,
54 | backref='products'
55 | )
56 |
57 | categories = sa.orm.relationship(
58 | Category,
59 | backref='products',
60 | secondary=product_categories
61 | )
62 | return Product
63 |
64 |
65 | @pytest.fixture
66 | def init_models(Category, Catalog, Product):
67 | pass
68 |
69 |
70 | @pytest.mark.usefixtures('postgresql_dsn')
71 | class TestAggregateOneToManyAndManyToMany:
72 |
73 | def test_insert(self, session, Category, Catalog, Product):
74 | category = Category()
75 | session.add(category)
76 | products = [
77 | Product(categories=[category]),
78 | Product(categories=[category])
79 | ]
80 | [session.add(product) for product in products]
81 | catalog = Catalog(products=products)
82 | session.add(catalog)
83 | products2 = [
84 | Product(categories=[category]),
85 | Product(categories=[category])
86 | ]
87 | [session.add(product) for product in products2]
88 | catalog2 = Catalog(products=products2)
89 | session.add(catalog2)
90 | session.commit()
91 | assert catalog.category_count == 1
92 | assert catalog2.category_count == 1
93 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregates import aggregated # noqa
2 | from .asserts import ( # noqa
3 | assert_max_length,
4 | assert_max_value,
5 | assert_min_value,
6 | assert_non_nullable,
7 | assert_nullable,
8 | )
9 | from .exceptions import ImproperlyConfigured # noqa
10 | from .expressions import Asterisk, row_to_json # noqa
11 | from .functions import ( # noqa
12 | cast_if,
13 | create_database,
14 | create_mock_engine,
15 | database_exists,
16 | dependent_objects,
17 | drop_database,
18 | escape_like,
19 | get_bind,
20 | get_class_by_table,
21 | get_column_key,
22 | get_columns,
23 | get_declarative_base,
24 | get_fk_constraint_for_columns,
25 | get_hybrid_properties,
26 | get_mapper,
27 | get_primary_keys,
28 | get_referencing_foreign_keys,
29 | get_tables,
30 | get_type,
31 | group_foreign_keys,
32 | has_changes,
33 | has_index,
34 | has_unique_index,
35 | identity,
36 | is_loaded,
37 | json_sql,
38 | jsonb_sql,
39 | merge_references,
40 | mock_engine,
41 | naturally_equivalent,
42 | render_expression,
43 | render_statement,
44 | table_name,
45 | )
46 | from .generic import generic_relationship # noqa
47 | from .i18n import TranslationHybrid # noqa
48 | from .listeners import ( # noqa
49 | auto_delete_orphans,
50 | coercion_listener,
51 | force_auto_coercion,
52 | force_instant_defaults,
53 | )
54 | from .models import generic_repr, Timestamp # noqa
55 | from .observer import observes # noqa
56 | from .primitives import Country, Currency, Ltree, WeekDay, WeekDays # noqa
57 | from .proxy_dict import proxy_dict, ProxyDict # noqa
58 | from .query_chain import QueryChain # noqa
59 | from .types import ( # noqa
60 | ArrowType,
61 | Choice,
62 | ChoiceType,
63 | ColorType,
64 | CompositeType,
65 | CountryType,
66 | CurrencyType,
67 | DateRangeType,
68 | DateTimeRangeType,
69 | EmailType,
70 | EncryptedType,
71 | EnrichedDateTimeType,
72 | EnrichedDateType,
73 | instrumented_list,
74 | InstrumentedList,
75 | Int8RangeType,
76 | IntRangeType,
77 | IPAddressType,
78 | JSONType,
79 | LocaleType,
80 | LtreeType,
81 | NumericRangeType,
82 | Password,
83 | PasswordType,
84 | PhoneNumber,
85 | PhoneNumberParseException,
86 | PhoneNumberType,
87 | register_composites,
88 | remove_composite_listeners,
89 | ScalarListException,
90 | ScalarListType,
91 | StringEncryptedType,
92 | TimezoneType,
93 | TSVectorType,
94 | URLType,
95 | UUIDType,
96 | WeekDaysType,
97 | )
98 | from .view import ( # noqa
99 | create_materialized_view,
100 | create_view,
101 | refresh_materialized_view,
102 | )
103 |
104 | __version__ = '0.42.0'
105 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist =
3 | py{3.9, 3.10, 3.11, 3.12, 3.13}-sqlalchemy{14, 2}
4 | ruff
5 | docs
6 | labels =
7 | update = update-{requirements, pre-commit}
8 |
9 | [testenv]
10 | commands =
11 | py.test sqlalchemy_utils --cov=sqlalchemy_utils --cov-report=xml tests
12 | deps =
13 | .[test_all]
14 | pytest-cov
15 | sqlalchemy14: SQLAlchemy>=1.4,<1.5
16 | sqlalchemy2: SQLAlchemy>=2
17 | ; It's sometimes necessary to test against specific sqlalchemy versions
18 | ; to verify bug or feature behavior before or after a specific version.
19 | ;
20 | ; You can choose the sqlalchemy version using this command syntax:
21 | ; $ tox -e py39-sqlalchemy1_4_xx
22 | ;
23 | ; sqlalchemy 1.4.19 through 1.4.23 have JSON-related TypeErrors. See #543.
24 | sqlalchemy1_4_19: SQLAlchemy==1.4.19
25 | ; sqlalchemy <1.4.28 threw a DeprecationWarning when copying a URL. See #573.
26 | sqlalchemy1_4_27: SQLAlchemy==1.4.27
27 | ; sqlalchemy 1.4.30 introduced UUID literal quoting. See #580.
28 | sqlalchemy1_4_29: SQLAlchemy==1.4.29
29 | setenv =
30 | SQLALCHEMY_WARN_20 = true
31 | passenv =
32 | SQLALCHEMY_UTILS_TEST_*
33 | recreate = True
34 |
35 | [testenv:ruff]
36 | skip_install = True
37 | recreate = False
38 | deps = pre-commit
39 | commands =
40 | pre-commit run --all ruff-check
41 | pre-commit run --all ruff-format
42 |
43 | [testenv:docs]
44 | base_python = py3.13
45 | recreate = False
46 | deps = -r requirements/docs/requirements.txt
47 | commands = sphinx-build docs/ build/docs
48 |
49 | [pytest]
50 | filterwarnings =
51 | error
52 | ; Ignore ResourceWarnings caused by unclosed sockets in pg8000.
53 | ; These are caught and re-raised as pytest.PytestUnraisableExceptionWarnings.
54 | ignore:Exception ignored:pytest.PytestUnraisableExceptionWarning
55 | ; Ignore DeprecationWarnings caused by pg8000.
56 | ignore:distutils Version classes are deprecated.:DeprecationWarning
57 | ; Ignore warnings about passlib's use of the crypt module (Python 3.11).
58 | ignore:'crypt' is deprecated and slated for removal:DeprecationWarning
59 | ; Ignore Python 3.12 UTC deprecation warnings caused by pendulum.
60 | ignore:datetime.datetime.utcfromtimestamp:DeprecationWarning
61 |
62 |
63 | [testenv:update-{requirements, pre-commit}]
64 | base_python = py3.13
65 | recreate = true
66 | description = Update tool dependency versions
67 | skip_install = true
68 | deps =
69 | requirements: poetry
70 | requirements: poetry-plugin-export
71 | pre-commit: pre-commit
72 | commands =
73 | # Update requirements files
74 | requirements: poetry update --directory="requirements/docs" --lock
75 | requirements: poetry export --directory="requirements/docs" --output="requirements.txt" --without-hashes
76 |
77 | # Update pre-commit hook versions
78 | pre-commit: pre-commit autoupdate
79 |
80 | # Run pre-commit immediately, but ignore its exit code
81 | pre-commit: - pre-commit run -a
82 |
--------------------------------------------------------------------------------
/tests/types/test_arrow.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 | from dateutil import tz
6 |
7 | from sqlalchemy_utils.types import arrow
8 |
9 |
10 | @pytest.fixture
11 | def Article(Base):
12 | class Article(Base):
13 | __tablename__ = 'article'
14 | id = sa.Column(sa.Integer, primary_key=True)
15 | created_at = sa.Column(arrow.ArrowType)
16 | published_at = sa.Column(arrow.ArrowType(timezone=True))
17 | published_at_dt = sa.Column(sa.DateTime(timezone=True))
18 | return Article
19 |
20 |
21 | @pytest.fixture
22 | def init_models(Article):
23 | pass
24 |
25 |
26 | @pytest.mark.skipif('arrow.arrow is None')
27 | class TestArrowDateTimeType:
28 | def test_parameter_processing(self, session, Article):
29 | article = Article(
30 | created_at=arrow.arrow.get(datetime(2000, 11, 1))
31 | )
32 |
33 | session.add(article)
34 | session.commit()
35 |
36 | article = session.query(Article).first()
37 | assert article.created_at.datetime
38 |
39 | def test_string_coercion(self, Article):
40 | article = Article(
41 | created_at='2013-01-01'
42 | )
43 | assert article.created_at.year == 2013
44 |
45 | def test_utc(self, session, Article):
46 | time = arrow.arrow.utcnow()
47 | article = Article(created_at=time)
48 | session.add(article)
49 | assert article.created_at == time
50 | session.commit()
51 | assert article.created_at == time
52 |
53 | def test_other_tz(self, session, Article):
54 | time = arrow.arrow.utcnow()
55 | local = time.to('US/Pacific')
56 | article = Article(created_at=local)
57 | session.add(article)
58 | assert article.created_at == time == local
59 | session.commit()
60 | assert article.created_at == time
61 |
62 | def test_literal_param(self, session, Article):
63 | clause = Article.created_at > datetime(2015, 1, 1)
64 | compiled = str(clause.compile(compile_kwargs={"literal_binds": True}))
65 | assert compiled == "article.created_at > '2015-01-01 00:00:00'"
66 |
67 | @pytest.mark.usefixtures('postgresql_dsn')
68 | def test_timezone(self, session, Article):
69 | timezone = tz.gettz('Europe/Stockholm')
70 | dt = arrow.arrow.get(datetime(2015, 1, 1, 15, 30, 45), timezone)
71 | article = Article(published_at=dt, published_at_dt=dt.datetime)
72 |
73 | session.add(article)
74 | session.commit()
75 | session.expunge_all()
76 |
77 | item = session.query(Article).one()
78 | assert item.published_at.datetime == item.published_at_dt
79 | assert item.published_at.to(timezone) == dt
80 |
81 | def test_compilation(self, Article, session):
82 | query = sa.select(Article.created_at)
83 | # the type should be cacheable and not throw exception
84 | session.execute(query)
85 |
--------------------------------------------------------------------------------
/tests/types/test_scalar_list.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | import sqlalchemy.exc
4 |
5 | from sqlalchemy_utils import ScalarListType
6 |
7 |
8 | class TestScalarIntegerList:
9 |
10 | @pytest.fixture
11 | def User(self, Base):
12 | class User(Base):
13 | __tablename__ = 'user'
14 | id = sa.Column(sa.Integer, primary_key=True)
15 | some_list = sa.Column(ScalarListType(int))
16 |
17 | def __repr__(self):
18 | return 'User(%r)' % self.id
19 |
20 | return User
21 |
22 | @pytest.fixture
23 | def init_models(self, User):
24 | pass
25 |
26 | def test_save_integer_list(self, session, User):
27 | user = User(
28 | some_list=[1, 2, 3, 4]
29 | )
30 |
31 | session.add(user)
32 | session.commit()
33 |
34 | user = session.query(User).first()
35 | assert user.some_list == [1, 2, 3, 4]
36 |
37 |
38 | class TestScalarUnicodeList:
39 |
40 | @pytest.fixture
41 | def User(self, Base):
42 | class User(Base):
43 | __tablename__ = 'user'
44 | id = sa.Column(sa.Integer, primary_key=True)
45 | some_list = sa.Column(ScalarListType(str))
46 |
47 | def __repr__(self):
48 | return 'User(%r)' % self.id
49 |
50 | return User
51 |
52 | @pytest.fixture
53 | def init_models(self, User):
54 | pass
55 |
56 | def test_throws_exception_if_using_separator_in_list_values(
57 | self,
58 | session,
59 | User
60 | ):
61 | user = User(
62 | some_list=[',']
63 | )
64 |
65 | session.add(user)
66 | with pytest.raises(sa.exc.StatementError) as db_err:
67 | session.commit()
68 | assert (
69 | "List values can't contain string ',' (its being used as "
70 | "separator. If you wish for scalar list values to contain "
71 | "these strings, use a different separator string.)"
72 | ) in str(db_err.value)
73 |
74 | def test_save_unicode_list(self, session, User):
75 | user = User(
76 | some_list=['1', '2', '3', '4']
77 | )
78 |
79 | session.add(user)
80 | session.commit()
81 |
82 | user = session.query(User).first()
83 | assert user.some_list == ['1', '2', '3', '4']
84 |
85 | def test_save_and_retrieve_empty_list(self, session, User):
86 | user = User(
87 | some_list=[]
88 | )
89 |
90 | session.add(user)
91 | session.commit()
92 |
93 | user = session.query(User).first()
94 | assert user.some_list == []
95 |
96 | def test_compilation(self, User, session):
97 | query = sa.select(User.some_list)
98 | # the type should be cacheable and not throw exception
99 | session.execute(query)
100 |
--------------------------------------------------------------------------------
/tests/types/test_enriched_datetime_pendulum.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 |
6 | from sqlalchemy_utils.types.enriched_datetime import (
7 | enriched_datetime_type,
8 | pendulum_datetime
9 | )
10 |
11 |
12 | @pytest.fixture
13 | def User(Base):
14 | class User(Base):
15 | __tablename__ = 'users'
16 | id = sa.Column(sa.Integer, primary_key=True)
17 | created_at = sa.Column(
18 | enriched_datetime_type.EnrichedDateTimeType(
19 | datetime_processor=pendulum_datetime.PendulumDateTime,
20 | ))
21 | return User
22 |
23 |
24 | @pytest.fixture
25 | def init_models(User):
26 | pass
27 |
28 |
29 | @pytest.mark.skipif('pendulum_datetime.pendulum is None')
30 | class TestPendulumDateTimeType:
31 |
32 | def test_parameter_processing(self, session, User):
33 | user = User(
34 | created_at=pendulum_datetime.pendulum.datetime(1995, 7, 11)
35 | )
36 |
37 | session.add(user)
38 | session.commit()
39 |
40 | user = session.query(User).first()
41 | assert isinstance(user.created_at, datetime)
42 |
43 | def test_int_coercion(self, User):
44 | user = User(
45 | created_at=1367900664
46 | )
47 | assert user.created_at.year == 2013
48 |
49 | def test_float_coercion(self, User):
50 | user = User(
51 | created_at=1367900664.0
52 | )
53 | assert user.created_at.year == 2013
54 |
55 | def test_string_coercion(self, User):
56 | user = User(
57 | created_at='1367900664'
58 | )
59 | assert user.created_at.year == 2013
60 |
61 | def test_utc(self, session, User):
62 | time = pendulum_datetime.pendulum.now("UTC")
63 | user = User(created_at=time)
64 | session.add(user)
65 | assert user.created_at == time
66 | session.commit()
67 | assert user.created_at == time
68 |
69 | @pytest.mark.usefixtures('postgresql_dsn')
70 | def test_utc_postgres(self, session, User):
71 | time = pendulum_datetime.pendulum.now("UTC")
72 | user = User(created_at=time)
73 | session.add(user)
74 | assert user.created_at == time
75 | session.commit()
76 | assert user.created_at == time
77 |
78 | def test_other_tz(self, session, User):
79 | time = pendulum_datetime.pendulum.now("UTC")
80 | local = time.in_tz('Asia/Tokyo')
81 | user = User(created_at=local)
82 | session.add(user)
83 | assert user.created_at == time == local
84 | session.commit()
85 | assert user.created_at == time
86 |
87 | def test_literal_param(self, session, User):
88 | clause = User.created_at > datetime(2015, 1, 1)
89 | compiled = str(clause.compile(compile_kwargs={"literal_binds": True}))
90 | assert compiled == "users.created_at > '2015-01-01 00:00:00'"
91 |
--------------------------------------------------------------------------------
/tests/aggregate/test_m2m_m2m.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | import sqlalchemy.orm
4 |
5 | from sqlalchemy_utils import aggregated
6 |
7 |
8 | @pytest.fixture
9 | def Category(Base):
10 | class Category(Base):
11 | __tablename__ = 'category'
12 | id = sa.Column(sa.Integer, primary_key=True)
13 | name = sa.Column(sa.Unicode(255))
14 | return Category
15 |
16 |
17 | @pytest.fixture
18 | def Catalog(Base, Category):
19 | class Catalog(Base):
20 | __tablename__ = 'catalog'
21 | id = sa.Column(sa.Integer, primary_key=True)
22 | name = sa.Column(sa.Unicode(255))
23 |
24 | @aggregated(
25 | 'products.categories',
26 | sa.Column(sa.Integer, default=0)
27 | )
28 | def category_count(self):
29 | return sa.func.count(sa.distinct(Category.id))
30 | return Catalog
31 |
32 |
33 | @pytest.fixture
34 | def Product(Base, Catalog, Category):
35 | catalog_products = sa.Table(
36 | 'catalog_product',
37 | Base.metadata,
38 | sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')),
39 | sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
40 | )
41 |
42 | product_categories = sa.Table(
43 | 'category_product',
44 | Base.metadata,
45 | sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')),
46 | sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id'))
47 | )
48 |
49 | class Product(Base):
50 | __tablename__ = 'product'
51 | id = sa.Column(sa.Integer, primary_key=True)
52 | name = sa.Column(sa.Unicode(255))
53 | price = sa.Column(sa.Numeric)
54 |
55 | catalog_id = sa.Column(
56 | sa.Integer, sa.ForeignKey('catalog.id')
57 | )
58 |
59 | catalogs = sa.orm.relationship(
60 | Catalog,
61 | backref='products',
62 | secondary=catalog_products
63 | )
64 |
65 | categories = sa.orm.relationship(
66 | Category,
67 | backref='products',
68 | secondary=product_categories
69 | )
70 | return Product
71 |
72 |
73 | @pytest.fixture
74 | def init_models(Category, Catalog, Product):
75 | pass
76 |
77 |
78 | @pytest.mark.usefixtures('postgresql_dsn')
79 | class TestAggregateManyToManyAndManyToMany:
80 |
81 | def test_insert(self, session, Product, Category, Catalog):
82 | category = Category()
83 | session.add(category)
84 | products = [
85 | Product(categories=[category]),
86 | Product(categories=[category])
87 | ]
88 | [session.add(product) for product in products]
89 | catalog = Catalog(products=products)
90 | session.add(catalog)
91 | catalog2 = Catalog(products=products)
92 | session.add(catalog2)
93 | session.commit()
94 | assert catalog.category_count == 1
95 | assert catalog2.category_count == 1
96 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/scalar_list.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy as sa
2 | from sqlalchemy import types
3 |
4 |
5 | class ScalarListException(Exception):
6 | pass
7 |
8 |
9 | class ScalarListType(types.TypeDecorator):
10 | """
11 | ScalarListType type provides convenient way for saving multiple scalar
12 | values in one column. ScalarListType works like list on python side and
13 | saves the result as comma-separated list in the database (custom separators
14 | can also be used).
15 |
16 | Example ::
17 |
18 |
19 | from sqlalchemy_utils import ScalarListType
20 |
21 |
22 | class User(Base):
23 | __tablename__ = 'user'
24 | id = sa.Column(sa.Integer, autoincrement=True)
25 | hobbies = sa.Column(ScalarListType())
26 |
27 |
28 | user = User()
29 | user.hobbies = ['football', 'ice_hockey']
30 | session.commit()
31 |
32 |
33 | You can easily set up integer lists too:
34 |
35 | ::
36 |
37 |
38 | from sqlalchemy_utils import ScalarListType
39 |
40 |
41 | class Player(Base):
42 | __tablename__ = 'player'
43 | id = sa.Column(sa.Integer, autoincrement=True)
44 | points = sa.Column(ScalarListType(int))
45 |
46 |
47 | player = Player()
48 | player.points = [11, 12, 8, 80]
49 | session.commit()
50 |
51 |
52 | ScalarListType is always stored as text. To use an array field on
53 | PostgreSQL database use variant construct::
54 |
55 | from sqlalchemy_utils import ScalarListType
56 |
57 |
58 | class Player(Base):
59 | __tablename__ = 'player'
60 | id = sa.Column(sa.Integer, autoincrement=True)
61 | points = sa.Column(
62 | ARRAY(Integer).with_variant(ScalarListType(int), 'sqlite')
63 | )
64 |
65 |
66 | """
67 |
68 | impl = sa.UnicodeText()
69 |
70 | cache_ok = True
71 |
72 | def __init__(self, coerce_func=str, separator=','):
73 | self.separator = str(separator)
74 | self.coerce_func = coerce_func
75 |
76 | def process_bind_param(self, value, dialect):
77 | # Convert list of values to unicode separator-separated list
78 | # Example: [1, 2, 3, 4] -> '1, 2, 3, 4'
79 | if value is not None:
80 | if any(self.separator in str(item) for item in value):
81 | raise ScalarListException(
82 | "List values can't contain string '%s' (its being used as "
83 | 'separator. If you wish for scalar list values to contain '
84 | 'these strings, use a different separator string.)' % self.separator
85 | )
86 | return self.separator.join(map(str, value))
87 |
88 | def process_result_value(self, value, dialect):
89 | if value is not None:
90 | if value == '':
91 | return []
92 | # coerce each value
93 | return list(map(self.coerce_func, value.split(self.separator)))
94 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timezone
2 |
3 | import pytest
4 | import sqlalchemy as sa
5 | import sqlalchemy.orm
6 |
7 | from sqlalchemy_utils import generic_repr, Timestamp
8 |
9 |
10 | class TestTimestamp:
11 | @pytest.fixture
12 | def Article(self, Base):
13 | class Article(Base, Timestamp):
14 | __tablename__ = 'article'
15 | id = sa.Column(sa.Integer, primary_key=True)
16 | name = sa.Column(sa.Unicode(255), default='Some article')
17 | return Article
18 |
19 | def test_created(self, session, Article):
20 | then = datetime.now(tz=timezone.utc).replace(tzinfo=None)
21 | article = Article()
22 |
23 | session.add(article)
24 | session.commit()
25 |
26 | assert article.created >= then
27 | assert article.created <= datetime.now(tz=timezone.utc).replace(tzinfo=None)
28 |
29 | def test_updated(self, session, Article):
30 | article = Article()
31 |
32 | session.add(article)
33 | session.commit()
34 |
35 | then = datetime.now(tz=timezone.utc).replace(tzinfo=None)
36 | article.name = "Something"
37 |
38 | session.commit()
39 |
40 | assert article.updated >= then
41 | assert article.updated <= datetime.now(tz=timezone.utc).replace(tzinfo=None)
42 |
43 |
44 | class TestGenericRepr:
45 | @pytest.fixture
46 | def Article(self, Base):
47 | class Article(Base):
48 | __tablename__ = 'article'
49 | id = sa.Column(sa.Integer, primary_key=True)
50 | name = sa.Column(sa.Unicode(255), default='Some article')
51 | return Article
52 |
53 | def test_repr(self, Article):
54 | """Representation of a basic model."""
55 | Article = generic_repr(Article)
56 | article = Article(id=1, name='Foo')
57 | expected_repr = "Article(id=1, name='Foo')"
58 | actual_repr = repr(article)
59 |
60 | assert actual_repr == expected_repr
61 |
62 | def test_repr_partial(self, Article):
63 | """Representation of a basic model with selected fields."""
64 | Article = generic_repr('id')(Article)
65 | article = Article(id=1, name='Foo')
66 | expected_repr = 'Article(id=1)'
67 | actual_repr = repr(article)
68 |
69 | assert actual_repr == expected_repr
70 |
71 | def test_not_loaded(self, session, Article):
72 | """:py:func:`~sqlalchemy_utils.models.generic_repr` doesn't force
73 | execution of additional queries if some fields are not loaded and
74 | instead represents them as "".
75 | """
76 | Article = generic_repr(Article)
77 | article = Article(name='Foo')
78 | session.add(article)
79 | session.commit()
80 |
81 | article = session.query(Article).options(sa.orm.defer(Article.name)).one()
82 | actual_repr = repr(article)
83 |
84 | expected_repr = f'Article(id={article.id}, name=)'
85 | assert actual_repr == expected_repr
86 |
--------------------------------------------------------------------------------
/tests/generic_relationship/__init__.py:
--------------------------------------------------------------------------------
1 | class GenericRelationshipTestCase:
2 | def test_set_as_none(self, Event):
3 | event = Event()
4 | event.object = None
5 | assert event.object is None
6 |
7 | def test_set_manual_and_get(self, session, User, Event):
8 | user = User()
9 |
10 | session.add(user)
11 | session.commit()
12 |
13 | event = Event()
14 | event.object_id = user.id
15 | event.object_type = str(type(user).__name__)
16 |
17 | assert event.object is None
18 |
19 | session.add(event)
20 | session.commit()
21 |
22 | assert event.object == user
23 |
24 | def test_set_and_get(self, session, User, Event):
25 | user = User()
26 |
27 | session.add(user)
28 | session.commit()
29 |
30 | event = Event(object=user)
31 |
32 | assert event.object_id == user.id
33 | assert event.object_type == type(user).__name__
34 |
35 | session.add(event)
36 | session.commit()
37 |
38 | assert event.object == user
39 |
40 | def test_compare_instance(self, session, User, Event):
41 | user1 = User()
42 | user2 = User()
43 |
44 | session.add_all([user1, user2])
45 | session.commit()
46 |
47 | event = Event(object=user1)
48 |
49 | session.add(event)
50 | session.commit()
51 |
52 | assert event.object == user1
53 | assert event.object != user2
54 |
55 | def test_compare_query(self, session, User, Event):
56 | user1 = User()
57 | user2 = User()
58 |
59 | session.add_all([user1, user2])
60 | session.commit()
61 |
62 | event = Event(object=user1)
63 |
64 | session.add(event)
65 | session.commit()
66 |
67 | q = session.query(Event)
68 | assert q.filter_by(object=user1).first() is not None
69 | assert q.filter_by(object=user2).first() is None
70 | assert q.filter(Event.object == user2).first() is None
71 |
72 | def test_compare_not_query(self, session, User, Event):
73 | user1 = User()
74 | user2 = User()
75 |
76 | session.add_all([user1, user2])
77 | session.commit()
78 |
79 | event = Event(object=user1)
80 |
81 | session.add(event)
82 | session.commit()
83 |
84 | q = session.query(Event)
85 | assert q.filter(Event.object != user2).first() is not None
86 |
87 | def test_compare_type(self, session, User, Event):
88 | user1 = User()
89 | user2 = User()
90 |
91 | session.add_all([user1, user2])
92 | session.commit()
93 |
94 | event1 = Event(object=user1)
95 | event2 = Event(object=user2)
96 |
97 | session.add_all([event1, event2])
98 | session.commit()
99 |
100 | statement = Event.object.is_type(User)
101 | q = session.query(Event).filter(statement)
102 | assert q.first() is not None
103 |
--------------------------------------------------------------------------------
/tests/observes/test_o2o_o2o_o2o.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.observer import observes
5 |
6 |
7 | @pytest.fixture
8 | def Catalog(Base):
9 | class Catalog(Base):
10 | __tablename__ = 'catalog'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | product_price = sa.Column(sa.Integer)
13 |
14 | @observes('category.sub_category.product')
15 | def product_observer(self, product):
16 | self.product_price = product.price if product else None
17 |
18 | category = sa.orm.relationship(
19 | 'Category',
20 | uselist=False,
21 | backref='catalog'
22 | )
23 | return Catalog
24 |
25 |
26 | @pytest.fixture
27 | def Category(Base):
28 | class Category(Base):
29 | __tablename__ = 'category'
30 | id = sa.Column(sa.Integer, primary_key=True)
31 | catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
32 |
33 | sub_category = sa.orm.relationship(
34 | 'SubCategory',
35 | uselist=False,
36 | backref='category'
37 | )
38 | return Category
39 |
40 |
41 | @pytest.fixture
42 | def SubCategory(Base):
43 | class SubCategory(Base):
44 | __tablename__ = 'sub_category'
45 | id = sa.Column(sa.Integer, primary_key=True)
46 | category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
47 | product = sa.orm.relationship(
48 | 'Product',
49 | uselist=False,
50 | backref='sub_category'
51 | )
52 | return SubCategory
53 |
54 |
55 | @pytest.fixture
56 | def Product(Base):
57 | class Product(Base):
58 | __tablename__ = 'product'
59 | id = sa.Column(sa.Integer, primary_key=True)
60 | price = sa.Column(sa.Integer)
61 |
62 | sub_category_id = sa.Column(
63 | sa.Integer, sa.ForeignKey('sub_category.id')
64 | )
65 | return Product
66 |
67 |
68 | @pytest.fixture
69 | def init_models(Catalog, Category, SubCategory, Product):
70 | pass
71 |
72 |
73 | @pytest.fixture
74 | def catalog(session, Catalog, Category, SubCategory, Product):
75 | sub_category = SubCategory(product=Product(price=123))
76 | category = Category(sub_category=sub_category)
77 | catalog = Catalog(category=category)
78 | session.add(catalog)
79 | session.flush()
80 | return catalog
81 |
82 |
83 | @pytest.mark.usefixtures('postgresql_dsn')
84 | class TestObservesForOneToOneToOneToOne:
85 |
86 | def test_simple_insert(self, catalog):
87 | assert catalog.product_price == 123
88 |
89 | def test_replace_leaf_object(self, catalog, session, Product):
90 | product = Product(price=44)
91 | catalog.category.sub_category.product = product
92 | session.flush()
93 | assert catalog.product_price == 44
94 |
95 | def test_delete_leaf_object(self, catalog, session):
96 | session.delete(catalog.category.sub_category.product)
97 | session.flush()
98 | assert catalog.product_price is None
99 |
--------------------------------------------------------------------------------
/tests/aggregate/test_multiple_aggregates_per_class.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils.aggregates import aggregated
5 |
6 |
7 | @pytest.fixture
8 | def Comment(Base):
9 | class Comment(Base):
10 | __tablename__ = 'comment'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | content = sa.Column(sa.Unicode(255))
13 | thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id'))
14 | return Comment
15 |
16 |
17 | @pytest.fixture
18 | def Thread(Base, Comment):
19 | class Thread(Base):
20 | __tablename__ = 'thread'
21 | id = sa.Column(sa.Integer, primary_key=True)
22 | name = sa.Column(sa.Unicode(255))
23 |
24 | @aggregated(
25 | 'comments',
26 | sa.Column(sa.Integer, default=0)
27 | )
28 | def comment_count(self):
29 | return sa.func.count('1')
30 |
31 | @aggregated('comments', sa.Column(sa.Integer))
32 | def last_comment_id(self):
33 | return sa.func.max(Comment.id)
34 |
35 | comments = sa.orm.relationship(
36 | 'Comment',
37 | backref='thread'
38 | )
39 |
40 | Thread.last_comment = sa.orm.relationship(
41 | 'Comment',
42 | primaryjoin='Thread.last_comment_id == Comment.id',
43 | foreign_keys=[Thread.last_comment_id],
44 | viewonly=True
45 | )
46 | return Thread
47 |
48 |
49 | @pytest.fixture
50 | def init_models(Comment, Thread):
51 | pass
52 |
53 |
54 | class TestAggregateValueGenerationForSimpleModelPaths:
55 |
56 | def test_assigns_aggregates_on_insert(self, session, Thread, Comment):
57 | thread = Thread()
58 | thread.name = 'some article name'
59 | session.add(thread)
60 | comment = Comment(content='Some content', thread=thread)
61 | session.add(comment)
62 | session.commit()
63 | session.refresh(thread)
64 | assert thread.comment_count == 1
65 | assert thread.last_comment_id == comment.id
66 |
67 | def test_assigns_aggregates_on_separate_insert(
68 | self,
69 | session,
70 | Thread,
71 | Comment
72 | ):
73 | thread = Thread()
74 | thread.name = 'some article name'
75 | session.add(thread)
76 | session.commit()
77 | comment = Comment(content='Some content', thread=thread)
78 | session.add(comment)
79 | session.commit()
80 | session.refresh(thread)
81 | assert thread.comment_count == 1
82 | assert thread.last_comment_id == 1
83 |
84 | def test_assigns_aggregates_on_delete(self, session, Thread, Comment):
85 | thread = Thread()
86 | thread.name = 'some article name'
87 | session.add(thread)
88 | session.commit()
89 | comment = Comment(content='Some content', thread=thread)
90 | session.add(comment)
91 | session.commit()
92 | session.delete(comment)
93 | session.commit()
94 | session.refresh(thread)
95 | assert thread.comment_count == 0
96 | assert thread.last_comment_id is None
97 |
--------------------------------------------------------------------------------
/tests/primitives/test_country.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 | import pytest
4 |
5 | from sqlalchemy_utils import Country, i18n
6 |
7 |
8 | @pytest.fixture
9 | def set_get_locale():
10 | i18n.get_locale = lambda: i18n.babel.Locale('en')
11 |
12 |
13 | @pytest.mark.skipif('i18n.babel is None')
14 | @pytest.mark.usefixtures('set_get_locale')
15 | class TestCountry:
16 |
17 | def test_init(self):
18 | assert Country('FI') == Country(Country('FI'))
19 |
20 | def test_constructor_with_wrong_type(self):
21 | with pytest.raises(TypeError) as e:
22 | Country(None)
23 | assert str(e.value) == (
24 | "Country() argument must be a string or a country, not 'NoneType'"
25 | )
26 |
27 | def test_constructor_with_invalid_code(self):
28 | with pytest.raises(ValueError) as e:
29 | Country('SomeUnknownCode')
30 | assert str(e.value) == (
31 | 'Could not convert string to country code: SomeUnknownCode'
32 | )
33 |
34 | @pytest.mark.parametrize(
35 | 'code',
36 | (
37 | 'FI',
38 | 'US',
39 | )
40 | )
41 | def test_validate_with_valid_codes(self, code):
42 | Country.validate(code)
43 |
44 | def test_validate_with_invalid_code(self):
45 | with pytest.raises(ValueError) as e:
46 | Country.validate('SomeUnknownCode')
47 | assert str(e.value) == (
48 | 'Could not convert string to country code: SomeUnknownCode'
49 | )
50 |
51 | def test_equality_operator(self):
52 | assert Country('FI') == 'FI'
53 | assert 'FI' == Country('FI')
54 | assert Country('FI') == Country('FI')
55 |
56 | def test_non_equality_operator(self):
57 | assert Country('FI') != 'sv'
58 | assert not (Country('FI') != 'FI')
59 |
60 | @pytest.mark.parametrize(
61 | 'op, code_left, code_right, is_',
62 | [
63 | (operator.lt, 'ES', 'FI', True),
64 | (operator.lt, 'FI', 'ES', False),
65 | (operator.lt, 'ES', 'ES', False),
66 |
67 | (operator.le, 'ES', 'FI', True),
68 | (operator.le, 'FI', 'ES', False),
69 | (operator.le, 'ES', 'ES', True),
70 |
71 | (operator.ge, 'ES', 'FI', False),
72 | (operator.ge, 'FI', 'ES', True),
73 | (operator.ge, 'ES', 'ES', True),
74 |
75 | (operator.gt, 'ES', 'FI', False),
76 | (operator.gt, 'FI', 'ES', True),
77 | (operator.gt, 'ES', 'ES', False),
78 | ]
79 | )
80 | def test_ordering(self, op, code_left, code_right, is_):
81 | country_left = Country(code_left)
82 | country_right = Country(code_right)
83 | assert op(country_left, country_right) is is_
84 | assert op(country_left, code_right) is is_
85 | assert op(code_left, country_right) is is_
86 |
87 | def test_hash(self):
88 | assert hash(Country('FI')) == hash('FI')
89 |
90 | def test_repr(self):
91 | assert repr(Country('FI')) == "Country('FI')"
92 |
93 | def test_str(self):
94 | country = Country('FI')
95 | assert str(country) == 'Finland'
96 |
--------------------------------------------------------------------------------
/tests/test_query_chain.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 |
4 | from sqlalchemy_utils import QueryChain
5 |
6 |
7 | @pytest.fixture
8 | def User(Base):
9 | class User(Base):
10 | __tablename__ = 'user'
11 | id = sa.Column(sa.Integer, primary_key=True)
12 | return User
13 |
14 |
15 | @pytest.fixture
16 | def Article(Base):
17 | class Article(Base):
18 | __tablename__ = 'article'
19 | id = sa.Column(sa.Integer, primary_key=True)
20 | return Article
21 |
22 |
23 | @pytest.fixture
24 | def BlogPost(Base):
25 | class BlogPost(Base):
26 | __tablename__ = 'blog_post'
27 | id = sa.Column(sa.Integer, primary_key=True)
28 | return BlogPost
29 |
30 |
31 | @pytest.fixture
32 | def init_models(User, Article, BlogPost):
33 | pass
34 |
35 |
36 | @pytest.fixture
37 | def users(session, User):
38 | users = [User(), User()]
39 | session.add_all(users)
40 | session.commit()
41 | return users
42 |
43 |
44 | @pytest.fixture
45 | def articles(session, Article):
46 | articles = [Article(), Article(), Article(), Article()]
47 | session.add_all(articles)
48 | session.commit()
49 | return articles
50 |
51 |
52 | @pytest.fixture
53 | def posts(session, BlogPost):
54 | posts = [BlogPost(), BlogPost(), BlogPost()]
55 | session.add_all(posts)
56 | session.commit()
57 | return posts
58 |
59 |
60 | @pytest.fixture
61 | def chain(session, users, articles, posts, User, Article, BlogPost):
62 | return QueryChain(
63 | [
64 | session.query(User).order_by('id'),
65 | session.query(Article).order_by('id'),
66 | session.query(BlogPost).order_by('id')
67 | ]
68 | )
69 |
70 |
71 | class TestQueryChain:
72 |
73 | def test_iter(self, chain):
74 | assert len(list(chain)) == 9
75 |
76 | def test_iter_with_limit(self, chain, users, articles):
77 | c = chain.limit(4)
78 | objects = list(c)
79 | assert users == objects[0:2]
80 | assert articles[0:2] == objects[2:]
81 |
82 | def test_iter_with_offset(self, chain, articles, posts):
83 | c = chain.offset(3)
84 | objects = list(c)
85 | assert articles[1:] + posts == objects
86 |
87 | def test_iter_with_limit_and_offset(self, chain, articles, posts):
88 | c = chain.offset(3).limit(4)
89 | objects = list(c)
90 | assert articles[1:] + posts[0:1] == objects
91 |
92 | def test_iter_with_offset_spanning_multiple_queries(self, chain, posts):
93 | c = chain.offset(7)
94 | objects = list(c)
95 | assert posts[1:] == objects
96 |
97 | def test_repr(self, chain):
98 | assert repr(chain) == '' % id(chain)
99 |
100 | def test_getitem_with_slice(self, chain):
101 | c = chain[1:]
102 | assert c._offset == 1
103 | assert c._limit is None
104 |
105 | def test_getitem_with_single_key(self, chain, articles):
106 | article = chain[2]
107 | assert article == articles[0]
108 |
109 | def test_count(self, chain):
110 | assert chain.count() == 9
111 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/types/timezone.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import types
2 |
3 | from ..exceptions import ImproperlyConfigured
4 | from .scalar_coercible import ScalarCoercible
5 |
6 |
7 | class TimezoneType(ScalarCoercible, types.TypeDecorator):
8 | """
9 | TimezoneType provides a way for saving timezones objects into database.
10 | TimezoneType saves timezone objects as strings on the way in and converts
11 | them back to objects when querying the database.
12 |
13 | ::
14 |
15 | from sqlalchemy_utils import TimezoneType
16 |
17 | class User(Base):
18 | __tablename__ = 'user'
19 |
20 | # Pass backend='pytz' to change it to use pytz. Other values:
21 | # 'dateutil' (default), and 'zoneinfo'.
22 | timezone = sa.Column(TimezoneType(backend='pytz'))
23 |
24 | :param backend:
25 | Whether to use 'dateutil', 'pytz' or 'zoneinfo' for timezones.
26 | 'zoneinfo' uses the standard library module.
27 |
28 | """
29 |
30 | impl = types.Unicode(50)
31 | python_type = None
32 | cache_ok = True
33 |
34 | def __init__(self, backend='dateutil'):
35 | self.backend = backend
36 | if backend == 'dateutil':
37 | try:
38 | from dateutil.tz import tzfile
39 | from dateutil.zoneinfo import get_zonefile_instance
40 |
41 | self.python_type = tzfile
42 | self._to = get_zonefile_instance().zones.get
43 | self._from = lambda x: str(x._filename)
44 |
45 | except ImportError:
46 | raise ImproperlyConfigured(
47 | "'python-dateutil' is required to use the "
48 | "'dateutil' backend for 'TimezoneType'"
49 | )
50 |
51 | elif backend == 'pytz':
52 | try:
53 | from pytz import timezone
54 | from pytz.tzinfo import BaseTzInfo
55 |
56 | self.python_type = BaseTzInfo
57 | self._to = timezone
58 | self._from = str
59 |
60 | except ImportError:
61 | raise ImproperlyConfigured(
62 | "'pytz' is required to use the 'pytz' backend for 'TimezoneType'"
63 | )
64 |
65 | elif backend == 'zoneinfo':
66 | import zoneinfo
67 |
68 | self.python_type = zoneinfo.ZoneInfo
69 | self._to = zoneinfo.ZoneInfo
70 | self._from = str
71 |
72 | else:
73 | raise ImproperlyConfigured(
74 | "'pytz', 'dateutil' or 'zoneinfo' are the backends "
75 | "supported for 'TimezoneType'"
76 | )
77 |
78 | def _coerce(self, value):
79 | if value is not None and not isinstance(value, self.python_type):
80 | obj = self._to(value)
81 | if obj is None:
82 | raise ValueError("unknown time zone '%s'" % value)
83 | return obj
84 | return value
85 |
86 | def process_bind_param(self, value, dialect):
87 | return self._from(self._coerce(value)) if value else None
88 |
89 | def process_result_value(self, value, dialect):
90 | return self._to(value) if value else None
91 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/primitives/country.py:
--------------------------------------------------------------------------------
1 | from functools import total_ordering
2 |
3 | from .. import i18n
4 | from ..utils import str_coercible
5 |
6 |
7 | @total_ordering
8 | @str_coercible
9 | class Country:
10 | """
11 | Country class wraps a 2 to 3 letter country code. It provides various
12 | convenience properties and methods.
13 |
14 | ::
15 |
16 | from babel import Locale
17 | from sqlalchemy_utils import Country, i18n
18 |
19 |
20 | # First lets add a locale getter for testing purposes
21 | i18n.get_locale = lambda: Locale('en')
22 |
23 |
24 | Country('FI').name # Finland
25 | Country('FI').code # FI
26 |
27 | Country(Country('FI')).code # 'FI'
28 |
29 | Country always validates the given code if you use at least the optional
30 | dependency list 'babel', otherwise no validation are performed.
31 |
32 | ::
33 |
34 | Country(None) # raises TypeError
35 |
36 | Country('UnknownCode') # raises ValueError
37 |
38 |
39 | Country supports equality operators.
40 |
41 | ::
42 |
43 | Country('FI') == Country('FI')
44 | Country('FI') != Country('US')
45 |
46 |
47 | Country objects are hashable.
48 |
49 |
50 | ::
51 |
52 | assert hash(Country('FI')) == hash('FI')
53 |
54 | """
55 |
56 | def __init__(self, code_or_country):
57 | if isinstance(code_or_country, Country):
58 | self.code = code_or_country.code
59 | elif isinstance(code_or_country, str):
60 | self.validate(code_or_country)
61 | self.code = code_or_country
62 | else:
63 | raise TypeError(
64 | "Country() argument must be a string or a country, not '{}'".format(
65 | type(code_or_country).__name__
66 | )
67 | )
68 |
69 | @property
70 | def name(self):
71 | return i18n.get_locale().territories[self.code]
72 |
73 | @classmethod
74 | def validate(self, code):
75 | try:
76 | i18n.babel.Locale('en').territories[code]
77 | except KeyError:
78 | raise ValueError(f'Could not convert string to country code: {code}')
79 | except AttributeError:
80 | # As babel is optional, we may raise an AttributeError accessing it
81 | pass
82 |
83 | def __eq__(self, other):
84 | if isinstance(other, Country):
85 | return self.code == other.code
86 | elif isinstance(other, str):
87 | return self.code == other
88 | else:
89 | return NotImplemented
90 |
91 | def __hash__(self):
92 | return hash(self.code)
93 |
94 | def __ne__(self, other):
95 | return not (self == other)
96 |
97 | def __lt__(self, other):
98 | if isinstance(other, Country):
99 | return self.code < other.code
100 | elif isinstance(other, str):
101 | return self.code < other
102 | return NotImplemented
103 |
104 | def __repr__(self):
105 | return f'{self.__class__.__name__}({self.code!r})'
106 |
107 | def __unicode__(self):
108 | return self.name
109 |
--------------------------------------------------------------------------------
/sqlalchemy_utils/primitives/currency.py:
--------------------------------------------------------------------------------
1 | from .. import i18n, ImproperlyConfigured
2 | from ..utils import str_coercible
3 |
4 |
5 | @str_coercible
6 | class Currency:
7 | """
8 | Currency class wraps a 3-letter currency code. It provides various
9 | convenience properties and methods.
10 |
11 | ::
12 |
13 | from babel import Locale
14 | from sqlalchemy_utils import Currency, i18n
15 |
16 |
17 | # First lets add a locale getter for testing purposes
18 | i18n.get_locale = lambda: Locale('en')
19 |
20 |
21 | Currency('USD').name # US Dollar
22 | Currency('USD').symbol # $
23 |
24 | Currency(Currency('USD')).code # 'USD'
25 |
26 | Currency always validates the given code if you use at least the optional
27 | dependency list 'babel', otherwise no validation are performed.
28 |
29 | ::
30 |
31 | Currency(None) # raises TypeError
32 |
33 | Currency('UnknownCode') # raises ValueError
34 |
35 |
36 | Currency supports equality operators.
37 |
38 | ::
39 |
40 | Currency('USD') == Currency('USD')
41 | Currency('USD') != Currency('EUR')
42 |
43 |
44 | Currencies are hashable.
45 |
46 |
47 | ::
48 |
49 | len(set([Currency('USD'), Currency('USD')])) # 1
50 |
51 |
52 | """
53 |
54 | def __init__(self, code):
55 | if i18n.babel is None:
56 | raise ImproperlyConfigured(
57 | "'babel' package is required in order to use Currency class."
58 | )
59 | if isinstance(code, Currency):
60 | self.code = code
61 | elif isinstance(code, str):
62 | self.validate(code)
63 | self.code = code
64 | else:
65 | raise TypeError(
66 | 'First argument given to Currency constructor should be '
67 | 'either an instance of Currency or valid three letter '
68 | 'currency code.'
69 | )
70 |
71 | @classmethod
72 | def validate(self, code):
73 | try:
74 | i18n.babel.Locale('en').currencies[code]
75 | except KeyError:
76 | raise ValueError(f"'{code}' is not valid currency code.")
77 | except AttributeError:
78 | # As babel is optional, we may raise an AttributeError accessing it
79 | pass
80 |
81 | @property
82 | def symbol(self):
83 | return i18n.babel.numbers.get_currency_symbol(self.code, i18n.get_locale())
84 |
85 | @property
86 | def name(self):
87 | return i18n.get_locale().currencies[self.code]
88 |
89 | def __eq__(self, other):
90 | if isinstance(other, Currency):
91 | return self.code == other.code
92 | elif isinstance(other, str):
93 | return self.code == other
94 | else:
95 | return NotImplemented
96 |
97 | def __ne__(self, other):
98 | return not (self == other)
99 |
100 | def __hash__(self):
101 | return hash(self.code)
102 |
103 | def __repr__(self):
104 | return f'{self.__class__.__name__}({self.code!r})'
105 |
106 | def __unicode__(self):
107 | return self.code
108 |
--------------------------------------------------------------------------------
/tests/test_auto_delete_orphans.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import sqlalchemy as sa
3 | from sqlalchemy.orm import backref
4 |
5 | from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured
6 |
7 |
8 | @pytest.fixture
9 | def tagging_tbl(Base):
10 | return sa.Table(
11 | 'tagging',
12 | Base.metadata,
13 | sa.Column(
14 | 'tag_id',
15 | sa.Integer,
16 | sa.ForeignKey('tag.id', ondelete='cascade'),
17 | primary_key=True
18 | ),
19 | sa.Column(
20 | 'entry_id',
21 | sa.Integer,
22 | sa.ForeignKey('entry.id', ondelete='cascade'),
23 | primary_key=True
24 | )
25 | )
26 |
27 |
28 | @pytest.fixture
29 | def Tag(Base):
30 | class Tag(Base):
31 | __tablename__ = 'tag'
32 | id = sa.Column(sa.Integer, primary_key=True)
33 | name = sa.Column(sa.String(100), unique=True, nullable=False)
34 |
35 | def __init__(self, name=None):
36 | self.name = name
37 | return Tag
38 |
39 |
40 | @pytest.fixture(
41 | params=['entries', backref('entries', lazy='select')],
42 | ids=['backref_string', 'backref_with_keywords']
43 | )
44 | def Entry(Base, Tag, tagging_tbl, request):
45 | class Entry(Base):
46 | __tablename__ = 'entry'
47 |
48 | id = sa.Column(sa.Integer, primary_key=True)
49 |
50 | tags = sa.orm.relationship(
51 | Tag,
52 | secondary=tagging_tbl,
53 | backref=request.param
54 | )
55 | auto_delete_orphans(Entry.tags)
56 | return Entry
57 |
58 |
59 | @pytest.fixture
60 | def EntryWithoutTagsBackref(Base, Tag, tagging_tbl):
61 | class EntryWithoutTagsBackref(Base):
62 | __tablename__ = 'entry'
63 |
64 | id = sa.Column(sa.Integer, primary_key=True)
65 |
66 | tags = sa.orm.relationship(
67 | Tag,
68 | secondary=tagging_tbl
69 | )
70 | return EntryWithoutTagsBackref
71 |
72 |
73 | class TestAutoDeleteOrphans:
74 |
75 | @pytest.fixture
76 | def init_models(self, Entry, Tag):
77 | pass
78 |
79 | def test_orphan_deletion(self, session, Entry, Tag):
80 | r1 = Entry()
81 | r2 = Entry()
82 | r3 = Entry()
83 | t1, t2, t3, t4 = (
84 | Tag('t1'),
85 | Tag('t2'),
86 | Tag('t3'),
87 | Tag('t4')
88 | )
89 |
90 | r1.tags.extend([t1, t2])
91 | r2.tags.extend([t2, t3])
92 | r3.tags.extend([t4])
93 | session.add_all([r1, r2, r3])
94 |
95 | assert session.query(Tag).count() == 4
96 | r2.tags.remove(t2)
97 | assert session.query(Tag).count() == 4
98 | r1.tags.remove(t2)
99 | assert session.query(Tag).count() == 3
100 | r1.tags.remove(t1)
101 | assert session.query(Tag).count() == 2
102 |
103 |
104 | class TestAutoDeleteOrphansWithoutBackref:
105 |
106 | @pytest.fixture
107 | def init_models(self, EntryWithoutTagsBackref, Tag):
108 | pass
109 |
110 | def test_orphan_deletion(self, EntryWithoutTagsBackref):
111 | with pytest.raises(ImproperlyConfigured):
112 | auto_delete_orphans(EntryWithoutTagsBackref.tags)
113 |
--------------------------------------------------------------------------------