├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── HISTORY.md ├── LICENSE ├── Makefile ├── README.md ├── pydantic_extra_types ├── __init__.py ├── color.py ├── coordinate.py ├── country.py ├── currency_code.py ├── domain.py ├── epoch.py ├── isbn.py ├── language_code.py ├── mac_address.py ├── mongo_object_id.py ├── path.py ├── payment.py ├── pendulum_dt.py ├── phone_numbers.py ├── py.typed ├── routing_number.py ├── s3.py ├── script_code.py ├── semantic_version.py ├── semver.py ├── timezone_name.py └── ulid.py ├── pyproject.toml ├── tests ├── __init__.py ├── test_coordinate.py ├── test_country_code.py ├── test_currency_code.py ├── test_domain.py ├── test_epoch.py ├── test_isbn.py ├── test_json_schema.py ├── test_language_codes.py ├── test_mac_address.py ├── test_mongo_object_id.py ├── test_path.py ├── test_pendulum_dt.py ├── test_phone_numbers.py ├── test_phone_numbers_validator.py ├── test_routing_number.py ├── test_s3.py ├── test_scripts.py ├── test_semantic_version.py ├── test_semver.py ├── test_timezone_names.py ├── test_types_color.py ├── test_types_payment.py └── test_ulid.py └── uv.lock /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | # GitHub Actions 5 | - package-ecosystem: "github-actions" 6 | directory: "/" 7 | schedule: 8 | interval: "monthly" 9 | commit-message: 10 | prefix: ⬆ 11 | # Python 12 | - package-ecosystem: "uv" 13 | directory: "/" 14 | schedule: 15 | interval: "monthly" 16 | groups: 17 | python-packages: 18 | patterns: 19 | - "*" 20 | commit-message: 21 | prefix: ⬆ 22 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - '**' 9 | pull_request: {} 10 | 11 | env: 12 | COLUMNS: 150 13 | 14 | jobs: 15 | lint: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | fail-fast: false 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - uses: astral-sh/setup-uv@v6 23 | with: 24 | enable-cache: true 25 | 26 | - name: Install dependencies 27 | run: uv sync --python 3.12 --group lint --all-extras 28 | 29 | - uses: pre-commit/action@v3.0.1 30 | with: 31 | extra_args: --all-files --verbose 32 | env: 33 | SKIP: no-commit-to-branch 34 | 35 | test: 36 | name: test py${{ matrix.python-version }} 37 | runs-on: ubuntu-latest 38 | 39 | strategy: 40 | fail-fast: false 41 | matrix: 42 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 43 | 44 | env: 45 | UV_PYTHON: ${{ matrix.python-version }} 46 | 47 | steps: 48 | - uses: actions/checkout@v4 49 | 50 | - uses: astral-sh/setup-uv@v6 51 | with: 52 | enable-cache: true 53 | 54 | - name: Install dependencies 55 | run: uv sync --extra all 56 | 57 | - name: Make coverage directory 58 | run: mkdir coverage 59 | 60 | - run: uv run --frozen coverage run -m pytest 61 | env: 62 | COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }} 63 | 64 | - name: store coverage files 65 | uses: actions/upload-artifact@v4 66 | with: 67 | name: coverage-${{ matrix.python-version }} 68 | path: coverage 69 | include-hidden-files: true 70 | 71 | coverage: 72 | runs-on: ubuntu-latest 73 | needs: [test] 74 | steps: 75 | - uses: actions/checkout@v4 76 | 77 | - uses: astral-sh/setup-uv@v6 78 | with: 79 | enable-cache: true 80 | 81 | - name: get coverage files 82 | uses: actions/download-artifact@v4 83 | with: 84 | merge-multiple: true 85 | path: coverage 86 | 87 | - run: uv run --frozen coverage combine coverage 88 | 89 | - run: uv run --frozen coverage report --fail-under 85 90 | 91 | # https://github.com/marketplace/actions/alls-green#why used for branch protection checks 92 | check: 93 | if: always() 94 | needs: [lint, test, coverage] 95 | runs-on: ubuntu-latest 96 | steps: 97 | - name: Decide whether the needed jobs succeeded or failed 98 | uses: re-actors/alls-green@release/v1 99 | with: 100 | jobs: ${{ toJSON(needs) }} 101 | 102 | release: 103 | needs: [check] 104 | if: "success() && startsWith(github.ref, 'refs/tags/')" 105 | runs-on: ubuntu-latest 106 | environment: release 107 | 108 | permissions: 109 | id-token: write 110 | 111 | steps: 112 | - uses: actions/checkout@v4 113 | 114 | - uses: astral-sh/setup-uv@v6 115 | with: 116 | enable-cache: true 117 | 118 | - name: check GITHUB_REF matches package version 119 | uses: samuelcolvin/check-python-version@v4.1 120 | with: 121 | version_file_path: pydantic_extra_types/__init__.py 122 | 123 | - run: uv build 124 | 125 | - name: Publish to PyPI 126 | uses: pypa/gh-action-pypi-publish@release/v1 127 | with: 128 | skip-existing: true 129 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | env/ 3 | venv/ 4 | .venv/ 5 | env3*/ 6 | Pipfile 7 | *.py[cod] 8 | *.egg-info/ 9 | .python-version 10 | /build/ 11 | dist/ 12 | .cache/ 13 | .mypy_cache/ 14 | test.py 15 | .coverage 16 | /htmlcov/ 17 | /site/ 18 | /site.zip 19 | .pytest_cache/ 20 | .vscode/ 21 | _build/ 22 | .auto-format 23 | /sandbox/ 24 | /.ghtopdep_cache/ 25 | /worktrees/ 26 | /.ruff_cache/ 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: no-commit-to-branch # prevent direct commits to the `main` branch 6 | - id: check-yaml 7 | - id: check-toml 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | 11 | - repo: local 12 | hooks: 13 | - id: format 14 | name: Format 15 | entry: make 16 | args: [format] 17 | language: system 18 | types: [python] 19 | pass_filenames: false 20 | - id: lint 21 | name: Lint 22 | entry: make 23 | args: [lint] 24 | types: [python] 25 | language: system 26 | pass_filenames: false 27 | - id: Typecheck 28 | name: Typecheck 29 | entry: make 30 | args: [typecheck] 31 | types: [python] 32 | language: system 33 | pass_filenames: false 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Samuel Colvin and other contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := all 2 | sources = pydantic_extra_types tests 3 | 4 | .PHONY: .uv # Check that uv is installed 5 | .uv: 6 | @uv --version || echo 'Please install uv: https://docs.astral.sh/uv/getting-started/installation/' 7 | 8 | .PHONY: install ## Install the package, dependencies, and pre-commit for local development 9 | install: .uv 10 | uv sync --frozen --all-groups --all-extras 11 | uv pip install pre-commit 12 | uv run pre-commit install --install-hooks 13 | 14 | .PHONY: rebuild-lockfiles ## Rebuild lockfiles from scratch, updating all dependencies 15 | rebuild-lockfiles: .uv 16 | uv lock --upgrade 17 | 18 | .PHONY: format # Format the code 19 | format: 20 | uv run ruff format 21 | uv run ruff check --fix --fix-only 22 | 23 | .PHONY: lint # Lint the code 24 | lint: 25 | uv run ruff format --check 26 | uv run ruff check 27 | 28 | .PHONY: typecheck # Typecheck the code 29 | typecheck: 30 | uv run mypy pydantic_extra_types 31 | 32 | .PHONY: test 33 | test: 34 | uv run pytest 35 | 36 | .PHONY: test-all-python # Run tests on Python 3.9 to 3.13 37 | test-all-python: 38 | UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 coverage run -p -m pytest 39 | UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 coverage run -p -m pytest 40 | UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 coverage run -p -m pytest 41 | UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 coverage run -p -m pytest 42 | UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 coverage run -p -m pytest 43 | @uv run coverage combine 44 | @uv run coverage report 45 | 46 | .PHONY: testcov # Run tests and collect coverage data 47 | testcov: 48 | uv run coverage run -m pytest 49 | @uv run coverage report 50 | @uv run coverage html 51 | 52 | .PHONY: all 53 | all: format lint testcov 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pydantic Extra Types 2 | 3 | [![CI](https://github.com/pydantic/pydantic-extra-types/workflows/CI/badge.svg?event=push)](https://github.com/pydantic/pydantic-extra-types/actions?query=event%3Apush+branch%3Amain+workflow%3ACI) 4 | [![Coverage](https://codecov.io/gh/pydantic/pydantic-extra-types/branch/main/graph/badge.svg)](https://codecov.io/gh/pydantic/pydantic-extra-types) 5 | [![pypi](https://img.shields.io/pypi/v/pydantic-extra-types.svg)](https://pypi.python.org/pypi/pydantic-extra-types) 6 | [![license](https://img.shields.io/github/license/pydantic/pydantic-extra-types.svg)](https://github.com/pydantic/pydantic-extra-types/blob/main/LICENSE) 7 | 8 | A place for pydantic types that probably shouldn't exist in the main pydantic lib. 9 | 10 | See [pydantic/pydantic#5012](https://github.com/pydantic/pydantic/issues/5012) for more info. 11 | 12 | ## Installation 13 | 14 | Install this library with the desired extras dependencies as listed in [project.optional-dependencies](./pyproject.toml). 15 | 16 | For example, if pendulum support was desired: 17 | 18 | ```shell 19 | # via uv 20 | $ uv add "pydantic-extra-types[pendulum]" 21 | 22 | # via pip 23 | $ pip install -U "pydantic-extra-types[pendulum]" 24 | ``` 25 | -------------------------------------------------------------------------------- /pydantic_extra_types/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.10.5' 2 | -------------------------------------------------------------------------------- /pydantic_extra_types/coordinate.py: -------------------------------------------------------------------------------- 1 | """The `pydantic_extra_types.coordinate` module provides the [`Latitude`][pydantic_extra_types.coordinate.Latitude], 2 | [`Longitude`][pydantic_extra_types.coordinate.Longitude], and 3 | [`Coordinate`][pydantic_extra_types.coordinate.Coordinate] data types. 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from dataclasses import dataclass 9 | from decimal import Decimal 10 | from typing import Any, ClassVar, Tuple, Union 11 | 12 | from pydantic import GetCoreSchemaHandler 13 | from pydantic._internal import _repr 14 | from pydantic_core import ArgsKwargs, PydanticCustomError, core_schema 15 | 16 | LatitudeType = Union[float, Decimal] 17 | LongitudeType = Union[float, Decimal] 18 | CoordinateType = Tuple[LatitudeType, LongitudeType] 19 | 20 | 21 | class Latitude(float): 22 | """Latitude value should be between -90 and 90, inclusive. 23 | 24 | Supports both float and Decimal types. 25 | 26 | ```py 27 | from decimal import Decimal 28 | from pydantic import BaseModel 29 | from pydantic_extra_types.coordinate import Latitude 30 | 31 | 32 | class Location(BaseModel): 33 | latitude: Latitude 34 | 35 | 36 | # Using float 37 | location1 = Location(latitude=41.40338) 38 | # Using Decimal 39 | location2 = Location(latitude=Decimal('41.40338')) 40 | ``` 41 | """ 42 | 43 | min: ClassVar[float] = -90.00 44 | max: ClassVar[float] = 90.00 45 | 46 | @classmethod 47 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 48 | return core_schema.union_schema( 49 | [ 50 | core_schema.float_schema(ge=cls.min, le=cls.max), 51 | core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)), 52 | ] 53 | ) 54 | 55 | 56 | class Longitude(float): 57 | """Longitude value should be between -180 and 180, inclusive. 58 | 59 | Supports both float and Decimal types. 60 | 61 | ```py 62 | from decimal import Decimal 63 | from pydantic import BaseModel 64 | 65 | from pydantic_extra_types.coordinate import Longitude 66 | 67 | 68 | class Location(BaseModel): 69 | longitude: Longitude 70 | 71 | 72 | # Using float 73 | location1 = Location(longitude=2.17403) 74 | # Using Decimal 75 | location2 = Location(longitude=Decimal('2.17403')) 76 | ``` 77 | """ 78 | 79 | min: ClassVar[float] = -180.00 80 | max: ClassVar[float] = 180.00 81 | 82 | @classmethod 83 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 84 | return core_schema.union_schema( 85 | [ 86 | core_schema.float_schema(ge=cls.min, le=cls.max), 87 | core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)), 88 | ] 89 | ) 90 | 91 | 92 | @dataclass 93 | class Coordinate(_repr.Representation): 94 | """Coordinate parses Latitude and Longitude. 95 | 96 | You can use the `Coordinate` data type for storing coordinates. Coordinates can be 97 | defined using one of the following formats: 98 | 99 | 1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)` or `(Decimal('41.40338'), Decimal('2.17403'))`. 100 | 2. `Coordinate` instance: `Coordinate(latitude=Latitude, longitude=Longitude)`. 101 | 102 | ```py 103 | from decimal import Decimal 104 | from pydantic import BaseModel 105 | 106 | from pydantic_extra_types.coordinate import Coordinate 107 | 108 | 109 | class Location(BaseModel): 110 | coordinate: Coordinate 111 | 112 | 113 | # Using float values 114 | location1 = Location(coordinate=(41.40338, 2.17403)) 115 | # > coordinate=Coordinate(latitude=41.40338, longitude=2.17403) 116 | 117 | # Using Decimal values 118 | location2 = Location(coordinate=(Decimal('41.40338'), Decimal('2.17403'))) 119 | # > coordinate=Coordinate(latitude=41.40338, longitude=2.17403) 120 | ``` 121 | """ 122 | 123 | _NULL_ISLAND: ClassVar[Tuple[float, float]] = (0.0, 0.0) 124 | 125 | latitude: Latitude 126 | longitude: Longitude 127 | 128 | @classmethod 129 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 130 | schema_chain = [ 131 | core_schema.no_info_wrap_validator_function(cls._parse_str, core_schema.str_schema()), 132 | core_schema.no_info_wrap_validator_function( 133 | cls._parse_tuple, 134 | handler.generate_schema(CoordinateType), 135 | ), 136 | handler(source), 137 | ] 138 | 139 | chain_length = len(schema_chain) 140 | chain_schemas = [core_schema.chain_schema(schema_chain[x:]) for x in range(chain_length - 1, -1, -1)] 141 | return core_schema.no_info_wrap_validator_function( 142 | cls._parse_args, 143 | core_schema.union_schema(chain_schemas), # type: ignore[arg-type] 144 | ) 145 | 146 | @classmethod 147 | def _parse_args(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: 148 | if isinstance(value, ArgsKwargs) and not value.kwargs: 149 | n_args = len(value.args) 150 | if n_args == 0: 151 | value = cls._NULL_ISLAND 152 | elif n_args == 1: 153 | value = value.args[0] 154 | return handler(value) 155 | 156 | @classmethod 157 | def _parse_str(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: 158 | if not isinstance(value, str): 159 | return value 160 | try: 161 | value = tuple(float(x) for x in value.split(',')) 162 | except ValueError as e: 163 | raise PydanticCustomError( 164 | 'coordinate_error', 165 | 'value is not a valid coordinate: string is not recognized as a valid coordinate', 166 | ) from e 167 | return ArgsKwargs(args=value) 168 | 169 | @classmethod 170 | def _parse_tuple(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: 171 | return ArgsKwargs(args=handler(value)) if isinstance(value, tuple) else value 172 | 173 | def __str__(self) -> str: 174 | return f'{self.latitude},{self.longitude}' 175 | 176 | def __eq__(self, other: Any) -> bool: 177 | return isinstance(other, Coordinate) and self.latitude == other.latitude and self.longitude == other.longitude 178 | 179 | def __hash__(self) -> int: 180 | return hash((self.latitude, self.longitude)) 181 | -------------------------------------------------------------------------------- /pydantic_extra_types/country.py: -------------------------------------------------------------------------------- 1 | """Country definitions that are based on the [ISO 3166](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes).""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass 6 | from functools import lru_cache 7 | from typing import Any 8 | 9 | from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler 10 | from pydantic_core import PydanticCustomError, core_schema 11 | 12 | try: 13 | import pycountry 14 | except ModuleNotFoundError as e: # pragma: no cover 15 | raise RuntimeError( 16 | 'The `country` module requires "pycountry" to be installed. You can install it with "pip install pycountry".' 17 | ) from e 18 | 19 | 20 | @dataclass 21 | class CountryInfo: 22 | alpha2: str 23 | alpha3: str 24 | numeric_code: str 25 | short_name: str 26 | 27 | 28 | @lru_cache 29 | def _countries() -> list[CountryInfo]: 30 | return [ 31 | CountryInfo( 32 | alpha2=country.alpha_2, 33 | alpha3=country.alpha_3, 34 | numeric_code=country.numeric, 35 | short_name=country.name, 36 | ) 37 | for country in pycountry.countries 38 | ] 39 | 40 | 41 | @lru_cache 42 | def _index_by_alpha2() -> dict[str, CountryInfo]: 43 | return {country.alpha2: country for country in _countries()} 44 | 45 | 46 | @lru_cache 47 | def _index_by_alpha3() -> dict[str, CountryInfo]: 48 | return {country.alpha3: country for country in _countries()} 49 | 50 | 51 | @lru_cache 52 | def _index_by_numeric_code() -> dict[str, CountryInfo]: 53 | return {country.numeric_code: country for country in _countries()} 54 | 55 | 56 | @lru_cache 57 | def _index_by_short_name() -> dict[str, CountryInfo]: 58 | return {country.short_name: country for country in _countries()} 59 | 60 | 61 | class CountryAlpha2(str): 62 | """CountryAlpha2 parses country codes in the [ISO 3166-1 alpha-2](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) 63 | format. 64 | 65 | ```py 66 | from pydantic import BaseModel 67 | 68 | from pydantic_extra_types.country import CountryAlpha2 69 | 70 | 71 | class Product(BaseModel): 72 | made_in: CountryAlpha2 73 | 74 | 75 | product = Product(made_in='ES') 76 | print(product) 77 | # > made_in='ES' 78 | ``` 79 | """ 80 | 81 | @classmethod 82 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> CountryAlpha2: 83 | if __input_value not in _index_by_alpha2(): 84 | raise PydanticCustomError('country_alpha2', 'Invalid country alpha2 code') 85 | return cls(__input_value) 86 | 87 | @classmethod 88 | def __get_pydantic_core_schema__( 89 | cls, source: type[Any], handler: GetCoreSchemaHandler 90 | ) -> core_schema.AfterValidatorFunctionSchema: 91 | return core_schema.with_info_after_validator_function( 92 | cls._validate, 93 | core_schema.str_schema(to_upper=True), 94 | ) 95 | 96 | @classmethod 97 | def __get_pydantic_json_schema__( 98 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 99 | ) -> dict[str, Any]: 100 | json_schema = handler(schema) 101 | json_schema.update({'pattern': r'^\w{2}$'}) 102 | return json_schema 103 | 104 | @property 105 | def alpha3(self) -> str: 106 | """The country code in the [ISO 3166-1 alpha-3](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3) format.""" 107 | return _index_by_alpha2()[self].alpha3 108 | 109 | @property 110 | def numeric_code(self) -> str: 111 | """The country code in the [ISO 3166-1 numeric](https://en.wikipedia.org/wiki/ISO_3166-1_numeric) format.""" 112 | return _index_by_alpha2()[self].numeric_code 113 | 114 | @property 115 | def short_name(self) -> str: 116 | """The country short name.""" 117 | return _index_by_alpha2()[self].short_name 118 | 119 | 120 | class CountryAlpha3(str): 121 | """CountryAlpha3 parses country codes in the [ISO 3166-1 alpha-3](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3) 122 | format. 123 | 124 | ```py 125 | from pydantic import BaseModel 126 | 127 | from pydantic_extra_types.country import CountryAlpha3 128 | 129 | 130 | class Product(BaseModel): 131 | made_in: CountryAlpha3 132 | 133 | 134 | product = Product(made_in='USA') 135 | print(product) 136 | # > made_in='USA' 137 | ``` 138 | """ 139 | 140 | @classmethod 141 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> CountryAlpha3: 142 | if __input_value not in _index_by_alpha3(): 143 | raise PydanticCustomError('country_alpha3', 'Invalid country alpha3 code') 144 | return cls(__input_value) 145 | 146 | @classmethod 147 | def __get_pydantic_core_schema__( 148 | cls, source: type[Any], handler: GetCoreSchemaHandler 149 | ) -> core_schema.AfterValidatorFunctionSchema: 150 | return core_schema.with_info_after_validator_function( 151 | cls._validate, 152 | core_schema.str_schema(to_upper=True), 153 | serialization=core_schema.to_string_ser_schema(), 154 | ) 155 | 156 | @classmethod 157 | def __get_pydantic_json_schema__( 158 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 159 | ) -> dict[str, Any]: 160 | json_schema = handler(schema) 161 | json_schema.update({'pattern': r'^\w{3}$'}) 162 | return json_schema 163 | 164 | @property 165 | def alpha2(self) -> str: 166 | """The country code in the [ISO 3166-1 alpha-2](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) format.""" 167 | return _index_by_alpha3()[self].alpha2 168 | 169 | @property 170 | def numeric_code(self) -> str: 171 | """The country code in the [ISO 3166-1 numeric](https://en.wikipedia.org/wiki/ISO_3166-1_numeric) format.""" 172 | return _index_by_alpha3()[self].numeric_code 173 | 174 | @property 175 | def short_name(self) -> str: 176 | """The country short name.""" 177 | return _index_by_alpha3()[self].short_name 178 | 179 | 180 | class CountryNumericCode(str): 181 | """CountryNumericCode parses country codes in the 182 | [ISO 3166-1 numeric](https://en.wikipedia.org/wiki/ISO_3166-1_numeric) format. 183 | 184 | ```py 185 | from pydantic import BaseModel 186 | 187 | from pydantic_extra_types.country import CountryNumericCode 188 | 189 | 190 | class Product(BaseModel): 191 | made_in: CountryNumericCode 192 | 193 | 194 | product = Product(made_in='840') 195 | print(product) 196 | # > made_in='840' 197 | ``` 198 | """ 199 | 200 | @classmethod 201 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> CountryNumericCode: 202 | if __input_value not in _index_by_numeric_code(): 203 | raise PydanticCustomError('country_numeric_code', 'Invalid country numeric code') 204 | return cls(__input_value) 205 | 206 | @classmethod 207 | def __get_pydantic_core_schema__( 208 | cls, source: type[Any], handler: GetCoreSchemaHandler 209 | ) -> core_schema.AfterValidatorFunctionSchema: 210 | return core_schema.with_info_after_validator_function( 211 | cls._validate, 212 | core_schema.str_schema(to_upper=True), 213 | serialization=core_schema.to_string_ser_schema(), 214 | ) 215 | 216 | @classmethod 217 | def __get_pydantic_json_schema__( 218 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 219 | ) -> dict[str, Any]: 220 | json_schema = handler(schema) 221 | json_schema.update({'pattern': r'^[0-9]{3}$'}) 222 | return json_schema 223 | 224 | @property 225 | def alpha2(self) -> str: 226 | """The country code in the [ISO 3166-1 alpha-2](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) format.""" 227 | return _index_by_numeric_code()[self].alpha2 228 | 229 | @property 230 | def alpha3(self) -> str: 231 | """The country code in the [ISO 3166-1 alpha-3](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3) format.""" 232 | return _index_by_numeric_code()[self].alpha3 233 | 234 | @property 235 | def short_name(self) -> str: 236 | """The country short name.""" 237 | return _index_by_numeric_code()[self].short_name 238 | 239 | 240 | class CountryShortName(str): 241 | """CountryShortName parses country codes in the short name format. 242 | 243 | ```py 244 | from pydantic import BaseModel 245 | 246 | from pydantic_extra_types.country import CountryShortName 247 | 248 | 249 | class Product(BaseModel): 250 | made_in: CountryShortName 251 | 252 | 253 | product = Product(made_in='United States') 254 | print(product) 255 | # > made_in='United States' 256 | ``` 257 | """ 258 | 259 | @classmethod 260 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> CountryShortName: 261 | if __input_value not in _index_by_short_name(): 262 | raise PydanticCustomError('country_short_name', 'Invalid country short name') 263 | return cls(__input_value) 264 | 265 | @classmethod 266 | def __get_pydantic_core_schema__( 267 | cls, source: type[Any], handler: GetCoreSchemaHandler 268 | ) -> core_schema.AfterValidatorFunctionSchema: 269 | return core_schema.with_info_after_validator_function( 270 | cls._validate, 271 | core_schema.str_schema(), 272 | serialization=core_schema.to_string_ser_schema(), 273 | ) 274 | 275 | @property 276 | def alpha2(self) -> str: 277 | """The country code in the [ISO 3166-1 alpha-2](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) format.""" 278 | return _index_by_short_name()[self].alpha2 279 | 280 | @property 281 | def alpha3(self) -> str: 282 | """The country code in the [ISO 3166-1 alpha-3](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3) format.""" 283 | return _index_by_short_name()[self].alpha3 284 | 285 | @property 286 | def numeric_code(self) -> str: 287 | """The country code in the [ISO 3166-1 numeric](https://en.wikipedia.org/wiki/ISO_3166-1_numeric) format.""" 288 | return _index_by_short_name()[self].numeric_code 289 | -------------------------------------------------------------------------------- /pydantic_extra_types/currency_code.py: -------------------------------------------------------------------------------- 1 | """Currency definitions that are based on the [ISO4217](https://en.wikipedia.org/wiki/ISO_4217).""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler 8 | from pydantic_core import PydanticCustomError, core_schema 9 | 10 | try: 11 | import pycountry 12 | except ModuleNotFoundError as e: # pragma: no cover 13 | raise RuntimeError( 14 | 'The `currency_code` module requires "pycountry" to be installed. You can install it with "pip install ' 15 | 'pycountry".' 16 | ) from e 17 | 18 | # List of codes that should not be usually used within regular transactions 19 | _CODES_FOR_BONDS_METAL_TESTING = { 20 | 'XTS', # testing 21 | 'XAU', # gold 22 | 'XAG', # silver 23 | 'XPD', # palladium 24 | 'XPT', # platinum 25 | 'XBA', # Bond Markets Unit European Composite Unit (EURCO) 26 | 'XBB', # Bond Markets Unit European Monetary Unit (E.M.U.-6) 27 | 'XBC', # Bond Markets Unit European Unit of Account 9 (E.U.A.-9) 28 | 'XBD', # Bond Markets Unit European Unit of Account 17 (E.U.A.-17) 29 | 'XXX', # no currency 30 | 'XDR', # SDR (Special Drawing Right) 31 | } 32 | 33 | 34 | class ISO4217(str): 35 | """ISO4217 parses Currency in the [ISO 4217](https://en.wikipedia.org/wiki/ISO_4217) format. 36 | 37 | ```py 38 | from pydantic import BaseModel 39 | 40 | from pydantic_extra_types.currency_code import ISO4217 41 | 42 | 43 | class Currency(BaseModel): 44 | alpha_3: ISO4217 45 | 46 | 47 | currency = Currency(alpha_3='AED') 48 | print(currency) 49 | # > alpha_3='AED' 50 | ``` 51 | """ 52 | 53 | allowed_countries_list = [country.alpha_3 for country in pycountry.currencies] 54 | allowed_currencies = set(allowed_countries_list) 55 | 56 | @classmethod 57 | def _validate(cls, currency_code: str, _: core_schema.ValidationInfo) -> str: 58 | """Validate a ISO 4217 language code from the provided str value. 59 | 60 | Args: 61 | currency_code: The str value to be validated. 62 | _: The Pydantic ValidationInfo. 63 | 64 | Returns: 65 | The validated ISO 4217 currency code. 66 | 67 | Raises: 68 | PydanticCustomError: If the ISO 4217 currency code is not valid. 69 | """ 70 | currency_code = currency_code.upper() 71 | if currency_code not in cls.allowed_currencies: 72 | raise PydanticCustomError( 73 | 'ISO4217', 'Invalid ISO 4217 currency code. See https://en.wikipedia.org/wiki/ISO_4217' 74 | ) 75 | return currency_code 76 | 77 | @classmethod 78 | def __get_pydantic_core_schema__(cls, _: type[Any], __: GetCoreSchemaHandler) -> core_schema.CoreSchema: 79 | return core_schema.with_info_after_validator_function( 80 | cls._validate, 81 | core_schema.str_schema(min_length=3, max_length=3), 82 | ) 83 | 84 | @classmethod 85 | def __get_pydantic_json_schema__( 86 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 87 | ) -> dict[str, Any]: 88 | json_schema = handler(schema) 89 | json_schema.update({'enum': cls.allowed_countries_list}) 90 | return json_schema 91 | 92 | 93 | class Currency(str): 94 | """Currency parses currency subset of the [ISO 4217](https://en.wikipedia.org/wiki/ISO_4217) format. 95 | It excludes bonds testing codes and precious metals. 96 | ```py 97 | from pydantic import BaseModel 98 | 99 | from pydantic_extra_types.currency_code import Currency 100 | 101 | 102 | class currency(BaseModel): 103 | alpha_3: Currency 104 | 105 | 106 | cur = currency(alpha_3='AED') 107 | print(cur) 108 | # > alpha_3='AED' 109 | ``` 110 | """ 111 | 112 | allowed_countries_list = list( 113 | filter(lambda x: x not in _CODES_FOR_BONDS_METAL_TESTING, ISO4217.allowed_countries_list) 114 | ) 115 | allowed_currencies = set(allowed_countries_list) 116 | 117 | @classmethod 118 | def _validate(cls, currency_symbol: str, _: core_schema.ValidationInfo) -> str: 119 | """Validate a subset of the [ISO4217](https://en.wikipedia.org/wiki/ISO_4217) format. 120 | It excludes bonds testing codes and precious metals. 121 | 122 | Args: 123 | currency_symbol: The str value to be validated. 124 | _: The Pydantic ValidationInfo. 125 | 126 | Returns: 127 | The validated ISO 4217 currency code. 128 | 129 | Raises: 130 | PydanticCustomError: If the ISO 4217 currency code is not valid or is bond, precious metal or testing code. 131 | """ 132 | currency_symbol = currency_symbol.upper() 133 | if currency_symbol not in cls.allowed_currencies: 134 | raise PydanticCustomError( 135 | 'InvalidCurrency', 136 | 'Invalid currency code.' 137 | ' See https://en.wikipedia.org/wiki/ISO_4217 . ' 138 | 'Bonds, testing and precious metals codes are not allowed.', 139 | ) 140 | return currency_symbol 141 | 142 | @classmethod 143 | def __get_pydantic_core_schema__(cls, _: type[Any], __: GetCoreSchemaHandler) -> core_schema.CoreSchema: 144 | """Return a Pydantic CoreSchema with the currency subset of the 145 | [ISO4217](https://en.wikipedia.org/wiki/ISO_4217) format. 146 | It excludes bonds testing codes and precious metals. 147 | 148 | Args: 149 | _: The source type. 150 | __: The handler to get the CoreSchema. 151 | 152 | Returns: 153 | A Pydantic CoreSchema with the subset of the currency subset of the 154 | [ISO4217](https://en.wikipedia.org/wiki/ISO_4217) format. 155 | It excludes bonds testing codes and precious metals. 156 | """ 157 | return core_schema.with_info_after_validator_function( 158 | cls._validate, 159 | core_schema.str_schema(min_length=3, max_length=3), 160 | ) 161 | 162 | @classmethod 163 | def __get_pydantic_json_schema__( 164 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 165 | ) -> dict[str, Any]: 166 | """Return a Pydantic JSON Schema with subset of the [ISO4217](https://en.wikipedia.org/wiki/ISO_4217) format. 167 | Excluding bonds testing codes and precious metals. 168 | 169 | Args: 170 | schema: The Pydantic CoreSchema. 171 | handler: The handler to get the JSON Schema. 172 | 173 | Returns: 174 | A Pydantic JSON Schema with the subset of the ISO4217 currency code validation. without bonds testing codes 175 | and precious metals. 176 | 177 | """ 178 | json_schema = handler(schema) 179 | json_schema.update({'enum': cls.allowed_countries_list}) 180 | return json_schema 181 | -------------------------------------------------------------------------------- /pydantic_extra_types/domain.py: -------------------------------------------------------------------------------- 1 | """The `domain_str` module provides the `DomainStr` data type. 2 | This class depends on the `pydantic` package and implements custom validation for domain string format. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import re 8 | from typing import Any 9 | 10 | from pydantic import GetCoreSchemaHandler 11 | from pydantic_core import PydanticCustomError, core_schema 12 | 13 | 14 | class DomainStr(str): 15 | """A string subclass with custom validation for domain string format.""" 16 | 17 | _domain_re_pattern = r'(?=^.{1,253}$)(^((?!-)[a-zA-Z0-9-]{1,63}(? str: 21 | """Validate a domain name from the provided value. 22 | 23 | Args: 24 | __input_value: The value to be validated. 25 | _: The source type to be converted. 26 | 27 | Returns: 28 | str: The parsed domain name. 29 | 30 | """ 31 | return cls._validate(__input_value) 32 | 33 | @classmethod 34 | def _validate(cls, v: Any) -> DomainStr: 35 | if not isinstance(v, str): 36 | raise PydanticCustomError('domain_type', 'Value must be a string') 37 | 38 | v = v.strip().lower() 39 | if len(v) < 1 or len(v) > 253: 40 | raise PydanticCustomError('domain_length', 'Domain must be between 1 and 253 characters') 41 | 42 | if not re.match(cls._domain_re_pattern, v): 43 | raise PydanticCustomError('domain_format', 'Invalid domain format') 44 | 45 | return cls(v) 46 | 47 | @classmethod 48 | def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 49 | return core_schema.with_info_before_validator_function( 50 | cls.validate, 51 | core_schema.str_schema(), 52 | ) 53 | 54 | @classmethod 55 | def __get_pydantic_json_schema__( 56 | cls, schema: core_schema.CoreSchema, handler: GetCoreSchemaHandler 57 | ) -> dict[str, Any]: 58 | # Cast the return value to dict[str, Any] 59 | return dict(handler(schema)) 60 | -------------------------------------------------------------------------------- /pydantic_extra_types/epoch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import datetime 4 | from typing import Any, Callable 5 | 6 | import pydantic_core.core_schema 7 | from pydantic import GetJsonSchemaHandler 8 | from pydantic.json_schema import JsonSchemaValue 9 | from pydantic_core import CoreSchema, core_schema 10 | 11 | EPOCH = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) 12 | 13 | 14 | class _Base(datetime.datetime): 15 | TYPE: str = '' 16 | SCHEMA: pydantic_core.core_schema.CoreSchema 17 | 18 | @classmethod 19 | def __get_pydantic_json_schema__( 20 | cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 21 | ) -> JsonSchemaValue: 22 | field_schema: dict[str, Any] = {} 23 | field_schema.update(type=cls.TYPE, format='date-time') 24 | return field_schema 25 | 26 | @classmethod 27 | def __get_pydantic_core_schema__( 28 | cls, source: type[Any], handler: Callable[[Any], CoreSchema] 29 | ) -> core_schema.CoreSchema: 30 | return core_schema.with_info_after_validator_function( 31 | cls._validate, 32 | cls.SCHEMA, 33 | serialization=core_schema.wrap_serializer_function_ser_schema(cls._f, return_schema=cls.SCHEMA), 34 | ) 35 | 36 | @classmethod 37 | def _validate(cls, __input_value: Any, _: Any) -> datetime.datetime: 38 | return EPOCH + datetime.timedelta(seconds=__input_value) 39 | 40 | @classmethod 41 | def _f(cls, value: Any, serializer: Callable[[Any], Any]) -> Any: # pragma: no cover 42 | raise NotImplementedError(cls) 43 | 44 | 45 | class Number(_Base): 46 | """epoch.Number parses unix timestamp as float and converts it to datetime. 47 | 48 | ```py 49 | from pydantic import BaseModel 50 | 51 | from pydantic_extra_types import epoch 52 | 53 | 54 | class LogEntry(BaseModel): 55 | timestamp: epoch.Number 56 | 57 | 58 | logentry = LogEntry(timestamp=1.1) 59 | print(logentry) 60 | # > timestamp=datetime.datetime(1970, 1, 1, 0, 0, 1, 100000, tzinfo=datetime.timezone.utc) 61 | ``` 62 | """ 63 | 64 | TYPE = 'number' 65 | SCHEMA = core_schema.float_schema() 66 | 67 | @classmethod 68 | def _f(cls, value: Any, serializer: Callable[[float], float]) -> float: 69 | ts = value.timestamp() 70 | return serializer(ts) 71 | 72 | 73 | class Integer(_Base): 74 | """epoch.Integer parses unix timestamp as integer and converts it to datetime. 75 | 76 | ``` 77 | ```py 78 | from pydantic import BaseModel 79 | 80 | from pydantic_extra_types import epoch 81 | 82 | class LogEntry(BaseModel): 83 | timestamp: epoch.Integer 84 | 85 | logentry = LogEntry(timestamp=1) 86 | print(logentry) 87 | #> timestamp=datetime.datetime(1970, 1, 1, 0, 0, 1, tzinfo=datetime.timezone.utc) 88 | ``` 89 | """ 90 | 91 | TYPE = 'integer' 92 | SCHEMA = core_schema.int_schema() 93 | 94 | @classmethod 95 | def _f(cls, value: Any, serializer: Callable[[int], int]) -> int: 96 | ts = value.timestamp() 97 | return serializer(int(ts)) 98 | -------------------------------------------------------------------------------- /pydantic_extra_types/isbn.py: -------------------------------------------------------------------------------- 1 | """The `pydantic_extra_types.isbn` module provides functionality to recieve and validate ISBN. 2 | 3 | ISBN (International Standard Book Number) is a numeric commercial book identifier which is intended to be unique. This module provides a ISBN type for Pydantic models. 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | import itertools as it 9 | from typing import Any 10 | 11 | from pydantic import GetCoreSchemaHandler 12 | from pydantic_core import PydanticCustomError, core_schema 13 | 14 | 15 | def isbn10_digit_calc(isbn: str) -> str: 16 | """Calc a ISBN-10 last digit from the provided str value. More information of validation algorithm on [Wikipedia](https://en.wikipedia.org/wiki/ISBN#Check_digits) 17 | 18 | Args: 19 | isbn: The str value representing the ISBN in 10 digits. 20 | 21 | Returns: 22 | The calculated last digit of the ISBN-10 value. 23 | """ 24 | total = sum(int(digit) * (10 - idx) for idx, digit in enumerate(isbn[:9])) 25 | diff = (11 - total) % 11 26 | valid_check_digit = 'X' if diff == 10 else str(diff) 27 | return valid_check_digit 28 | 29 | 30 | def isbn13_digit_calc(isbn: str) -> str: 31 | """Calc a ISBN-13 last digit from the provided str value. More information of validation algorithm on [Wikipedia](https://en.wikipedia.org/wiki/ISBN#Check_digits) 32 | 33 | Args: 34 | isbn: The str value representing the ISBN in 13 digits. 35 | 36 | Returns: 37 | The calculated last digit of the ISBN-13 value. 38 | """ 39 | total = sum(int(digit) * factor for digit, factor in zip(isbn[:12], it.cycle((1, 3)))) 40 | 41 | check_digit = (10 - total) % 10 42 | 43 | return str(check_digit) 44 | 45 | 46 | class ISBN(str): 47 | """Represents a ISBN and provides methods for conversion, validation, and serialization. 48 | 49 | ```py 50 | from pydantic import BaseModel 51 | 52 | from pydantic_extra_types.isbn import ISBN 53 | 54 | 55 | class Book(BaseModel): 56 | isbn: ISBN 57 | 58 | 59 | book = Book(isbn='8537809667') 60 | print(book) 61 | # > isbn='9788537809662' 62 | ``` 63 | """ 64 | 65 | @classmethod 66 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 67 | """Return a Pydantic CoreSchema with the ISBN validation. 68 | 69 | Args: 70 | source: The source type to be converted. 71 | handler: The handler to get the CoreSchema. 72 | 73 | Returns: 74 | A Pydantic CoreSchema with the ISBN validation. 75 | 76 | """ 77 | return core_schema.with_info_before_validator_function( 78 | cls._validate, 79 | core_schema.str_schema(), 80 | ) 81 | 82 | @classmethod 83 | def _validate(cls, __input_value: str, _: Any) -> str: 84 | """Validate a ISBN from the provided str value. 85 | 86 | Args: 87 | __input_value: The str value to be validated. 88 | _: The source type to be converted. 89 | 90 | Returns: 91 | The validated ISBN. 92 | 93 | Raises: 94 | PydanticCustomError: If the ISBN is not valid. 95 | """ 96 | cls.validate_isbn_format(__input_value) 97 | 98 | return cls.convert_isbn10_to_isbn13(__input_value) 99 | 100 | @staticmethod 101 | def validate_isbn_format(value: str) -> None: 102 | """Validate a ISBN format from the provided str value. 103 | 104 | Args: 105 | value: The str value representing the ISBN in 10 or 13 digits. 106 | 107 | Raises: 108 | PydanticCustomError: If the ISBN is not valid. 109 | """ 110 | isbn_length = len(value) 111 | 112 | if isbn_length not in (10, 13): 113 | raise PydanticCustomError('isbn_length', f'Length for ISBN must be 10 or 13 digits, not {isbn_length}') 114 | 115 | if isbn_length == 10: 116 | if not value[:-1].isdigit() or ((value[-1] != 'X') and (not value[-1].isdigit())): 117 | raise PydanticCustomError('isbn10_invalid_characters', 'First 9 digits of ISBN-10 must be integers') 118 | if isbn10_digit_calc(value) != value[-1]: 119 | raise PydanticCustomError('isbn_invalid_digit_check_isbn10', 'Provided digit is invalid for given ISBN') 120 | 121 | if isbn_length == 13: 122 | if not value.isdigit(): 123 | raise PydanticCustomError('isbn13_invalid_characters', 'All digits of ISBN-13 must be integers') 124 | if value[:3] not in ('978', '979'): 125 | raise PydanticCustomError( 126 | 'isbn_invalid_early_characters', 'The first 3 digits of ISBN-13 must be 978 or 979' 127 | ) 128 | if isbn13_digit_calc(value) != value[-1]: 129 | raise PydanticCustomError('isbn_invalid_digit_check_isbn13', 'Provided digit is invalid for given ISBN') 130 | 131 | @staticmethod 132 | def convert_isbn10_to_isbn13(value: str) -> str: 133 | """Convert an ISBN-10 to ISBN-13. 134 | 135 | Args: 136 | value: The ISBN-10 value to be converted. 137 | 138 | Returns: 139 | The converted ISBN or the original value if no conversion is necessary. 140 | """ 141 | if len(value) == 10: 142 | base_isbn = f'978{value[:-1]}' 143 | isbn13_digit = isbn13_digit_calc(base_isbn) 144 | return ISBN(f'{base_isbn}{isbn13_digit}') 145 | 146 | return ISBN(value) 147 | -------------------------------------------------------------------------------- /pydantic_extra_types/language_code.py: -------------------------------------------------------------------------------- 1 | """Language definitions that are based on the [ISO 639-3](https://en.wikipedia.org/wiki/ISO_639-3) & [ISO 639-5](https://en.wikipedia.org/wiki/ISO_639-5).""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass 6 | from functools import lru_cache 7 | from typing import Any, Union 8 | 9 | from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler 10 | from pydantic_core import PydanticCustomError, core_schema 11 | 12 | try: 13 | import pycountry 14 | except ModuleNotFoundError as e: # pragma: no cover 15 | raise RuntimeError( 16 | 'The `language_code` module requires "pycountry" to be installed.' 17 | ' You can install it with "pip install pycountry".' 18 | ) from e 19 | 20 | 21 | @dataclass 22 | class LanguageInfo: 23 | """LanguageInfo is a dataclass that contains the language information. 24 | 25 | Args: 26 | alpha2: The language code in the [ISO 639-1 alpha-2](https://en.wikipedia.org/wiki/ISO_639-1) format. 27 | alpha3: The language code in the [ISO 639-3 alpha-3](https://en.wikipedia.org/wiki/ISO_639-3) format. 28 | name: The language name. 29 | """ 30 | 31 | alpha2: Union[str, None] 32 | alpha3: str 33 | name: str 34 | 35 | 36 | @lru_cache 37 | def _languages() -> list[LanguageInfo]: 38 | """Return a list of LanguageInfo objects containing the language information. 39 | 40 | Returns: 41 | A list of LanguageInfo objects containing the language information. 42 | """ 43 | return [ 44 | LanguageInfo( 45 | alpha2=getattr(language, 'alpha_2', None), 46 | alpha3=language.alpha_3, 47 | name=language.name, 48 | ) 49 | for language in pycountry.languages 50 | ] 51 | 52 | 53 | @lru_cache 54 | def _index_by_alpha2() -> dict[str, LanguageInfo]: 55 | """Return a dictionary with the language code in the [ISO 639-1 alpha-2](https://en.wikipedia.org/wiki/ISO_639-1) format as the key and the LanguageInfo object as the value.""" 56 | return {language.alpha2: language for language in _languages() if language.alpha2 is not None} 57 | 58 | 59 | @lru_cache 60 | def _index_by_alpha3() -> dict[str, LanguageInfo]: 61 | """Return a dictionary with the language code in the [ISO 639-3 alpha-3](https://en.wikipedia.org/wiki/ISO_639-3) format as the key and the LanguageInfo object as the value.""" 62 | return {language.alpha3: language for language in _languages()} 63 | 64 | 65 | @lru_cache 66 | def _index_by_name() -> dict[str, LanguageInfo]: 67 | """Return a dictionary with the language name as the key and the LanguageInfo object as the value.""" 68 | return {language.name: language for language in _languages()} 69 | 70 | 71 | class LanguageAlpha2(str): 72 | """LanguageAlpha2 parses languages codes in the [ISO 639-1 alpha-2](https://en.wikipedia.org/wiki/ISO_639-1) 73 | format. 74 | 75 | ```py 76 | from pydantic import BaseModel 77 | 78 | from pydantic_extra_types.language_code import LanguageAlpha2 79 | 80 | 81 | class Movie(BaseModel): 82 | audio_lang: LanguageAlpha2 83 | subtitles_lang: LanguageAlpha2 84 | 85 | 86 | movie = Movie(audio_lang='de', subtitles_lang='fr') 87 | print(movie) 88 | # > audio_lang='de' subtitles_lang='fr' 89 | ``` 90 | """ 91 | 92 | @classmethod 93 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> LanguageAlpha2: 94 | """Validate a language code in the ISO 639-1 alpha-2 format from the provided str value. 95 | 96 | Args: 97 | __input_value: The str value to be validated. 98 | _: The Pydantic ValidationInfo. 99 | 100 | Returns: 101 | The validated language code in the ISO 639-1 alpha-2 format. 102 | """ 103 | if __input_value not in _index_by_alpha2(): 104 | raise PydanticCustomError('language_alpha2', 'Invalid language alpha2 code') 105 | return cls(__input_value) 106 | 107 | @classmethod 108 | def __get_pydantic_core_schema__( 109 | cls, source: type[Any], handler: GetCoreSchemaHandler 110 | ) -> core_schema.AfterValidatorFunctionSchema: 111 | """Return a Pydantic CoreSchema with the language code in the ISO 639-1 alpha-2 format validation. 112 | 113 | Args: 114 | source: The source type. 115 | handler: The handler to get the CoreSchema. 116 | 117 | Returns: 118 | A Pydantic CoreSchema with the language code in the ISO 639-1 alpha-2 format validation. 119 | """ 120 | return core_schema.with_info_after_validator_function( 121 | cls._validate, 122 | core_schema.str_schema(to_lower=True), 123 | ) 124 | 125 | @classmethod 126 | def __get_pydantic_json_schema__( 127 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 128 | ) -> dict[str, Any]: 129 | """Return a Pydantic JSON Schema with the language code in the ISO 639-1 alpha-2 format validation. 130 | 131 | Args: 132 | schema: The Pydantic CoreSchema. 133 | handler: The handler to get the JSON Schema. 134 | 135 | Returns: 136 | A Pydantic JSON Schema with the language code in the ISO 639-1 alpha-2 format validation. 137 | """ 138 | json_schema = handler(schema) 139 | json_schema.update({'pattern': r'^\w{2}$'}) 140 | return json_schema 141 | 142 | @property 143 | def alpha3(self) -> str: 144 | """The language code in the [ISO 639-3 alpha-3](https://en.wikipedia.org/wiki/ISO_639-3) format.""" 145 | return _index_by_alpha2()[self].alpha3 146 | 147 | @property 148 | def name(self) -> str: 149 | """The language name.""" 150 | return _index_by_alpha2()[self].name 151 | 152 | 153 | class LanguageName(str): 154 | """LanguageName parses languages names listed in the [ISO 639-3 standard](https://en.wikipedia.org/wiki/ISO_639-3) 155 | format. 156 | 157 | ```py 158 | from pydantic import BaseModel 159 | 160 | from pydantic_extra_types.language_code import LanguageName 161 | 162 | 163 | class Movie(BaseModel): 164 | audio_lang: LanguageName 165 | subtitles_lang: LanguageName 166 | 167 | 168 | movie = Movie(audio_lang='Dutch', subtitles_lang='Mandarin Chinese') 169 | print(movie) 170 | # > audio_lang='Dutch' subtitles_lang='Mandarin Chinese' 171 | ``` 172 | """ 173 | 174 | @classmethod 175 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> LanguageName: 176 | """Validate a language name from the provided str value. 177 | 178 | Args: 179 | __input_value: The str value to be validated. 180 | _: The Pydantic ValidationInfo. 181 | 182 | Returns: 183 | The validated language name. 184 | """ 185 | if __input_value not in _index_by_name(): 186 | raise PydanticCustomError('language_name', 'Invalid language name') 187 | return cls(__input_value) 188 | 189 | @classmethod 190 | def __get_pydantic_core_schema__( 191 | cls, source: type[Any], handler: GetCoreSchemaHandler 192 | ) -> core_schema.AfterValidatorFunctionSchema: 193 | """Return a Pydantic CoreSchema with the language name validation. 194 | 195 | Args: 196 | source: The source type. 197 | handler: The handler to get the CoreSchema. 198 | 199 | Returns: 200 | A Pydantic CoreSchema with the language name validation. 201 | """ 202 | return core_schema.with_info_after_validator_function( 203 | cls._validate, 204 | core_schema.str_schema(), 205 | serialization=core_schema.to_string_ser_schema(), 206 | ) 207 | 208 | @property 209 | def alpha2(self) -> Union[str, None]: 210 | """The language code in the [ISO 639-1 alpha-2](https://en.wikipedia.org/wiki/ISO_639-1) format. Does not exist for all languages.""" 211 | return _index_by_name()[self].alpha2 212 | 213 | @property 214 | def alpha3(self) -> str: 215 | """The language code in the [ISO 639-3 alpha-3](https://en.wikipedia.org/wiki/ISO_639-3) format.""" 216 | return _index_by_name()[self].alpha3 217 | 218 | 219 | class ISO639_3(str): 220 | """ISO639_3 parses Language in the [ISO 639-3 alpha-3](https://en.wikipedia.org/wiki/ISO_639-3_alpha-3) 221 | format. 222 | 223 | ```py 224 | from pydantic import BaseModel 225 | 226 | from pydantic_extra_types.language_code import ISO639_3 227 | 228 | 229 | class Language(BaseModel): 230 | alpha_3: ISO639_3 231 | 232 | 233 | lang = Language(alpha_3='ssr') 234 | print(lang) 235 | # > alpha_3='ssr' 236 | ``` 237 | """ 238 | 239 | allowed_values_list = [lang.alpha_3 for lang in pycountry.languages] 240 | allowed_values = set(allowed_values_list) 241 | 242 | @classmethod 243 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> ISO639_3: 244 | """Validate a ISO 639-3 language code from the provided str value. 245 | 246 | Args: 247 | __input_value: The str value to be validated. 248 | _: The Pydantic ValidationInfo. 249 | 250 | Returns: 251 | The validated ISO 639-3 language code. 252 | 253 | Raises: 254 | PydanticCustomError: If the ISO 639-3 language code is not valid. 255 | """ 256 | if __input_value not in cls.allowed_values: 257 | raise PydanticCustomError( 258 | 'ISO649_3', 'Invalid ISO 639-3 language code. See https://en.wikipedia.org/wiki/ISO_639-3' 259 | ) 260 | return cls(__input_value) 261 | 262 | @classmethod 263 | def __get_pydantic_core_schema__( 264 | cls, _: type[Any], __: GetCoreSchemaHandler 265 | ) -> core_schema.AfterValidatorFunctionSchema: 266 | """Return a Pydantic CoreSchema with the ISO 639-3 language code validation. 267 | 268 | Args: 269 | _: The source type. 270 | __: The handler to get the CoreSchema. 271 | 272 | Returns: 273 | A Pydantic CoreSchema with the ISO 639-3 language code validation. 274 | 275 | """ 276 | return core_schema.with_info_after_validator_function( 277 | cls._validate, 278 | core_schema.str_schema(min_length=3, max_length=3), 279 | ) 280 | 281 | @classmethod 282 | def __get_pydantic_json_schema__( 283 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 284 | ) -> dict[str, Any]: 285 | """Return a Pydantic JSON Schema with the ISO 639-3 language code validation. 286 | 287 | Args: 288 | schema: The Pydantic CoreSchema. 289 | handler: The handler to get the JSON Schema. 290 | 291 | Returns: 292 | A Pydantic JSON Schema with the ISO 639-3 language code validation. 293 | 294 | """ 295 | json_schema = handler(schema) 296 | json_schema.update({'enum': cls.allowed_values_list}) 297 | return json_schema 298 | 299 | 300 | class ISO639_5(str): 301 | """ISO639_5 parses Language in the [ISO 639-5 alpha-3](https://en.wikipedia.org/wiki/ISO_639-5_alpha-3) 302 | format. 303 | 304 | ```py 305 | from pydantic import BaseModel 306 | 307 | from pydantic_extra_types.language_code import ISO639_5 308 | 309 | 310 | class Language(BaseModel): 311 | alpha_3: ISO639_5 312 | 313 | 314 | lang = Language(alpha_3='gem') 315 | print(lang) 316 | # > alpha_3='gem' 317 | ``` 318 | """ 319 | 320 | allowed_values_list = [lang.alpha_3 for lang in pycountry.language_families] 321 | allowed_values_list.sort() 322 | allowed_values = set(allowed_values_list) 323 | 324 | @classmethod 325 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> ISO639_5: 326 | """Validate a ISO 639-5 language code from the provided str value. 327 | 328 | Args: 329 | __input_value: The str value to be validated. 330 | _: The Pydantic ValidationInfo. 331 | 332 | Returns: 333 | The validated ISO 639-3 language code. 334 | 335 | Raises: 336 | PydanticCustomError: If the ISO 639-5 language code is not valid. 337 | """ 338 | if __input_value not in cls.allowed_values: 339 | raise PydanticCustomError( 340 | 'ISO649_5', 'Invalid ISO 639-5 language code. See https://en.wikipedia.org/wiki/ISO_639-5' 341 | ) 342 | return cls(__input_value) 343 | 344 | @classmethod 345 | def __get_pydantic_core_schema__( 346 | cls, _: type[Any], __: GetCoreSchemaHandler 347 | ) -> core_schema.AfterValidatorFunctionSchema: 348 | """Return a Pydantic CoreSchema with the ISO 639-5 language code validation. 349 | 350 | Args: 351 | _: The source type. 352 | __: The handler to get the CoreSchema. 353 | 354 | Returns: 355 | A Pydantic CoreSchema with the ISO 639-5 language code validation. 356 | 357 | """ 358 | return core_schema.with_info_after_validator_function( 359 | cls._validate, 360 | core_schema.str_schema(min_length=3, max_length=3), 361 | ) 362 | 363 | @classmethod 364 | def __get_pydantic_json_schema__( 365 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 366 | ) -> dict[str, Any]: 367 | """Return a Pydantic JSON Schema with the ISO 639-5 language code validation. 368 | 369 | Args: 370 | schema: The Pydantic CoreSchema. 371 | handler: The handler to get the JSON Schema. 372 | 373 | Returns: 374 | A Pydantic JSON Schema with the ISO 639-5 language code validation. 375 | 376 | """ 377 | json_schema = handler(schema) 378 | json_schema.update({'enum': cls.allowed_values_list}) 379 | return json_schema 380 | -------------------------------------------------------------------------------- /pydantic_extra_types/mac_address.py: -------------------------------------------------------------------------------- 1 | """The MAC address module provides functionality to parse and validate MAC addresses in different 2 | formats, such as IEEE 802 MAC-48, EUI-48, EUI-64, or a 20-octet format. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Any 8 | 9 | from pydantic import GetCoreSchemaHandler 10 | from pydantic_core import PydanticCustomError, core_schema 11 | 12 | 13 | class MacAddress(str): 14 | """Represents a MAC address and provides methods for conversion, validation, and serialization. 15 | 16 | ```py 17 | from pydantic import BaseModel 18 | 19 | from pydantic_extra_types.mac_address import MacAddress 20 | 21 | 22 | class Network(BaseModel): 23 | mac_address: MacAddress 24 | 25 | 26 | network = Network(mac_address='00:00:5e:00:53:01') 27 | print(network) 28 | # > mac_address='00:00:5e:00:53:01' 29 | ``` 30 | """ 31 | 32 | @classmethod 33 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 34 | """Return a Pydantic CoreSchema with the MAC address validation. 35 | 36 | Args: 37 | source: The source type to be converted. 38 | handler: The handler to get the CoreSchema. 39 | 40 | Returns: 41 | A Pydantic CoreSchema with the MAC address validation. 42 | 43 | """ 44 | return core_schema.with_info_before_validator_function( 45 | cls._validate, 46 | core_schema.str_schema(), 47 | ) 48 | 49 | @classmethod 50 | def _validate(cls, __input_value: str, _: Any) -> str: 51 | """Validate a MAC Address from the provided str value. 52 | 53 | Args: 54 | __input_value: The str value to be validated. 55 | _: The source type to be converted. 56 | 57 | Returns: 58 | str: The parsed MAC address. 59 | 60 | """ 61 | return cls.validate_mac_address(__input_value.encode()) 62 | 63 | @staticmethod 64 | def validate_mac_address(value: bytes) -> str: 65 | """Validate a MAC Address from the provided byte value.""" 66 | string = value.decode() 67 | if len(string) < 14: 68 | raise PydanticCustomError( 69 | 'mac_address_len', 70 | 'Length for a {mac_address} MAC address must be {required_length}', 71 | {'mac_address': string, 'required_length': 14}, 72 | ) 73 | for sep, partbytes in ((':', 2), ('-', 2), ('.', 4)): 74 | if sep in string: 75 | parts = string.split(sep) 76 | if any(len(part) != partbytes for part in parts): 77 | raise PydanticCustomError( 78 | 'mac_address_format', 79 | f'Must have the format xx{sep}xx{sep}xx{sep}xx{sep}xx{sep}xx', 80 | ) 81 | if len(parts) * partbytes // 2 not in (6, 8, 20): 82 | raise PydanticCustomError( 83 | 'mac_address_format', 84 | 'Length for a {mac_address} MAC address must be {required_length}', 85 | {'mac_address': string, 'required_length': (6, 8, 20)}, 86 | ) 87 | mac_address = [] 88 | for part in parts: 89 | for idx in range(0, partbytes, 2): 90 | try: 91 | byte_value = int(part[idx : idx + 2], 16) 92 | except ValueError as exc: 93 | raise PydanticCustomError('mac_address_format', 'Unrecognized format') from exc 94 | else: 95 | mac_address.append(byte_value) 96 | return ':'.join(f'{b:02x}' for b in mac_address) 97 | else: 98 | raise PydanticCustomError('mac_address_format', 'Unrecognized format') 99 | -------------------------------------------------------------------------------- /pydantic_extra_types/mongo_object_id.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation for MongoDB ObjectId fields. 3 | 4 | Ref: https://github.com/pydantic/pydantic-extra-types/issues/133 5 | """ 6 | 7 | from typing import Any 8 | 9 | from pydantic import GetCoreSchemaHandler 10 | from pydantic_core import core_schema 11 | 12 | try: 13 | from bson import ObjectId 14 | except ModuleNotFoundError as e: # pragma: no cover 15 | raise RuntimeError( 16 | 'The `mongo_object_id` module requires "pymongo" to be installed. You can install it with "pip install ' 17 | 'pymongo".' 18 | ) from e 19 | 20 | 21 | class MongoObjectId(str): 22 | """MongoObjectId parses and validates MongoDB bson.ObjectId. 23 | 24 | ```py 25 | from pydantic import BaseModel 26 | 27 | from pydantic_extra_types.mongo_object_id import MongoObjectId 28 | 29 | 30 | class MongoDocument(BaseModel): 31 | id: MongoObjectId 32 | 33 | 34 | doc = MongoDocument(id='5f9f2f4b9d3c5a7b4c7e6c1d') 35 | print(doc) 36 | # > id='5f9f2f4b9d3c5a7b4c7e6c1d' 37 | ``` 38 | 39 | Raises: 40 | PydanticCustomError: If the provided value is not a valid MongoDB ObjectId. 41 | """ 42 | 43 | OBJECT_ID_LENGTH = 24 44 | 45 | @classmethod 46 | def __get_pydantic_core_schema__(cls, _: Any, __: GetCoreSchemaHandler) -> core_schema.CoreSchema: 47 | return core_schema.json_or_python_schema( 48 | json_schema=core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH), 49 | python_schema=core_schema.union_schema( 50 | [ 51 | core_schema.is_instance_schema(ObjectId), 52 | core_schema.chain_schema( 53 | [ 54 | core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH), 55 | core_schema.no_info_plain_validator_function(cls.validate), 56 | ] 57 | ), 58 | ] 59 | ), 60 | serialization=core_schema.plain_serializer_function_ser_schema(lambda x: str(x), when_used='json'), 61 | ) 62 | 63 | @classmethod 64 | def validate(cls, value: str) -> ObjectId: 65 | """Validate the MongoObjectId str is a valid ObjectId instance.""" 66 | if not ObjectId.is_valid(value): 67 | raise ValueError( 68 | f"Invalid ObjectId {value} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'." 69 | ) 70 | 71 | return ObjectId(value) 72 | -------------------------------------------------------------------------------- /pydantic_extra_types/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | 7 | import pydantic 8 | from pydantic.types import PathType 9 | from pydantic_core import core_schema 10 | from typing_extensions import Annotated 11 | 12 | ExistingPath = typing.Union[pydantic.FilePath, pydantic.DirectoryPath] 13 | 14 | 15 | @dataclass 16 | class ResolvedPathType(PathType): 17 | """A custom PathType that resolves the path to its absolute form. 18 | 19 | Args: 20 | path_type (typing.Literal['file', 'dir', 'new']): The type of path to resolve. Can be 'file', 'dir' or 'new'. 21 | 22 | Returns: 23 | Resolved path as a pathlib.Path object. 24 | 25 | Example: 26 | ```python 27 | from pydantic import BaseModel 28 | from pydantic_extra_types.path import ResolvedFilePath, ResolvedDirectoryPath, ResolvedNewPath 29 | 30 | 31 | class MyModel(BaseModel): 32 | file_path: ResolvedFilePath 33 | dir_path: ResolvedDirectoryPath 34 | new_path: ResolvedNewPath 35 | 36 | 37 | model = MyModel(file_path='~/myfile.txt', dir_path='~/mydir', new_path='~/newfile.txt') 38 | print(model.file_path) 39 | # > file_path=PosixPath('/home/user/myfile.txt') dir_path=PosixPath('/home/user/mydir') new_path=PosixPath('/home/user/newfile.txt')""" 40 | 41 | @staticmethod 42 | def validate_file(path: Path, _: core_schema.ValidationInfo) -> Path: 43 | return PathType.validate_file(path.expanduser().resolve(), _) 44 | 45 | @staticmethod 46 | def validate_directory(path: Path, _: core_schema.ValidationInfo) -> Path: 47 | return PathType.validate_directory(path.expanduser().resolve(), _) 48 | 49 | @staticmethod 50 | def validate_new(path: Path, _: core_schema.ValidationInfo) -> Path: 51 | return PathType.validate_new(path.expanduser().resolve(), _) 52 | 53 | def __hash__(self) -> int: 54 | return hash(type(self.path_type)) 55 | 56 | 57 | ResolvedFilePath = Annotated[Path, ResolvedPathType('file')] 58 | ResolvedDirectoryPath = Annotated[Path, ResolvedPathType('dir')] 59 | ResolvedNewPath = Annotated[Path, ResolvedPathType('new')] 60 | ResolvedExistingPath = typing.Union[ResolvedFilePath, ResolvedDirectoryPath] 61 | -------------------------------------------------------------------------------- /pydantic_extra_types/payment.py: -------------------------------------------------------------------------------- 1 | """The `pydantic_extra_types.payment` module provides the 2 | [`PaymentCardNumber`][pydantic_extra_types.payment.PaymentCardNumber] data type. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from enum import Enum 8 | from typing import Any, ClassVar 9 | 10 | from pydantic import GetCoreSchemaHandler 11 | from pydantic_core import PydanticCustomError, core_schema 12 | 13 | 14 | class PaymentCardBrand(str, Enum): 15 | """Payment card brands supported by the [`PaymentCardNumber`][pydantic_extra_types.payment.PaymentCardNumber].""" 16 | 17 | amex = 'American Express' 18 | mastercard = 'Mastercard' 19 | visa = 'Visa' 20 | mir = 'Mir' 21 | maestro = 'Maestro' 22 | discover = 'Discover' 23 | verve = 'Verve' 24 | dankort = 'Dankort' 25 | troy = 'Troy' 26 | unionpay = 'UnionPay' 27 | jcb = 'JCB' 28 | other = 'other' 29 | 30 | def __str__(self) -> str: 31 | return self.value 32 | 33 | 34 | class PaymentCardNumber(str): 35 | """A [payment card number](https://en.wikipedia.org/wiki/Payment_card_number).""" 36 | 37 | strip_whitespace: ClassVar[bool] = True 38 | """Whether to strip whitespace from the input value.""" 39 | min_length: ClassVar[int] = 12 40 | """The minimum length of the card number.""" 41 | max_length: ClassVar[int] = 19 42 | """The maximum length of the card number.""" 43 | bin: str 44 | """The first 6 digits of the card number.""" 45 | last4: str 46 | """The last 4 digits of the card number.""" 47 | brand: PaymentCardBrand 48 | """The brand of the card.""" 49 | 50 | def __init__(self, card_number: str): 51 | self.validate_digits(card_number) 52 | 53 | card_number = self.validate_luhn_check_digit(card_number) 54 | 55 | self.bin = card_number[:6] 56 | self.last4 = card_number[-4:] 57 | self.brand = self.validate_brand(card_number) 58 | 59 | @classmethod 60 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 61 | return core_schema.with_info_after_validator_function( 62 | cls.validate, 63 | core_schema.str_schema( 64 | min_length=cls.min_length, max_length=cls.max_length, strip_whitespace=cls.strip_whitespace 65 | ), 66 | ) 67 | 68 | @classmethod 69 | def validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> PaymentCardNumber: 70 | """Validate the `PaymentCardNumber` instance. 71 | 72 | Args: 73 | __input_value: The input value to validate. 74 | _: The validation info. 75 | 76 | Returns: 77 | The validated `PaymentCardNumber` instance. 78 | """ 79 | return cls(__input_value) 80 | 81 | @property 82 | def masked(self) -> str: 83 | """The masked card number.""" 84 | num_masked = len(self) - 10 # len(bin) + len(last4) == 10 85 | return f'{self.bin}{"*" * num_masked}{self.last4}' 86 | 87 | @classmethod 88 | def validate_digits(cls, card_number: str) -> None: 89 | """Validate that the card number is all digits. 90 | 91 | Args: 92 | card_number: The card number to validate. 93 | 94 | Raises: 95 | PydanticCustomError: If the card number is not all digits. 96 | """ 97 | if not card_number or not all('0' <= c <= '9' for c in card_number): 98 | raise PydanticCustomError('payment_card_number_digits', 'Card number is not all digits') 99 | 100 | @classmethod 101 | def validate_luhn_check_digit(cls, card_number: str) -> str: 102 | """Validate the payment card number. 103 | Based on the [Luhn algorithm](https://en.wikipedia.org/wiki/Luhn_algorithm). 104 | 105 | Args: 106 | card_number: The card number to validate. 107 | 108 | Returns: 109 | The validated card number. 110 | 111 | Raises: 112 | PydanticCustomError: If the card number is not valid. 113 | """ 114 | sum_ = int(card_number[-1]) 115 | length = len(card_number) 116 | parity = length % 2 117 | for i in range(length - 1): 118 | digit = int(card_number[i]) 119 | if i % 2 == parity: 120 | digit *= 2 121 | if digit > 9: 122 | digit -= 9 123 | sum_ += digit 124 | valid = sum_ % 10 == 0 125 | if not valid: 126 | raise PydanticCustomError('payment_card_number_luhn', 'Card number is not luhn valid') 127 | return card_number 128 | 129 | @staticmethod 130 | def validate_brand(card_number: str) -> PaymentCardBrand: 131 | """Validate length based on 132 | [BIN](https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN)) 133 | for major brands. 134 | 135 | Args: 136 | card_number: The card number to validate. 137 | 138 | Returns: 139 | The validated card brand. 140 | 141 | Raises: 142 | PydanticCustomError: If the card number is not valid. 143 | """ 144 | brand = PaymentCardBrand.other 145 | 146 | if card_number[0] == '4': 147 | brand = PaymentCardBrand.visa 148 | required_length = [13, 16, 19] 149 | elif 51 <= int(card_number[:2]) <= 55: 150 | brand = PaymentCardBrand.mastercard 151 | required_length = [16] 152 | elif card_number[:2] in {'34', '37'}: 153 | brand = PaymentCardBrand.amex 154 | required_length = [15] 155 | elif 2200 <= int(card_number[:4]) <= 2204: 156 | brand = PaymentCardBrand.mir 157 | required_length = list(range(16, 20)) 158 | elif card_number[:4] in {'5018', '5020', '5038', '5893', '6304', '6759', '6761', '6762', '6763'} or card_number[ 159 | :6 160 | ] in ( 161 | '676770', 162 | '676774', 163 | ): 164 | brand = PaymentCardBrand.maestro 165 | required_length = list(range(12, 20)) 166 | elif card_number.startswith('65') or 644 <= int(card_number[:3]) <= 649 or card_number.startswith('6011'): 167 | brand = PaymentCardBrand.discover 168 | required_length = list(range(16, 20)) 169 | elif ( 170 | 506099 <= int(card_number[:6]) <= 506198 171 | or 650002 <= int(card_number[:6]) <= 650027 172 | or 507865 <= int(card_number[:6]) <= 507964 173 | ): 174 | brand = PaymentCardBrand.verve 175 | required_length = [16, 18, 19] 176 | elif card_number[:4] in {'5019', '4571'}: 177 | brand = PaymentCardBrand.dankort 178 | required_length = [16] 179 | elif card_number.startswith('9792'): 180 | brand = PaymentCardBrand.troy 181 | required_length = [16] 182 | elif card_number[:2] in {'62', '81'}: 183 | brand = PaymentCardBrand.unionpay 184 | required_length = [16, 19] 185 | elif 3528 <= int(card_number[:4]) <= 3589: 186 | brand = PaymentCardBrand.jcb 187 | required_length = [16, 19] 188 | 189 | valid = len(card_number) in required_length if brand != PaymentCardBrand.other else True 190 | 191 | if not valid: 192 | raise PydanticCustomError( 193 | 'payment_card_number_brand', 194 | f'Length for a {brand} card must be {" or ".join(map(str, required_length))}', 195 | {'brand': brand, 'required_length': required_length}, 196 | ) 197 | 198 | return brand 199 | -------------------------------------------------------------------------------- /pydantic_extra_types/pendulum_dt.py: -------------------------------------------------------------------------------- 1 | """Native Pendulum DateTime object implementation. This is a copy of the Pendulum DateTime object, but with a Pydantic 2 | CoreSchema implementation. This allows Pydantic to validate the DateTime object. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | try: 8 | from pendulum import Date as _Date 9 | from pendulum import DateTime as _DateTime 10 | from pendulum import Duration as _Duration 11 | from pendulum import parse 12 | except ModuleNotFoundError as e: # pragma: no cover 13 | raise RuntimeError( 14 | 'The `pendulum_dt` module requires "pendulum" to be installed. You can install it with "pip install pendulum".' 15 | ) from e 16 | from datetime import date, datetime, timedelta 17 | from typing import Any 18 | 19 | from pydantic import GetCoreSchemaHandler 20 | from pydantic_core import PydanticCustomError, core_schema 21 | 22 | 23 | class DateTimeSettings(type): 24 | def __new__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def] 25 | dct['strict'] = kwargs.pop('strict', True) 26 | return super().__new__(cls, name, bases, dct) 27 | 28 | def __init__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def] 29 | super().__init__(name, bases, dct) 30 | cls.strict = kwargs.get('strict', True) 31 | 32 | 33 | class DateTime(_DateTime, metaclass=DateTimeSettings): 34 | """A `pendulum.DateTime` object. At runtime, this type decomposes into pendulum.DateTime automatically. 35 | This type exists because Pydantic throws a fit on unknown types. 36 | 37 | ```python 38 | from pydantic import BaseModel 39 | from pydantic_extra_types.pendulum_dt import DateTime 40 | 41 | 42 | class test_model(BaseModel): 43 | dt: DateTime 44 | 45 | 46 | print(test_model(dt='2021-01-01T00:00:00+00:00')) 47 | 48 | # > test_model(dt=DateTime(2021, 1, 1, 0, 0, 0, tzinfo=FixedTimezone(0, name="+00:00"))) 49 | ``` 50 | """ 51 | 52 | __slots__: list[str] = [] 53 | 54 | @classmethod 55 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 56 | """Return a Pydantic CoreSchema with the Datetime validation 57 | 58 | Args: 59 | source: The source type to be converted. 60 | handler: The handler to get the CoreSchema. 61 | 62 | Returns: 63 | A Pydantic CoreSchema with the Datetime validation. 64 | """ 65 | return core_schema.no_info_wrap_validator_function(cls._validate, core_schema.datetime_schema()) 66 | 67 | @classmethod 68 | def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> DateTime: 69 | """Validate the datetime object and return it. 70 | 71 | Args: 72 | value: The value to validate. 73 | handler: The handler to get the CoreSchema. 74 | 75 | Returns: 76 | The validated value or raises a PydanticCustomError. 77 | """ 78 | # if we are passed an existing instance, pass it straight through. 79 | if isinstance(value, (_DateTime, datetime)): 80 | return DateTime.instance(value) 81 | try: 82 | # probably the best way to have feature parity with 83 | # https://docs.pydantic.dev/latest/api/standard_library_types/#datetimedatetime 84 | value = handler(value) 85 | return DateTime.instance(value) 86 | except ValueError: 87 | try: 88 | value = parse(value, strict=cls.strict) 89 | if isinstance(value, _DateTime): 90 | return DateTime.instance(value) 91 | raise ValueError(f'value is not a valid datetime it is a {type(value)}') 92 | except ValueError: 93 | raise 94 | except Exception as exc: 95 | raise PydanticCustomError('value_error', 'value is not a valid datetime') from exc 96 | 97 | 98 | class Date(_Date): 99 | """A `pendulum.Date` object. At runtime, this type decomposes into pendulum.Date automatically. 100 | This type exists because Pydantic throws a fit on unknown types. 101 | 102 | ```python 103 | from pydantic import BaseModel 104 | from pydantic_extra_types.pendulum_dt import Date 105 | 106 | 107 | class test_model(BaseModel): 108 | dt: Date 109 | 110 | 111 | print(test_model(dt='2021-01-01')) 112 | 113 | # > test_model(dt=Date(2021, 1, 1)) 114 | ``` 115 | """ 116 | 117 | __slots__: list[str] = [] 118 | 119 | @classmethod 120 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 121 | """Return a Pydantic CoreSchema with the Date validation 122 | 123 | Args: 124 | source: The source type to be converted. 125 | handler: The handler to get the CoreSchema. 126 | 127 | Returns: 128 | A Pydantic CoreSchema with the Date validation. 129 | """ 130 | return core_schema.no_info_wrap_validator_function(cls._validate, core_schema.date_schema()) 131 | 132 | @classmethod 133 | def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Date: 134 | """Validate the date object and return it. 135 | 136 | Args: 137 | value: The value to validate. 138 | handler: The handler to get the CoreSchema. 139 | 140 | Returns: 141 | The validated value or raises a PydanticCustomError. 142 | """ 143 | # if we are passed an existing instance, pass it straight through. 144 | if isinstance(value, (_Date, date)): 145 | return Date(value.year, value.month, value.day) 146 | 147 | # otherwise, parse it. 148 | try: 149 | parsed = parse(value) 150 | if isinstance(parsed, (_DateTime, _Date)): 151 | return Date(parsed.year, parsed.month, parsed.day) 152 | raise ValueError('value is not a valid date it is a {type(parsed)}') 153 | except Exception as exc: 154 | raise PydanticCustomError('value_error', 'value is not a valid date') from exc 155 | 156 | 157 | class Duration(_Duration): 158 | """A `pendulum.Duration` object. At runtime, this type decomposes into pendulum.Duration automatically. 159 | This type exists because Pydantic throws a fit on unknown types. 160 | 161 | ```python 162 | from pydantic import BaseModel 163 | from pydantic_extra_types.pendulum_dt import Duration 164 | 165 | 166 | class test_model(BaseModel): 167 | delta_t: Duration 168 | 169 | 170 | print(test_model(delta_t='P1DT25H')) 171 | 172 | # > test_model(delta_t=Duration(days=2, hours=1)) 173 | ``` 174 | """ 175 | 176 | __slots__: list[str] = [] 177 | 178 | @classmethod 179 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 180 | """Return a Pydantic CoreSchema with the Duration validation 181 | 182 | Args: 183 | source: The source type to be converted. 184 | handler: The handler to get the CoreSchema. 185 | 186 | Returns: 187 | A Pydantic CoreSchema with the Duration validation. 188 | """ 189 | return core_schema.no_info_wrap_validator_function( 190 | cls._validate, 191 | core_schema.timedelta_schema(), 192 | serialization=core_schema.plain_serializer_function_ser_schema( 193 | lambda instance: instance.to_iso8601_string(), when_used='json-unless-none' 194 | ), 195 | ) 196 | 197 | def to_iso8601_string(self) -> str: 198 | """ 199 | Convert a Duration object to an ISO 8601 string. 200 | 201 | In addition to the standard ISO 8601 format, this method also supports the representation of fractions of a second and negative durations. 202 | 203 | Args: 204 | duration (Duration): The Duration object. 205 | 206 | Returns: 207 | str: The ISO 8601 string representation of the duration. 208 | """ 209 | # Extracting components from the Duration object 210 | years = self.years 211 | months = self.months 212 | days = self._days 213 | hours = self.hours 214 | minutes = self.minutes 215 | seconds = self.remaining_seconds 216 | milliseconds = self.microseconds // 1000 217 | microseconds = self.microseconds % 1000 218 | 219 | # Constructing the ISO 8601 duration string 220 | iso_duration = 'P' 221 | if years or months or days: 222 | if years: 223 | iso_duration += f'{years}Y' 224 | if months: 225 | iso_duration += f'{months}M' 226 | if days: 227 | iso_duration += f'{days}D' 228 | 229 | if hours or minutes or seconds or milliseconds or microseconds: 230 | iso_duration += 'T' 231 | if hours: 232 | iso_duration += f'{hours}H' 233 | if minutes: 234 | iso_duration += f'{minutes}M' 235 | if seconds or milliseconds or microseconds: 236 | iso_duration += f'{seconds}' 237 | if milliseconds or microseconds: 238 | iso_duration += f'.{milliseconds:03d}' 239 | if microseconds: 240 | iso_duration += f'{microseconds:03d}' 241 | iso_duration += 'S' 242 | 243 | # Prefix with '-' if the duration is negative 244 | if self.total_seconds() < 0: 245 | iso_duration = '-' + iso_duration 246 | 247 | if iso_duration == 'P': 248 | iso_duration = 'P0D' 249 | 250 | return iso_duration 251 | 252 | @classmethod 253 | def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Duration: 254 | """Validate the Duration object and return it. 255 | 256 | Args: 257 | value: The value to validate. 258 | handler: The handler to get the CoreSchema. 259 | 260 | Returns: 261 | The validated value or raises a PydanticCustomError. 262 | """ 263 | 264 | if isinstance(value, _Duration): 265 | return Duration( 266 | years=value.years, 267 | months=value.months, 268 | weeks=value.weeks, 269 | days=value.remaining_days, 270 | hours=value.hours, 271 | minutes=value.minutes, 272 | seconds=value.remaining_seconds, 273 | microseconds=value.microseconds, 274 | ) 275 | 276 | if isinstance(value, timedelta): 277 | return Duration( 278 | days=value.days, 279 | seconds=value.seconds, 280 | microseconds=value.microseconds, 281 | ) 282 | 283 | assert isinstance(value, str) 284 | try: 285 | # https://github.com/python-pendulum/pendulum/issues/532 286 | if value.startswith('-'): 287 | parsed = parse(value.lstrip('-'), exact=True) 288 | else: 289 | parsed = parse(value, exact=True) 290 | if not isinstance(parsed, _Duration): 291 | raise ValueError(f'value is not a valid duration it is a {type(parsed)}') 292 | if value.startswith('-'): 293 | parsed = -parsed 294 | 295 | return Duration( 296 | years=parsed.years, 297 | months=parsed.months, 298 | weeks=parsed.weeks, 299 | days=parsed.remaining_days, 300 | hours=parsed.hours, 301 | minutes=parsed.minutes, 302 | seconds=parsed.remaining_seconds, 303 | microseconds=parsed.microseconds, 304 | ) 305 | except Exception as exc: 306 | raise PydanticCustomError('value_error', 'value is not a valid duration') from exc 307 | -------------------------------------------------------------------------------- /pydantic_extra_types/phone_numbers.py: -------------------------------------------------------------------------------- 1 | """The `pydantic_extra_types.phone_numbers` module provides the 2 | [`PhoneNumber`][pydantic_extra_types.phone_numbers.PhoneNumber] data type. 3 | 4 | This class depends on the [phonenumbers] package, which is a Python port of Google's [libphonenumber]. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from collections.abc import Sequence 10 | from dataclasses import dataclass 11 | from functools import partial 12 | from typing import Any, ClassVar 13 | 14 | from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler 15 | from pydantic_core import PydanticCustomError, core_schema 16 | 17 | try: 18 | import phonenumbers 19 | from phonenumbers import PhoneNumber as BasePhoneNumber 20 | from phonenumbers.phonenumberutil import NumberParseException 21 | except ModuleNotFoundError as e: # pragma: no cover 22 | raise RuntimeError( 23 | '`PhoneNumber` requires "phonenumbers" to be installed. You can install it with "pip install phonenumbers"' 24 | ) from e 25 | 26 | 27 | class PhoneNumber(str): 28 | """A wrapper around [phonenumbers](https://pypi.org/project/phonenumbers/) package, which 29 | is a Python port of Google's [libphonenumber](https://github.com/google/libphonenumber/). 30 | """ 31 | 32 | supported_regions: list[str] = [] 33 | """The supported regions. If empty, all regions are supported.""" 34 | 35 | default_region_code: ClassVar[str | None] = None 36 | """The default region code to use when parsing phone numbers without an international prefix.""" 37 | phone_format: str = 'RFC3966' 38 | """The format of the phone number.""" 39 | 40 | @classmethod 41 | def __get_pydantic_json_schema__( 42 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 43 | ) -> dict[str, Any]: 44 | json_schema = handler(schema) 45 | json_schema.update({'format': 'phone'}) 46 | return json_schema 47 | 48 | @classmethod 49 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 50 | return core_schema.with_info_after_validator_function( 51 | cls._validate, 52 | core_schema.str_schema(), 53 | ) 54 | 55 | @classmethod 56 | def _validate(cls, phone_number: str, _: core_schema.ValidationInfo) -> str: 57 | try: 58 | parsed_number = phonenumbers.parse(phone_number, cls.default_region_code) 59 | except phonenumbers.phonenumberutil.NumberParseException as exc: 60 | raise PydanticCustomError('value_error', 'value is not a valid phone number') from exc 61 | if not phonenumbers.is_valid_number(parsed_number): 62 | raise PydanticCustomError('value_error', 'value is not a valid phone number') 63 | 64 | if cls.supported_regions and not any( 65 | phonenumbers.is_valid_number_for_region(parsed_number, region_code=region) 66 | for region in cls.supported_regions 67 | ): 68 | raise PydanticCustomError('value_error', 'value is not from a supported region') 69 | 70 | return phonenumbers.format_number(parsed_number, getattr(phonenumbers.PhoneNumberFormat, cls.phone_format)) 71 | 72 | def __eq__(self, other: Any) -> bool: 73 | return super().__eq__(other) 74 | 75 | def __hash__(self) -> int: 76 | return super().__hash__() 77 | 78 | 79 | @dataclass(frozen=True) 80 | class PhoneNumberValidator: 81 | """A pydantic before validator for phone numbers using the [phonenumbers](https://pypi.org/project/phonenumbers/) package, 82 | a Python port of Google's [libphonenumber](https://github.com/google/libphonenumber/). 83 | 84 | Intended to be used to create custom pydantic data types using the `typing.Annotated` type construct. 85 | 86 | Args: 87 | default_region (str | None): The default region code to use when parsing phone numbers without an international prefix. 88 | If `None` (default), the region must be supplied in the phone number as an international prefix. 89 | number_format (str): The format of the phone number to return. See `phonenumbers.PhoneNumberFormat` for valid values. 90 | supported_regions (list[str]): The supported regions. If empty, all regions are supported (default). 91 | 92 | Returns: 93 | The formatted phone number. 94 | 95 | Example: 96 | MyNumberType = Annotated[ 97 | Union[str, phonenumbers.PhoneNumber], 98 | PhoneNumberValidator() 99 | ] 100 | USNumberType = Annotated[ 101 | Union[str, phonenumbers.PhoneNumber], 102 | PhoneNumberValidator(supported_regions=['US'], default_region='US') 103 | ] 104 | 105 | class SomeModel(BaseModel): 106 | phone_number: MyNumberType 107 | us_number: USNumberType 108 | """ 109 | 110 | default_region: str | None = None 111 | number_format: str = 'RFC3966' 112 | supported_regions: Sequence[str] | None = None 113 | 114 | def __post_init__(self) -> None: 115 | if self.default_region and self.default_region not in phonenumbers.SUPPORTED_REGIONS: 116 | raise ValueError(f'Invalid default region code: {self.default_region}') 117 | 118 | if self.number_format not in ( 119 | number_format 120 | for number_format in dir(phonenumbers.PhoneNumberFormat) 121 | if not number_format.startswith('_') and number_format.isupper() 122 | ): 123 | raise ValueError(f'Invalid number format: {self.number_format}') 124 | 125 | if self.supported_regions: 126 | for supported_region in self.supported_regions: 127 | if supported_region not in phonenumbers.SUPPORTED_REGIONS: 128 | raise ValueError(f'Invalid supported region code: {supported_region}') 129 | 130 | @staticmethod 131 | def _parse( 132 | region: str | None, 133 | number_format: str, 134 | supported_regions: Sequence[str] | None, 135 | phone_number: Any, 136 | ) -> str: 137 | if not phone_number: 138 | raise PydanticCustomError('value_error', 'value is not a valid phone number') 139 | 140 | if not isinstance(phone_number, (str, BasePhoneNumber)): 141 | raise PydanticCustomError('value_error', 'value is not a valid phone number') 142 | 143 | parsed_number = None 144 | if isinstance(phone_number, BasePhoneNumber): 145 | parsed_number = phone_number 146 | else: 147 | try: 148 | parsed_number = phonenumbers.parse(phone_number, region=region) 149 | except NumberParseException as exc: 150 | raise PydanticCustomError('value_error', 'value is not a valid phone number') from exc 151 | 152 | if not phonenumbers.is_valid_number(parsed_number): 153 | raise PydanticCustomError('value_error', 'value is not a valid phone number') 154 | 155 | if supported_regions and not any( 156 | phonenumbers.is_valid_number_for_region(parsed_number, region_code=region) for region in supported_regions 157 | ): 158 | raise PydanticCustomError('value_error', 'value is not from a supported region') 159 | 160 | return phonenumbers.format_number(parsed_number, getattr(phonenumbers.PhoneNumberFormat, number_format)) 161 | 162 | def __get_pydantic_core_schema__(self, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 163 | return core_schema.no_info_before_validator_function( 164 | partial( 165 | self._parse, 166 | self.default_region, 167 | self.number_format, 168 | self.supported_regions, 169 | ), 170 | core_schema.str_schema(), 171 | ) 172 | 173 | def __get_pydantic_json_schema__( 174 | self, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 175 | ) -> dict[str, Any]: 176 | json_schema = handler(schema) 177 | json_schema.update({'format': 'phone'}) 178 | return json_schema 179 | 180 | def __hash__(self) -> int: 181 | return super().__hash__() 182 | -------------------------------------------------------------------------------- /pydantic_extra_types/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pydantic/pydantic-extra-types/7332abb2337527c6d259b9b07bde17012da020a7/pydantic_extra_types/py.typed -------------------------------------------------------------------------------- /pydantic_extra_types/routing_number.py: -------------------------------------------------------------------------------- 1 | """The `pydantic_extra_types.routing_number` module provides the 2 | [`ABARoutingNumber`][pydantic_extra_types.routing_number.ABARoutingNumber] data type. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import itertools as it 8 | from typing import Any, ClassVar 9 | 10 | from pydantic import GetCoreSchemaHandler 11 | from pydantic_core import PydanticCustomError, core_schema 12 | 13 | 14 | class ABARoutingNumber(str): 15 | """The `ABARoutingNumber` data type is a string of 9 digits representing an ABA routing transit number. 16 | 17 | The algorithm used to validate the routing number is described in the 18 | [ABA routing transit number](https://en.wikipedia.org/wiki/ABA_routing_transit_number#Check_digit) 19 | Wikipedia article. 20 | 21 | ```py 22 | from pydantic import BaseModel 23 | 24 | from pydantic_extra_types.routing_number import ABARoutingNumber 25 | 26 | 27 | class BankAccount(BaseModel): 28 | routing_number: ABARoutingNumber 29 | 30 | 31 | account = BankAccount(routing_number='122105155') 32 | print(account) 33 | # > routing_number='122105155' 34 | ``` 35 | """ 36 | 37 | strip_whitespace: ClassVar[bool] = True 38 | min_length: ClassVar[int] = 9 39 | max_length: ClassVar[int] = 9 40 | 41 | def __init__(self, routing_number: str): 42 | self._validate_digits(routing_number) 43 | self._routing_number = self._validate_routing_number(routing_number) 44 | 45 | @classmethod 46 | def __get_pydantic_core_schema__( 47 | cls, source: type[Any], handler: GetCoreSchemaHandler 48 | ) -> core_schema.AfterValidatorFunctionSchema: 49 | return core_schema.with_info_after_validator_function( 50 | cls._validate, 51 | core_schema.str_schema( 52 | min_length=cls.min_length, 53 | max_length=cls.max_length, 54 | strip_whitespace=cls.strip_whitespace, 55 | strict=False, 56 | ), 57 | ) 58 | 59 | @classmethod 60 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> ABARoutingNumber: 61 | return cls(__input_value) 62 | 63 | @classmethod 64 | def _validate_digits(cls, routing_number: str) -> None: 65 | """Check that the routing number is all digits. 66 | 67 | Args: 68 | routing_number: The routing number to validate. 69 | 70 | Raises: 71 | PydanticCustomError: If the routing number is not all digits. 72 | """ 73 | if not routing_number.isdigit(): 74 | raise PydanticCustomError('aba_routing_number', 'routing number is not all digits') 75 | 76 | @classmethod 77 | def _validate_routing_number(cls, routing_number: str) -> str: 78 | """Check [digit algorithm](https://en.wikipedia.org/wiki/ABA_routing_transit_number#Check_digit) for 79 | [ABA routing transit number](https://www.routingnumber.com/). 80 | 81 | Args: 82 | routing_number: The routing number to validate. 83 | 84 | Raises: 85 | PydanticCustomError: If the routing number is incorrect. 86 | """ 87 | checksum = sum(int(digit) * factor for digit, factor in zip(routing_number, it.cycle((3, 7, 1)))) 88 | if checksum % 10: 89 | raise PydanticCustomError('aba_routing_number', 'Incorrect ABA routing transit number') 90 | return routing_number 91 | -------------------------------------------------------------------------------- /pydantic_extra_types/s3.py: -------------------------------------------------------------------------------- 1 | """The `pydantic_extra_types.s3` module provides the 2 | [`S3Path`][pydantic_extra_types.s3.S3Path] data type. 3 | 4 | A simpleAWS S3 URLs parser. 5 | It also provides the `Bucket`, `Key` component. 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import re 11 | from typing import Any, ClassVar 12 | 13 | from pydantic import GetCoreSchemaHandler 14 | from pydantic_core import core_schema 15 | 16 | 17 | class S3Path(str): 18 | """An object representing a valid S3 path. 19 | This type also allows you to access the `bucket` and `key` component of the S3 path. 20 | It also contains the `last_key` which represents the last part of the path (tipically a file). 21 | 22 | ```python 23 | from pydantic import BaseModel 24 | from pydantic_extra_types.s3 import S3Path 25 | 26 | 27 | class TestModel(BaseModel): 28 | path: S3Path 29 | 30 | 31 | p = 's3://my-data-bucket/2023/08/29/sales-report.csv' 32 | model = TestModel(path=p) 33 | model 34 | 35 | # > TestModel(path=S3Path('s3://my-data-bucket/2023/08/29/sales-report.csv')) 36 | 37 | model.path.bucket 38 | 39 | # > 'my-data-bucket' 40 | ``` 41 | """ 42 | 43 | patt: ClassVar[str] = r'^s3://([^/]+)/(.*?([^/]+)/?)$' 44 | 45 | def __init__(self, value: str) -> None: 46 | self.value = value 47 | groups: tuple[str, str, str] = re.match(self.patt, self.value).groups() # type: ignore 48 | self.bucket: str = groups[0] 49 | self.key: str = groups[1] 50 | self.last_key: str = groups[2] 51 | 52 | def __str__(self) -> str: # pragma: no cover 53 | return self.value 54 | 55 | def __repr__(self) -> str: # pragma: no cover 56 | return f'{self.__class__.__name__}({self.value!r})' 57 | 58 | @classmethod 59 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> S3Path: 60 | return cls(__input_value) 61 | 62 | @classmethod 63 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 64 | _, _ = source, handler 65 | return core_schema.with_info_after_validator_function( 66 | cls._validate, 67 | core_schema.str_schema(pattern=cls.patt), 68 | field_name=cls.__class__.__name__, 69 | ) 70 | -------------------------------------------------------------------------------- /pydantic_extra_types/script_code.py: -------------------------------------------------------------------------------- 1 | """script definitions that are based on the [ISO 15924](https://en.wikipedia.org/wiki/ISO_15924)""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler 8 | from pydantic_core import PydanticCustomError, core_schema 9 | 10 | try: 11 | import pycountry 12 | except ModuleNotFoundError as e: # pragma: no cover 13 | raise RuntimeError( 14 | 'The `script_code` module requires "pycountry" to be installed.' 15 | ' You can install it with "pip install pycountry".' 16 | ) from e 17 | 18 | 19 | class ISO_15924(str): 20 | """ISO_15924 parses script in the [ISO 15924](https://en.wikipedia.org/wiki/ISO_15924) 21 | format. 22 | 23 | ```py 24 | from pydantic import BaseModel 25 | 26 | from pydantic_extra_types.language_code import ISO_15924 27 | 28 | 29 | class Script(BaseModel): 30 | alpha_4: ISO_15924 31 | 32 | 33 | script = Script(alpha_4='Java') 34 | print(lang) 35 | # > script='Java' 36 | ``` 37 | """ 38 | 39 | allowed_values_list = [script.alpha_4 for script in pycountry.scripts] 40 | allowed_values = set(allowed_values_list) 41 | 42 | @classmethod 43 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> ISO_15924: 44 | """Validate a ISO 15924 language code from the provided str value. 45 | 46 | Args: 47 | __input_value: The str value to be validated. 48 | _: The Pydantic ValidationInfo. 49 | 50 | Returns: 51 | The validated ISO 15924 script code. 52 | 53 | Raises: 54 | PydanticCustomError: If the ISO 15924 script code is not valid. 55 | """ 56 | if __input_value not in cls.allowed_values: 57 | raise PydanticCustomError( 58 | 'ISO_15924', 'Invalid ISO 15924 script code. See https://en.wikipedia.org/wiki/ISO_15924' 59 | ) 60 | return cls(__input_value) 61 | 62 | @classmethod 63 | def __get_pydantic_core_schema__( 64 | cls, _: type[Any], __: GetCoreSchemaHandler 65 | ) -> core_schema.AfterValidatorFunctionSchema: 66 | """Return a Pydantic CoreSchema with the ISO 639-3 language code validation. 67 | 68 | Args: 69 | _: The source type. 70 | __: The handler to get the CoreSchema. 71 | 72 | Returns: 73 | A Pydantic CoreSchema with the ISO 639-3 language code validation. 74 | 75 | """ 76 | return core_schema.with_info_after_validator_function( 77 | cls._validate, 78 | core_schema.str_schema(min_length=4, max_length=4), 79 | ) 80 | 81 | @classmethod 82 | def __get_pydantic_json_schema__( 83 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 84 | ) -> dict[str, Any]: 85 | """Return a Pydantic JSON Schema with the ISO 639-3 language code validation. 86 | 87 | Args: 88 | schema: The Pydantic CoreSchema. 89 | handler: The handler to get the JSON Schema. 90 | 91 | Returns: 92 | A Pydantic JSON Schema with the ISO 639-3 language code validation. 93 | 94 | """ 95 | json_schema = handler(schema) 96 | json_schema.update({'enum': cls.allowed_values_list}) 97 | return json_schema 98 | -------------------------------------------------------------------------------- /pydantic_extra_types/semantic_version.py: -------------------------------------------------------------------------------- 1 | """SemanticVersion definition that is based on the Semantiv Versioning Specification [semver](https://semver.org/).""" 2 | 3 | from typing import Any, Callable 4 | 5 | from pydantic import GetJsonSchemaHandler 6 | from pydantic.json_schema import JsonSchemaValue 7 | from pydantic_core import core_schema 8 | 9 | try: 10 | import semver 11 | except ModuleNotFoundError as e: # pragma: no cover 12 | raise RuntimeError( 13 | 'The `semantic_version` module requires "semver" to be installed. You can install it with "pip install semver".' 14 | ) from e 15 | 16 | 17 | class SemanticVersion(semver.Version): 18 | """Semantic version based on the official [semver thread](https://python-semver.readthedocs.io/en/latest/advanced/combine-pydantic-and-semver.html).""" 19 | 20 | @classmethod 21 | def __get_pydantic_core_schema__( 22 | cls, 23 | _source_type: Any, 24 | _handler: Callable[[Any], core_schema.CoreSchema], 25 | ) -> core_schema.CoreSchema: 26 | def validate_from_str(value: str) -> SemanticVersion: 27 | return cls.parse(value) 28 | 29 | from_str_schema = core_schema.chain_schema( 30 | [ 31 | core_schema.str_schema(), 32 | core_schema.no_info_plain_validator_function(validate_from_str), 33 | ] 34 | ) 35 | 36 | return core_schema.json_or_python_schema( 37 | json_schema=from_str_schema, 38 | python_schema=core_schema.union_schema( 39 | [ 40 | core_schema.is_instance_schema(semver.Version), 41 | from_str_schema, 42 | ] 43 | ), 44 | serialization=core_schema.to_string_ser_schema(), 45 | ) 46 | 47 | @classmethod 48 | def __get_pydantic_json_schema__( 49 | cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 50 | ) -> JsonSchemaValue: 51 | return handler( 52 | core_schema.str_schema( 53 | pattern=r'^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$' 54 | ) 55 | ) 56 | 57 | @classmethod 58 | def validate_from_str(cls, value: str) -> 'SemanticVersion': 59 | return cls.parse(value) 60 | -------------------------------------------------------------------------------- /pydantic_extra_types/semver.py: -------------------------------------------------------------------------------- 1 | """The _VersionPydanticAnnotation class provides functionality to parse and validate Semantic Versioning (SemVer) strings. 2 | 3 | This class depends on the [semver](https://python-semver.readthedocs.io/en/latest/index.html) package. 4 | """ 5 | 6 | import warnings 7 | from typing import Any, Callable 8 | 9 | from pydantic import GetJsonSchemaHandler 10 | from pydantic.json_schema import JsonSchemaValue 11 | from pydantic_core import core_schema 12 | from semver import Version 13 | from typing_extensions import Annotated 14 | 15 | warnings.warn( 16 | 'Use from pydantic_extra_types.semver import SemanticVersion instead. Will be removed in 3.0.0.', DeprecationWarning 17 | ) 18 | 19 | 20 | class _VersionPydanticAnnotation(Version): 21 | """Represents a Semantic Versioning (SemVer). 22 | 23 | Wraps the `version` type from `semver`. 24 | 25 | Example: 26 | ```python 27 | from pydantic import BaseModel 28 | 29 | from pydantic_extra_types.semver import _VersionPydanticAnnotation 30 | 31 | 32 | class appVersion(BaseModel): 33 | version: _VersionPydanticAnnotation 34 | 35 | 36 | app_version = appVersion(version='1.2.3') 37 | 38 | print(app_version.version) 39 | # > 1.2.3 40 | ``` 41 | """ 42 | 43 | @classmethod 44 | def __get_pydantic_core_schema__( 45 | cls, 46 | _source_type: Any, 47 | _handler: Callable[[Any], core_schema.CoreSchema], 48 | ) -> core_schema.CoreSchema: 49 | def validate_from_str(value: str) -> Version: 50 | return Version.parse(value) 51 | 52 | from_str_schema = core_schema.chain_schema( 53 | [ 54 | core_schema.str_schema(), 55 | core_schema.no_info_plain_validator_function(validate_from_str), 56 | ] 57 | ) 58 | 59 | return core_schema.json_or_python_schema( 60 | json_schema=from_str_schema, 61 | python_schema=core_schema.union_schema( 62 | [ 63 | core_schema.is_instance_schema(Version), 64 | from_str_schema, 65 | ] 66 | ), 67 | serialization=core_schema.to_string_ser_schema(), 68 | ) 69 | 70 | @classmethod 71 | def __get_pydantic_json_schema__( 72 | cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 73 | ) -> JsonSchemaValue: 74 | return handler(core_schema.str_schema()) 75 | 76 | 77 | ManifestVersion = Annotated[Version, _VersionPydanticAnnotation] 78 | -------------------------------------------------------------------------------- /pydantic_extra_types/timezone_name.py: -------------------------------------------------------------------------------- 1 | """Time zone name validation and serialization module.""" 2 | 3 | from __future__ import annotations 4 | 5 | import importlib 6 | import sys 7 | import warnings 8 | from typing import Any, Callable, cast 9 | 10 | from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler 11 | from pydantic_core import PydanticCustomError, core_schema 12 | 13 | 14 | def _is_available(name: str) -> bool: 15 | """Check if a module is available for import.""" 16 | try: 17 | importlib.import_module(name) 18 | return True 19 | except ModuleNotFoundError: # pragma: no cover 20 | return False 21 | 22 | 23 | def _tz_provider_from_zone_info() -> set[str]: # pragma: no cover 24 | """Get timezones from the zoneinfo module.""" 25 | from zoneinfo import available_timezones 26 | 27 | return set(available_timezones()) 28 | 29 | 30 | def _tz_provider_from_pytz() -> set[str]: # pragma: no cover 31 | """Get timezones from the pytz module.""" 32 | from pytz import all_timezones 33 | 34 | return set(all_timezones) 35 | 36 | 37 | def _warn_about_pytz_usage() -> None: 38 | """Warn about using pytz with Python 3.9 or later.""" 39 | warnings.warn( # pragma: no cover 40 | 'Projects using Python 3.9 or later should be using the support now included as part of the standard library. ' 41 | 'Please consider switching to the standard library (zoneinfo) module.' 42 | ) 43 | 44 | 45 | def get_timezones() -> set[str]: 46 | """Determine the timezone provider and return available timezones.""" 47 | if _is_available('zoneinfo'): # pragma: no cover 48 | timezones = _tz_provider_from_zone_info() 49 | if len(timezones) == 0: # pragma: no cover 50 | raise ImportError('No timezone provider found. Please install tzdata with "pip install tzdata"') 51 | return timezones 52 | elif _is_available('pytz'): # pragma: no cover 53 | return _tz_provider_from_pytz() 54 | else: # pragma: no cover 55 | if sys.version_info[:2] == (3, 8): 56 | raise ImportError('No pytz module found. Please install it with "pip install pytz"') 57 | raise ImportError('No timezone provider found. Please install tzdata with "pip install tzdata"') 58 | 59 | 60 | class TimeZoneNameSettings(type): 61 | def __new__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any) -> type[TimeZoneName]: 62 | dct['strict'] = kwargs.pop('strict', True) 63 | return cast('type[TimeZoneName]', super().__new__(cls, name, bases, dct)) 64 | 65 | def __init__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any) -> None: 66 | super().__init__(name, bases, dct) 67 | cls.strict = kwargs.get('strict', True) 68 | 69 | 70 | def timezone_name_settings(**kwargs: Any) -> Callable[[type[TimeZoneName]], type[TimeZoneName]]: 71 | def wrapper(cls: type[TimeZoneName]) -> type[TimeZoneName]: 72 | cls.strict = kwargs.get('strict', True) 73 | return cls 74 | 75 | return wrapper 76 | 77 | 78 | @timezone_name_settings(strict=True) 79 | class TimeZoneName(str): 80 | """TimeZoneName is a custom string subclass for validating and serializing timezone names. 81 | 82 | The TimeZoneName class uses the IANA Time Zone Database for validation. 83 | It supports both strict and non-strict modes for timezone name validation. 84 | 85 | 86 | ## Examples: 87 | 88 | Some examples of using the TimeZoneName class: 89 | 90 | ### Normal usage: 91 | 92 | ```python 93 | from pydantic_extra_types.timezone_name import TimeZoneName 94 | from pydantic import BaseModel 95 | class Location(BaseModel): 96 | city: str 97 | timezone: TimeZoneName 98 | 99 | loc = Location(city="New York", timezone="America/New_York") 100 | print(loc.timezone) 101 | 102 | >> America/New_York 103 | 104 | ``` 105 | 106 | ### Non-strict mode: 107 | 108 | ```python 109 | 110 | from pydantic_extra_types.timezone_name import TimeZoneName, timezone_name_settings 111 | 112 | @timezone_name_settings(strict=False) 113 | class TZNonStrict(TimeZoneName): 114 | pass 115 | 116 | tz = TZNonStrict("america/new_york") 117 | 118 | print(tz) 119 | 120 | >> america/new_york 121 | 122 | ``` 123 | """ 124 | 125 | __slots__: list[str] = [] 126 | allowed_values: set[str] = set(get_timezones()) 127 | allowed_values_list: list[str] = sorted(allowed_values) 128 | allowed_values_upper_to_correct: dict[str, str] = {val.upper(): val for val in allowed_values} 129 | strict: bool 130 | 131 | @classmethod 132 | def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName: 133 | """Validate a time zone name from the provided str value. 134 | 135 | Args: 136 | __input_value: The str value to be validated. 137 | _: The Pydantic ValidationInfo. 138 | 139 | Returns: 140 | The validated time zone name. 141 | 142 | Raises: 143 | PydanticCustomError: If the timezone name is not valid. 144 | """ 145 | if __input_value not in cls.allowed_values: # be fast for the most common case 146 | if not cls.strict: 147 | upper_value = __input_value.strip().upper() 148 | if upper_value in cls.allowed_values_upper_to_correct: 149 | return cls(cls.allowed_values_upper_to_correct[upper_value]) 150 | raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.') 151 | return cls(__input_value) 152 | 153 | @classmethod 154 | def __get_pydantic_core_schema__( 155 | cls, _: type[Any], __: GetCoreSchemaHandler 156 | ) -> core_schema.AfterValidatorFunctionSchema: 157 | """Return a Pydantic CoreSchema with the timezone name validation. 158 | 159 | Args: 160 | _: The source type. 161 | __: The handler to get the CoreSchema. 162 | 163 | Returns: 164 | A Pydantic CoreSchema with the timezone name validation. 165 | """ 166 | return core_schema.with_info_after_validator_function( 167 | cls._validate, 168 | core_schema.str_schema(min_length=1), 169 | ) 170 | 171 | @classmethod 172 | def __get_pydantic_json_schema__( 173 | cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler 174 | ) -> dict[str, Any]: 175 | """Return a Pydantic JSON Schema with the timezone name validation. 176 | 177 | Args: 178 | schema: The Pydantic CoreSchema. 179 | handler: The handler to get the JSON Schema. 180 | 181 | Returns: 182 | A Pydantic JSON Schema with the timezone name validation. 183 | """ 184 | json_schema = handler(schema) 185 | json_schema.update({'enum': cls.allowed_values_list}) 186 | return json_schema 187 | -------------------------------------------------------------------------------- /pydantic_extra_types/ulid.py: -------------------------------------------------------------------------------- 1 | """The `pydantic_extra_types.ULID` module provides the [`ULID`] data type. 2 | 3 | This class depends on the [python-ulid] package, which is a validate by the [ULID-spec](https://github.com/ulid/spec#implementations-in-other-languages). 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | import uuid 9 | from dataclasses import dataclass 10 | from typing import Any, Union 11 | 12 | from pydantic import GetCoreSchemaHandler 13 | from pydantic._internal import _repr 14 | from pydantic_core import PydanticCustomError, core_schema 15 | 16 | try: 17 | from ulid import ULID as _ULID 18 | except ModuleNotFoundError as e: # pragma: no cover 19 | raise RuntimeError( 20 | 'The `ulid` module requires "python-ulid" to be installed. You can install it with "pip install python-ulid".' 21 | ) from e 22 | 23 | UlidType = Union[str, bytes, int] 24 | 25 | 26 | @dataclass 27 | class ULID(_repr.Representation): 28 | """A wrapper around [python-ulid](https://pypi.org/project/python-ulid/) package, which 29 | is a validate by the [ULID-spec](https://github.com/ulid/spec#implementations-in-other-languages). 30 | """ 31 | 32 | ulid: _ULID 33 | 34 | @classmethod 35 | def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: 36 | return core_schema.no_info_wrap_validator_function( 37 | cls._validate_ulid, 38 | core_schema.union_schema( 39 | [ 40 | core_schema.is_instance_schema(_ULID), 41 | core_schema.int_schema(), 42 | core_schema.bytes_schema(), 43 | core_schema.str_schema(), 44 | core_schema.uuid_schema(), 45 | ] 46 | ), 47 | ) 48 | 49 | @classmethod 50 | def _validate_ulid(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: 51 | ulid: _ULID 52 | if isinstance(value, bool): 53 | raise PydanticCustomError('ulid_format', 'Unrecognized format') 54 | try: 55 | if isinstance(value, int): 56 | ulid = _ULID.from_int(value) 57 | elif isinstance(value, str): 58 | ulid = _ULID.from_str(value) 59 | elif isinstance(value, uuid.UUID): 60 | ulid = _ULID.from_uuid(value) 61 | elif isinstance(value, _ULID): 62 | ulid = value 63 | else: 64 | ulid = _ULID.from_bytes(value) 65 | except ValueError as e: 66 | raise PydanticCustomError('ulid_format', 'Unrecognized format') from e 67 | return handler(ulid) 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ['hatchling'] 3 | build-backend = 'hatchling.build' 4 | 5 | [tool.hatch.version] 6 | path = 'pydantic_extra_types/__init__.py' 7 | 8 | [project] 9 | name = 'pydantic-extra-types' 10 | description = 'Extra Pydantic types.' 11 | authors = [ 12 | { name = 'Samuel Colvin', email = 's@muelcolvin.com' }, 13 | { name = 'Yasser Tahiri', email = 'hello@yezz.me' }, 14 | ] 15 | license = 'MIT' 16 | readme = 'README.md' 17 | classifiers = [ 18 | 'Development Status :: 5 - Production/Stable', 19 | 'Programming Language :: Python', 20 | 'Programming Language :: Python :: 3', 21 | 'Programming Language :: Python :: 3 :: Only', 22 | 'Programming Language :: Python :: 3.8', 23 | 'Programming Language :: Python :: 3.9', 24 | 'Programming Language :: Python :: 3.10', 25 | 'Programming Language :: Python :: 3.11', 26 | 'Programming Language :: Python :: 3.12', 27 | 'Programming Language :: Python :: 3.13', 28 | 'Intended Audience :: Developers', 29 | 'Intended Audience :: Information Technology', 30 | 'Intended Audience :: System Administrators', 31 | 'License :: OSI Approved :: MIT License', 32 | 'Operating System :: Unix', 33 | 'Operating System :: POSIX :: Linux', 34 | 'Environment :: Console', 35 | 'Environment :: MacOS X', 36 | 'Framework :: Pydantic', 37 | 'Framework :: Pydantic :: 2', 38 | 'Topic :: Software Development :: Libraries :: Python Modules', 39 | 'Topic :: Internet', 40 | ] 41 | requires-python = '>=3.8' 42 | dependencies = ['pydantic>=2.5.2','typing-extensions'] 43 | dynamic = ['version'] 44 | 45 | [project.optional-dependencies] 46 | all = [ 47 | 'phonenumbers>=8,<10', 48 | 'pycountry>=23', 49 | 'semver>=3.0.2', 50 | 'python-ulid>=1,<2; python_version<"3.9"', 51 | 'python-ulid>=1,<4; python_version>="3.9"', 52 | 'pendulum>=3.0.0,<4.0.0', 53 | 'pymongo>=4.0.0,<5.0.0', 54 | 'pytz>=2024.1', 55 | 'semver~=3.0.2', 56 | 'tzdata>=2024.1', 57 | ] 58 | phonenumbers = ['phonenumbers>=8,<10'] 59 | pycountry = ['pycountry>=23'] 60 | semver = ['semver>=3.0.2'] 61 | python_ulid = [ 62 | 'python-ulid>=1,<2; python_version<"3.9"', 63 | 'python-ulid>=1,<4; python_version>="3.9"', 64 | ] 65 | pendulum = ['pendulum>=3.0.0,<4.0.0'] 66 | 67 | [dependency-groups] 68 | dev = [ 69 | "coverage[toml]>=7.6.1", 70 | "pytest-pretty>=1.2.0", 71 | "dirty-equals>=0.7.1", 72 | "pytest>=8.3.2", 73 | ] 74 | lint = [ 75 | "ruff>=0.7.4", 76 | "mypy>=0.910", 77 | "annotated-types>=0.7.0", 78 | "types-pytz>=2024.1.0.20240417", 79 | ] 80 | extra = [ 81 | { include-group = 'dev' }, 82 | { include-group = 'lint' }, 83 | ] 84 | 85 | [project.urls] 86 | Homepage = 'https://github.com/pydantic/pydantic-extra-types' 87 | Source = 'https://github.com/pydantic/pydantic-extra-types' 88 | Changelog = 'https://github.com/pydantic/pydantic-extra-types/releases' 89 | Documentation = 'https://docs.pydantic.dev/latest/' 90 | 91 | [tool.ruff.lint.pyupgrade] 92 | keep-runtime-typing = true 93 | 94 | [tool.ruff] 95 | line-length = 120 96 | target-version = 'py38' 97 | 98 | [tool.ruff.lint] 99 | extend-select = [ 100 | "Q", 101 | "RUF100", 102 | "C90", 103 | "UP", 104 | "I", 105 | ] 106 | flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' } 107 | isort = {known-first-party = ['pydantic_extra_types', 'tests'] } 108 | mccabe = { max-complexity = 14 } 109 | pydocstyle = { convention = 'google' } 110 | 111 | [tool.ruff.format] 112 | docstring-code-format = true 113 | quote-style = "single" 114 | 115 | [tool.ruff.lint.per-file-ignores] 116 | 'pydantic_extra_types/color.py' = ['E741'] 117 | 118 | [tool.coverage.run] 119 | source = ['pydantic_extra_types'] 120 | branch = true 121 | context = '${CONTEXT}' 122 | 123 | [tool.coverage.paths] 124 | source = [ 125 | 'pydantic_extra_types/', 126 | '/Users/runner/work/pydantic-extra-types/pydantic-extra-types/pydantic_extra_types/', 127 | 'D:\a\pydantic-extra-types\pydantic-extra-types\pydantic_extra_types', 128 | ] 129 | 130 | [tool.coverage.report] 131 | precision = 2 132 | fail_under = 100 133 | show_missing = true 134 | skip_covered = true 135 | exclude_lines = [ 136 | 'pragma: no cover', 137 | 'raise NotImplementedError', 138 | 'if TYPE_CHECKING:', 139 | '@overload', 140 | ] 141 | 142 | [tool.mypy] 143 | strict = true 144 | plugins = 'pydantic.mypy' 145 | 146 | [tool.pytest.ini_options] 147 | filterwarnings = [ 148 | 'error', 149 | # This ignore will be removed when pycountry will drop py36 & support py311 150 | 'ignore:::pkg_resources', 151 | # This ignore will be removed when pendulum fixes https://github.com/sdispater/pendulum/issues/834 152 | 'ignore:datetime.datetime.utcfromtimestamp.*:DeprecationWarning', 153 | ' ignore:Use from pydantic_extra_types.semver import SemanticVersion instead. Will be removed in 3.0.0.:DeprecationWarning' 154 | ] 155 | 156 | # configuring https://github.com/pydantic/hooky 157 | [tool.hooky] 158 | reviewers = ['yezz123', 'Kludex'] 159 | require_change_file = false 160 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pydantic/pydantic-extra-types/7332abb2337527c6d259b9b07bde17012da020a7/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_coordinate.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | from re import Pattern 3 | from typing import Any, Optional, Union 4 | 5 | import pytest 6 | from pydantic import BaseModel, ValidationError 7 | from pydantic_core._pydantic_core import ArgsKwargs 8 | 9 | from pydantic_extra_types.coordinate import Coordinate, Latitude, Longitude 10 | 11 | 12 | class Coord(BaseModel): 13 | coord: Coordinate 14 | 15 | 16 | class Lat(BaseModel): 17 | lat: Latitude 18 | 19 | 20 | class Lng(BaseModel): 21 | lng: Longitude 22 | 23 | 24 | @pytest.mark.parametrize( 25 | 'coord, result, error', 26 | [ 27 | # Valid coordinates 28 | ((20.0, 10.0), (20.0, 10.0), None), 29 | ((-90.0, 0.0), (-90.0, 0.0), None), 30 | (('20.0', 10.0), (20.0, 10.0), None), 31 | ((20.0, '10.0'), (20.0, 10.0), None), 32 | ((45.678, -123.456), (45.678, -123.456), None), 33 | (('45.678, -123.456'), (45.678, -123.456), None), 34 | (Coordinate(20.0, 10.0), (20.0, 10.0), None), 35 | (Coordinate(latitude=0, longitude=0), (0, 0), None), 36 | (ArgsKwargs(args=()), (0, 0), None), 37 | (ArgsKwargs(args=(1, 0.0)), (1.0, 0), None), 38 | # Decimal test cases 39 | ((Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None), 40 | ((Decimal('-90.0'), Decimal('0.0')), (Decimal('-90.0'), Decimal('0.0')), None), 41 | ((Decimal('45.678'), Decimal('-123.456')), (Decimal('45.678'), Decimal('-123.456')), None), 42 | (Coordinate(Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None), 43 | (Coordinate(latitude=Decimal('0'), longitude=Decimal('0')), (Decimal('0'), Decimal('0')), None), 44 | (ArgsKwargs(args=(Decimal('1'), Decimal('0.0'))), (Decimal('1.0'), Decimal('0.0')), None), 45 | # Invalid coordinates 46 | ((), None, 'Field required'), # Empty tuple 47 | ((10.0,), None, 'Field required'), # Tuple with only one value 48 | (('ten, '), None, 'string is not recognized as a valid coordinate'), 49 | ((20.0, 10.0, 30.0), None, 'Tuple should have at most 2 items'), # Tuple with more than 2 values 50 | (ArgsKwargs(args=(1.0,)), None, 'Input should be a dictionary or an instance of Coordinate'), 51 | ( 52 | '20.0, 10.0, 30.0', 53 | None, 54 | 'Input should be a dictionary or an instance of Coordinate ', 55 | ), # Str with more than 2 values 56 | ('20.0, 10.0, 30.0', None, 'Unexpected positional argument'), # Str with more than 2 values 57 | (2, None, 'Input should be a dictionary or an instance of Coordinate'), # Wrong type 58 | ], 59 | ) 60 | def test_format_for_coordinate( 61 | coord: (Any, Any), result: (Union[float, Decimal], Union[float, Decimal]), error: Optional[Pattern] 62 | ): 63 | if error is None: 64 | _coord: Coordinate = Coord(coord=coord).coord 65 | assert _coord.latitude == result[0] 66 | assert _coord.longitude == result[1] 67 | else: 68 | with pytest.raises(ValidationError, match=error): 69 | Coord(coord=coord).coord 70 | 71 | 72 | @pytest.mark.parametrize( 73 | 'coord, error', 74 | [ 75 | # Valid coordinates 76 | ((-90.0, 0.0), None), 77 | ((50.0, 180.0), None), 78 | # Invalid coordinates 79 | ((-91.0, 0.0), 'Input should be greater than or equal to -90'), 80 | ((50.0, 181.0), 'Input should be less than or equal to 180'), 81 | # Valid Decimal coordinates 82 | ((Decimal('-90.0'), Decimal('0.0')), None), 83 | ((Decimal('50.0'), Decimal('180.0')), None), 84 | ((Decimal('-89.999999'), Decimal('179.999999')), None), 85 | ((Decimal('0.0'), Decimal('0.0')), None), 86 | # Invalid Decimal coordinates 87 | ((Decimal('-90.1'), Decimal('0.0')), 'Input should be greater than or equal to -90'), 88 | ((Decimal('50.0'), Decimal('180.1')), 'Input should be less than or equal to 180'), 89 | ((Decimal('90.1'), Decimal('0.0')), 'Input should be less than or equal to 90'), 90 | ((Decimal('0.0'), Decimal('-180.1')), 'Input should be greater than or equal to -180'), 91 | ], 92 | ) 93 | def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]): 94 | if error is None: 95 | _coord: Coordinate = Coord(coord=coord).coord 96 | assert _coord.latitude == coord[0] 97 | assert _coord.longitude == coord[1] 98 | else: 99 | with pytest.raises(ValidationError, match=error): 100 | Coord(coord=coord).coord 101 | 102 | 103 | @pytest.mark.parametrize( 104 | 'latitude, valid', 105 | [ 106 | # Valid latitude 107 | (20.0, True), 108 | (3.0000000000000000000000, True), 109 | (90.0, True), 110 | ('90.0', True), 111 | (-90.0, True), 112 | ('-90.0', True), 113 | (Decimal('90.0'), True), 114 | (Decimal('-90.0'), True), 115 | # Unvalid latitude 116 | (91.0, False), 117 | (-91.0, False), 118 | (Decimal('91.0'), False), 119 | (Decimal('-91.0'), False), 120 | ], 121 | ) 122 | def test_format_latitude(latitude: float, valid: bool): 123 | if valid: 124 | _lat = Lat(lat=latitude).lat 125 | assert _lat == float(latitude) 126 | else: 127 | with pytest.raises(ValidationError, match='2 validation errors for Lat'): 128 | Lat(lat=latitude) 129 | 130 | 131 | @pytest.mark.parametrize( 132 | 'longitude, valid', 133 | [ 134 | # Valid latitude 135 | (20.0, True), 136 | (3.0000000000000000000000, True), 137 | (90.0, True), 138 | ('90.0', True), 139 | (-90.0, True), 140 | ('-90.0', True), 141 | (91.0, True), 142 | (-91.0, True), 143 | (180.0, True), 144 | (-180.0, True), 145 | (Decimal('180.0'), True), 146 | (Decimal('-180.0'), True), 147 | # Unvalid latitude 148 | (181.0, False), 149 | (-181.0, False), 150 | (Decimal('181.0'), False), 151 | (Decimal('-181.0'), False), 152 | ], 153 | ) 154 | def test_format_longitude(longitude: float, valid: bool): 155 | if valid: 156 | _lng = Lng(lng=longitude).lng 157 | assert _lng == float(longitude) 158 | else: 159 | with pytest.raises(ValidationError, match='2 validation errors for Lng'): 160 | Lng(lng=longitude) 161 | 162 | 163 | def test_str_repr(): 164 | # Float tests 165 | assert str(Coord(coord=(20.0, 10.0)).coord) == '20.0,10.0' 166 | assert str(Coord(coord=('20.0, 10.0')).coord) == '20.0,10.0' 167 | assert repr(Coord(coord=(20.0, 10.0)).coord) == 'Coordinate(latitude=20.0, longitude=10.0)' 168 | # Decimal tests 169 | assert str(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == '20.0,10.0' 170 | assert str(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == '20.000,10.000' 171 | assert ( 172 | repr(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) 173 | == "Coordinate(latitude=Decimal('20.0'), longitude=Decimal('10.0'))" 174 | ) 175 | 176 | 177 | def test_eq(): 178 | # Float tests 179 | assert Coord(coord=(20.0, 10.0)).coord != Coord(coord='20.0,11.0').coord 180 | assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord 181 | assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord 182 | assert Coord(coord=(20.0, 10.0)).coord == Coord(coord='20.0,10.0').coord 183 | 184 | # Decimal tests 185 | assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord='20.0,10.0').coord 186 | assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord=(20.0, 10.0)).coord 187 | assert ( 188 | Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord != Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord 189 | ) 190 | assert ( 191 | Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord 192 | == Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord 193 | ) 194 | 195 | 196 | def test_hashable(): 197 | # Float tests 198 | assert hash(Coord(coord=(20.0, 10.0)).coord) == hash(Coord(coord=(20.0, 10.0)).coord) 199 | assert hash(Coord(coord=(20.0, 11.0)).coord) != hash(Coord(coord=(20.0, 10.0)).coord) 200 | 201 | # Decimal tests 202 | assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash( 203 | Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord 204 | ) 205 | assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash(Coord(coord=(20.0, 10.0)).coord) 206 | assert hash(Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord) != hash( 207 | Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord 208 | ) 209 | assert hash(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == hash( 210 | Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord 211 | ) 212 | 213 | 214 | def test_json_schema(): 215 | class Model(BaseModel): 216 | value: Coordinate 217 | 218 | assert Model.model_json_schema(mode='validation')['$defs']['Coordinate'] == { 219 | 'properties': { 220 | 'latitude': { 221 | 'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}], 222 | 'title': 'Latitude', 223 | }, 224 | 'longitude': { 225 | 'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}], 226 | 'title': 'Longitude', 227 | }, 228 | }, 229 | 'required': ['latitude', 'longitude'], 230 | 'title': 'Coordinate', 231 | 'type': 'object', 232 | } 233 | assert Model.model_json_schema(mode='validation')['properties']['value'] == { 234 | 'anyOf': [ 235 | {'$ref': '#/$defs/Coordinate'}, 236 | { 237 | 'maxItems': 2, 238 | 'minItems': 2, 239 | 'prefixItems': [ 240 | {'anyOf': [{'type': 'number'}, {'type': 'string'}]}, 241 | {'anyOf': [{'type': 'number'}, {'type': 'string'}]}, 242 | ], 243 | 'type': 'array', 244 | }, 245 | {'type': 'string'}, 246 | ], 247 | 'title': 'Value', 248 | } 249 | assert Model.model_json_schema(mode='serialization') == { 250 | '$defs': { 251 | 'Coordinate': { 252 | 'properties': { 253 | 'latitude': { 254 | 'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}], 255 | 'title': 'Latitude', 256 | }, 257 | 'longitude': { 258 | 'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}], 259 | 'title': 'Longitude', 260 | }, 261 | }, 262 | 'required': ['latitude', 'longitude'], 263 | 'title': 'Coordinate', 264 | 'type': 'object', 265 | } 266 | }, 267 | 'properties': {'value': {'$ref': '#/$defs/Coordinate', 'title': 'Value'}}, 268 | 'required': ['value'], 269 | 'title': 'Model', 270 | 'type': 'object', 271 | } 272 | -------------------------------------------------------------------------------- /tests/test_country_code.py: -------------------------------------------------------------------------------- 1 | from string import printable 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | 6 | from pydantic_extra_types.country import ( 7 | CountryAlpha2, 8 | CountryAlpha3, 9 | CountryInfo, 10 | CountryNumericCode, 11 | CountryShortName, 12 | _index_by_alpha2, 13 | _index_by_alpha3, 14 | _index_by_numeric_code, 15 | _index_by_short_name, 16 | ) 17 | 18 | PARAMS_AMOUNT = 20 19 | 20 | 21 | @pytest.fixture(scope='module', name='ProductAlpha2') 22 | def product_alpha2_fixture(): 23 | class Product(BaseModel): 24 | made_in: CountryAlpha2 25 | 26 | return Product 27 | 28 | 29 | @pytest.fixture(scope='module', name='ProductAlpha3') 30 | def product_alpha3_fixture(): 31 | class Product(BaseModel): 32 | made_in: CountryAlpha3 33 | 34 | return Product 35 | 36 | 37 | @pytest.fixture(scope='module', name='ProductShortName') 38 | def product_short_name_fixture(): 39 | class Product(BaseModel): 40 | made_in: CountryShortName 41 | 42 | return Product 43 | 44 | 45 | @pytest.fixture(scope='module', name='ProductNumericCode') 46 | def product_numeric_code_fixture(): 47 | class Product(BaseModel): 48 | made_in: CountryNumericCode 49 | 50 | return Product 51 | 52 | 53 | @pytest.mark.parametrize('alpha2, country_data', list(_index_by_alpha2().items())[:PARAMS_AMOUNT]) 54 | def test_valid_alpha2(alpha2: str, country_data: CountryInfo, ProductAlpha2): 55 | banana = ProductAlpha2(made_in=alpha2) 56 | assert banana.made_in == country_data.alpha2 57 | assert banana.made_in.alpha3 == country_data.alpha3 58 | assert banana.made_in.numeric_code == country_data.numeric_code 59 | assert banana.made_in.short_name == country_data.short_name 60 | 61 | 62 | @pytest.mark.parametrize('alpha2', list(printable)) 63 | def test_invalid_alpha2(alpha2: str, ProductAlpha2): 64 | with pytest.raises(ValidationError, match='Invalid country alpha2 code'): 65 | ProductAlpha2(made_in=alpha2) 66 | 67 | 68 | @pytest.mark.parametrize('alpha3, country_data', list(_index_by_alpha3().items())[:PARAMS_AMOUNT]) 69 | def test_valid_alpha3(alpha3: str, country_data: CountryInfo, ProductAlpha3): 70 | banana = ProductAlpha3(made_in=alpha3) 71 | assert banana.made_in == country_data.alpha3 72 | assert banana.made_in.alpha2 == country_data.alpha2 73 | assert banana.made_in.numeric_code == country_data.numeric_code 74 | assert banana.made_in.short_name == country_data.short_name 75 | 76 | 77 | @pytest.mark.parametrize('alpha3', list(printable)) 78 | def test_invalid_alpha3(alpha3: str, ProductAlpha3): 79 | with pytest.raises(ValidationError, match='Invalid country alpha3 code'): 80 | ProductAlpha3(made_in=alpha3) 81 | 82 | 83 | @pytest.mark.parametrize('short_name, country_data', list(_index_by_short_name().items())[:PARAMS_AMOUNT]) 84 | def test_valid_short_name(short_name: str, country_data: CountryInfo, ProductShortName): 85 | banana = ProductShortName(made_in=short_name) 86 | assert banana.made_in == country_data.short_name 87 | assert banana.made_in.alpha2 == country_data.alpha2 88 | assert banana.made_in.alpha3 == country_data.alpha3 89 | assert banana.made_in.numeric_code == country_data.numeric_code 90 | 91 | 92 | @pytest.mark.parametrize('short_name', list(printable)) 93 | def test_invalid_short_name(short_name: str, ProductShortName): 94 | with pytest.raises(ValidationError, match='Invalid country short name'): 95 | ProductShortName(made_in=short_name) 96 | 97 | 98 | @pytest.mark.parametrize('numeric_code, country_data', list(_index_by_numeric_code().items())[:PARAMS_AMOUNT]) 99 | def test_valid_numeric_code(numeric_code: str, country_data: CountryInfo, ProductNumericCode): 100 | banana = ProductNumericCode(made_in=numeric_code) 101 | assert banana.made_in == country_data.numeric_code 102 | assert banana.made_in.alpha2 == country_data.alpha2 103 | assert banana.made_in.alpha3 == country_data.alpha3 104 | assert banana.made_in.short_name == country_data.short_name 105 | 106 | 107 | @pytest.mark.parametrize('numeric_code', list(printable)) 108 | def test_invalid_numeric_code(numeric_code: str, ProductNumericCode): 109 | with pytest.raises(ValidationError, match='Invalid country numeric code'): 110 | ProductNumericCode(made_in=numeric_code) 111 | -------------------------------------------------------------------------------- /tests/test_currency_code.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pycountry 4 | import pytest 5 | from pydantic import BaseModel, ValidationError 6 | 7 | from pydantic_extra_types import currency_code 8 | 9 | 10 | class ISO4217CheckingModel(BaseModel): 11 | currency: currency_code.ISO4217 12 | 13 | 14 | class CurrencyCheckingModel(BaseModel): 15 | currency: currency_code.Currency 16 | 17 | 18 | forbidden_currencies = sorted(currency_code._CODES_FOR_BONDS_METAL_TESTING) 19 | 20 | 21 | @pytest.mark.parametrize('currency', map(lambda code: code.alpha_3, pycountry.currencies)) 22 | def test_ISO4217_code_ok(currency: str): 23 | model = ISO4217CheckingModel(currency=currency) 24 | assert model.currency == currency 25 | assert model.model_dump() == {'currency': currency} # test serialization 26 | 27 | 28 | @pytest.mark.parametrize('currency', ['USD', 'usd', 'UsD']) 29 | def test_ISO4217_code_ok_lower_case(currency: str): 30 | model = ISO4217CheckingModel(currency=currency) 31 | assert model.currency == currency.upper() 32 | 33 | 34 | @pytest.mark.parametrize( 35 | 'currency', 36 | filter( 37 | lambda code: code not in currency_code._CODES_FOR_BONDS_METAL_TESTING, 38 | map(lambda code: code.alpha_3, pycountry.currencies), 39 | ), 40 | ) 41 | def test_everyday_code_ok(currency: str): 42 | model = CurrencyCheckingModel(currency=currency) 43 | assert model.currency == currency 44 | assert model.model_dump() == {'currency': currency} # test serialization 45 | 46 | 47 | @pytest.mark.parametrize('currency', ['USD', 'usd', 'UsD']) 48 | def test_everyday_code_ok_lower_case(currency: str): 49 | model = CurrencyCheckingModel(currency=currency) 50 | assert model.currency == currency.upper() 51 | 52 | 53 | def test_ISO4217_fails(): 54 | with pytest.raises( 55 | ValidationError, 56 | match=re.escape( 57 | '1 validation error for ISO4217CheckingModel\ncurrency\n ' 58 | 'Invalid ISO 4217 currency code. See https://en.wikipedia.org/wiki/ISO_4217 ' 59 | "[type=ISO4217, input_value='OMG', input_type=str]" 60 | ), 61 | ): 62 | ISO4217CheckingModel(currency='OMG') 63 | 64 | 65 | @pytest.mark.parametrize('forbidden_currency', forbidden_currencies) 66 | def test_forbidden_everyday(forbidden_currency): 67 | with pytest.raises( 68 | ValidationError, 69 | match=re.escape( 70 | '1 validation error for CurrencyCheckingModel\ncurrency\n ' 71 | 'Invalid currency code. See https://en.wikipedia.org/wiki/ISO_4217 . ' 72 | 'Bonds, testing and precious metals codes are not allowed. ' 73 | f"[type=InvalidCurrency, input_value='{forbidden_currency}', input_type=str]" 74 | ), 75 | ): 76 | CurrencyCheckingModel(currency=forbidden_currency) 77 | -------------------------------------------------------------------------------- /tests/test_domain.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | 6 | from pydantic_extra_types.domain import DomainStr 7 | 8 | 9 | class MyModel(BaseModel): 10 | domain: DomainStr 11 | 12 | 13 | valid_domains = [ 14 | 'example.com', 15 | 'sub.example.com', 16 | 'sub-domain.example-site.co.uk', 17 | 'a.com', 18 | 'x.com', 19 | '1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.com', # Multiple subdomains 20 | ] 21 | 22 | invalid_domains = [ 23 | '', # Empty string 24 | 'example', # Missing TLD 25 | '.com', # Missing domain name 26 | 'example.', # Trailing dot 27 | 'exam ple.com', # Space in domain 28 | 'exa_mple.com', # Underscore in domain 29 | 'example.com.', # Trailing dot 30 | '192.168.1.23', # Ip address 31 | '192.168.1.0/23', # CIDR 32 | ] 33 | 34 | very_long_domains = [ 35 | 'a' * 249 + '.com', # Just under the limit 36 | 'a' * 250 + '.com', # At the limit 37 | 'a' * 251 + '.com', # Just over the limit 38 | 'sub1.sub2.sub3.sub4.sub5.sub6.sub7.sub8.sub9.sub10.sub11.sub12.sub13.sub14.sub15.sub16.sub17.sub18.sub19.sub20.sub21.sub22.sub23.sub24.sub25.sub26.sub27.sub28.sub29.sub30.sub31.sub32.sub33.extremely-long-domain-name-example-to-test-the-253-character-limit.com', 39 | ] 40 | 41 | invalid_domain_types = [1, 2, 1.1, 2.1, False, [], {}, None] 42 | 43 | 44 | @pytest.mark.parametrize('domain', valid_domains) 45 | def test_valid_domains(domain: str): 46 | try: 47 | MyModel.model_validate({'domain': domain}) 48 | assert len(domain) < 254 and len(domain) > 0 49 | except ValidationError: 50 | assert len(domain) > 254 or len(domain) == 0 51 | 52 | 53 | @pytest.mark.parametrize('domain', invalid_domains) 54 | def test_invalid_domains(domain: str): 55 | try: 56 | MyModel.model_validate({'domain': domain}) 57 | raise Exception( 58 | f"This test case has only samples that should raise a ValidationError. This domain '{domain}' did not raise such an exception." 59 | ) 60 | except ValidationError: 61 | # An error is expected on this test 62 | pass 63 | 64 | 65 | @pytest.mark.parametrize('domain', very_long_domains) 66 | def test_very_long_domains(domain: str): 67 | try: 68 | MyModel.model_validate({'domain': domain}) 69 | assert len(domain) < 254 and len(domain) > 0 70 | except ValidationError: 71 | # An error is expected on this test 72 | pass 73 | 74 | 75 | @pytest.mark.parametrize('domain', invalid_domain_types) 76 | def test_invalid_domain_types(domain: Any): 77 | with pytest.raises(ValidationError, match='Value must be a string'): 78 | MyModel(domain=domain) 79 | -------------------------------------------------------------------------------- /tests/test_epoch.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pytest 4 | 5 | from pydantic_extra_types import epoch 6 | 7 | 8 | @pytest.mark.parametrize('type_,cls_', [(int, epoch.Integer), (float, epoch.Number)], ids=['integer', 'number']) 9 | def test_type(type_, cls_): 10 | from pydantic import BaseModel 11 | 12 | class A(BaseModel): 13 | epoch: cls_ 14 | 15 | now = datetime.datetime.now(tz=datetime.timezone.utc) 16 | ts = type_(now.timestamp()) 17 | a = A.model_validate({'epoch': ts}) 18 | v = a.model_dump() 19 | assert v['epoch'] == ts 20 | 21 | b = A.model_construct(epoch=now) 22 | 23 | v = b.model_dump() 24 | assert v['epoch'] == ts 25 | 26 | c = A.model_validate(dict(epoch=ts)) 27 | v = c.model_dump() 28 | assert v['epoch'] == ts 29 | 30 | 31 | @pytest.mark.parametrize('cls_', [(epoch.Integer), (epoch.Number)], ids=['integer', 'number']) 32 | def test_schema(cls_): 33 | from pydantic import BaseModel 34 | 35 | class A(BaseModel): 36 | dt: cls_ 37 | 38 | v = A.model_json_schema() 39 | assert (dt := v['properties']['dt'])['type'] == cls_.TYPE and dt['format'] == 'date-time' 40 | -------------------------------------------------------------------------------- /tests/test_isbn.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | 6 | from pydantic_extra_types.isbn import ISBN 7 | 8 | 9 | class Book(BaseModel): 10 | isbn: ISBN 11 | 12 | 13 | isbn_length_test_cases = [ 14 | # Valid ISBNs 15 | ('8537809667', '9788537809662', True), # ISBN-10 as input 16 | ('9788537809662', '9788537809662', True), # ISBN-13 as input 17 | ('080442957X', '9780804429573', True), # ISBN-10 ending in "X" as input 18 | ('9788584390670', '9788584390670', True), # ISBN-13 Starting with 978 19 | ('9790306406156', '9790306406156', True), # ISBN-13 starting with 979 20 | # Invalid ISBNs 21 | ('97885843906701', None, False), # Length: 14 (Higher) 22 | ('978858439067', None, False), # Length: 12 (In Between) 23 | ('97885843906', None, False), # Length: 11 (In Between) 24 | ('978858439', None, False), # Length: 9 (Lower) 25 | ('', None, False), # Length: 0 (Lower) 26 | ] 27 | 28 | 29 | @pytest.mark.parametrize('input_isbn, output_isbn, valid', isbn_length_test_cases) 30 | def test_isbn_length(input_isbn: Any, output_isbn: str, valid: bool) -> None: 31 | if valid: 32 | assert Book(isbn=ISBN(input_isbn)).isbn == output_isbn 33 | else: 34 | with pytest.raises(ValidationError, match='isbn_length'): 35 | Book(isbn=ISBN(input_isbn)) 36 | 37 | 38 | isbn10_digits_test_cases = [ 39 | # Valid ISBNs 40 | ('8537809667', '9788537809662', True), # ISBN-10 as input 41 | ('080442957X', '9780804429573', True), # ISBN-10 ending in "X" as input 42 | # Invalid ISBNs 43 | ('@80442957X', None, False), # Non Integer in [0] position 44 | ('8@37809667', None, False), # Non Integer in [1] position 45 | ('85@7809667', None, False), # Non Integer in [2] position 46 | ('853@809667', None, False), # Non Integer in [3] position 47 | ('8537@09667', None, False), # Non Integer in [4] position 48 | ('85378@9667', None, False), # Non Integer in [5] position 49 | ('853780@667', None, False), # Non Integer in [6] position 50 | ('8537809@67', None, False), # Non Integer in [7] position 51 | ('85378096@7', None, False), # Non Integer in [8] position 52 | ('853780966@', None, False), # Non Integer or X in [9] position 53 | ] 54 | 55 | 56 | @pytest.mark.parametrize('input_isbn, output_isbn, valid', isbn10_digits_test_cases) 57 | def test_isbn10_digits(input_isbn: Any, output_isbn: str, valid: bool) -> None: 58 | if valid: 59 | assert Book(isbn=ISBN(input_isbn)).isbn == output_isbn 60 | else: 61 | with pytest.raises(ValidationError, match='isbn10_invalid_characters'): 62 | Book(isbn=ISBN(input_isbn)) 63 | 64 | 65 | isbn13_digits_test_cases = [ 66 | # Valid ISBNs 67 | ('9788537809662', '9788537809662', True), # ISBN-13 as input 68 | ('9780306406157', '9780306406157', True), # ISBN-13 as input 69 | ('9788584390670', '9788584390670', True), # ISBN-13 Starting with 978 70 | ('9790306406156', '9790306406156', True), # ISBN-13 starting with 979 71 | # Invalid ISBNs 72 | ('@788537809662', None, False), # Non Integer in [0] position 73 | ('9@88537809662', None, False), # Non Integer in [1] position 74 | ('97@8537809662', None, False), # Non Integer in [2] position 75 | ('978@537809662', None, False), # Non Integer in [3] position 76 | ('9788@37809662', None, False), # Non Integer in [4] position 77 | ('97885@7809662', None, False), # Non Integer in [5] position 78 | ('978853@809662', None, False), # Non Integer in [6] position 79 | ('9788537@09662', None, False), # Non Integer in [7] position 80 | ('97885378@9662', None, False), # Non Integer in [8] position 81 | ('978853780@662', None, False), # Non Integer in [9] position 82 | ('9788537809@62', None, False), # Non Integer in [10] position 83 | ('97885378096@2', None, False), # Non Integer in [11] position 84 | ('978853780966@', None, False), # Non Integer in [12] position 85 | ] 86 | 87 | 88 | @pytest.mark.parametrize('input_isbn, output_isbn, valid', isbn13_digits_test_cases) 89 | def test_isbn13_digits(input_isbn: Any, output_isbn: str, valid: bool) -> None: 90 | if valid: 91 | assert Book(isbn=ISBN(input_isbn)).isbn == output_isbn 92 | else: 93 | with pytest.raises(ValidationError, match='isbn13_invalid_characters'): 94 | Book(isbn=ISBN(input_isbn)) 95 | 96 | 97 | isbn13_early_digits_test_cases = [ 98 | # Valid ISBNs 99 | ('9780306406157', '9780306406157', True), # ISBN-13 as input 100 | ('9788584390670', '9788584390670', True), # ISBN-13 Starting with 978 101 | ('9790306406156', '9790306406156', True), # ISBN-13 starting with 979 102 | # Invalid ISBNs 103 | ('1788584390670', None, False), # Does not start with 978 or 979 104 | ('9288584390670', None, False), # Does not start with 978 or 979 105 | ('9738584390670', None, False), # Does not start with 978 or 979 106 | ] 107 | 108 | 109 | @pytest.mark.parametrize('input_isbn, output_isbn, valid', isbn13_early_digits_test_cases) 110 | def test_isbn13_early_digits(input_isbn: Any, output_isbn: str, valid: bool) -> None: 111 | if valid: 112 | assert Book(isbn=ISBN(input_isbn)).isbn == output_isbn 113 | else: 114 | with pytest.raises(ValidationError, match='isbn_invalid_early_characters'): 115 | Book(isbn=ISBN(input_isbn)) 116 | 117 | 118 | isbn_last_digit_test_cases = [ 119 | # Valid ISBNs 120 | ('8537809667', '9788537809662', True), # ISBN-10 as input 121 | ('9788537809662', '9788537809662', True), # ISBN-13 as input 122 | ('080442957X', '9780804429573', True), # ISBN-10 ending in "X" as input 123 | ('9788584390670', '9788584390670', True), # ISBN-13 Starting with 978 124 | ('9790306406156', '9790306406156', True), # ISBN-13 starting with 979 125 | ('8306018060', '9788306018066', True), # ISBN-10 as input 126 | # Invalid ISBNs 127 | ('8537809663', None, False), # ISBN-10 as input with wrong last digit 128 | ('9788537809661', None, False), # ISBN-13 as input with wrong last digit 129 | ('080442953X', None, False), # ISBN-10 ending in "X" as input with wrong last digit 130 | ('9788584390671', None, False), # ISBN-13 Starting with 978 with wrong last digit 131 | ('9790306406155', None, False), # ISBN-13 starting with 979 with wrong last digit 132 | ] 133 | 134 | 135 | @pytest.mark.parametrize('input_isbn, output_isbn, valid', isbn_last_digit_test_cases) 136 | def test_isbn_last_digit(input_isbn: Any, output_isbn: str, valid: bool) -> None: 137 | if valid: 138 | assert Book(isbn=ISBN(input_isbn)).isbn == output_isbn 139 | else: 140 | with pytest.raises(ValidationError, match='isbn_invalid_digit_check_isbn'): 141 | Book(isbn=ISBN(input_isbn)) 142 | 143 | 144 | isbn_conversion_test_cases = [ 145 | # Valid ISBNs 146 | ('8537809667', '9788537809662'), 147 | ('080442957X', '9780804429573'), 148 | ('9788584390670', '9788584390670'), 149 | ('9790306406156', '9790306406156'), 150 | ] 151 | 152 | 153 | @pytest.mark.parametrize('input_isbn, output_isbn', isbn_conversion_test_cases) 154 | def test_isbn_conversion(input_isbn: Any, output_isbn: str) -> None: 155 | assert Book(isbn=ISBN(input_isbn)).isbn == output_isbn 156 | -------------------------------------------------------------------------------- /tests/test_language_codes.py: -------------------------------------------------------------------------------- 1 | import re 2 | from string import printable 3 | 4 | import pycountry 5 | import pytest 6 | from pydantic import BaseModel, ValidationError 7 | 8 | from pydantic_extra_types import language_code 9 | from pydantic_extra_types.language_code import ( 10 | LanguageAlpha2, 11 | LanguageInfo, 12 | LanguageName, 13 | _index_by_alpha2, 14 | _index_by_alpha3, 15 | _index_by_name, 16 | ) 17 | 18 | PARAMS_AMOUNT = 20 19 | 20 | 21 | @pytest.fixture(scope='module', name='MovieAlpha2') 22 | def movie_alpha2_fixture(): 23 | class Movie(BaseModel): 24 | audio_lang: LanguageAlpha2 25 | 26 | return Movie 27 | 28 | 29 | @pytest.fixture(scope='module', name='MovieName') 30 | def movie_name_fixture(): 31 | class Movie(BaseModel): 32 | audio_lang: LanguageName 33 | 34 | return Movie 35 | 36 | 37 | class ISO3CheckingModel(BaseModel): 38 | lang: language_code.ISO639_3 39 | 40 | 41 | class ISO5CheckingModel(BaseModel): 42 | lang: language_code.ISO639_5 43 | 44 | 45 | @pytest.mark.parametrize('alpha2, language_data', list(_index_by_alpha2().items())) 46 | def test_valid_alpha2(alpha2: str, language_data: LanguageInfo, MovieAlpha2): 47 | the_godfather = MovieAlpha2(audio_lang=alpha2) 48 | assert the_godfather.audio_lang == language_data.alpha2 49 | assert the_godfather.audio_lang.alpha3 == language_data.alpha3 50 | assert the_godfather.audio_lang.name == language_data.name 51 | 52 | 53 | @pytest.mark.parametrize('alpha2', list(printable) + list(_index_by_alpha3().keys())[:PARAMS_AMOUNT]) 54 | def test_invalid_alpha2(alpha2: str, MovieAlpha2): 55 | with pytest.raises(ValidationError, match='Invalid language alpha2 code'): 56 | MovieAlpha2(audio_lang=alpha2) 57 | 58 | 59 | @pytest.mark.parametrize('name, language_data', list(_index_by_name().items())[:PARAMS_AMOUNT]) 60 | def test_valid_name(name: str, language_data: LanguageInfo, MovieName): 61 | the_godfather = MovieName(audio_lang=name) 62 | assert the_godfather.audio_lang == language_data.name 63 | assert the_godfather.audio_lang.alpha2 == language_data.alpha2 64 | assert the_godfather.audio_lang.alpha3 == language_data.alpha3 65 | 66 | 67 | @pytest.mark.parametrize('name', set(printable) - {'E', 'U'}) # E and U are valid language codes 68 | def test_invalid_name(name: str, MovieName): 69 | with pytest.raises(ValidationError, match='Invalid language name'): 70 | MovieName(audio_lang=name) 71 | 72 | 73 | @pytest.mark.parametrize('lang', map(lambda lang: lang.alpha_3, pycountry.languages)) 74 | def test_iso_ISO639_3_code_ok(lang: str): 75 | model = ISO3CheckingModel(lang=lang) 76 | assert model.lang == lang 77 | assert model.model_dump() == {'lang': lang} # test serialization 78 | 79 | 80 | @pytest.mark.parametrize('lang', map(lambda lang: lang.alpha_3, pycountry.language_families)) 81 | def test_iso_639_5_code_ok(lang: str): 82 | model = ISO5CheckingModel(lang=lang) 83 | assert model.lang == lang 84 | assert model.model_dump() == {'lang': lang} # test serialization 85 | 86 | 87 | def test_iso3_language_fail(): 88 | with pytest.raises( 89 | ValidationError, 90 | match=re.escape( 91 | '1 validation error for ISO3CheckingModel\nlang\n ' 92 | 'Invalid ISO 639-3 language code. ' 93 | "See https://en.wikipedia.org/wiki/ISO_639-3 [type=ISO649_3, input_value='LOL', input_type=str]" 94 | ), 95 | ): 96 | ISO3CheckingModel(lang='LOL') 97 | 98 | 99 | def test_iso5_language_fail(): 100 | with pytest.raises( 101 | ValidationError, 102 | match=re.escape( 103 | '1 validation error for ISO5CheckingModel\nlang\n ' 104 | 'Invalid ISO 639-5 language code. ' 105 | "See https://en.wikipedia.org/wiki/ISO_639-5 [type=ISO649_5, input_value='LOL', input_type=str]" 106 | ), 107 | ): 108 | ISO5CheckingModel(lang='LOL') 109 | -------------------------------------------------------------------------------- /tests/test_mac_address.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | 6 | from pydantic_extra_types.mac_address import MacAddress 7 | 8 | 9 | class Network(BaseModel): 10 | mac_address: MacAddress 11 | 12 | 13 | @pytest.mark.parametrize( 14 | 'mac_address, result, valid', 15 | [ 16 | # Valid MAC addresses 17 | ('00:00:5e:00:53:01', '00:00:5e:00:53:01', True), 18 | ('02:00:5e:10:00:00:00:01', '02:00:5e:10:00:00:00:01', True), 19 | ( 20 | '00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01', 21 | '00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01', 22 | True, 23 | ), 24 | ('00-00-5e-00-53-01', '00:00:5e:00:53:01', True), 25 | ('02-00-5e-10-00-00-00-01', '02:00:5e:10:00:00:00:01', True), 26 | ( 27 | '00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01', 28 | '00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01', 29 | True, 30 | ), 31 | ('0000.5e00.5301', '00:00:5e:00:53:01', True), 32 | ('0200.5e10.0000.0001', '02:00:5e:10:00:00:00:01', True), 33 | ( 34 | '0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001', 35 | '00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01', 36 | True, 37 | ), 38 | # Invalid MAC addresses 39 | ('0200.5e10.0000.001', None, False), 40 | ('00-00-5e-00-53-0', None, False), 41 | ('00:00:5e:00:53:1', None, False), 42 | ('02:00:5e:10:00:00:00:1', None, False), 43 | ('00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:1', None, False), 44 | ('0200.5e10.0000.001', None, False), # Invalid length 45 | ('00-00-5e-00-53-0', None, False), # Missing character 46 | ('00:00:5e:00:53:1', None, False), # Missing leading zero 47 | ('00:00:5g:00:53:01', None, False), # Invalid hex digit 'g' 48 | ('00.00.5e.0.3.01.0.0.5e.0.53.01', None, False), 49 | ('00-00-5e-00-53-01:', None, False), # Extra separator at the end 50 | ('00000.5e000.5301', None, False), 51 | ('000.5e0.530001', None, False), 52 | ('0000.5e#0./301', None, False), 53 | (b'12.!4.5!.7/.#G.AB......', None, False), 54 | ('12.!4.5!.7/.#G.AB', None, False), 55 | ('00-00-5e-00-53-01-', None, False), # Extra separator at the end 56 | ('00.00.5e.00.53.01.', None, False), # Extra separator at the end 57 | ('00:00:5e:00:53:', None, False), # Incomplete MAC address 58 | (float(12345678910111213), None, False), 59 | ], 60 | ) 61 | def test_format_for_mac_address(mac_address: Any, result: str, valid: bool): 62 | if valid: 63 | assert Network(mac_address=MacAddress(mac_address)).mac_address == result 64 | else: 65 | with pytest.raises(ValidationError, match='format'): 66 | Network(mac_address=MacAddress(mac_address)) 67 | 68 | 69 | @pytest.mark.parametrize( 70 | 'mac_address, result, valid', 71 | [ 72 | # Valid MAC addresses 73 | ('00:00:5e:00:53:01', '00:00:5e:00:53:01', True), 74 | ('02:00:5e:10:00:00:00:01', '02:00:5e:10:00:00:00:01', True), 75 | ( 76 | '00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01', 77 | '00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01', 78 | True, 79 | ), 80 | ('00-00-5e-00-53-01', '00:00:5e:00:53:01', True), 81 | ('02-00-5e-10-00-00-00-01', '02:00:5e:10:00:00:00:01', True), 82 | ( 83 | '00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01', 84 | '00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01', 85 | True, 86 | ), 87 | ('0000.5e00.5301', '00:00:5e:00:53:01', True), 88 | ('0200.5e10.0000.0001', '02:00:5e:10:00:00:00:01', True), 89 | ( 90 | '0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001', 91 | '00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01', 92 | True, 93 | ), 94 | # Invalid MAC addresses 95 | ('0', None, False), 96 | ('00:00:00', None, False), 97 | ('00-00-5e-00-53-01-01', None, False), 98 | ('0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001.0000.0001', None, False), 99 | ], 100 | ) 101 | def test_length_for_mac_address(mac_address: str, result: str, valid: bool): 102 | if valid: 103 | assert Network(mac_address=MacAddress(mac_address)).mac_address == result 104 | else: 105 | with pytest.raises(ValueError, match='Length'): 106 | Network(mac_address=MacAddress(mac_address)) 107 | 108 | 109 | @pytest.mark.parametrize( 110 | 'mac_address, valid', 111 | [ 112 | # Valid MAC addresses 113 | ('00:00:5e:00:53:01', True), 114 | (MacAddress('00:00:5e:00:53:01'), True), 115 | # Invalid MAC addresses 116 | (0, False), 117 | (['00:00:00'], False), 118 | ], 119 | ) 120 | def test_type_for_mac_address(mac_address: Any, valid: bool): 121 | if valid: 122 | Network(mac_address=MacAddress(mac_address)) 123 | else: 124 | with pytest.raises(ValidationError, match='MAC address must be 14'): 125 | Network(mac_address=MacAddress(mac_address)) 126 | 127 | 128 | def test_model_validation(): 129 | class Model(BaseModel): 130 | mac_address: MacAddress 131 | 132 | assert Model(mac_address='00:00:5e:00:53:01').mac_address == '00:00:5e:00:53:01' 133 | with pytest.raises(ValidationError) as exc_info: 134 | Model(mac_address='1234') 135 | 136 | assert exc_info.value.errors() == [ 137 | { 138 | 'ctx': {'mac_address': '1234', 'required_length': 14}, 139 | 'input': '1234', 140 | 'loc': ('mac_address',), 141 | 'msg': 'Length for a 1234 MAC address must be 14', 142 | 'type': 'mac_address_len', 143 | } 144 | ] 145 | -------------------------------------------------------------------------------- /tests/test_mongo_object_id.py: -------------------------------------------------------------------------------- 1 | """Tests for the mongo_object_id module.""" 2 | 3 | import pytest 4 | from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError 5 | from pydantic.json_schema import JsonSchemaMode 6 | 7 | from pydantic_extra_types.mongo_object_id import MongoObjectId 8 | 9 | 10 | class MongoDocument(BaseModel): 11 | object_id: MongoObjectId 12 | 13 | 14 | @pytest.mark.parametrize( 15 | 'object_id, result, valid', 16 | [ 17 | # Valid ObjectId for str format 18 | ('611827f2878b88b49ebb69fc', '611827f2878b88b49ebb69fc', True), 19 | ('611827f2878b88b49ebb69fd', '611827f2878b88b49ebb69fd', True), 20 | # Invalid ObjectId for str format 21 | ('611827f2878b88b49ebb69f', None, False), # Invalid ObjectId (short length) 22 | ('611827f2878b88b49ebb69fca', None, False), # Invalid ObjectId (long length) 23 | # Valid ObjectId for bytes format 24 | ], 25 | ) 26 | def test_format_for_object_id(object_id: str, result: str, valid: bool) -> None: 27 | """Test the MongoObjectId validation.""" 28 | if valid: 29 | assert str(MongoDocument(object_id=object_id).object_id) == result 30 | else: 31 | with pytest.raises(ValidationError): 32 | MongoDocument(object_id=object_id) 33 | with pytest.raises( 34 | ValueError, 35 | match=f"Invalid ObjectId {object_id} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'.", 36 | ): 37 | MongoObjectId.validate(object_id) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | 'schema_mode', 42 | [ 43 | 'validation', 44 | 'serialization', 45 | ], 46 | ) 47 | def test_json_schema(schema_mode: JsonSchemaMode) -> None: 48 | """Test the MongoObjectId model_json_schema implementation.""" 49 | expected_json_schema = { 50 | 'properties': { 51 | 'object_id': { 52 | 'maxLength': MongoObjectId.OBJECT_ID_LENGTH, 53 | 'minLength': MongoObjectId.OBJECT_ID_LENGTH, 54 | 'title': 'Object Id', 55 | 'type': 'string', 56 | } 57 | }, 58 | 'required': ['object_id'], 59 | 'title': 'MongoDocument', 60 | 'type': 'object', 61 | } 62 | assert MongoDocument.model_json_schema(mode=schema_mode) == expected_json_schema 63 | 64 | 65 | def test_get_pydantic_core_schema() -> None: 66 | """Test the __get_pydantic_core_schema__ method override.""" 67 | schema = MongoObjectId.__get_pydantic_core_schema__(MongoObjectId, GetCoreSchemaHandler()) 68 | assert isinstance(schema, dict) 69 | assert 'json_schema' in schema 70 | assert 'python_schema' in schema 71 | assert schema['json_schema']['type'] == 'str' 72 | -------------------------------------------------------------------------------- /tests/test_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import pytest 5 | from pydantic import BaseModel 6 | 7 | from pydantic_extra_types.path import ( 8 | ExistingPath, 9 | ResolvedDirectoryPath, 10 | ResolvedExistingPath, 11 | ResolvedFilePath, 12 | ResolvedNewPath, 13 | ) 14 | 15 | 16 | class File(BaseModel): 17 | file: ResolvedFilePath 18 | 19 | 20 | class Directory(BaseModel): 21 | directory: ResolvedDirectoryPath 22 | 23 | 24 | class NewPath(BaseModel): 25 | new_path: ResolvedNewPath 26 | 27 | 28 | class Existing(BaseModel): 29 | existing: ExistingPath 30 | 31 | 32 | class ResolvedExisting(BaseModel): 33 | resolved_existing: ResolvedExistingPath 34 | 35 | 36 | @pytest.fixture 37 | def absolute_file_path(tmp_path: pathlib.Path) -> pathlib.Path: 38 | directory = tmp_path / 'test-relative' 39 | directory.mkdir() 40 | file_path = directory / 'test-relative.txt' 41 | file_path.touch() 42 | return file_path 43 | 44 | 45 | @pytest.fixture 46 | def relative_file_path(absolute_file_path: pathlib.Path) -> pathlib.Path: 47 | return pathlib.Path(os.path.relpath(absolute_file_path, os.getcwd())) 48 | 49 | 50 | @pytest.fixture 51 | def absolute_directory_path(tmp_path: pathlib.Path) -> pathlib.Path: 52 | directory = tmp_path / 'test-relative' 53 | directory.mkdir() 54 | return directory 55 | 56 | 57 | @pytest.fixture 58 | def relative_directory_path(absolute_directory_path: pathlib.Path) -> pathlib.Path: 59 | return pathlib.Path(os.path.relpath(absolute_directory_path, os.getcwd())) 60 | 61 | 62 | @pytest.fixture 63 | def absolute_new_path(tmp_path: pathlib.Path) -> pathlib.Path: 64 | return tmp_path / 'test-relative' 65 | 66 | 67 | @pytest.fixture 68 | def relative_new_path(absolute_new_path: pathlib.Path) -> pathlib.Path: 69 | return pathlib.Path(os.path.relpath(absolute_new_path, os.getcwd())) 70 | 71 | 72 | def test_relative_file(absolute_file_path: pathlib.Path, relative_file_path: pathlib.Path): 73 | file = File(file=relative_file_path) 74 | assert file.file == absolute_file_path 75 | 76 | 77 | def test_absolute_file(absolute_file_path: pathlib.Path): 78 | file = File(file=absolute_file_path) 79 | assert file.file == absolute_file_path 80 | 81 | 82 | def test_relative_directory(absolute_directory_path: pathlib.Path, relative_directory_path: pathlib.Path): 83 | directory = Directory(directory=relative_directory_path) 84 | assert directory.directory == absolute_directory_path 85 | 86 | 87 | def test_absolute_directory(absolute_directory_path: pathlib.Path): 88 | directory = Directory(directory=absolute_directory_path) 89 | assert directory.directory == absolute_directory_path 90 | 91 | 92 | def test_relative_new_path(absolute_new_path: pathlib.Path, relative_new_path: pathlib.Path): 93 | new_path = NewPath(new_path=relative_new_path) 94 | assert new_path.new_path == absolute_new_path 95 | 96 | 97 | def test_absolute_new_path(absolute_new_path: pathlib.Path): 98 | new_path = NewPath(new_path=absolute_new_path) 99 | assert new_path.new_path == absolute_new_path 100 | 101 | 102 | @pytest.mark.parametrize( 103 | ('pass_fixture', 'expect_fixture'), 104 | ( 105 | ('relative_file_path', 'relative_file_path'), 106 | ('absolute_file_path', 'absolute_file_path'), 107 | ('relative_directory_path', 'relative_directory_path'), 108 | ('absolute_directory_path', 'absolute_directory_path'), 109 | ), 110 | ) 111 | def test_existing_path(request: pytest.FixtureRequest, pass_fixture: str, expect_fixture: str): 112 | existing = Existing(existing=request.getfixturevalue(pass_fixture)) 113 | assert existing.existing == request.getfixturevalue(expect_fixture) 114 | 115 | 116 | @pytest.mark.parametrize( 117 | ('pass_fixture', 'expect_fixture'), 118 | ( 119 | ('relative_file_path', 'absolute_file_path'), 120 | ('absolute_file_path', 'absolute_file_path'), 121 | ('relative_directory_path', 'absolute_directory_path'), 122 | ('absolute_directory_path', 'absolute_directory_path'), 123 | ), 124 | ) 125 | def test_resolved_existing_path(request: pytest.FixtureRequest, pass_fixture: str, expect_fixture: str): 126 | resolved_existing = ResolvedExisting(resolved_existing=request.getfixturevalue(pass_fixture)) 127 | assert resolved_existing.resolved_existing == request.getfixturevalue(expect_fixture) 128 | -------------------------------------------------------------------------------- /tests/test_phone_numbers.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | 6 | from pydantic_extra_types.phone_numbers import PhoneNumber 7 | 8 | 9 | class Something(BaseModel): 10 | phone_number: PhoneNumber 11 | 12 | 13 | # Note: the 555 area code will result in an invalid phone number 14 | def test_valid_phone_number() -> None: 15 | Something(phone_number='+1 901 555 1212') 16 | 17 | 18 | def test_when_extension_provided() -> None: 19 | Something(phone_number='+1 901 555 1212 ext 12533') 20 | 21 | 22 | @pytest.mark.parametrize('invalid_number', ['', '123', 12, None, object(), '55 121']) 23 | def test_invalid_phone_number(invalid_number: Any) -> None: 24 | with pytest.raises(ValidationError): 25 | Something(phone_number=invalid_number) 26 | 27 | 28 | def test_formats_phone_number() -> None: 29 | result = Something(phone_number='+1 901 555 1212 ext 12533') 30 | assert result.phone_number == 'tel:+1-901-555-1212;ext=12533' 31 | 32 | 33 | def test_supported_regions() -> None: 34 | assert PhoneNumber.supported_regions == [] 35 | PhoneNumber.supported_regions = ['US'] 36 | 37 | assert Something(phone_number='+1 901 555 1212') 38 | 39 | with pytest.raises(ValidationError, match='value is not from a supported region'): 40 | Something(phone_number='+44 20 7946 0958') 41 | 42 | USPhoneNumber = PhoneNumber() 43 | USPhoneNumber.supported_regions = ['US'] 44 | assert USPhoneNumber.supported_regions == ['US'] 45 | assert Something(phone_number='+1 901 555 1212') 46 | 47 | with pytest.raises(ValidationError, match='value is not from a supported region'): 48 | Something(phone_number='+44 20 7946 0958') 49 | 50 | 51 | def test_parse_error() -> None: 52 | with pytest.raises(ValidationError, match='value is not a valid phone number'): 53 | Something(phone_number='555 1212') 54 | 55 | 56 | def test_parsed_but_not_a_valid_number() -> None: 57 | with pytest.raises(ValidationError, match='value is not a valid phone number'): 58 | Something(phone_number='+1 555-1212') 59 | 60 | 61 | def test_hashes() -> None: 62 | assert hash(PhoneNumber('555-1212')) == hash(PhoneNumber('555-1212')) 63 | assert hash(PhoneNumber('555-1212')) == hash('555-1212') 64 | assert hash(PhoneNumber('555-1212')) != hash('555-1213') 65 | assert hash(PhoneNumber('555-1212')) != hash(PhoneNumber('555-1213')) 66 | 67 | 68 | def test_eq() -> None: 69 | assert PhoneNumber('555-1212') == PhoneNumber('555-1212') 70 | assert PhoneNumber('555-1212') == '555-1212' 71 | assert PhoneNumber('555-1212') != '555-1213' 72 | assert PhoneNumber('555-1212') != PhoneNumber('555-1213') 73 | -------------------------------------------------------------------------------- /tests/test_phone_numbers_validator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union 2 | 3 | import phonenumbers 4 | import pytest 5 | from phonenumbers import PhoneNumber 6 | from pydantic import BaseModel, TypeAdapter, ValidationError 7 | from typing_extensions import Annotated 8 | 9 | from pydantic_extra_types.phone_numbers import PhoneNumberValidator 10 | 11 | Number = Annotated[Union[str, PhoneNumber], PhoneNumberValidator()] 12 | NANumber = Annotated[ 13 | Union[str, PhoneNumber], 14 | PhoneNumberValidator( 15 | supported_regions=['US', 'CA'], 16 | default_region='US', 17 | ), 18 | ] 19 | UKNumber = Annotated[ 20 | Union[str, PhoneNumber], 21 | PhoneNumberValidator( 22 | supported_regions=['GB'], 23 | default_region='GB', 24 | number_format='E164', 25 | ), 26 | ] 27 | 28 | number_adapter = TypeAdapter(Number) 29 | 30 | 31 | class Numbers(BaseModel): 32 | phone_number: Optional[Number] = None 33 | na_number: Optional[NANumber] = None 34 | uk_number: Optional[UKNumber] = None 35 | 36 | 37 | def test_validator_constructor() -> None: 38 | PhoneNumberValidator() 39 | PhoneNumberValidator(supported_regions=['US', 'CA'], default_region='US') 40 | PhoneNumberValidator(supported_regions=['GB'], default_region='GB', number_format='E164') 41 | with pytest.raises(ValueError, match='Invalid default region code: XX'): 42 | PhoneNumberValidator(default_region='XX') 43 | with pytest.raises(ValueError, match='Invalid number format: XX'): 44 | PhoneNumberValidator(number_format='XX') 45 | with pytest.raises(ValueError, match='Invalid supported region code: XX'): 46 | PhoneNumberValidator(supported_regions=['XX']) 47 | 48 | 49 | # Note: the 555 area code will result in an invalid phone number 50 | def test_valid_phone_number() -> None: 51 | Numbers(phone_number='+1 901 555 1212') 52 | 53 | 54 | def test_when_extension_provided() -> None: 55 | Numbers(phone_number='+1 901 555 1212 ext 12533') 56 | 57 | 58 | def test_when_phonenumber_instance() -> None: 59 | phone_number = phonenumbers.parse('+1 901 555 1212', region='US') 60 | numbers = Numbers(phone_number=phone_number) 61 | assert numbers.phone_number == 'tel:+1-901-555-1212' 62 | # Additional validation is still performed on the instance 63 | with pytest.raises(ValidationError, match='value is not from a supported region'): 64 | Numbers(uk_number=phone_number) 65 | 66 | 67 | @pytest.mark.parametrize('invalid_number', ['', '123', 12, object(), '55 121']) 68 | def test_invalid_phone_number(invalid_number: Any) -> None: 69 | # Use a TypeAdapter to test the validation logic for None otherwise 70 | # optional fields will not attempt to validate 71 | with pytest.raises(ValidationError, match='value is not a valid phone number'): 72 | number_adapter.validate_python(invalid_number) 73 | 74 | 75 | def test_formats_phone_number() -> None: 76 | result = Numbers(phone_number='+1 901 555 1212 ext 12533', uk_number='+44 20 7946 0958') 77 | assert result.phone_number == 'tel:+1-901-555-1212;ext=12533' 78 | assert result.uk_number == '+442079460958' 79 | 80 | 81 | def test_default_region() -> None: 82 | result = Numbers(na_number='901 555 1212') 83 | assert result.na_number == 'tel:+1-901-555-1212' 84 | with pytest.raises(ValidationError, match='value is not a valid phone number'): 85 | Numbers(phone_number='901 555 1212') 86 | 87 | 88 | def test_supported_regions() -> None: 89 | assert Numbers(na_number='+1 901 555 1212') 90 | assert Numbers(uk_number='+44 20 7946 0958') 91 | with pytest.raises(ValidationError, match='value is not from a supported region'): 92 | Numbers(na_number='+44 20 7946 0958') 93 | 94 | 95 | def test_parse_error() -> None: 96 | with pytest.raises(ValidationError, match='value is not a valid phone number'): 97 | Numbers(phone_number='555 1212') 98 | 99 | 100 | def test_parsed_but_not_a_valid_number() -> None: 101 | with pytest.raises(ValidationError, match='value is not a valid phone number'): 102 | Numbers(phone_number='+1 555-1212') 103 | -------------------------------------------------------------------------------- /tests/test_routing_number.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | 6 | from pydantic_extra_types.routing_number import ABARoutingNumber 7 | 8 | 9 | class Model(BaseModel): 10 | routing_number: ABARoutingNumber 11 | 12 | 13 | @pytest.mark.parametrize('routing_number', [12, None, object(), 123456789]) 14 | def test_invalid_routing_number_string(routing_number: Any) -> None: 15 | with pytest.raises(ValidationError) as validation_error: 16 | Model(routing_number=routing_number) 17 | assert validation_error.match('Input should be a valid string') 18 | 19 | 20 | @pytest.mark.parametrize('routing_number', ['', '123', '1234567890']) 21 | def test_invalid_routing_number_length(routing_number: Any) -> None: 22 | with pytest.raises(ValidationError) as validation_error: 23 | Model(routing_number=routing_number) 24 | assert validation_error.match(r'String should have at (most|least) 9 characters') 25 | 26 | 27 | @pytest.mark.parametrize('routing_number', ['122105154', '122235822', '123103723', '074900781']) 28 | def test_invalid_routing_number(routing_number: Any) -> None: 29 | with pytest.raises(ValidationError) as validation_error: 30 | Model(routing_number=routing_number) 31 | assert validation_error.match('Incorrect ABA routing transit number') 32 | 33 | 34 | @pytest.mark.parametrize('routing_number', ['122105155', '122235821', '123103729', '074900783']) 35 | def test_valid_routing_number(routing_number: str) -> None: 36 | Model(routing_number=routing_number) 37 | 38 | 39 | def test_raises_error_when_not_a_string() -> None: 40 | with pytest.raises(ValidationError, match='routing number is not all digits'): 41 | Model(routing_number='A12210515') 42 | -------------------------------------------------------------------------------- /tests/test_s3.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel, ValidationError 3 | 4 | from pydantic_extra_types.s3 import S3Path 5 | 6 | 7 | class S3Check(BaseModel): 8 | path: S3Path 9 | 10 | 11 | @pytest.mark.parametrize( 12 | 'raw,bucket,key,last_key', 13 | [ 14 | ( 15 | 's3://my-data-bucket/2023/08/29/sales-report.csv', 16 | 'my-data-bucket', 17 | '2023/08/29/sales-report.csv', 18 | 'sales-report.csv', 19 | ), 20 | ( 21 | 's3://logs-bucket/app-logs/production/2024/07/01/application-log.txt', 22 | 'logs-bucket', 23 | 'app-logs/production/2024/07/01/application-log.txt', 24 | 'application-log.txt', 25 | ), 26 | ( 27 | 's3://backup-storage/user_data/john_doe/photos/photo-2024-08-15.jpg', 28 | 'backup-storage', 29 | 'user_data/john_doe/photos/photo-2024-08-15.jpg', 30 | 'photo-2024-08-15.jpg', 31 | ), 32 | ( 33 | 's3://analytics-bucket/weekly-reports/Q3/2023/week-35-summary.pdf', 34 | 'analytics-bucket', 35 | 'weekly-reports/Q3/2023/week-35-summary.pdf', 36 | 'week-35-summary.pdf', 37 | ), 38 | ( 39 | 's3://project-data/docs/presentations/quarterly_review.pptx', 40 | 'project-data', 41 | 'docs/presentations/quarterly_review.pptx', 42 | 'quarterly_review.pptx', 43 | ), 44 | ( 45 | 's3://my-music-archive/genres/rock/2024/favorite-songs.mp3', 46 | 'my-music-archive', 47 | 'genres/rock/2024/favorite-songs.mp3', 48 | 'favorite-songs.mp3', 49 | ), 50 | ( 51 | 's3://video-uploads/movies/2024/03/action/thriller/movie-trailer.mp4', 52 | 'video-uploads', 53 | 'movies/2024/03/action/thriller/movie-trailer.mp4', 54 | 'movie-trailer.mp4', 55 | ), 56 | ( 57 | 's3://company-files/legal/contracts/contract-2023-09-01.pdf', 58 | 'company-files', 59 | 'legal/contracts/contract-2023-09-01.pdf', 60 | 'contract-2023-09-01.pdf', 61 | ), 62 | ( 63 | 's3://dev-environment/source-code/release_v1.0.2.zip', 64 | 'dev-environment', 65 | 'source-code/release_v1.0.2.zip', 66 | 'release_v1.0.2.zip', 67 | ), 68 | ( 69 | 's3://public-bucket/open-data/geojson/maps/city_boundaries.geojson', 70 | 'public-bucket', 71 | 'open-data/geojson/maps/city_boundaries.geojson', 72 | 'city_boundaries.geojson', 73 | ), 74 | ( 75 | 's3://image-storage/2024/portfolio/shoots/wedding/couple_photo_12.jpg', 76 | 'image-storage', 77 | '2024/portfolio/shoots/wedding/couple_photo_12.jpg', 78 | 'couple_photo_12.jpg', 79 | ), 80 | ( 81 | 's3://finance-data/reports/2024/Q2/income_statement.xlsx', 82 | 'finance-data', 83 | 'reports/2024/Q2/income_statement.xlsx', 84 | 'income_statement.xlsx', 85 | ), 86 | ( 87 | 's3://training-data/nlp/corpora/english/2023/text_corpus.txt', 88 | 'training-data', 89 | 'nlp/corpora/english/2023/text_corpus.txt', 90 | 'text_corpus.txt', 91 | ), 92 | ( 93 | 's3://ecommerce-backup/2024/transactions/august/orders_2024_08_28.csv', 94 | 'ecommerce-backup', 95 | '2024/transactions/august/orders_2024_08_28.csv', 96 | 'orders_2024_08_28.csv', 97 | ), 98 | ( 99 | 's3://gaming-assets/3d_models/characters/hero/model_v5.obj', 100 | 'gaming-assets', 101 | '3d_models/characters/hero/model_v5.obj', 102 | 'model_v5.obj', 103 | ), 104 | ( 105 | 's3://iot-sensor-data/2024/temperature_sensors/sensor_42_readings.csv', 106 | 'iot-sensor-data', 107 | '2024/temperature_sensors/sensor_42_readings.csv', 108 | 'sensor_42_readings.csv', 109 | ), 110 | ( 111 | 's3://user-uploads/avatars/user123/avatar_2024_08_29.png', 112 | 'user-uploads', 113 | 'avatars/user123/avatar_2024_08_29.png', 114 | 'avatar_2024_08_29.png', 115 | ), 116 | ( 117 | 's3://media-library/podcasts/2023/episode_45.mp3', 118 | 'media-library', 119 | 'podcasts/2023/episode_45.mp3', 120 | 'episode_45.mp3', 121 | ), 122 | ( 123 | 's3://logs-bucket/security/firewall-logs/2024/08/failed_attempts.log', 124 | 'logs-bucket', 125 | 'security/firewall-logs/2024/08/failed_attempts.log', 126 | 'failed_attempts.log', 127 | ), 128 | ( 129 | 's3://data-warehouse/financials/quarterly/2024/Q1/profit_loss.csv', 130 | 'data-warehouse', 131 | 'financials/quarterly/2024/Q1/profit_loss.csv', 132 | 'profit_loss.csv', 133 | ), 134 | ( 135 | 's3://data-warehouse/financials/quarterly/2024/Q1', 136 | 'data-warehouse', 137 | 'financials/quarterly/2024/Q1', 138 | 'Q1', 139 | ), 140 | ], 141 | ) 142 | def test_s3(raw: str, bucket: str, key: str, last_key: str): 143 | model = S3Check(path=raw) 144 | assert model.path == S3Path(raw) 145 | assert model.path.bucket == bucket 146 | assert model.path.key == key 147 | assert model.path.last_key == last_key 148 | 149 | 150 | def test_wrong_s3(): 151 | with pytest.raises(ValidationError): 152 | S3Check(path='s3/ok') 153 | -------------------------------------------------------------------------------- /tests/test_scripts.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pycountry 4 | import pytest 5 | from pydantic import BaseModel, ValidationError 6 | 7 | from pydantic_extra_types.script_code import ISO_15924 8 | 9 | 10 | class ScriptCheck(BaseModel): 11 | script: ISO_15924 12 | 13 | 14 | @pytest.mark.parametrize('script', map(lambda lang: lang.alpha_4, pycountry.scripts)) 15 | def test_ISO_15924_code_ok(script: str): 16 | model = ScriptCheck(script=script) 17 | assert model.script == script 18 | assert str(model.script) == script 19 | assert model.model_dump() == {'script': script} # test serialization 20 | 21 | 22 | def test_ISO_15924_code_fail_not_enought_letters(): 23 | with pytest.raises( 24 | ValidationError, 25 | match=re.escape( 26 | '1 validation error for ScriptCheck\nscript\n ' 27 | "String should have at least 4 characters [type=string_too_short, input_value='X', input_type=str]\n" 28 | ), 29 | ): 30 | ScriptCheck(script='X') 31 | 32 | 33 | def test_ISO_15924_code_fail_too_much_letters(): 34 | with pytest.raises( 35 | ValidationError, 36 | match=re.escape( 37 | '1 validation error for ScriptCheck\nscript\n ' 38 | "String should have at most 4 characters [type=string_too_long, input_value='Klingon', input_type=str]" 39 | ), 40 | ): 41 | ScriptCheck(script='Klingon') 42 | 43 | 44 | def test_ISO_15924_code_fail_not_existing(): 45 | with pytest.raises( 46 | ValidationError, 47 | match=re.escape( 48 | '1 validation error for ScriptCheck\nscript\n ' 49 | 'Invalid ISO 15924 script code. See https://en.wikipedia.org/wiki/ISO_15924 ' 50 | "[type=ISO_15924, input_value='Klin', input_type=str]" 51 | ), 52 | ): 53 | ScriptCheck(script='Klin') 54 | -------------------------------------------------------------------------------- /tests/test_semantic_version.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import semver 3 | from pydantic import BaseModel, ValidationError 4 | 5 | from pydantic_extra_types.semantic_version import SemanticVersion 6 | 7 | 8 | @pytest.fixture(scope='module', name='SemanticVersionObject') 9 | def application_object_fixture(): 10 | class Application(BaseModel): 11 | version: SemanticVersion 12 | 13 | return Application 14 | 15 | 16 | @pytest.mark.parametrize( 17 | 'constructor', [str, semver.Version.parse, SemanticVersion.parse], ids=['str', 'semver.Version', 'SemanticVersion'] 18 | ) 19 | @pytest.mark.parametrize( 20 | 'version', 21 | [ 22 | '0.0.4', 23 | '1.2.3', 24 | '10.20.30', 25 | '1.1.2-prerelease+meta', 26 | '1.1.2+meta', 27 | '1.1.2+meta-valid', 28 | '1.0.0-alpha', 29 | '1.0.0-beta', 30 | '1.0.0-alpha.beta', 31 | '1.0.0-alpha.beta.1', 32 | '1.0.0-alpha.1', 33 | '1.0.0-alpha0.valid', 34 | '1.0.0-alpha.0valid', 35 | '1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', 36 | '1.0.0-rc.1+build.1', 37 | '2.0.0-rc.1+build.123', 38 | '1.2.3-beta', 39 | '10.2.3-DEV-SNAPSHOT', 40 | '1.2.3-SNAPSHOT-123', 41 | '1.0.0', 42 | '2.0.0', 43 | '1.1.7', 44 | '2.0.0+build.1848', 45 | '2.0.1-alpha.1227', 46 | '1.0.0-alpha+beta', 47 | '1.2.3----RC-SNAPSHOT.12.9.1--.12+788', 48 | '1.2.3----R-S.12.9.1--.12+meta', 49 | '1.2.3----RC-SNAPSHOT.12.9.1--.12', 50 | '1.0.0+0.build.1-rc.10000aaa-kk-0.1', 51 | '99999999999999999999999.999999999999999999.99999999999999999', 52 | '1.0.0-0A.is.legal', 53 | ], 54 | ) 55 | def test_valid_semantic_version(SemanticVersionObject, constructor, version): 56 | application = SemanticVersionObject(version=constructor(version)) 57 | assert application.version 58 | assert application.model_dump() == {'version': version} 59 | 60 | 61 | @pytest.mark.parametrize( 62 | 'invalid_version', 63 | [ 64 | '', 65 | '1', 66 | '1.2', 67 | '1.2.3-0123', 68 | '1.2.3-0123.0123', 69 | '1.1.2+.123', 70 | '+invalid', 71 | '-invalid', 72 | '-invalid+invalid', 73 | '-invalid.01', 74 | 'alpha', 75 | 'alpha.beta', 76 | 'alpha.beta.1', 77 | 'alpha.1', 78 | 'alpha+beta', 79 | 'alpha_beta', 80 | 'alpha.', 81 | 'alpha..', 82 | 'beta', 83 | '1.0.0-alpha_beta', 84 | '-alpha.', 85 | '1.0.0-alpha..', 86 | '1.0.0-alpha..1', 87 | '1.0.0-alpha...1', 88 | '1.0.0-alpha....1', 89 | '1.0.0-alpha.....1', 90 | '1.0.0-alpha......1', 91 | '1.0.0-alpha.......1', 92 | '01.1.1', 93 | '1.01.1', 94 | '1.1.01', 95 | '1.2', 96 | '1.2.3.DEV', 97 | '1.2-SNAPSHOT', 98 | '1.2.31.2.3----RC-SNAPSHOT.12.09.1--..12+788', 99 | '1.2-RC-SNAPSHOT', 100 | '-1.0.3-gamma+b7718', 101 | '+justmeta', 102 | '9.8.7+meta+meta', 103 | '9.8.7-whatever+meta+meta', 104 | '99999999999999999999999.999999999999999999.99999999999999999----RC-SNAPSHOT.12.09.1--------------------------------..12', 105 | ], 106 | ) 107 | def test_invalid_semantic_version(SemanticVersionObject, invalid_version): 108 | with pytest.raises(ValidationError): 109 | SemanticVersionObject(version=invalid_version) 110 | -------------------------------------------------------------------------------- /tests/test_semver.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | from pydantic_extra_types.semver import _VersionPydanticAnnotation 5 | 6 | 7 | class SomethingWithAVersion(BaseModel): 8 | version: _VersionPydanticAnnotation 9 | 10 | 11 | def test_valid_semver() -> None: 12 | SomethingWithAVersion(version='1.2.3') 13 | 14 | 15 | def test_valid_semver_with_prerelease() -> None: 16 | SomethingWithAVersion(version='1.2.3-alpha.1') 17 | 18 | 19 | def test_invalid_semver() -> None: 20 | with pytest.raises(ValueError): 21 | SomethingWithAVersion(version='jim.was.here') 22 | -------------------------------------------------------------------------------- /tests/test_timezone_names.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pytest 4 | import pytz 5 | from pydantic import BaseModel, ValidationError 6 | from pydantic_core import PydanticCustomError 7 | 8 | from pydantic_extra_types.timezone_name import TimeZoneName, TimeZoneNameSettings, timezone_name_settings 9 | 10 | has_zone_info = True 11 | try: 12 | from zoneinfo import available_timezones 13 | except ImportError: 14 | has_zone_info = False 15 | 16 | pytz_zones_bad = [(zone.lower(), zone) for zone in pytz.all_timezones] 17 | pytz_zones_bad.extend([(f' {zone}', zone) for zone in pytz.all_timezones_set]) 18 | 19 | 20 | class TZNameCheck(BaseModel): 21 | timezone_name: TimeZoneName 22 | 23 | 24 | @timezone_name_settings(strict=False) 25 | class TZNonStrict(TimeZoneName): 26 | pass 27 | 28 | 29 | class NonStrictTzName(BaseModel): 30 | timezone_name: TZNonStrict 31 | 32 | 33 | @pytest.mark.parametrize('zone', pytz.all_timezones) 34 | def test_all_timezones_non_strict_pytz(zone): 35 | assert TZNameCheck(timezone_name=zone).timezone_name == zone 36 | assert NonStrictTzName(timezone_name=zone).timezone_name == zone 37 | 38 | 39 | @pytest.mark.parametrize('zone', pytz_zones_bad) 40 | def test_all_timezones_pytz_lower(zone): 41 | assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1] 42 | 43 | 44 | def test_fail_non_existing_timezone(): 45 | with pytest.raises( 46 | ValidationError, 47 | match=re.escape( 48 | '1 validation error for TZNameCheck\n' 49 | 'timezone_name\n ' 50 | 'Invalid timezone name. ' 51 | "[type=TimeZoneName, input_value='mars', input_type=str]" 52 | ), 53 | ): 54 | TZNameCheck(timezone_name='mars') 55 | 56 | with pytest.raises( 57 | ValidationError, 58 | match=re.escape( 59 | '1 validation error for NonStrictTzName\n' 60 | 'timezone_name\n ' 61 | 'Invalid timezone name. ' 62 | "[type=TimeZoneName, input_value='mars', input_type=str]" 63 | ), 64 | ): 65 | NonStrictTzName(timezone_name='mars') 66 | 67 | 68 | if has_zone_info: 69 | zones = list(available_timezones()) 70 | zones.sort() 71 | zones_bad = [(zone.lower(), zone) for zone in zones] 72 | 73 | @pytest.mark.parametrize('zone', zones) 74 | def test_all_timezones_zone_info(zone): 75 | assert TZNameCheck(timezone_name=zone).timezone_name == zone 76 | assert NonStrictTzName(timezone_name=zone).timezone_name == zone 77 | 78 | @pytest.mark.parametrize('zone', zones_bad) 79 | def test_all_timezones_zone_info_NonStrict(zone): 80 | assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1] 81 | 82 | 83 | def test_timezone_name_settings_metaclass(): 84 | class TestStrictTZ(TimeZoneName, strict=True, metaclass=TimeZoneNameSettings): 85 | pass 86 | 87 | class TestNonStrictTZ(TimeZoneName, strict=False, metaclass=TimeZoneNameSettings): 88 | pass 89 | 90 | assert TestStrictTZ.strict is True 91 | assert TestNonStrictTZ.strict is False 92 | 93 | # Test default value 94 | class TestDefaultStrictTZ(TimeZoneName, metaclass=TimeZoneNameSettings): 95 | pass 96 | 97 | assert TestDefaultStrictTZ.strict is True 98 | 99 | 100 | def test_timezone_name_validation(): 101 | valid_tz = 'America/New_York' 102 | invalid_tz = 'Invalid/Timezone' 103 | 104 | assert TimeZoneName._validate(valid_tz, None) == valid_tz 105 | 106 | with pytest.raises(PydanticCustomError): 107 | TimeZoneName._validate(invalid_tz, None) 108 | 109 | assert TZNonStrict._validate(valid_tz.lower(), None) == valid_tz 110 | assert TZNonStrict._validate(f' {valid_tz} ', None) == valid_tz 111 | 112 | with pytest.raises(PydanticCustomError): 113 | TZNonStrict._validate(invalid_tz, None) 114 | 115 | 116 | def test_timezone_name_pydantic_core_schema(): 117 | schema = TimeZoneName.__get_pydantic_core_schema__(TimeZoneName, None) 118 | assert isinstance(schema, dict) 119 | assert schema['type'] == 'function-after' 120 | assert 'function' in schema 121 | assert 'schema' in schema 122 | assert schema['schema']['type'] == 'str' 123 | assert schema['schema']['min_length'] == 1 124 | 125 | 126 | def test_timezone_name_pydantic_json_schema(): 127 | core_schema = TimeZoneName.__get_pydantic_core_schema__(TimeZoneName, None) 128 | 129 | class MockJsonSchemaHandler: 130 | def __call__(self, schema): 131 | return {'type': 'string'} 132 | 133 | handler = MockJsonSchemaHandler() 134 | json_schema = TimeZoneName.__get_pydantic_json_schema__(core_schema, handler) 135 | assert 'enum' in json_schema 136 | assert isinstance(json_schema['enum'], list) 137 | assert len(json_schema['enum']) > 0 138 | 139 | 140 | def test_timezone_name_repr(): 141 | tz = TimeZoneName('America/New_York') 142 | assert repr(tz) == "'America/New_York'" 143 | assert str(tz) == 'America/New_York' 144 | 145 | 146 | def test_timezone_name_allowed_values(): 147 | assert isinstance(TimeZoneName.allowed_values, set) 148 | assert len(TimeZoneName.allowed_values) > 0 149 | assert all(isinstance(tz, str) for tz in TimeZoneName.allowed_values) 150 | 151 | assert isinstance(TimeZoneName.allowed_values_list, list) 152 | assert len(TimeZoneName.allowed_values_list) > 0 153 | assert all(isinstance(tz, str) for tz in TimeZoneName.allowed_values_list) 154 | 155 | assert isinstance(TimeZoneName.allowed_values_upper_to_correct, dict) 156 | assert len(TimeZoneName.allowed_values_upper_to_correct) > 0 157 | assert all( 158 | isinstance(k, str) and isinstance(v, str) for k, v in TimeZoneName.allowed_values_upper_to_correct.items() 159 | ) 160 | 161 | 162 | def test_timezone_name_inheritance(): 163 | class CustomTZ(TimeZoneName, metaclass=TimeZoneNameSettings): 164 | pass 165 | 166 | assert issubclass(CustomTZ, TimeZoneName) 167 | assert issubclass(CustomTZ, str) 168 | assert isinstance(CustomTZ('America/New_York'), (CustomTZ, TimeZoneName, str)) 169 | 170 | 171 | def test_timezone_name_string_operations(): 172 | tz = TimeZoneName('America/New_York') 173 | assert tz.upper() == 'AMERICA/NEW_YORK' 174 | assert tz.lower() == 'america/new_york' 175 | assert tz.strip() == 'America/New_York' 176 | assert f'{tz} Time' == 'America/New_York Time' 177 | assert tz.startswith('America') 178 | assert tz.endswith('York') 179 | 180 | 181 | def test_timezone_name_comparison(): 182 | tz1 = TimeZoneName('America/New_York') 183 | tz2 = TimeZoneName('Europe/London') 184 | tz3 = TimeZoneName('America/New_York') 185 | 186 | assert tz1 == tz3 187 | assert tz1 != tz2 188 | assert tz1 < tz2 # Alphabetical comparison 189 | assert tz2 > tz1 190 | assert tz1 <= tz3 191 | assert tz1 >= tz3 192 | 193 | 194 | def test_timezone_name_hash(): 195 | tz1 = TimeZoneName('America/New_York') 196 | tz2 = TimeZoneName('America/New_York') 197 | tz3 = TimeZoneName('Europe/London') 198 | 199 | assert hash(tz1) == hash(tz2) 200 | assert hash(tz1) != hash(tz3) 201 | 202 | tz_set = {tz1, tz2, tz3} 203 | assert len(tz_set) == 2 204 | 205 | 206 | def test_timezone_name_slots(): 207 | tz = TimeZoneName('America/New_York') 208 | with pytest.raises(AttributeError): 209 | tz.new_attribute = 'test' 210 | -------------------------------------------------------------------------------- /tests/test_types_color.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | from pydantic_core import PydanticCustomError 6 | 7 | from pydantic_extra_types.color import Color 8 | 9 | 10 | @pytest.mark.parametrize( 11 | 'raw_color, as_tuple', 12 | [ 13 | # named colors 14 | ('aliceblue', (240, 248, 255)), 15 | ('Antiquewhite', (250, 235, 215)), 16 | ('transparent', (0, 0, 0, 0)), 17 | ('#000000', (0, 0, 0)), 18 | ('#DAB', (221, 170, 187)), 19 | ('#dab', (221, 170, 187)), 20 | ('#000', (0, 0, 0)), 21 | ('0x797979', (121, 121, 121)), 22 | ('0x777', (119, 119, 119)), 23 | ('0x777777', (119, 119, 119)), 24 | ('0x777777cc', (119, 119, 119, 0.8)), 25 | ('777', (119, 119, 119)), 26 | ('777c', (119, 119, 119, 0.8)), 27 | (' 777', (119, 119, 119)), 28 | ('777 ', (119, 119, 119)), 29 | (' 777 ', (119, 119, 119)), 30 | ((0, 0, 128), (0, 0, 128)), 31 | ([0, 0, 128], (0, 0, 128)), 32 | ((0, 0, 205, 1.0), (0, 0, 205)), 33 | ((0, 0, 205, 0.5), (0, 0, 205, 0.5)), 34 | ('rgb(0, 0, 205)', (0, 0, 205)), 35 | ('rgb(0, 0, 205.2)', (0, 0, 205)), 36 | ('rgb(0, 0.2, 205)', (0, 0, 205)), 37 | ('rgba(0, 0, 128, 0.6)', (0, 0, 128, 0.6)), 38 | ('rgba(0, 0, 128, .6)', (0, 0, 128, 0.6)), 39 | ('rgba(0, 0, 128, 60%)', (0, 0, 128, 0.6)), 40 | (' rgba(0, 0, 128,0.6) ', (0, 0, 128, 0.6)), 41 | ('rgba(00,0,128,0.6 )', (0, 0, 128, 0.6)), 42 | ('rgba(0, 0, 128, 0)', (0, 0, 128, 0)), 43 | ('rgba(0, 0, 128, 1)', (0, 0, 128)), 44 | ('rgb(0 0.2 205)', (0, 0, 205)), 45 | ('rgb(0 0.2 205 / 0.6)', (0, 0, 205, 0.6)), 46 | ('rgb(0 0.2 205 / 60%)', (0, 0, 205, 0.6)), 47 | ('rgba(0 0 128)', (0, 0, 128)), 48 | ('rgba(0 0 128 / 0.6)', (0, 0, 128, 0.6)), 49 | ('rgba(0 0 128 / 60%)', (0, 0, 128, 0.6)), 50 | ('hsl(270, 60%, 70%)', (178, 133, 224)), 51 | ('hsl(180, 100%, 50%)', (0, 255, 255)), 52 | ('hsl(630, 60%, 70%)', (178, 133, 224)), 53 | ('hsl(270deg, 60%, 70%)', (178, 133, 224)), 54 | ('hsl(.75turn, 60%, 70%)', (178, 133, 224)), 55 | ('hsl(-.25turn, 60%, 70%)', (178, 133, 224)), 56 | ('hsl(-0.25turn, 60%, 70%)', (178, 133, 224)), 57 | ('hsl(4.71238rad, 60%, 70%)', (178, 133, 224)), 58 | ('hsl(10.9955rad, 60%, 70%)', (178, 133, 224)), 59 | ('hsl(270, 60%, 50%, .15)', (127, 51, 204, 0.15)), 60 | ('hsl(270.00deg, 60%, 50%, 15%)', (127, 51, 204, 0.15)), 61 | ('hsl(630 60% 70%)', (178, 133, 224)), 62 | ('hsl(270 60% 50% / .15)', (127, 51, 204, 0.15)), 63 | ('hsla(630, 60%, 70%)', (178, 133, 224)), 64 | ('hsla(630 60% 70%)', (178, 133, 224)), 65 | ('hsla(270 60% 50% / .15)', (127, 51, 204, 0.15)), 66 | ], 67 | ) 68 | def test_color_success(raw_color, as_tuple): 69 | c = Color(raw_color) 70 | assert c.as_rgb_tuple() == as_tuple 71 | assert c.original() == raw_color 72 | 73 | 74 | @pytest.mark.parametrize( 75 | 'color', 76 | [ 77 | # named colors 78 | 'nosuchname', 79 | 'chucknorris', 80 | # hex 81 | '#0000000', 82 | 'x000', 83 | # rgb/rgba tuples 84 | (256, 256, 256), 85 | (128, 128, 128, 0.5, 128), 86 | (0, 0, 'x'), 87 | (0, 0, 0, 1.5), 88 | (0, 0, 0, 'x'), 89 | (0, 0, 1280), 90 | (0, 0, 1205, 0.1), 91 | (0, 0, 1128, 0.5), 92 | (0, 0, 1128, -0.5), 93 | (0, 0, 1128, 1.5), 94 | ({}, 0, 0), 95 | # rgb/rgba strings 96 | 'rgb(0, 0, 1205)', 97 | 'rgb(0, 0, 1128)', 98 | 'rgb(0, 0, 200 / 0.2)', 99 | 'rgb(72 122 18, 0.3)', 100 | 'rgba(0, 0, 11205, 0.1)', 101 | 'rgba(0, 0, 128, 11.5)', 102 | 'rgba(0, 0, 128 / 11.5)', 103 | 'rgba(72 122 18 0.3)', 104 | # hsl/hsla strings 105 | 'hsl(180, 101%, 50%)', 106 | 'hsl(72 122 18 / 0.3)', 107 | 'hsl(630 60% 70%, 0.3)', 108 | 'hsla(72 122 18 / 0.3)', 109 | # neither a tuple, not a string 110 | datetime(2017, 10, 5, 19, 47, 7), 111 | object, 112 | range(10), 113 | ], 114 | ) 115 | def test_color_fail(color): 116 | with pytest.raises(PydanticCustomError) as exc_info: 117 | Color(color) 118 | assert exc_info.value.type == 'color_error' 119 | 120 | 121 | def test_model_validation(): 122 | class Model(BaseModel): 123 | color: Color 124 | 125 | assert Model(color='red').color.as_hex() == '#f00' 126 | assert Model(color=Color('red')).color.as_hex() == '#f00' 127 | with pytest.raises(ValidationError) as exc_info: 128 | Model(color='snot') 129 | # insert_assert(exc_info.value.errors()) 130 | assert exc_info.value.errors() == [ 131 | { 132 | 'type': 'color_error', 133 | 'loc': ('color',), 134 | 'msg': 'value is not a valid color: string not recognised as a valid color', 135 | 'input': 'snot', 136 | } 137 | ] 138 | 139 | 140 | def test_as_rgb(): 141 | assert Color('bad').as_rgb() == 'rgb(187, 170, 221)' 142 | assert Color((1, 2, 3, 0.123456)).as_rgb() == 'rgba(1, 2, 3, 0.12)' 143 | assert Color((1, 2, 3, 0.1)).as_rgb() == 'rgba(1, 2, 3, 0.1)' 144 | 145 | 146 | def test_as_rgb_tuple(): 147 | assert Color((1, 2, 3)).as_rgb_tuple(alpha=None) == (1, 2, 3) 148 | assert Color((1, 2, 3, 1)).as_rgb_tuple(alpha=None) == (1, 2, 3) 149 | assert Color((1, 2, 3, 0.3)).as_rgb_tuple(alpha=None) == (1, 2, 3, 0.3) 150 | assert Color((1, 2, 3, 0.3)).as_rgb_tuple(alpha=None) == (1, 2, 3, 0.3) 151 | 152 | assert Color((1, 2, 3)).as_rgb_tuple(alpha=False) == (1, 2, 3) 153 | assert Color((1, 2, 3, 0.3)).as_rgb_tuple(alpha=False) == (1, 2, 3) 154 | 155 | assert Color((1, 2, 3)).as_rgb_tuple(alpha=True) == (1, 2, 3, 1) 156 | assert Color((1, 2, 3, 0.3)).as_rgb_tuple(alpha=True) == (1, 2, 3, 0.3) 157 | 158 | 159 | def test_as_hsl(): 160 | assert Color('bad').as_hsl() == 'hsl(260, 43%, 77%)' 161 | assert Color((1, 2, 3, 0.123456)).as_hsl() == 'hsl(210, 50%, 1%, 0.12)' 162 | assert Color('hsl(260, 43%, 77%)').as_hsl() == 'hsl(260, 43%, 77%)' 163 | 164 | 165 | def test_as_hsl_tuple(): 166 | c = Color('016997') 167 | h, s, l_, a = c.as_hsl_tuple(alpha=True) 168 | assert h == pytest.approx(0.551, rel=0.01) 169 | assert s == pytest.approx(0.986, rel=0.01) 170 | assert l_ == pytest.approx(0.298, rel=0.01) 171 | assert a == 1 172 | 173 | assert c.as_hsl_tuple(alpha=False) == c.as_hsl_tuple(alpha=None) == (h, s, l_) 174 | 175 | c = Color((3, 40, 50, 0.5)) 176 | hsla = c.as_hsl_tuple(alpha=None) 177 | assert len(hsla) == 4 178 | assert hsla[3] == 0.5 179 | 180 | 181 | def test_as_hex(): 182 | assert Color((1, 2, 3)).as_hex() == '#010203' 183 | assert Color((119, 119, 119)).as_hex() == '#777' 184 | assert Color((119, 0, 238)).as_hex() == '#70e' 185 | assert Color('B0B').as_hex() == '#b0b' 186 | assert Color((1, 2, 3, 0.123456)).as_hex() == '#0102031f' 187 | assert Color((1, 2, 3, 0.1)).as_hex() == '#0102031a' 188 | 189 | 190 | def test_as_hex_long(): 191 | assert Color((1, 2, 3)).as_hex(format='long') == '#010203' 192 | assert Color((119, 119, 119)).as_hex(format='long') == '#777777' 193 | assert Color((119, 0, 238)).as_hex(format='long') == '#7700ee' 194 | assert Color('B0B').as_hex(format='long') == '#bb00bb' 195 | assert Color('#0102031a').as_hex(format='long') == '#0102031a' 196 | 197 | 198 | def test_as_named(): 199 | assert Color((0, 255, 255)).as_named() == 'cyan' 200 | assert Color('#808000').as_named() == 'olive' 201 | assert Color('hsl(180, 100%, 50%)').as_named() == 'cyan' 202 | 203 | assert Color((240, 248, 255)).as_named() == 'aliceblue' 204 | with pytest.raises(ValueError) as exc_info: 205 | Color((1, 2, 3)).as_named() 206 | assert exc_info.value.args[0] == 'no named color found, use fallback=True, as_hex() or as_rgb()' 207 | 208 | assert Color((1, 2, 3)).as_named(fallback=True) == '#010203' 209 | assert Color((1, 2, 3, 0.1)).as_named(fallback=True) == '#0102031a' 210 | 211 | 212 | def test_str_repr(): 213 | assert str(Color('red')) == 'red' 214 | assert repr(Color('red')) == "Color('red', rgb=(255, 0, 0))" 215 | assert str(Color((1, 2, 3))) == '#010203' 216 | assert repr(Color((1, 2, 3))) == "Color('#010203', rgb=(1, 2, 3))" 217 | 218 | 219 | def test_eq(): 220 | assert Color('red') == Color('red') 221 | assert Color('red') != Color('blue') 222 | assert Color('red') != 'red' 223 | 224 | assert Color('red') == Color((255, 0, 0)) 225 | assert Color('red') != Color((0, 0, 255)) 226 | 227 | 228 | def test_color_hashable(): 229 | assert hash(Color('red')) != hash(Color('blue')) 230 | assert hash(Color('red')) == hash(Color((255, 0, 0))) 231 | assert hash(Color('red')) != hash(Color((255, 0, 0, 0.5))) 232 | -------------------------------------------------------------------------------- /tests/test_types_payment.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Any 3 | 4 | import pytest 5 | from pydantic import BaseModel, ValidationError 6 | from pydantic_core._pydantic_core import PydanticCustomError 7 | 8 | from pydantic_extra_types.payment import PaymentCardBrand, PaymentCardNumber 9 | 10 | VALID_AMEX = '370000000000002' 11 | VALID_MC = '5100000000000003' 12 | VALID_VISA_13 = '4050000000001' 13 | VALID_VISA_16 = '4050000000000001' 14 | VALID_VISA_19 = '4050000000000000001' 15 | VALID_MIR_16 = '2200000000000004' 16 | VALID_MIR_17 = '22000000000000004' 17 | VALID_MIR_18 = '220000000000000004' 18 | VALID_MIR_19 = '2200000000000000004' 19 | VALID_DISCOVER = '6011000000000004' 20 | VALID_VERVE_16 = '5061000000000001' 21 | VALID_VERVE_18 = '506100000000000001' 22 | VALID_VERVE_19 = '5061000000000000001' 23 | VALID_DANKORT = '5019000000000000' 24 | VALID_UNIONPAY_16 = '6200000000000001' 25 | VALID_UNIONPAY_19 = '8100000000000000001' 26 | VALID_JCB_16 = '3528000000000001' 27 | VALID_JCB_19 = '3528000000000000001' 28 | VALID_MAESTRO = '6759649826438453' 29 | VALID_TROY = '9792000000000001' 30 | VALID_OTHER = '2000000000000000008' 31 | LUHN_INVALID = '4000000000000000' 32 | LEN_INVALID = '40000000000000006' 33 | 34 | 35 | # Mock PaymentCardNumber 36 | PCN = namedtuple('PaymentCardNumber', ['card_number', 'brand']) 37 | PCN.__len__ = lambda v: len(v.card_number) 38 | 39 | 40 | @pytest.fixture(scope='session', name='PaymentCard') 41 | def payment_card_model_fixture(): 42 | class PaymentCard(BaseModel): 43 | card_number: PaymentCardNumber 44 | 45 | return PaymentCard 46 | 47 | 48 | def test_validate_digits(): 49 | digits = '12345' 50 | assert PaymentCardNumber.validate_digits(digits) is None 51 | with pytest.raises(PydanticCustomError, match='Card number is not all digits'): 52 | PaymentCardNumber.validate_digits('hello') 53 | with pytest.raises(PydanticCustomError, match='Card number is not all digits'): 54 | PaymentCardNumber.validate_digits('²') 55 | 56 | 57 | @pytest.mark.parametrize( 58 | 'card_number, valid', 59 | [ 60 | ('0', True), 61 | ('00', True), 62 | ('18', True), 63 | ('0000000000000000', True), 64 | ('4242424242424240', False), 65 | ('4242424242424241', False), 66 | ('4242424242424242', True), 67 | ('4242424242424243', False), 68 | ('4242424242424244', False), 69 | ('4242424242424245', False), 70 | ('4242424242424246', False), 71 | ('4242424242424247', False), 72 | ('4242424242424248', False), 73 | ('4242424242424249', False), 74 | ('42424242424242426', True), 75 | ('424242424242424267', True), 76 | ('4242424242424242675', True), 77 | ('5164581347216566', True), 78 | ('4345351087414150', True), 79 | ('343728738009846', True), 80 | ('5164581347216567', False), 81 | ('4345351087414151', False), 82 | ('343728738009847', False), 83 | ('000000018', True), 84 | ('99999999999999999999', True), 85 | ('99999999999999999999999999999999999999999999999999999999999999999997', True), 86 | ], 87 | ) 88 | def test_validate_luhn_check_digit(card_number: str, valid: bool): 89 | if valid: 90 | assert PaymentCardNumber.validate_luhn_check_digit(card_number) == card_number 91 | else: 92 | with pytest.raises(PydanticCustomError, match='Card number is not luhn valid'): 93 | PaymentCardNumber.validate_luhn_check_digit(card_number) 94 | 95 | 96 | @pytest.mark.parametrize( 97 | 'card_number, brand, valid', 98 | [ 99 | (VALID_VISA_13, PaymentCardBrand.visa, True), 100 | (VALID_VISA_16, PaymentCardBrand.visa, True), 101 | (VALID_VISA_19, PaymentCardBrand.visa, True), 102 | (VALID_MC, PaymentCardBrand.mastercard, True), 103 | (VALID_AMEX, PaymentCardBrand.amex, True), 104 | (VALID_MIR_16, PaymentCardBrand.mir, True), 105 | (VALID_MIR_17, PaymentCardBrand.mir, True), 106 | (VALID_MIR_18, PaymentCardBrand.mir, True), 107 | (VALID_MIR_19, PaymentCardBrand.mir, True), 108 | (VALID_DISCOVER, PaymentCardBrand.discover, True), 109 | (VALID_VERVE_16, PaymentCardBrand.verve, True), 110 | (VALID_VERVE_18, PaymentCardBrand.verve, True), 111 | (VALID_VERVE_19, PaymentCardBrand.verve, True), 112 | (VALID_DANKORT, PaymentCardBrand.dankort, True), 113 | (VALID_UNIONPAY_16, PaymentCardBrand.unionpay, True), 114 | (VALID_UNIONPAY_19, PaymentCardBrand.unionpay, True), 115 | (VALID_JCB_16, PaymentCardBrand.jcb, True), 116 | (VALID_JCB_19, PaymentCardBrand.jcb, True), 117 | (LEN_INVALID, PaymentCardBrand.visa, False), 118 | (VALID_MAESTRO, PaymentCardBrand.maestro, True), 119 | (VALID_TROY, PaymentCardBrand.troy, True), 120 | (VALID_OTHER, PaymentCardBrand.other, True), 121 | ], 122 | ) 123 | def test_length_for_brand(card_number: str, brand: PaymentCardBrand, valid: bool): 124 | # pcn = PCN(card_number, brand) 125 | if valid: 126 | assert PaymentCardNumber.validate_brand(card_number) == brand 127 | else: 128 | with pytest.raises(PydanticCustomError) as exc_info: 129 | PaymentCardNumber.validate_brand(card_number) 130 | assert exc_info.value.type == 'payment_card_number_brand' 131 | 132 | 133 | @pytest.mark.parametrize( 134 | 'card_number, brand', 135 | [ 136 | (VALID_AMEX, PaymentCardBrand.amex), 137 | (VALID_MC, PaymentCardBrand.mastercard), 138 | (VALID_VISA_16, PaymentCardBrand.visa), 139 | (VALID_MIR_16, PaymentCardBrand.mir), 140 | (VALID_DISCOVER, PaymentCardBrand.discover), 141 | (VALID_VERVE_16, PaymentCardBrand.verve), 142 | (VALID_DANKORT, PaymentCardBrand.dankort), 143 | (VALID_UNIONPAY_16, PaymentCardBrand.unionpay), 144 | (VALID_JCB_16, PaymentCardBrand.jcb), 145 | (VALID_OTHER, PaymentCardBrand.other), 146 | (VALID_MAESTRO, PaymentCardBrand.maestro), 147 | (VALID_TROY, PaymentCardBrand.troy), 148 | ], 149 | ) 150 | def test_get_brand(card_number: str, brand: PaymentCardBrand): 151 | assert PaymentCardNumber.validate_brand(card_number) == brand 152 | 153 | 154 | def test_valid(PaymentCard): 155 | card = PaymentCard(card_number=VALID_VISA_16) 156 | assert str(card.card_number) == VALID_VISA_16 157 | assert card.card_number.masked == '405000******0001' 158 | 159 | 160 | @pytest.mark.parametrize( 161 | 'card_number, error_message', 162 | [ 163 | (None, 'type=string_type'), 164 | ('1' * 11, 'type=string_too_short,'), 165 | ('1' * 20, 'type=string_too_long,'), 166 | ('h' * 16, 'type=payment_card_number_digits'), 167 | (LUHN_INVALID, 'type=payment_card_number_luhn,'), 168 | (LEN_INVALID, 'type=payment_card_number_brand,'), 169 | ], 170 | ) 171 | def test_error_types(card_number: Any, error_message: str, PaymentCard): 172 | with pytest.raises(ValidationError, match=error_message): 173 | PaymentCard(card_number=card_number) 174 | 175 | 176 | def test_payment_card_brand(): 177 | b = PaymentCardBrand.visa 178 | assert str(b) == 'Visa' 179 | assert b is PaymentCardBrand.visa 180 | assert b == PaymentCardBrand.visa 181 | assert b in {PaymentCardBrand.visa, PaymentCardBrand.mastercard} 182 | 183 | b = 'Visa' 184 | assert b is not PaymentCardBrand.visa 185 | assert b == PaymentCardBrand.visa 186 | assert b in {PaymentCardBrand.visa, PaymentCardBrand.mastercard} 187 | 188 | b = PaymentCardBrand.amex 189 | assert b is not PaymentCardBrand.visa 190 | assert b != PaymentCardBrand.visa 191 | assert b not in {PaymentCardBrand.visa, PaymentCardBrand.mastercard} 192 | -------------------------------------------------------------------------------- /tests/test_ulid.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime, timezone 3 | from typing import Any 4 | 5 | import pytest 6 | from pydantic import BaseModel, ValidationError 7 | 8 | from pydantic_extra_types.ulid import ULID 9 | 10 | try: 11 | from ulid import ULID as _ULID 12 | except ModuleNotFoundError: # pragma: no cover 13 | raise RuntimeError( 14 | 'The `ulid` module requires "python-ulid" to be installed. You can install it with "pip install python-ulid".' 15 | ) 16 | 17 | 18 | class Something(BaseModel): 19 | ulid: ULID 20 | 21 | 22 | @pytest.mark.parametrize( 23 | 'ulid, result, valid', 24 | [ 25 | # Valid ULID for str format 26 | ('01BTGNYV6HRNK8K8VKZASZCFPE', '01BTGNYV6HRNK8K8VKZASZCFPE', True), 27 | ('01BTGNYV6HRNK8K8VKZASZCFPF', '01BTGNYV6HRNK8K8VKZASZCFPF', True), 28 | # Invalid ULID for str format 29 | ('01BTGNYV6HRNK8K8VKZASZCFP', None, False), # Invalid ULID (short length) 30 | ('01BTGNYV6HRNK8K8VKZASZCFPEA', None, False), # Invalid ULID (long length) 31 | # Valid ULID for UUID format 32 | (uuid.UUID('0196FEB3-9C99-8D8C-B3F3-4301C5E9DCE1'), '01JVZB774SHP6B7WT3072YKQ71', True), 33 | (uuid.UUID('0196FEB3-CD14-4B50-0015-C1E09BF7B221'), '01JVZB7K8M9D8005E1W2DZFCH1', True), 34 | # Valid ULID for _ULID format 35 | (_ULID.from_str('01BTGNYV6HRNK8K8VKZASZCFPE'), '01BTGNYV6HRNK8K8VKZASZCFPE', True), 36 | (_ULID.from_str('01BTGNYV6HRNK8K8VKZASZCFPF'), '01BTGNYV6HRNK8K8VKZASZCFPF', True), 37 | # Invalid _ULID for bytes format 38 | (b'\x01\xba\x1e\xb2\x8a\x9f\xfay\x10\xd5\xa5k\xc8', None, False), # Invalid ULID (short length) 39 | (b'\x01\xba\x1e\xb2\x8a\x9f\xfay\x10\xd5\xa5k\xc8\xb6\x00', None, False), # Invalid ULID (long length) 40 | # Valid ULID for int format 41 | (109667145845879622871206540411193812282, '2JG4FVY7N8XS4GFVHPXGJZ8S9T', True), 42 | (109667145845879622871206540411193812283, '2JG4FVY7N8XS4GFVHPXGJZ8S9V', True), 43 | (109667145845879622871206540411193812284, '2JG4FVY7N8XS4GFVHPXGJZ8S9W', True), 44 | # Invalid ULID for bool format 45 | (True, None, False), 46 | (False, None, False), 47 | ], 48 | ) 49 | def test_format_for_ulid(ulid: Any, result: Any, valid: bool): 50 | if valid: 51 | assert str(Something(ulid=ulid).ulid) == result 52 | else: 53 | with pytest.raises(ValidationError, match='format'): 54 | Something(ulid=ulid) 55 | 56 | 57 | def test_property_for_ulid(): 58 | ulid = Something(ulid='01BTGNYV6HRNK8K8VKZASZCFPE').ulid 59 | assert ulid.hex == '015ea15f6cd1c56689a373fab3f63ece' 60 | assert ulid == '01BTGNYV6HRNK8K8VKZASZCFPE' 61 | assert ulid.datetime == datetime(2017, 9, 20, 22, 18, 59, 153000, tzinfo=timezone.utc) 62 | assert ulid.timestamp == 1505945939.153 63 | 64 | 65 | def test_json_schema(): 66 | assert Something.model_json_schema(mode='validation') == { 67 | 'properties': { 68 | 'ulid': { 69 | 'anyOf': [ 70 | {'type': 'integer'}, 71 | {'format': 'binary', 'type': 'string'}, 72 | {'type': 'string'}, 73 | {'format': 'uuid', 'type': 'string'}, 74 | ], 75 | 'title': 'Ulid', 76 | } 77 | }, 78 | 'required': ['ulid'], 79 | 'title': 'Something', 80 | 'type': 'object', 81 | } 82 | assert Something.model_json_schema(mode='serialization') == { 83 | 'properties': { 84 | 'ulid': { 85 | 'anyOf': [ 86 | {'type': 'integer'}, 87 | {'format': 'binary', 'type': 'string'}, 88 | {'type': 'string'}, 89 | {'format': 'uuid', 'type': 'string'}, 90 | ], 91 | 'title': 'Ulid', 92 | } 93 | }, 94 | 'required': ['ulid'], 95 | 'title': 'Something', 96 | 'type': 'object', 97 | } 98 | --------------------------------------------------------------------------------