├── docs ├── .nojekyll ├── _sidebar.md ├── _navbar.md ├── _coverpage.md ├── index.html └── overview.md ├── tests ├── __init__.py ├── core │ ├── injector │ │ ├── __init__.py │ │ ├── test_config.py │ │ └── test_registry.py │ └── test_workspace.py ├── integrations │ └── feature_flag │ │ ├── test_file.py │ │ ├── test_harness.py │ │ └── test_launchdarkly.py ├── builtin │ ├── test_filters.py │ └── test_metrics.py └── types │ └── test_monads.py ├── src └── cdf │ ├── core │ ├── __init__.py │ ├── component │ │ ├── operation.py │ │ ├── service.py │ │ ├── publisher.py │ │ ├── __init__.py │ │ ├── pipeline.py │ │ └── base.py │ ├── injector │ │ ├── errors.py │ │ ├── __init__.py │ │ └── registry.py │ ├── context.py │ ├── workspace.py │ └── configuration.py │ ├── integrations │ ├── __init__.py │ ├── feature_flag │ │ ├── noop.py │ │ ├── split.py │ │ ├── launchdarkly.py │ │ ├── __init__.py │ │ ├── file.py │ │ ├── base.py │ │ └── harness.py │ └── slack.py │ ├── types │ ├── __init__.py │ └── monads.py │ ├── proxy │ ├── __init__.py │ ├── mysql.py │ └── planner.py │ ├── __init__.py │ └── builtin │ ├── filters.py │ └── metrics.py ├── pytest.ini ├── Makefile ├── .gitignore ├── .neoconf.json ├── .github └── workflows │ └── python-package.yml ├── pyproject.toml ├── README.md └── LICENSE /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/cdf/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/cdf/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/core/injector/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integrations/feature_flag/test_file.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integrations/feature_flag/test_harness.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integrations/feature_flag/test_launchdarkly.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | credentials: mark a test as requiring credentials 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: update-docs 2 | 3 | update-docs: 4 | @echo "Updating docs..." 5 | @pydoc-markdown -I src/cdf >docs/api_reference.md 6 | @echo "Done." 7 | 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | .git/ 4 | 5 | dist/ 6 | logs/ 7 | 8 | wip/ 9 | 10 | _storage/ 11 | _rendered/ 12 | 13 | .cache/ 14 | .ipynb_checkpoints/ 15 | 16 | *.duckdb 17 | -------------------------------------------------------------------------------- /.neoconf.json: -------------------------------------------------------------------------------- 1 | { 2 | "lspconfig": { 3 | "pyright": { 4 | "python": { 5 | "analysis": { 6 | "diagnosticMode": "workspace" 7 | } 8 | } 9 | } 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/cdf/core/component/operation.py: -------------------------------------------------------------------------------- 1 | from .base import Entrypoint 2 | 3 | OperationProto = int 4 | 5 | 6 | class Operation(Entrypoint[OperationProto], frozen=True): 7 | """A generic callable that returns an exit code.""" 8 | -------------------------------------------------------------------------------- /src/cdf/core/component/service.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from .base import Component 4 | 5 | ServiceProto = t.Any 6 | 7 | 8 | class Service(Component[ServiceProto], frozen=True): 9 | """A service that the workspace provides. IE an API, database, requests client, etc.""" 10 | -------------------------------------------------------------------------------- /docs/_sidebar.md: -------------------------------------------------------------------------------- 1 | - Getting started 2 | 3 | - [Quick start](quickstart.md) 4 | 5 | - Customization 6 | 7 | - [Configuration](configuration.md) 8 | 9 | - API Reference 10 | 11 | - [Python](api_reference.md) 12 | - [CLI](cli_reference.md) 13 | 14 | - [Changelog](changelog.md) 15 | -------------------------------------------------------------------------------- /src/cdf/core/injector/errors.py: -------------------------------------------------------------------------------- 1 | class DependencyCycleError(Exception): 2 | """Raised when a dependency cycle is detected.""" 3 | 4 | pass 5 | 6 | 7 | class DependencyMutationError(Exception): 8 | """Raised when an instance/singleton dependency has already been resolved but a mutation is attempted.""" 9 | 10 | pass 11 | -------------------------------------------------------------------------------- /src/cdf/core/injector/__init__.py: -------------------------------------------------------------------------------- 1 | from cdf.core.injector.registry import ( 2 | GLOBAL_REGISTRY, 3 | Dependency, 4 | DependencyKey, 5 | DependencyRegistry, 6 | Lifecycle, 7 | ) 8 | 9 | __all__ = [ 10 | "Dependency", 11 | "DependencyRegistry", 12 | "DependencyKey", 13 | "Lifecycle", 14 | "GLOBAL_REGISTRY", 15 | ] 16 | -------------------------------------------------------------------------------- /docs/_navbar.md: -------------------------------------------------------------------------------- 1 | * Getting started 2 | 3 | * [Quick start](quickstart.md) 4 | * [Writing more pages](more-pages.md) 5 | * [Custom navbar](custom-navbar.md) 6 | * [Cover page](cover.md) 7 | 8 | * Configuration 9 | 10 | * [Configuration](configuration.md) 11 | * [Themes](themes.md) 12 | * [Using plugins](plugins.md) 13 | * [Markdown configuration](markdown.md) 14 | * [Language highlight](language-highlight.md) 15 | -------------------------------------------------------------------------------- /docs/_coverpage.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Continuous Data Framework (cdf) 4 | 5 | > A framework for managing data, continuously. 6 | 7 | - Simple and lightweight 8 | - Leverages the power of `sqlmesh` and `dlt` 9 | - Opinionated, but flexible 10 | 11 | [GitHub](https://github.com/z3z1ma/cdf/) 12 | [Get Started](overview.md) 13 | -------------------------------------------------------------------------------- /src/cdf/types/__init__.py: -------------------------------------------------------------------------------- 1 | """A module for shared types.""" 2 | 3 | import sys 4 | import typing as t 5 | from pathlib import Path 6 | 7 | import cdf.types.monads as M 8 | 9 | if t.TYPE_CHECKING: 10 | import decimal 11 | 12 | PathLike = t.Union[str, Path] 13 | Number = t.Union[int, float, "decimal.Decimal"] 14 | 15 | if sys.version_info < (3, 10): 16 | from typing_extensions import ParamSpec 17 | else: 18 | from typing import ParamSpec 19 | 20 | P = ParamSpec("P") 21 | 22 | __all__ = ["M", "P", "PathLike", "Number"] 23 | -------------------------------------------------------------------------------- /src/cdf/proxy/__init__.py: -------------------------------------------------------------------------------- 1 | """The proxy module provides a MySQL proxy server for the CDF. 2 | 3 | The proxy server is used to intercept MySQL queries and execute them using SQLMesh. 4 | This allows it to integrate with BI tools and other MySQL clients. Furthermore, 5 | during interception, the server can rewrite queries expanding semantic references 6 | making it an easy to use semantic layer for SQLMesh. 7 | """ 8 | 9 | from cdf.proxy.mysql import run_mysql_proxy 10 | from cdf.proxy.planner import run_plan_server 11 | 12 | 13 | __all__ = ["run_mysql_proxy", "run_plan_server"] 14 | -------------------------------------------------------------------------------- /src/cdf/integrations/feature_flag/noop.py: -------------------------------------------------------------------------------- 1 | """No-op feature flag provider.""" 2 | 3 | import typing as t 4 | 5 | from cdf.integrations.feature_flag.base import AbstractFeatureFlagAdapter 6 | 7 | 8 | class NoopFeatureFlagAdapter(AbstractFeatureFlagAdapter): 9 | """A feature flag adapter that does nothing.""" 10 | 11 | def __init__(self, **kwargs: t.Any) -> None: 12 | """Initialize the adapter.""" 13 | pass 14 | 15 | def get(self, feature_name: str) -> bool: 16 | return True 17 | 18 | def save(self, feature_name: str, flag: bool) -> None: 19 | pass 20 | 21 | def get_all_feature_names(self) -> t.List[str]: 22 | return [] 23 | 24 | 25 | __all__ = ["NoopFeatureFlagAdapter"] 26 | -------------------------------------------------------------------------------- /src/cdf/core/component/publisher.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from .base import Entrypoint 4 | 5 | DataPublisherProto = t.Tuple[ 6 | t.Callable[..., None], # run 7 | t.Callable[..., bool], # preflight 8 | t.Optional[t.Callable[..., None]], # success hook 9 | t.Optional[t.Callable[..., None]], # failure hook 10 | ] 11 | 12 | 13 | class DataPublisher( 14 | Entrypoint[DataPublisherProto], 15 | frozen=True, 16 | ): 17 | """A data publisher which pushes data to an operational system.""" 18 | 19 | def __call__(self, *args: t.Any, **kwargs: t.Any) -> None: 20 | """Publish the data""" 21 | publisher, pre, success, err = self.main(*args, **kwargs) 22 | if not pre(): 23 | raise ValueError("Preflight check failed") 24 | try: 25 | return publisher() 26 | except Exception as e: 27 | if err: 28 | err() 29 | raise e 30 | else: 31 | if success: 32 | success() 33 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.9", "3.10", "3.11"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install '.[dev]' 31 | - name: Test with pytest 32 | run: | 33 | pytest 34 | -------------------------------------------------------------------------------- /src/cdf/__init__.py: -------------------------------------------------------------------------------- 1 | import cdf.core.configuration as conf 2 | from cdf.core.component import ( 3 | DataPipeline, 4 | DataPublisher, 5 | Operation, 6 | Service, 7 | ServiceLevelAgreement, 8 | ) 9 | from cdf.core.configuration import ( 10 | ConfigResolver, 11 | Request, 12 | map_config_section, 13 | map_config_values, 14 | ) 15 | from cdf.core.context import ( 16 | get_active_workspace, 17 | invoke, 18 | resolve, 19 | set_active_workspace, 20 | use_workspace, 21 | ) 22 | from cdf.core.injector import Dependency, DependencyRegistry 23 | from cdf.core.workspace import Workspace 24 | 25 | __all__ = [ 26 | "conf", 27 | "DataPipeline", 28 | "DataPublisher", 29 | "Operation", 30 | "Service", 31 | "ServiceLevelAgreement", 32 | "ConfigResolver", 33 | "Request", 34 | "map_config_section", 35 | "map_config_values", 36 | "Workspace", 37 | "Dependency", 38 | "DependencyRegistry", 39 | "get_active_workspace", 40 | "set_active_workspace", 41 | "resolve", 42 | "invoke", 43 | "use_workspace", 44 | ] 45 | -------------------------------------------------------------------------------- /src/cdf/integrations/feature_flag/split.py: -------------------------------------------------------------------------------- 1 | """Split feature flag provider.""" 2 | 3 | import typing as t 4 | 5 | from cdf.integrations.feature_flag.base import AbstractFeatureFlagAdapter 6 | 7 | 8 | class SplitFeatureFlagAdapter(AbstractFeatureFlagAdapter): 9 | """A feature flag adapter that uses Split.""" 10 | 11 | def __init__(self, sdk_key: str, **kwargs: t.Any) -> None: 12 | """Initialize the Split feature flags. 13 | 14 | Args: 15 | sdk_key: The SDK key to use for Split. 16 | """ 17 | self.sdk_key = sdk_key 18 | 19 | def __repr__(self) -> str: 20 | return f"{type(self).__name__}(sdk_key={self.sdk_key!r})" 21 | 22 | def __str__(self) -> str: 23 | return self.sdk_key 24 | 25 | def get(self, feature_name: str) -> bool: 26 | raise NotImplementedError("This provider is not yet implemented") 27 | 28 | def save(self, feature_name: str, flag: bool) -> None: 29 | raise NotImplementedError("This provider is not yet implemented") 30 | 31 | def get_all_feature_names(self) -> t.List[str]: 32 | raise NotImplementedError("This provider is not yet implemented") 33 | 34 | 35 | __all__ = ["SplitFeatureFlagAdapter"] 36 | -------------------------------------------------------------------------------- /src/cdf/integrations/feature_flag/launchdarkly.py: -------------------------------------------------------------------------------- 1 | """LaunchDarkly feature flag provider.""" 2 | 3 | import typing as t 4 | 5 | from dlt.common.configuration import with_config 6 | 7 | from cdf.integrations.feature_flag.base import AbstractFeatureFlagAdapter 8 | 9 | 10 | class LaunchDarklyFeatureFlagAdapter(AbstractFeatureFlagAdapter): 11 | """A feature flag adapter that uses LaunchDarkly.""" 12 | 13 | @with_config(sections=("feature_flags",)) 14 | def __init__(self, sdk_key: str, **kwargs: t.Any) -> None: 15 | """Initialize the LaunchDarkly feature flags. 16 | 17 | Args: 18 | sdk_key: The SDK key to use for LaunchDarkly. 19 | """ 20 | self.sdk_key = sdk_key 21 | 22 | def __repr__(self) -> str: 23 | return f"{type(self).__name__}(sdk_key={self.sdk_key!r})" 24 | 25 | def __str__(self) -> str: 26 | return self.sdk_key 27 | 28 | def get(self, feature_name: str) -> bool: 29 | raise NotImplementedError("This provider is not yet implemented") 30 | 31 | def save(self, feature_name: str, flag: bool) -> None: 32 | raise NotImplementedError("This provider is not yet implemented") 33 | 34 | def get_all_feature_names(self) -> t.List[str]: 35 | raise NotImplementedError("This provider is not yet implemented") 36 | 37 | 38 | __all__ = ["LaunchDarklyFeatureFlagAdapter"] 39 | -------------------------------------------------------------------------------- /src/cdf/core/component/__init__.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from .base import Component, Entrypoint, ServiceLevelAgreement 4 | from .operation import Operation, OperationProto 5 | from .pipeline import DataPipeline, DataPipelineProto 6 | from .publisher import DataPublisher, DataPublisherProto 7 | from .service import Service, ServiceProto 8 | 9 | __all__ = [ 10 | "DataPipeline", 11 | "DataPublisher", 12 | "Operation", 13 | "Service", 14 | "ServiceDef", 15 | "DataPipelineDef", 16 | "DataPublisherDef", 17 | "OperationDef", 18 | "TComponent", 19 | "TComponent", 20 | "ServiceLevelAgreement", 21 | ] 22 | 23 | ServiceDef = t.Union[ 24 | Service, 25 | t.Callable[..., ServiceProto], 26 | t.Dict[str, t.Any], 27 | ] 28 | DataPipelineDef = t.Union[ 29 | DataPipeline, 30 | t.Callable[..., DataPipelineProto], 31 | t.Dict[str, t.Any], 32 | ] 33 | DataPublisherDef = t.Union[ 34 | DataPublisher, 35 | t.Callable[..., DataPublisherProto], 36 | t.Dict[str, t.Any], 37 | ] 38 | OperationDef = t.Union[ 39 | Operation, 40 | t.Callable[..., OperationProto], 41 | t.Dict[str, t.Any], 42 | ] 43 | 44 | TComponent = t.TypeVar("TComponent", bound=t.Union[Component, Entrypoint]) 45 | TComponentDef = t.TypeVar( 46 | "TComponentDef", 47 | ServiceDef, 48 | DataPipelineDef, 49 | DataPublisherDef, 50 | OperationDef, 51 | ) 52 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Document 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /docs/overview.md: -------------------------------------------------------------------------------- 1 | # cdf 2 | 3 | CDF (Continuous Data Framework) is an integrated framework designed to manage 4 | data across the entire lifecycle, from ingestion through transformation to 5 | publishing. It is built on top of two open-source projects, `sqlmesh` and 6 | `dlt`, providing a unified interface for complex data operations. CDF 7 | simplifies data engineering workflows, offering scalable solutions from small 8 | to large projects through an opinionated project structure that supports both 9 | multi-workspace and single-workspace layouts. We place a heavy emphasis on the 10 | inner loop of data engineering. We believe that the most important part of data 11 | engineering is the ability to rapidly iterate on the data, and we have designed 12 | CDF to make that as easy as possible. We achieve this through a combination of 13 | dlt's simplicity in authoring pipelines with dynamic parameterization of sinks 14 | and developer utilities such as `head` and `discover`. We streamline the 15 | process of scaffolding out new components and view the idea of a workspace as 16 | something that is full of business-specific components. Pipelines, 17 | transformations, publishers, scripts, and notebooks. Spend less time on 18 | boilerplate, less time figuring out how to consolidate your custom code into 19 | perfect collections of software engineering best practices, and spend much more 20 | time on point solutions that solve your business problems. Thats the benefit of 21 | opinionation. And we offer it in a way that is flexible and extensible. 22 | 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "python-cdf" 3 | version = "0.9.4" 4 | description = "A framework to manage data continuously" 5 | authors = [ 6 | { name = "z3z1ma", email = "butler.alex2010@gmail.com" }, 7 | ] 8 | dependencies = [ 9 | # We will find a good version range for these eventually, but allowing them to float 10 | # and be pinned by the user for now is far more useful. 11 | "sqlmesh", 12 | "dlt[duckdb]", 13 | "duckdb", 14 | # The following deps have well-defined version ranges 15 | "mysql-mimic>=2,<3", 16 | "harness-featureflags>=1.2.0,<1.6.1", 17 | "python-dotenv>=1,<2", 18 | "pex>=2.1.100,<2.2.0", 19 | "pydantic>=2.5.0,<3", 20 | "psutil~=5.9.0", 21 | "typing-extensions>=4,<5", 22 | "fsspec>=2022", 23 | "dynaconf>=3,<4", 24 | "eval_type_backport~=0.1.3; python_version<'3.10'", 25 | ] 26 | requires-python = ">=3.9,<3.13" 27 | readme = "README.md" 28 | license.file = "LICENSE" 29 | 30 | [tool.poetry] 31 | packages = [ 32 | { include = "cdf", from = "src" } 33 | ] 34 | 35 | [project.optional-dependencies] 36 | dev = [ 37 | # "poetry @ git+https://github.com/radoering/poetry.git@pep621-support", 38 | "pytest>=7.4.3", 39 | "pytest-mock>=3.12.0", 40 | "pydoc-markdown>4", 41 | ] 42 | 43 | [build-system] 44 | requires = ["poetry-core@ git+https://github.com/radoering/poetry-core.git@pep621-support"] 45 | build-backend = "poetry.core.masonry.api" 46 | 47 | [tool.hatch.metadata] 48 | allow-direct-references = true 49 | 50 | [tool.pyright] 51 | include = ["src"] 52 | exclude = ["examples/", "docs/"] 53 | ignore = ["src/builtin"] 54 | reportPrivateImportUsage = false 55 | -------------------------------------------------------------------------------- /tests/builtin/test_filters.py: -------------------------------------------------------------------------------- 1 | from cdf.builtin.filters import ( 2 | eq, 3 | gt, 4 | gte, 5 | in_list, 6 | lt, 7 | lte, 8 | ne, 9 | not_empty, 10 | not_in_list, 11 | not_null, 12 | ) 13 | 14 | 15 | def test_eq(): 16 | assert eq("name", "Alice")({"name": "Alice"}) 17 | 18 | 19 | def test_ne(): 20 | assert ne("name", "Alice")({"name": "Bob"}) 21 | 22 | 23 | def test_gt(): 24 | assert gt("age", 30)({"age": 35}) 25 | 26 | 27 | def test_gte(): 28 | assert gte("age", 30)({"age": 30}) 29 | 30 | 31 | def test_lt(): 32 | assert lt("age", 30)({"age": 25}) 33 | 34 | 35 | def test_lte(): 36 | assert lte("age", 30)({"age": 30}) 37 | 38 | 39 | def test_in_list(): 40 | assert in_list("name", ["Alice", "Bob"])({"name": "Alice"}) 41 | 42 | 43 | def test_not_in_list(): 44 | assert not_in_list("name", ["Alice", "Bob"])({"name": "Charlie"}) 45 | 46 | 47 | def test_not_empty(): 48 | assert not_empty("name")({"name": "Alice"}) 49 | assert not_empty("name")({"name": 0}) 50 | assert not_empty("name")({"name": False}) 51 | 52 | assert not_empty("name")({"name": ""}) is False 53 | assert not_empty("name")({"name": []}) is False 54 | assert not_empty("name")({"name": {}}) is False 55 | assert not_empty("name")({"name": None}) is False 56 | 57 | 58 | def test_not_null(): 59 | assert not_null("name")({"name": "Alice"}) 60 | assert not_null("name")({"name": 0}) 61 | assert not_null("name")({"name": False}) 62 | assert not_null("name")({"name": []}) 63 | assert not_null("name")({"name": {}}) 64 | 65 | assert not_null("name")({"name": None}) is False 66 | assert not_null("name")({"whatever": 1}) is False 67 | -------------------------------------------------------------------------------- /tests/builtin/test_metrics.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import pytest 4 | 5 | from cdf.builtin.metrics import ( 6 | avg_value, 7 | count, 8 | max_value, 9 | median_value, 10 | min_value, 11 | mode_value, 12 | stdev_value, 13 | sum_value, 14 | unique, 15 | variance_value, 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def data(): 21 | return [ 22 | {"name": "Alice", "age": 25}, 23 | {"name": "Bob", "age": 30}, 24 | {"name": "Charlie", "age": 35}, 25 | {"name": "David", "age": 40}, 26 | {"name": "Eve", "age": 45}, 27 | {"name": "Frank", "age": 50}, 28 | {"name": "Alice", "age": 25}, 29 | {"name": "Bob"}, 30 | ] 31 | 32 | 33 | def test_count(data): 34 | assert reduce(lambda metric, item: count(item, metric), data, 0) == 8 35 | 36 | 37 | def test_unique(data): 38 | func = unique("name") 39 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 6 40 | 41 | 42 | def test_max_value(data): 43 | func = max_value("age") 44 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 50 45 | 46 | 47 | def test_min_value(data): 48 | func = min_value("age") 49 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 25 50 | 51 | 52 | def test_sum_value(data): 53 | func = sum_value("age") 54 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 250 55 | 56 | 57 | def test_avg_value(data): 58 | func = avg_value("age") 59 | assert ( 60 | reduce(lambda metric, item: func(item, metric), data, 0) == 35.714285714285715 61 | ) 62 | 63 | 64 | def test_median_value(data): 65 | func = median_value("age") 66 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 35 67 | 68 | 69 | def test_variance_value(data): 70 | func = variance_value("age") 71 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 81.63265306122435 72 | 73 | 74 | def test_stdev_value(data): 75 | func = stdev_value("age") 76 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 9.035079029052504 77 | 78 | 79 | def test_mode_value(data): 80 | func = mode_value("age") 81 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 25 82 | -------------------------------------------------------------------------------- /tests/types/test_monads.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | import typing as t 4 | from collections import defaultdict 5 | 6 | import requests 7 | 8 | from cdf.types.monads import State, promise, state, to_state 9 | 10 | threadtime = defaultdict(list) 11 | 12 | T = t.TypeVar("T") 13 | 14 | 15 | @promise 16 | def fetch(url: str) -> requests.Response: 17 | tid = threading.get_ident() 18 | threadtime[tid].append(time.perf_counter()) 19 | resp = requests.get(url) 20 | return resp 21 | 22 | 23 | @promise 24 | def track(v: T) -> T: 25 | tid = threading.get_ident() 26 | threadtime[tid].append(time.perf_counter()) 27 | return v 28 | 29 | 30 | @promise 31 | def num_abilities(resp: requests.Response) -> int: 32 | data = resp.json() 33 | i = len(data["abilities"]) 34 | tid = threading.get_ident() 35 | threadtime[tid].append(time.perf_counter()) 36 | return i 37 | 38 | 39 | def test_fetch(): 40 | futs = [] 41 | for i in range(5): 42 | print(f"Starting iteration {i}") 43 | futs.append( 44 | track("https://pokeapi.co/api/v2/pokemon/ditto") 45 | >> fetch 46 | >> track 47 | >> num_abilities 48 | >> track 49 | ) 50 | 51 | for fut in futs: 52 | print(fut.unwrap()) 53 | 54 | 55 | def test_state(): 56 | state_x = to_state(1) # 1 is the value for the computations NOT the state 57 | 58 | @state 59 | def add_one(x: int) -> int: 60 | return x + 1 61 | 62 | def print_state(x: int): 63 | def _print(state: int): 64 | print(state) 65 | return x, state 66 | 67 | return _print 68 | 69 | add_one(State(print_state(1))) 70 | 71 | def process_int(x: int) -> State[list[int], int]: 72 | """Process an integer, tracking unique values""" 73 | 74 | def process(state: list[int]) -> t.Tuple[int, list[int]]: 75 | nonlocal x 76 | x += 1 77 | if x in state: 78 | return x, state 79 | state.append(x) 80 | return x, state 81 | 82 | return State(process) 83 | 84 | state_y = state_x >> process_int >> add_one >> process_int >> add_one 85 | x, y = state_y.run_state([]) # type: ignore 86 | assert x == 5 87 | assert y == [2, 4] 88 | -------------------------------------------------------------------------------- /src/cdf/integrations/feature_flag/__init__.py: -------------------------------------------------------------------------------- 1 | """Feature flag providers implement a uniform interface and are wrapped by an adapter. 2 | 3 | The adapter is responsible for loading the correct provider and applying the feature flags within 4 | various contexts in cdf. This allows for a clean separation of concerns and makes it easy to 5 | implement new feature flag providers in the future. 6 | """ 7 | 8 | import typing as t 9 | 10 | import dlt 11 | from dlt.common.configuration import with_config 12 | 13 | from cdf.integrations.feature_flag.base import AbstractFeatureFlagAdapter 14 | from cdf.integrations.feature_flag.file import FilesystemFeatureFlagAdapter 15 | from cdf.integrations.feature_flag.harness import HarnessFeatureFlagAdapter 16 | from cdf.integrations.feature_flag.launchdarkly import LaunchDarklyFeatureFlagAdapter 17 | from cdf.integrations.feature_flag.noop import NoopFeatureFlagAdapter 18 | from cdf.integrations.feature_flag.split import SplitFeatureFlagAdapter 19 | from cdf.types import M 20 | 21 | ADAPTERS: t.Dict[str, t.Type[AbstractFeatureFlagAdapter]] = { 22 | "filesystem": FilesystemFeatureFlagAdapter, 23 | "harness": HarnessFeatureFlagAdapter, 24 | "launchdarkly": LaunchDarklyFeatureFlagAdapter, 25 | "split": SplitFeatureFlagAdapter, 26 | "noop": NoopFeatureFlagAdapter, 27 | } 28 | """Feature flag provider adapters classes by name.""" 29 | 30 | 31 | @with_config(sections=("feature_flags",)) 32 | def get_feature_flag_adapter_cls( 33 | provider: str = dlt.config.value, 34 | ) -> M.Result[t.Type[AbstractFeatureFlagAdapter], Exception]: 35 | """Get a feature flag adapter by name. 36 | 37 | Args: 38 | provider: The name of the feature flag adapter. 39 | options: The configuration for the feature flag adapter. 40 | 41 | Returns: 42 | The feature flag adapter. 43 | """ 44 | try: 45 | if provider not in ADAPTERS: 46 | raise KeyError( 47 | f"Unknown provider: {provider}. Available providers: {', '.join(ADAPTERS.keys())}" 48 | ) 49 | return M.ok(ADAPTERS[provider]) 50 | except KeyError as e: 51 | # Notify available providers 52 | return M.error(e) 53 | except Exception as e: 54 | return M.error(e) 55 | 56 | 57 | __all__ = [ 58 | "ADAPTERS", 59 | "AbstractFeatureFlagAdapter", 60 | "FilesystemFeatureFlagAdapter", 61 | "HarnessFeatureFlagAdapter", 62 | "LaunchDarklyFeatureFlagAdapter", 63 | "NoopFeatureFlagAdapter", 64 | "SplitFeatureFlagAdapter", 65 | "get_feature_flag_adapter_cls", 66 | ] 67 | -------------------------------------------------------------------------------- /src/cdf/proxy/mysql.py: -------------------------------------------------------------------------------- 1 | """A MySQL proxy server which uses SQLMesh to execute queries.""" 2 | 3 | import typing as t 4 | import asyncio 5 | import logging 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import sqlmesh 10 | from mysql_mimic import MysqlServer, Session 11 | from mysql_mimic.server import logger 12 | from sqlglot import exp 13 | 14 | 15 | async def file_watcher(context: sqlmesh.Context) -> None: 16 | """Watch for changes in the workspace and refresh the context.""" 17 | while True: 18 | await asyncio.sleep(5.0) 19 | await asyncio.to_thread(context.refresh) 20 | 21 | 22 | class SQLMeshSession(Session): 23 | """A session for the MySQL proxy server which uses SQLMesh.""" 24 | 25 | context: sqlmesh.Context 26 | 27 | async def query( 28 | self, expression: exp.Expression, sql: str, attrs: t.Dict[str, str] 29 | ) -> t.Tuple[t.Tuple[t.Tuple[t.Any], ...], t.List[str]]: 30 | """Execute a query.""" 31 | tables = list(expression.find_all(exp.Table)) 32 | if any((table.db, table.name) == ("__semantic", "__table") for table in tables): 33 | expression = self.context.rewrite(sql) 34 | logger.info("Compiled semantic expression!") 35 | logger.info(expression.sql(self.context.default_dialect)) 36 | df = self.context.fetchdf(expression) 37 | logger.debug(df) 38 | df.replace({np.nan: None}, inplace=True) 39 | return tuple(df.itertuples(index=False)), list(df.columns) 40 | 41 | async def schema(self) -> t.Dict[str, t.Dict[str, t.Dict[str, str]]]: 42 | """Get the schema of the database.""" 43 | schema = defaultdict(dict) 44 | for model in self.context.models.values(): 45 | fqn = model.fully_qualified_table 46 | if model.columns_to_types and all( 47 | typ is not None for typ in model.columns_to_types.values() 48 | ): 49 | schema[fqn.db][fqn.name] = model.columns_to_types 50 | return schema 51 | 52 | 53 | async def run_mysql_proxy(context: sqlmesh.Context) -> None: 54 | """Run the MySQL proxy server.""" 55 | 56 | logging.basicConfig(level=logging.DEBUG) 57 | server = MysqlServer( 58 | session_factory=type( 59 | "BoundSQLMeshSession", 60 | (SQLMeshSession,), 61 | {"context": context}, 62 | ) 63 | ) 64 | asyncio.create_task(file_watcher(context)) 65 | try: 66 | await server.serve_forever() 67 | except asyncio.CancelledError: 68 | await server.wait_closed() 69 | 70 | 71 | __all__ = ["run_mysql_proxy"] 72 | -------------------------------------------------------------------------------- /tests/core/test_workspace.py: -------------------------------------------------------------------------------- 1 | import cdf.core.component as cmp 2 | import cdf.core.configuration as conf 3 | import cdf.core.injector as injector 4 | from cdf.core.workspace import Workspace 5 | 6 | 7 | def test_workspace(): 8 | import dlt 9 | 10 | @dlt.source 11 | def test_source(a: int, prod_bigquery: str): 12 | @dlt.resource 13 | def test_resource(): 14 | yield from [{"a": a, "prod_bigquery": prod_bigquery}] 15 | 16 | return [test_resource] 17 | 18 | # Define a workspace 19 | datateam = Workspace( 20 | name="data-team", 21 | version="0.1.1", 22 | configuration_sources=[ 23 | # DATATEAM_CONFIG, 24 | { 25 | "sfdc": {"username": "abc"}, 26 | "bigquery": {"project_id": ...}, 27 | }, 28 | ], 29 | service_definitions=[ 30 | cmp.Service( 31 | name="a", 32 | main=injector.Dependency.instance(1), 33 | owner="Alex", 34 | description="A secret number", 35 | sla=cmp.ServiceLevelAgreement.CRITICAL, 36 | ), 37 | cmp.Service( 38 | name="b", 39 | main=injector.Dependency.prototype(lambda a: a + 1 * 5 / 10), 40 | owner="Alex", 41 | ), 42 | cmp.Service( 43 | name="prod_bigquery", 44 | main=injector.Dependency.instance("dwh-123"), 45 | owner="DataTeam", 46 | ), 47 | cmp.Service( 48 | name="sfdc", 49 | main=injector.Dependency( 50 | factory=lambda username: f"https://sfdc.com/{username}", 51 | conf_spec=("sfdc",), 52 | ), 53 | owner="RevOps", 54 | ), 55 | ], 56 | ) 57 | 58 | @conf.map_config_values(secret_number="a.b.c") 59 | def c(secret_number: int, sfdc: str) -> int: 60 | print(f"SFDC: {sfdc=}") 61 | return secret_number * 10 62 | 63 | # Imperatively add dependencies or config if needed 64 | datateam.container.add_from_dependency(injector.Dependency.prototype(c)) 65 | datateam.conf_resolver.import_source({"a.b.c": 10}) 66 | 67 | def source_a(a: int, prod_bigquery: str): 68 | print(f"Source A: {a=}, {prod_bigquery=}") 69 | 70 | # Some interface examples 71 | assert datateam.name == "data-team" 72 | datateam.invoke(source_a) 73 | assert datateam.conf_resolver["sfdc.username"] == "abc" 74 | assert datateam.container.resolve_or_raise("sfdc") == "https://sfdc.com/abc" 75 | assert datateam.invoke(c) == 100 76 | assert list(datateam.invoke(test_source)) == [{"a": 1, "prod_bigquery": "dwh-123"}] 77 | -------------------------------------------------------------------------------- /src/cdf/proxy/planner.py: -------------------------------------------------------------------------------- 1 | """An http server which executed a plan which is a pickled pydantic model 2 | 3 | This is purely a POC. It will be replaced by a more robust solution in the future 4 | using flask or fastapi. It will always be designed such that input must be 5 | trusted. In an environment where the input is not trusted, the server should 6 | never be exposed to the internet. It should always be behind a firewall and 7 | only accessible by trusted clients. 8 | """ 9 | 10 | import pickle 11 | import typing as t 12 | import http.server 13 | import socketserver 14 | import traceback 15 | import logging 16 | import uuid 17 | import json 18 | import io 19 | from contextlib import redirect_stdout, redirect_stderr 20 | 21 | import sqlmesh 22 | 23 | 24 | def run_plan_server(port: int, context: sqlmesh.Context) -> None: 25 | """Listen on a port and execute plans.""" 26 | 27 | # TODO: move this 28 | logging.basicConfig(level=logging.DEBUG) 29 | 30 | def _plan(plan: t.Any) -> t.Any: 31 | """Run a plan""" 32 | stdout = io.StringIO() 33 | stderr = io.StringIO() 34 | with redirect_stdout(stdout), redirect_stderr(stderr): 35 | context.apply(plan) 36 | return { 37 | "stdout": stdout.getvalue(), 38 | "stderr": stderr.getvalue(), 39 | "execution_id": uuid.uuid4().hex, 40 | } 41 | 42 | class Handler(http.server.SimpleHTTPRequestHandler): 43 | def do_GET(self) -> None: 44 | """Ping the server""" 45 | self.send_response(200) 46 | self.send_header("Content-type", "text/plain") 47 | self.end_headers() 48 | self.wfile.write(b"Pong") 49 | 50 | def do_POST(self) -> None: 51 | """Run the plan""" 52 | content_length = int(self.headers["Content-Length"]) 53 | ser_plan = self.rfile.read(content_length) 54 | try: 55 | plan = pickle.loads(ser_plan) 56 | resp = _plan(plan) 57 | self.send_response(200) 58 | self.send_header("Content-type", "application/json") 59 | self.end_headers() 60 | self.wfile.write(json.dumps(resp).encode()) 61 | except Exception as e: 62 | self.send_response(500) 63 | self.send_header("Content-type", "text/plain") 64 | self.end_headers() 65 | self.wfile.write(str(e).encode()) 66 | self.wfile.write(b"\n") 67 | self.wfile.write(traceback.format_exc().encode()) 68 | 69 | with socketserver.TCPServer(("", port), Handler) as httpd: 70 | logging.info("serving at port %s", port) 71 | try: 72 | httpd.serve_forever() 73 | except KeyboardInterrupt: 74 | pass 75 | -------------------------------------------------------------------------------- /src/cdf/core/component/pipeline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing as t 3 | 4 | from .base import Entrypoint 5 | 6 | if t.TYPE_CHECKING: 7 | from dlt.common.destination import Destination as DltDestination 8 | from dlt.common.pipeline import LoadInfo 9 | from dlt.pipeline.pipeline import Pipeline as DltPipeline 10 | 11 | _GRN = "\033[32;1m" 12 | _YLW = "\033[33;1m" 13 | _RED = "\033[31;1m" 14 | _CLR = "\033[0m" 15 | 16 | TEST_RESULT_MAP = { 17 | None: f"{_YLW}SKIP{_CLR}", 18 | True: f"{_GRN}PASS{_CLR}", 19 | False: f"{_RED}FAIL{_CLR}", 20 | } 21 | 22 | DataPipelineProto = t.Tuple[ 23 | "DltPipeline", 24 | t.Union[ 25 | t.Callable[..., "LoadInfo"], 26 | t.Callable[..., t.Iterator["LoadInfo"]], 27 | ], # run 28 | t.Sequence[t.Callable[..., t.Optional[t.Union[bool, t.Tuple[bool, str]]]]], # tests 29 | ] 30 | 31 | 32 | class DataPipeline( 33 | Entrypoint[DataPipelineProto], 34 | frozen=True, 35 | ): 36 | """A data pipeline which loads data from a source to a destination.""" 37 | 38 | def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.List["LoadInfo"]: 39 | """Run the data pipeline""" 40 | _, runner, _ = self.main(*args, **kwargs) 41 | if inspect.isgeneratorfunction(runner): 42 | return list(runner()) 43 | return [t.cast("LoadInfo", runner())] 44 | 45 | run = __call__ 46 | 47 | def unwrap(self) -> "DltPipeline": 48 | """Get the dlt pipeline object.""" 49 | pipeline, _, _ = self.main() 50 | return pipeline 51 | 52 | def get_schemas(self, destination: t.Optional["DltDestination"] = None): 53 | """Get the schemas for the pipeline.""" 54 | pipeline = self.unwrap() 55 | pipeline.sync_destination(destination=destination) 56 | return pipeline.schemas 57 | 58 | def run_tests(self) -> None: 59 | """Run the integration test for the pipeline.""" 60 | _, _, tests = self.main() 61 | if not tests: 62 | raise ValueError("No tests found for pipeline") 63 | tpl = "[{nr}/{tot}] {message} ({state})" 64 | tot = len(tests) 65 | for nr, test in enumerate(tests, 1): 66 | result_struct = test() 67 | if isinstance(result_struct, bool) or result_struct is None: 68 | result, reason = result_struct, "No message" 69 | elif isinstance(result_struct, tuple): 70 | result, reason = result_struct 71 | else: 72 | raise ValueError( 73 | f"Invalid return type `{type(result_struct)}`, expected none, bool, or tuple(bool, str)" 74 | ) 75 | if result not in TEST_RESULT_MAP: 76 | raise ValueError(f"Invalid return status for test: `{result}`") 77 | print( 78 | tpl.format( 79 | nr=nr, tot=tot, state=TEST_RESULT_MAP[result], message=reason 80 | ) 81 | ) 82 | -------------------------------------------------------------------------------- /src/cdf/builtin/filters.py: -------------------------------------------------------------------------------- 1 | """Built-in filters for CDF 2 | 3 | They can be referenced via absolute import paths in a pipeline spec. 4 | """ 5 | 6 | import typing as t 7 | 8 | FilterFunc = t.Callable[[t.Any], bool] 9 | 10 | 11 | def not_empty(key: str) -> FilterFunc: 12 | """Filters out items where a key is empty""" 13 | 14 | def _not_empty(item: t.Any) -> bool: 15 | if item.get(key) is None: 16 | return False 17 | if isinstance(item[key], str): 18 | return item[key].strip() != "" 19 | if isinstance(item[key], list): 20 | return len(item[key]) > 0 21 | if isinstance(item[key], dict): 22 | return len(item[key]) > 0 23 | return True 24 | 25 | return _not_empty 26 | 27 | 28 | def not_null(key: str) -> FilterFunc: 29 | """Filters out items where a key is null""" 30 | 31 | def _not_null(item: t.Any) -> bool: 32 | return item.get(key) is not None 33 | 34 | return _not_null 35 | 36 | 37 | def gt(key: str, value: t.Any) -> FilterFunc: 38 | """Filters out items where a key is greater than a value""" 39 | 40 | def _greater_than(item: t.Any) -> bool: 41 | return item[key] > value 42 | 43 | return _greater_than 44 | 45 | 46 | def lt(key: str, value: t.Any) -> FilterFunc: 47 | """Filters out items where a key is less than a value""" 48 | 49 | def _less_than(item: t.Any) -> bool: 50 | return item[key] < value 51 | 52 | return _less_than 53 | 54 | 55 | def gte(key: str, value: t.Any) -> FilterFunc: 56 | """Filters out items where a key is greater than or equal to a value""" 57 | 58 | def _greater_than_or_equal(item: t.Any) -> bool: 59 | return item[key] >= value 60 | 61 | return _greater_than_or_equal 62 | 63 | 64 | def lte(key: str, value: t.Any) -> FilterFunc: 65 | """Filters out items where a key is less than or equal to a value""" 66 | 67 | def _less_than_or_equal(item: t.Any) -> bool: 68 | return item[key] <= value 69 | 70 | return _less_than_or_equal 71 | 72 | 73 | def eq(key: str, value: t.Any) -> FilterFunc: 74 | """Filters out items where a key is equal to a value""" 75 | 76 | def _equal(item: t.Any) -> bool: 77 | return item[key] == value 78 | 79 | return _equal 80 | 81 | 82 | def ne(key: str, value: t.Any) -> FilterFunc: 83 | """Filters out items where a key is not equal to a value""" 84 | 85 | def _not_equal(item: t.Any) -> bool: 86 | return item[key] != value 87 | 88 | return _not_equal 89 | 90 | 91 | def in_list(key: str, value: t.List[str]) -> FilterFunc: 92 | """Filters out items where a key is in a list of values""" 93 | 94 | def _in_list(item: t.Any) -> bool: 95 | return item[key] in value 96 | 97 | return _in_list 98 | 99 | 100 | def not_in_list(key: str, value: t.List[str]) -> FilterFunc: 101 | """Filters out items where a key is not in a list of values""" 102 | 103 | def _not_in_list(item: t.Any) -> bool: 104 | return item[key] not in value 105 | 106 | return _not_in_list 107 | -------------------------------------------------------------------------------- /tests/core/injector/test_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from cdf.core.configuration import ConfigResolver 7 | 8 | 9 | def test_apply_converters(): 10 | with patch.dict("os.environ", {}): 11 | os.environ["CDF_TEST"] = "1" 12 | assert ConfigResolver.apply_converters("$CDF_TEST") == "1" 13 | assert ConfigResolver.apply_converters("@int ${CDF_TEST}") == 1 14 | os.environ["CDF_BOOL"] = "true" 15 | assert ConfigResolver.apply_converters("@bool ${CDF_BOOL}") is True 16 | os.environ["CDF_FLOAT"] = "3.14" 17 | assert ConfigResolver.apply_converters("@float ${CDF_FLOAT}") == 3.14 18 | os.environ["CDF_JSON"] = '{"key": "value"}' 19 | assert ConfigResolver.apply_converters("@json ${CDF_JSON}") == {"key": "value"} 20 | os.environ["CDF_PATH"] = "tests/v2/test_config.py" 21 | assert ConfigResolver.apply_converters("@path ${CDF_PATH}") == os.path.abspath( 22 | "tests/v2/test_config.py" 23 | ) 24 | os.environ["CDF_DICT"] = "{'key': 'value'}" 25 | assert ConfigResolver.apply_converters("@dict ${CDF_DICT}") == {"key": "value"} 26 | os.environ["CDF_LIST"] = "['key', 'value']" 27 | assert ConfigResolver.apply_converters("@list ${CDF_LIST}") == ["key", "value"] 28 | os.environ["CDF_TUPLE"] = "('key', 'value')" 29 | assert ConfigResolver.apply_converters("@tuple ${CDF_TUPLE}") == ( 30 | "key", 31 | "value", 32 | ) 33 | os.environ["CDF_SET"] = "{'key', 'value'}" 34 | assert ConfigResolver.apply_converters("@set ${CDF_SET}") == {"key", "value"} 35 | 36 | with pytest.raises(ValueError): 37 | ConfigResolver.apply_converters("@unknown_converter idk") 38 | with pytest.raises(ValueError): 39 | ConfigResolver.apply_converters("@int something") 40 | 41 | assert ConfigResolver.apply_converters("no conversion") == "no conversion" 42 | 43 | 44 | def test_config_resolver(): 45 | os.environ["CDF_TEST"] = "1" 46 | resolver = ConfigResolver( 47 | { 48 | "main_api": { 49 | "user": "someone", 50 | "password": "secret", 51 | "database": "test", 52 | }, 53 | "db_1": "@int ${CDF_TEST}", 54 | "db_2": "@resolve main_api", 55 | } 56 | ) 57 | 58 | assert resolver["main_api"] == { 59 | "user": "someone", 60 | "password": "secret", 61 | "database": "test", 62 | } 63 | assert resolver["db_1"] == 1 64 | assert resolver["main_api"] == resolver["db_2"] 65 | resolver.import_source({"db_1": 2}) 66 | assert resolver["db_1"] == 2 67 | 68 | @ConfigResolver.map_values(db_1="db_1", db_2="db_2") 69 | def foo(db_1: int, db_2: dict): 70 | return db_1, db_2 71 | 72 | foo_configured = resolver.resolve_defaults(foo) 73 | assert foo_configured() == ( 74 | 2, 75 | {"user": "someone", "password": "secret", "database": "test"}, 76 | ) 77 | 78 | @ConfigResolver.map_values( 79 | user="main_api.user", password="main_api.password", database="main_api.database" 80 | ) 81 | def bar(user: str, password: str, database: str): 82 | return user, password, database 83 | 84 | bar_configured = resolver.resolve_defaults(bar) 85 | assert bar_configured() == ("someone", "secret", "test") 86 | 87 | assert "main_api" in resolver 88 | -------------------------------------------------------------------------------- /src/cdf/integrations/feature_flag/file.py: -------------------------------------------------------------------------------- 1 | """File-based feature flag provider.""" 2 | 3 | import json 4 | import logging 5 | import typing as t 6 | from collections import defaultdict 7 | from threading import Lock 8 | 9 | import dlt 10 | import fsspec 11 | from dlt.common.configuration import with_config 12 | 13 | from cdf.integrations.feature_flag.base import ( 14 | AbstractFeatureFlagAdapter, 15 | FlagAdapterResponse, 16 | ) 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class FilesystemFeatureFlagAdapter(AbstractFeatureFlagAdapter): 22 | """A feature flag adapter that uses the filesystem.""" 23 | 24 | _LOCK = defaultdict(Lock) 25 | 26 | __cdf_resolve__ = ("feature_flags",) 27 | 28 | @with_config(sections=("feature_flags",)) 29 | def __init__( 30 | self, 31 | filesystem: fsspec.AbstractFileSystem, 32 | filename: str = dlt.config.value, 33 | **kwargs: t.Any, 34 | ) -> None: 35 | """Initialize the filesystem feature flags. 36 | 37 | Args: 38 | filesystem: The filesystem to use. 39 | filename: The filename to use for the feature flags. 40 | """ 41 | self.filename = filename 42 | self.filesystem = filesystem 43 | self.__flags: t.Optional[t.Dict[str, FlagAdapterResponse]] = None 44 | 45 | def __repr__(self) -> str: 46 | return f"{type(self).__name__}(filename={self.filename!r})" 47 | 48 | def __str__(self) -> str: 49 | return self.filename 50 | 51 | def _read(self) -> t.Dict[str, FlagAdapterResponse]: 52 | """Read the feature flags from the filesystem.""" 53 | logger.info("Reading feature flags from %s", self.filename) 54 | if not self.filesystem.exists(self.filename): 55 | flags = {} 56 | else: 57 | with self.filesystem.open(self.filename) as file: 58 | flags = json.load(file) 59 | return {k: FlagAdapterResponse.from_bool(v) for k, v in flags.items()} 60 | 61 | def _commit(self) -> None: 62 | """Commit the feature flags to the filesystem.""" 63 | logger.info("Committing feature flags to %s", self.filename) 64 | with ( 65 | self._LOCK[self.filename], 66 | self.filesystem.open(self.filename, "w") as file, 67 | ): 68 | json.dump({k: v.to_bool() for k, v in self._flags.items()}, file, indent=2) 69 | 70 | @property 71 | def _flags(self) -> t.Dict[str, FlagAdapterResponse]: 72 | """Get the feature flags.""" 73 | if self.__flags is None: 74 | self.__flags = self._read() 75 | return t.cast(t.Dict[str, FlagAdapterResponse], self.__flags) 76 | 77 | def get(self, feature_name: str) -> FlagAdapterResponse: 78 | """Get a feature flag. 79 | 80 | Args: 81 | feature_name: The name of the feature flag. 82 | 83 | Returns: 84 | The feature flag. 85 | """ 86 | return self._flags.get(feature_name, FlagAdapterResponse.NOT_FOUND) 87 | 88 | def get_all_feature_names(self) -> t.List[str]: 89 | """Get all feature flag names. 90 | 91 | Returns: 92 | The feature flag names. 93 | """ 94 | return list(self._flags.keys()) 95 | 96 | def save(self, feature_name: str, flag: bool) -> None: 97 | """Save a feature flag. 98 | 99 | Args: 100 | feature_name: The name of the feature flag. 101 | flag: The value of the feature flag. 102 | """ 103 | self._flags[feature_name] = FlagAdapterResponse.from_bool(flag) 104 | self._commit() 105 | 106 | def save_many(self, flags: t.Dict[str, bool]) -> None: 107 | """Save multiple feature flags. 108 | 109 | Args: 110 | flags: The feature flags to save. 111 | """ 112 | self._flags.update( 113 | {k: FlagAdapterResponse.from_bool(v) for k, v in flags.items()} 114 | ) 115 | self._commit() 116 | 117 | 118 | __all__ = ["FilesystemFeatureFlagAdapter"] 119 | -------------------------------------------------------------------------------- /src/cdf/core/context.py: -------------------------------------------------------------------------------- 1 | """Context management utilities for managing the active workspace.""" 2 | 3 | import contextlib 4 | import functools 5 | import logging 6 | import typing as t 7 | from contextvars import ContextVar, Token 8 | 9 | if t.TYPE_CHECKING: 10 | from cdf.core.injector import Lifecycle 11 | from cdf.core.workspace import Workspace 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | _ACTIVE_WORKSPACE: ContextVar[t.Optional["Workspace"]] = ContextVar( 16 | "active_workspace", default=None 17 | ) 18 | """The active workspace for resolving injected dependencies.""" 19 | 20 | _DEFAULT_CALLABLE_LIFECYCLE: ContextVar[t.Optional["Lifecycle"]] = ContextVar( 21 | "default_callable_lifecycle", default=None 22 | ) 23 | """The default lifecycle for callables when otherwise unspecified.""" 24 | 25 | 26 | def get_active_workspace() -> t.Optional["Workspace"]: 27 | """Get the active workspace for resolving injected dependencies.""" 28 | return _ACTIVE_WORKSPACE.get() 29 | 30 | 31 | def set_active_workspace(workspace: t.Optional["Workspace"]) -> Token: 32 | """Set the active workspace for resolving injected dependencies.""" 33 | return _ACTIVE_WORKSPACE.set(workspace) 34 | 35 | 36 | @contextlib.contextmanager 37 | def use_workspace(workspace: t.Optional["Workspace"]) -> t.Iterator[None]: 38 | """Context manager for temporarily setting the active workspace.""" 39 | token = set_active_workspace(workspace) 40 | try: 41 | yield 42 | finally: 43 | set_active_workspace(token.old_value) 44 | 45 | 46 | T = t.TypeVar("T") 47 | 48 | 49 | @t.overload 50 | def resolve( 51 | dependencies: t.Callable[..., T], 52 | configuration: bool = ..., 53 | eagerly_bind_workspace: bool = ..., 54 | ) -> t.Callable[..., T]: ... 55 | 56 | 57 | @t.overload 58 | def resolve( 59 | dependencies: bool = ..., 60 | configuration: bool = ..., 61 | eagerly_bind_workspace: bool = ..., 62 | ) -> t.Callable[[t.Callable[..., T]], t.Callable[..., T]]: ... 63 | 64 | 65 | def resolve( 66 | dependencies: t.Union[t.Callable[..., T], bool] = True, 67 | configuration: bool = True, 68 | eagerly_bind_workspace: bool = False, 69 | ) -> t.Callable[..., t.Union[T, t.Callable[..., T]]]: 70 | """Decorator for injecting dependencies and resolving configuration for a function.""" 71 | 72 | if eagerly_bind_workspace: 73 | # Get the active workspace before the function is resolved 74 | workspace = get_active_workspace() 75 | else: 76 | workspace = None 77 | 78 | def _resolve(func: t.Callable[..., T]) -> t.Callable[..., T]: 79 | @functools.wraps(func) 80 | def wrapper(*args: t.Any, **kwargs: t.Any) -> T: 81 | nonlocal func, workspace 82 | workspace = workspace or get_active_workspace() 83 | if workspace is None: 84 | return func(*args, **kwargs) 85 | if configuration: 86 | func = workspace.conf_resolver.resolve_defaults(func) 87 | if dependencies: 88 | func = workspace.container.wire(func) 89 | return func(*args, **kwargs) 90 | 91 | return wrapper 92 | 93 | if callable(dependencies): 94 | return _resolve(dependencies) 95 | 96 | return _resolve 97 | 98 | 99 | def invoke(func_or_cls: t.Callable, *args: t.Any, **kwargs: t.Any) -> t.Any: 100 | """Invoke a function or class with resolved dependencies.""" 101 | workspace = get_active_workspace() 102 | if workspace is None: 103 | logger.debug("Invoking %s without a bound workspace", func_or_cls) 104 | return func_or_cls(*args, **kwargs) 105 | logger.debug("Invoking %s bound to workspace %s", func_or_cls, workspace) 106 | return workspace.invoke(func_or_cls, *args, **kwargs) 107 | 108 | 109 | def get_default_callable_lifecycle() -> "Lifecycle": 110 | """Get the default lifecycle for callables when otherwise unspecified.""" 111 | from cdf.core.injector import Lifecycle 112 | 113 | return _DEFAULT_CALLABLE_LIFECYCLE.get() or Lifecycle.SINGLETON 114 | 115 | 116 | def set_default_callable_lifecycle(lifecycle: t.Optional["Lifecycle"]) -> Token: 117 | """Set the default lifecycle for callables when otherwise unspecified.""" 118 | if lifecycle and lifecycle.is_instance: 119 | raise ValueError("Default callable lifecycle cannot be set to INSTANCE") 120 | return _DEFAULT_CALLABLE_LIFECYCLE.set(lifecycle) 121 | 122 | 123 | @contextlib.contextmanager 124 | def use_default_callable_lifecycle( 125 | lifecycle: t.Optional["Lifecycle"], 126 | ) -> t.Iterator[None]: 127 | """Context manager for temporarily setting the default callable lifecycle.""" 128 | token = set_default_callable_lifecycle(lifecycle) 129 | try: 130 | yield 131 | finally: 132 | set_default_callable_lifecycle(token.old_value) 133 | -------------------------------------------------------------------------------- /src/cdf/integrations/feature_flag/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing as t 3 | from enum import Enum, auto 4 | 5 | if t.TYPE_CHECKING: 6 | from dlt.sources import DltSource 7 | 8 | 9 | class FlagAdapterResponse(Enum): 10 | """Feature flag response. 11 | 12 | This enum is used to represent the state of a feature flag. It is similar 13 | to a boolean but with an extra state for when the flag is not found. 14 | """ 15 | 16 | ENABLED = auto() 17 | """The feature flag is enabled.""" 18 | DISABLED = auto() 19 | """The feature flag is disabled.""" 20 | NOT_FOUND = auto() 21 | """The feature flag is not found.""" 22 | 23 | def __bool__(self) -> bool: 24 | """Return True if the flag is enabled and False otherwise.""" 25 | return self is FlagAdapterResponse.ENABLED 26 | 27 | to_bool = __bool__ 28 | 29 | def __eq__(self, value: object, /) -> bool: 30 | """Compare the flag to a boolean.""" 31 | if isinstance(value, bool): 32 | return self is FlagAdapterResponse.ENABLED and value 33 | return super().__eq__(value) 34 | 35 | @classmethod 36 | def from_bool(cls, flag: bool) -> "FlagAdapterResponse": 37 | """Convert a boolean to a flag response.""" 38 | return cls.ENABLED if flag else cls.DISABLED 39 | 40 | 41 | class AbstractFeatureFlagAdapter(abc.ABC): 42 | """Abstract feature flag adapter.""" 43 | 44 | def __init__(self, **kwargs: t.Any) -> None: 45 | """Initialize the adapter.""" 46 | pass 47 | 48 | @abc.abstractmethod 49 | def get(self, feature_name: str) -> FlagAdapterResponse: 50 | """Get the feature flag.""" 51 | pass 52 | 53 | def __getitem__(self, feature_name: str) -> FlagAdapterResponse: 54 | """Get the feature flag.""" 55 | return self.get(feature_name) 56 | 57 | def get_many(self, feature_names: t.List[str]) -> t.Dict[str, FlagAdapterResponse]: 58 | """Get many feature flags. 59 | 60 | Implementations should override this method if they can optimize it. The default 61 | will call get in a loop. 62 | """ 63 | return {feature_name: self.get(feature_name) for feature_name in feature_names} 64 | 65 | @abc.abstractmethod 66 | def save(self, feature_name: str, flag: bool) -> None: 67 | """Save the feature flag.""" 68 | pass 69 | 70 | def __setitem__(self, feature_name: str, flag: bool) -> None: 71 | """Save the feature flag.""" 72 | self.save(feature_name, flag) 73 | 74 | def save_many(self, flags: t.Dict[str, bool]) -> None: 75 | """Save many feature flags. 76 | 77 | Implementations should override this method if they can optimize it. The default 78 | will call save in a loop. 79 | """ 80 | for feature_name, flag in flags.items(): 81 | self.save(feature_name, flag) 82 | 83 | @abc.abstractmethod 84 | def get_all_feature_names(self) -> t.List[str]: 85 | """Get all feature names.""" 86 | pass 87 | 88 | def keys(self) -> t.List[str]: 89 | """Get all feature names.""" 90 | return self.get_all_feature_names() 91 | 92 | def __iter__(self) -> t.Iterator[str]: 93 | """Iterate over the feature names.""" 94 | return iter(self.get_all_feature_names()) 95 | 96 | def __contains__(self, feature_name: str) -> bool: 97 | """Check if a feature flag exists.""" 98 | return self.get(feature_name) is not FlagAdapterResponse.NOT_FOUND 99 | 100 | def __len__(self) -> int: 101 | """Get the number of feature flags.""" 102 | return len(self.get_all_feature_names()) 103 | 104 | def delete(self, feature_name: str) -> None: 105 | """Delete a feature flag. 106 | 107 | By default, this will disable the flag but implementations can override this method 108 | to delete the flag. 109 | """ 110 | self.save(feature_name, False) 111 | 112 | __delitem__ = delete 113 | 114 | def delete_many(self, feature_names: t.List[str]) -> None: 115 | """Delete many feature flags.""" 116 | self.save_many({feature_name: False for feature_name in feature_names}) 117 | 118 | def apply_source(self, source: "DltSource", *namespace: str) -> "DltSource": 119 | """Apply the feature flags to a dlt source. 120 | 121 | Args: 122 | source: The source to apply the feature flags to. 123 | 124 | Returns: 125 | The source with the feature flags applied. 126 | """ 127 | new = {} 128 | source_name = source.name 129 | for resource_name, resource in source.selected_resources.items(): 130 | k = ".".join(filter(lambda s: s, [*namespace, source_name, resource_name])) 131 | resp = self.get(k) 132 | resource.selected = bool(resp) 133 | if resp is FlagAdapterResponse.NOT_FOUND: 134 | new[k] = False 135 | if new: 136 | self.save_many(new) 137 | 138 | return source 139 | -------------------------------------------------------------------------------- /tests/core/injector/test_registry.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | 6 | from cdf.core.injector.errors import DependencyCycleError 7 | from cdf.core.injector.registry import ( 8 | Dependency, 9 | DependencyRegistry, 10 | Lifecycle, 11 | TypedKey, 12 | ) 13 | 14 | 15 | def test_registry(): 16 | # A generic test to show some of the API together 17 | container = DependencyRegistry() 18 | container.add("a", lambda: 1) 19 | 20 | def b(a: int) -> int: 21 | return a + 1 22 | 23 | container.add("b", b) 24 | container.add("obj_proto", object, container.lifecycle.PROTOTYPE) 25 | container.add("obj_singleton", object) 26 | 27 | def foo(a: int, b: int, c: int = 0) -> int: 28 | return a + b 29 | 30 | foo_wired = container.wire(foo) 31 | 32 | assert foo_wired() == 3 33 | assert foo_wired(1) == 3 34 | assert foo_wired(2) == 4 35 | assert foo_wired(3, 3) == 6 36 | 37 | assert container.get("obj_proto") is not container.get("obj_proto") 38 | assert container.get("obj_singleton") is container.get("obj_singleton") 39 | 40 | assert container(foo) == 3 41 | 42 | container.add("c", lambda a, b: a + b, container.lifecycle.PROTOTYPE) 43 | 44 | assert container(foo) == 3 45 | 46 | def bar(a: int, b: int, c: t.Optional[int] = None) -> int: 47 | if c is None: 48 | raise ValueError("c is required") 49 | return a + b + c 50 | 51 | assert container(bar) == 6 52 | assert container(bar, c=5) == 8 53 | 54 | @container.wire 55 | def baz(a: int, b: int, c: int = 0) -> int: 56 | return a + b + c 57 | 58 | assert baz() == 3 59 | 60 | 61 | @pytest.fixture 62 | def registry(): 63 | return DependencyRegistry() 64 | 65 | 66 | def test_typed_key_creation(): 67 | key = TypedKey(name="test", type_=int) 68 | assert key.name == "test" 69 | assert key.type_ is int 70 | 71 | 72 | def test_typed_key_equality(): 73 | key1 = TypedKey("name1", int) 74 | key2 = TypedKey("name1", int) 75 | assert key1 == key2 76 | 77 | 78 | def test_typed_key_inequality(): 79 | key1 = TypedKey("name1", int) 80 | key2 = TypedKey("name2", str) 81 | assert key1 != key2 82 | 83 | 84 | def test_typed_key_string_representation(): 85 | key1 = TypedKey("name1", int) 86 | assert str(key1) == "name1: int" 87 | 88 | 89 | def test_instance_dependency(): 90 | instance = Dependency.instance(42) 91 | assert instance() == 42 92 | assert instance.lifecycle.is_instance 93 | 94 | 95 | def test_singleton_dependency(): 96 | factory = MagicMock(return_value=42) 97 | singleton_dep = Dependency.singleton(factory) 98 | assert singleton_dep.lifecycle == Lifecycle.SINGLETON 99 | 100 | 101 | def test_prototype_dependency(): 102 | factory = MagicMock(return_value=42) 103 | prototype_dep = Dependency.prototype(factory) 104 | assert prototype_dep.lifecycle == Lifecycle.PROTOTYPE 105 | 106 | 107 | def test_apply_function_to_instance(): 108 | dep = Dependency.instance(42) 109 | new_dep = dep.map_value(lambda x: x + 1) 110 | assert new_dep() == 43 111 | 112 | 113 | def test_apply_wrappers_to_factory(): 114 | factory = MagicMock(return_value=42) 115 | dep = Dependency.singleton(factory) 116 | wrapper = MagicMock(side_effect=lambda f: lambda: f() + 1) 117 | dep_wrapped = dep.map(wrapper) 118 | assert dep_wrapped.unwrap() == 43 119 | 120 | 121 | def test_add_and_get_singleton(registry: DependencyRegistry): 122 | registry.add_singleton("test", object) 123 | retrieved1 = registry.get("test") 124 | retrieved2 = registry.get("test") 125 | assert retrieved1 is retrieved2 # Ensure same instance each time 126 | 127 | 128 | def test_add_and_get_prototype(registry: DependencyRegistry): 129 | registry.add_prototype("test", object) 130 | retrieved1 = registry.get("test") 131 | retrieved2 = registry.get("test") 132 | assert retrieved1 is not retrieved2 # Ensure new instance each time 133 | 134 | 135 | def test_add_and_get_instance(registry: DependencyRegistry): 136 | registry.add_instance("test", 42) 137 | retrieved = registry.get("test") 138 | assert retrieved == 42 139 | 140 | 141 | def test_wire_function(registry: DependencyRegistry): 142 | factory = MagicMock(return_value=42) 143 | registry.add_singleton("test", factory) 144 | 145 | @registry.wire 146 | def func(test): 147 | return test 148 | 149 | assert func() == 42 150 | 151 | 152 | def test_dependency_cycle(registry: DependencyRegistry): 153 | registry.add("left", lambda right: right) 154 | registry.add("right", lambda left: left) 155 | with pytest.raises(DependencyCycleError): 156 | registry["left"] 157 | 158 | 159 | def test_contains(registry: DependencyRegistry): 160 | factory = MagicMock(return_value=42) 161 | registry.add_singleton("test", factory) 162 | assert "test" in registry 163 | 164 | 165 | def test_remove_dependency(registry: DependencyRegistry): 166 | factory = MagicMock(return_value=42) 167 | registry.add_singleton("test", factory) 168 | assert "test" in registry 169 | registry.remove("test") 170 | assert "test" not in registry 171 | -------------------------------------------------------------------------------- /src/cdf/builtin/metrics.py: -------------------------------------------------------------------------------- 1 | """Built-in metrics for CDF 2 | 3 | They can be referenced via absolute import paths in a pipeline spec. 4 | """ 5 | 6 | import bisect 7 | import decimal 8 | import math 9 | import statistics 10 | import typing as t 11 | from collections import defaultdict 12 | 13 | TNumber = t.TypeVar("TNumber", int, float, decimal.Decimal) 14 | 15 | MetricFunc = t.Callable[[t.Any, TNumber], TNumber] 16 | 17 | 18 | def count(_: t.Any, metric: TNumber = 0) -> TNumber: 19 | """Counts the number of items in a dataset""" 20 | return metric + 1 21 | 22 | 23 | def unique(key: str) -> MetricFunc: 24 | """Counts the number of unique items in a dataset by a given key""" 25 | seen = set() 26 | 27 | def _unique(item: t.Any, _: t.Optional[TNumber] = None) -> int: 28 | k = item.get(key) 29 | if k is not None and k not in seen: 30 | seen.add(k) 31 | return len(seen) 32 | 33 | return _unique 34 | 35 | 36 | def max_value(key: str) -> MetricFunc: 37 | """Returns the maximum value of a key in a dataset""" 38 | first = True 39 | 40 | def _max_value(item: t.Any, metric: t.Optional[TNumber] = None) -> TNumber: 41 | nonlocal first 42 | k = item.get(key) 43 | if metric is None or first: 44 | first = False 45 | return k 46 | if k is None: 47 | return metric 48 | return max(metric, k) 49 | 50 | return _max_value 51 | 52 | 53 | def min_value(key: str) -> MetricFunc: 54 | """Returns the minimum value of a key in a dataset""" 55 | first = True 56 | 57 | def _min_value(item: t.Any, metric: t.Optional[TNumber] = None) -> TNumber: 58 | nonlocal first 59 | k = item.get(key) 60 | if metric is None or first: 61 | first = False 62 | return k 63 | if k is None: 64 | return metric 65 | return min(metric, k) 66 | 67 | return _min_value 68 | 69 | 70 | def sum_value(key: str) -> MetricFunc: 71 | """Returns the sum of a key in a dataset""" 72 | 73 | def _sum_value(item: t.Any, metric: TNumber = 0) -> TNumber: 74 | k = item.get(key, 0) 75 | return metric + k 76 | 77 | return _sum_value 78 | 79 | 80 | def avg_value(key: str) -> MetricFunc: 81 | """Returns the average of a key in a dataset""" 82 | n_sum, n_count = 0, 0 83 | 84 | def _avg_value( 85 | item: t.Any, last_value: t.Optional[TNumber] = None 86 | ) -> t.Optional[TNumber]: 87 | nonlocal n_sum, n_count 88 | k = item.get(key) 89 | if k is None: 90 | return last_value 91 | n_sum += k 92 | n_count += 1 93 | return n_sum / n_count 94 | 95 | return _avg_value 96 | 97 | 98 | def median_value(key: str, window: int = 1000) -> MetricFunc: 99 | """Returns the median of a key in a dataset""" 100 | arr = [] 101 | 102 | def _median_value( 103 | item: t.Any, last_value: t.Optional[TNumber] = None 104 | ) -> t.Optional[TNumber]: 105 | nonlocal arr 106 | k = item.get(key) 107 | if k is None: 108 | return last_value 109 | bisect.insort(arr, k) 110 | if len(arr) > window: 111 | del arr[0], arr[-1] 112 | return statistics.median(arr) 113 | 114 | return _median_value 115 | 116 | 117 | def stdev_value(key: str) -> MetricFunc: 118 | """Returns the standard deviation of a key in a dataset""" 119 | n_sum, n_squared_sum, n_count = 0, 0, 0 120 | 121 | def _stdev_value( 122 | item: t.Any, last_value: t.Optional[TNumber] = None 123 | ) -> t.Optional[float]: 124 | nonlocal n_sum, n_squared_sum, n_count 125 | k = item.get(key) 126 | if k is None: 127 | return t.cast(t.Optional[float], last_value) 128 | n_sum += k 129 | n_squared_sum += k**2 130 | n_count += 1 131 | mean = n_sum / n_count 132 | return math.sqrt(n_squared_sum / n_count - mean**2) 133 | 134 | return _stdev_value 135 | 136 | 137 | def variance_value(key: str) -> MetricFunc: 138 | """Returns the variance of a key in a dataset""" 139 | n_sum, n_squared_sum, n_count = 0, 0, 0 140 | 141 | def _variance_value( 142 | item: t.Any, last_value: t.Optional[TNumber] = None 143 | ) -> t.Optional[float]: 144 | nonlocal n_sum, n_squared_sum, n_count 145 | k = item.get(key) 146 | if k is None: 147 | return t.cast(t.Optional[float], last_value) 148 | n_sum += k 149 | n_squared_sum += k**2 150 | n_count += 1 151 | if n_count == 1: 152 | return 0 153 | mean = n_sum / n_count 154 | return (n_squared_sum / n_count) - mean**2 155 | 156 | return _variance_value 157 | 158 | 159 | def mode_value(key: str) -> MetricFunc: 160 | """Returns the mode of a key in a dataset.""" 161 | frequency = defaultdict(int) 162 | 163 | def _mode_value(item: t.Any, last_value: t.Optional[t.Any] = None) -> t.Any: 164 | nonlocal frequency 165 | k = item.get(key) 166 | if k is None: 167 | return last_value 168 | frequency[k] += 1 169 | return max(frequency.items(), key=lambda x: x[1])[0] 170 | 171 | return _mode_value 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | 4 |
CDF (Continuous Data Framework)

5 |

Craft end-to-end data pipelines and manage them continuously

6 | 7 |

8 | Python 9 | SQLMesh 10 | dlt 11 |

12 | GitHub license 13 | git-last-commit 14 | GitHub commit activity 15 | GitHub top language 16 |
17 | 18 | --- 19 | 20 | ## 📖 Table of Contents 21 | - [📖 Table of Contents](#-table-of-contents) 22 | - [📍 Overview](#-overview) 23 | - [📦 Features](#-features) 24 | - [🚀 Getting Started](#-getting-started) 25 | - [📚 Documentation](#-documentation) 26 | - [🤝 Contributing](#-contributing) 27 | - [🛣 Roadmap](#-roadmap) 28 | - [📄 License](#-license) 29 | - [👏 Acknowledgments](#-acknowledgments) 30 | 31 | --- 32 | 33 | ## 📍 Overview 34 | 35 | CDF (Continuous Data Framework) is an integrated framework designed to manage data across the entire lifecycle, from ingestion through transformation to publishing. It is built on top of two open-source projects, `sqlmesh` and `dlt`, providing a unified interface for complex data operations. CDF simplifies data engineering workflows, offering scalable solutions from small to large projects through an opinionated project structure that supports both multi-workspace and single-workspace layouts. 36 | 37 | > [!WARNING] 38 | > The repo is currently under ACTIVE development with multiple large refactors already having been completed. As such, you must be aware that the codebase is not yet stable and is subject to change. Furthermore, you must look to the code (or tests) itself for the most accurate and up-to-date information until this disclaimer is removed. 39 | 40 | ## Features 41 | 42 | ... 43 | 44 | ## Getting Started 45 | 46 | 1. **Installation**: 47 | 48 | CDF requires Python 3.9 or newer. Install CDF using pip: 49 | 50 | ```bash 51 | pip install python-cdf 52 | ``` 53 | 54 | ## Documentation 55 | 56 | For detailed documentation, including API references and tutorials, visit [CDF Documentation](#). 57 | 58 | ## Contributing 59 | 60 | Contributions to CDF are welcome! Please refer to the [contributing guidelines](CONTRIBUTING.md) for more information on how to submit pull requests, report issues, or suggest enhancements. 61 | 62 | ## License 63 | 64 | CDF is licensed under [Apache 2.0 License](LICENSE). 65 | 66 | --- 67 | 68 | This README provides an overview of the CDF tool, highlighting its primary features, installation steps, basic usage examples, and contribution guidelines. It serves as a starting point for users to understand the capabilities of CDF and how it can be integrated into their data engineering workflows. 69 | 70 | ### 🧪 Tests 71 | 72 | Run the tests with `pytest`: 73 | 74 | ```sh 75 | pytest tests 76 | ``` 77 | 78 | ## 🛣 Project Roadmap 79 | 80 | TODO: Add a roadmap for the project. 81 | 82 | 83 | ## 🤝 Contributing 84 | 85 | Contributions are welcome! Here are several ways you can contribute: 86 | 87 | - **[Submit Pull Requests](https://github.com/z3z1ma/cdf/blob/main/CONTRIBUTING.md)**: Review open PRs, and submit your own PRs. 88 | - **[Join the Discussions](https://github.com/z3z1ma/cdf/discussions)**: Share your insights, provide feedback, or ask questions. 89 | - **[Report Issues](https://github.com/z3z1ma/cdf/issues)**: Submit bugs found or log feature requests for z3z1ma. 90 | 91 | 92 | #### *Contributing Guidelines* 93 | 94 |
95 | Click to expand 96 | 97 | 1. **Fork the Repository**: Start by forking the project repository to your GitHub account. 98 | 2. **Clone Locally**: Clone the forked repository to your local machine using a Git client. 99 | ```sh 100 | git clone 101 | ``` 102 | 3. **Create a New Branch**: Always work on a new branch, giving it a descriptive name. 103 | ```sh 104 | git checkout -b new-feature-x 105 | ``` 106 | 4. **Make Your Changes**: Develop and test your changes locally. 107 | 5. **Commit Your Changes**: Commit with a clear and concise message describing your updates. 108 | ```sh 109 | git commit -m 'Implemented new feature x.' 110 | ``` 111 | 6. **Push to GitHub**: Push the changes to your forked repository. 112 | ```sh 113 | git push origin new-feature-x 114 | ``` 115 | 7a. **Submit a Pull Request**: Create a PR against the original project repository. Clearly describe the changes and their motivations. 116 | 117 | Once your PR is reviewed and approved, it will be merged into the main branch. 118 | 119 |
120 | 121 | --- 122 | 123 | ## 📄 License 124 | 125 | 126 | This project is distributed under the [Apache 2.0](http://www.apache.org/licenses/LICENSE-2.0) License. For more details, refer to the [LICENSE](https://github.com/z3z1ma/cdf/blob/main/LICENSE) file. 127 | 128 | --- 129 | 130 | ## 👏 Acknowledgments 131 | 132 | - Harness (https://harness.io/) for being the proving grounds in which the initial concept of this project was born. 133 | - SQLMesh (https://sqlmesh.com) for being a foundational pillar of this project as well as the team for their support, 134 | advice, and guidance. 135 | - DLT (https://dlthub.com) for being the other foundational pillar of this project as well as the team for their 136 | support, advice, and guidance. 137 | 138 | [**Return**](#Top) 139 | 140 | --- 141 | 142 | 143 | -------------------------------------------------------------------------------- /src/cdf/integrations/feature_flag/harness.py: -------------------------------------------------------------------------------- 1 | """Harness feature flag provider.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | import os 7 | import typing as t 8 | from concurrent.futures import ThreadPoolExecutor 9 | 10 | import dlt 11 | from dlt.common.configuration import with_config 12 | from dlt.sources import DltSource 13 | from dlt.sources.helpers import requests 14 | from featureflags.client import CfClient, Config, Target 15 | from featureflags.evaluations.feature import FeatureConfigKind 16 | from featureflags.interface import Cache 17 | from featureflags.util import log as _ff_logger 18 | 19 | from cdf.integrations.feature_flag.base import ( 20 | AbstractFeatureFlagAdapter, 21 | FlagAdapterResponse, 22 | ) 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | # This exists because the default harness LRU implementation does not store >1000 flags 28 | # The interface is mostly satisfied by dict, so we subclass it and implement the missing methods 29 | class _HarnessCache(dict, Cache): 30 | """A cache implementation for the harness feature flag provider.""" 31 | 32 | def set(self, key: str, value: bool) -> None: 33 | self[key] = value 34 | 35 | def remove(self, key: str | t.List[str]) -> None: 36 | if isinstance(key, str): 37 | self.pop(key, None) 38 | for k in key: 39 | self.pop(k, None) 40 | 41 | 42 | def _quiet_logger(): 43 | """Configure the harness FF logger to only show errors. Its too verbose otherwise.""" 44 | _ff_logger.setLevel(logging.ERROR) 45 | 46 | 47 | class HarnessFeatureFlagAdapter(AbstractFeatureFlagAdapter): 48 | _TARGET = Target("cdf") 49 | 50 | @with_config(sections=("feature_flags",)) 51 | def __init__( 52 | self, 53 | sdk_key: str = dlt.secrets.value, 54 | api_key: str = dlt.secrets.value, 55 | account: str = dlt.secrets.value, 56 | organization: str = dlt.secrets.value, 57 | project: str = dlt.secrets.value, 58 | **kwargs: t.Any, 59 | ) -> None: 60 | """Initialize the adapter.""" 61 | self.sdk_key = sdk_key 62 | self.api_key = api_key 63 | self.account = account 64 | self.organization = organization 65 | self.project = project 66 | self._pool = None 67 | self._client = None 68 | _quiet_logger() 69 | 70 | @property 71 | def client(self) -> CfClient: 72 | """Get the client and cache it in the instance.""" 73 | if self._client is None: 74 | client = CfClient( 75 | sdk_key=str(self.sdk_key), 76 | config=Config( 77 | enable_stream=False, enable_analytics=False, cache=_HarnessCache() 78 | ), 79 | ) 80 | client.wait_for_initialization() 81 | self._client = client 82 | return self._client 83 | 84 | @property 85 | def pool(self) -> ThreadPoolExecutor: 86 | """Get the thread pool.""" 87 | if self._pool is None: 88 | self._pool = ThreadPoolExecutor(thread_name_prefix="cdf-ff-") 89 | return self._pool 90 | 91 | def get(self, feature_name: str) -> FlagAdapterResponse: 92 | """Get a feature flag.""" 93 | if feature_name not in self.get_all_feature_names(): 94 | return FlagAdapterResponse.NOT_FOUND 95 | return FlagAdapterResponse.from_bool( 96 | self.client.bool_variation(feature_name, self._TARGET, False) 97 | ) 98 | 99 | def get_all_feature_names(self) -> t.List[str]: 100 | """Get all the feature flags.""" 101 | return list( 102 | map(lambda f: f.split("/", 1)[1], self.client._repository.cache.keys()) 103 | ) 104 | 105 | def _toggle(self, feature_name: str, flag: bool) -> None: 106 | """Toggle a feature flag.""" 107 | if flag is self.get(feature_name).to_bool(): 108 | return 109 | logger.info(f"Toggling feature flag {feature_name} to {flag}") 110 | requests.patch( 111 | f"https://app.harness.io/cf/admin/features/{feature_name}", 112 | headers={"x-api-key": self.api_key}, 113 | params={ 114 | "accountIdentifier": self.account, 115 | "orgIdentifier": self.organization, 116 | "projectIdentifier": self.project, 117 | }, 118 | json={ 119 | "instructions": [ 120 | { 121 | "kind": "setFeatureFlagState", 122 | "parameters": {"state": "on" if flag else "off"}, 123 | } 124 | ] 125 | }, 126 | ) 127 | 128 | def save(self, feature_name: str, flag: bool) -> None: 129 | """Create a feature flag.""" 130 | if self.get(feature_name) is FlagAdapterResponse.NOT_FOUND: 131 | logger.info(f"Creating feature flag {feature_name}") 132 | try: 133 | requests.post( 134 | "https://app.harness.io/cf/admin/features", 135 | params={ 136 | "accountIdentifier": self.account, 137 | "orgIdentifier": self.organization, 138 | }, 139 | headers={ 140 | "Content-Type": "application/json", 141 | "x-api-key": self.api_key, 142 | }, 143 | json={ 144 | "defaultOnVariation": "on-variation", 145 | "defaultOffVariation": "off-variation", 146 | "description": "Managed by CDF", 147 | "identifier": feature_name, 148 | "name": feature_name.upper(), 149 | "kind": FeatureConfigKind.BOOLEAN.value, 150 | "permanent": True, 151 | "project": self.project, 152 | "variations": [ 153 | {"identifier": "on-variation", "value": "true"}, 154 | {"identifier": "off-variation", "value": "false"}, 155 | ], 156 | }, 157 | ) 158 | except Exception: 159 | logger.exception(f"Failed to create feature flag {feature_name}") 160 | self._toggle(feature_name, flag) 161 | 162 | def save_many(self, flags: t.Dict[str, bool]) -> None: 163 | """Create many feature flags.""" 164 | list(self.pool.map(lambda f: self.save(*f), flags.items())) 165 | 166 | def delete(self, feature_name: str) -> None: 167 | """Drop a feature flag.""" 168 | logger.info(f"Deleting feature flag {feature_name}") 169 | requests.delete( 170 | f"https://app.harness.io/cf/admin/features/{feature_name}", 171 | headers={"x-api-key": self.api_key}, 172 | params={ 173 | "accountIdentifier": self.account, 174 | "orgIdentifier": self.organization, 175 | "projectIdentifier": self.project, 176 | "forceDelete": True, 177 | }, 178 | ) 179 | 180 | def delete_many(self, feature_names: t.List[str]) -> None: 181 | """Drop many feature flags.""" 182 | list(self.pool.map(self.delete, feature_names)) 183 | 184 | def apply_source(self, source: DltSource, *namespace: str) -> DltSource: 185 | """Apply the feature flags to a dlt source.""" 186 | # NOTE: we use just the last section due to legacy design decisions 187 | # We will remove this when the Harness team cleans up the feature flag namespace 188 | ns = f"pipeline__{namespace[-1]}__{source.name}" 189 | 190 | # A closure to produce a resource id 191 | def _get_resource_id(resource: str) -> str: 192 | return f"{ns}__{resource}" 193 | 194 | resource_lookup = { 195 | _get_resource_id(key): resource 196 | for key, resource in source.resources.items() 197 | } 198 | every_resource = resource_lookup.keys() 199 | selected_resources = set( 200 | map(_get_resource_id, source.selected_resources.keys()) 201 | ) 202 | 203 | current_flags = set( 204 | filter(lambda f: f.startswith(ns), self.get_all_feature_names()) 205 | ) 206 | 207 | removed = current_flags.difference(every_resource) 208 | added = selected_resources.difference(current_flags) 209 | 210 | # TODO: reconciliation will be promoted to a top level context function 211 | if os.getenv("HARNESS_FF_AUTORECONCILE", "0") == "1": 212 | self.delete_many(list(removed)) 213 | 214 | self.save_many({f: False for f in added}) 215 | for f in added: 216 | resource_lookup[f].selected = False 217 | for f in current_flags.intersection(selected_resources): 218 | resource_lookup[f].selected = self.get(f).to_bool() 219 | 220 | return source 221 | 222 | def __del__(self) -> None: 223 | """Close the client.""" 224 | if self._client is not None: 225 | self._client.close() 226 | if self._pool is not None: 227 | self._pool.shutdown() 228 | 229 | 230 | __all__ = ["HarnessFeatureFlagAdapter"] 231 | -------------------------------------------------------------------------------- /src/cdf/core/component/base.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import inspect 3 | import typing as t 4 | from contextlib import suppress 5 | from dataclasses import field 6 | from enum import Enum 7 | 8 | import pydantic 9 | 10 | import cdf.core.context as ctx 11 | import cdf.core.injector as injector 12 | 13 | if t.TYPE_CHECKING: 14 | from cdf.core.workspace import Workspace 15 | 16 | T = t.TypeVar("T") 17 | 18 | __all__ = [ 19 | "Component", 20 | "Entrypoint", 21 | "ServiceLevelAgreement", 22 | ] 23 | 24 | 25 | class ServiceLevelAgreement(Enum): 26 | """An SLA to assign to a component. Users can define the meaning of each level.""" 27 | 28 | DEPRECATING = -1 29 | NONE = 0 30 | LOW = 1 31 | MEDIUM = 2 32 | HIGH = 3 33 | CRITICAL = 4 34 | 35 | 36 | class _Node(pydantic.BaseModel, frozen=True): 37 | """A node in a graph of components.""" 38 | 39 | owner: t.Optional[str] = None 40 | """The owner of the node. Useful for tracking who to contact for issues or config.""" 41 | description: str = "No description provided" 42 | """A description of the node.""" 43 | sla: ServiceLevelAgreement = ServiceLevelAgreement.MEDIUM 44 | """The SLA for the node.""" 45 | enabled: bool = True 46 | """Whether the node is enabled or disabled. Disabled components are not loaded.""" 47 | version: str = "0.1.0" 48 | """A semantic version for the node. Can signal breaking changes to dependents.""" 49 | tags: t.List[str] = field(default_factory=list) 50 | """Tags to categorize the node.""" 51 | metadata: t.Dict[str, t.Any] = field(default_factory=dict) 52 | """Additional metadata for the node. Useful for custom integrations.""" 53 | 54 | @pydantic.field_validator("sla", mode="before") 55 | @classmethod 56 | def _validate_sla(cls, value: t.Any) -> t.Any: 57 | if isinstance(value, str): 58 | value = ServiceLevelAgreement[value.upper()] 59 | return value 60 | 61 | @pydantic.field_validator("tags", mode="before") 62 | @classmethod 63 | def _validate_tags(cls, value: t.Any) -> t.Any: 64 | if isinstance(value, str): 65 | value = value.split(",") 66 | return value 67 | 68 | 69 | def _parse_metadata_from_callable(func: t.Callable) -> t.Dict[str, t.Any]: 70 | """Parse _Node metadata from a function or class allowing looser coupling of configuration. 71 | 72 | The function or class docstring is used as the description if available. The rest 73 | of the metadata is inferred from the function or class attributes. The attributes 74 | may be in global form () or dunder form (____). 75 | 76 | We look for the following attributes in the function or class: 77 | - name: The name of the component 78 | 79 | The following attributes in the function or class with fallback to the module: 80 | - version: The version of the component 81 | - enabled: Whether the component is enabled 82 | - sla: The SLA of the component 83 | - owner: The owner of the component 84 | 85 | And the following are merged from both the function and module: 86 | - tags: Tags for the component 87 | - metadata: Additional metadata for the component 88 | """ 89 | if not callable(func): 90 | return {} 91 | 92 | mod = inspect.getmodule(func) 93 | 94 | def _lookup_attributes( 95 | *attrs: str, callback: t.Optional[t.Callable[[t.Any], t.Any]] = None 96 | ) -> t.Optional[t.Any]: 97 | # Look for the attribute in the function and module 98 | for attr in attrs: 99 | with suppress(AttributeError): 100 | v = getattr(func, attr) 101 | if callback: 102 | callback(v) 103 | else: 104 | return v 105 | if mod is not None: 106 | with suppress(AttributeError): 107 | v = getattr(mod, attr) 108 | if callback: 109 | callback(v) 110 | else: 111 | return v 112 | 113 | parsed: t.Dict[str, t.Any] = { 114 | "description": inspect.getdoc(func) or "No description provided" 115 | } 116 | for k in ("name", "version", "enabled", "sla", "owner"): 117 | if (v := _lookup_attributes(k.upper(), f"__{k}__")) is not None: 118 | parsed[k] = v 119 | 120 | _lookup_attributes( 121 | "TAGS", "__tags__", callback=parsed.setdefault("tags", []).extend 122 | ) 123 | _lookup_attributes( 124 | "METADATA", "__metadata__", callback=parsed.setdefault("metadata", {}).update 125 | ) 126 | 127 | return parsed 128 | 129 | 130 | def _bind_active_workspace(func: t.Any) -> t.Any: 131 | """Bind the active workspace to a function or class. 132 | 133 | Args: 134 | func: The function or class to bind the workspace to. 135 | 136 | Returns: 137 | The bound function or class. 138 | """ 139 | if callable(func): 140 | return ctx.resolve(eagerly_bind_workspace=True)(func) 141 | return func 142 | 143 | 144 | def _get_bind_func(info: pydantic.ValidationInfo) -> t.Callable: 145 | """Get the bind function from the pydantic context or use the active workspace. 146 | 147 | Args: 148 | info: The pydantic validation info. 149 | 150 | Returns: 151 | The bind function to use for the component. 152 | """ 153 | context = info.context 154 | if context: 155 | bind = t.cast("Workspace", context["parent"]).bind 156 | else: 157 | bind = _bind_active_workspace 158 | return bind 159 | 160 | 161 | def _unwrap_entrypoint(value: t.Any) -> t.Any: 162 | """Import an entrypoint if it is a string. 163 | 164 | Args: 165 | value: The value to import. 166 | 167 | Returns: 168 | The imported value if it is a string, otherwise the original value. 169 | """ 170 | if isinstance(value, str): 171 | mod, func = value.split(":", 1) 172 | mod = importlib.import_module(mod) 173 | value = getattr(mod, func) 174 | return value 175 | 176 | 177 | class Component(_Node, t.Generic[T], frozen=True): 178 | """A component with a binding to a dependency.""" 179 | 180 | main: injector.Dependency[T] 181 | """The dependency for the component. This is what is injected into the workspace.""" 182 | 183 | name: t.Annotated[str, pydantic.Field(..., pattern=r"^[a-zA-Z_][a-zA-Z0-9_]*$")] 184 | """The key to register the component in the container. 185 | 186 | Must be a valid Python identifier. Users can use these names as function parameters 187 | for implicit dependency injection. Names must be unique within the workspace. 188 | """ 189 | 190 | def __call__(self) -> T: 191 | """Unwrap the main dependency invoking the underlying callable.""" 192 | return self.main.unwrap() 193 | 194 | @pydantic.model_validator(mode="before") 195 | @classmethod 196 | def _parse_main(cls, data: t.Any) -> t.Any: 197 | """Parse function metadata into node defaults.""" 198 | if inspect.isfunction(data) or isinstance(data, injector.Dependency): 199 | data = {"main": data} 200 | if isinstance(data, dict): 201 | dep = data["main"] 202 | if isinstance(dep, dict): 203 | func = dep["factory"] 204 | if dep.get("alias", None): 205 | data.setdefault("name", dep["alias"]) 206 | elif isinstance(dep, injector.Dependency): 207 | func = dep.factory 208 | if dep.alias: 209 | data.setdefault("name", dep.alias) 210 | else: 211 | func = dep 212 | return {**_parse_metadata_from_callable(func), **data} 213 | return data 214 | 215 | @pydantic.field_validator("main", mode="before") 216 | @classmethod 217 | def _ensure_dependency(cls, value: t.Any, info: pydantic.ValidationInfo) -> t.Any: 218 | """Ensure the main function is a dependency.""" 219 | value = _unwrap_entrypoint(value) 220 | if isinstance(value, (dict, injector.Dependency)): 221 | parsed_dep = injector.Dependency.model_validate(value, context=info.context) 222 | else: 223 | parsed_dep = injector.Dependency.wrap(value) 224 | # NOTE: We do this extra round-trip to bypass the unecessary Generic type check in pydantic 225 | return parsed_dep.model_dump() 226 | 227 | @pydantic.model_validator(mode="after") 228 | def _bind_main(self, info: pydantic.ValidationInfo) -> t.Any: 229 | """Bind the active workspace to the main function.""" 230 | self.main.map(_get_bind_func(info), idempotent=True) 231 | return self 232 | 233 | def __str__(self): 234 | return f"" 235 | 236 | 237 | class Entrypoint(_Node, t.Generic[T], frozen=True): 238 | """An entrypoint representing an invokeable set of functions.""" 239 | 240 | main: t.Callable[..., T] 241 | """The main function associated with the entrypoint.""" 242 | 243 | name: str 244 | """The name of the entrypoint. 245 | 246 | This is used to register the entrypoint in the workspace and CLI. Names must be 247 | unique within the workspace. The name can contain spaces and special characters. 248 | """ 249 | 250 | @pydantic.model_validator(mode="before") 251 | @classmethod 252 | def _parse_main(cls, data: t.Any) -> t.Any: 253 | """Parse function metadata into node defaults.""" 254 | if inspect.isfunction(data): 255 | data = {"main": data} 256 | if isinstance(data, dict): 257 | func = _unwrap_entrypoint(data["main"]) 258 | return {**_parse_metadata_from_callable(func), **data} 259 | return data 260 | 261 | @pydantic.field_validator("main", mode="before") 262 | @classmethod 263 | def _bind_main(cls, value: t.Any, info: pydantic.ValidationInfo) -> t.Any: 264 | """Bind the active workspace to the main function.""" 265 | return _get_bind_func(info)(_unwrap_entrypoint(value)) 266 | 267 | def __str__(self): 268 | return f"" 269 | 270 | def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: 271 | """Invoke the entrypoint.""" 272 | return self.main(*args, **kwargs) 273 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/cdf/core/workspace.py: -------------------------------------------------------------------------------- 1 | """A workspace is a container for components and configurations.""" 2 | 3 | import os 4 | import time 5 | import typing as t 6 | from functools import cached_property, partialmethod 7 | from pathlib import Path 8 | 9 | import pydantic 10 | from typing_extensions import ParamSpec, Self 11 | 12 | import cdf.core.component as cmp 13 | import cdf.core.configuration as conf 14 | import cdf.core.context as ctx 15 | import cdf.core.injector as injector 16 | 17 | if t.TYPE_CHECKING: 18 | import click 19 | import sqlmesh 20 | 21 | 22 | T = t.TypeVar("T") 23 | P = ParamSpec("P") 24 | 25 | __all__ = ["Workspace"] 26 | 27 | 28 | class Workspace(pydantic.BaseModel, frozen=True): 29 | """A CDF workspace that allows for dependency injection and configuration resolution.""" 30 | 31 | name: str = "default" 32 | """A human-readable name for the workspace.""" 33 | version: str = "0.1.0" 34 | """A semver version string for the workspace.""" 35 | environment: str = pydantic.Field( 36 | default_factory=lambda: os.getenv("CDF_ENVIRONMENT", "dev") 37 | ) 38 | """The runtime environment used to resolve configuration.""" 39 | conf_resolver: conf.ConfigResolver = pydantic.Field( 40 | default_factory=conf.ConfigResolver 41 | ) 42 | """The configuration resolver for the workspace.""" 43 | container: injector.DependencyRegistry = pydantic.Field( 44 | default_factory=injector.DependencyRegistry 45 | ) 46 | """The dependency injection container for the workspace.""" 47 | configuration_sources: t.Iterable[conf.ConfigSource] = ( 48 | "cdf.toml", 49 | "cdf.yaml", 50 | "cdf.json", 51 | "~/.cdf.toml", 52 | ) 53 | """A list of configuration sources resolved and merged by the workspace.""" 54 | services_: t.Iterable[cmp.ServiceDef] = pydantic.Field( 55 | default_factory=tuple, alias="services" 56 | ) 57 | """An iterable of raw service definitions that the workspace provides.""" 58 | pipelines_: t.Iterable[cmp.DataPipelineDef] = pydantic.Field( 59 | default_factory=tuple, alias="pipelines" 60 | ) 61 | """An iterable of raw pipeline definitions that the workspace provides.""" 62 | publishers_: t.Iterable[cmp.DataPublisherDef] = pydantic.Field( 63 | default_factory=tuple, alias="publishers" 64 | ) 65 | """An iterable of raw publisher definitions that the workspace provides.""" 66 | operations_: t.Iterable[cmp.OperationDef] = pydantic.Field( 67 | default_factory=tuple, alias="operations" 68 | ) 69 | """An iterable of raw generic operation definitions that the workspace provides.""" 70 | 71 | # TODO: define an adapter for transformation providers 72 | sqlmesh_path: t.Optional[t.Union[str, Path]] = None 73 | """The path to the sqlmesh root for the workspace.""" 74 | sqlmesh_context_kwargs: t.Dict[str, t.Any] = {} 75 | """Keyword arguments to pass to the sqlmesh context.""" 76 | 77 | if t.TYPE_CHECKING: 78 | # PERF: this is a workaround for pydantic not being able to build a model with a forwardref 79 | # we still want the deferred import at runtime 80 | sqlmesh_context_class: t.Optional[t.Type[sqlmesh.Context]] = None 81 | """A custom context class to use for sqlmesh.""" 82 | else: 83 | sqlmesh_context_class: t.Optional[t.Type[t.Any]] = None 84 | """A custom context class to use for sqlmesh.""" 85 | 86 | @pydantic.model_validator(mode="after") 87 | def _setup(self) -> Self: 88 | """Initialize the workspace.""" 89 | for source in self.configuration_sources: 90 | self.conf_resolver.import_source(source) 91 | self.conf_resolver.set_environment(self.environment) 92 | self.container.add_from_dependency( 93 | injector.Dependency.instance(self), 94 | key="cdf_workspace", 95 | override=True, 96 | ) 97 | self.container.add_from_dependency( 98 | injector.Dependency.instance(self.environment), 99 | key="cdf_environment", 100 | override=True, 101 | ) 102 | self.container.add_from_dependency( 103 | injector.Dependency.instance(self.conf_resolver), 104 | key="cdf_config", 105 | override=True, 106 | ) 107 | self.container.add_from_dependency( 108 | injector.Dependency.singleton(self.get_sqlmesh_context_or_raise), 109 | key="cdf_transform", 110 | override=True, 111 | ) 112 | for service in self.services.values(): 113 | self.container.add_from_dependency(service.main, key=service.name) 114 | self.activate() 115 | return self 116 | 117 | def activate(self) -> Self: 118 | """Activate the workspace for the current context.""" 119 | ctx.set_active_workspace(self) 120 | return self 121 | 122 | def _parse_definitions( 123 | self, defs: t.Iterable[cmp.TComponentDef], into: t.Type[cmp.TComponent] 124 | ) -> t.Dict[str, cmp.TComponent]: 125 | """Parse a list of component definitions into a lookup.""" 126 | components = {} 127 | with ctx.use_workspace(self): 128 | for definition in defs: 129 | component = into.model_validate(definition, context={"parent": self}) 130 | components[component.name] = component 131 | return components 132 | 133 | @cached_property 134 | def services(self) -> t.Dict[str, cmp.Service]: 135 | """Return the resolved services of the workspace.""" 136 | return self._parse_definitions(self.services_, cmp.Service) 137 | 138 | @cached_property 139 | def pipelines(self) -> t.Dict[str, cmp.DataPipeline]: 140 | """Return the resolved data pipelines of the workspace.""" 141 | return self._parse_definitions(self.pipelines_, cmp.DataPipeline) 142 | 143 | @cached_property 144 | def publishers(self) -> t.Dict[str, cmp.DataPublisher]: 145 | """Return the resolved data publishers of the workspace.""" 146 | return self._parse_definitions(self.publishers_, cmp.DataPublisher) 147 | 148 | @cached_property 149 | def operations(self) -> t.Dict[str, cmp.Operation]: 150 | """Return the resolved operations of the workspace.""" 151 | return self._parse_definitions(self.operations_, cmp.Operation) 152 | 153 | @t.overload 154 | def get_sqlmesh_context( 155 | self, 156 | gateway: t.Optional[str], 157 | must_exist: t.Literal[False], 158 | **kwargs: t.Any, 159 | ) -> t.Optional["sqlmesh.Context"]: ... 160 | 161 | @t.overload 162 | def get_sqlmesh_context( 163 | self, 164 | gateway: t.Optional[str], 165 | must_exist: t.Literal[True], 166 | **kwargs: t.Any, 167 | ) -> "sqlmesh.Context": ... 168 | 169 | def get_sqlmesh_context( 170 | self, gateway: t.Optional[str] = None, must_exist: bool = False, **kwargs: t.Any 171 | ) -> t.Optional["sqlmesh.Context"]: 172 | """Return the transform context or raise an error if not defined.""" 173 | import sqlmesh 174 | 175 | if self.sqlmesh_path is None: 176 | if must_exist: 177 | raise ValueError("Transformation provider not defined.") 178 | return None 179 | 180 | kwargs = {**self.sqlmesh_context_kwargs, **kwargs} 181 | with ctx.use_workspace(self): 182 | klass = self.sqlmesh_context_class or sqlmesh.Context 183 | return klass(paths=[self.sqlmesh_path], gateway=gateway, **kwargs) 184 | 185 | if t.TYPE_CHECKING: 186 | 187 | def get_sqlmesh_context_or_raise( 188 | self, gateway: t.Optional[str] = None, **kwargs: t.Any 189 | ) -> "sqlmesh.Context": ... 190 | 191 | else: 192 | get_sqlmesh_context_or_raise = partialmethod( 193 | get_sqlmesh_context, must_exist=True 194 | ) 195 | 196 | @property 197 | def cli(self) -> "click.Group": 198 | """Dynamically generate a CLI entrypoint for the workspace.""" 199 | import click 200 | 201 | @click.group() 202 | def cli() -> None: 203 | """A dynamically generated CLI for the workspace.""" 204 | self.activate() 205 | 206 | def _list(d: t.Dict[str, cmp.TComponent], verbose: bool = False) -> None: 207 | for name in sorted(d.keys()): 208 | if verbose: 209 | click.echo(d[name].model_dump_json(indent=2, exclude={"main"})) 210 | else: 211 | click.echo(d[name]) 212 | 213 | for k in ("services", "pipelines", "publishers", "operations"): 214 | cli.command( 215 | f"list-{k}", help=f"List the {k} in the {self.name} workspace." 216 | )( 217 | click.option("-v", "--verbose", is_flag=True)( 218 | lambda verbose=False, k=k: _list(getattr(self, k), verbose=verbose) 219 | ) 220 | ) 221 | 222 | @cli.command("run-pipeline") 223 | @click.argument( 224 | "pipeline_name", 225 | required=False, 226 | type=click.Choice(list(self.pipelines.keys())), 227 | ) 228 | @click.option( 229 | "--test", 230 | is_flag=True, 231 | help="Run the pipelines integration test if defined.", 232 | ) 233 | @click.pass_context 234 | def run_pipeline( 235 | ctx: click.Context, 236 | pipeline_name: t.Optional[str] = None, 237 | test: bool = False, 238 | ) -> None: 239 | """Run a data pipeline.""" 240 | if pipeline_name is None: 241 | pipeline_name = click.prompt( 242 | "Enter a pipeline", 243 | type=click.Choice(list(self.pipelines.keys())), 244 | show_choices=True, 245 | ) 246 | if pipeline_name is None: 247 | raise click.BadParameter( 248 | "Pipeline must be specified.", ctx=ctx, param_hint="pipeline" 249 | ) 250 | 251 | pipeline = self.pipelines[pipeline_name] 252 | 253 | if test: 254 | click.echo("Running pipeline tests.", err=True) 255 | try: 256 | pipeline.run_tests() 257 | except Exception as e: 258 | click.echo(f"Pipeline test(s) failed: {e}", err=True) 259 | ctx.exit(1) 260 | else: 261 | click.echo("Pipeline test(s) passed!", err=True) 262 | ctx.exit(0) 263 | 264 | start = time.time() 265 | try: 266 | jobs = pipeline() 267 | except Exception as e: 268 | click.echo( 269 | f"Pipeline failed after {time.time() - start:.2f} seconds: {e}", 270 | err=True, 271 | ) 272 | ctx.exit(1) 273 | 274 | click.echo( 275 | f"Pipeline process finished in {time.time() - start:.2f} seconds.", 276 | err=True, 277 | ) 278 | 279 | for job in jobs: 280 | if job.has_failed_jobs: 281 | ctx.fail("Pipeline failed.") 282 | 283 | ctx.exit(0) 284 | 285 | @cli.command("run-publisher") 286 | @click.argument( 287 | "publisher_name", 288 | required=False, 289 | type=click.Choice(list(self.publishers.keys())), 290 | ) 291 | @click.option( 292 | "--test", 293 | is_flag=True, 294 | help="Run the publishers integration test if defined.", 295 | ) 296 | @click.pass_context 297 | def run_publisher( 298 | ctx: click.Context, 299 | publisher_name: t.Optional[str] = None, 300 | test: bool = False, 301 | ) -> None: 302 | """Run a data publisher.""" 303 | if publisher_name is None: 304 | publisher_name = click.prompt( 305 | "Enter a publisher", 306 | type=click.Choice(list(self.publishers.keys())), 307 | show_choices=True, 308 | ) 309 | if publisher_name is None: 310 | raise click.BadParameter( 311 | "Publisher must be specified.", ctx=ctx, param_hint="publisher" 312 | ) 313 | 314 | publisher = self.publishers[publisher_name] 315 | 316 | start = time.time() 317 | try: 318 | publisher() 319 | except Exception as e: 320 | click.echo( 321 | f"Publisher failed after {time.time() - start:.2f} seconds: {e}", 322 | err=True, 323 | ) 324 | ctx.exit(1) 325 | 326 | click.echo( 327 | f"Publisher process finished in {time.time() - start:.2f} seconds.", 328 | err=True, 329 | ) 330 | ctx.exit(0) 331 | 332 | @cli.command("run-operation") 333 | @click.argument( 334 | "operation_name", 335 | required=False, 336 | type=click.Choice(list(self.operations.keys())), 337 | ) 338 | @click.pass_context 339 | def run_operation( 340 | ctx: click.Context, operation_name: t.Optional[str] = None 341 | ) -> int: 342 | """Run a generic operation.""" 343 | if operation_name is None: 344 | operation_name = click.prompt( 345 | "Enter an operation", 346 | type=click.Choice(list(self.operations.keys())), 347 | show_choices=True, 348 | ) 349 | if operation_name is None: 350 | raise click.BadParameter( 351 | "Operation must be specified.", ctx=ctx, param_hint="operation" 352 | ) 353 | 354 | operation = self.operations[operation_name] 355 | 356 | ctx.exit(operation()) 357 | 358 | return cli 359 | 360 | def bind(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]: 361 | """Wrap a function with configuration and dependencies defined in the workspace.""" 362 | configured_f = self.conf_resolver.resolve_defaults(func_or_cls) 363 | return self.container.wire(configured_f) 364 | 365 | def invoke(self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any) -> T: 366 | """Invoke a function with configuration and dependencies defined in the workspace.""" 367 | with ctx.use_workspace(self): 368 | return self.bind(func_or_cls)(*args, **kwargs) 369 | -------------------------------------------------------------------------------- /src/cdf/integrations/slack.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import traceback 4 | import typing as t 5 | from datetime import datetime, timezone 6 | from enum import Enum 7 | from textwrap import dedent, indent 8 | 9 | import requests 10 | 11 | SLACK_MAX_TEXT_LENGTH = 3000 12 | SLACK_MAX_ALERT_PREVIEW_BLOCKS = 5 13 | SLACK_MAX_ATTACHMENTS_BLOCKS = 50 14 | CONTINUATION_SYMBOL = "..." 15 | 16 | 17 | TSlackBlock = t.Dict[str, t.Any] 18 | 19 | 20 | class TSlackBlocks(t.TypedDict): 21 | blocks: t.List[TSlackBlock] 22 | 23 | 24 | class TSlackMessage(TSlackBlocks, t.TypedDict): 25 | attachments: t.List[TSlackBlocks] 26 | 27 | 28 | class SlackMessageComposer: 29 | """Builds Slack message with primary and secondary blocks""" 30 | 31 | def __init__(self, initial_message: t.Optional[TSlackMessage] = None) -> None: 32 | """Initialize the Slack message builder""" 33 | self.slack_message = initial_message or { 34 | "blocks": [], 35 | "attachments": [{"blocks": []}], 36 | } 37 | 38 | def add_primary_blocks(self, *blocks: TSlackBlock) -> "SlackMessageComposer": 39 | """Add blocks to the message. Blocks are always displayed""" 40 | self.slack_message["blocks"].extend(blocks) 41 | return self 42 | 43 | def add_secondary_blocks(self, *blocks: TSlackBlock) -> "SlackMessageComposer": 44 | """Add attachments to the message 45 | 46 | Attachments are hidden behind "show more" button. The first 5 attachments 47 | are always displayed. NOTICE: attachments blocks are deprecated by Slack 48 | """ 49 | self.slack_message["attachments"][0]["blocks"].extend(blocks) 50 | if ( 51 | len(self.slack_message["attachments"][0]["blocks"]) 52 | >= SLACK_MAX_ATTACHMENTS_BLOCKS 53 | ): 54 | raise ValueError("Too many attachments") 55 | return self 56 | 57 | def _introspect(self) -> "SlackMessageComposer": 58 | """Print the message to stdout 59 | 60 | This is a debugging method. Useful during composition of the message.""" 61 | print(json.dumps(self.slack_message, indent=2)) 62 | return self 63 | 64 | 65 | def normalize_message(message: t.Union[str, t.List[str], t.Iterable[str]]) -> str: 66 | """Normalize message to fit Slack's max text length""" 67 | if isinstance(message, (list, tuple, set)): 68 | message = stringify_list(list(message)) 69 | assert isinstance(message, str), f"Message must be a string, got {type(message)}" 70 | dedented_message = dedent(message) 71 | if len(dedented_message) < SLACK_MAX_TEXT_LENGTH: 72 | return dedent(dedented_message) 73 | return dedent( 74 | dedented_message[: SLACK_MAX_TEXT_LENGTH - len(CONTINUATION_SYMBOL) - 3] 75 | + CONTINUATION_SYMBOL 76 | + dedented_message[-3:] 77 | ) 78 | 79 | 80 | def divider_block() -> dict: 81 | """Create a divider block""" 82 | return {"type": "divider"} 83 | 84 | 85 | def fields_section_block(*messages: str) -> dict: 86 | """Create a section block with multiple fields""" 87 | return { 88 | "type": "section", 89 | **{ 90 | "fields": { 91 | "type": "mrkdwn", 92 | "text": normalize_message(message), 93 | } 94 | for message in messages 95 | }, 96 | } 97 | 98 | 99 | def text_section_block(message: str) -> dict: 100 | """Create a section block with text""" 101 | return { 102 | "type": "section", 103 | "text": { 104 | "type": "mrkdwn", 105 | "text": normalize_message(message), 106 | }, 107 | } 108 | 109 | 110 | def empty_section_block() -> dict: 111 | """Create an empty section block""" 112 | return { 113 | "type": "section", 114 | "text": { 115 | "type": "mrkdwn", 116 | "text": normalize_message("\t"), 117 | }, 118 | } 119 | 120 | 121 | def context_block(*messages: str) -> dict: 122 | """Create a context block with multiple fields""" 123 | return { 124 | "type": "context", 125 | "elements": [ 126 | { 127 | "type": "mrkdwn", 128 | "text": normalize_message(message), 129 | } 130 | for message in messages 131 | ], 132 | } 133 | 134 | 135 | def header_block(message: str) -> dict: 136 | """Create a header block""" 137 | return { 138 | "type": "header", 139 | "text": { 140 | "type": "plain_text", 141 | "text": message, 142 | }, 143 | } 144 | 145 | 146 | def button_action_block(text: str, url: str) -> dict: 147 | """Create a button action block""" 148 | return { 149 | "type": "actions", 150 | "elements": [ 151 | { 152 | "type": "button", 153 | "text": {"type": "plain_text", "text": text, "emoji": True}, 154 | "value": text, 155 | "url": url, 156 | } 157 | ], 158 | } 159 | 160 | 161 | def compacted_sections_blocks(*messages: t.Union[str, t.Iterable[str]]) -> t.List[dict]: 162 | """Create a list of compacted sections blocks""" 163 | return [ 164 | { 165 | "type": "section", 166 | "fields": [ 167 | { 168 | "type": "mrkdwn", 169 | "text": normalize_message(message), 170 | } 171 | for message in messages[i : i + 2] 172 | ], 173 | } 174 | for i in range(0, len(messages), 2) 175 | ] 176 | 177 | 178 | class SlackAlertIcon(str, Enum): 179 | """Enum for status of the alert""" 180 | 181 | # simple statuses 182 | OK = ":large_green_circle:" 183 | WARN = ":warning:" 184 | ERROR = ":x:" 185 | START = ":arrow_forward:" 186 | ALERT = ":rotating_light:" 187 | STOP = ":stop_button:" 188 | 189 | # log levels 190 | UNKNOWN = ":question:" 191 | INFO = ":information_source:" 192 | DEBUG = ":beetle:" 193 | CRITICAL = ":fire:" 194 | FATAL = ":skull_and_crossbones:" 195 | EXCEPTION = ":boom:" 196 | 197 | # test statuses 198 | FAILURE = ":no_entry_sign:" 199 | SUCCESS = ":white_check_mark:" 200 | WARNING = ":warning:" 201 | SKIPPED = ":fast_forward:" 202 | PASSED = ":white_check_mark:" 203 | 204 | def __str__(self) -> str: 205 | return self.value 206 | 207 | 208 | def stringify_list(list_variation: t.Union[t.List[str], str]) -> str: 209 | """Prettify and deduplicate list of strings converting it to a newline delimited string""" 210 | if isinstance(list_variation, str): 211 | return list_variation 212 | if len(list_variation) == 1: 213 | return list_variation[0] 214 | list_variation = list(list_variation) 215 | for i, item in enumerate(list_variation): 216 | if not isinstance(item, str): 217 | list_variation[i] = str(item) 218 | order = {item: i for i, item in reversed(list(enumerate(list_variation)))} 219 | return "\n".join(sorted(set(list_variation), key=lambda item: order[item])) 220 | 221 | 222 | def send_basic_slack_message( 223 | incoming_hook: str, message: str, is_markdown: bool = True 224 | ) -> None: 225 | """Sends a `message` to Slack `incoming_hook`, by default formatted as markdown.""" 226 | resp = requests.post( 227 | incoming_hook, 228 | data=json.dumps({"text": message, "mrkdwn": is_markdown}).encode("utf-8"), 229 | headers={"Content-Type": "application/json;charset=utf-8"}, 230 | ) 231 | resp.raise_for_status() 232 | 233 | 234 | def send_extract_start_slack_message( 235 | incoming_hook: str, 236 | source: str, 237 | run_id: str, 238 | tags: t.List[str], 239 | owners: t.List[str], 240 | environment: str, 241 | resources_selected: t.List[str], 242 | resources_count: int, 243 | ) -> None: 244 | """Sends a Slack message for the start of an extract""" 245 | resp = requests.post( 246 | incoming_hook, 247 | json=SlackMessageComposer() 248 | .add_primary_blocks( 249 | header_block(f"{SlackAlertIcon.START} Starting Extract (id: {run_id})"), 250 | context_block( 251 | "*Source:* {source} |".format(source=source), 252 | "*Status:* Starting Extraction |", 253 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 254 | ), 255 | divider_block(), 256 | *compacted_sections_blocks( 257 | ("*Tags*", stringify_list(tags)), 258 | ("*Owners*", stringify_list(owners)), 259 | ), 260 | *compacted_sections_blocks( 261 | ("*Environment*", environment), 262 | ( 263 | "*Resources*", 264 | f"{len(resources_selected)}/{resources_count} selected", 265 | ), 266 | ), 267 | divider_block(), 268 | text_section_block( 269 | f""" 270 | Resources selected for extraction :test_tube: 271 | 272 | {stringify_list(resources_selected)} 273 | """ 274 | ), 275 | button_action_block("View in Harness", url="https://app.harness.io"), 276 | context_block(f"*Python Version:* {sys.version}"), 277 | ) 278 | .slack_message, 279 | ) 280 | resp.raise_for_status() 281 | 282 | 283 | def send_extract_failure_message( 284 | incoming_hook: str, source: str, run_id: str, duration: float, error: Exception 285 | ) -> None: 286 | """Sends a Slack message for the failure of an extract""" 287 | # trace = "\n".join(f"> {line}" for line in traceback.format_exc().splitlines()) 288 | resp = requests.post( 289 | incoming_hook, 290 | json=SlackMessageComposer() 291 | .add_primary_blocks( 292 | header_block(f"{SlackAlertIcon.ERROR} Extract Failed (id: {run_id})"), 293 | context_block( 294 | "*Source:* {source} |".format(source=source), 295 | "*Status:* Extraction Failed |", 296 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 297 | ), 298 | divider_block(), 299 | text_section_block( 300 | f""" 301 | Extract failed after {duration:.2f}s :fire: 302 | 303 | ``` 304 | {indent(traceback.format_exc(), " " * 12, lambda line: (not line.startswith("Traceback")))} 305 | ``` 306 | 307 | Please check the logs for more information. 308 | """ 309 | ), 310 | button_action_block("View in Harness", url="https://app.harness.io"), 311 | context_block(f"*Python Version:* {sys.version}"), 312 | ) 313 | .slack_message, 314 | ) 315 | resp.raise_for_status() 316 | 317 | 318 | def send_extract_success_message( 319 | incoming_hook: str, source: str, run_id: str, duration: float 320 | ) -> None: 321 | """Sends a Slack message for the success of an extract""" 322 | resp = requests.post( 323 | incoming_hook, 324 | json=SlackMessageComposer() 325 | .add_primary_blocks( 326 | header_block(f"{SlackAlertIcon.OK} Extract Succeeded (id: {run_id})"), 327 | context_block( 328 | "*Source:* {source} |".format(source=source), 329 | "*Status:* Extraction Succeeded |", 330 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 331 | ), 332 | divider_block(), 333 | text_section_block( 334 | f""" 335 | Extract succeeded after {duration:.2f}s :tada: 336 | 337 | Please check the logs for more information. 338 | """ 339 | ), 340 | button_action_block("View in Harness", url="https://app.harness.io"), 341 | context_block(f"*Python Version:* {sys.version}"), 342 | ) 343 | .slack_message, 344 | ) 345 | resp.raise_for_status() 346 | 347 | 348 | def send_normalize_start_slack_message( 349 | incoming_hook: str, 350 | source: str, 351 | blob_name: str, 352 | run_id: str, 353 | environment: str, 354 | ) -> None: 355 | """Sends a Slack message for the start of an extract""" 356 | _ = environment 357 | resp = requests.post( 358 | incoming_hook, 359 | json=SlackMessageComposer() 360 | .add_primary_blocks( 361 | header_block(f"{SlackAlertIcon.START} Normalizing (id: {run_id})"), 362 | context_block( 363 | "*Source:* {source} |".format(source=source), 364 | "*Status:* Starting Normalization |", 365 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 366 | ), 367 | divider_block(), 368 | text_section_block( 369 | f""" 370 | Pending load package discovered in stage :package: 371 | 372 | Starting normalization for: :file_folder: 373 | 374 | `{blob_name}` 375 | """ 376 | ), 377 | context_block(f"*Python Version:* {sys.version}"), 378 | ) 379 | .slack_message, 380 | ) 381 | resp.raise_for_status() 382 | 383 | 384 | def send_normalize_failure_message( 385 | incoming_hook: str, 386 | source: str, 387 | blob_name: str, 388 | run_id: str, 389 | duration: float, 390 | error: Exception, 391 | ) -> None: 392 | """Sends a Slack message for the failure of an normalization""" 393 | # trace = "\n".join(f"> {line}" for line in traceback.format_exc().splitlines()) 394 | resp = requests.post( 395 | incoming_hook, 396 | json=SlackMessageComposer() 397 | .add_primary_blocks( 398 | header_block(f"{SlackAlertIcon.ERROR} Normalization Failed (id: {run_id})"), 399 | context_block( 400 | "*Source:* {source} |".format(source=source), 401 | "*Status:* Normalization Failed |", 402 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 403 | ), 404 | divider_block(), 405 | text_section_block( 406 | f""" 407 | Normalization failed after {duration:.2f}s :fire: 408 | 409 | ``` 410 | {indent(traceback.format_exc(), " " * 12, lambda line: (not line.startswith("Traceback")))} 411 | ``` 412 | 413 | Please check the pod logs for more information. 414 | """ 415 | ), 416 | context_block(f"*Python Version:* {sys.version}"), 417 | ) 418 | .slack_message, 419 | ) 420 | resp.raise_for_status() 421 | 422 | 423 | def send_normalization_success_message( 424 | incoming_hook: str, source: str, blob_name: str, run_id: str, duration: float 425 | ) -> None: 426 | """Sends a Slack message for the success of an normalization""" 427 | resp = requests.post( 428 | incoming_hook, 429 | json=SlackMessageComposer() 430 | .add_primary_blocks( 431 | header_block(f"{SlackAlertIcon.OK} Normalization Succeeded (id: {run_id})"), 432 | context_block( 433 | "*Source:* {source} |".format(source=source), 434 | "*Status:* Normalization Succeeded |", 435 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 436 | ), 437 | divider_block(), 438 | text_section_block( 439 | f""" 440 | Normalization took {duration:.2f}s :tada: 441 | 442 | The package was normalized successfully: :file_folder: 443 | 444 | `{blob_name}` 445 | 446 | This package is now prepared for loading. 447 | """ 448 | ), 449 | ) 450 | .slack_message, 451 | ) 452 | resp.raise_for_status() 453 | 454 | 455 | def send_load_start_slack_message( 456 | incoming_hook: str, 457 | source: str, 458 | destination: str, 459 | dataset: str, 460 | run_id: str, 461 | ) -> None: 462 | """Sends a Slack message for the start of a load""" 463 | resp = requests.post( 464 | incoming_hook, 465 | json=SlackMessageComposer() 466 | .add_primary_blocks( 467 | header_block(f"{SlackAlertIcon.START} Loading (id: {run_id})"), 468 | context_block( 469 | "*Source:* {source} |".format(source=source), 470 | "*Status:* Starting Load |", 471 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 472 | ), 473 | divider_block(), 474 | *compacted_sections_blocks( 475 | ("*Destination*", destination), 476 | ("*Dataset*", dataset), 477 | ), 478 | context_block(f"*Python Version:* {sys.version}"), 479 | ) 480 | .slack_message, 481 | ) 482 | resp.raise_for_status() 483 | 484 | 485 | def send_load_failure_message( 486 | incoming_hook: str, 487 | source: str, 488 | destination: str, 489 | dataset: str, 490 | run_id: str, 491 | ) -> None: 492 | """Sends a Slack message for the failure of an load""" 493 | # trace = "\n".join(f"> {line}" for line in traceback.format_exc().splitlines()) 494 | resp = requests.post( 495 | incoming_hook, 496 | json=SlackMessageComposer() 497 | .add_primary_blocks( 498 | header_block(f"{SlackAlertIcon.ERROR} Load Failed (id: {run_id})"), 499 | context_block( 500 | "*Source:* {source} |".format(source=source), 501 | "*Status:* Normalization Failed |", 502 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 503 | ), 504 | divider_block(), 505 | text_section_block( 506 | f""" 507 | Load to {destination} dataset named {dataset} failed :fire: 508 | 509 | ``` 510 | {indent(traceback.format_exc(), " " * 12, lambda line: (not line.startswith("Traceback")))} 511 | ``` 512 | 513 | Please check the pod logs for more information. 514 | """ 515 | ), 516 | context_block(f"*Python Version:* {sys.version}"), 517 | ) 518 | .slack_message, 519 | ) 520 | resp.raise_for_status() 521 | 522 | 523 | def send_load_success_message( 524 | incoming_hook: str, 525 | source: str, 526 | destination: str, 527 | dataset: str, 528 | run_id: str, 529 | payload: str, 530 | ) -> None: 531 | """Sends a Slack message for the success of an normalization""" 532 | resp = requests.post( 533 | incoming_hook, 534 | json=SlackMessageComposer() 535 | .add_primary_blocks( 536 | header_block(f"{SlackAlertIcon.OK} Load Succeeded (id: {run_id})"), 537 | context_block( 538 | "*Source:* {source} |".format(source=source), 539 | "*Status:* Loading Succeeded |", 540 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")), 541 | ), 542 | divider_block(), 543 | *compacted_sections_blocks( 544 | ("*Destination*", destination), 545 | ("*Dataset*", dataset), 546 | ), 547 | divider_block(), 548 | text_section_block( 549 | f""" 550 | The package was loaded successfully: :file_folder: 551 | 552 | ``` 553 | {payload} 554 | ``` 555 | """ 556 | ), 557 | context_block(f"*Python Version:* {sys.version}"), 558 | ) 559 | .slack_message, 560 | ) 561 | resp.raise_for_status() 562 | -------------------------------------------------------------------------------- /src/cdf/core/configuration.py: -------------------------------------------------------------------------------- 1 | """Configuration utilities for the CDF configuration resolver system. 2 | 3 | There are 3 ways to request configuration values: 4 | 5 | 1. Using a Request annotation: 6 | 7 | Pro: It's explicit and re-usable. An annotation can be used in multiple places. 8 | 9 | ```python 10 | import typing as t 11 | import cdf.core.configuration as conf 12 | 13 | def foo(bar: t.Annotated[str, conf.Request["api.key"]]) -> None: 14 | print(bar) 15 | ``` 16 | 17 | 2. Setting a __cdf_resolve__ attribute on a callable object. This can be done 18 | directly or by using the `map_section` or `map_values` decorators: 19 | 20 | Pro: It's concise and can be used in a decorator. It also works with classes. 21 | 22 | ```python 23 | import cdf.core.configuration as conf 24 | 25 | @conf.map_section("api") 26 | def foo(key: str) -> None: 27 | print(key) 28 | 29 | @conf.map_values(key="api.key") 30 | def bar(key: str) -> None: 31 | print(key) 32 | 33 | def baz(key: str) -> None: 34 | print(key) 35 | 36 | baz.__cdf_resolve__ = ("api",) 37 | ``` 38 | 39 | 3. Using the `_cdf_resolve` kwarg to request the resolver: 40 | 41 | Pro: It's flexible and can be used in any function. It requires no imports. 42 | 43 | ```python 44 | def foo(key: str, _cdf_resolve=("api",)) -> None: 45 | print(key) 46 | 47 | def bar(key: str, _cdf_resolve={"key": "api.key"}) -> None: 48 | print(key) 49 | ``` 50 | """ 51 | 52 | import ast 53 | import functools 54 | import inspect 55 | import json 56 | import logging 57 | import os 58 | import re 59 | import string 60 | import typing as t 61 | from collections import ChainMap 62 | from contextlib import suppress 63 | from pathlib import Path 64 | 65 | import pydantic 66 | import pydantic_core 67 | from typing_extensions import ParamSpec 68 | 69 | if t.TYPE_CHECKING: 70 | from dynaconf.vendor.box import Box 71 | 72 | from cdf.types import M 73 | 74 | logger = logging.getLogger(__name__) 75 | 76 | T = t.TypeVar("T") 77 | P = ParamSpec("P") 78 | 79 | __all__ = [ 80 | "ConfigLoader", 81 | "ConfigResolver", 82 | "ConfigSource", 83 | "Request", 84 | "add_custom_converter", 85 | "remove_converter", 86 | "load_file", 87 | "map_config_section", 88 | "map_config_values", 89 | ] 90 | 91 | 92 | def load_file(path: Path) -> M.Result[t.Dict[str, t.Any], Exception]: 93 | """Load a configuration from a file path. 94 | 95 | Args: 96 | path: The file path. 97 | 98 | Returns: 99 | A Result monad with the configuration dictionary if the file format is JSON, YAML or TOML. 100 | Otherwise, a Result monad with an error. 101 | """ 102 | if path.suffix == ".json": 103 | return _load_json(path) 104 | if path.suffix in (".yaml", ".yml"): 105 | return _load_yaml(path) 106 | if path.suffix == ".toml": 107 | return _load_toml(path) 108 | return M.error(ValueError("Invalid file format, must be JSON, YAML or TOML")) 109 | 110 | 111 | def _load_json(path: Path) -> M.Result[t.Dict[str, t.Any], Exception]: 112 | """Load a configuration from a JSON file. 113 | 114 | Args: 115 | path: The file path to a valid JSON document. 116 | 117 | Returns: 118 | A Result monad with the configuration dictionary if the file format is JSON. Otherwise, a 119 | Result monad with an error. 120 | """ 121 | try: 122 | return M.ok(json.loads(path.read_text())) 123 | except Exception as e: 124 | return M.error(e) 125 | 126 | 127 | def _load_yaml(path: Path) -> M.Result[t.Dict[str, t.Any], Exception]: 128 | """Load a configuration from a YAML file. 129 | 130 | Args: 131 | path: The file path to a valid YAML document. 132 | 133 | Returns: 134 | A Result monad with the configuration dictionary if the file format is YAML. Otherwise, a 135 | Result monad with an error. 136 | """ 137 | try: 138 | import ruamel.yaml as yaml 139 | 140 | yaml_ = yaml.YAML() 141 | return M.ok(yaml_.load(path)) 142 | except Exception as e: 143 | return M.error(e) 144 | 145 | 146 | def _load_toml(path: Path) -> M.Result[t.Dict[str, t.Any], Exception]: 147 | """Load a configuration from a TOML file. 148 | 149 | Args: 150 | path: The file path to a valid TOML document. 151 | 152 | Returns: 153 | A Result monad with the configuration dictionary if the file format is TOML. Otherwise, a 154 | Result monad with an error. 155 | """ 156 | try: 157 | import tomlkit 158 | 159 | return M.ok(tomlkit.loads(path.read_text()).unwrap()) 160 | except Exception as e: 161 | return M.error(e) 162 | 163 | 164 | def _to_bool(value: str) -> bool: 165 | """Convert a string to a boolean.""" 166 | return value.lower() in ["true", "1", "yes"] 167 | 168 | 169 | def _resolve_template(template: str, **overrides: t.Any) -> str: 170 | """Resolve a template string using environment variables.""" 171 | return string.Template(template).substitute(overrides, **os.environ) 172 | 173 | 174 | _CONVERTERS = { 175 | "json": json.loads, 176 | "int": int, 177 | "float": float, 178 | "str": str, 179 | "bool": _to_bool, 180 | "path": os.path.abspath, 181 | "dict": ast.literal_eval, 182 | "list": ast.literal_eval, 183 | "tuple": ast.literal_eval, 184 | "set": ast.literal_eval, 185 | "resolve": None, 186 | } 187 | """Converters for configuration values.""" 188 | 189 | _CONVERTER_PATTERN = re.compile(r"@(\w+) ", re.IGNORECASE) 190 | """Pattern to match converters in a string.""" 191 | 192 | 193 | def add_custom_converter(name: str, converter: t.Callable[[str], t.Any]) -> None: 194 | """Add a custom converter to the configuration system.""" 195 | if name in _CONVERTERS: 196 | raise ValueError(f"Converter {name} already exists.") 197 | _CONVERTERS[name] = converter 198 | 199 | 200 | def get_converter(name: str) -> t.Callable[[str], t.Any]: 201 | """Get a custom converter from the configuration system.""" 202 | return _CONVERTERS[name] 203 | 204 | 205 | def remove_converter(name: str) -> None: 206 | """Remove a custom converter from the configuration system.""" 207 | if name not in _CONVERTERS: 208 | raise ValueError(f"Converter {name} does not exist.") 209 | del _CONVERTERS[name] 210 | 211 | 212 | def apply_converters( 213 | input_value: t.Any, resolver: t.Optional["ConfigResolver"] = None 214 | ) -> t.Any: 215 | """Apply converters to a string.""" 216 | if not isinstance(input_value, str): 217 | return input_value 218 | expanded_value = _resolve_template(input_value) 219 | converters = _CONVERTER_PATTERN.findall(expanded_value) 220 | if len(converters) == 0: 221 | return expanded_value 222 | base_value = _CONVERTER_PATTERN.sub("", expanded_value).lstrip() 223 | if not base_value: 224 | return None 225 | transformed_value = base_value 226 | for converter in reversed(converters): 227 | try: 228 | if converter.lower() == "resolve": 229 | if resolver is None: 230 | raise ValueError( 231 | "Resolver instance not provided but found @resolve converter" 232 | ) 233 | if transformed_value not in resolver: 234 | raise ValueError(f"Key not found in resolver: {transformed_value}") 235 | transformed_value = resolver[transformed_value] 236 | continue 237 | transformed_value = _CONVERTERS[converter.lower()](transformed_value) 238 | except KeyError as e: 239 | raise ValueError(f"Unknown converter: {converter}") from e 240 | except Exception as e: 241 | raise ValueError(f"Failed to convert value: {e}") from e 242 | return transformed_value 243 | 244 | 245 | def _to_box(mapping: t.Mapping[str, t.Any]) -> "Box": 246 | """Convert a mapping to a standardized Box.""" 247 | from dynaconf.vendor.box import Box 248 | 249 | return Box(mapping, box_dots=True) 250 | 251 | 252 | class _ConfigScopes(t.NamedTuple): 253 | """A struct to store named configuration scopes by precedence.""" 254 | 255 | explicit: "Box" 256 | """User-provided configuration passed as a dictionary.""" 257 | environment: "Box" 258 | """Environment-specific configuration loaded from a file.""" 259 | baseline: "Box" 260 | """Configuration loaded from a base config file.""" 261 | 262 | def resolve(self) -> "Box": 263 | """Resolve the configuration scopes.""" 264 | output = self.baseline 265 | output.merge_update(self.environment) 266 | output.merge_update(self.explicit) 267 | return output 268 | 269 | 270 | ConfigSource = t.Union[str, Path, t.Mapping[str, t.Any]] 271 | 272 | 273 | class ConfigLoader: 274 | """Load configuration from multiple sources.""" 275 | 276 | def __init__( 277 | self, 278 | *sources: ConfigSource, 279 | environment: str = "dev", 280 | ) -> None: 281 | """Initialize the configuration loader.""" 282 | self.environment = environment 283 | self.sources = list(sources) 284 | 285 | def load(self) -> t.MutableMapping[str, t.Any]: 286 | """Load configuration from sources.""" 287 | scopes = _ConfigScopes( 288 | explicit=_to_box({}), environment=_to_box({}), baseline=_to_box({}) 289 | ) 290 | for source in self.sources: 291 | if isinstance(source, dict): 292 | # User may provide configuration as a dictionary directly 293 | # in which case it takes precedence over other sources 294 | scopes.explicit.merge_update(source) 295 | elif isinstance(source, (str, Path)): 296 | # Load configuration from file 297 | path = Path(source) 298 | result = load_file(path) 299 | if result.is_ok(): 300 | scopes.baseline.merge_update(result.unwrap()) 301 | else: 302 | err = result.unwrap_err() 303 | if not isinstance(err, FileNotFoundError): 304 | logger.warning( 305 | f"Failed to load configuration from {path}: {result.unwrap_err()}" 306 | ) 307 | else: 308 | logger.debug(f"Configuration file not found: {path}") 309 | # Load environment-specific configuration from corresponding file 310 | # e.g. config.dev.json, config.dev.yaml, config.dev.toml 311 | env_path = path.with_name( 312 | f"{path.stem}.{self.environment}{path.suffix}" 313 | ) 314 | result = load_file(env_path) 315 | if result.is_ok(): 316 | scopes.environment.merge_update(result.unwrap()) 317 | else: 318 | err = result.unwrap_err() 319 | if not isinstance(err, FileNotFoundError): 320 | logger.warning( 321 | f"Failed to load configuration from {path}: {err}" 322 | ) 323 | else: 324 | logger.debug(f"Configuration file not found: {env_path}") 325 | return scopes.resolve() 326 | 327 | def import_source(self, source: ConfigSource, append: bool = True) -> None: 328 | """Include a new source of configuration.""" 329 | if append: 330 | # Takes priority within the same scope 331 | self.sources.append(source) 332 | else: 333 | self.sources.insert(0, source) 334 | 335 | def clear_sources(self) -> t.List[ConfigSource]: 336 | """Clear all sources of configuration returning the previous sources.""" 337 | cleared_sources = self.sources.copy() 338 | self.sources.clear() 339 | return cleared_sources 340 | 341 | 342 | _MISSING: t.Any = object() 343 | """A sentinel value for a missing configuration value.""" 344 | 345 | RESOLVER_HINT = "__cdf_resolve__" 346 | """A hint to engage the configuration resolver.""" 347 | 348 | 349 | def map_config_section( 350 | *sections: str, 351 | ) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]: 352 | """Mark a function to inject configuration values from a specific section.""" 353 | 354 | def decorator(func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]: 355 | setattr(inspect.unwrap(func_or_cls), RESOLVER_HINT, sections) 356 | return func_or_cls 357 | 358 | return decorator 359 | 360 | 361 | def map_config_values( 362 | **mapping: t.Any, 363 | ) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]: 364 | """Mark a function to inject configuration values from a specific mapping of param names to keys.""" 365 | 366 | def decorator(func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]: 367 | setattr(inspect.unwrap(func_or_cls), RESOLVER_HINT, mapping) 368 | return func_or_cls 369 | 370 | return decorator 371 | 372 | 373 | class Request: 374 | def __init__(self, item: str): 375 | self.item = item 376 | 377 | def __class_getitem__(cls, item: str) -> "Request": 378 | return cls(item) 379 | 380 | 381 | class ConfigLoaderProtocol(t.Protocol): 382 | environment: str 383 | sources: t.List[ConfigSource] 384 | 385 | def load(self) -> t.MutableMapping[str, t.Any]: ... 386 | 387 | def import_source(self, source: ConfigSource, append: bool = True) -> None: ... 388 | 389 | def clear_sources(self) -> t.List[ConfigSource]: ... 390 | 391 | 392 | class ConfigResolver(t.MutableMapping[str, t.Any]): 393 | """Resolve configuration values.""" 394 | 395 | def __init__( 396 | self, 397 | *sources: ConfigSource, 398 | environment: str = "dev", 399 | loader: ConfigLoaderProtocol = ConfigLoader("config.json"), 400 | deferred: bool = False, 401 | ) -> None: 402 | """Initialize the configuration resolver. 403 | 404 | The environment serves 2 purposes: 405 | 1. It determines supplementary configuration file to load, e.g. config.dev.json. 406 | 2. It prefixes configuration keys and prioritizes them over non-prefixed keys. e.g. dev.api.key. 407 | 408 | These are not mutually exclusive and can be used together. 409 | 410 | Args: 411 | sources: The sources of configuration. 412 | environment: The environment to load configuration for. 413 | loader: The configuration loader. 414 | deferred: If True, the configuration is not loaded until requested. 415 | """ 416 | self.environment = environment 417 | for source in sources: 418 | loader.import_source(source) 419 | self._loader = loader 420 | self._config = loader.load() if not deferred else None 421 | self._frozen_environment = os.environ.copy() 422 | self._explicit_values = _to_box({}) 423 | 424 | @property 425 | def wrapped(self) -> t.MutableMapping[str, t.Any]: 426 | """Get the configuration dictionary.""" 427 | if self._config is None: 428 | self._config = _to_box(self._loader.load()) 429 | return ChainMap(self._explicit_values, self._config) 430 | 431 | def __getitem__(self, key: str) -> t.Any: 432 | """Get a configuration value.""" 433 | try: 434 | v = self.wrapped[f"{self.environment}.{key}"] 435 | except KeyError: 436 | v = self.wrapped[key] 437 | return self.apply_converters(v, self) 438 | 439 | def __setitem__(self, key: str, value: t.Any) -> None: 440 | """Set a configuration value.""" 441 | self._explicit_values[f"{self.environment}.{key}"] = value 442 | 443 | def __delitem__(self, key: str) -> None: 444 | self._explicit_values.pop(f"{self.environment}.{key}", None) 445 | 446 | def __iter__(self) -> t.Iterator[str]: 447 | return iter(self.wrapped) 448 | 449 | def __len__(self) -> int: 450 | return len(self.wrapped) 451 | 452 | def __getattr__(self, key: str) -> t.Any: 453 | """Get a configuration value.""" 454 | try: 455 | return self[key] 456 | except KeyError as e: 457 | raise AttributeError from e 458 | 459 | def __enter__(self) -> "ConfigResolver": 460 | """Enter a context.""" 461 | return self 462 | 463 | def __exit__(self, *args) -> None: 464 | """Exit a context.""" 465 | os.environ.clear() 466 | os.environ.update(self._frozen_environment) 467 | 468 | def __repr__(self) -> str: 469 | """Get a string representation of the configuration resolver.""" 470 | return f"{self.__class__.__name__}(<{len(self._loader.sources)} sources>)" 471 | 472 | def set_environment(self, environment: str) -> None: 473 | """Set the environment of the configuration resolver.""" 474 | self.environment = environment 475 | self._loader.environment = environment 476 | self._config = None 477 | 478 | def import_source(self, source: ConfigSource, append: bool = True) -> None: 479 | """Include a new source of configuration.""" 480 | self._loader.import_source(source, append) 481 | self._config = None 482 | 483 | def clear_sources(self) -> t.List[ConfigSource]: 484 | """Clear all sources of configuration returning the previous sources.""" 485 | sources = self._loader.clear_sources() 486 | self._config = None 487 | return sources 488 | 489 | map_section = staticmethod(map_config_section) 490 | """Mark a function to inject configuration values from a specific section.""" 491 | 492 | map_values = staticmethod(map_config_values) 493 | """Mark a function to inject configuration values from a specific mapping of param names to keys.""" 494 | 495 | add_custom_converter = staticmethod(add_custom_converter) 496 | """Add a custom converter to the configuration system.""" 497 | 498 | apply_converters = staticmethod(apply_converters) 499 | """Apply converters to a string.""" 500 | 501 | KWARG_HINT = "_cdf_resolve" 502 | """A hint supplied in a kwarg to engage the configuration resolver.""" 503 | 504 | def _parse_hint_from_params( 505 | self, func_or_cls: t.Callable, sig: t.Optional[inspect.Signature] = None 506 | ) -> t.Optional[t.Union[t.Tuple[str, ...], t.Mapping[str, str]]]: 507 | """Get the sections or explicit lookups from a function. 508 | 509 | This assumes a kwarg named `_cdf_resolve` that is either a tuple of section names or 510 | a dictionary of param names to config keys is present in the function signature. 511 | """ 512 | sig = sig or inspect.signature(func_or_cls) 513 | if self.KWARG_HINT in sig.parameters: 514 | resolver_spec = sig.parameters[self.KWARG_HINT] 515 | if isinstance(resolver_spec.default, (tuple, dict)): 516 | return resolver_spec.default 517 | 518 | def resolve_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]: 519 | """Resolve configuration values into a function or class.""" 520 | if not callable(func_or_cls): 521 | return func_or_cls 522 | 523 | sig = inspect.signature(func_or_cls) 524 | is_resolved_sentinel = "__config_resolved__" 525 | 526 | resolver_hint = getattr( 527 | inspect.unwrap(func_or_cls), 528 | RESOLVER_HINT, 529 | self._parse_hint_from_params(func_or_cls, sig), 530 | ) 531 | 532 | if any(hasattr(f, is_resolved_sentinel) for f in _iter_wrapped(func_or_cls)): 533 | return func_or_cls 534 | 535 | @functools.wraps(func_or_cls) 536 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 537 | bound_args = sig.bind_partial(*args, **kwargs) 538 | bound_args.apply_defaults() 539 | 540 | # Apply converters to string literal arguments 541 | for arg_name, arg_value in bound_args.arguments.items(): 542 | if isinstance(arg_value, str): 543 | with suppress(Exception): 544 | bound_args.arguments[arg_name] = self.apply_converters( 545 | arg_value, 546 | self, 547 | ) 548 | 549 | # Resolve configuration values 550 | for name, param in sig.parameters.items(): 551 | value = _MISSING 552 | if not self.is_resolvable(param): 553 | continue 554 | 555 | # 1. Prioritize Request annotations 556 | elif request := self.extract_request_annotation(param): 557 | value = self.get(request, _MISSING) 558 | 559 | # 2. Use explicit lookups if provided 560 | elif isinstance(resolver_hint, dict): 561 | if name not in resolver_hint: 562 | continue 563 | value = self.get(resolver_hint[name], _MISSING) 564 | 565 | # 3. Use section-based lookups if provided 566 | elif isinstance(resolver_hint, (tuple, list)): 567 | value = self.get(".".join((*resolver_hint, name)), _MISSING) 568 | 569 | # Inject the value into the function 570 | if value is not _MISSING: 571 | bound_args.arguments[name] = self.apply_converters(value, self) 572 | 573 | return func_or_cls(*bound_args.args, **bound_args.kwargs) 574 | 575 | setattr(wrapper, is_resolved_sentinel, True) 576 | return wrapper 577 | 578 | def is_resolvable(self, param: inspect.Parameter) -> bool: 579 | """Check if a parameter is injectable.""" 580 | return param.default in (param.empty, None) 581 | 582 | @staticmethod 583 | def extract_request_annotation(param: inspect.Parameter) -> t.Optional[str]: 584 | """Extract a request annotation from a parameter.""" 585 | for hint in getattr(param.annotation, "__metadata__", ()): 586 | if isinstance(hint, Request): 587 | return hint.item 588 | 589 | def __call__( 590 | self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any 591 | ) -> T: 592 | """Invoke a callable with injected configuration values.""" 593 | configured_f = self.resolve_defaults(func_or_cls) 594 | if not callable(configured_f): 595 | return configured_f 596 | return configured_f(*args, **kwargs) 597 | 598 | @classmethod 599 | def __get_pydantic_core_schema__( 600 | cls, source_type: t.Any, handler: pydantic.GetCoreSchemaHandler 601 | ) -> pydantic_core.CoreSchema: 602 | return pydantic_core.core_schema.dict_schema( 603 | keys_schema=pydantic_core.core_schema.str_schema(), 604 | values_schema=pydantic_core.core_schema.any_schema(), 605 | ) 606 | 607 | 608 | def _iter_wrapped(f: t.Callable): 609 | yield f 610 | f_w = inspect.unwrap(f) 611 | if f_w is not f: 612 | yield from _iter_wrapped(f_w) 613 | -------------------------------------------------------------------------------- /src/cdf/core/injector/registry.py: -------------------------------------------------------------------------------- 1 | """Dependency registry with lifecycle management.""" 2 | 3 | import enum 4 | import inspect 5 | import logging 6 | import os 7 | import sys 8 | import types 9 | import typing as t 10 | from collections import ChainMap 11 | from functools import partial, partialmethod, wraps 12 | 13 | import pydantic 14 | import pydantic_core 15 | from typing_extensions import ParamSpec, Self 16 | 17 | import cdf.core.configuration as conf 18 | from cdf.core.context import get_default_callable_lifecycle 19 | from cdf.core.injector.errors import DependencyCycleError, DependencyMutationError 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | T = t.TypeVar("T") 24 | P = ParamSpec("P") 25 | 26 | __all__ = [ 27 | "DependencyRegistry", 28 | "Dependency", 29 | "Lifecycle", 30 | "DependencyKey", 31 | "GLOBAL_REGISTRY", 32 | ] 33 | 34 | 35 | class Lifecycle(enum.Enum): 36 | """Lifecycle of a dependency.""" 37 | 38 | PROTOTYPE = enum.auto() 39 | """A prototype dependency is created every time it is requested""" 40 | 41 | SINGLETON = enum.auto() 42 | """A singleton dependency is created once and shared.""" 43 | 44 | INSTANCE = enum.auto() 45 | """An instance dependency is a global object which is not created by the container.""" 46 | 47 | @property 48 | def is_prototype(self) -> bool: 49 | """Check if the lifecycle is prototype.""" 50 | return self == Lifecycle.PROTOTYPE 51 | 52 | @property 53 | def is_singleton(self) -> bool: 54 | """Check if the lifecycle is singleton.""" 55 | return self == Lifecycle.SINGLETON 56 | 57 | @property 58 | def is_instance(self) -> bool: 59 | """Check if the lifecycle is instance.""" 60 | return self == Lifecycle.INSTANCE 61 | 62 | @property 63 | def is_deferred(self) -> bool: 64 | """Check if the object to be created is deferred.""" 65 | return self.is_prototype or self.is_singleton 66 | 67 | def __str__(self) -> str: 68 | return self.name.lower() 69 | 70 | @classmethod 71 | def default_for(cls, obj: t.Any) -> "Lifecycle": 72 | """Get the default lifecycle.""" 73 | if callable(obj): 74 | return get_default_callable_lifecycle() 75 | return cls.INSTANCE 76 | 77 | 78 | class TypedKey(t.NamedTuple): 79 | """A key which is a tuple of a name and a type.""" 80 | 81 | name: str 82 | type_: t.Type[t.Any] 83 | 84 | @property 85 | def type_name(self) -> t.Optional[str]: 86 | """Get the name of the type if applicable.""" 87 | return self.type_.__name__ 88 | 89 | def __str__(self) -> str: 90 | return f"{self.name}: {self.type_name}" 91 | 92 | def __repr__(self) -> str: 93 | return f"" 94 | 95 | def __eq__(self, other: t.Any) -> bool: 96 | """Two keys are equal if their names and base types match.""" 97 | if not isinstance(other, (TypedKey, tuple)): 98 | return False 99 | return self.name == other[0] and _same_type(self.type_, other[1]) 100 | 101 | def __hash__(self) -> int: 102 | """Hash the key with the effective type if possible.""" 103 | try: 104 | return hash((self.name, _unwrap_type(self.type_))) 105 | except TypeError as e: 106 | logger.warning(f"Failed to hash key {self!r}: {e}") 107 | return hash((self.name, self.type_)) 108 | 109 | 110 | DependencyKey = t.Union[str, t.Tuple[str, t.Type[t.Any]], TypedKey] 111 | """A string or a typed key.""" 112 | 113 | 114 | def _unwrap_optional(hint: t.Type) -> t.Type: 115 | """Unwrap Optional type hint. Also unwraps types.UnionType like str | None 116 | 117 | Args: 118 | hint: The type hint. 119 | 120 | Returns: 121 | The unwrapped type hint. 122 | """ 123 | args = t.get_args(hint) 124 | if len(args) != 2 or args[1] is not type(None): 125 | return hint 126 | return args[0] 127 | 128 | 129 | def _is_union(hint: t.Type) -> bool: 130 | """Check if a type hint is a Union. 131 | 132 | Args: 133 | hint: The type hint. 134 | 135 | Returns: 136 | True if the type hint is a Union. 137 | """ 138 | return hint is t.Union or (sys.version_info >= (3, 10) and hint is types.UnionType) 139 | 140 | 141 | def _is_ambiguous_type(hint: t.Optional[t.Type]) -> bool: 142 | """Check if a type hint or Signature annotation is ambiguous. 143 | 144 | Args: 145 | hint: The type hint. 146 | 147 | Returns: 148 | True if the type hint is ambiguous. 149 | """ 150 | return hint in ( 151 | object, 152 | t.Any, 153 | None, 154 | type(None), 155 | t.NoReturn, 156 | inspect.Parameter.empty, 157 | type(lambda: None), 158 | ) 159 | 160 | 161 | def _unwrap_type(hint: t.Type) -> t.Type: 162 | """Unwrap a type hint. 163 | 164 | For a Union, this is the base type if all types are the same base type. 165 | Otherwise, it is the hint itself with Optional unwrapped. 166 | 167 | Args: 168 | hint: The type hint. 169 | 170 | Returns: 171 | The unwrapped type hint. 172 | """ 173 | hint = _unwrap_optional(hint) 174 | if _is_union(hint): 175 | args = list(map(_unwrap_optional, t.get_args(hint))) 176 | if not args: 177 | return hint 178 | f_base = getattr(args[0], "__base__", None) 179 | if f_base and all(f_base is getattr(arg, "__base__", None) for arg in args[1:]): 180 | # Ex. Union[HarnessFFProvider, SplitFFProvider, LaunchDarklyFFProvider] 181 | # == BaseFFProvider 182 | return f_base 183 | return hint 184 | 185 | 186 | def _same_type(hint1: t.Type, hint2: t.Type) -> bool: 187 | """Check if two type hints are of the same unwrapped type. 188 | 189 | Args: 190 | hint1: The first type hint. 191 | hint2: The second type hint. 192 | 193 | Returns: 194 | True if the unwrapped types are the same. 195 | """ 196 | return _unwrap_type(hint1) is _unwrap_type(hint2) 197 | 198 | 199 | @t.overload 200 | def _normalize_key(key: str) -> str: ... 201 | 202 | 203 | @t.overload 204 | def _normalize_key(key: t.Union[t.Tuple[str, t.Any], TypedKey]) -> TypedKey: ... 205 | 206 | 207 | def _normalize_key( 208 | key: t.Union[str, t.Tuple[str, t.Type[t.Any]], TypedKey], 209 | ) -> t.Union[str, TypedKey]: 210 | """Normalize a key 2-tuple to a TypedKey if it is not already, preserve str. 211 | 212 | Args: 213 | key: The key to normalize. 214 | 215 | Returns: 216 | The normalized key. 217 | """ 218 | if isinstance(key, str): 219 | return key 220 | k, t_ = key 221 | return TypedKey(k, _unwrap_type(t_)) 222 | 223 | 224 | def _safe_get_type_hints(obj: t.Any) -> t.Dict[str, t.Type]: 225 | """Get type hints for an object, ignoring errors. 226 | 227 | Args: 228 | obj: The object to get type hints for. 229 | 230 | Returns: 231 | A dictionary of attribute names to type hints. 232 | """ 233 | try: 234 | if isinstance(obj, partial): 235 | obj = obj.func 236 | return t.get_type_hints(obj) 237 | except Exception as e: 238 | logger.debug(f"Failed to get type hints for {obj!r}: {e}") 239 | return {} 240 | 241 | 242 | class Dependency(pydantic.BaseModel, t.Generic[T]): 243 | """A Monadic type which wraps a value with lifecycle and allows simple transformations.""" 244 | 245 | factory: t.Callable[..., T] 246 | """The factory or instance of the dependency.""" 247 | lifecycle: Lifecycle = Lifecycle.SINGLETON 248 | """The lifecycle of the dependency.""" 249 | 250 | conf_spec: t.Optional[t.Union[t.Tuple[str, ...], t.Dict[str, str]]] = None 251 | """A hint for configuration values.""" 252 | alias: t.Optional[str] = None 253 | """Used as an alternative to inferring the name from the factory.""" 254 | 255 | _instance: t.Optional[T] = None 256 | """The instance of the dependency once resolved.""" 257 | _is_resolved: bool = False 258 | """Flag to indicate if the dependency has been unwrapped.""" 259 | 260 | @pydantic.model_validator(mode="after") 261 | def _apply_spec(self) -> Self: 262 | """Apply the configuration spec to the dependency.""" 263 | spec = self.conf_spec 264 | if isinstance(spec, dict): 265 | self.map(conf.map_config_values(**spec)) 266 | elif isinstance(spec, tuple): 267 | self.map(conf.map_config_section(*spec)) 268 | return self 269 | 270 | @pydantic.model_validator(mode="before") 271 | @classmethod 272 | def _ensure_lifecycle(cls, data: t.Any) -> t.Any: 273 | """Ensure a valid lifecycle is set for the dependency.""" 274 | 275 | if isinstance(data, dict): 276 | factory = data["factory"] 277 | lc = data.get("lifecycle", Lifecycle.default_for(factory)) 278 | if isinstance(lc, str): 279 | lc = Lifecycle[lc.upper()] 280 | if not isinstance(lc, Lifecycle): 281 | raise ValueError(f"Invalid lifecycle {lc=}") 282 | if not (lc.is_instance or callable(factory)): 283 | raise ValueError(f"Value must be callable for {lc=}") 284 | data["lifecycle"] = lc 285 | return data 286 | 287 | @pydantic.field_validator("factory", mode="before") 288 | @classmethod 289 | def _ensure_callable(cls, factory: t.Any) -> t.Any: 290 | """Ensure the factory is callable.""" 291 | if not callable(factory): 292 | 293 | def defer() -> T: 294 | return factory 295 | 296 | defer.__name__ = f"factory_{os.urandom(4).hex()}" 297 | return defer 298 | return factory 299 | 300 | @classmethod 301 | def instance(cls, instance: t.Any) -> "Dependency": 302 | """Create a dependency from an instance. 303 | 304 | Args: 305 | instance: The instance to use as the dependency. 306 | 307 | Returns: 308 | A new Dependency object with the instance lifecycle. 309 | """ 310 | obj = cls(factory=instance, lifecycle=Lifecycle.INSTANCE) 311 | obj._instance = instance 312 | obj._is_resolved = True 313 | return obj 314 | 315 | @classmethod 316 | def singleton( 317 | cls, factory: t.Callable[..., T], *args: t.Any, **kwargs: t.Any 318 | ) -> "Dependency": 319 | """Create a singleton dependency. 320 | 321 | Args: 322 | factory: The factory function to create the dependency. 323 | args: Positional arguments to pass to the factory. 324 | kwargs: Keyword arguments to pass to the factory. 325 | 326 | Returns: 327 | A new Dependency object with the singleton lifecycle. 328 | """ 329 | if callable(factory) and (args or kwargs): 330 | factory = partial(factory, *args, **kwargs) 331 | return cls(factory=factory, lifecycle=Lifecycle.SINGLETON) 332 | 333 | @classmethod 334 | def prototype( 335 | cls, factory: t.Callable[..., T], *args: t.Any, **kwargs: t.Any 336 | ) -> "Dependency": 337 | """Create a prototype dependency. 338 | 339 | Args: 340 | factory: The factory function to create the dependency. 341 | args: Positional arguments to pass to the factory. 342 | kwargs: Keyword arguments to pass to the factory. 343 | 344 | Returns: 345 | A new Dependency object with the prototype lifecycle. 346 | """ 347 | if callable(factory) and (args or kwargs): 348 | factory = partial(factory, *args, **kwargs) 349 | return cls(factory=factory, lifecycle=Lifecycle.PROTOTYPE) 350 | 351 | @classmethod 352 | def wrap(cls, obj: t.Any, *args: t.Any, **kwargs: t.Any) -> Self: 353 | """Wrap an object as a dependency. 354 | 355 | Assumes singleton lifecycle for callables unless a default lifecycle context is set. 356 | 357 | Args: 358 | obj: The object to wrap. 359 | 360 | Returns: 361 | A new Dependency object with the object as the factory. 362 | """ 363 | if callable(obj): 364 | if args or kwargs: 365 | obj = partial(obj, *args, **kwargs) 366 | return cls(factory=obj, lifecycle=get_default_callable_lifecycle()) 367 | return cls(factory=obj, lifecycle=Lifecycle.INSTANCE) 368 | 369 | def map_value(self, func: t.Callable[[T], T]) -> Self: 370 | """Apply a function to the unwrapped value. 371 | 372 | Args: 373 | func: The function to apply to the unwrapped value. 374 | 375 | Returns: 376 | A new Dependency object with the function applied. 377 | """ 378 | if self._is_resolved: 379 | self._instance = func(self._instance) # type: ignore 380 | return self 381 | 382 | factory = self.factory 383 | 384 | @wraps(factory) 385 | def wrapper() -> T: 386 | return func(factory()) 387 | 388 | self.factory = wrapper 389 | return self 390 | 391 | def map( 392 | self, 393 | *funcs: t.Callable[[t.Callable[..., T]], t.Callable[..., T]], 394 | idempotent: bool = False, 395 | ) -> Self: 396 | """Apply a sequence of transformations to the wrapped value. 397 | 398 | The transformations are applied in order. This is a no-op if the dependency is 399 | already resolved and idempotent is True or the dependency is an instance. 400 | 401 | Args: 402 | funcs: The functions to apply to the wrapped value. 403 | idempotent: If True, allow transformations on resolved dependencies to be a no-op. 404 | 405 | Returns: 406 | The Dependency object with the transformations applied. 407 | """ 408 | if self._is_resolved: 409 | if self.lifecycle.is_instance or idempotent: 410 | return self 411 | raise DependencyMutationError( 412 | f"Dependency {self!r} is already resolved, cannot apply transformations to factory" 413 | ) 414 | factory = self.factory 415 | for func in funcs: 416 | factory = func(factory) 417 | self.factory = factory 418 | return self 419 | 420 | def unwrap(self) -> T: 421 | """Unwrap the value from the factory.""" 422 | if self.lifecycle.is_prototype: 423 | return self.factory() 424 | if self._instance is not None: 425 | return self._instance 426 | self._instance = self.factory() 427 | if self.lifecycle.is_singleton: 428 | self._is_resolved = True 429 | return self._instance 430 | 431 | def __str__(self) -> str: 432 | return f"{self.factory} ({self.lifecycle})" 433 | 434 | def __repr__(self) -> str: 435 | return f"" 436 | 437 | def __call__(self) -> T: 438 | """Alias for unwrap.""" 439 | return self.unwrap() 440 | 441 | def try_infer_type(self) -> t.Optional[t.Type[T]]: 442 | """Get the effective type of the dependency.""" 443 | if inspect.isclass(self.factory): 444 | return _unwrap_type(self.factory) 445 | if inspect.isfunction(self.factory): 446 | if hint := _safe_get_type_hints(inspect.unwrap(self.factory)).get("return"): 447 | return _unwrap_type(hint) 448 | if self._is_resolved: 449 | return _unwrap_type(type(self._instance)) 450 | 451 | def try_infer_name(self) -> t.Optional[str]: 452 | """Infer the name of the dependency from the factory.""" 453 | if self.alias: 454 | return self.alias 455 | if isinstance(self.factory, partial): 456 | f = inspect.unwrap(self.factory.func) 457 | else: 458 | f = inspect.unwrap(self.factory) 459 | if inspect.isfunction(f): 460 | return f.__name__ 461 | if inspect.isclass(f): 462 | return f.__name__ 463 | return getattr(f, "name", None) 464 | 465 | def generate_key( 466 | self, name: t.Optional[DependencyKey] = None 467 | ) -> t.Union[str, TypedKey]: 468 | """Generate a typed key for the dependency. 469 | 470 | Args: 471 | name: The name of the dependency. 472 | 473 | Returns: 474 | A typed key if the type can be inferred, else the name. 475 | """ 476 | if not name: 477 | name = self.try_infer_name() 478 | if not name: 479 | raise ValueError( 480 | "Cannot infer name for dependency and no name or alias provided" 481 | ) 482 | if isinstance(name, TypedKey): 483 | return name 484 | elif isinstance(name, tuple): 485 | return TypedKey(name[0], name[1]) 486 | hint = self.try_infer_type() 487 | return TypedKey(name, hint) if hint and not _is_ambiguous_type(hint) else name 488 | 489 | 490 | class DependencyRegistry(t.MutableMapping[DependencyKey, Dependency]): 491 | """A registry for dependencies with lifecycle management. 492 | 493 | Dependencies can be registered with a name or a typed key. Typed keys are tuples 494 | of a name and a type hint. Dependencies can be added with a lifecycle, which can be 495 | one of prototype, singleton, or instance. Dependencies can be retrieved by name or 496 | typed key. Dependencies can be injected into functions or classes. Dependencies can 497 | be wired into callables to resolve a dependency graph. 498 | """ 499 | 500 | lifecycle = Lifecycle 501 | 502 | def __init__(self, strict: bool = False) -> None: 503 | """Initialize the registry. 504 | 505 | Args: 506 | strict: If True, do not inject an untyped lookup for a typed dependency. 507 | """ 508 | self.strict = strict 509 | self._typed_dependencies: t.Dict[TypedKey, Dependency] = {} 510 | self._untyped_dependencies: t.Dict[str, Dependency] = {} 511 | self._resolving: t.Set[t.Union[str, TypedKey]] = set() 512 | 513 | @property 514 | def dependencies(self) -> ChainMap[t.Any, Dependency]: 515 | """Get all dependencies.""" 516 | return ChainMap(self._typed_dependencies, self._untyped_dependencies) 517 | 518 | def add( 519 | self, 520 | key: DependencyKey, 521 | value: t.Any, 522 | lifecycle: t.Optional[Lifecycle] = None, 523 | override: bool = False, 524 | init_args: t.Tuple[t.Any, ...] = (), 525 | init_kwargs: t.Optional[t.Dict[str, t.Any]] = None, 526 | ) -> None: 527 | """Register a dependency with the container. 528 | 529 | Args: 530 | key: The name of the dependency. 531 | value: The factory or instance of the dependency. 532 | lifecycle: The lifecycle of the dependency. 533 | override: If True, override an existing dependency. 534 | init_args: Arguments to initialize the factory with. 535 | init_kwargs: Keyword arguments to initialize the factory with. 536 | """ 537 | 538 | # Assume singleton lifecycle if the value is callable unless set in context 539 | if lifecycle is None: 540 | lifecycle = Lifecycle.default_for(value) 541 | 542 | # If the value is callable and has initialization args, bind them early so 543 | # we don't need to schlepp them around 544 | if callable(value) and (init_args or init_kwargs): 545 | value = partial(value, *init_args, **(init_kwargs or {})) 546 | 547 | # Register the dependency 548 | dependency = Dependency(factory=value, lifecycle=lifecycle) 549 | dependency_key = dependency.generate_key(key) 550 | if self.has(dependency_key) and not override: 551 | raise ValueError(f'Dependency "{dependency_key}" is already registered') 552 | if isinstance(dependency_key, TypedKey): 553 | self._typed_dependencies[dependency_key] = dependency 554 | # Allow untyped access to typed dependencies for convenience if not strict 555 | # or if the hint is not a distinct type 556 | if not self.strict or _is_ambiguous_type(dependency_key.type_): 557 | self._untyped_dependencies[dependency_key.name] = dependency 558 | else: 559 | self._untyped_dependencies[dependency_key] = dependency 560 | 561 | add_prototype = partialmethod(add, lifecycle=Lifecycle.PROTOTYPE) 562 | add_singleton = partialmethod(add, lifecycle=Lifecycle.SINGLETON) 563 | add_instance = partialmethod(add, lifecycle=Lifecycle.INSTANCE) 564 | 565 | def add_from_dependency( 566 | self, 567 | dependency: Dependency, 568 | key: t.Optional[DependencyKey] = None, 569 | override: bool = False, 570 | ) -> None: 571 | """Add a Dependency object to the container. 572 | 573 | Args: 574 | key: The name or typed key of the dependency. 575 | dependency: The dependency object. 576 | override: If True, override an existing dependency 577 | """ 578 | dependency_key = dependency.generate_key(key) 579 | if self.has(dependency_key) and not override: 580 | raise ValueError( 581 | f'Dependency "{dependency_key}" is already registered, use a different name to avoid conflicts' 582 | ) 583 | if isinstance(dependency_key, TypedKey): 584 | self._typed_dependencies[dependency_key] = dependency 585 | if not self.strict or _is_ambiguous_type(dependency_key.type_): 586 | self._untyped_dependencies[dependency_key.name] = dependency 587 | else: 588 | self._untyped_dependencies[dependency_key] = dependency 589 | 590 | def remove(self, name_or_key: DependencyKey) -> None: 591 | """Remove a dependency by name or key from the container. 592 | 593 | Args: 594 | name_or_key: The name or typed key of the dependency. 595 | """ 596 | key = _normalize_key(name_or_key) 597 | if isinstance(key, str): 598 | if key in self._untyped_dependencies: 599 | del self._untyped_dependencies[key] 600 | else: 601 | raise KeyError(f'Dependency "{key}" is not registered') 602 | elif key in self._typed_dependencies: 603 | del self._typed_dependencies[key] 604 | else: 605 | raise KeyError(f'Dependency "{key}" is not registered') 606 | 607 | def clear(self) -> None: 608 | """Clear all dependencies and singletons.""" 609 | self._typed_dependencies.clear() 610 | self._untyped_dependencies.clear() 611 | 612 | def has(self, name_or_key: DependencyKey) -> bool: 613 | """Check if a dependency is registered. 614 | 615 | Args: 616 | name_or_key: The name or typed key of the dependency. 617 | """ 618 | return name_or_key in self.dependencies 619 | 620 | def resolve(self, name_or_key: DependencyKey, must_exist: bool = False) -> t.Any: 621 | """Get a dependency. 622 | 623 | Args: 624 | name_or_key: The name or typed key of the dependency. 625 | must_exist: If True, raise KeyError if the dependency is not found. 626 | 627 | Returns: 628 | The dependency if found, else None. 629 | """ 630 | key = _normalize_key(name_or_key) 631 | 632 | # Resolve the dependency 633 | if isinstance(key, str): 634 | if key not in self._untyped_dependencies: 635 | if must_exist: 636 | raise KeyError(f'Dependency "{key}" is not registered') 637 | return 638 | dep = self.dependencies[key] 639 | else: 640 | if _is_union(key.type_): 641 | types = map(_unwrap_type, t.get_args(key.type_)) 642 | else: 643 | types = [key.type_] 644 | for type_ in types: 645 | key = TypedKey(key.name, type_) 646 | if self.has(key): 647 | break 648 | else: 649 | if must_exist: 650 | raise KeyError(f'Dependency "{key}" is not registered') 651 | return 652 | dep = self.dependencies[key] 653 | 654 | # Detect dependency cycles 655 | if key in self._resolving: 656 | raise DependencyCycleError( 657 | f"Dependency cycle detected while resolving {key} for {dep.factory!r}" 658 | ) 659 | 660 | # Handle the lifecycle of the dependency, recursively resolving dependencies 661 | self._resolving.add(key) 662 | try: 663 | return dep.map(self.wire).unwrap() 664 | except DependencyMutationError: 665 | return dep.unwrap() 666 | finally: 667 | self._resolving.remove(key) 668 | 669 | resolve_or_raise = partialmethod(resolve, must_exist=True) 670 | 671 | def __contains__(self, name: t.Any) -> bool: 672 | """Check if a dependency is registered.""" 673 | return self.has(name) 674 | 675 | def __getitem__(self, name: DependencyKey) -> t.Any: 676 | """Get a dependency. Raises KeyError if not found.""" 677 | return self.resolve(name, must_exist=True) 678 | 679 | def __setitem__(self, name: DependencyKey, value: t.Any) -> None: 680 | """Add a dependency. Defaults to singleton lifecycle if callable, else instance.""" 681 | self.add(name, value, override=True) 682 | 683 | def __delitem__(self, name: DependencyKey) -> None: 684 | """Remove a dependency.""" 685 | self.remove(name) 686 | 687 | def wire(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]: 688 | """Inject dependencies into a function. 689 | 690 | Args: 691 | func_or_cls: The function or class to inject dependencies into. 692 | 693 | Returns: 694 | A function that can be called with dependencies injected 695 | """ 696 | if not callable(func_or_cls): 697 | return func_or_cls 698 | 699 | sig = inspect.signature(func_or_cls) 700 | is_resolved_sentinel = "__deps_resolved__" 701 | 702 | if any(hasattr(f, is_resolved_sentinel) for f in _iter_wrapped(func_or_cls)): 703 | return func_or_cls 704 | 705 | @wraps(func_or_cls) 706 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 707 | bound_args = sig.bind_partial(*args, **kwargs) 708 | for name, param in sig.parameters.items(): 709 | if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD): 710 | continue 711 | if param.default not in (param.empty, None): 712 | continue 713 | if name not in bound_args.arguments: 714 | dep = None 715 | # Try to resolve a typed dependency 716 | if not _is_ambiguous_type(param.annotation): 717 | dep = self.resolve((name, param.annotation)) 718 | # Fallback to untyped injection 719 | if dep is None: 720 | dep = self.resolve(name) 721 | # If a dependency is found, inject it 722 | if dep is not None: 723 | bound_args.arguments[name] = dep 724 | return func_or_cls(*bound_args.args, **bound_args.kwargs) 725 | 726 | setattr(wrapper, is_resolved_sentinel, True) 727 | return wrapper 728 | 729 | def __call__( 730 | self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any 731 | ) -> T: 732 | """Invoke a callable with dependencies injected from the registry. 733 | 734 | Args: 735 | func_or_cls: The function or class to invoke. 736 | args: Positional arguments to pass to the callable. 737 | kwargs: Keyword arguments to pass to the callable. 738 | 739 | Returns: 740 | The result of the callable 741 | """ 742 | wired_f = self.wire(func_or_cls) 743 | if not callable(wired_f): 744 | return wired_f 745 | return wired_f(*args, **kwargs) 746 | 747 | def __iter__(self) -> t.Iterator[TypedKey]: 748 | """Iterate over dependency names.""" 749 | return iter(self.dependencies) 750 | 751 | def __len__(self) -> int: 752 | """Return the number of dependencies.""" 753 | return len(self.dependencies) 754 | 755 | def __repr__(self) -> str: 756 | return f"DependencyRegistry(<{list(self.dependencies.keys())}>)" 757 | 758 | def __str__(self) -> str: 759 | return repr(self) 760 | 761 | def __bool__(self) -> bool: 762 | """True if the registry has dependencies.""" 763 | return bool(self.dependencies) 764 | 765 | def __or__(self, other: "DependencyRegistry") -> "DependencyRegistry": 766 | """Merge two registries like pythons dict union overload.""" 767 | self._untyped_dependencies = { 768 | **self._untyped_dependencies, 769 | **other._untyped_dependencies, 770 | } 771 | self._typed_dependencies = { 772 | **self._typed_dependencies, 773 | **other._typed_dependencies, 774 | } 775 | return self 776 | 777 | def __getstate__(self) -> t.Dict[str, t.Any]: 778 | """Serialize the state.""" 779 | return { 780 | "_typed_dependencies": self._typed_dependencies, 781 | "_untyped_dependencies": self._untyped_dependencies, 782 | "_resolving": self._resolving, 783 | } 784 | 785 | def __setstate__(self, state: t.Dict[str, t.Any]) -> None: 786 | """Deserialize the state.""" 787 | self._typed_dependencies = state["_typed_dependencies"] 788 | self._untyped_dependencies = state["_untyped_dependencies"] 789 | self._resolving = state["_resolving"] 790 | 791 | @classmethod 792 | def __get_pydantic_core_schema__( 793 | cls, source_type: t.Any, handler: pydantic.GetCoreSchemaHandler 794 | ) -> pydantic_core.CoreSchema: 795 | return pydantic_core.core_schema.dict_schema( 796 | keys_schema=pydantic_core.core_schema.union_schema( 797 | [ 798 | pydantic_core.core_schema.str_schema(), 799 | pydantic_core.core_schema.tuple_schema( 800 | [ 801 | pydantic_core.core_schema.str_schema(), 802 | pydantic_core.core_schema.any_schema(), 803 | ] 804 | ), 805 | pydantic_core.core_schema.is_instance_schema(TypedKey), 806 | ] 807 | ), 808 | values_schema=pydantic_core.core_schema.any_schema(), 809 | ) 810 | 811 | 812 | def _iter_wrapped(f: t.Callable): 813 | yield f 814 | f_w = inspect.unwrap(f) 815 | if f_w is not f: 816 | yield from _iter_wrapped(f_w) 817 | 818 | 819 | GLOBAL_REGISTRY = DependencyRegistry() 820 | """A global dependency registry.""" 821 | -------------------------------------------------------------------------------- /src/cdf/types/monads.py: -------------------------------------------------------------------------------- 1 | """Contains monadic types and functions for working with them.""" 2 | 3 | from __future__ import annotations 4 | 5 | import abc 6 | import asyncio 7 | import functools 8 | import inspect 9 | import sys 10 | import typing as t 11 | 12 | from typing_extensions import Self 13 | 14 | if sys.version_info < (3, 10): 15 | from typing_extensions import ParamSpec 16 | else: 17 | from typing import ParamSpec 18 | 19 | T = t.TypeVar("T") # The type of the value inside the Monad 20 | U = t.TypeVar("U") # The transformed type of the value inside the Monad 21 | K = t.TypeVar("K") # A known type that is not necessarily the same as T 22 | L = t.TypeVar("L") # A known type that is not necessarily the same as U 23 | E = t.TypeVar( 24 | "E", bound=BaseException, covariant=True 25 | ) # The type of the error inside the Result 26 | P = ParamSpec("P") 27 | 28 | TState = t.TypeVar("TState") # The type of the state 29 | TMonad = t.TypeVar("TMonad", bound="Monad") # Generic Self type for Monad 30 | 31 | 32 | class Monad(t.Generic[T], abc.ABC): 33 | def __init__(self, value: T) -> None: 34 | self._value = value 35 | 36 | def __hash__(self) -> int: 37 | return hash(self._value) 38 | 39 | @abc.abstractmethod 40 | def bind(self, func: t.Callable[[T], "Monad[U]"]) -> "Monad[U]": 41 | pass 42 | 43 | @abc.abstractmethod 44 | def map(self, func: t.Callable[[T], U]) -> "Monad[U]": 45 | pass 46 | 47 | @abc.abstractmethod 48 | def filter(self, predicate: t.Callable[[T], bool]) -> Self: 49 | pass 50 | 51 | @abc.abstractmethod 52 | def unwrap(self) -> T: 53 | pass 54 | 55 | @abc.abstractmethod 56 | def unwrap_or(self, default: U) -> t.Union[T, U]: 57 | pass 58 | 59 | def __call__(self, func: t.Callable[[T], "Monad[U]"]) -> "Monad[U]": 60 | return self.bind(func) 61 | 62 | def __rshift__(self, func: t.Callable[[T], "Monad[U]"]) -> "Monad[U]": 63 | return self.bind(func) 64 | 65 | 66 | class Maybe(Monad[T], abc.ABC): 67 | @classmethod 68 | def pure(cls, value: K) -> "Maybe[K]": 69 | """Creates a Maybe with a value.""" 70 | return Just(value) 71 | 72 | @abc.abstractmethod 73 | def is_just(self) -> bool: 74 | pass 75 | 76 | @abc.abstractmethod 77 | def is_nothing(self) -> bool: 78 | pass 79 | 80 | if t.TYPE_CHECKING: 81 | 82 | def bind(self, func: t.Callable[[T], "Maybe[U]"]) -> "Maybe[U]": ... 83 | 84 | def map(self, func: t.Callable[[T], U]) -> "Maybe[U]": ... 85 | 86 | def filter(self, predicate: t.Callable[[T], bool]) -> "Maybe[T]": ... 87 | 88 | def unwrap(self) -> T: 89 | """Unwraps the value of the Maybe. 90 | 91 | Returns: 92 | The unwrapped value. 93 | """ 94 | if self.is_just(): 95 | return self._value 96 | else: 97 | raise ValueError("Cannot unwrap Nothing.") 98 | 99 | def unwrap_or(self, default: U) -> t.Union[T, U]: 100 | """Tries to unwrap the Maybe, returning a default value if the Maybe is Nothing. 101 | 102 | Args: 103 | default: The value to return if unwrapping Nothing. 104 | 105 | Returns: 106 | The unwrapped value or the default value. 107 | """ 108 | if self.is_just(): 109 | return self._value 110 | else: 111 | return default 112 | 113 | @classmethod 114 | def lift(cls, func: t.Callable[[U], K]) -> t.Callable[["U | Maybe[U]"], "Maybe[K]"]: 115 | """Lifts a function to work within the Maybe monad. 116 | 117 | Args: 118 | func: A function to lift. 119 | 120 | Returns: 121 | A new function that returns a Maybe value. 122 | """ 123 | 124 | @functools.wraps(func) 125 | def wrapper(value: U | Maybe[U]) -> Maybe[K]: 126 | if isinstance(value, Maybe): 127 | return value.map(func) # type: ignore 128 | value = t.cast(U, value) 129 | try: 130 | result = func(value) 131 | if result is None: 132 | return Nothing() 133 | return Just(result) 134 | except Exception: 135 | return Nothing() 136 | 137 | return wrapper 138 | 139 | def __iter__(self) -> t.Iterator[T]: 140 | """Allows safely unwrapping the value of the Maybe using a for construct.""" 141 | if self.is_just(): 142 | yield self.unwrap() 143 | 144 | 145 | class Just(Maybe[T]): 146 | def bind(self, func: t.Callable[[T], Maybe[U]]) -> Maybe[U]: 147 | """Applies a function to the value inside the Just. 148 | 149 | Args: 150 | func: A function that takes a value of type T and returns a Maybe containing a value of type U. 151 | 152 | Returns: 153 | The result of applying the function to the value inside the Just. 154 | """ 155 | return func(self._value) 156 | 157 | def map(self, func: t.Callable[[T], U]) -> "Maybe[U]": 158 | """Applies a mapping function to the value inside the Just. 159 | 160 | Args: 161 | func: A function that takes a value of type T and returns a value of type U. 162 | 163 | Returns: 164 | A new Just containing the result of applying the function to the value inside the Just. 165 | """ 166 | try: 167 | result = func(self._value) 168 | if result is None: 169 | return Nothing() 170 | return Just(result) 171 | except Exception: 172 | return Nothing() 173 | 174 | def filter(self, predicate: t.Callable[[T], bool]) -> Maybe[T]: 175 | """Filters the value inside the Just based on a predicate. 176 | 177 | Args: 178 | predicate: A function that takes a value of type T and returns a boolean. 179 | 180 | Returns: 181 | A new Just containing the value inside the Just if the predicate holds. 182 | """ 183 | if predicate(self._value): 184 | return self 185 | else: 186 | return Nothing() 187 | 188 | def is_just(self) -> bool: 189 | """Returns True if the Maybe is a Just.""" 190 | return True 191 | 192 | def is_nothing(self) -> bool: 193 | """Returns False if the Maybe is a Just.""" 194 | return False 195 | 196 | def __repr__(self) -> str: 197 | return f"Just({self._value})" 198 | 199 | 200 | class Nothing(Maybe[T]): 201 | def __init__(self) -> None: 202 | super().__init__(t.cast(T, None)) 203 | 204 | def bind(self, func: t.Callable[[T], Maybe[U]]) -> "Nothing[T]": 205 | """Applies a function to the value inside the Just. 206 | 207 | Args: 208 | func: A function that takes a value of type T and returns a Maybe containing a value of type U. 209 | 210 | Returns: 211 | The result of applying the function to the value inside the Just. 212 | """ 213 | return self 214 | 215 | def map(self, func: t.Callable[[T], U]) -> "Nothing[T]": 216 | """Applies a mapping function to the value inside the Just. 217 | 218 | Args: 219 | func: A function that takes a value of type T and returns a value of type U. 220 | 221 | Returns: 222 | A new Just containing the result of applying the function to the value inside the Just. 223 | """ 224 | return self 225 | 226 | def filter(self, predicate: t.Callable[[T], bool]) -> "Nothing[T]": 227 | """Filters the value inside the Just based on a predicate. 228 | 229 | Args: 230 | predicate: A function that takes a value of type T and returns a boolean. 231 | 232 | Returns: 233 | A new Just containing the value inside the Just if the predicate holds. 234 | """ 235 | return self 236 | 237 | def is_just(self) -> bool: 238 | """Returns False if the Maybe is a Nothing.""" 239 | return False 240 | 241 | def is_nothing(self) -> bool: 242 | """Returns True if the Maybe is a Nothing.""" 243 | return True 244 | 245 | def __repr__(self) -> str: 246 | return "Nothing()" 247 | 248 | 249 | class Result(Monad[T], t.Generic[T, E]): 250 | @classmethod 251 | def pure(cls, value: K) -> "Result[K, E]": 252 | """Creates an Ok with a value.""" 253 | return Ok(value) 254 | 255 | @abc.abstractmethod 256 | def is_ok(self) -> bool: 257 | pass 258 | 259 | @abc.abstractmethod 260 | def is_err(self) -> bool: 261 | pass 262 | 263 | @abc.abstractmethod 264 | def unwrap(self) -> T: 265 | pass 266 | 267 | @abc.abstractmethod 268 | def unwrap_or(self, default: U) -> t.Union[T, U]: 269 | pass 270 | 271 | @abc.abstractmethod 272 | def unwrap_err(self) -> BaseException: 273 | pass 274 | 275 | @abc.abstractmethod 276 | def to_parts(self) -> t.Tuple[T, E | None]: 277 | pass 278 | 279 | @classmethod 280 | def lift( 281 | cls, func: t.Callable[[U], K] 282 | ) -> t.Callable[["U | Result[U, Exception]"], "Result[K, Exception]"]: 283 | """Transforms a function to work with arguments and output wrapped in Result monads. 284 | 285 | Args: 286 | func: A function that takes any number of arguments and returns a value of type T. 287 | 288 | Returns: 289 | A function that takes the same number of unwrapped arguments and returns a Result-wrapped result. 290 | """ 291 | 292 | def wrapper(result: U | Result[U, Exception]) -> Result[K, Exception]: 293 | if isinstance(result, Result): 294 | return result.map(func) 295 | result = t.cast(U, result) 296 | try: 297 | return Ok(func(result)) 298 | except Exception as e: 299 | return Err(e) 300 | 301 | if hasattr(func, "__defaults__") and func.__defaults__: 302 | default = func.__defaults__[0] 303 | wrapper.__defaults__ = (default,) 304 | 305 | return wrapper 306 | 307 | if t.TYPE_CHECKING: 308 | 309 | def bind(self, func: t.Callable[[T], "Result[U, E]"]) -> "Result[U, E]": ... 310 | 311 | def map(self, func: t.Callable[[T], U]) -> "Result[U, E]": ... 312 | 313 | def filter(self, predicate: t.Callable[[T], bool]) -> "Result[T, E]": ... 314 | 315 | def __call__(self, func: t.Callable[[T], "Result[U, E]"]) -> "Result[U, E]": ... 316 | 317 | def __rshift__( 318 | self, func: t.Callable[[T], "Result[U, E]"] 319 | ) -> "Result[U, E]": ... 320 | 321 | def __iter__(self) -> t.Iterator[T]: 322 | """Allows safely unwrapping the value of the Result using a for construct.""" 323 | if self.is_ok(): 324 | yield self.unwrap() 325 | 326 | 327 | class Ok(Result[T, E]): 328 | def bind(self, func: t.Callable[[T], Result[U, E]]) -> Result[U, E]: 329 | """Applies a function to the result of the Ok. 330 | 331 | Args: 332 | func: A function that takes a value of type T and returns a Result containing a value of type U. 333 | 334 | Returns: 335 | A new Result containing the result of the original Result after applying the function. 336 | """ 337 | return func(self._value) 338 | 339 | def map(self, func: t.Callable[[T], U]) -> Result[U, E]: 340 | """Applies a mapping function to the result of the Ok. 341 | 342 | Args: 343 | func: A function that takes a value of type T and returns a value of type U. 344 | 345 | Returns: 346 | A new Ok containing the result of the original Ok after applying the function. 347 | """ 348 | try: 349 | return Ok(func(self._value)) 350 | except Exception as e: 351 | return Err(t.cast(E, e)) 352 | 353 | def is_ok(self) -> bool: 354 | """Returns True if the Result is an Ok.""" 355 | return True 356 | 357 | def is_err(self) -> bool: 358 | """Returns False if the Result is an Ok.""" 359 | return False 360 | 361 | def unwrap(self) -> T: 362 | """Unwraps the value of the Ok. 363 | 364 | Returns: 365 | The unwrapped value. 366 | """ 367 | return self._value 368 | 369 | def unwrap_or(self, default: t.Any) -> T: 370 | """Tries to unwrap the Ok, returning a default value if unwrapping raises an exception. 371 | 372 | Args: 373 | default: The value to return if unwrapping raises an exception. 374 | 375 | Returns: 376 | The unwrapped value or the default value if an exception is raised. 377 | """ 378 | return self._value 379 | 380 | def unwrap_err(self) -> BaseException: 381 | """Raises a ValueError since the Result is an Ok.""" 382 | raise ValueError("Called unwrap_err on Ok") 383 | 384 | def filter(self, predicate: t.Callable[[T], bool]) -> Result[T, E]: 385 | """Filters the result of the Ok based on a predicate. 386 | 387 | Args: 388 | predicate: A function that takes a value of type T and returns a boolean. 389 | error: The error to return if the predicate does not hold. 390 | 391 | Returns: 392 | A new Result containing the result of the original Result if the predicate holds. 393 | """ 394 | if predicate(self._value): 395 | return self 396 | else: 397 | return Err(t.cast(E, ValueError("Predicate does not hold"))) 398 | 399 | def to_parts(self) -> t.Tuple[T, None]: 400 | """Unpacks the value of the Ok.""" 401 | return (self._value, None) 402 | 403 | def __repr__(self) -> str: 404 | return f"Ok({self._value})" 405 | 406 | 407 | class Err(Result[T, E]): 408 | def __init__(self, error: E) -> None: 409 | """Initializes an Err with an error. 410 | 411 | Args: 412 | error: The error to wrap in the Err. 413 | """ 414 | self._error = error 415 | 416 | def __hash__(self) -> int: 417 | return hash(self._error) 418 | 419 | def bind(self, func: t.Callable[[T], Result[U, E]]) -> "Err[T, E]": 420 | """Applies a function to the result of the Err. 421 | 422 | Args: 423 | func: A function that takes a value of type T and returns a Result containing a value of type U. 424 | 425 | Returns: 426 | An Err containing the original error. 427 | """ 428 | return self 429 | 430 | def map(self, func: t.Callable[[T], U]) -> "Err[T, E]": 431 | """Applies a mapping function to the result of the Err. 432 | 433 | Args: 434 | func: A function that takes a value of type T and returns a value of type U. 435 | 436 | Returns: 437 | An Err containing the original error. 438 | """ 439 | return self 440 | 441 | def is_ok(self) -> bool: 442 | """Returns False if the Result is an Err.""" 443 | return False 444 | 445 | def is_err(self) -> bool: 446 | """Returns True if the Result is an Err.""" 447 | return True 448 | 449 | def unwrap(self) -> T: 450 | """Raises a ValueError since the Result is an Err.""" 451 | raise self._error 452 | 453 | def unwrap_or(self, default: U) -> U: 454 | """Returns a default value since the Result is an Err. 455 | 456 | Args: 457 | default: The value to return. 458 | 459 | Returns: 460 | The default value. 461 | """ 462 | return default 463 | 464 | def unwrap_err(self) -> BaseException: 465 | """Unwraps the error of the Err. 466 | 467 | Returns: 468 | The unwrapped error. 469 | """ 470 | return self._error 471 | 472 | def filter(self, predicate: t.Callable[[T], bool]) -> "Err[T, E]": 473 | """Filters the result of the Err based on a predicate. 474 | 475 | Args: 476 | predicate: A function that takes a value of type T and returns a boolean. 477 | 478 | Returns: 479 | An Err containing the original error. 480 | """ 481 | return self 482 | 483 | def to_parts(self) -> t.Tuple[None, E]: 484 | """Unpacks the error of the Err.""" 485 | return (None, self._error) 486 | 487 | def __repr__(self) -> str: 488 | return f"Err({self._error})" 489 | 490 | 491 | class Promise(t.Generic[T], t.Awaitable[T], Monad[T]): 492 | def __init__( 493 | self, 494 | coro_func: t.Callable[P, t.Coroutine[None, None, T]], 495 | *args: P.args, 496 | **kwargs: P.kwargs, 497 | ) -> None: 498 | """Initializes a Promise with a coroutine function. 499 | 500 | Args: 501 | coro_func: A coroutine function that returns a value of type T. 502 | args: Positional arguments to pass to the coroutine function. 503 | kwargs: Keyword arguments to pass to the coroutine function. 504 | """ 505 | self._loop = asyncio.get_event_loop() 506 | if callable(coro_func): 507 | coro = coro_func(*args, **kwargs) 508 | elif inspect.iscoroutine(coro_func): 509 | coro = t.cast(t.Coroutine[None, None, T], coro_func) 510 | else: 511 | raise ValueError("Invalid coroutine function") 512 | self._future: asyncio.Future[T] = asyncio.ensure_future(coro, loop=self._loop) 513 | 514 | @classmethod 515 | def pure(cls, value: K) -> "Promise[K]": 516 | """Creates a Promise that is already resolved with a value. 517 | 518 | Args: 519 | value: The value to resolve the Promise with. 520 | 521 | Returns: 522 | A new Promise that is already resolved with the value. 523 | """ 524 | return cls.from_value(value) # type: ignore 525 | 526 | def __hash__(self) -> int: 527 | return hash(self._future) 528 | 529 | def __await__(self): 530 | """Allows the Promise to be awaited.""" 531 | yield from self._future.__await__() 532 | return (yield from self._future.__await__()) 533 | 534 | def set_result(self, result: T) -> None: 535 | """Sets a result on the Promise. 536 | 537 | Args: 538 | result: The result to set on the Promise. 539 | """ 540 | if not self._future.done(): 541 | self._loop.call_soon_threadsafe(self._future.set_result, result) 542 | 543 | def set_exception(self, exception: Exception) -> None: 544 | """Sets an exception on the Promise. 545 | 546 | Args: 547 | exception: The exception to set on the Promise. 548 | """ 549 | if not self._future.done(): 550 | self._loop.call_soon_threadsafe(self._future.set_exception, exception) 551 | 552 | def bind(self, func: t.Callable[[T], "Promise[U]"]) -> "Promise[U]": 553 | """Applies a function to the result of the Promise. 554 | 555 | Args: 556 | func: A function that takes a value of type T and returns a Promise containing a value of type U. 557 | 558 | Returns: 559 | A new Promise containing the result of the original Promise after applying the function. 560 | """ 561 | 562 | async def bound_coro() -> U: 563 | try: 564 | value = await self 565 | next_promise = func(value) 566 | return await next_promise 567 | except Exception as e: 568 | future = self._loop.create_future() 569 | future.set_exception(e) 570 | return t.cast(U, await future) 571 | 572 | return Promise(bound_coro) 573 | 574 | def map(self, func: t.Callable[[T], U]) -> "Promise[U]": 575 | """Applies a mapping function to the result of the Promise. 576 | 577 | Args: 578 | func: A function that takes a value of type T and returns a value of type U. 579 | 580 | Returns: 581 | A new Promise containing the result of the original Promise after applying the function. 582 | """ 583 | 584 | async def mapped_coro() -> U: 585 | try: 586 | value = await self 587 | return func(value) 588 | except Exception as e: 589 | future = self._loop.create_future() 590 | future.set_exception(e) 591 | return t.cast(U, await future) 592 | 593 | return Promise(mapped_coro) 594 | 595 | then = map # syntactic sugar, equivalent to map 596 | 597 | def filter(self, predicate: t.Callable[[T], bool]) -> "Promise[T]": 598 | """Filters the result of the Promise based on a predicate. 599 | 600 | Args: 601 | predicate: A function that takes a value of type T and returns a boolean. 602 | 603 | Returns: 604 | A new Promise containing the result of the original Promise if the predicate holds. 605 | """ 606 | 607 | async def filtered_coro() -> T: 608 | try: 609 | value = await self 610 | if predicate(value): 611 | return value 612 | else: 613 | raise ValueError("Filter predicate failed") 614 | except Exception as e: 615 | future = self._loop.create_future() 616 | future.set_exception(e) 617 | return await future 618 | 619 | return Promise(filtered_coro) 620 | 621 | def unwrap(self) -> T: 622 | return self._loop.run_until_complete(self) 623 | 624 | def unwrap_or(self, default: T) -> T: 625 | """Tries to unwrap the Promise, returning a default value if unwrapping raises an exception. 626 | 627 | Args: 628 | default: The value to return if unwrapping raises an exception. 629 | 630 | Returns: 631 | The unwrapped value or the default value if an exception is raised. 632 | """ 633 | try: 634 | return self._loop.run_until_complete(self) 635 | except Exception: 636 | return default 637 | 638 | @classmethod 639 | def from_value(cls, value: T) -> "Promise[T]": 640 | """Creates a Promise that is already resolved with a value. 641 | 642 | Args: 643 | value: The value to resolve the Promise with. 644 | 645 | Returns: 646 | A new Promise that is already resolved with the value. 647 | """ 648 | 649 | async def _fut(): 650 | return value 651 | 652 | return cls(_fut) 653 | 654 | @classmethod 655 | def from_exception(cls, exception: BaseException) -> "Promise[T]": 656 | """Creates a Promise that is already resolved with an exception. 657 | 658 | Args: 659 | exception: The exception to resolve the Promise with. 660 | 661 | Returns: 662 | A new Promise that is already resolved with the exception. 663 | """ 664 | 665 | async def _fut(): 666 | raise exception 667 | 668 | return cls(_fut) 669 | 670 | @classmethod 671 | def lift( 672 | cls, func: t.Callable[[U], T] 673 | ) -> t.Callable[["U | Promise[U]"], "Promise[T]"]: 674 | """ 675 | Lifts a synchronous function to work within the Promise context, 676 | making it return a Promise of the result and allowing it to be used 677 | with Promise inputs. 678 | 679 | Args: 680 | func: A synchronous function that returns a value of type T. 681 | 682 | Returns: 683 | A function that, when called, returns a Promise wrapping the result of the original function. 684 | """ 685 | 686 | @functools.wraps(func) 687 | def wrapper(value: "U | Promise[U]") -> "Promise[T]": 688 | if isinstance(value, Promise): 689 | return value.map(func) 690 | value = t.cast(U, value) 691 | 692 | async def async_wrapper() -> T: 693 | return func(value) 694 | 695 | return cls(async_wrapper) 696 | 697 | return wrapper 698 | 699 | 700 | class Lazy(Monad[T]): 701 | def __init__(self, computation: t.Callable[[], T]) -> None: 702 | """Initializes a Lazy monad with a computation that will be executed lazily. 703 | 704 | Args: 705 | computation: A function that takes no arguments and returns a value of type T. 706 | """ 707 | self._computation = computation 708 | self._value = None 709 | self._evaluated = False 710 | 711 | @classmethod 712 | def pure(cls, value: T) -> "Lazy[T]": 713 | """Creates a Lazy monad with a pure value.""" 714 | return cls(lambda: value) 715 | 716 | def evaluate(self) -> T: 717 | """Evaluates the computation if it has not been evaluated yet and caches the result. 718 | 719 | Returns: 720 | The result of the computation. 721 | """ 722 | if not self._evaluated: 723 | self._value = self._computation() 724 | self._evaluated = True 725 | return t.cast(T, self._value) 726 | 727 | def bind(self, func: t.Callable[[T], "Lazy[U]"]) -> "Lazy[U]": 728 | """Lazily applies a function to the result of the current computation. 729 | 730 | Args: 731 | func: A function that takes a value of type T and returns a Lazy monad containing a value of type U. 732 | 733 | Returns: 734 | A new Lazy monad containing the result of the computation after applying the function. 735 | """ 736 | return Lazy(lambda: func(self.evaluate()).evaluate()) 737 | 738 | def map(self, func: t.Callable[[T], U]) -> "Lazy[U]": 739 | """Lazily applies a mapping function to the result of the computation. 740 | 741 | Args: 742 | func: A function that takes a value of type T and returns a value of type U. 743 | 744 | Returns: 745 | A new Lazy monad containing the result of the computation after applying the function. 746 | """ 747 | return Lazy(lambda: func(self.evaluate())) 748 | 749 | def filter(self, predicate: t.Callable[[T], bool]) -> "Lazy[T]": 750 | """Lazily filters the result of the computation based on a predicate. 751 | 752 | Args: 753 | predicate: A function that takes a value of type T and returns a boolean. 754 | 755 | Returns: 756 | A new Lazy monad containing the result of the computation if the predicate holds. 757 | """ 758 | 759 | def filter_computation(): 760 | result = self.evaluate() 761 | if predicate(result): 762 | return result 763 | else: 764 | raise ValueError("Predicate does not hold for the value.") 765 | 766 | return Lazy(filter_computation) 767 | 768 | def unwrap(self) -> T: 769 | """Forces evaluation of the computation and returns its result. 770 | 771 | Returns: 772 | The result of the computation. 773 | """ 774 | return self.evaluate() 775 | 776 | def unwrap_or(self, default: T) -> T: 777 | """Tries to evaluate the computation, returning a default value if evaluation raises an exception. 778 | 779 | Args: 780 | default: The value to return if the computation raises an exception. 781 | 782 | Returns: 783 | The result of the computation or the default value if an exception is raised. 784 | """ 785 | try: 786 | return self.evaluate() 787 | except Exception: 788 | return default 789 | 790 | @classmethod 791 | def lift(cls, func: t.Callable[[U], T]) -> t.Callable[["U | Lazy[U]"], "Lazy[T]"]: 792 | """Transforms a function to work with arguments and output wrapped in Lazy monads. 793 | 794 | Args: 795 | func: A function that takes any number of arguments and returns a value of type U. 796 | 797 | Returns: 798 | A function that takes the same number of Lazy-wrapped arguments and returns a Lazy-wrapped result. 799 | """ 800 | 801 | @functools.wraps(func) 802 | def wrapper(value: "U | Lazy[U]") -> "Lazy[T]": 803 | if isinstance(value, Lazy): 804 | return value.map(func) 805 | value = t.cast(U, value) 806 | 807 | def computation() -> T: 808 | return func(value) 809 | 810 | return cls(computation) 811 | 812 | return wrapper 813 | 814 | 815 | Defer = Lazy # Defer is an alias for Lazy 816 | 817 | S = t.TypeVar("S") # State type 818 | A = t.TypeVar("A") # Return type 819 | B = t.TypeVar("B") # Transformed type 820 | 821 | 822 | class State(t.Generic[S, A], Monad[A], abc.ABC): 823 | def __init__(self, run_state: t.Callable[[S], t.Tuple[A, S]]) -> None: 824 | self.run_state = run_state 825 | 826 | def bind(self, func: t.Callable[[A], "State[S, B]"]) -> "State[S, B]": 827 | def new_run_state(s: S) -> t.Tuple[B, S]: 828 | a, state_prime = self.run_state(s) 829 | return func(a).run_state(state_prime) 830 | 831 | return State(new_run_state) 832 | 833 | def map(self, func: t.Callable[[A], B]) -> "State[S, B]": 834 | def new_run_state(s: S) -> t.Tuple[B, S]: 835 | a, state_prime = self.run_state(s) 836 | return func(a), state_prime 837 | 838 | return State(new_run_state) 839 | 840 | def filter(self, predicate: t.Callable[[A], bool]) -> "State[S, A]": 841 | def new_run_state(s: S) -> t.Tuple[A, S]: 842 | a, state_prime = self.run_state(s) 843 | if predicate(a): 844 | return a, state_prime 845 | else: 846 | raise ValueError("Value does not satisfy predicate") 847 | 848 | return State(new_run_state) 849 | 850 | def unwrap(self) -> A: 851 | raise NotImplementedError( 852 | "State cannot be directly unwrapped without providing an initial state." 853 | ) 854 | 855 | def unwrap_or(self, default: B) -> t.Union[A, B]: 856 | raise NotImplementedError( 857 | "State cannot directly return a value without an initial state." 858 | ) 859 | 860 | def __hash__(self) -> int: 861 | return id(self.run_state) 862 | 863 | @staticmethod 864 | def pure(value: A) -> "State[S, A]": 865 | return State(lambda s: (value, s)) 866 | 867 | def __call__(self, state: S) -> t.Tuple[A, S]: 868 | return self.run_state(state) 869 | 870 | def __repr__(self) -> str: 871 | return f"State({self.run_state})" 872 | 873 | @classmethod 874 | def lift( 875 | cls, func: t.Callable[[U], A] 876 | ) -> t.Callable[["U | State[S, U]"], "State[S, A]"]: 877 | """Lifts a function to work within the State monad. 878 | Args: 879 | func: A function to lift. 880 | Returns: 881 | A new function that returns a State value. 882 | """ 883 | 884 | @functools.wraps(func) 885 | def wrapper(value: "U | State[S, U]") -> "State[S, A]": 886 | if isinstance(value, State): 887 | return value.map(func) 888 | value = t.cast(U, value) 889 | 890 | def run_state(s: S) -> t.Tuple[A, S]: 891 | return func(value), s 892 | 893 | return cls(run_state) 894 | 895 | return wrapper 896 | 897 | 898 | # Aliases for monadic converters 899 | # to_ is the pure function 900 | # is the lift function 901 | 902 | to_maybe = just = Maybe.pure 903 | nothing = Nothing[t.Any]() 904 | maybe = Maybe.lift 905 | 906 | to_result = ok = Result.pure 907 | error = lambda e: Err(e) # noqa: E731 908 | result = Result.lift 909 | 910 | to_promise = Promise.pure 911 | promise = Promise.lift 912 | 913 | to_lazy = Lazy.pure 914 | lazy = Lazy.lift 915 | 916 | to_deferred = Defer.pure 917 | deferred = Defer.lift 918 | 919 | to_state = State.pure 920 | state = State.lift 921 | 922 | # to_io = IO.pure 923 | # io = IO.lift 924 | --------------------------------------------------------------------------------