├── docs
├── .nojekyll
├── _sidebar.md
├── _navbar.md
├── _coverpage.md
├── index.html
└── overview.md
├── tests
├── __init__.py
├── core
│ ├── injector
│ │ ├── __init__.py
│ │ ├── test_config.py
│ │ └── test_registry.py
│ └── test_workspace.py
├── integrations
│ └── feature_flag
│ │ ├── test_file.py
│ │ ├── test_harness.py
│ │ └── test_launchdarkly.py
├── builtin
│ ├── test_filters.py
│ └── test_metrics.py
└── types
│ └── test_monads.py
├── src
└── cdf
│ ├── core
│ ├── __init__.py
│ ├── component
│ │ ├── operation.py
│ │ ├── service.py
│ │ ├── publisher.py
│ │ ├── __init__.py
│ │ ├── pipeline.py
│ │ └── base.py
│ ├── injector
│ │ ├── errors.py
│ │ ├── __init__.py
│ │ └── registry.py
│ ├── context.py
│ ├── workspace.py
│ └── configuration.py
│ ├── integrations
│ ├── __init__.py
│ ├── feature_flag
│ │ ├── noop.py
│ │ ├── split.py
│ │ ├── launchdarkly.py
│ │ ├── __init__.py
│ │ ├── file.py
│ │ ├── base.py
│ │ └── harness.py
│ └── slack.py
│ ├── types
│ ├── __init__.py
│ └── monads.py
│ ├── proxy
│ ├── __init__.py
│ ├── mysql.py
│ └── planner.py
│ ├── __init__.py
│ └── builtin
│ ├── filters.py
│ └── metrics.py
├── pytest.ini
├── Makefile
├── .gitignore
├── .neoconf.json
├── .github
└── workflows
│ └── python-package.yml
├── pyproject.toml
├── README.md
└── LICENSE
/docs/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/cdf/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/cdf/integrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/core/injector/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integrations/feature_flag/test_file.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integrations/feature_flag/test_harness.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integrations/feature_flag/test_launchdarkly.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | markers =
3 | credentials: mark a test as requiring credentials
4 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: update-docs
2 |
3 | update-docs:
4 | @echo "Updating docs..."
5 | @pydoc-markdown -I src/cdf >docs/api_reference.md
6 | @echo "Done."
7 |
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 |
3 | .git/
4 |
5 | dist/
6 | logs/
7 |
8 | wip/
9 |
10 | _storage/
11 | _rendered/
12 |
13 | .cache/
14 | .ipynb_checkpoints/
15 |
16 | *.duckdb
17 |
--------------------------------------------------------------------------------
/.neoconf.json:
--------------------------------------------------------------------------------
1 | {
2 | "lspconfig": {
3 | "pyright": {
4 | "python": {
5 | "analysis": {
6 | "diagnosticMode": "workspace"
7 | }
8 | }
9 | }
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/src/cdf/core/component/operation.py:
--------------------------------------------------------------------------------
1 | from .base import Entrypoint
2 |
3 | OperationProto = int
4 |
5 |
6 | class Operation(Entrypoint[OperationProto], frozen=True):
7 | """A generic callable that returns an exit code."""
8 |
--------------------------------------------------------------------------------
/src/cdf/core/component/service.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from .base import Component
4 |
5 | ServiceProto = t.Any
6 |
7 |
8 | class Service(Component[ServiceProto], frozen=True):
9 | """A service that the workspace provides. IE an API, database, requests client, etc."""
10 |
--------------------------------------------------------------------------------
/docs/_sidebar.md:
--------------------------------------------------------------------------------
1 | - Getting started
2 |
3 | - [Quick start](quickstart.md)
4 |
5 | - Customization
6 |
7 | - [Configuration](configuration.md)
8 |
9 | - API Reference
10 |
11 | - [Python](api_reference.md)
12 | - [CLI](cli_reference.md)
13 |
14 | - [Changelog](changelog.md)
15 |
--------------------------------------------------------------------------------
/src/cdf/core/injector/errors.py:
--------------------------------------------------------------------------------
1 | class DependencyCycleError(Exception):
2 | """Raised when a dependency cycle is detected."""
3 |
4 | pass
5 |
6 |
7 | class DependencyMutationError(Exception):
8 | """Raised when an instance/singleton dependency has already been resolved but a mutation is attempted."""
9 |
10 | pass
11 |
--------------------------------------------------------------------------------
/src/cdf/core/injector/__init__.py:
--------------------------------------------------------------------------------
1 | from cdf.core.injector.registry import (
2 | GLOBAL_REGISTRY,
3 | Dependency,
4 | DependencyKey,
5 | DependencyRegistry,
6 | Lifecycle,
7 | )
8 |
9 | __all__ = [
10 | "Dependency",
11 | "DependencyRegistry",
12 | "DependencyKey",
13 | "Lifecycle",
14 | "GLOBAL_REGISTRY",
15 | ]
16 |
--------------------------------------------------------------------------------
/docs/_navbar.md:
--------------------------------------------------------------------------------
1 | * Getting started
2 |
3 | * [Quick start](quickstart.md)
4 | * [Writing more pages](more-pages.md)
5 | * [Custom navbar](custom-navbar.md)
6 | * [Cover page](cover.md)
7 |
8 | * Configuration
9 |
10 | * [Configuration](configuration.md)
11 | * [Themes](themes.md)
12 | * [Using plugins](plugins.md)
13 | * [Markdown configuration](markdown.md)
14 | * [Language highlight](language-highlight.md)
15 |
--------------------------------------------------------------------------------
/docs/_coverpage.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Continuous Data Framework (cdf)
4 |
5 | > A framework for managing data, continuously.
6 |
7 | - Simple and lightweight
8 | - Leverages the power of `sqlmesh` and `dlt`
9 | - Opinionated, but flexible
10 |
11 | [GitHub](https://github.com/z3z1ma/cdf/)
12 | [Get Started](overview.md)
13 |
--------------------------------------------------------------------------------
/src/cdf/types/__init__.py:
--------------------------------------------------------------------------------
1 | """A module for shared types."""
2 |
3 | import sys
4 | import typing as t
5 | from pathlib import Path
6 |
7 | import cdf.types.monads as M
8 |
9 | if t.TYPE_CHECKING:
10 | import decimal
11 |
12 | PathLike = t.Union[str, Path]
13 | Number = t.Union[int, float, "decimal.Decimal"]
14 |
15 | if sys.version_info < (3, 10):
16 | from typing_extensions import ParamSpec
17 | else:
18 | from typing import ParamSpec
19 |
20 | P = ParamSpec("P")
21 |
22 | __all__ = ["M", "P", "PathLike", "Number"]
23 |
--------------------------------------------------------------------------------
/src/cdf/proxy/__init__.py:
--------------------------------------------------------------------------------
1 | """The proxy module provides a MySQL proxy server for the CDF.
2 |
3 | The proxy server is used to intercept MySQL queries and execute them using SQLMesh.
4 | This allows it to integrate with BI tools and other MySQL clients. Furthermore,
5 | during interception, the server can rewrite queries expanding semantic references
6 | making it an easy to use semantic layer for SQLMesh.
7 | """
8 |
9 | from cdf.proxy.mysql import run_mysql_proxy
10 | from cdf.proxy.planner import run_plan_server
11 |
12 |
13 | __all__ = ["run_mysql_proxy", "run_plan_server"]
14 |
--------------------------------------------------------------------------------
/src/cdf/integrations/feature_flag/noop.py:
--------------------------------------------------------------------------------
1 | """No-op feature flag provider."""
2 |
3 | import typing as t
4 |
5 | from cdf.integrations.feature_flag.base import AbstractFeatureFlagAdapter
6 |
7 |
8 | class NoopFeatureFlagAdapter(AbstractFeatureFlagAdapter):
9 | """A feature flag adapter that does nothing."""
10 |
11 | def __init__(self, **kwargs: t.Any) -> None:
12 | """Initialize the adapter."""
13 | pass
14 |
15 | def get(self, feature_name: str) -> bool:
16 | return True
17 |
18 | def save(self, feature_name: str, flag: bool) -> None:
19 | pass
20 |
21 | def get_all_feature_names(self) -> t.List[str]:
22 | return []
23 |
24 |
25 | __all__ = ["NoopFeatureFlagAdapter"]
26 |
--------------------------------------------------------------------------------
/src/cdf/core/component/publisher.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from .base import Entrypoint
4 |
5 | DataPublisherProto = t.Tuple[
6 | t.Callable[..., None], # run
7 | t.Callable[..., bool], # preflight
8 | t.Optional[t.Callable[..., None]], # success hook
9 | t.Optional[t.Callable[..., None]], # failure hook
10 | ]
11 |
12 |
13 | class DataPublisher(
14 | Entrypoint[DataPublisherProto],
15 | frozen=True,
16 | ):
17 | """A data publisher which pushes data to an operational system."""
18 |
19 | def __call__(self, *args: t.Any, **kwargs: t.Any) -> None:
20 | """Publish the data"""
21 | publisher, pre, success, err = self.main(*args, **kwargs)
22 | if not pre():
23 | raise ValueError("Preflight check failed")
24 | try:
25 | return publisher()
26 | except Exception as e:
27 | if err:
28 | err()
29 | raise e
30 | else:
31 | if success:
32 | success()
33 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.9", "3.10", "3.11"]
20 |
21 | steps:
22 | - uses: actions/checkout@v3
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v3
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install '.[dev]'
31 | - name: Test with pytest
32 | run: |
33 | pytest
34 |
--------------------------------------------------------------------------------
/src/cdf/__init__.py:
--------------------------------------------------------------------------------
1 | import cdf.core.configuration as conf
2 | from cdf.core.component import (
3 | DataPipeline,
4 | DataPublisher,
5 | Operation,
6 | Service,
7 | ServiceLevelAgreement,
8 | )
9 | from cdf.core.configuration import (
10 | ConfigResolver,
11 | Request,
12 | map_config_section,
13 | map_config_values,
14 | )
15 | from cdf.core.context import (
16 | get_active_workspace,
17 | invoke,
18 | resolve,
19 | set_active_workspace,
20 | use_workspace,
21 | )
22 | from cdf.core.injector import Dependency, DependencyRegistry
23 | from cdf.core.workspace import Workspace
24 |
25 | __all__ = [
26 | "conf",
27 | "DataPipeline",
28 | "DataPublisher",
29 | "Operation",
30 | "Service",
31 | "ServiceLevelAgreement",
32 | "ConfigResolver",
33 | "Request",
34 | "map_config_section",
35 | "map_config_values",
36 | "Workspace",
37 | "Dependency",
38 | "DependencyRegistry",
39 | "get_active_workspace",
40 | "set_active_workspace",
41 | "resolve",
42 | "invoke",
43 | "use_workspace",
44 | ]
45 |
--------------------------------------------------------------------------------
/src/cdf/integrations/feature_flag/split.py:
--------------------------------------------------------------------------------
1 | """Split feature flag provider."""
2 |
3 | import typing as t
4 |
5 | from cdf.integrations.feature_flag.base import AbstractFeatureFlagAdapter
6 |
7 |
8 | class SplitFeatureFlagAdapter(AbstractFeatureFlagAdapter):
9 | """A feature flag adapter that uses Split."""
10 |
11 | def __init__(self, sdk_key: str, **kwargs: t.Any) -> None:
12 | """Initialize the Split feature flags.
13 |
14 | Args:
15 | sdk_key: The SDK key to use for Split.
16 | """
17 | self.sdk_key = sdk_key
18 |
19 | def __repr__(self) -> str:
20 | return f"{type(self).__name__}(sdk_key={self.sdk_key!r})"
21 |
22 | def __str__(self) -> str:
23 | return self.sdk_key
24 |
25 | def get(self, feature_name: str) -> bool:
26 | raise NotImplementedError("This provider is not yet implemented")
27 |
28 | def save(self, feature_name: str, flag: bool) -> None:
29 | raise NotImplementedError("This provider is not yet implemented")
30 |
31 | def get_all_feature_names(self) -> t.List[str]:
32 | raise NotImplementedError("This provider is not yet implemented")
33 |
34 |
35 | __all__ = ["SplitFeatureFlagAdapter"]
36 |
--------------------------------------------------------------------------------
/src/cdf/integrations/feature_flag/launchdarkly.py:
--------------------------------------------------------------------------------
1 | """LaunchDarkly feature flag provider."""
2 |
3 | import typing as t
4 |
5 | from dlt.common.configuration import with_config
6 |
7 | from cdf.integrations.feature_flag.base import AbstractFeatureFlagAdapter
8 |
9 |
10 | class LaunchDarklyFeatureFlagAdapter(AbstractFeatureFlagAdapter):
11 | """A feature flag adapter that uses LaunchDarkly."""
12 |
13 | @with_config(sections=("feature_flags",))
14 | def __init__(self, sdk_key: str, **kwargs: t.Any) -> None:
15 | """Initialize the LaunchDarkly feature flags.
16 |
17 | Args:
18 | sdk_key: The SDK key to use for LaunchDarkly.
19 | """
20 | self.sdk_key = sdk_key
21 |
22 | def __repr__(self) -> str:
23 | return f"{type(self).__name__}(sdk_key={self.sdk_key!r})"
24 |
25 | def __str__(self) -> str:
26 | return self.sdk_key
27 |
28 | def get(self, feature_name: str) -> bool:
29 | raise NotImplementedError("This provider is not yet implemented")
30 |
31 | def save(self, feature_name: str, flag: bool) -> None:
32 | raise NotImplementedError("This provider is not yet implemented")
33 |
34 | def get_all_feature_names(self) -> t.List[str]:
35 | raise NotImplementedError("This provider is not yet implemented")
36 |
37 |
38 | __all__ = ["LaunchDarklyFeatureFlagAdapter"]
39 |
--------------------------------------------------------------------------------
/src/cdf/core/component/__init__.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from .base import Component, Entrypoint, ServiceLevelAgreement
4 | from .operation import Operation, OperationProto
5 | from .pipeline import DataPipeline, DataPipelineProto
6 | from .publisher import DataPublisher, DataPublisherProto
7 | from .service import Service, ServiceProto
8 |
9 | __all__ = [
10 | "DataPipeline",
11 | "DataPublisher",
12 | "Operation",
13 | "Service",
14 | "ServiceDef",
15 | "DataPipelineDef",
16 | "DataPublisherDef",
17 | "OperationDef",
18 | "TComponent",
19 | "TComponent",
20 | "ServiceLevelAgreement",
21 | ]
22 |
23 | ServiceDef = t.Union[
24 | Service,
25 | t.Callable[..., ServiceProto],
26 | t.Dict[str, t.Any],
27 | ]
28 | DataPipelineDef = t.Union[
29 | DataPipeline,
30 | t.Callable[..., DataPipelineProto],
31 | t.Dict[str, t.Any],
32 | ]
33 | DataPublisherDef = t.Union[
34 | DataPublisher,
35 | t.Callable[..., DataPublisherProto],
36 | t.Dict[str, t.Any],
37 | ]
38 | OperationDef = t.Union[
39 | Operation,
40 | t.Callable[..., OperationProto],
41 | t.Dict[str, t.Any],
42 | ]
43 |
44 | TComponent = t.TypeVar("TComponent", bound=t.Union[Component, Entrypoint])
45 | TComponentDef = t.TypeVar(
46 | "TComponentDef",
47 | ServiceDef,
48 | DataPipelineDef,
49 | DataPublisherDef,
50 | OperationDef,
51 | )
52 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Document
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/docs/overview.md:
--------------------------------------------------------------------------------
1 | # cdf
2 |
3 | CDF (Continuous Data Framework) is an integrated framework designed to manage
4 | data across the entire lifecycle, from ingestion through transformation to
5 | publishing. It is built on top of two open-source projects, `sqlmesh` and
6 | `dlt`, providing a unified interface for complex data operations. CDF
7 | simplifies data engineering workflows, offering scalable solutions from small
8 | to large projects through an opinionated project structure that supports both
9 | multi-workspace and single-workspace layouts. We place a heavy emphasis on the
10 | inner loop of data engineering. We believe that the most important part of data
11 | engineering is the ability to rapidly iterate on the data, and we have designed
12 | CDF to make that as easy as possible. We achieve this through a combination of
13 | dlt's simplicity in authoring pipelines with dynamic parameterization of sinks
14 | and developer utilities such as `head` and `discover`. We streamline the
15 | process of scaffolding out new components and view the idea of a workspace as
16 | something that is full of business-specific components. Pipelines,
17 | transformations, publishers, scripts, and notebooks. Spend less time on
18 | boilerplate, less time figuring out how to consolidate your custom code into
19 | perfect collections of software engineering best practices, and spend much more
20 | time on point solutions that solve your business problems. Thats the benefit of
21 | opinionation. And we offer it in a way that is flexible and extensible.
22 |
23 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "python-cdf"
3 | version = "0.9.4"
4 | description = "A framework to manage data continuously"
5 | authors = [
6 | { name = "z3z1ma", email = "butler.alex2010@gmail.com" },
7 | ]
8 | dependencies = [
9 | # We will find a good version range for these eventually, but allowing them to float
10 | # and be pinned by the user for now is far more useful.
11 | "sqlmesh",
12 | "dlt[duckdb]",
13 | "duckdb",
14 | # The following deps have well-defined version ranges
15 | "mysql-mimic>=2,<3",
16 | "harness-featureflags>=1.2.0,<1.6.1",
17 | "python-dotenv>=1,<2",
18 | "pex>=2.1.100,<2.2.0",
19 | "pydantic>=2.5.0,<3",
20 | "psutil~=5.9.0",
21 | "typing-extensions>=4,<5",
22 | "fsspec>=2022",
23 | "dynaconf>=3,<4",
24 | "eval_type_backport~=0.1.3; python_version<'3.10'",
25 | ]
26 | requires-python = ">=3.9,<3.13"
27 | readme = "README.md"
28 | license.file = "LICENSE"
29 |
30 | [tool.poetry]
31 | packages = [
32 | { include = "cdf", from = "src" }
33 | ]
34 |
35 | [project.optional-dependencies]
36 | dev = [
37 | # "poetry @ git+https://github.com/radoering/poetry.git@pep621-support",
38 | "pytest>=7.4.3",
39 | "pytest-mock>=3.12.0",
40 | "pydoc-markdown>4",
41 | ]
42 |
43 | [build-system]
44 | requires = ["poetry-core@ git+https://github.com/radoering/poetry-core.git@pep621-support"]
45 | build-backend = "poetry.core.masonry.api"
46 |
47 | [tool.hatch.metadata]
48 | allow-direct-references = true
49 |
50 | [tool.pyright]
51 | include = ["src"]
52 | exclude = ["examples/", "docs/"]
53 | ignore = ["src/builtin"]
54 | reportPrivateImportUsage = false
55 |
--------------------------------------------------------------------------------
/tests/builtin/test_filters.py:
--------------------------------------------------------------------------------
1 | from cdf.builtin.filters import (
2 | eq,
3 | gt,
4 | gte,
5 | in_list,
6 | lt,
7 | lte,
8 | ne,
9 | not_empty,
10 | not_in_list,
11 | not_null,
12 | )
13 |
14 |
15 | def test_eq():
16 | assert eq("name", "Alice")({"name": "Alice"})
17 |
18 |
19 | def test_ne():
20 | assert ne("name", "Alice")({"name": "Bob"})
21 |
22 |
23 | def test_gt():
24 | assert gt("age", 30)({"age": 35})
25 |
26 |
27 | def test_gte():
28 | assert gte("age", 30)({"age": 30})
29 |
30 |
31 | def test_lt():
32 | assert lt("age", 30)({"age": 25})
33 |
34 |
35 | def test_lte():
36 | assert lte("age", 30)({"age": 30})
37 |
38 |
39 | def test_in_list():
40 | assert in_list("name", ["Alice", "Bob"])({"name": "Alice"})
41 |
42 |
43 | def test_not_in_list():
44 | assert not_in_list("name", ["Alice", "Bob"])({"name": "Charlie"})
45 |
46 |
47 | def test_not_empty():
48 | assert not_empty("name")({"name": "Alice"})
49 | assert not_empty("name")({"name": 0})
50 | assert not_empty("name")({"name": False})
51 |
52 | assert not_empty("name")({"name": ""}) is False
53 | assert not_empty("name")({"name": []}) is False
54 | assert not_empty("name")({"name": {}}) is False
55 | assert not_empty("name")({"name": None}) is False
56 |
57 |
58 | def test_not_null():
59 | assert not_null("name")({"name": "Alice"})
60 | assert not_null("name")({"name": 0})
61 | assert not_null("name")({"name": False})
62 | assert not_null("name")({"name": []})
63 | assert not_null("name")({"name": {}})
64 |
65 | assert not_null("name")({"name": None}) is False
66 | assert not_null("name")({"whatever": 1}) is False
67 |
--------------------------------------------------------------------------------
/tests/builtin/test_metrics.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 |
3 | import pytest
4 |
5 | from cdf.builtin.metrics import (
6 | avg_value,
7 | count,
8 | max_value,
9 | median_value,
10 | min_value,
11 | mode_value,
12 | stdev_value,
13 | sum_value,
14 | unique,
15 | variance_value,
16 | )
17 |
18 |
19 | @pytest.fixture
20 | def data():
21 | return [
22 | {"name": "Alice", "age": 25},
23 | {"name": "Bob", "age": 30},
24 | {"name": "Charlie", "age": 35},
25 | {"name": "David", "age": 40},
26 | {"name": "Eve", "age": 45},
27 | {"name": "Frank", "age": 50},
28 | {"name": "Alice", "age": 25},
29 | {"name": "Bob"},
30 | ]
31 |
32 |
33 | def test_count(data):
34 | assert reduce(lambda metric, item: count(item, metric), data, 0) == 8
35 |
36 |
37 | def test_unique(data):
38 | func = unique("name")
39 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 6
40 |
41 |
42 | def test_max_value(data):
43 | func = max_value("age")
44 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 50
45 |
46 |
47 | def test_min_value(data):
48 | func = min_value("age")
49 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 25
50 |
51 |
52 | def test_sum_value(data):
53 | func = sum_value("age")
54 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 250
55 |
56 |
57 | def test_avg_value(data):
58 | func = avg_value("age")
59 | assert (
60 | reduce(lambda metric, item: func(item, metric), data, 0) == 35.714285714285715
61 | )
62 |
63 |
64 | def test_median_value(data):
65 | func = median_value("age")
66 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 35
67 |
68 |
69 | def test_variance_value(data):
70 | func = variance_value("age")
71 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 81.63265306122435
72 |
73 |
74 | def test_stdev_value(data):
75 | func = stdev_value("age")
76 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 9.035079029052504
77 |
78 |
79 | def test_mode_value(data):
80 | func = mode_value("age")
81 | assert reduce(lambda metric, item: func(item, metric), data, 0) == 25
82 |
--------------------------------------------------------------------------------
/tests/types/test_monads.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import time
3 | import typing as t
4 | from collections import defaultdict
5 |
6 | import requests
7 |
8 | from cdf.types.monads import State, promise, state, to_state
9 |
10 | threadtime = defaultdict(list)
11 |
12 | T = t.TypeVar("T")
13 |
14 |
15 | @promise
16 | def fetch(url: str) -> requests.Response:
17 | tid = threading.get_ident()
18 | threadtime[tid].append(time.perf_counter())
19 | resp = requests.get(url)
20 | return resp
21 |
22 |
23 | @promise
24 | def track(v: T) -> T:
25 | tid = threading.get_ident()
26 | threadtime[tid].append(time.perf_counter())
27 | return v
28 |
29 |
30 | @promise
31 | def num_abilities(resp: requests.Response) -> int:
32 | data = resp.json()
33 | i = len(data["abilities"])
34 | tid = threading.get_ident()
35 | threadtime[tid].append(time.perf_counter())
36 | return i
37 |
38 |
39 | def test_fetch():
40 | futs = []
41 | for i in range(5):
42 | print(f"Starting iteration {i}")
43 | futs.append(
44 | track("https://pokeapi.co/api/v2/pokemon/ditto")
45 | >> fetch
46 | >> track
47 | >> num_abilities
48 | >> track
49 | )
50 |
51 | for fut in futs:
52 | print(fut.unwrap())
53 |
54 |
55 | def test_state():
56 | state_x = to_state(1) # 1 is the value for the computations NOT the state
57 |
58 | @state
59 | def add_one(x: int) -> int:
60 | return x + 1
61 |
62 | def print_state(x: int):
63 | def _print(state: int):
64 | print(state)
65 | return x, state
66 |
67 | return _print
68 |
69 | add_one(State(print_state(1)))
70 |
71 | def process_int(x: int) -> State[list[int], int]:
72 | """Process an integer, tracking unique values"""
73 |
74 | def process(state: list[int]) -> t.Tuple[int, list[int]]:
75 | nonlocal x
76 | x += 1
77 | if x in state:
78 | return x, state
79 | state.append(x)
80 | return x, state
81 |
82 | return State(process)
83 |
84 | state_y = state_x >> process_int >> add_one >> process_int >> add_one
85 | x, y = state_y.run_state([]) # type: ignore
86 | assert x == 5
87 | assert y == [2, 4]
88 |
--------------------------------------------------------------------------------
/src/cdf/integrations/feature_flag/__init__.py:
--------------------------------------------------------------------------------
1 | """Feature flag providers implement a uniform interface and are wrapped by an adapter.
2 |
3 | The adapter is responsible for loading the correct provider and applying the feature flags within
4 | various contexts in cdf. This allows for a clean separation of concerns and makes it easy to
5 | implement new feature flag providers in the future.
6 | """
7 |
8 | import typing as t
9 |
10 | import dlt
11 | from dlt.common.configuration import with_config
12 |
13 | from cdf.integrations.feature_flag.base import AbstractFeatureFlagAdapter
14 | from cdf.integrations.feature_flag.file import FilesystemFeatureFlagAdapter
15 | from cdf.integrations.feature_flag.harness import HarnessFeatureFlagAdapter
16 | from cdf.integrations.feature_flag.launchdarkly import LaunchDarklyFeatureFlagAdapter
17 | from cdf.integrations.feature_flag.noop import NoopFeatureFlagAdapter
18 | from cdf.integrations.feature_flag.split import SplitFeatureFlagAdapter
19 | from cdf.types import M
20 |
21 | ADAPTERS: t.Dict[str, t.Type[AbstractFeatureFlagAdapter]] = {
22 | "filesystem": FilesystemFeatureFlagAdapter,
23 | "harness": HarnessFeatureFlagAdapter,
24 | "launchdarkly": LaunchDarklyFeatureFlagAdapter,
25 | "split": SplitFeatureFlagAdapter,
26 | "noop": NoopFeatureFlagAdapter,
27 | }
28 | """Feature flag provider adapters classes by name."""
29 |
30 |
31 | @with_config(sections=("feature_flags",))
32 | def get_feature_flag_adapter_cls(
33 | provider: str = dlt.config.value,
34 | ) -> M.Result[t.Type[AbstractFeatureFlagAdapter], Exception]:
35 | """Get a feature flag adapter by name.
36 |
37 | Args:
38 | provider: The name of the feature flag adapter.
39 | options: The configuration for the feature flag adapter.
40 |
41 | Returns:
42 | The feature flag adapter.
43 | """
44 | try:
45 | if provider not in ADAPTERS:
46 | raise KeyError(
47 | f"Unknown provider: {provider}. Available providers: {', '.join(ADAPTERS.keys())}"
48 | )
49 | return M.ok(ADAPTERS[provider])
50 | except KeyError as e:
51 | # Notify available providers
52 | return M.error(e)
53 | except Exception as e:
54 | return M.error(e)
55 |
56 |
57 | __all__ = [
58 | "ADAPTERS",
59 | "AbstractFeatureFlagAdapter",
60 | "FilesystemFeatureFlagAdapter",
61 | "HarnessFeatureFlagAdapter",
62 | "LaunchDarklyFeatureFlagAdapter",
63 | "NoopFeatureFlagAdapter",
64 | "SplitFeatureFlagAdapter",
65 | "get_feature_flag_adapter_cls",
66 | ]
67 |
--------------------------------------------------------------------------------
/src/cdf/proxy/mysql.py:
--------------------------------------------------------------------------------
1 | """A MySQL proxy server which uses SQLMesh to execute queries."""
2 |
3 | import typing as t
4 | import asyncio
5 | import logging
6 | from collections import defaultdict
7 |
8 | import numpy as np
9 | import sqlmesh
10 | from mysql_mimic import MysqlServer, Session
11 | from mysql_mimic.server import logger
12 | from sqlglot import exp
13 |
14 |
15 | async def file_watcher(context: sqlmesh.Context) -> None:
16 | """Watch for changes in the workspace and refresh the context."""
17 | while True:
18 | await asyncio.sleep(5.0)
19 | await asyncio.to_thread(context.refresh)
20 |
21 |
22 | class SQLMeshSession(Session):
23 | """A session for the MySQL proxy server which uses SQLMesh."""
24 |
25 | context: sqlmesh.Context
26 |
27 | async def query(
28 | self, expression: exp.Expression, sql: str, attrs: t.Dict[str, str]
29 | ) -> t.Tuple[t.Tuple[t.Tuple[t.Any], ...], t.List[str]]:
30 | """Execute a query."""
31 | tables = list(expression.find_all(exp.Table))
32 | if any((table.db, table.name) == ("__semantic", "__table") for table in tables):
33 | expression = self.context.rewrite(sql)
34 | logger.info("Compiled semantic expression!")
35 | logger.info(expression.sql(self.context.default_dialect))
36 | df = self.context.fetchdf(expression)
37 | logger.debug(df)
38 | df.replace({np.nan: None}, inplace=True)
39 | return tuple(df.itertuples(index=False)), list(df.columns)
40 |
41 | async def schema(self) -> t.Dict[str, t.Dict[str, t.Dict[str, str]]]:
42 | """Get the schema of the database."""
43 | schema = defaultdict(dict)
44 | for model in self.context.models.values():
45 | fqn = model.fully_qualified_table
46 | if model.columns_to_types and all(
47 | typ is not None for typ in model.columns_to_types.values()
48 | ):
49 | schema[fqn.db][fqn.name] = model.columns_to_types
50 | return schema
51 |
52 |
53 | async def run_mysql_proxy(context: sqlmesh.Context) -> None:
54 | """Run the MySQL proxy server."""
55 |
56 | logging.basicConfig(level=logging.DEBUG)
57 | server = MysqlServer(
58 | session_factory=type(
59 | "BoundSQLMeshSession",
60 | (SQLMeshSession,),
61 | {"context": context},
62 | )
63 | )
64 | asyncio.create_task(file_watcher(context))
65 | try:
66 | await server.serve_forever()
67 | except asyncio.CancelledError:
68 | await server.wait_closed()
69 |
70 |
71 | __all__ = ["run_mysql_proxy"]
72 |
--------------------------------------------------------------------------------
/tests/core/test_workspace.py:
--------------------------------------------------------------------------------
1 | import cdf.core.component as cmp
2 | import cdf.core.configuration as conf
3 | import cdf.core.injector as injector
4 | from cdf.core.workspace import Workspace
5 |
6 |
7 | def test_workspace():
8 | import dlt
9 |
10 | @dlt.source
11 | def test_source(a: int, prod_bigquery: str):
12 | @dlt.resource
13 | def test_resource():
14 | yield from [{"a": a, "prod_bigquery": prod_bigquery}]
15 |
16 | return [test_resource]
17 |
18 | # Define a workspace
19 | datateam = Workspace(
20 | name="data-team",
21 | version="0.1.1",
22 | configuration_sources=[
23 | # DATATEAM_CONFIG,
24 | {
25 | "sfdc": {"username": "abc"},
26 | "bigquery": {"project_id": ...},
27 | },
28 | ],
29 | service_definitions=[
30 | cmp.Service(
31 | name="a",
32 | main=injector.Dependency.instance(1),
33 | owner="Alex",
34 | description="A secret number",
35 | sla=cmp.ServiceLevelAgreement.CRITICAL,
36 | ),
37 | cmp.Service(
38 | name="b",
39 | main=injector.Dependency.prototype(lambda a: a + 1 * 5 / 10),
40 | owner="Alex",
41 | ),
42 | cmp.Service(
43 | name="prod_bigquery",
44 | main=injector.Dependency.instance("dwh-123"),
45 | owner="DataTeam",
46 | ),
47 | cmp.Service(
48 | name="sfdc",
49 | main=injector.Dependency(
50 | factory=lambda username: f"https://sfdc.com/{username}",
51 | conf_spec=("sfdc",),
52 | ),
53 | owner="RevOps",
54 | ),
55 | ],
56 | )
57 |
58 | @conf.map_config_values(secret_number="a.b.c")
59 | def c(secret_number: int, sfdc: str) -> int:
60 | print(f"SFDC: {sfdc=}")
61 | return secret_number * 10
62 |
63 | # Imperatively add dependencies or config if needed
64 | datateam.container.add_from_dependency(injector.Dependency.prototype(c))
65 | datateam.conf_resolver.import_source({"a.b.c": 10})
66 |
67 | def source_a(a: int, prod_bigquery: str):
68 | print(f"Source A: {a=}, {prod_bigquery=}")
69 |
70 | # Some interface examples
71 | assert datateam.name == "data-team"
72 | datateam.invoke(source_a)
73 | assert datateam.conf_resolver["sfdc.username"] == "abc"
74 | assert datateam.container.resolve_or_raise("sfdc") == "https://sfdc.com/abc"
75 | assert datateam.invoke(c) == 100
76 | assert list(datateam.invoke(test_source)) == [{"a": 1, "prod_bigquery": "dwh-123"}]
77 |
--------------------------------------------------------------------------------
/src/cdf/proxy/planner.py:
--------------------------------------------------------------------------------
1 | """An http server which executed a plan which is a pickled pydantic model
2 |
3 | This is purely a POC. It will be replaced by a more robust solution in the future
4 | using flask or fastapi. It will always be designed such that input must be
5 | trusted. In an environment where the input is not trusted, the server should
6 | never be exposed to the internet. It should always be behind a firewall and
7 | only accessible by trusted clients.
8 | """
9 |
10 | import pickle
11 | import typing as t
12 | import http.server
13 | import socketserver
14 | import traceback
15 | import logging
16 | import uuid
17 | import json
18 | import io
19 | from contextlib import redirect_stdout, redirect_stderr
20 |
21 | import sqlmesh
22 |
23 |
24 | def run_plan_server(port: int, context: sqlmesh.Context) -> None:
25 | """Listen on a port and execute plans."""
26 |
27 | # TODO: move this
28 | logging.basicConfig(level=logging.DEBUG)
29 |
30 | def _plan(plan: t.Any) -> t.Any:
31 | """Run a plan"""
32 | stdout = io.StringIO()
33 | stderr = io.StringIO()
34 | with redirect_stdout(stdout), redirect_stderr(stderr):
35 | context.apply(plan)
36 | return {
37 | "stdout": stdout.getvalue(),
38 | "stderr": stderr.getvalue(),
39 | "execution_id": uuid.uuid4().hex,
40 | }
41 |
42 | class Handler(http.server.SimpleHTTPRequestHandler):
43 | def do_GET(self) -> None:
44 | """Ping the server"""
45 | self.send_response(200)
46 | self.send_header("Content-type", "text/plain")
47 | self.end_headers()
48 | self.wfile.write(b"Pong")
49 |
50 | def do_POST(self) -> None:
51 | """Run the plan"""
52 | content_length = int(self.headers["Content-Length"])
53 | ser_plan = self.rfile.read(content_length)
54 | try:
55 | plan = pickle.loads(ser_plan)
56 | resp = _plan(plan)
57 | self.send_response(200)
58 | self.send_header("Content-type", "application/json")
59 | self.end_headers()
60 | self.wfile.write(json.dumps(resp).encode())
61 | except Exception as e:
62 | self.send_response(500)
63 | self.send_header("Content-type", "text/plain")
64 | self.end_headers()
65 | self.wfile.write(str(e).encode())
66 | self.wfile.write(b"\n")
67 | self.wfile.write(traceback.format_exc().encode())
68 |
69 | with socketserver.TCPServer(("", port), Handler) as httpd:
70 | logging.info("serving at port %s", port)
71 | try:
72 | httpd.serve_forever()
73 | except KeyboardInterrupt:
74 | pass
75 |
--------------------------------------------------------------------------------
/src/cdf/core/component/pipeline.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import typing as t
3 |
4 | from .base import Entrypoint
5 |
6 | if t.TYPE_CHECKING:
7 | from dlt.common.destination import Destination as DltDestination
8 | from dlt.common.pipeline import LoadInfo
9 | from dlt.pipeline.pipeline import Pipeline as DltPipeline
10 |
11 | _GRN = "\033[32;1m"
12 | _YLW = "\033[33;1m"
13 | _RED = "\033[31;1m"
14 | _CLR = "\033[0m"
15 |
16 | TEST_RESULT_MAP = {
17 | None: f"{_YLW}SKIP{_CLR}",
18 | True: f"{_GRN}PASS{_CLR}",
19 | False: f"{_RED}FAIL{_CLR}",
20 | }
21 |
22 | DataPipelineProto = t.Tuple[
23 | "DltPipeline",
24 | t.Union[
25 | t.Callable[..., "LoadInfo"],
26 | t.Callable[..., t.Iterator["LoadInfo"]],
27 | ], # run
28 | t.Sequence[t.Callable[..., t.Optional[t.Union[bool, t.Tuple[bool, str]]]]], # tests
29 | ]
30 |
31 |
32 | class DataPipeline(
33 | Entrypoint[DataPipelineProto],
34 | frozen=True,
35 | ):
36 | """A data pipeline which loads data from a source to a destination."""
37 |
38 | def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.List["LoadInfo"]:
39 | """Run the data pipeline"""
40 | _, runner, _ = self.main(*args, **kwargs)
41 | if inspect.isgeneratorfunction(runner):
42 | return list(runner())
43 | return [t.cast("LoadInfo", runner())]
44 |
45 | run = __call__
46 |
47 | def unwrap(self) -> "DltPipeline":
48 | """Get the dlt pipeline object."""
49 | pipeline, _, _ = self.main()
50 | return pipeline
51 |
52 | def get_schemas(self, destination: t.Optional["DltDestination"] = None):
53 | """Get the schemas for the pipeline."""
54 | pipeline = self.unwrap()
55 | pipeline.sync_destination(destination=destination)
56 | return pipeline.schemas
57 |
58 | def run_tests(self) -> None:
59 | """Run the integration test for the pipeline."""
60 | _, _, tests = self.main()
61 | if not tests:
62 | raise ValueError("No tests found for pipeline")
63 | tpl = "[{nr}/{tot}] {message} ({state})"
64 | tot = len(tests)
65 | for nr, test in enumerate(tests, 1):
66 | result_struct = test()
67 | if isinstance(result_struct, bool) or result_struct is None:
68 | result, reason = result_struct, "No message"
69 | elif isinstance(result_struct, tuple):
70 | result, reason = result_struct
71 | else:
72 | raise ValueError(
73 | f"Invalid return type `{type(result_struct)}`, expected none, bool, or tuple(bool, str)"
74 | )
75 | if result not in TEST_RESULT_MAP:
76 | raise ValueError(f"Invalid return status for test: `{result}`")
77 | print(
78 | tpl.format(
79 | nr=nr, tot=tot, state=TEST_RESULT_MAP[result], message=reason
80 | )
81 | )
82 |
--------------------------------------------------------------------------------
/src/cdf/builtin/filters.py:
--------------------------------------------------------------------------------
1 | """Built-in filters for CDF
2 |
3 | They can be referenced via absolute import paths in a pipeline spec.
4 | """
5 |
6 | import typing as t
7 |
8 | FilterFunc = t.Callable[[t.Any], bool]
9 |
10 |
11 | def not_empty(key: str) -> FilterFunc:
12 | """Filters out items where a key is empty"""
13 |
14 | def _not_empty(item: t.Any) -> bool:
15 | if item.get(key) is None:
16 | return False
17 | if isinstance(item[key], str):
18 | return item[key].strip() != ""
19 | if isinstance(item[key], list):
20 | return len(item[key]) > 0
21 | if isinstance(item[key], dict):
22 | return len(item[key]) > 0
23 | return True
24 |
25 | return _not_empty
26 |
27 |
28 | def not_null(key: str) -> FilterFunc:
29 | """Filters out items where a key is null"""
30 |
31 | def _not_null(item: t.Any) -> bool:
32 | return item.get(key) is not None
33 |
34 | return _not_null
35 |
36 |
37 | def gt(key: str, value: t.Any) -> FilterFunc:
38 | """Filters out items where a key is greater than a value"""
39 |
40 | def _greater_than(item: t.Any) -> bool:
41 | return item[key] > value
42 |
43 | return _greater_than
44 |
45 |
46 | def lt(key: str, value: t.Any) -> FilterFunc:
47 | """Filters out items where a key is less than a value"""
48 |
49 | def _less_than(item: t.Any) -> bool:
50 | return item[key] < value
51 |
52 | return _less_than
53 |
54 |
55 | def gte(key: str, value: t.Any) -> FilterFunc:
56 | """Filters out items where a key is greater than or equal to a value"""
57 |
58 | def _greater_than_or_equal(item: t.Any) -> bool:
59 | return item[key] >= value
60 |
61 | return _greater_than_or_equal
62 |
63 |
64 | def lte(key: str, value: t.Any) -> FilterFunc:
65 | """Filters out items where a key is less than or equal to a value"""
66 |
67 | def _less_than_or_equal(item: t.Any) -> bool:
68 | return item[key] <= value
69 |
70 | return _less_than_or_equal
71 |
72 |
73 | def eq(key: str, value: t.Any) -> FilterFunc:
74 | """Filters out items where a key is equal to a value"""
75 |
76 | def _equal(item: t.Any) -> bool:
77 | return item[key] == value
78 |
79 | return _equal
80 |
81 |
82 | def ne(key: str, value: t.Any) -> FilterFunc:
83 | """Filters out items where a key is not equal to a value"""
84 |
85 | def _not_equal(item: t.Any) -> bool:
86 | return item[key] != value
87 |
88 | return _not_equal
89 |
90 |
91 | def in_list(key: str, value: t.List[str]) -> FilterFunc:
92 | """Filters out items where a key is in a list of values"""
93 |
94 | def _in_list(item: t.Any) -> bool:
95 | return item[key] in value
96 |
97 | return _in_list
98 |
99 |
100 | def not_in_list(key: str, value: t.List[str]) -> FilterFunc:
101 | """Filters out items where a key is not in a list of values"""
102 |
103 | def _not_in_list(item: t.Any) -> bool:
104 | return item[key] not in value
105 |
106 | return _not_in_list
107 |
--------------------------------------------------------------------------------
/tests/core/injector/test_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | from cdf.core.configuration import ConfigResolver
7 |
8 |
9 | def test_apply_converters():
10 | with patch.dict("os.environ", {}):
11 | os.environ["CDF_TEST"] = "1"
12 | assert ConfigResolver.apply_converters("$CDF_TEST") == "1"
13 | assert ConfigResolver.apply_converters("@int ${CDF_TEST}") == 1
14 | os.environ["CDF_BOOL"] = "true"
15 | assert ConfigResolver.apply_converters("@bool ${CDF_BOOL}") is True
16 | os.environ["CDF_FLOAT"] = "3.14"
17 | assert ConfigResolver.apply_converters("@float ${CDF_FLOAT}") == 3.14
18 | os.environ["CDF_JSON"] = '{"key": "value"}'
19 | assert ConfigResolver.apply_converters("@json ${CDF_JSON}") == {"key": "value"}
20 | os.environ["CDF_PATH"] = "tests/v2/test_config.py"
21 | assert ConfigResolver.apply_converters("@path ${CDF_PATH}") == os.path.abspath(
22 | "tests/v2/test_config.py"
23 | )
24 | os.environ["CDF_DICT"] = "{'key': 'value'}"
25 | assert ConfigResolver.apply_converters("@dict ${CDF_DICT}") == {"key": "value"}
26 | os.environ["CDF_LIST"] = "['key', 'value']"
27 | assert ConfigResolver.apply_converters("@list ${CDF_LIST}") == ["key", "value"]
28 | os.environ["CDF_TUPLE"] = "('key', 'value')"
29 | assert ConfigResolver.apply_converters("@tuple ${CDF_TUPLE}") == (
30 | "key",
31 | "value",
32 | )
33 | os.environ["CDF_SET"] = "{'key', 'value'}"
34 | assert ConfigResolver.apply_converters("@set ${CDF_SET}") == {"key", "value"}
35 |
36 | with pytest.raises(ValueError):
37 | ConfigResolver.apply_converters("@unknown_converter idk")
38 | with pytest.raises(ValueError):
39 | ConfigResolver.apply_converters("@int something")
40 |
41 | assert ConfigResolver.apply_converters("no conversion") == "no conversion"
42 |
43 |
44 | def test_config_resolver():
45 | os.environ["CDF_TEST"] = "1"
46 | resolver = ConfigResolver(
47 | {
48 | "main_api": {
49 | "user": "someone",
50 | "password": "secret",
51 | "database": "test",
52 | },
53 | "db_1": "@int ${CDF_TEST}",
54 | "db_2": "@resolve main_api",
55 | }
56 | )
57 |
58 | assert resolver["main_api"] == {
59 | "user": "someone",
60 | "password": "secret",
61 | "database": "test",
62 | }
63 | assert resolver["db_1"] == 1
64 | assert resolver["main_api"] == resolver["db_2"]
65 | resolver.import_source({"db_1": 2})
66 | assert resolver["db_1"] == 2
67 |
68 | @ConfigResolver.map_values(db_1="db_1", db_2="db_2")
69 | def foo(db_1: int, db_2: dict):
70 | return db_1, db_2
71 |
72 | foo_configured = resolver.resolve_defaults(foo)
73 | assert foo_configured() == (
74 | 2,
75 | {"user": "someone", "password": "secret", "database": "test"},
76 | )
77 |
78 | @ConfigResolver.map_values(
79 | user="main_api.user", password="main_api.password", database="main_api.database"
80 | )
81 | def bar(user: str, password: str, database: str):
82 | return user, password, database
83 |
84 | bar_configured = resolver.resolve_defaults(bar)
85 | assert bar_configured() == ("someone", "secret", "test")
86 |
87 | assert "main_api" in resolver
88 |
--------------------------------------------------------------------------------
/src/cdf/integrations/feature_flag/file.py:
--------------------------------------------------------------------------------
1 | """File-based feature flag provider."""
2 |
3 | import json
4 | import logging
5 | import typing as t
6 | from collections import defaultdict
7 | from threading import Lock
8 |
9 | import dlt
10 | import fsspec
11 | from dlt.common.configuration import with_config
12 |
13 | from cdf.integrations.feature_flag.base import (
14 | AbstractFeatureFlagAdapter,
15 | FlagAdapterResponse,
16 | )
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class FilesystemFeatureFlagAdapter(AbstractFeatureFlagAdapter):
22 | """A feature flag adapter that uses the filesystem."""
23 |
24 | _LOCK = defaultdict(Lock)
25 |
26 | __cdf_resolve__ = ("feature_flags",)
27 |
28 | @with_config(sections=("feature_flags",))
29 | def __init__(
30 | self,
31 | filesystem: fsspec.AbstractFileSystem,
32 | filename: str = dlt.config.value,
33 | **kwargs: t.Any,
34 | ) -> None:
35 | """Initialize the filesystem feature flags.
36 |
37 | Args:
38 | filesystem: The filesystem to use.
39 | filename: The filename to use for the feature flags.
40 | """
41 | self.filename = filename
42 | self.filesystem = filesystem
43 | self.__flags: t.Optional[t.Dict[str, FlagAdapterResponse]] = None
44 |
45 | def __repr__(self) -> str:
46 | return f"{type(self).__name__}(filename={self.filename!r})"
47 |
48 | def __str__(self) -> str:
49 | return self.filename
50 |
51 | def _read(self) -> t.Dict[str, FlagAdapterResponse]:
52 | """Read the feature flags from the filesystem."""
53 | logger.info("Reading feature flags from %s", self.filename)
54 | if not self.filesystem.exists(self.filename):
55 | flags = {}
56 | else:
57 | with self.filesystem.open(self.filename) as file:
58 | flags = json.load(file)
59 | return {k: FlagAdapterResponse.from_bool(v) for k, v in flags.items()}
60 |
61 | def _commit(self) -> None:
62 | """Commit the feature flags to the filesystem."""
63 | logger.info("Committing feature flags to %s", self.filename)
64 | with (
65 | self._LOCK[self.filename],
66 | self.filesystem.open(self.filename, "w") as file,
67 | ):
68 | json.dump({k: v.to_bool() for k, v in self._flags.items()}, file, indent=2)
69 |
70 | @property
71 | def _flags(self) -> t.Dict[str, FlagAdapterResponse]:
72 | """Get the feature flags."""
73 | if self.__flags is None:
74 | self.__flags = self._read()
75 | return t.cast(t.Dict[str, FlagAdapterResponse], self.__flags)
76 |
77 | def get(self, feature_name: str) -> FlagAdapterResponse:
78 | """Get a feature flag.
79 |
80 | Args:
81 | feature_name: The name of the feature flag.
82 |
83 | Returns:
84 | The feature flag.
85 | """
86 | return self._flags.get(feature_name, FlagAdapterResponse.NOT_FOUND)
87 |
88 | def get_all_feature_names(self) -> t.List[str]:
89 | """Get all feature flag names.
90 |
91 | Returns:
92 | The feature flag names.
93 | """
94 | return list(self._flags.keys())
95 |
96 | def save(self, feature_name: str, flag: bool) -> None:
97 | """Save a feature flag.
98 |
99 | Args:
100 | feature_name: The name of the feature flag.
101 | flag: The value of the feature flag.
102 | """
103 | self._flags[feature_name] = FlagAdapterResponse.from_bool(flag)
104 | self._commit()
105 |
106 | def save_many(self, flags: t.Dict[str, bool]) -> None:
107 | """Save multiple feature flags.
108 |
109 | Args:
110 | flags: The feature flags to save.
111 | """
112 | self._flags.update(
113 | {k: FlagAdapterResponse.from_bool(v) for k, v in flags.items()}
114 | )
115 | self._commit()
116 |
117 |
118 | __all__ = ["FilesystemFeatureFlagAdapter"]
119 |
--------------------------------------------------------------------------------
/src/cdf/core/context.py:
--------------------------------------------------------------------------------
1 | """Context management utilities for managing the active workspace."""
2 |
3 | import contextlib
4 | import functools
5 | import logging
6 | import typing as t
7 | from contextvars import ContextVar, Token
8 |
9 | if t.TYPE_CHECKING:
10 | from cdf.core.injector import Lifecycle
11 | from cdf.core.workspace import Workspace
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | _ACTIVE_WORKSPACE: ContextVar[t.Optional["Workspace"]] = ContextVar(
16 | "active_workspace", default=None
17 | )
18 | """The active workspace for resolving injected dependencies."""
19 |
20 | _DEFAULT_CALLABLE_LIFECYCLE: ContextVar[t.Optional["Lifecycle"]] = ContextVar(
21 | "default_callable_lifecycle", default=None
22 | )
23 | """The default lifecycle for callables when otherwise unspecified."""
24 |
25 |
26 | def get_active_workspace() -> t.Optional["Workspace"]:
27 | """Get the active workspace for resolving injected dependencies."""
28 | return _ACTIVE_WORKSPACE.get()
29 |
30 |
31 | def set_active_workspace(workspace: t.Optional["Workspace"]) -> Token:
32 | """Set the active workspace for resolving injected dependencies."""
33 | return _ACTIVE_WORKSPACE.set(workspace)
34 |
35 |
36 | @contextlib.contextmanager
37 | def use_workspace(workspace: t.Optional["Workspace"]) -> t.Iterator[None]:
38 | """Context manager for temporarily setting the active workspace."""
39 | token = set_active_workspace(workspace)
40 | try:
41 | yield
42 | finally:
43 | set_active_workspace(token.old_value)
44 |
45 |
46 | T = t.TypeVar("T")
47 |
48 |
49 | @t.overload
50 | def resolve(
51 | dependencies: t.Callable[..., T],
52 | configuration: bool = ...,
53 | eagerly_bind_workspace: bool = ...,
54 | ) -> t.Callable[..., T]: ...
55 |
56 |
57 | @t.overload
58 | def resolve(
59 | dependencies: bool = ...,
60 | configuration: bool = ...,
61 | eagerly_bind_workspace: bool = ...,
62 | ) -> t.Callable[[t.Callable[..., T]], t.Callable[..., T]]: ...
63 |
64 |
65 | def resolve(
66 | dependencies: t.Union[t.Callable[..., T], bool] = True,
67 | configuration: bool = True,
68 | eagerly_bind_workspace: bool = False,
69 | ) -> t.Callable[..., t.Union[T, t.Callable[..., T]]]:
70 | """Decorator for injecting dependencies and resolving configuration for a function."""
71 |
72 | if eagerly_bind_workspace:
73 | # Get the active workspace before the function is resolved
74 | workspace = get_active_workspace()
75 | else:
76 | workspace = None
77 |
78 | def _resolve(func: t.Callable[..., T]) -> t.Callable[..., T]:
79 | @functools.wraps(func)
80 | def wrapper(*args: t.Any, **kwargs: t.Any) -> T:
81 | nonlocal func, workspace
82 | workspace = workspace or get_active_workspace()
83 | if workspace is None:
84 | return func(*args, **kwargs)
85 | if configuration:
86 | func = workspace.conf_resolver.resolve_defaults(func)
87 | if dependencies:
88 | func = workspace.container.wire(func)
89 | return func(*args, **kwargs)
90 |
91 | return wrapper
92 |
93 | if callable(dependencies):
94 | return _resolve(dependencies)
95 |
96 | return _resolve
97 |
98 |
99 | def invoke(func_or_cls: t.Callable, *args: t.Any, **kwargs: t.Any) -> t.Any:
100 | """Invoke a function or class with resolved dependencies."""
101 | workspace = get_active_workspace()
102 | if workspace is None:
103 | logger.debug("Invoking %s without a bound workspace", func_or_cls)
104 | return func_or_cls(*args, **kwargs)
105 | logger.debug("Invoking %s bound to workspace %s", func_or_cls, workspace)
106 | return workspace.invoke(func_or_cls, *args, **kwargs)
107 |
108 |
109 | def get_default_callable_lifecycle() -> "Lifecycle":
110 | """Get the default lifecycle for callables when otherwise unspecified."""
111 | from cdf.core.injector import Lifecycle
112 |
113 | return _DEFAULT_CALLABLE_LIFECYCLE.get() or Lifecycle.SINGLETON
114 |
115 |
116 | def set_default_callable_lifecycle(lifecycle: t.Optional["Lifecycle"]) -> Token:
117 | """Set the default lifecycle for callables when otherwise unspecified."""
118 | if lifecycle and lifecycle.is_instance:
119 | raise ValueError("Default callable lifecycle cannot be set to INSTANCE")
120 | return _DEFAULT_CALLABLE_LIFECYCLE.set(lifecycle)
121 |
122 |
123 | @contextlib.contextmanager
124 | def use_default_callable_lifecycle(
125 | lifecycle: t.Optional["Lifecycle"],
126 | ) -> t.Iterator[None]:
127 | """Context manager for temporarily setting the default callable lifecycle."""
128 | token = set_default_callable_lifecycle(lifecycle)
129 | try:
130 | yield
131 | finally:
132 | set_default_callable_lifecycle(token.old_value)
133 |
--------------------------------------------------------------------------------
/src/cdf/integrations/feature_flag/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import typing as t
3 | from enum import Enum, auto
4 |
5 | if t.TYPE_CHECKING:
6 | from dlt.sources import DltSource
7 |
8 |
9 | class FlagAdapterResponse(Enum):
10 | """Feature flag response.
11 |
12 | This enum is used to represent the state of a feature flag. It is similar
13 | to a boolean but with an extra state for when the flag is not found.
14 | """
15 |
16 | ENABLED = auto()
17 | """The feature flag is enabled."""
18 | DISABLED = auto()
19 | """The feature flag is disabled."""
20 | NOT_FOUND = auto()
21 | """The feature flag is not found."""
22 |
23 | def __bool__(self) -> bool:
24 | """Return True if the flag is enabled and False otherwise."""
25 | return self is FlagAdapterResponse.ENABLED
26 |
27 | to_bool = __bool__
28 |
29 | def __eq__(self, value: object, /) -> bool:
30 | """Compare the flag to a boolean."""
31 | if isinstance(value, bool):
32 | return self is FlagAdapterResponse.ENABLED and value
33 | return super().__eq__(value)
34 |
35 | @classmethod
36 | def from_bool(cls, flag: bool) -> "FlagAdapterResponse":
37 | """Convert a boolean to a flag response."""
38 | return cls.ENABLED if flag else cls.DISABLED
39 |
40 |
41 | class AbstractFeatureFlagAdapter(abc.ABC):
42 | """Abstract feature flag adapter."""
43 |
44 | def __init__(self, **kwargs: t.Any) -> None:
45 | """Initialize the adapter."""
46 | pass
47 |
48 | @abc.abstractmethod
49 | def get(self, feature_name: str) -> FlagAdapterResponse:
50 | """Get the feature flag."""
51 | pass
52 |
53 | def __getitem__(self, feature_name: str) -> FlagAdapterResponse:
54 | """Get the feature flag."""
55 | return self.get(feature_name)
56 |
57 | def get_many(self, feature_names: t.List[str]) -> t.Dict[str, FlagAdapterResponse]:
58 | """Get many feature flags.
59 |
60 | Implementations should override this method if they can optimize it. The default
61 | will call get in a loop.
62 | """
63 | return {feature_name: self.get(feature_name) for feature_name in feature_names}
64 |
65 | @abc.abstractmethod
66 | def save(self, feature_name: str, flag: bool) -> None:
67 | """Save the feature flag."""
68 | pass
69 |
70 | def __setitem__(self, feature_name: str, flag: bool) -> None:
71 | """Save the feature flag."""
72 | self.save(feature_name, flag)
73 |
74 | def save_many(self, flags: t.Dict[str, bool]) -> None:
75 | """Save many feature flags.
76 |
77 | Implementations should override this method if they can optimize it. The default
78 | will call save in a loop.
79 | """
80 | for feature_name, flag in flags.items():
81 | self.save(feature_name, flag)
82 |
83 | @abc.abstractmethod
84 | def get_all_feature_names(self) -> t.List[str]:
85 | """Get all feature names."""
86 | pass
87 |
88 | def keys(self) -> t.List[str]:
89 | """Get all feature names."""
90 | return self.get_all_feature_names()
91 |
92 | def __iter__(self) -> t.Iterator[str]:
93 | """Iterate over the feature names."""
94 | return iter(self.get_all_feature_names())
95 |
96 | def __contains__(self, feature_name: str) -> bool:
97 | """Check if a feature flag exists."""
98 | return self.get(feature_name) is not FlagAdapterResponse.NOT_FOUND
99 |
100 | def __len__(self) -> int:
101 | """Get the number of feature flags."""
102 | return len(self.get_all_feature_names())
103 |
104 | def delete(self, feature_name: str) -> None:
105 | """Delete a feature flag.
106 |
107 | By default, this will disable the flag but implementations can override this method
108 | to delete the flag.
109 | """
110 | self.save(feature_name, False)
111 |
112 | __delitem__ = delete
113 |
114 | def delete_many(self, feature_names: t.List[str]) -> None:
115 | """Delete many feature flags."""
116 | self.save_many({feature_name: False for feature_name in feature_names})
117 |
118 | def apply_source(self, source: "DltSource", *namespace: str) -> "DltSource":
119 | """Apply the feature flags to a dlt source.
120 |
121 | Args:
122 | source: The source to apply the feature flags to.
123 |
124 | Returns:
125 | The source with the feature flags applied.
126 | """
127 | new = {}
128 | source_name = source.name
129 | for resource_name, resource in source.selected_resources.items():
130 | k = ".".join(filter(lambda s: s, [*namespace, source_name, resource_name]))
131 | resp = self.get(k)
132 | resource.selected = bool(resp)
133 | if resp is FlagAdapterResponse.NOT_FOUND:
134 | new[k] = False
135 | if new:
136 | self.save_many(new)
137 |
138 | return source
139 |
--------------------------------------------------------------------------------
/tests/core/injector/test_registry.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 | from unittest.mock import MagicMock
3 |
4 | import pytest
5 |
6 | from cdf.core.injector.errors import DependencyCycleError
7 | from cdf.core.injector.registry import (
8 | Dependency,
9 | DependencyRegistry,
10 | Lifecycle,
11 | TypedKey,
12 | )
13 |
14 |
15 | def test_registry():
16 | # A generic test to show some of the API together
17 | container = DependencyRegistry()
18 | container.add("a", lambda: 1)
19 |
20 | def b(a: int) -> int:
21 | return a + 1
22 |
23 | container.add("b", b)
24 | container.add("obj_proto", object, container.lifecycle.PROTOTYPE)
25 | container.add("obj_singleton", object)
26 |
27 | def foo(a: int, b: int, c: int = 0) -> int:
28 | return a + b
29 |
30 | foo_wired = container.wire(foo)
31 |
32 | assert foo_wired() == 3
33 | assert foo_wired(1) == 3
34 | assert foo_wired(2) == 4
35 | assert foo_wired(3, 3) == 6
36 |
37 | assert container.get("obj_proto") is not container.get("obj_proto")
38 | assert container.get("obj_singleton") is container.get("obj_singleton")
39 |
40 | assert container(foo) == 3
41 |
42 | container.add("c", lambda a, b: a + b, container.lifecycle.PROTOTYPE)
43 |
44 | assert container(foo) == 3
45 |
46 | def bar(a: int, b: int, c: t.Optional[int] = None) -> int:
47 | if c is None:
48 | raise ValueError("c is required")
49 | return a + b + c
50 |
51 | assert container(bar) == 6
52 | assert container(bar, c=5) == 8
53 |
54 | @container.wire
55 | def baz(a: int, b: int, c: int = 0) -> int:
56 | return a + b + c
57 |
58 | assert baz() == 3
59 |
60 |
61 | @pytest.fixture
62 | def registry():
63 | return DependencyRegistry()
64 |
65 |
66 | def test_typed_key_creation():
67 | key = TypedKey(name="test", type_=int)
68 | assert key.name == "test"
69 | assert key.type_ is int
70 |
71 |
72 | def test_typed_key_equality():
73 | key1 = TypedKey("name1", int)
74 | key2 = TypedKey("name1", int)
75 | assert key1 == key2
76 |
77 |
78 | def test_typed_key_inequality():
79 | key1 = TypedKey("name1", int)
80 | key2 = TypedKey("name2", str)
81 | assert key1 != key2
82 |
83 |
84 | def test_typed_key_string_representation():
85 | key1 = TypedKey("name1", int)
86 | assert str(key1) == "name1: int"
87 |
88 |
89 | def test_instance_dependency():
90 | instance = Dependency.instance(42)
91 | assert instance() == 42
92 | assert instance.lifecycle.is_instance
93 |
94 |
95 | def test_singleton_dependency():
96 | factory = MagicMock(return_value=42)
97 | singleton_dep = Dependency.singleton(factory)
98 | assert singleton_dep.lifecycle == Lifecycle.SINGLETON
99 |
100 |
101 | def test_prototype_dependency():
102 | factory = MagicMock(return_value=42)
103 | prototype_dep = Dependency.prototype(factory)
104 | assert prototype_dep.lifecycle == Lifecycle.PROTOTYPE
105 |
106 |
107 | def test_apply_function_to_instance():
108 | dep = Dependency.instance(42)
109 | new_dep = dep.map_value(lambda x: x + 1)
110 | assert new_dep() == 43
111 |
112 |
113 | def test_apply_wrappers_to_factory():
114 | factory = MagicMock(return_value=42)
115 | dep = Dependency.singleton(factory)
116 | wrapper = MagicMock(side_effect=lambda f: lambda: f() + 1)
117 | dep_wrapped = dep.map(wrapper)
118 | assert dep_wrapped.unwrap() == 43
119 |
120 |
121 | def test_add_and_get_singleton(registry: DependencyRegistry):
122 | registry.add_singleton("test", object)
123 | retrieved1 = registry.get("test")
124 | retrieved2 = registry.get("test")
125 | assert retrieved1 is retrieved2 # Ensure same instance each time
126 |
127 |
128 | def test_add_and_get_prototype(registry: DependencyRegistry):
129 | registry.add_prototype("test", object)
130 | retrieved1 = registry.get("test")
131 | retrieved2 = registry.get("test")
132 | assert retrieved1 is not retrieved2 # Ensure new instance each time
133 |
134 |
135 | def test_add_and_get_instance(registry: DependencyRegistry):
136 | registry.add_instance("test", 42)
137 | retrieved = registry.get("test")
138 | assert retrieved == 42
139 |
140 |
141 | def test_wire_function(registry: DependencyRegistry):
142 | factory = MagicMock(return_value=42)
143 | registry.add_singleton("test", factory)
144 |
145 | @registry.wire
146 | def func(test):
147 | return test
148 |
149 | assert func() == 42
150 |
151 |
152 | def test_dependency_cycle(registry: DependencyRegistry):
153 | registry.add("left", lambda right: right)
154 | registry.add("right", lambda left: left)
155 | with pytest.raises(DependencyCycleError):
156 | registry["left"]
157 |
158 |
159 | def test_contains(registry: DependencyRegistry):
160 | factory = MagicMock(return_value=42)
161 | registry.add_singleton("test", factory)
162 | assert "test" in registry
163 |
164 |
165 | def test_remove_dependency(registry: DependencyRegistry):
166 | factory = MagicMock(return_value=42)
167 | registry.add_singleton("test", factory)
168 | assert "test" in registry
169 | registry.remove("test")
170 | assert "test" not in registry
171 |
--------------------------------------------------------------------------------
/src/cdf/builtin/metrics.py:
--------------------------------------------------------------------------------
1 | """Built-in metrics for CDF
2 |
3 | They can be referenced via absolute import paths in a pipeline spec.
4 | """
5 |
6 | import bisect
7 | import decimal
8 | import math
9 | import statistics
10 | import typing as t
11 | from collections import defaultdict
12 |
13 | TNumber = t.TypeVar("TNumber", int, float, decimal.Decimal)
14 |
15 | MetricFunc = t.Callable[[t.Any, TNumber], TNumber]
16 |
17 |
18 | def count(_: t.Any, metric: TNumber = 0) -> TNumber:
19 | """Counts the number of items in a dataset"""
20 | return metric + 1
21 |
22 |
23 | def unique(key: str) -> MetricFunc:
24 | """Counts the number of unique items in a dataset by a given key"""
25 | seen = set()
26 |
27 | def _unique(item: t.Any, _: t.Optional[TNumber] = None) -> int:
28 | k = item.get(key)
29 | if k is not None and k not in seen:
30 | seen.add(k)
31 | return len(seen)
32 |
33 | return _unique
34 |
35 |
36 | def max_value(key: str) -> MetricFunc:
37 | """Returns the maximum value of a key in a dataset"""
38 | first = True
39 |
40 | def _max_value(item: t.Any, metric: t.Optional[TNumber] = None) -> TNumber:
41 | nonlocal first
42 | k = item.get(key)
43 | if metric is None or first:
44 | first = False
45 | return k
46 | if k is None:
47 | return metric
48 | return max(metric, k)
49 |
50 | return _max_value
51 |
52 |
53 | def min_value(key: str) -> MetricFunc:
54 | """Returns the minimum value of a key in a dataset"""
55 | first = True
56 |
57 | def _min_value(item: t.Any, metric: t.Optional[TNumber] = None) -> TNumber:
58 | nonlocal first
59 | k = item.get(key)
60 | if metric is None or first:
61 | first = False
62 | return k
63 | if k is None:
64 | return metric
65 | return min(metric, k)
66 |
67 | return _min_value
68 |
69 |
70 | def sum_value(key: str) -> MetricFunc:
71 | """Returns the sum of a key in a dataset"""
72 |
73 | def _sum_value(item: t.Any, metric: TNumber = 0) -> TNumber:
74 | k = item.get(key, 0)
75 | return metric + k
76 |
77 | return _sum_value
78 |
79 |
80 | def avg_value(key: str) -> MetricFunc:
81 | """Returns the average of a key in a dataset"""
82 | n_sum, n_count = 0, 0
83 |
84 | def _avg_value(
85 | item: t.Any, last_value: t.Optional[TNumber] = None
86 | ) -> t.Optional[TNumber]:
87 | nonlocal n_sum, n_count
88 | k = item.get(key)
89 | if k is None:
90 | return last_value
91 | n_sum += k
92 | n_count += 1
93 | return n_sum / n_count
94 |
95 | return _avg_value
96 |
97 |
98 | def median_value(key: str, window: int = 1000) -> MetricFunc:
99 | """Returns the median of a key in a dataset"""
100 | arr = []
101 |
102 | def _median_value(
103 | item: t.Any, last_value: t.Optional[TNumber] = None
104 | ) -> t.Optional[TNumber]:
105 | nonlocal arr
106 | k = item.get(key)
107 | if k is None:
108 | return last_value
109 | bisect.insort(arr, k)
110 | if len(arr) > window:
111 | del arr[0], arr[-1]
112 | return statistics.median(arr)
113 |
114 | return _median_value
115 |
116 |
117 | def stdev_value(key: str) -> MetricFunc:
118 | """Returns the standard deviation of a key in a dataset"""
119 | n_sum, n_squared_sum, n_count = 0, 0, 0
120 |
121 | def _stdev_value(
122 | item: t.Any, last_value: t.Optional[TNumber] = None
123 | ) -> t.Optional[float]:
124 | nonlocal n_sum, n_squared_sum, n_count
125 | k = item.get(key)
126 | if k is None:
127 | return t.cast(t.Optional[float], last_value)
128 | n_sum += k
129 | n_squared_sum += k**2
130 | n_count += 1
131 | mean = n_sum / n_count
132 | return math.sqrt(n_squared_sum / n_count - mean**2)
133 |
134 | return _stdev_value
135 |
136 |
137 | def variance_value(key: str) -> MetricFunc:
138 | """Returns the variance of a key in a dataset"""
139 | n_sum, n_squared_sum, n_count = 0, 0, 0
140 |
141 | def _variance_value(
142 | item: t.Any, last_value: t.Optional[TNumber] = None
143 | ) -> t.Optional[float]:
144 | nonlocal n_sum, n_squared_sum, n_count
145 | k = item.get(key)
146 | if k is None:
147 | return t.cast(t.Optional[float], last_value)
148 | n_sum += k
149 | n_squared_sum += k**2
150 | n_count += 1
151 | if n_count == 1:
152 | return 0
153 | mean = n_sum / n_count
154 | return (n_squared_sum / n_count) - mean**2
155 |
156 | return _variance_value
157 |
158 |
159 | def mode_value(key: str) -> MetricFunc:
160 | """Returns the mode of a key in a dataset."""
161 | frequency = defaultdict(int)
162 |
163 | def _mode_value(item: t.Any, last_value: t.Optional[t.Any] = None) -> t.Any:
164 | nonlocal frequency
165 | k = item.get(key)
166 | if k is None:
167 | return last_value
168 | frequency[k] += 1
169 | return max(frequency.items(), key=lambda x: x[1])[0]
170 |
171 | return _mode_value
172 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
CDF (Continuous Data Framework)
5 |
Craft end-to-end data pipelines and manage them continuously
6 |
7 |
8 |
9 |
10 |
11 |
12 |

13 |

14 |

15 |

16 |
17 |
18 | ---
19 |
20 | ## 📖 Table of Contents
21 | - [📖 Table of Contents](#-table-of-contents)
22 | - [📍 Overview](#-overview)
23 | - [📦 Features](#-features)
24 | - [🚀 Getting Started](#-getting-started)
25 | - [📚 Documentation](#-documentation)
26 | - [🤝 Contributing](#-contributing)
27 | - [🛣 Roadmap](#-roadmap)
28 | - [📄 License](#-license)
29 | - [👏 Acknowledgments](#-acknowledgments)
30 |
31 | ---
32 |
33 | ## 📍 Overview
34 |
35 | CDF (Continuous Data Framework) is an integrated framework designed to manage data across the entire lifecycle, from ingestion through transformation to publishing. It is built on top of two open-source projects, `sqlmesh` and `dlt`, providing a unified interface for complex data operations. CDF simplifies data engineering workflows, offering scalable solutions from small to large projects through an opinionated project structure that supports both multi-workspace and single-workspace layouts.
36 |
37 | > [!WARNING]
38 | > The repo is currently under ACTIVE development with multiple large refactors already having been completed. As such, you must be aware that the codebase is not yet stable and is subject to change. Furthermore, you must look to the code (or tests) itself for the most accurate and up-to-date information until this disclaimer is removed.
39 |
40 | ## Features
41 |
42 | ...
43 |
44 | ## Getting Started
45 |
46 | 1. **Installation**:
47 |
48 | CDF requires Python 3.9 or newer. Install CDF using pip:
49 |
50 | ```bash
51 | pip install python-cdf
52 | ```
53 |
54 | ## Documentation
55 |
56 | For detailed documentation, including API references and tutorials, visit [CDF Documentation](#).
57 |
58 | ## Contributing
59 |
60 | Contributions to CDF are welcome! Please refer to the [contributing guidelines](CONTRIBUTING.md) for more information on how to submit pull requests, report issues, or suggest enhancements.
61 |
62 | ## License
63 |
64 | CDF is licensed under [Apache 2.0 License](LICENSE).
65 |
66 | ---
67 |
68 | This README provides an overview of the CDF tool, highlighting its primary features, installation steps, basic usage examples, and contribution guidelines. It serves as a starting point for users to understand the capabilities of CDF and how it can be integrated into their data engineering workflows.
69 |
70 | ### 🧪 Tests
71 |
72 | Run the tests with `pytest`:
73 |
74 | ```sh
75 | pytest tests
76 | ```
77 |
78 | ## 🛣 Project Roadmap
79 |
80 | TODO: Add a roadmap for the project.
81 |
82 |
83 | ## 🤝 Contributing
84 |
85 | Contributions are welcome! Here are several ways you can contribute:
86 |
87 | - **[Submit Pull Requests](https://github.com/z3z1ma/cdf/blob/main/CONTRIBUTING.md)**: Review open PRs, and submit your own PRs.
88 | - **[Join the Discussions](https://github.com/z3z1ma/cdf/discussions)**: Share your insights, provide feedback, or ask questions.
89 | - **[Report Issues](https://github.com/z3z1ma/cdf/issues)**: Submit bugs found or log feature requests for z3z1ma.
90 |
91 |
92 | #### *Contributing Guidelines*
93 |
94 |
95 | Click to expand
96 |
97 | 1. **Fork the Repository**: Start by forking the project repository to your GitHub account.
98 | 2. **Clone Locally**: Clone the forked repository to your local machine using a Git client.
99 | ```sh
100 | git clone
101 | ```
102 | 3. **Create a New Branch**: Always work on a new branch, giving it a descriptive name.
103 | ```sh
104 | git checkout -b new-feature-x
105 | ```
106 | 4. **Make Your Changes**: Develop and test your changes locally.
107 | 5. **Commit Your Changes**: Commit with a clear and concise message describing your updates.
108 | ```sh
109 | git commit -m 'Implemented new feature x.'
110 | ```
111 | 6. **Push to GitHub**: Push the changes to your forked repository.
112 | ```sh
113 | git push origin new-feature-x
114 | ```
115 | 7a. **Submit a Pull Request**: Create a PR against the original project repository. Clearly describe the changes and their motivations.
116 |
117 | Once your PR is reviewed and approved, it will be merged into the main branch.
118 |
119 |
120 |
121 | ---
122 |
123 | ## 📄 License
124 |
125 |
126 | This project is distributed under the [Apache 2.0](http://www.apache.org/licenses/LICENSE-2.0) License. For more details, refer to the [LICENSE](https://github.com/z3z1ma/cdf/blob/main/LICENSE) file.
127 |
128 | ---
129 |
130 | ## 👏 Acknowledgments
131 |
132 | - Harness (https://harness.io/) for being the proving grounds in which the initial concept of this project was born.
133 | - SQLMesh (https://sqlmesh.com) for being a foundational pillar of this project as well as the team for their support,
134 | advice, and guidance.
135 | - DLT (https://dlthub.com) for being the other foundational pillar of this project as well as the team for their
136 | support, advice, and guidance.
137 |
138 | [**Return**](#Top)
139 |
140 | ---
141 |
142 |
143 |
--------------------------------------------------------------------------------
/src/cdf/integrations/feature_flag/harness.py:
--------------------------------------------------------------------------------
1 | """Harness feature flag provider."""
2 |
3 | from __future__ import annotations
4 |
5 | import logging
6 | import os
7 | import typing as t
8 | from concurrent.futures import ThreadPoolExecutor
9 |
10 | import dlt
11 | from dlt.common.configuration import with_config
12 | from dlt.sources import DltSource
13 | from dlt.sources.helpers import requests
14 | from featureflags.client import CfClient, Config, Target
15 | from featureflags.evaluations.feature import FeatureConfigKind
16 | from featureflags.interface import Cache
17 | from featureflags.util import log as _ff_logger
18 |
19 | from cdf.integrations.feature_flag.base import (
20 | AbstractFeatureFlagAdapter,
21 | FlagAdapterResponse,
22 | )
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | # This exists because the default harness LRU implementation does not store >1000 flags
28 | # The interface is mostly satisfied by dict, so we subclass it and implement the missing methods
29 | class _HarnessCache(dict, Cache):
30 | """A cache implementation for the harness feature flag provider."""
31 |
32 | def set(self, key: str, value: bool) -> None:
33 | self[key] = value
34 |
35 | def remove(self, key: str | t.List[str]) -> None:
36 | if isinstance(key, str):
37 | self.pop(key, None)
38 | for k in key:
39 | self.pop(k, None)
40 |
41 |
42 | def _quiet_logger():
43 | """Configure the harness FF logger to only show errors. Its too verbose otherwise."""
44 | _ff_logger.setLevel(logging.ERROR)
45 |
46 |
47 | class HarnessFeatureFlagAdapter(AbstractFeatureFlagAdapter):
48 | _TARGET = Target("cdf")
49 |
50 | @with_config(sections=("feature_flags",))
51 | def __init__(
52 | self,
53 | sdk_key: str = dlt.secrets.value,
54 | api_key: str = dlt.secrets.value,
55 | account: str = dlt.secrets.value,
56 | organization: str = dlt.secrets.value,
57 | project: str = dlt.secrets.value,
58 | **kwargs: t.Any,
59 | ) -> None:
60 | """Initialize the adapter."""
61 | self.sdk_key = sdk_key
62 | self.api_key = api_key
63 | self.account = account
64 | self.organization = organization
65 | self.project = project
66 | self._pool = None
67 | self._client = None
68 | _quiet_logger()
69 |
70 | @property
71 | def client(self) -> CfClient:
72 | """Get the client and cache it in the instance."""
73 | if self._client is None:
74 | client = CfClient(
75 | sdk_key=str(self.sdk_key),
76 | config=Config(
77 | enable_stream=False, enable_analytics=False, cache=_HarnessCache()
78 | ),
79 | )
80 | client.wait_for_initialization()
81 | self._client = client
82 | return self._client
83 |
84 | @property
85 | def pool(self) -> ThreadPoolExecutor:
86 | """Get the thread pool."""
87 | if self._pool is None:
88 | self._pool = ThreadPoolExecutor(thread_name_prefix="cdf-ff-")
89 | return self._pool
90 |
91 | def get(self, feature_name: str) -> FlagAdapterResponse:
92 | """Get a feature flag."""
93 | if feature_name not in self.get_all_feature_names():
94 | return FlagAdapterResponse.NOT_FOUND
95 | return FlagAdapterResponse.from_bool(
96 | self.client.bool_variation(feature_name, self._TARGET, False)
97 | )
98 |
99 | def get_all_feature_names(self) -> t.List[str]:
100 | """Get all the feature flags."""
101 | return list(
102 | map(lambda f: f.split("/", 1)[1], self.client._repository.cache.keys())
103 | )
104 |
105 | def _toggle(self, feature_name: str, flag: bool) -> None:
106 | """Toggle a feature flag."""
107 | if flag is self.get(feature_name).to_bool():
108 | return
109 | logger.info(f"Toggling feature flag {feature_name} to {flag}")
110 | requests.patch(
111 | f"https://app.harness.io/cf/admin/features/{feature_name}",
112 | headers={"x-api-key": self.api_key},
113 | params={
114 | "accountIdentifier": self.account,
115 | "orgIdentifier": self.organization,
116 | "projectIdentifier": self.project,
117 | },
118 | json={
119 | "instructions": [
120 | {
121 | "kind": "setFeatureFlagState",
122 | "parameters": {"state": "on" if flag else "off"},
123 | }
124 | ]
125 | },
126 | )
127 |
128 | def save(self, feature_name: str, flag: bool) -> None:
129 | """Create a feature flag."""
130 | if self.get(feature_name) is FlagAdapterResponse.NOT_FOUND:
131 | logger.info(f"Creating feature flag {feature_name}")
132 | try:
133 | requests.post(
134 | "https://app.harness.io/cf/admin/features",
135 | params={
136 | "accountIdentifier": self.account,
137 | "orgIdentifier": self.organization,
138 | },
139 | headers={
140 | "Content-Type": "application/json",
141 | "x-api-key": self.api_key,
142 | },
143 | json={
144 | "defaultOnVariation": "on-variation",
145 | "defaultOffVariation": "off-variation",
146 | "description": "Managed by CDF",
147 | "identifier": feature_name,
148 | "name": feature_name.upper(),
149 | "kind": FeatureConfigKind.BOOLEAN.value,
150 | "permanent": True,
151 | "project": self.project,
152 | "variations": [
153 | {"identifier": "on-variation", "value": "true"},
154 | {"identifier": "off-variation", "value": "false"},
155 | ],
156 | },
157 | )
158 | except Exception:
159 | logger.exception(f"Failed to create feature flag {feature_name}")
160 | self._toggle(feature_name, flag)
161 |
162 | def save_many(self, flags: t.Dict[str, bool]) -> None:
163 | """Create many feature flags."""
164 | list(self.pool.map(lambda f: self.save(*f), flags.items()))
165 |
166 | def delete(self, feature_name: str) -> None:
167 | """Drop a feature flag."""
168 | logger.info(f"Deleting feature flag {feature_name}")
169 | requests.delete(
170 | f"https://app.harness.io/cf/admin/features/{feature_name}",
171 | headers={"x-api-key": self.api_key},
172 | params={
173 | "accountIdentifier": self.account,
174 | "orgIdentifier": self.organization,
175 | "projectIdentifier": self.project,
176 | "forceDelete": True,
177 | },
178 | )
179 |
180 | def delete_many(self, feature_names: t.List[str]) -> None:
181 | """Drop many feature flags."""
182 | list(self.pool.map(self.delete, feature_names))
183 |
184 | def apply_source(self, source: DltSource, *namespace: str) -> DltSource:
185 | """Apply the feature flags to a dlt source."""
186 | # NOTE: we use just the last section due to legacy design decisions
187 | # We will remove this when the Harness team cleans up the feature flag namespace
188 | ns = f"pipeline__{namespace[-1]}__{source.name}"
189 |
190 | # A closure to produce a resource id
191 | def _get_resource_id(resource: str) -> str:
192 | return f"{ns}__{resource}"
193 |
194 | resource_lookup = {
195 | _get_resource_id(key): resource
196 | for key, resource in source.resources.items()
197 | }
198 | every_resource = resource_lookup.keys()
199 | selected_resources = set(
200 | map(_get_resource_id, source.selected_resources.keys())
201 | )
202 |
203 | current_flags = set(
204 | filter(lambda f: f.startswith(ns), self.get_all_feature_names())
205 | )
206 |
207 | removed = current_flags.difference(every_resource)
208 | added = selected_resources.difference(current_flags)
209 |
210 | # TODO: reconciliation will be promoted to a top level context function
211 | if os.getenv("HARNESS_FF_AUTORECONCILE", "0") == "1":
212 | self.delete_many(list(removed))
213 |
214 | self.save_many({f: False for f in added})
215 | for f in added:
216 | resource_lookup[f].selected = False
217 | for f in current_flags.intersection(selected_resources):
218 | resource_lookup[f].selected = self.get(f).to_bool()
219 |
220 | return source
221 |
222 | def __del__(self) -> None:
223 | """Close the client."""
224 | if self._client is not None:
225 | self._client.close()
226 | if self._pool is not None:
227 | self._pool.shutdown()
228 |
229 |
230 | __all__ = ["HarnessFeatureFlagAdapter"]
231 |
--------------------------------------------------------------------------------
/src/cdf/core/component/base.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import inspect
3 | import typing as t
4 | from contextlib import suppress
5 | from dataclasses import field
6 | from enum import Enum
7 |
8 | import pydantic
9 |
10 | import cdf.core.context as ctx
11 | import cdf.core.injector as injector
12 |
13 | if t.TYPE_CHECKING:
14 | from cdf.core.workspace import Workspace
15 |
16 | T = t.TypeVar("T")
17 |
18 | __all__ = [
19 | "Component",
20 | "Entrypoint",
21 | "ServiceLevelAgreement",
22 | ]
23 |
24 |
25 | class ServiceLevelAgreement(Enum):
26 | """An SLA to assign to a component. Users can define the meaning of each level."""
27 |
28 | DEPRECATING = -1
29 | NONE = 0
30 | LOW = 1
31 | MEDIUM = 2
32 | HIGH = 3
33 | CRITICAL = 4
34 |
35 |
36 | class _Node(pydantic.BaseModel, frozen=True):
37 | """A node in a graph of components."""
38 |
39 | owner: t.Optional[str] = None
40 | """The owner of the node. Useful for tracking who to contact for issues or config."""
41 | description: str = "No description provided"
42 | """A description of the node."""
43 | sla: ServiceLevelAgreement = ServiceLevelAgreement.MEDIUM
44 | """The SLA for the node."""
45 | enabled: bool = True
46 | """Whether the node is enabled or disabled. Disabled components are not loaded."""
47 | version: str = "0.1.0"
48 | """A semantic version for the node. Can signal breaking changes to dependents."""
49 | tags: t.List[str] = field(default_factory=list)
50 | """Tags to categorize the node."""
51 | metadata: t.Dict[str, t.Any] = field(default_factory=dict)
52 | """Additional metadata for the node. Useful for custom integrations."""
53 |
54 | @pydantic.field_validator("sla", mode="before")
55 | @classmethod
56 | def _validate_sla(cls, value: t.Any) -> t.Any:
57 | if isinstance(value, str):
58 | value = ServiceLevelAgreement[value.upper()]
59 | return value
60 |
61 | @pydantic.field_validator("tags", mode="before")
62 | @classmethod
63 | def _validate_tags(cls, value: t.Any) -> t.Any:
64 | if isinstance(value, str):
65 | value = value.split(",")
66 | return value
67 |
68 |
69 | def _parse_metadata_from_callable(func: t.Callable) -> t.Dict[str, t.Any]:
70 | """Parse _Node metadata from a function or class allowing looser coupling of configuration.
71 |
72 | The function or class docstring is used as the description if available. The rest
73 | of the metadata is inferred from the function or class attributes. The attributes
74 | may be in global form () or dunder form (____).
75 |
76 | We look for the following attributes in the function or class:
77 | - name: The name of the component
78 |
79 | The following attributes in the function or class with fallback to the module:
80 | - version: The version of the component
81 | - enabled: Whether the component is enabled
82 | - sla: The SLA of the component
83 | - owner: The owner of the component
84 |
85 | And the following are merged from both the function and module:
86 | - tags: Tags for the component
87 | - metadata: Additional metadata for the component
88 | """
89 | if not callable(func):
90 | return {}
91 |
92 | mod = inspect.getmodule(func)
93 |
94 | def _lookup_attributes(
95 | *attrs: str, callback: t.Optional[t.Callable[[t.Any], t.Any]] = None
96 | ) -> t.Optional[t.Any]:
97 | # Look for the attribute in the function and module
98 | for attr in attrs:
99 | with suppress(AttributeError):
100 | v = getattr(func, attr)
101 | if callback:
102 | callback(v)
103 | else:
104 | return v
105 | if mod is not None:
106 | with suppress(AttributeError):
107 | v = getattr(mod, attr)
108 | if callback:
109 | callback(v)
110 | else:
111 | return v
112 |
113 | parsed: t.Dict[str, t.Any] = {
114 | "description": inspect.getdoc(func) or "No description provided"
115 | }
116 | for k in ("name", "version", "enabled", "sla", "owner"):
117 | if (v := _lookup_attributes(k.upper(), f"__{k}__")) is not None:
118 | parsed[k] = v
119 |
120 | _lookup_attributes(
121 | "TAGS", "__tags__", callback=parsed.setdefault("tags", []).extend
122 | )
123 | _lookup_attributes(
124 | "METADATA", "__metadata__", callback=parsed.setdefault("metadata", {}).update
125 | )
126 |
127 | return parsed
128 |
129 |
130 | def _bind_active_workspace(func: t.Any) -> t.Any:
131 | """Bind the active workspace to a function or class.
132 |
133 | Args:
134 | func: The function or class to bind the workspace to.
135 |
136 | Returns:
137 | The bound function or class.
138 | """
139 | if callable(func):
140 | return ctx.resolve(eagerly_bind_workspace=True)(func)
141 | return func
142 |
143 |
144 | def _get_bind_func(info: pydantic.ValidationInfo) -> t.Callable:
145 | """Get the bind function from the pydantic context or use the active workspace.
146 |
147 | Args:
148 | info: The pydantic validation info.
149 |
150 | Returns:
151 | The bind function to use for the component.
152 | """
153 | context = info.context
154 | if context:
155 | bind = t.cast("Workspace", context["parent"]).bind
156 | else:
157 | bind = _bind_active_workspace
158 | return bind
159 |
160 |
161 | def _unwrap_entrypoint(value: t.Any) -> t.Any:
162 | """Import an entrypoint if it is a string.
163 |
164 | Args:
165 | value: The value to import.
166 |
167 | Returns:
168 | The imported value if it is a string, otherwise the original value.
169 | """
170 | if isinstance(value, str):
171 | mod, func = value.split(":", 1)
172 | mod = importlib.import_module(mod)
173 | value = getattr(mod, func)
174 | return value
175 |
176 |
177 | class Component(_Node, t.Generic[T], frozen=True):
178 | """A component with a binding to a dependency."""
179 |
180 | main: injector.Dependency[T]
181 | """The dependency for the component. This is what is injected into the workspace."""
182 |
183 | name: t.Annotated[str, pydantic.Field(..., pattern=r"^[a-zA-Z_][a-zA-Z0-9_]*$")]
184 | """The key to register the component in the container.
185 |
186 | Must be a valid Python identifier. Users can use these names as function parameters
187 | for implicit dependency injection. Names must be unique within the workspace.
188 | """
189 |
190 | def __call__(self) -> T:
191 | """Unwrap the main dependency invoking the underlying callable."""
192 | return self.main.unwrap()
193 |
194 | @pydantic.model_validator(mode="before")
195 | @classmethod
196 | def _parse_main(cls, data: t.Any) -> t.Any:
197 | """Parse function metadata into node defaults."""
198 | if inspect.isfunction(data) or isinstance(data, injector.Dependency):
199 | data = {"main": data}
200 | if isinstance(data, dict):
201 | dep = data["main"]
202 | if isinstance(dep, dict):
203 | func = dep["factory"]
204 | if dep.get("alias", None):
205 | data.setdefault("name", dep["alias"])
206 | elif isinstance(dep, injector.Dependency):
207 | func = dep.factory
208 | if dep.alias:
209 | data.setdefault("name", dep.alias)
210 | else:
211 | func = dep
212 | return {**_parse_metadata_from_callable(func), **data}
213 | return data
214 |
215 | @pydantic.field_validator("main", mode="before")
216 | @classmethod
217 | def _ensure_dependency(cls, value: t.Any, info: pydantic.ValidationInfo) -> t.Any:
218 | """Ensure the main function is a dependency."""
219 | value = _unwrap_entrypoint(value)
220 | if isinstance(value, (dict, injector.Dependency)):
221 | parsed_dep = injector.Dependency.model_validate(value, context=info.context)
222 | else:
223 | parsed_dep = injector.Dependency.wrap(value)
224 | # NOTE: We do this extra round-trip to bypass the unecessary Generic type check in pydantic
225 | return parsed_dep.model_dump()
226 |
227 | @pydantic.model_validator(mode="after")
228 | def _bind_main(self, info: pydantic.ValidationInfo) -> t.Any:
229 | """Bind the active workspace to the main function."""
230 | self.main.map(_get_bind_func(info), idempotent=True)
231 | return self
232 |
233 | def __str__(self):
234 | return f""
235 |
236 |
237 | class Entrypoint(_Node, t.Generic[T], frozen=True):
238 | """An entrypoint representing an invokeable set of functions."""
239 |
240 | main: t.Callable[..., T]
241 | """The main function associated with the entrypoint."""
242 |
243 | name: str
244 | """The name of the entrypoint.
245 |
246 | This is used to register the entrypoint in the workspace and CLI. Names must be
247 | unique within the workspace. The name can contain spaces and special characters.
248 | """
249 |
250 | @pydantic.model_validator(mode="before")
251 | @classmethod
252 | def _parse_main(cls, data: t.Any) -> t.Any:
253 | """Parse function metadata into node defaults."""
254 | if inspect.isfunction(data):
255 | data = {"main": data}
256 | if isinstance(data, dict):
257 | func = _unwrap_entrypoint(data["main"])
258 | return {**_parse_metadata_from_callable(func), **data}
259 | return data
260 |
261 | @pydantic.field_validator("main", mode="before")
262 | @classmethod
263 | def _bind_main(cls, value: t.Any, info: pydantic.ValidationInfo) -> t.Any:
264 | """Bind the active workspace to the main function."""
265 | return _get_bind_func(info)(_unwrap_entrypoint(value))
266 |
267 | def __str__(self):
268 | return f""
269 |
270 | def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
271 | """Invoke the entrypoint."""
272 | return self.main(*args, **kwargs)
273 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/src/cdf/core/workspace.py:
--------------------------------------------------------------------------------
1 | """A workspace is a container for components and configurations."""
2 |
3 | import os
4 | import time
5 | import typing as t
6 | from functools import cached_property, partialmethod
7 | from pathlib import Path
8 |
9 | import pydantic
10 | from typing_extensions import ParamSpec, Self
11 |
12 | import cdf.core.component as cmp
13 | import cdf.core.configuration as conf
14 | import cdf.core.context as ctx
15 | import cdf.core.injector as injector
16 |
17 | if t.TYPE_CHECKING:
18 | import click
19 | import sqlmesh
20 |
21 |
22 | T = t.TypeVar("T")
23 | P = ParamSpec("P")
24 |
25 | __all__ = ["Workspace"]
26 |
27 |
28 | class Workspace(pydantic.BaseModel, frozen=True):
29 | """A CDF workspace that allows for dependency injection and configuration resolution."""
30 |
31 | name: str = "default"
32 | """A human-readable name for the workspace."""
33 | version: str = "0.1.0"
34 | """A semver version string for the workspace."""
35 | environment: str = pydantic.Field(
36 | default_factory=lambda: os.getenv("CDF_ENVIRONMENT", "dev")
37 | )
38 | """The runtime environment used to resolve configuration."""
39 | conf_resolver: conf.ConfigResolver = pydantic.Field(
40 | default_factory=conf.ConfigResolver
41 | )
42 | """The configuration resolver for the workspace."""
43 | container: injector.DependencyRegistry = pydantic.Field(
44 | default_factory=injector.DependencyRegistry
45 | )
46 | """The dependency injection container for the workspace."""
47 | configuration_sources: t.Iterable[conf.ConfigSource] = (
48 | "cdf.toml",
49 | "cdf.yaml",
50 | "cdf.json",
51 | "~/.cdf.toml",
52 | )
53 | """A list of configuration sources resolved and merged by the workspace."""
54 | services_: t.Iterable[cmp.ServiceDef] = pydantic.Field(
55 | default_factory=tuple, alias="services"
56 | )
57 | """An iterable of raw service definitions that the workspace provides."""
58 | pipelines_: t.Iterable[cmp.DataPipelineDef] = pydantic.Field(
59 | default_factory=tuple, alias="pipelines"
60 | )
61 | """An iterable of raw pipeline definitions that the workspace provides."""
62 | publishers_: t.Iterable[cmp.DataPublisherDef] = pydantic.Field(
63 | default_factory=tuple, alias="publishers"
64 | )
65 | """An iterable of raw publisher definitions that the workspace provides."""
66 | operations_: t.Iterable[cmp.OperationDef] = pydantic.Field(
67 | default_factory=tuple, alias="operations"
68 | )
69 | """An iterable of raw generic operation definitions that the workspace provides."""
70 |
71 | # TODO: define an adapter for transformation providers
72 | sqlmesh_path: t.Optional[t.Union[str, Path]] = None
73 | """The path to the sqlmesh root for the workspace."""
74 | sqlmesh_context_kwargs: t.Dict[str, t.Any] = {}
75 | """Keyword arguments to pass to the sqlmesh context."""
76 |
77 | if t.TYPE_CHECKING:
78 | # PERF: this is a workaround for pydantic not being able to build a model with a forwardref
79 | # we still want the deferred import at runtime
80 | sqlmesh_context_class: t.Optional[t.Type[sqlmesh.Context]] = None
81 | """A custom context class to use for sqlmesh."""
82 | else:
83 | sqlmesh_context_class: t.Optional[t.Type[t.Any]] = None
84 | """A custom context class to use for sqlmesh."""
85 |
86 | @pydantic.model_validator(mode="after")
87 | def _setup(self) -> Self:
88 | """Initialize the workspace."""
89 | for source in self.configuration_sources:
90 | self.conf_resolver.import_source(source)
91 | self.conf_resolver.set_environment(self.environment)
92 | self.container.add_from_dependency(
93 | injector.Dependency.instance(self),
94 | key="cdf_workspace",
95 | override=True,
96 | )
97 | self.container.add_from_dependency(
98 | injector.Dependency.instance(self.environment),
99 | key="cdf_environment",
100 | override=True,
101 | )
102 | self.container.add_from_dependency(
103 | injector.Dependency.instance(self.conf_resolver),
104 | key="cdf_config",
105 | override=True,
106 | )
107 | self.container.add_from_dependency(
108 | injector.Dependency.singleton(self.get_sqlmesh_context_or_raise),
109 | key="cdf_transform",
110 | override=True,
111 | )
112 | for service in self.services.values():
113 | self.container.add_from_dependency(service.main, key=service.name)
114 | self.activate()
115 | return self
116 |
117 | def activate(self) -> Self:
118 | """Activate the workspace for the current context."""
119 | ctx.set_active_workspace(self)
120 | return self
121 |
122 | def _parse_definitions(
123 | self, defs: t.Iterable[cmp.TComponentDef], into: t.Type[cmp.TComponent]
124 | ) -> t.Dict[str, cmp.TComponent]:
125 | """Parse a list of component definitions into a lookup."""
126 | components = {}
127 | with ctx.use_workspace(self):
128 | for definition in defs:
129 | component = into.model_validate(definition, context={"parent": self})
130 | components[component.name] = component
131 | return components
132 |
133 | @cached_property
134 | def services(self) -> t.Dict[str, cmp.Service]:
135 | """Return the resolved services of the workspace."""
136 | return self._parse_definitions(self.services_, cmp.Service)
137 |
138 | @cached_property
139 | def pipelines(self) -> t.Dict[str, cmp.DataPipeline]:
140 | """Return the resolved data pipelines of the workspace."""
141 | return self._parse_definitions(self.pipelines_, cmp.DataPipeline)
142 |
143 | @cached_property
144 | def publishers(self) -> t.Dict[str, cmp.DataPublisher]:
145 | """Return the resolved data publishers of the workspace."""
146 | return self._parse_definitions(self.publishers_, cmp.DataPublisher)
147 |
148 | @cached_property
149 | def operations(self) -> t.Dict[str, cmp.Operation]:
150 | """Return the resolved operations of the workspace."""
151 | return self._parse_definitions(self.operations_, cmp.Operation)
152 |
153 | @t.overload
154 | def get_sqlmesh_context(
155 | self,
156 | gateway: t.Optional[str],
157 | must_exist: t.Literal[False],
158 | **kwargs: t.Any,
159 | ) -> t.Optional["sqlmesh.Context"]: ...
160 |
161 | @t.overload
162 | def get_sqlmesh_context(
163 | self,
164 | gateway: t.Optional[str],
165 | must_exist: t.Literal[True],
166 | **kwargs: t.Any,
167 | ) -> "sqlmesh.Context": ...
168 |
169 | def get_sqlmesh_context(
170 | self, gateway: t.Optional[str] = None, must_exist: bool = False, **kwargs: t.Any
171 | ) -> t.Optional["sqlmesh.Context"]:
172 | """Return the transform context or raise an error if not defined."""
173 | import sqlmesh
174 |
175 | if self.sqlmesh_path is None:
176 | if must_exist:
177 | raise ValueError("Transformation provider not defined.")
178 | return None
179 |
180 | kwargs = {**self.sqlmesh_context_kwargs, **kwargs}
181 | with ctx.use_workspace(self):
182 | klass = self.sqlmesh_context_class or sqlmesh.Context
183 | return klass(paths=[self.sqlmesh_path], gateway=gateway, **kwargs)
184 |
185 | if t.TYPE_CHECKING:
186 |
187 | def get_sqlmesh_context_or_raise(
188 | self, gateway: t.Optional[str] = None, **kwargs: t.Any
189 | ) -> "sqlmesh.Context": ...
190 |
191 | else:
192 | get_sqlmesh_context_or_raise = partialmethod(
193 | get_sqlmesh_context, must_exist=True
194 | )
195 |
196 | @property
197 | def cli(self) -> "click.Group":
198 | """Dynamically generate a CLI entrypoint for the workspace."""
199 | import click
200 |
201 | @click.group()
202 | def cli() -> None:
203 | """A dynamically generated CLI for the workspace."""
204 | self.activate()
205 |
206 | def _list(d: t.Dict[str, cmp.TComponent], verbose: bool = False) -> None:
207 | for name in sorted(d.keys()):
208 | if verbose:
209 | click.echo(d[name].model_dump_json(indent=2, exclude={"main"}))
210 | else:
211 | click.echo(d[name])
212 |
213 | for k in ("services", "pipelines", "publishers", "operations"):
214 | cli.command(
215 | f"list-{k}", help=f"List the {k} in the {self.name} workspace."
216 | )(
217 | click.option("-v", "--verbose", is_flag=True)(
218 | lambda verbose=False, k=k: _list(getattr(self, k), verbose=verbose)
219 | )
220 | )
221 |
222 | @cli.command("run-pipeline")
223 | @click.argument(
224 | "pipeline_name",
225 | required=False,
226 | type=click.Choice(list(self.pipelines.keys())),
227 | )
228 | @click.option(
229 | "--test",
230 | is_flag=True,
231 | help="Run the pipelines integration test if defined.",
232 | )
233 | @click.pass_context
234 | def run_pipeline(
235 | ctx: click.Context,
236 | pipeline_name: t.Optional[str] = None,
237 | test: bool = False,
238 | ) -> None:
239 | """Run a data pipeline."""
240 | if pipeline_name is None:
241 | pipeline_name = click.prompt(
242 | "Enter a pipeline",
243 | type=click.Choice(list(self.pipelines.keys())),
244 | show_choices=True,
245 | )
246 | if pipeline_name is None:
247 | raise click.BadParameter(
248 | "Pipeline must be specified.", ctx=ctx, param_hint="pipeline"
249 | )
250 |
251 | pipeline = self.pipelines[pipeline_name]
252 |
253 | if test:
254 | click.echo("Running pipeline tests.", err=True)
255 | try:
256 | pipeline.run_tests()
257 | except Exception as e:
258 | click.echo(f"Pipeline test(s) failed: {e}", err=True)
259 | ctx.exit(1)
260 | else:
261 | click.echo("Pipeline test(s) passed!", err=True)
262 | ctx.exit(0)
263 |
264 | start = time.time()
265 | try:
266 | jobs = pipeline()
267 | except Exception as e:
268 | click.echo(
269 | f"Pipeline failed after {time.time() - start:.2f} seconds: {e}",
270 | err=True,
271 | )
272 | ctx.exit(1)
273 |
274 | click.echo(
275 | f"Pipeline process finished in {time.time() - start:.2f} seconds.",
276 | err=True,
277 | )
278 |
279 | for job in jobs:
280 | if job.has_failed_jobs:
281 | ctx.fail("Pipeline failed.")
282 |
283 | ctx.exit(0)
284 |
285 | @cli.command("run-publisher")
286 | @click.argument(
287 | "publisher_name",
288 | required=False,
289 | type=click.Choice(list(self.publishers.keys())),
290 | )
291 | @click.option(
292 | "--test",
293 | is_flag=True,
294 | help="Run the publishers integration test if defined.",
295 | )
296 | @click.pass_context
297 | def run_publisher(
298 | ctx: click.Context,
299 | publisher_name: t.Optional[str] = None,
300 | test: bool = False,
301 | ) -> None:
302 | """Run a data publisher."""
303 | if publisher_name is None:
304 | publisher_name = click.prompt(
305 | "Enter a publisher",
306 | type=click.Choice(list(self.publishers.keys())),
307 | show_choices=True,
308 | )
309 | if publisher_name is None:
310 | raise click.BadParameter(
311 | "Publisher must be specified.", ctx=ctx, param_hint="publisher"
312 | )
313 |
314 | publisher = self.publishers[publisher_name]
315 |
316 | start = time.time()
317 | try:
318 | publisher()
319 | except Exception as e:
320 | click.echo(
321 | f"Publisher failed after {time.time() - start:.2f} seconds: {e}",
322 | err=True,
323 | )
324 | ctx.exit(1)
325 |
326 | click.echo(
327 | f"Publisher process finished in {time.time() - start:.2f} seconds.",
328 | err=True,
329 | )
330 | ctx.exit(0)
331 |
332 | @cli.command("run-operation")
333 | @click.argument(
334 | "operation_name",
335 | required=False,
336 | type=click.Choice(list(self.operations.keys())),
337 | )
338 | @click.pass_context
339 | def run_operation(
340 | ctx: click.Context, operation_name: t.Optional[str] = None
341 | ) -> int:
342 | """Run a generic operation."""
343 | if operation_name is None:
344 | operation_name = click.prompt(
345 | "Enter an operation",
346 | type=click.Choice(list(self.operations.keys())),
347 | show_choices=True,
348 | )
349 | if operation_name is None:
350 | raise click.BadParameter(
351 | "Operation must be specified.", ctx=ctx, param_hint="operation"
352 | )
353 |
354 | operation = self.operations[operation_name]
355 |
356 | ctx.exit(operation())
357 |
358 | return cli
359 |
360 | def bind(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
361 | """Wrap a function with configuration and dependencies defined in the workspace."""
362 | configured_f = self.conf_resolver.resolve_defaults(func_or_cls)
363 | return self.container.wire(configured_f)
364 |
365 | def invoke(self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any) -> T:
366 | """Invoke a function with configuration and dependencies defined in the workspace."""
367 | with ctx.use_workspace(self):
368 | return self.bind(func_or_cls)(*args, **kwargs)
369 |
--------------------------------------------------------------------------------
/src/cdf/integrations/slack.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import traceback
4 | import typing as t
5 | from datetime import datetime, timezone
6 | from enum import Enum
7 | from textwrap import dedent, indent
8 |
9 | import requests
10 |
11 | SLACK_MAX_TEXT_LENGTH = 3000
12 | SLACK_MAX_ALERT_PREVIEW_BLOCKS = 5
13 | SLACK_MAX_ATTACHMENTS_BLOCKS = 50
14 | CONTINUATION_SYMBOL = "..."
15 |
16 |
17 | TSlackBlock = t.Dict[str, t.Any]
18 |
19 |
20 | class TSlackBlocks(t.TypedDict):
21 | blocks: t.List[TSlackBlock]
22 |
23 |
24 | class TSlackMessage(TSlackBlocks, t.TypedDict):
25 | attachments: t.List[TSlackBlocks]
26 |
27 |
28 | class SlackMessageComposer:
29 | """Builds Slack message with primary and secondary blocks"""
30 |
31 | def __init__(self, initial_message: t.Optional[TSlackMessage] = None) -> None:
32 | """Initialize the Slack message builder"""
33 | self.slack_message = initial_message or {
34 | "blocks": [],
35 | "attachments": [{"blocks": []}],
36 | }
37 |
38 | def add_primary_blocks(self, *blocks: TSlackBlock) -> "SlackMessageComposer":
39 | """Add blocks to the message. Blocks are always displayed"""
40 | self.slack_message["blocks"].extend(blocks)
41 | return self
42 |
43 | def add_secondary_blocks(self, *blocks: TSlackBlock) -> "SlackMessageComposer":
44 | """Add attachments to the message
45 |
46 | Attachments are hidden behind "show more" button. The first 5 attachments
47 | are always displayed. NOTICE: attachments blocks are deprecated by Slack
48 | """
49 | self.slack_message["attachments"][0]["blocks"].extend(blocks)
50 | if (
51 | len(self.slack_message["attachments"][0]["blocks"])
52 | >= SLACK_MAX_ATTACHMENTS_BLOCKS
53 | ):
54 | raise ValueError("Too many attachments")
55 | return self
56 |
57 | def _introspect(self) -> "SlackMessageComposer":
58 | """Print the message to stdout
59 |
60 | This is a debugging method. Useful during composition of the message."""
61 | print(json.dumps(self.slack_message, indent=2))
62 | return self
63 |
64 |
65 | def normalize_message(message: t.Union[str, t.List[str], t.Iterable[str]]) -> str:
66 | """Normalize message to fit Slack's max text length"""
67 | if isinstance(message, (list, tuple, set)):
68 | message = stringify_list(list(message))
69 | assert isinstance(message, str), f"Message must be a string, got {type(message)}"
70 | dedented_message = dedent(message)
71 | if len(dedented_message) < SLACK_MAX_TEXT_LENGTH:
72 | return dedent(dedented_message)
73 | return dedent(
74 | dedented_message[: SLACK_MAX_TEXT_LENGTH - len(CONTINUATION_SYMBOL) - 3]
75 | + CONTINUATION_SYMBOL
76 | + dedented_message[-3:]
77 | )
78 |
79 |
80 | def divider_block() -> dict:
81 | """Create a divider block"""
82 | return {"type": "divider"}
83 |
84 |
85 | def fields_section_block(*messages: str) -> dict:
86 | """Create a section block with multiple fields"""
87 | return {
88 | "type": "section",
89 | **{
90 | "fields": {
91 | "type": "mrkdwn",
92 | "text": normalize_message(message),
93 | }
94 | for message in messages
95 | },
96 | }
97 |
98 |
99 | def text_section_block(message: str) -> dict:
100 | """Create a section block with text"""
101 | return {
102 | "type": "section",
103 | "text": {
104 | "type": "mrkdwn",
105 | "text": normalize_message(message),
106 | },
107 | }
108 |
109 |
110 | def empty_section_block() -> dict:
111 | """Create an empty section block"""
112 | return {
113 | "type": "section",
114 | "text": {
115 | "type": "mrkdwn",
116 | "text": normalize_message("\t"),
117 | },
118 | }
119 |
120 |
121 | def context_block(*messages: str) -> dict:
122 | """Create a context block with multiple fields"""
123 | return {
124 | "type": "context",
125 | "elements": [
126 | {
127 | "type": "mrkdwn",
128 | "text": normalize_message(message),
129 | }
130 | for message in messages
131 | ],
132 | }
133 |
134 |
135 | def header_block(message: str) -> dict:
136 | """Create a header block"""
137 | return {
138 | "type": "header",
139 | "text": {
140 | "type": "plain_text",
141 | "text": message,
142 | },
143 | }
144 |
145 |
146 | def button_action_block(text: str, url: str) -> dict:
147 | """Create a button action block"""
148 | return {
149 | "type": "actions",
150 | "elements": [
151 | {
152 | "type": "button",
153 | "text": {"type": "plain_text", "text": text, "emoji": True},
154 | "value": text,
155 | "url": url,
156 | }
157 | ],
158 | }
159 |
160 |
161 | def compacted_sections_blocks(*messages: t.Union[str, t.Iterable[str]]) -> t.List[dict]:
162 | """Create a list of compacted sections blocks"""
163 | return [
164 | {
165 | "type": "section",
166 | "fields": [
167 | {
168 | "type": "mrkdwn",
169 | "text": normalize_message(message),
170 | }
171 | for message in messages[i : i + 2]
172 | ],
173 | }
174 | for i in range(0, len(messages), 2)
175 | ]
176 |
177 |
178 | class SlackAlertIcon(str, Enum):
179 | """Enum for status of the alert"""
180 |
181 | # simple statuses
182 | OK = ":large_green_circle:"
183 | WARN = ":warning:"
184 | ERROR = ":x:"
185 | START = ":arrow_forward:"
186 | ALERT = ":rotating_light:"
187 | STOP = ":stop_button:"
188 |
189 | # log levels
190 | UNKNOWN = ":question:"
191 | INFO = ":information_source:"
192 | DEBUG = ":beetle:"
193 | CRITICAL = ":fire:"
194 | FATAL = ":skull_and_crossbones:"
195 | EXCEPTION = ":boom:"
196 |
197 | # test statuses
198 | FAILURE = ":no_entry_sign:"
199 | SUCCESS = ":white_check_mark:"
200 | WARNING = ":warning:"
201 | SKIPPED = ":fast_forward:"
202 | PASSED = ":white_check_mark:"
203 |
204 | def __str__(self) -> str:
205 | return self.value
206 |
207 |
208 | def stringify_list(list_variation: t.Union[t.List[str], str]) -> str:
209 | """Prettify and deduplicate list of strings converting it to a newline delimited string"""
210 | if isinstance(list_variation, str):
211 | return list_variation
212 | if len(list_variation) == 1:
213 | return list_variation[0]
214 | list_variation = list(list_variation)
215 | for i, item in enumerate(list_variation):
216 | if not isinstance(item, str):
217 | list_variation[i] = str(item)
218 | order = {item: i for i, item in reversed(list(enumerate(list_variation)))}
219 | return "\n".join(sorted(set(list_variation), key=lambda item: order[item]))
220 |
221 |
222 | def send_basic_slack_message(
223 | incoming_hook: str, message: str, is_markdown: bool = True
224 | ) -> None:
225 | """Sends a `message` to Slack `incoming_hook`, by default formatted as markdown."""
226 | resp = requests.post(
227 | incoming_hook,
228 | data=json.dumps({"text": message, "mrkdwn": is_markdown}).encode("utf-8"),
229 | headers={"Content-Type": "application/json;charset=utf-8"},
230 | )
231 | resp.raise_for_status()
232 |
233 |
234 | def send_extract_start_slack_message(
235 | incoming_hook: str,
236 | source: str,
237 | run_id: str,
238 | tags: t.List[str],
239 | owners: t.List[str],
240 | environment: str,
241 | resources_selected: t.List[str],
242 | resources_count: int,
243 | ) -> None:
244 | """Sends a Slack message for the start of an extract"""
245 | resp = requests.post(
246 | incoming_hook,
247 | json=SlackMessageComposer()
248 | .add_primary_blocks(
249 | header_block(f"{SlackAlertIcon.START} Starting Extract (id: {run_id})"),
250 | context_block(
251 | "*Source:* {source} |".format(source=source),
252 | "*Status:* Starting Extraction |",
253 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
254 | ),
255 | divider_block(),
256 | *compacted_sections_blocks(
257 | ("*Tags*", stringify_list(tags)),
258 | ("*Owners*", stringify_list(owners)),
259 | ),
260 | *compacted_sections_blocks(
261 | ("*Environment*", environment),
262 | (
263 | "*Resources*",
264 | f"{len(resources_selected)}/{resources_count} selected",
265 | ),
266 | ),
267 | divider_block(),
268 | text_section_block(
269 | f"""
270 | Resources selected for extraction :test_tube:
271 |
272 | {stringify_list(resources_selected)}
273 | """
274 | ),
275 | button_action_block("View in Harness", url="https://app.harness.io"),
276 | context_block(f"*Python Version:* {sys.version}"),
277 | )
278 | .slack_message,
279 | )
280 | resp.raise_for_status()
281 |
282 |
283 | def send_extract_failure_message(
284 | incoming_hook: str, source: str, run_id: str, duration: float, error: Exception
285 | ) -> None:
286 | """Sends a Slack message for the failure of an extract"""
287 | # trace = "\n".join(f"> {line}" for line in traceback.format_exc().splitlines())
288 | resp = requests.post(
289 | incoming_hook,
290 | json=SlackMessageComposer()
291 | .add_primary_blocks(
292 | header_block(f"{SlackAlertIcon.ERROR} Extract Failed (id: {run_id})"),
293 | context_block(
294 | "*Source:* {source} |".format(source=source),
295 | "*Status:* Extraction Failed |",
296 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
297 | ),
298 | divider_block(),
299 | text_section_block(
300 | f"""
301 | Extract failed after {duration:.2f}s :fire:
302 |
303 | ```
304 | {indent(traceback.format_exc(), " " * 12, lambda line: (not line.startswith("Traceback")))}
305 | ```
306 |
307 | Please check the logs for more information.
308 | """
309 | ),
310 | button_action_block("View in Harness", url="https://app.harness.io"),
311 | context_block(f"*Python Version:* {sys.version}"),
312 | )
313 | .slack_message,
314 | )
315 | resp.raise_for_status()
316 |
317 |
318 | def send_extract_success_message(
319 | incoming_hook: str, source: str, run_id: str, duration: float
320 | ) -> None:
321 | """Sends a Slack message for the success of an extract"""
322 | resp = requests.post(
323 | incoming_hook,
324 | json=SlackMessageComposer()
325 | .add_primary_blocks(
326 | header_block(f"{SlackAlertIcon.OK} Extract Succeeded (id: {run_id})"),
327 | context_block(
328 | "*Source:* {source} |".format(source=source),
329 | "*Status:* Extraction Succeeded |",
330 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
331 | ),
332 | divider_block(),
333 | text_section_block(
334 | f"""
335 | Extract succeeded after {duration:.2f}s :tada:
336 |
337 | Please check the logs for more information.
338 | """
339 | ),
340 | button_action_block("View in Harness", url="https://app.harness.io"),
341 | context_block(f"*Python Version:* {sys.version}"),
342 | )
343 | .slack_message,
344 | )
345 | resp.raise_for_status()
346 |
347 |
348 | def send_normalize_start_slack_message(
349 | incoming_hook: str,
350 | source: str,
351 | blob_name: str,
352 | run_id: str,
353 | environment: str,
354 | ) -> None:
355 | """Sends a Slack message for the start of an extract"""
356 | _ = environment
357 | resp = requests.post(
358 | incoming_hook,
359 | json=SlackMessageComposer()
360 | .add_primary_blocks(
361 | header_block(f"{SlackAlertIcon.START} Normalizing (id: {run_id})"),
362 | context_block(
363 | "*Source:* {source} |".format(source=source),
364 | "*Status:* Starting Normalization |",
365 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
366 | ),
367 | divider_block(),
368 | text_section_block(
369 | f"""
370 | Pending load package discovered in stage :package:
371 |
372 | Starting normalization for: :file_folder:
373 |
374 | `{blob_name}`
375 | """
376 | ),
377 | context_block(f"*Python Version:* {sys.version}"),
378 | )
379 | .slack_message,
380 | )
381 | resp.raise_for_status()
382 |
383 |
384 | def send_normalize_failure_message(
385 | incoming_hook: str,
386 | source: str,
387 | blob_name: str,
388 | run_id: str,
389 | duration: float,
390 | error: Exception,
391 | ) -> None:
392 | """Sends a Slack message for the failure of an normalization"""
393 | # trace = "\n".join(f"> {line}" for line in traceback.format_exc().splitlines())
394 | resp = requests.post(
395 | incoming_hook,
396 | json=SlackMessageComposer()
397 | .add_primary_blocks(
398 | header_block(f"{SlackAlertIcon.ERROR} Normalization Failed (id: {run_id})"),
399 | context_block(
400 | "*Source:* {source} |".format(source=source),
401 | "*Status:* Normalization Failed |",
402 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
403 | ),
404 | divider_block(),
405 | text_section_block(
406 | f"""
407 | Normalization failed after {duration:.2f}s :fire:
408 |
409 | ```
410 | {indent(traceback.format_exc(), " " * 12, lambda line: (not line.startswith("Traceback")))}
411 | ```
412 |
413 | Please check the pod logs for more information.
414 | """
415 | ),
416 | context_block(f"*Python Version:* {sys.version}"),
417 | )
418 | .slack_message,
419 | )
420 | resp.raise_for_status()
421 |
422 |
423 | def send_normalization_success_message(
424 | incoming_hook: str, source: str, blob_name: str, run_id: str, duration: float
425 | ) -> None:
426 | """Sends a Slack message for the success of an normalization"""
427 | resp = requests.post(
428 | incoming_hook,
429 | json=SlackMessageComposer()
430 | .add_primary_blocks(
431 | header_block(f"{SlackAlertIcon.OK} Normalization Succeeded (id: {run_id})"),
432 | context_block(
433 | "*Source:* {source} |".format(source=source),
434 | "*Status:* Normalization Succeeded |",
435 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
436 | ),
437 | divider_block(),
438 | text_section_block(
439 | f"""
440 | Normalization took {duration:.2f}s :tada:
441 |
442 | The package was normalized successfully: :file_folder:
443 |
444 | `{blob_name}`
445 |
446 | This package is now prepared for loading.
447 | """
448 | ),
449 | )
450 | .slack_message,
451 | )
452 | resp.raise_for_status()
453 |
454 |
455 | def send_load_start_slack_message(
456 | incoming_hook: str,
457 | source: str,
458 | destination: str,
459 | dataset: str,
460 | run_id: str,
461 | ) -> None:
462 | """Sends a Slack message for the start of a load"""
463 | resp = requests.post(
464 | incoming_hook,
465 | json=SlackMessageComposer()
466 | .add_primary_blocks(
467 | header_block(f"{SlackAlertIcon.START} Loading (id: {run_id})"),
468 | context_block(
469 | "*Source:* {source} |".format(source=source),
470 | "*Status:* Starting Load |",
471 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
472 | ),
473 | divider_block(),
474 | *compacted_sections_blocks(
475 | ("*Destination*", destination),
476 | ("*Dataset*", dataset),
477 | ),
478 | context_block(f"*Python Version:* {sys.version}"),
479 | )
480 | .slack_message,
481 | )
482 | resp.raise_for_status()
483 |
484 |
485 | def send_load_failure_message(
486 | incoming_hook: str,
487 | source: str,
488 | destination: str,
489 | dataset: str,
490 | run_id: str,
491 | ) -> None:
492 | """Sends a Slack message for the failure of an load"""
493 | # trace = "\n".join(f"> {line}" for line in traceback.format_exc().splitlines())
494 | resp = requests.post(
495 | incoming_hook,
496 | json=SlackMessageComposer()
497 | .add_primary_blocks(
498 | header_block(f"{SlackAlertIcon.ERROR} Load Failed (id: {run_id})"),
499 | context_block(
500 | "*Source:* {source} |".format(source=source),
501 | "*Status:* Normalization Failed |",
502 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
503 | ),
504 | divider_block(),
505 | text_section_block(
506 | f"""
507 | Load to {destination} dataset named {dataset} failed :fire:
508 |
509 | ```
510 | {indent(traceback.format_exc(), " " * 12, lambda line: (not line.startswith("Traceback")))}
511 | ```
512 |
513 | Please check the pod logs for more information.
514 | """
515 | ),
516 | context_block(f"*Python Version:* {sys.version}"),
517 | )
518 | .slack_message,
519 | )
520 | resp.raise_for_status()
521 |
522 |
523 | def send_load_success_message(
524 | incoming_hook: str,
525 | source: str,
526 | destination: str,
527 | dataset: str,
528 | run_id: str,
529 | payload: str,
530 | ) -> None:
531 | """Sends a Slack message for the success of an normalization"""
532 | resp = requests.post(
533 | incoming_hook,
534 | json=SlackMessageComposer()
535 | .add_primary_blocks(
536 | header_block(f"{SlackAlertIcon.OK} Load Succeeded (id: {run_id})"),
537 | context_block(
538 | "*Source:* {source} |".format(source=source),
539 | "*Status:* Loading Succeeded |",
540 | "*{date}*".format(date=datetime.now(timezone.utc).strftime("%x %X")),
541 | ),
542 | divider_block(),
543 | *compacted_sections_blocks(
544 | ("*Destination*", destination),
545 | ("*Dataset*", dataset),
546 | ),
547 | divider_block(),
548 | text_section_block(
549 | f"""
550 | The package was loaded successfully: :file_folder:
551 |
552 | ```
553 | {payload}
554 | ```
555 | """
556 | ),
557 | context_block(f"*Python Version:* {sys.version}"),
558 | )
559 | .slack_message,
560 | )
561 | resp.raise_for_status()
562 |
--------------------------------------------------------------------------------
/src/cdf/core/configuration.py:
--------------------------------------------------------------------------------
1 | """Configuration utilities for the CDF configuration resolver system.
2 |
3 | There are 3 ways to request configuration values:
4 |
5 | 1. Using a Request annotation:
6 |
7 | Pro: It's explicit and re-usable. An annotation can be used in multiple places.
8 |
9 | ```python
10 | import typing as t
11 | import cdf.core.configuration as conf
12 |
13 | def foo(bar: t.Annotated[str, conf.Request["api.key"]]) -> None:
14 | print(bar)
15 | ```
16 |
17 | 2. Setting a __cdf_resolve__ attribute on a callable object. This can be done
18 | directly or by using the `map_section` or `map_values` decorators:
19 |
20 | Pro: It's concise and can be used in a decorator. It also works with classes.
21 |
22 | ```python
23 | import cdf.core.configuration as conf
24 |
25 | @conf.map_section("api")
26 | def foo(key: str) -> None:
27 | print(key)
28 |
29 | @conf.map_values(key="api.key")
30 | def bar(key: str) -> None:
31 | print(key)
32 |
33 | def baz(key: str) -> None:
34 | print(key)
35 |
36 | baz.__cdf_resolve__ = ("api",)
37 | ```
38 |
39 | 3. Using the `_cdf_resolve` kwarg to request the resolver:
40 |
41 | Pro: It's flexible and can be used in any function. It requires no imports.
42 |
43 | ```python
44 | def foo(key: str, _cdf_resolve=("api",)) -> None:
45 | print(key)
46 |
47 | def bar(key: str, _cdf_resolve={"key": "api.key"}) -> None:
48 | print(key)
49 | ```
50 | """
51 |
52 | import ast
53 | import functools
54 | import inspect
55 | import json
56 | import logging
57 | import os
58 | import re
59 | import string
60 | import typing as t
61 | from collections import ChainMap
62 | from contextlib import suppress
63 | from pathlib import Path
64 |
65 | import pydantic
66 | import pydantic_core
67 | from typing_extensions import ParamSpec
68 |
69 | if t.TYPE_CHECKING:
70 | from dynaconf.vendor.box import Box
71 |
72 | from cdf.types import M
73 |
74 | logger = logging.getLogger(__name__)
75 |
76 | T = t.TypeVar("T")
77 | P = ParamSpec("P")
78 |
79 | __all__ = [
80 | "ConfigLoader",
81 | "ConfigResolver",
82 | "ConfigSource",
83 | "Request",
84 | "add_custom_converter",
85 | "remove_converter",
86 | "load_file",
87 | "map_config_section",
88 | "map_config_values",
89 | ]
90 |
91 |
92 | def load_file(path: Path) -> M.Result[t.Dict[str, t.Any], Exception]:
93 | """Load a configuration from a file path.
94 |
95 | Args:
96 | path: The file path.
97 |
98 | Returns:
99 | A Result monad with the configuration dictionary if the file format is JSON, YAML or TOML.
100 | Otherwise, a Result monad with an error.
101 | """
102 | if path.suffix == ".json":
103 | return _load_json(path)
104 | if path.suffix in (".yaml", ".yml"):
105 | return _load_yaml(path)
106 | if path.suffix == ".toml":
107 | return _load_toml(path)
108 | return M.error(ValueError("Invalid file format, must be JSON, YAML or TOML"))
109 |
110 |
111 | def _load_json(path: Path) -> M.Result[t.Dict[str, t.Any], Exception]:
112 | """Load a configuration from a JSON file.
113 |
114 | Args:
115 | path: The file path to a valid JSON document.
116 |
117 | Returns:
118 | A Result monad with the configuration dictionary if the file format is JSON. Otherwise, a
119 | Result monad with an error.
120 | """
121 | try:
122 | return M.ok(json.loads(path.read_text()))
123 | except Exception as e:
124 | return M.error(e)
125 |
126 |
127 | def _load_yaml(path: Path) -> M.Result[t.Dict[str, t.Any], Exception]:
128 | """Load a configuration from a YAML file.
129 |
130 | Args:
131 | path: The file path to a valid YAML document.
132 |
133 | Returns:
134 | A Result monad with the configuration dictionary if the file format is YAML. Otherwise, a
135 | Result monad with an error.
136 | """
137 | try:
138 | import ruamel.yaml as yaml
139 |
140 | yaml_ = yaml.YAML()
141 | return M.ok(yaml_.load(path))
142 | except Exception as e:
143 | return M.error(e)
144 |
145 |
146 | def _load_toml(path: Path) -> M.Result[t.Dict[str, t.Any], Exception]:
147 | """Load a configuration from a TOML file.
148 |
149 | Args:
150 | path: The file path to a valid TOML document.
151 |
152 | Returns:
153 | A Result monad with the configuration dictionary if the file format is TOML. Otherwise, a
154 | Result monad with an error.
155 | """
156 | try:
157 | import tomlkit
158 |
159 | return M.ok(tomlkit.loads(path.read_text()).unwrap())
160 | except Exception as e:
161 | return M.error(e)
162 |
163 |
164 | def _to_bool(value: str) -> bool:
165 | """Convert a string to a boolean."""
166 | return value.lower() in ["true", "1", "yes"]
167 |
168 |
169 | def _resolve_template(template: str, **overrides: t.Any) -> str:
170 | """Resolve a template string using environment variables."""
171 | return string.Template(template).substitute(overrides, **os.environ)
172 |
173 |
174 | _CONVERTERS = {
175 | "json": json.loads,
176 | "int": int,
177 | "float": float,
178 | "str": str,
179 | "bool": _to_bool,
180 | "path": os.path.abspath,
181 | "dict": ast.literal_eval,
182 | "list": ast.literal_eval,
183 | "tuple": ast.literal_eval,
184 | "set": ast.literal_eval,
185 | "resolve": None,
186 | }
187 | """Converters for configuration values."""
188 |
189 | _CONVERTER_PATTERN = re.compile(r"@(\w+) ", re.IGNORECASE)
190 | """Pattern to match converters in a string."""
191 |
192 |
193 | def add_custom_converter(name: str, converter: t.Callable[[str], t.Any]) -> None:
194 | """Add a custom converter to the configuration system."""
195 | if name in _CONVERTERS:
196 | raise ValueError(f"Converter {name} already exists.")
197 | _CONVERTERS[name] = converter
198 |
199 |
200 | def get_converter(name: str) -> t.Callable[[str], t.Any]:
201 | """Get a custom converter from the configuration system."""
202 | return _CONVERTERS[name]
203 |
204 |
205 | def remove_converter(name: str) -> None:
206 | """Remove a custom converter from the configuration system."""
207 | if name not in _CONVERTERS:
208 | raise ValueError(f"Converter {name} does not exist.")
209 | del _CONVERTERS[name]
210 |
211 |
212 | def apply_converters(
213 | input_value: t.Any, resolver: t.Optional["ConfigResolver"] = None
214 | ) -> t.Any:
215 | """Apply converters to a string."""
216 | if not isinstance(input_value, str):
217 | return input_value
218 | expanded_value = _resolve_template(input_value)
219 | converters = _CONVERTER_PATTERN.findall(expanded_value)
220 | if len(converters) == 0:
221 | return expanded_value
222 | base_value = _CONVERTER_PATTERN.sub("", expanded_value).lstrip()
223 | if not base_value:
224 | return None
225 | transformed_value = base_value
226 | for converter in reversed(converters):
227 | try:
228 | if converter.lower() == "resolve":
229 | if resolver is None:
230 | raise ValueError(
231 | "Resolver instance not provided but found @resolve converter"
232 | )
233 | if transformed_value not in resolver:
234 | raise ValueError(f"Key not found in resolver: {transformed_value}")
235 | transformed_value = resolver[transformed_value]
236 | continue
237 | transformed_value = _CONVERTERS[converter.lower()](transformed_value)
238 | except KeyError as e:
239 | raise ValueError(f"Unknown converter: {converter}") from e
240 | except Exception as e:
241 | raise ValueError(f"Failed to convert value: {e}") from e
242 | return transformed_value
243 |
244 |
245 | def _to_box(mapping: t.Mapping[str, t.Any]) -> "Box":
246 | """Convert a mapping to a standardized Box."""
247 | from dynaconf.vendor.box import Box
248 |
249 | return Box(mapping, box_dots=True)
250 |
251 |
252 | class _ConfigScopes(t.NamedTuple):
253 | """A struct to store named configuration scopes by precedence."""
254 |
255 | explicit: "Box"
256 | """User-provided configuration passed as a dictionary."""
257 | environment: "Box"
258 | """Environment-specific configuration loaded from a file."""
259 | baseline: "Box"
260 | """Configuration loaded from a base config file."""
261 |
262 | def resolve(self) -> "Box":
263 | """Resolve the configuration scopes."""
264 | output = self.baseline
265 | output.merge_update(self.environment)
266 | output.merge_update(self.explicit)
267 | return output
268 |
269 |
270 | ConfigSource = t.Union[str, Path, t.Mapping[str, t.Any]]
271 |
272 |
273 | class ConfigLoader:
274 | """Load configuration from multiple sources."""
275 |
276 | def __init__(
277 | self,
278 | *sources: ConfigSource,
279 | environment: str = "dev",
280 | ) -> None:
281 | """Initialize the configuration loader."""
282 | self.environment = environment
283 | self.sources = list(sources)
284 |
285 | def load(self) -> t.MutableMapping[str, t.Any]:
286 | """Load configuration from sources."""
287 | scopes = _ConfigScopes(
288 | explicit=_to_box({}), environment=_to_box({}), baseline=_to_box({})
289 | )
290 | for source in self.sources:
291 | if isinstance(source, dict):
292 | # User may provide configuration as a dictionary directly
293 | # in which case it takes precedence over other sources
294 | scopes.explicit.merge_update(source)
295 | elif isinstance(source, (str, Path)):
296 | # Load configuration from file
297 | path = Path(source)
298 | result = load_file(path)
299 | if result.is_ok():
300 | scopes.baseline.merge_update(result.unwrap())
301 | else:
302 | err = result.unwrap_err()
303 | if not isinstance(err, FileNotFoundError):
304 | logger.warning(
305 | f"Failed to load configuration from {path}: {result.unwrap_err()}"
306 | )
307 | else:
308 | logger.debug(f"Configuration file not found: {path}")
309 | # Load environment-specific configuration from corresponding file
310 | # e.g. config.dev.json, config.dev.yaml, config.dev.toml
311 | env_path = path.with_name(
312 | f"{path.stem}.{self.environment}{path.suffix}"
313 | )
314 | result = load_file(env_path)
315 | if result.is_ok():
316 | scopes.environment.merge_update(result.unwrap())
317 | else:
318 | err = result.unwrap_err()
319 | if not isinstance(err, FileNotFoundError):
320 | logger.warning(
321 | f"Failed to load configuration from {path}: {err}"
322 | )
323 | else:
324 | logger.debug(f"Configuration file not found: {env_path}")
325 | return scopes.resolve()
326 |
327 | def import_source(self, source: ConfigSource, append: bool = True) -> None:
328 | """Include a new source of configuration."""
329 | if append:
330 | # Takes priority within the same scope
331 | self.sources.append(source)
332 | else:
333 | self.sources.insert(0, source)
334 |
335 | def clear_sources(self) -> t.List[ConfigSource]:
336 | """Clear all sources of configuration returning the previous sources."""
337 | cleared_sources = self.sources.copy()
338 | self.sources.clear()
339 | return cleared_sources
340 |
341 |
342 | _MISSING: t.Any = object()
343 | """A sentinel value for a missing configuration value."""
344 |
345 | RESOLVER_HINT = "__cdf_resolve__"
346 | """A hint to engage the configuration resolver."""
347 |
348 |
349 | def map_config_section(
350 | *sections: str,
351 | ) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]:
352 | """Mark a function to inject configuration values from a specific section."""
353 |
354 | def decorator(func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]:
355 | setattr(inspect.unwrap(func_or_cls), RESOLVER_HINT, sections)
356 | return func_or_cls
357 |
358 | return decorator
359 |
360 |
361 | def map_config_values(
362 | **mapping: t.Any,
363 | ) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]:
364 | """Mark a function to inject configuration values from a specific mapping of param names to keys."""
365 |
366 | def decorator(func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]:
367 | setattr(inspect.unwrap(func_or_cls), RESOLVER_HINT, mapping)
368 | return func_or_cls
369 |
370 | return decorator
371 |
372 |
373 | class Request:
374 | def __init__(self, item: str):
375 | self.item = item
376 |
377 | def __class_getitem__(cls, item: str) -> "Request":
378 | return cls(item)
379 |
380 |
381 | class ConfigLoaderProtocol(t.Protocol):
382 | environment: str
383 | sources: t.List[ConfigSource]
384 |
385 | def load(self) -> t.MutableMapping[str, t.Any]: ...
386 |
387 | def import_source(self, source: ConfigSource, append: bool = True) -> None: ...
388 |
389 | def clear_sources(self) -> t.List[ConfigSource]: ...
390 |
391 |
392 | class ConfigResolver(t.MutableMapping[str, t.Any]):
393 | """Resolve configuration values."""
394 |
395 | def __init__(
396 | self,
397 | *sources: ConfigSource,
398 | environment: str = "dev",
399 | loader: ConfigLoaderProtocol = ConfigLoader("config.json"),
400 | deferred: bool = False,
401 | ) -> None:
402 | """Initialize the configuration resolver.
403 |
404 | The environment serves 2 purposes:
405 | 1. It determines supplementary configuration file to load, e.g. config.dev.json.
406 | 2. It prefixes configuration keys and prioritizes them over non-prefixed keys. e.g. dev.api.key.
407 |
408 | These are not mutually exclusive and can be used together.
409 |
410 | Args:
411 | sources: The sources of configuration.
412 | environment: The environment to load configuration for.
413 | loader: The configuration loader.
414 | deferred: If True, the configuration is not loaded until requested.
415 | """
416 | self.environment = environment
417 | for source in sources:
418 | loader.import_source(source)
419 | self._loader = loader
420 | self._config = loader.load() if not deferred else None
421 | self._frozen_environment = os.environ.copy()
422 | self._explicit_values = _to_box({})
423 |
424 | @property
425 | def wrapped(self) -> t.MutableMapping[str, t.Any]:
426 | """Get the configuration dictionary."""
427 | if self._config is None:
428 | self._config = _to_box(self._loader.load())
429 | return ChainMap(self._explicit_values, self._config)
430 |
431 | def __getitem__(self, key: str) -> t.Any:
432 | """Get a configuration value."""
433 | try:
434 | v = self.wrapped[f"{self.environment}.{key}"]
435 | except KeyError:
436 | v = self.wrapped[key]
437 | return self.apply_converters(v, self)
438 |
439 | def __setitem__(self, key: str, value: t.Any) -> None:
440 | """Set a configuration value."""
441 | self._explicit_values[f"{self.environment}.{key}"] = value
442 |
443 | def __delitem__(self, key: str) -> None:
444 | self._explicit_values.pop(f"{self.environment}.{key}", None)
445 |
446 | def __iter__(self) -> t.Iterator[str]:
447 | return iter(self.wrapped)
448 |
449 | def __len__(self) -> int:
450 | return len(self.wrapped)
451 |
452 | def __getattr__(self, key: str) -> t.Any:
453 | """Get a configuration value."""
454 | try:
455 | return self[key]
456 | except KeyError as e:
457 | raise AttributeError from e
458 |
459 | def __enter__(self) -> "ConfigResolver":
460 | """Enter a context."""
461 | return self
462 |
463 | def __exit__(self, *args) -> None:
464 | """Exit a context."""
465 | os.environ.clear()
466 | os.environ.update(self._frozen_environment)
467 |
468 | def __repr__(self) -> str:
469 | """Get a string representation of the configuration resolver."""
470 | return f"{self.__class__.__name__}(<{len(self._loader.sources)} sources>)"
471 |
472 | def set_environment(self, environment: str) -> None:
473 | """Set the environment of the configuration resolver."""
474 | self.environment = environment
475 | self._loader.environment = environment
476 | self._config = None
477 |
478 | def import_source(self, source: ConfigSource, append: bool = True) -> None:
479 | """Include a new source of configuration."""
480 | self._loader.import_source(source, append)
481 | self._config = None
482 |
483 | def clear_sources(self) -> t.List[ConfigSource]:
484 | """Clear all sources of configuration returning the previous sources."""
485 | sources = self._loader.clear_sources()
486 | self._config = None
487 | return sources
488 |
489 | map_section = staticmethod(map_config_section)
490 | """Mark a function to inject configuration values from a specific section."""
491 |
492 | map_values = staticmethod(map_config_values)
493 | """Mark a function to inject configuration values from a specific mapping of param names to keys."""
494 |
495 | add_custom_converter = staticmethod(add_custom_converter)
496 | """Add a custom converter to the configuration system."""
497 |
498 | apply_converters = staticmethod(apply_converters)
499 | """Apply converters to a string."""
500 |
501 | KWARG_HINT = "_cdf_resolve"
502 | """A hint supplied in a kwarg to engage the configuration resolver."""
503 |
504 | def _parse_hint_from_params(
505 | self, func_or_cls: t.Callable, sig: t.Optional[inspect.Signature] = None
506 | ) -> t.Optional[t.Union[t.Tuple[str, ...], t.Mapping[str, str]]]:
507 | """Get the sections or explicit lookups from a function.
508 |
509 | This assumes a kwarg named `_cdf_resolve` that is either a tuple of section names or
510 | a dictionary of param names to config keys is present in the function signature.
511 | """
512 | sig = sig or inspect.signature(func_or_cls)
513 | if self.KWARG_HINT in sig.parameters:
514 | resolver_spec = sig.parameters[self.KWARG_HINT]
515 | if isinstance(resolver_spec.default, (tuple, dict)):
516 | return resolver_spec.default
517 |
518 | def resolve_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
519 | """Resolve configuration values into a function or class."""
520 | if not callable(func_or_cls):
521 | return func_or_cls
522 |
523 | sig = inspect.signature(func_or_cls)
524 | is_resolved_sentinel = "__config_resolved__"
525 |
526 | resolver_hint = getattr(
527 | inspect.unwrap(func_or_cls),
528 | RESOLVER_HINT,
529 | self._parse_hint_from_params(func_or_cls, sig),
530 | )
531 |
532 | if any(hasattr(f, is_resolved_sentinel) for f in _iter_wrapped(func_or_cls)):
533 | return func_or_cls
534 |
535 | @functools.wraps(func_or_cls)
536 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
537 | bound_args = sig.bind_partial(*args, **kwargs)
538 | bound_args.apply_defaults()
539 |
540 | # Apply converters to string literal arguments
541 | for arg_name, arg_value in bound_args.arguments.items():
542 | if isinstance(arg_value, str):
543 | with suppress(Exception):
544 | bound_args.arguments[arg_name] = self.apply_converters(
545 | arg_value,
546 | self,
547 | )
548 |
549 | # Resolve configuration values
550 | for name, param in sig.parameters.items():
551 | value = _MISSING
552 | if not self.is_resolvable(param):
553 | continue
554 |
555 | # 1. Prioritize Request annotations
556 | elif request := self.extract_request_annotation(param):
557 | value = self.get(request, _MISSING)
558 |
559 | # 2. Use explicit lookups if provided
560 | elif isinstance(resolver_hint, dict):
561 | if name not in resolver_hint:
562 | continue
563 | value = self.get(resolver_hint[name], _MISSING)
564 |
565 | # 3. Use section-based lookups if provided
566 | elif isinstance(resolver_hint, (tuple, list)):
567 | value = self.get(".".join((*resolver_hint, name)), _MISSING)
568 |
569 | # Inject the value into the function
570 | if value is not _MISSING:
571 | bound_args.arguments[name] = self.apply_converters(value, self)
572 |
573 | return func_or_cls(*bound_args.args, **bound_args.kwargs)
574 |
575 | setattr(wrapper, is_resolved_sentinel, True)
576 | return wrapper
577 |
578 | def is_resolvable(self, param: inspect.Parameter) -> bool:
579 | """Check if a parameter is injectable."""
580 | return param.default in (param.empty, None)
581 |
582 | @staticmethod
583 | def extract_request_annotation(param: inspect.Parameter) -> t.Optional[str]:
584 | """Extract a request annotation from a parameter."""
585 | for hint in getattr(param.annotation, "__metadata__", ()):
586 | if isinstance(hint, Request):
587 | return hint.item
588 |
589 | def __call__(
590 | self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any
591 | ) -> T:
592 | """Invoke a callable with injected configuration values."""
593 | configured_f = self.resolve_defaults(func_or_cls)
594 | if not callable(configured_f):
595 | return configured_f
596 | return configured_f(*args, **kwargs)
597 |
598 | @classmethod
599 | def __get_pydantic_core_schema__(
600 | cls, source_type: t.Any, handler: pydantic.GetCoreSchemaHandler
601 | ) -> pydantic_core.CoreSchema:
602 | return pydantic_core.core_schema.dict_schema(
603 | keys_schema=pydantic_core.core_schema.str_schema(),
604 | values_schema=pydantic_core.core_schema.any_schema(),
605 | )
606 |
607 |
608 | def _iter_wrapped(f: t.Callable):
609 | yield f
610 | f_w = inspect.unwrap(f)
611 | if f_w is not f:
612 | yield from _iter_wrapped(f_w)
613 |
--------------------------------------------------------------------------------
/src/cdf/core/injector/registry.py:
--------------------------------------------------------------------------------
1 | """Dependency registry with lifecycle management."""
2 |
3 | import enum
4 | import inspect
5 | import logging
6 | import os
7 | import sys
8 | import types
9 | import typing as t
10 | from collections import ChainMap
11 | from functools import partial, partialmethod, wraps
12 |
13 | import pydantic
14 | import pydantic_core
15 | from typing_extensions import ParamSpec, Self
16 |
17 | import cdf.core.configuration as conf
18 | from cdf.core.context import get_default_callable_lifecycle
19 | from cdf.core.injector.errors import DependencyCycleError, DependencyMutationError
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | T = t.TypeVar("T")
24 | P = ParamSpec("P")
25 |
26 | __all__ = [
27 | "DependencyRegistry",
28 | "Dependency",
29 | "Lifecycle",
30 | "DependencyKey",
31 | "GLOBAL_REGISTRY",
32 | ]
33 |
34 |
35 | class Lifecycle(enum.Enum):
36 | """Lifecycle of a dependency."""
37 |
38 | PROTOTYPE = enum.auto()
39 | """A prototype dependency is created every time it is requested"""
40 |
41 | SINGLETON = enum.auto()
42 | """A singleton dependency is created once and shared."""
43 |
44 | INSTANCE = enum.auto()
45 | """An instance dependency is a global object which is not created by the container."""
46 |
47 | @property
48 | def is_prototype(self) -> bool:
49 | """Check if the lifecycle is prototype."""
50 | return self == Lifecycle.PROTOTYPE
51 |
52 | @property
53 | def is_singleton(self) -> bool:
54 | """Check if the lifecycle is singleton."""
55 | return self == Lifecycle.SINGLETON
56 |
57 | @property
58 | def is_instance(self) -> bool:
59 | """Check if the lifecycle is instance."""
60 | return self == Lifecycle.INSTANCE
61 |
62 | @property
63 | def is_deferred(self) -> bool:
64 | """Check if the object to be created is deferred."""
65 | return self.is_prototype or self.is_singleton
66 |
67 | def __str__(self) -> str:
68 | return self.name.lower()
69 |
70 | @classmethod
71 | def default_for(cls, obj: t.Any) -> "Lifecycle":
72 | """Get the default lifecycle."""
73 | if callable(obj):
74 | return get_default_callable_lifecycle()
75 | return cls.INSTANCE
76 |
77 |
78 | class TypedKey(t.NamedTuple):
79 | """A key which is a tuple of a name and a type."""
80 |
81 | name: str
82 | type_: t.Type[t.Any]
83 |
84 | @property
85 | def type_name(self) -> t.Optional[str]:
86 | """Get the name of the type if applicable."""
87 | return self.type_.__name__
88 |
89 | def __str__(self) -> str:
90 | return f"{self.name}: {self.type_name}"
91 |
92 | def __repr__(self) -> str:
93 | return f""
94 |
95 | def __eq__(self, other: t.Any) -> bool:
96 | """Two keys are equal if their names and base types match."""
97 | if not isinstance(other, (TypedKey, tuple)):
98 | return False
99 | return self.name == other[0] and _same_type(self.type_, other[1])
100 |
101 | def __hash__(self) -> int:
102 | """Hash the key with the effective type if possible."""
103 | try:
104 | return hash((self.name, _unwrap_type(self.type_)))
105 | except TypeError as e:
106 | logger.warning(f"Failed to hash key {self!r}: {e}")
107 | return hash((self.name, self.type_))
108 |
109 |
110 | DependencyKey = t.Union[str, t.Tuple[str, t.Type[t.Any]], TypedKey]
111 | """A string or a typed key."""
112 |
113 |
114 | def _unwrap_optional(hint: t.Type) -> t.Type:
115 | """Unwrap Optional type hint. Also unwraps types.UnionType like str | None
116 |
117 | Args:
118 | hint: The type hint.
119 |
120 | Returns:
121 | The unwrapped type hint.
122 | """
123 | args = t.get_args(hint)
124 | if len(args) != 2 or args[1] is not type(None):
125 | return hint
126 | return args[0]
127 |
128 |
129 | def _is_union(hint: t.Type) -> bool:
130 | """Check if a type hint is a Union.
131 |
132 | Args:
133 | hint: The type hint.
134 |
135 | Returns:
136 | True if the type hint is a Union.
137 | """
138 | return hint is t.Union or (sys.version_info >= (3, 10) and hint is types.UnionType)
139 |
140 |
141 | def _is_ambiguous_type(hint: t.Optional[t.Type]) -> bool:
142 | """Check if a type hint or Signature annotation is ambiguous.
143 |
144 | Args:
145 | hint: The type hint.
146 |
147 | Returns:
148 | True if the type hint is ambiguous.
149 | """
150 | return hint in (
151 | object,
152 | t.Any,
153 | None,
154 | type(None),
155 | t.NoReturn,
156 | inspect.Parameter.empty,
157 | type(lambda: None),
158 | )
159 |
160 |
161 | def _unwrap_type(hint: t.Type) -> t.Type:
162 | """Unwrap a type hint.
163 |
164 | For a Union, this is the base type if all types are the same base type.
165 | Otherwise, it is the hint itself with Optional unwrapped.
166 |
167 | Args:
168 | hint: The type hint.
169 |
170 | Returns:
171 | The unwrapped type hint.
172 | """
173 | hint = _unwrap_optional(hint)
174 | if _is_union(hint):
175 | args = list(map(_unwrap_optional, t.get_args(hint)))
176 | if not args:
177 | return hint
178 | f_base = getattr(args[0], "__base__", None)
179 | if f_base and all(f_base is getattr(arg, "__base__", None) for arg in args[1:]):
180 | # Ex. Union[HarnessFFProvider, SplitFFProvider, LaunchDarklyFFProvider]
181 | # == BaseFFProvider
182 | return f_base
183 | return hint
184 |
185 |
186 | def _same_type(hint1: t.Type, hint2: t.Type) -> bool:
187 | """Check if two type hints are of the same unwrapped type.
188 |
189 | Args:
190 | hint1: The first type hint.
191 | hint2: The second type hint.
192 |
193 | Returns:
194 | True if the unwrapped types are the same.
195 | """
196 | return _unwrap_type(hint1) is _unwrap_type(hint2)
197 |
198 |
199 | @t.overload
200 | def _normalize_key(key: str) -> str: ...
201 |
202 |
203 | @t.overload
204 | def _normalize_key(key: t.Union[t.Tuple[str, t.Any], TypedKey]) -> TypedKey: ...
205 |
206 |
207 | def _normalize_key(
208 | key: t.Union[str, t.Tuple[str, t.Type[t.Any]], TypedKey],
209 | ) -> t.Union[str, TypedKey]:
210 | """Normalize a key 2-tuple to a TypedKey if it is not already, preserve str.
211 |
212 | Args:
213 | key: The key to normalize.
214 |
215 | Returns:
216 | The normalized key.
217 | """
218 | if isinstance(key, str):
219 | return key
220 | k, t_ = key
221 | return TypedKey(k, _unwrap_type(t_))
222 |
223 |
224 | def _safe_get_type_hints(obj: t.Any) -> t.Dict[str, t.Type]:
225 | """Get type hints for an object, ignoring errors.
226 |
227 | Args:
228 | obj: The object to get type hints for.
229 |
230 | Returns:
231 | A dictionary of attribute names to type hints.
232 | """
233 | try:
234 | if isinstance(obj, partial):
235 | obj = obj.func
236 | return t.get_type_hints(obj)
237 | except Exception as e:
238 | logger.debug(f"Failed to get type hints for {obj!r}: {e}")
239 | return {}
240 |
241 |
242 | class Dependency(pydantic.BaseModel, t.Generic[T]):
243 | """A Monadic type which wraps a value with lifecycle and allows simple transformations."""
244 |
245 | factory: t.Callable[..., T]
246 | """The factory or instance of the dependency."""
247 | lifecycle: Lifecycle = Lifecycle.SINGLETON
248 | """The lifecycle of the dependency."""
249 |
250 | conf_spec: t.Optional[t.Union[t.Tuple[str, ...], t.Dict[str, str]]] = None
251 | """A hint for configuration values."""
252 | alias: t.Optional[str] = None
253 | """Used as an alternative to inferring the name from the factory."""
254 |
255 | _instance: t.Optional[T] = None
256 | """The instance of the dependency once resolved."""
257 | _is_resolved: bool = False
258 | """Flag to indicate if the dependency has been unwrapped."""
259 |
260 | @pydantic.model_validator(mode="after")
261 | def _apply_spec(self) -> Self:
262 | """Apply the configuration spec to the dependency."""
263 | spec = self.conf_spec
264 | if isinstance(spec, dict):
265 | self.map(conf.map_config_values(**spec))
266 | elif isinstance(spec, tuple):
267 | self.map(conf.map_config_section(*spec))
268 | return self
269 |
270 | @pydantic.model_validator(mode="before")
271 | @classmethod
272 | def _ensure_lifecycle(cls, data: t.Any) -> t.Any:
273 | """Ensure a valid lifecycle is set for the dependency."""
274 |
275 | if isinstance(data, dict):
276 | factory = data["factory"]
277 | lc = data.get("lifecycle", Lifecycle.default_for(factory))
278 | if isinstance(lc, str):
279 | lc = Lifecycle[lc.upper()]
280 | if not isinstance(lc, Lifecycle):
281 | raise ValueError(f"Invalid lifecycle {lc=}")
282 | if not (lc.is_instance or callable(factory)):
283 | raise ValueError(f"Value must be callable for {lc=}")
284 | data["lifecycle"] = lc
285 | return data
286 |
287 | @pydantic.field_validator("factory", mode="before")
288 | @classmethod
289 | def _ensure_callable(cls, factory: t.Any) -> t.Any:
290 | """Ensure the factory is callable."""
291 | if not callable(factory):
292 |
293 | def defer() -> T:
294 | return factory
295 |
296 | defer.__name__ = f"factory_{os.urandom(4).hex()}"
297 | return defer
298 | return factory
299 |
300 | @classmethod
301 | def instance(cls, instance: t.Any) -> "Dependency":
302 | """Create a dependency from an instance.
303 |
304 | Args:
305 | instance: The instance to use as the dependency.
306 |
307 | Returns:
308 | A new Dependency object with the instance lifecycle.
309 | """
310 | obj = cls(factory=instance, lifecycle=Lifecycle.INSTANCE)
311 | obj._instance = instance
312 | obj._is_resolved = True
313 | return obj
314 |
315 | @classmethod
316 | def singleton(
317 | cls, factory: t.Callable[..., T], *args: t.Any, **kwargs: t.Any
318 | ) -> "Dependency":
319 | """Create a singleton dependency.
320 |
321 | Args:
322 | factory: The factory function to create the dependency.
323 | args: Positional arguments to pass to the factory.
324 | kwargs: Keyword arguments to pass to the factory.
325 |
326 | Returns:
327 | A new Dependency object with the singleton lifecycle.
328 | """
329 | if callable(factory) and (args or kwargs):
330 | factory = partial(factory, *args, **kwargs)
331 | return cls(factory=factory, lifecycle=Lifecycle.SINGLETON)
332 |
333 | @classmethod
334 | def prototype(
335 | cls, factory: t.Callable[..., T], *args: t.Any, **kwargs: t.Any
336 | ) -> "Dependency":
337 | """Create a prototype dependency.
338 |
339 | Args:
340 | factory: The factory function to create the dependency.
341 | args: Positional arguments to pass to the factory.
342 | kwargs: Keyword arguments to pass to the factory.
343 |
344 | Returns:
345 | A new Dependency object with the prototype lifecycle.
346 | """
347 | if callable(factory) and (args or kwargs):
348 | factory = partial(factory, *args, **kwargs)
349 | return cls(factory=factory, lifecycle=Lifecycle.PROTOTYPE)
350 |
351 | @classmethod
352 | def wrap(cls, obj: t.Any, *args: t.Any, **kwargs: t.Any) -> Self:
353 | """Wrap an object as a dependency.
354 |
355 | Assumes singleton lifecycle for callables unless a default lifecycle context is set.
356 |
357 | Args:
358 | obj: The object to wrap.
359 |
360 | Returns:
361 | A new Dependency object with the object as the factory.
362 | """
363 | if callable(obj):
364 | if args or kwargs:
365 | obj = partial(obj, *args, **kwargs)
366 | return cls(factory=obj, lifecycle=get_default_callable_lifecycle())
367 | return cls(factory=obj, lifecycle=Lifecycle.INSTANCE)
368 |
369 | def map_value(self, func: t.Callable[[T], T]) -> Self:
370 | """Apply a function to the unwrapped value.
371 |
372 | Args:
373 | func: The function to apply to the unwrapped value.
374 |
375 | Returns:
376 | A new Dependency object with the function applied.
377 | """
378 | if self._is_resolved:
379 | self._instance = func(self._instance) # type: ignore
380 | return self
381 |
382 | factory = self.factory
383 |
384 | @wraps(factory)
385 | def wrapper() -> T:
386 | return func(factory())
387 |
388 | self.factory = wrapper
389 | return self
390 |
391 | def map(
392 | self,
393 | *funcs: t.Callable[[t.Callable[..., T]], t.Callable[..., T]],
394 | idempotent: bool = False,
395 | ) -> Self:
396 | """Apply a sequence of transformations to the wrapped value.
397 |
398 | The transformations are applied in order. This is a no-op if the dependency is
399 | already resolved and idempotent is True or the dependency is an instance.
400 |
401 | Args:
402 | funcs: The functions to apply to the wrapped value.
403 | idempotent: If True, allow transformations on resolved dependencies to be a no-op.
404 |
405 | Returns:
406 | The Dependency object with the transformations applied.
407 | """
408 | if self._is_resolved:
409 | if self.lifecycle.is_instance or idempotent:
410 | return self
411 | raise DependencyMutationError(
412 | f"Dependency {self!r} is already resolved, cannot apply transformations to factory"
413 | )
414 | factory = self.factory
415 | for func in funcs:
416 | factory = func(factory)
417 | self.factory = factory
418 | return self
419 |
420 | def unwrap(self) -> T:
421 | """Unwrap the value from the factory."""
422 | if self.lifecycle.is_prototype:
423 | return self.factory()
424 | if self._instance is not None:
425 | return self._instance
426 | self._instance = self.factory()
427 | if self.lifecycle.is_singleton:
428 | self._is_resolved = True
429 | return self._instance
430 |
431 | def __str__(self) -> str:
432 | return f"{self.factory} ({self.lifecycle})"
433 |
434 | def __repr__(self) -> str:
435 | return f""
436 |
437 | def __call__(self) -> T:
438 | """Alias for unwrap."""
439 | return self.unwrap()
440 |
441 | def try_infer_type(self) -> t.Optional[t.Type[T]]:
442 | """Get the effective type of the dependency."""
443 | if inspect.isclass(self.factory):
444 | return _unwrap_type(self.factory)
445 | if inspect.isfunction(self.factory):
446 | if hint := _safe_get_type_hints(inspect.unwrap(self.factory)).get("return"):
447 | return _unwrap_type(hint)
448 | if self._is_resolved:
449 | return _unwrap_type(type(self._instance))
450 |
451 | def try_infer_name(self) -> t.Optional[str]:
452 | """Infer the name of the dependency from the factory."""
453 | if self.alias:
454 | return self.alias
455 | if isinstance(self.factory, partial):
456 | f = inspect.unwrap(self.factory.func)
457 | else:
458 | f = inspect.unwrap(self.factory)
459 | if inspect.isfunction(f):
460 | return f.__name__
461 | if inspect.isclass(f):
462 | return f.__name__
463 | return getattr(f, "name", None)
464 |
465 | def generate_key(
466 | self, name: t.Optional[DependencyKey] = None
467 | ) -> t.Union[str, TypedKey]:
468 | """Generate a typed key for the dependency.
469 |
470 | Args:
471 | name: The name of the dependency.
472 |
473 | Returns:
474 | A typed key if the type can be inferred, else the name.
475 | """
476 | if not name:
477 | name = self.try_infer_name()
478 | if not name:
479 | raise ValueError(
480 | "Cannot infer name for dependency and no name or alias provided"
481 | )
482 | if isinstance(name, TypedKey):
483 | return name
484 | elif isinstance(name, tuple):
485 | return TypedKey(name[0], name[1])
486 | hint = self.try_infer_type()
487 | return TypedKey(name, hint) if hint and not _is_ambiguous_type(hint) else name
488 |
489 |
490 | class DependencyRegistry(t.MutableMapping[DependencyKey, Dependency]):
491 | """A registry for dependencies with lifecycle management.
492 |
493 | Dependencies can be registered with a name or a typed key. Typed keys are tuples
494 | of a name and a type hint. Dependencies can be added with a lifecycle, which can be
495 | one of prototype, singleton, or instance. Dependencies can be retrieved by name or
496 | typed key. Dependencies can be injected into functions or classes. Dependencies can
497 | be wired into callables to resolve a dependency graph.
498 | """
499 |
500 | lifecycle = Lifecycle
501 |
502 | def __init__(self, strict: bool = False) -> None:
503 | """Initialize the registry.
504 |
505 | Args:
506 | strict: If True, do not inject an untyped lookup for a typed dependency.
507 | """
508 | self.strict = strict
509 | self._typed_dependencies: t.Dict[TypedKey, Dependency] = {}
510 | self._untyped_dependencies: t.Dict[str, Dependency] = {}
511 | self._resolving: t.Set[t.Union[str, TypedKey]] = set()
512 |
513 | @property
514 | def dependencies(self) -> ChainMap[t.Any, Dependency]:
515 | """Get all dependencies."""
516 | return ChainMap(self._typed_dependencies, self._untyped_dependencies)
517 |
518 | def add(
519 | self,
520 | key: DependencyKey,
521 | value: t.Any,
522 | lifecycle: t.Optional[Lifecycle] = None,
523 | override: bool = False,
524 | init_args: t.Tuple[t.Any, ...] = (),
525 | init_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
526 | ) -> None:
527 | """Register a dependency with the container.
528 |
529 | Args:
530 | key: The name of the dependency.
531 | value: The factory or instance of the dependency.
532 | lifecycle: The lifecycle of the dependency.
533 | override: If True, override an existing dependency.
534 | init_args: Arguments to initialize the factory with.
535 | init_kwargs: Keyword arguments to initialize the factory with.
536 | """
537 |
538 | # Assume singleton lifecycle if the value is callable unless set in context
539 | if lifecycle is None:
540 | lifecycle = Lifecycle.default_for(value)
541 |
542 | # If the value is callable and has initialization args, bind them early so
543 | # we don't need to schlepp them around
544 | if callable(value) and (init_args or init_kwargs):
545 | value = partial(value, *init_args, **(init_kwargs or {}))
546 |
547 | # Register the dependency
548 | dependency = Dependency(factory=value, lifecycle=lifecycle)
549 | dependency_key = dependency.generate_key(key)
550 | if self.has(dependency_key) and not override:
551 | raise ValueError(f'Dependency "{dependency_key}" is already registered')
552 | if isinstance(dependency_key, TypedKey):
553 | self._typed_dependencies[dependency_key] = dependency
554 | # Allow untyped access to typed dependencies for convenience if not strict
555 | # or if the hint is not a distinct type
556 | if not self.strict or _is_ambiguous_type(dependency_key.type_):
557 | self._untyped_dependencies[dependency_key.name] = dependency
558 | else:
559 | self._untyped_dependencies[dependency_key] = dependency
560 |
561 | add_prototype = partialmethod(add, lifecycle=Lifecycle.PROTOTYPE)
562 | add_singleton = partialmethod(add, lifecycle=Lifecycle.SINGLETON)
563 | add_instance = partialmethod(add, lifecycle=Lifecycle.INSTANCE)
564 |
565 | def add_from_dependency(
566 | self,
567 | dependency: Dependency,
568 | key: t.Optional[DependencyKey] = None,
569 | override: bool = False,
570 | ) -> None:
571 | """Add a Dependency object to the container.
572 |
573 | Args:
574 | key: The name or typed key of the dependency.
575 | dependency: The dependency object.
576 | override: If True, override an existing dependency
577 | """
578 | dependency_key = dependency.generate_key(key)
579 | if self.has(dependency_key) and not override:
580 | raise ValueError(
581 | f'Dependency "{dependency_key}" is already registered, use a different name to avoid conflicts'
582 | )
583 | if isinstance(dependency_key, TypedKey):
584 | self._typed_dependencies[dependency_key] = dependency
585 | if not self.strict or _is_ambiguous_type(dependency_key.type_):
586 | self._untyped_dependencies[dependency_key.name] = dependency
587 | else:
588 | self._untyped_dependencies[dependency_key] = dependency
589 |
590 | def remove(self, name_or_key: DependencyKey) -> None:
591 | """Remove a dependency by name or key from the container.
592 |
593 | Args:
594 | name_or_key: The name or typed key of the dependency.
595 | """
596 | key = _normalize_key(name_or_key)
597 | if isinstance(key, str):
598 | if key in self._untyped_dependencies:
599 | del self._untyped_dependencies[key]
600 | else:
601 | raise KeyError(f'Dependency "{key}" is not registered')
602 | elif key in self._typed_dependencies:
603 | del self._typed_dependencies[key]
604 | else:
605 | raise KeyError(f'Dependency "{key}" is not registered')
606 |
607 | def clear(self) -> None:
608 | """Clear all dependencies and singletons."""
609 | self._typed_dependencies.clear()
610 | self._untyped_dependencies.clear()
611 |
612 | def has(self, name_or_key: DependencyKey) -> bool:
613 | """Check if a dependency is registered.
614 |
615 | Args:
616 | name_or_key: The name or typed key of the dependency.
617 | """
618 | return name_or_key in self.dependencies
619 |
620 | def resolve(self, name_or_key: DependencyKey, must_exist: bool = False) -> t.Any:
621 | """Get a dependency.
622 |
623 | Args:
624 | name_or_key: The name or typed key of the dependency.
625 | must_exist: If True, raise KeyError if the dependency is not found.
626 |
627 | Returns:
628 | The dependency if found, else None.
629 | """
630 | key = _normalize_key(name_or_key)
631 |
632 | # Resolve the dependency
633 | if isinstance(key, str):
634 | if key not in self._untyped_dependencies:
635 | if must_exist:
636 | raise KeyError(f'Dependency "{key}" is not registered')
637 | return
638 | dep = self.dependencies[key]
639 | else:
640 | if _is_union(key.type_):
641 | types = map(_unwrap_type, t.get_args(key.type_))
642 | else:
643 | types = [key.type_]
644 | for type_ in types:
645 | key = TypedKey(key.name, type_)
646 | if self.has(key):
647 | break
648 | else:
649 | if must_exist:
650 | raise KeyError(f'Dependency "{key}" is not registered')
651 | return
652 | dep = self.dependencies[key]
653 |
654 | # Detect dependency cycles
655 | if key in self._resolving:
656 | raise DependencyCycleError(
657 | f"Dependency cycle detected while resolving {key} for {dep.factory!r}"
658 | )
659 |
660 | # Handle the lifecycle of the dependency, recursively resolving dependencies
661 | self._resolving.add(key)
662 | try:
663 | return dep.map(self.wire).unwrap()
664 | except DependencyMutationError:
665 | return dep.unwrap()
666 | finally:
667 | self._resolving.remove(key)
668 |
669 | resolve_or_raise = partialmethod(resolve, must_exist=True)
670 |
671 | def __contains__(self, name: t.Any) -> bool:
672 | """Check if a dependency is registered."""
673 | return self.has(name)
674 |
675 | def __getitem__(self, name: DependencyKey) -> t.Any:
676 | """Get a dependency. Raises KeyError if not found."""
677 | return self.resolve(name, must_exist=True)
678 |
679 | def __setitem__(self, name: DependencyKey, value: t.Any) -> None:
680 | """Add a dependency. Defaults to singleton lifecycle if callable, else instance."""
681 | self.add(name, value, override=True)
682 |
683 | def __delitem__(self, name: DependencyKey) -> None:
684 | """Remove a dependency."""
685 | self.remove(name)
686 |
687 | def wire(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
688 | """Inject dependencies into a function.
689 |
690 | Args:
691 | func_or_cls: The function or class to inject dependencies into.
692 |
693 | Returns:
694 | A function that can be called with dependencies injected
695 | """
696 | if not callable(func_or_cls):
697 | return func_or_cls
698 |
699 | sig = inspect.signature(func_or_cls)
700 | is_resolved_sentinel = "__deps_resolved__"
701 |
702 | if any(hasattr(f, is_resolved_sentinel) for f in _iter_wrapped(func_or_cls)):
703 | return func_or_cls
704 |
705 | @wraps(func_or_cls)
706 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
707 | bound_args = sig.bind_partial(*args, **kwargs)
708 | for name, param in sig.parameters.items():
709 | if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
710 | continue
711 | if param.default not in (param.empty, None):
712 | continue
713 | if name not in bound_args.arguments:
714 | dep = None
715 | # Try to resolve a typed dependency
716 | if not _is_ambiguous_type(param.annotation):
717 | dep = self.resolve((name, param.annotation))
718 | # Fallback to untyped injection
719 | if dep is None:
720 | dep = self.resolve(name)
721 | # If a dependency is found, inject it
722 | if dep is not None:
723 | bound_args.arguments[name] = dep
724 | return func_or_cls(*bound_args.args, **bound_args.kwargs)
725 |
726 | setattr(wrapper, is_resolved_sentinel, True)
727 | return wrapper
728 |
729 | def __call__(
730 | self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any
731 | ) -> T:
732 | """Invoke a callable with dependencies injected from the registry.
733 |
734 | Args:
735 | func_or_cls: The function or class to invoke.
736 | args: Positional arguments to pass to the callable.
737 | kwargs: Keyword arguments to pass to the callable.
738 |
739 | Returns:
740 | The result of the callable
741 | """
742 | wired_f = self.wire(func_or_cls)
743 | if not callable(wired_f):
744 | return wired_f
745 | return wired_f(*args, **kwargs)
746 |
747 | def __iter__(self) -> t.Iterator[TypedKey]:
748 | """Iterate over dependency names."""
749 | return iter(self.dependencies)
750 |
751 | def __len__(self) -> int:
752 | """Return the number of dependencies."""
753 | return len(self.dependencies)
754 |
755 | def __repr__(self) -> str:
756 | return f"DependencyRegistry(<{list(self.dependencies.keys())}>)"
757 |
758 | def __str__(self) -> str:
759 | return repr(self)
760 |
761 | def __bool__(self) -> bool:
762 | """True if the registry has dependencies."""
763 | return bool(self.dependencies)
764 |
765 | def __or__(self, other: "DependencyRegistry") -> "DependencyRegistry":
766 | """Merge two registries like pythons dict union overload."""
767 | self._untyped_dependencies = {
768 | **self._untyped_dependencies,
769 | **other._untyped_dependencies,
770 | }
771 | self._typed_dependencies = {
772 | **self._typed_dependencies,
773 | **other._typed_dependencies,
774 | }
775 | return self
776 |
777 | def __getstate__(self) -> t.Dict[str, t.Any]:
778 | """Serialize the state."""
779 | return {
780 | "_typed_dependencies": self._typed_dependencies,
781 | "_untyped_dependencies": self._untyped_dependencies,
782 | "_resolving": self._resolving,
783 | }
784 |
785 | def __setstate__(self, state: t.Dict[str, t.Any]) -> None:
786 | """Deserialize the state."""
787 | self._typed_dependencies = state["_typed_dependencies"]
788 | self._untyped_dependencies = state["_untyped_dependencies"]
789 | self._resolving = state["_resolving"]
790 |
791 | @classmethod
792 | def __get_pydantic_core_schema__(
793 | cls, source_type: t.Any, handler: pydantic.GetCoreSchemaHandler
794 | ) -> pydantic_core.CoreSchema:
795 | return pydantic_core.core_schema.dict_schema(
796 | keys_schema=pydantic_core.core_schema.union_schema(
797 | [
798 | pydantic_core.core_schema.str_schema(),
799 | pydantic_core.core_schema.tuple_schema(
800 | [
801 | pydantic_core.core_schema.str_schema(),
802 | pydantic_core.core_schema.any_schema(),
803 | ]
804 | ),
805 | pydantic_core.core_schema.is_instance_schema(TypedKey),
806 | ]
807 | ),
808 | values_schema=pydantic_core.core_schema.any_schema(),
809 | )
810 |
811 |
812 | def _iter_wrapped(f: t.Callable):
813 | yield f
814 | f_w = inspect.unwrap(f)
815 | if f_w is not f:
816 | yield from _iter_wrapped(f_w)
817 |
818 |
819 | GLOBAL_REGISTRY = DependencyRegistry()
820 | """A global dependency registry."""
821 |
--------------------------------------------------------------------------------
/src/cdf/types/monads.py:
--------------------------------------------------------------------------------
1 | """Contains monadic types and functions for working with them."""
2 |
3 | from __future__ import annotations
4 |
5 | import abc
6 | import asyncio
7 | import functools
8 | import inspect
9 | import sys
10 | import typing as t
11 |
12 | from typing_extensions import Self
13 |
14 | if sys.version_info < (3, 10):
15 | from typing_extensions import ParamSpec
16 | else:
17 | from typing import ParamSpec
18 |
19 | T = t.TypeVar("T") # The type of the value inside the Monad
20 | U = t.TypeVar("U") # The transformed type of the value inside the Monad
21 | K = t.TypeVar("K") # A known type that is not necessarily the same as T
22 | L = t.TypeVar("L") # A known type that is not necessarily the same as U
23 | E = t.TypeVar(
24 | "E", bound=BaseException, covariant=True
25 | ) # The type of the error inside the Result
26 | P = ParamSpec("P")
27 |
28 | TState = t.TypeVar("TState") # The type of the state
29 | TMonad = t.TypeVar("TMonad", bound="Monad") # Generic Self type for Monad
30 |
31 |
32 | class Monad(t.Generic[T], abc.ABC):
33 | def __init__(self, value: T) -> None:
34 | self._value = value
35 |
36 | def __hash__(self) -> int:
37 | return hash(self._value)
38 |
39 | @abc.abstractmethod
40 | def bind(self, func: t.Callable[[T], "Monad[U]"]) -> "Monad[U]":
41 | pass
42 |
43 | @abc.abstractmethod
44 | def map(self, func: t.Callable[[T], U]) -> "Monad[U]":
45 | pass
46 |
47 | @abc.abstractmethod
48 | def filter(self, predicate: t.Callable[[T], bool]) -> Self:
49 | pass
50 |
51 | @abc.abstractmethod
52 | def unwrap(self) -> T:
53 | pass
54 |
55 | @abc.abstractmethod
56 | def unwrap_or(self, default: U) -> t.Union[T, U]:
57 | pass
58 |
59 | def __call__(self, func: t.Callable[[T], "Monad[U]"]) -> "Monad[U]":
60 | return self.bind(func)
61 |
62 | def __rshift__(self, func: t.Callable[[T], "Monad[U]"]) -> "Monad[U]":
63 | return self.bind(func)
64 |
65 |
66 | class Maybe(Monad[T], abc.ABC):
67 | @classmethod
68 | def pure(cls, value: K) -> "Maybe[K]":
69 | """Creates a Maybe with a value."""
70 | return Just(value)
71 |
72 | @abc.abstractmethod
73 | def is_just(self) -> bool:
74 | pass
75 |
76 | @abc.abstractmethod
77 | def is_nothing(self) -> bool:
78 | pass
79 |
80 | if t.TYPE_CHECKING:
81 |
82 | def bind(self, func: t.Callable[[T], "Maybe[U]"]) -> "Maybe[U]": ...
83 |
84 | def map(self, func: t.Callable[[T], U]) -> "Maybe[U]": ...
85 |
86 | def filter(self, predicate: t.Callable[[T], bool]) -> "Maybe[T]": ...
87 |
88 | def unwrap(self) -> T:
89 | """Unwraps the value of the Maybe.
90 |
91 | Returns:
92 | The unwrapped value.
93 | """
94 | if self.is_just():
95 | return self._value
96 | else:
97 | raise ValueError("Cannot unwrap Nothing.")
98 |
99 | def unwrap_or(self, default: U) -> t.Union[T, U]:
100 | """Tries to unwrap the Maybe, returning a default value if the Maybe is Nothing.
101 |
102 | Args:
103 | default: The value to return if unwrapping Nothing.
104 |
105 | Returns:
106 | The unwrapped value or the default value.
107 | """
108 | if self.is_just():
109 | return self._value
110 | else:
111 | return default
112 |
113 | @classmethod
114 | def lift(cls, func: t.Callable[[U], K]) -> t.Callable[["U | Maybe[U]"], "Maybe[K]"]:
115 | """Lifts a function to work within the Maybe monad.
116 |
117 | Args:
118 | func: A function to lift.
119 |
120 | Returns:
121 | A new function that returns a Maybe value.
122 | """
123 |
124 | @functools.wraps(func)
125 | def wrapper(value: U | Maybe[U]) -> Maybe[K]:
126 | if isinstance(value, Maybe):
127 | return value.map(func) # type: ignore
128 | value = t.cast(U, value)
129 | try:
130 | result = func(value)
131 | if result is None:
132 | return Nothing()
133 | return Just(result)
134 | except Exception:
135 | return Nothing()
136 |
137 | return wrapper
138 |
139 | def __iter__(self) -> t.Iterator[T]:
140 | """Allows safely unwrapping the value of the Maybe using a for construct."""
141 | if self.is_just():
142 | yield self.unwrap()
143 |
144 |
145 | class Just(Maybe[T]):
146 | def bind(self, func: t.Callable[[T], Maybe[U]]) -> Maybe[U]:
147 | """Applies a function to the value inside the Just.
148 |
149 | Args:
150 | func: A function that takes a value of type T and returns a Maybe containing a value of type U.
151 |
152 | Returns:
153 | The result of applying the function to the value inside the Just.
154 | """
155 | return func(self._value)
156 |
157 | def map(self, func: t.Callable[[T], U]) -> "Maybe[U]":
158 | """Applies a mapping function to the value inside the Just.
159 |
160 | Args:
161 | func: A function that takes a value of type T and returns a value of type U.
162 |
163 | Returns:
164 | A new Just containing the result of applying the function to the value inside the Just.
165 | """
166 | try:
167 | result = func(self._value)
168 | if result is None:
169 | return Nothing()
170 | return Just(result)
171 | except Exception:
172 | return Nothing()
173 |
174 | def filter(self, predicate: t.Callable[[T], bool]) -> Maybe[T]:
175 | """Filters the value inside the Just based on a predicate.
176 |
177 | Args:
178 | predicate: A function that takes a value of type T and returns a boolean.
179 |
180 | Returns:
181 | A new Just containing the value inside the Just if the predicate holds.
182 | """
183 | if predicate(self._value):
184 | return self
185 | else:
186 | return Nothing()
187 |
188 | def is_just(self) -> bool:
189 | """Returns True if the Maybe is a Just."""
190 | return True
191 |
192 | def is_nothing(self) -> bool:
193 | """Returns False if the Maybe is a Just."""
194 | return False
195 |
196 | def __repr__(self) -> str:
197 | return f"Just({self._value})"
198 |
199 |
200 | class Nothing(Maybe[T]):
201 | def __init__(self) -> None:
202 | super().__init__(t.cast(T, None))
203 |
204 | def bind(self, func: t.Callable[[T], Maybe[U]]) -> "Nothing[T]":
205 | """Applies a function to the value inside the Just.
206 |
207 | Args:
208 | func: A function that takes a value of type T and returns a Maybe containing a value of type U.
209 |
210 | Returns:
211 | The result of applying the function to the value inside the Just.
212 | """
213 | return self
214 |
215 | def map(self, func: t.Callable[[T], U]) -> "Nothing[T]":
216 | """Applies a mapping function to the value inside the Just.
217 |
218 | Args:
219 | func: A function that takes a value of type T and returns a value of type U.
220 |
221 | Returns:
222 | A new Just containing the result of applying the function to the value inside the Just.
223 | """
224 | return self
225 |
226 | def filter(self, predicate: t.Callable[[T], bool]) -> "Nothing[T]":
227 | """Filters the value inside the Just based on a predicate.
228 |
229 | Args:
230 | predicate: A function that takes a value of type T and returns a boolean.
231 |
232 | Returns:
233 | A new Just containing the value inside the Just if the predicate holds.
234 | """
235 | return self
236 |
237 | def is_just(self) -> bool:
238 | """Returns False if the Maybe is a Nothing."""
239 | return False
240 |
241 | def is_nothing(self) -> bool:
242 | """Returns True if the Maybe is a Nothing."""
243 | return True
244 |
245 | def __repr__(self) -> str:
246 | return "Nothing()"
247 |
248 |
249 | class Result(Monad[T], t.Generic[T, E]):
250 | @classmethod
251 | def pure(cls, value: K) -> "Result[K, E]":
252 | """Creates an Ok with a value."""
253 | return Ok(value)
254 |
255 | @abc.abstractmethod
256 | def is_ok(self) -> bool:
257 | pass
258 |
259 | @abc.abstractmethod
260 | def is_err(self) -> bool:
261 | pass
262 |
263 | @abc.abstractmethod
264 | def unwrap(self) -> T:
265 | pass
266 |
267 | @abc.abstractmethod
268 | def unwrap_or(self, default: U) -> t.Union[T, U]:
269 | pass
270 |
271 | @abc.abstractmethod
272 | def unwrap_err(self) -> BaseException:
273 | pass
274 |
275 | @abc.abstractmethod
276 | def to_parts(self) -> t.Tuple[T, E | None]:
277 | pass
278 |
279 | @classmethod
280 | def lift(
281 | cls, func: t.Callable[[U], K]
282 | ) -> t.Callable[["U | Result[U, Exception]"], "Result[K, Exception]"]:
283 | """Transforms a function to work with arguments and output wrapped in Result monads.
284 |
285 | Args:
286 | func: A function that takes any number of arguments and returns a value of type T.
287 |
288 | Returns:
289 | A function that takes the same number of unwrapped arguments and returns a Result-wrapped result.
290 | """
291 |
292 | def wrapper(result: U | Result[U, Exception]) -> Result[K, Exception]:
293 | if isinstance(result, Result):
294 | return result.map(func)
295 | result = t.cast(U, result)
296 | try:
297 | return Ok(func(result))
298 | except Exception as e:
299 | return Err(e)
300 |
301 | if hasattr(func, "__defaults__") and func.__defaults__:
302 | default = func.__defaults__[0]
303 | wrapper.__defaults__ = (default,)
304 |
305 | return wrapper
306 |
307 | if t.TYPE_CHECKING:
308 |
309 | def bind(self, func: t.Callable[[T], "Result[U, E]"]) -> "Result[U, E]": ...
310 |
311 | def map(self, func: t.Callable[[T], U]) -> "Result[U, E]": ...
312 |
313 | def filter(self, predicate: t.Callable[[T], bool]) -> "Result[T, E]": ...
314 |
315 | def __call__(self, func: t.Callable[[T], "Result[U, E]"]) -> "Result[U, E]": ...
316 |
317 | def __rshift__(
318 | self, func: t.Callable[[T], "Result[U, E]"]
319 | ) -> "Result[U, E]": ...
320 |
321 | def __iter__(self) -> t.Iterator[T]:
322 | """Allows safely unwrapping the value of the Result using a for construct."""
323 | if self.is_ok():
324 | yield self.unwrap()
325 |
326 |
327 | class Ok(Result[T, E]):
328 | def bind(self, func: t.Callable[[T], Result[U, E]]) -> Result[U, E]:
329 | """Applies a function to the result of the Ok.
330 |
331 | Args:
332 | func: A function that takes a value of type T and returns a Result containing a value of type U.
333 |
334 | Returns:
335 | A new Result containing the result of the original Result after applying the function.
336 | """
337 | return func(self._value)
338 |
339 | def map(self, func: t.Callable[[T], U]) -> Result[U, E]:
340 | """Applies a mapping function to the result of the Ok.
341 |
342 | Args:
343 | func: A function that takes a value of type T and returns a value of type U.
344 |
345 | Returns:
346 | A new Ok containing the result of the original Ok after applying the function.
347 | """
348 | try:
349 | return Ok(func(self._value))
350 | except Exception as e:
351 | return Err(t.cast(E, e))
352 |
353 | def is_ok(self) -> bool:
354 | """Returns True if the Result is an Ok."""
355 | return True
356 |
357 | def is_err(self) -> bool:
358 | """Returns False if the Result is an Ok."""
359 | return False
360 |
361 | def unwrap(self) -> T:
362 | """Unwraps the value of the Ok.
363 |
364 | Returns:
365 | The unwrapped value.
366 | """
367 | return self._value
368 |
369 | def unwrap_or(self, default: t.Any) -> T:
370 | """Tries to unwrap the Ok, returning a default value if unwrapping raises an exception.
371 |
372 | Args:
373 | default: The value to return if unwrapping raises an exception.
374 |
375 | Returns:
376 | The unwrapped value or the default value if an exception is raised.
377 | """
378 | return self._value
379 |
380 | def unwrap_err(self) -> BaseException:
381 | """Raises a ValueError since the Result is an Ok."""
382 | raise ValueError("Called unwrap_err on Ok")
383 |
384 | def filter(self, predicate: t.Callable[[T], bool]) -> Result[T, E]:
385 | """Filters the result of the Ok based on a predicate.
386 |
387 | Args:
388 | predicate: A function that takes a value of type T and returns a boolean.
389 | error: The error to return if the predicate does not hold.
390 |
391 | Returns:
392 | A new Result containing the result of the original Result if the predicate holds.
393 | """
394 | if predicate(self._value):
395 | return self
396 | else:
397 | return Err(t.cast(E, ValueError("Predicate does not hold")))
398 |
399 | def to_parts(self) -> t.Tuple[T, None]:
400 | """Unpacks the value of the Ok."""
401 | return (self._value, None)
402 |
403 | def __repr__(self) -> str:
404 | return f"Ok({self._value})"
405 |
406 |
407 | class Err(Result[T, E]):
408 | def __init__(self, error: E) -> None:
409 | """Initializes an Err with an error.
410 |
411 | Args:
412 | error: The error to wrap in the Err.
413 | """
414 | self._error = error
415 |
416 | def __hash__(self) -> int:
417 | return hash(self._error)
418 |
419 | def bind(self, func: t.Callable[[T], Result[U, E]]) -> "Err[T, E]":
420 | """Applies a function to the result of the Err.
421 |
422 | Args:
423 | func: A function that takes a value of type T and returns a Result containing a value of type U.
424 |
425 | Returns:
426 | An Err containing the original error.
427 | """
428 | return self
429 |
430 | def map(self, func: t.Callable[[T], U]) -> "Err[T, E]":
431 | """Applies a mapping function to the result of the Err.
432 |
433 | Args:
434 | func: A function that takes a value of type T and returns a value of type U.
435 |
436 | Returns:
437 | An Err containing the original error.
438 | """
439 | return self
440 |
441 | def is_ok(self) -> bool:
442 | """Returns False if the Result is an Err."""
443 | return False
444 |
445 | def is_err(self) -> bool:
446 | """Returns True if the Result is an Err."""
447 | return True
448 |
449 | def unwrap(self) -> T:
450 | """Raises a ValueError since the Result is an Err."""
451 | raise self._error
452 |
453 | def unwrap_or(self, default: U) -> U:
454 | """Returns a default value since the Result is an Err.
455 |
456 | Args:
457 | default: The value to return.
458 |
459 | Returns:
460 | The default value.
461 | """
462 | return default
463 |
464 | def unwrap_err(self) -> BaseException:
465 | """Unwraps the error of the Err.
466 |
467 | Returns:
468 | The unwrapped error.
469 | """
470 | return self._error
471 |
472 | def filter(self, predicate: t.Callable[[T], bool]) -> "Err[T, E]":
473 | """Filters the result of the Err based on a predicate.
474 |
475 | Args:
476 | predicate: A function that takes a value of type T and returns a boolean.
477 |
478 | Returns:
479 | An Err containing the original error.
480 | """
481 | return self
482 |
483 | def to_parts(self) -> t.Tuple[None, E]:
484 | """Unpacks the error of the Err."""
485 | return (None, self._error)
486 |
487 | def __repr__(self) -> str:
488 | return f"Err({self._error})"
489 |
490 |
491 | class Promise(t.Generic[T], t.Awaitable[T], Monad[T]):
492 | def __init__(
493 | self,
494 | coro_func: t.Callable[P, t.Coroutine[None, None, T]],
495 | *args: P.args,
496 | **kwargs: P.kwargs,
497 | ) -> None:
498 | """Initializes a Promise with a coroutine function.
499 |
500 | Args:
501 | coro_func: A coroutine function that returns a value of type T.
502 | args: Positional arguments to pass to the coroutine function.
503 | kwargs: Keyword arguments to pass to the coroutine function.
504 | """
505 | self._loop = asyncio.get_event_loop()
506 | if callable(coro_func):
507 | coro = coro_func(*args, **kwargs)
508 | elif inspect.iscoroutine(coro_func):
509 | coro = t.cast(t.Coroutine[None, None, T], coro_func)
510 | else:
511 | raise ValueError("Invalid coroutine function")
512 | self._future: asyncio.Future[T] = asyncio.ensure_future(coro, loop=self._loop)
513 |
514 | @classmethod
515 | def pure(cls, value: K) -> "Promise[K]":
516 | """Creates a Promise that is already resolved with a value.
517 |
518 | Args:
519 | value: The value to resolve the Promise with.
520 |
521 | Returns:
522 | A new Promise that is already resolved with the value.
523 | """
524 | return cls.from_value(value) # type: ignore
525 |
526 | def __hash__(self) -> int:
527 | return hash(self._future)
528 |
529 | def __await__(self):
530 | """Allows the Promise to be awaited."""
531 | yield from self._future.__await__()
532 | return (yield from self._future.__await__())
533 |
534 | def set_result(self, result: T) -> None:
535 | """Sets a result on the Promise.
536 |
537 | Args:
538 | result: The result to set on the Promise.
539 | """
540 | if not self._future.done():
541 | self._loop.call_soon_threadsafe(self._future.set_result, result)
542 |
543 | def set_exception(self, exception: Exception) -> None:
544 | """Sets an exception on the Promise.
545 |
546 | Args:
547 | exception: The exception to set on the Promise.
548 | """
549 | if not self._future.done():
550 | self._loop.call_soon_threadsafe(self._future.set_exception, exception)
551 |
552 | def bind(self, func: t.Callable[[T], "Promise[U]"]) -> "Promise[U]":
553 | """Applies a function to the result of the Promise.
554 |
555 | Args:
556 | func: A function that takes a value of type T and returns a Promise containing a value of type U.
557 |
558 | Returns:
559 | A new Promise containing the result of the original Promise after applying the function.
560 | """
561 |
562 | async def bound_coro() -> U:
563 | try:
564 | value = await self
565 | next_promise = func(value)
566 | return await next_promise
567 | except Exception as e:
568 | future = self._loop.create_future()
569 | future.set_exception(e)
570 | return t.cast(U, await future)
571 |
572 | return Promise(bound_coro)
573 |
574 | def map(self, func: t.Callable[[T], U]) -> "Promise[U]":
575 | """Applies a mapping function to the result of the Promise.
576 |
577 | Args:
578 | func: A function that takes a value of type T and returns a value of type U.
579 |
580 | Returns:
581 | A new Promise containing the result of the original Promise after applying the function.
582 | """
583 |
584 | async def mapped_coro() -> U:
585 | try:
586 | value = await self
587 | return func(value)
588 | except Exception as e:
589 | future = self._loop.create_future()
590 | future.set_exception(e)
591 | return t.cast(U, await future)
592 |
593 | return Promise(mapped_coro)
594 |
595 | then = map # syntactic sugar, equivalent to map
596 |
597 | def filter(self, predicate: t.Callable[[T], bool]) -> "Promise[T]":
598 | """Filters the result of the Promise based on a predicate.
599 |
600 | Args:
601 | predicate: A function that takes a value of type T and returns a boolean.
602 |
603 | Returns:
604 | A new Promise containing the result of the original Promise if the predicate holds.
605 | """
606 |
607 | async def filtered_coro() -> T:
608 | try:
609 | value = await self
610 | if predicate(value):
611 | return value
612 | else:
613 | raise ValueError("Filter predicate failed")
614 | except Exception as e:
615 | future = self._loop.create_future()
616 | future.set_exception(e)
617 | return await future
618 |
619 | return Promise(filtered_coro)
620 |
621 | def unwrap(self) -> T:
622 | return self._loop.run_until_complete(self)
623 |
624 | def unwrap_or(self, default: T) -> T:
625 | """Tries to unwrap the Promise, returning a default value if unwrapping raises an exception.
626 |
627 | Args:
628 | default: The value to return if unwrapping raises an exception.
629 |
630 | Returns:
631 | The unwrapped value or the default value if an exception is raised.
632 | """
633 | try:
634 | return self._loop.run_until_complete(self)
635 | except Exception:
636 | return default
637 |
638 | @classmethod
639 | def from_value(cls, value: T) -> "Promise[T]":
640 | """Creates a Promise that is already resolved with a value.
641 |
642 | Args:
643 | value: The value to resolve the Promise with.
644 |
645 | Returns:
646 | A new Promise that is already resolved with the value.
647 | """
648 |
649 | async def _fut():
650 | return value
651 |
652 | return cls(_fut)
653 |
654 | @classmethod
655 | def from_exception(cls, exception: BaseException) -> "Promise[T]":
656 | """Creates a Promise that is already resolved with an exception.
657 |
658 | Args:
659 | exception: The exception to resolve the Promise with.
660 |
661 | Returns:
662 | A new Promise that is already resolved with the exception.
663 | """
664 |
665 | async def _fut():
666 | raise exception
667 |
668 | return cls(_fut)
669 |
670 | @classmethod
671 | def lift(
672 | cls, func: t.Callable[[U], T]
673 | ) -> t.Callable[["U | Promise[U]"], "Promise[T]"]:
674 | """
675 | Lifts a synchronous function to work within the Promise context,
676 | making it return a Promise of the result and allowing it to be used
677 | with Promise inputs.
678 |
679 | Args:
680 | func: A synchronous function that returns a value of type T.
681 |
682 | Returns:
683 | A function that, when called, returns a Promise wrapping the result of the original function.
684 | """
685 |
686 | @functools.wraps(func)
687 | def wrapper(value: "U | Promise[U]") -> "Promise[T]":
688 | if isinstance(value, Promise):
689 | return value.map(func)
690 | value = t.cast(U, value)
691 |
692 | async def async_wrapper() -> T:
693 | return func(value)
694 |
695 | return cls(async_wrapper)
696 |
697 | return wrapper
698 |
699 |
700 | class Lazy(Monad[T]):
701 | def __init__(self, computation: t.Callable[[], T]) -> None:
702 | """Initializes a Lazy monad with a computation that will be executed lazily.
703 |
704 | Args:
705 | computation: A function that takes no arguments and returns a value of type T.
706 | """
707 | self._computation = computation
708 | self._value = None
709 | self._evaluated = False
710 |
711 | @classmethod
712 | def pure(cls, value: T) -> "Lazy[T]":
713 | """Creates a Lazy monad with a pure value."""
714 | return cls(lambda: value)
715 |
716 | def evaluate(self) -> T:
717 | """Evaluates the computation if it has not been evaluated yet and caches the result.
718 |
719 | Returns:
720 | The result of the computation.
721 | """
722 | if not self._evaluated:
723 | self._value = self._computation()
724 | self._evaluated = True
725 | return t.cast(T, self._value)
726 |
727 | def bind(self, func: t.Callable[[T], "Lazy[U]"]) -> "Lazy[U]":
728 | """Lazily applies a function to the result of the current computation.
729 |
730 | Args:
731 | func: A function that takes a value of type T and returns a Lazy monad containing a value of type U.
732 |
733 | Returns:
734 | A new Lazy monad containing the result of the computation after applying the function.
735 | """
736 | return Lazy(lambda: func(self.evaluate()).evaluate())
737 |
738 | def map(self, func: t.Callable[[T], U]) -> "Lazy[U]":
739 | """Lazily applies a mapping function to the result of the computation.
740 |
741 | Args:
742 | func: A function that takes a value of type T and returns a value of type U.
743 |
744 | Returns:
745 | A new Lazy monad containing the result of the computation after applying the function.
746 | """
747 | return Lazy(lambda: func(self.evaluate()))
748 |
749 | def filter(self, predicate: t.Callable[[T], bool]) -> "Lazy[T]":
750 | """Lazily filters the result of the computation based on a predicate.
751 |
752 | Args:
753 | predicate: A function that takes a value of type T and returns a boolean.
754 |
755 | Returns:
756 | A new Lazy monad containing the result of the computation if the predicate holds.
757 | """
758 |
759 | def filter_computation():
760 | result = self.evaluate()
761 | if predicate(result):
762 | return result
763 | else:
764 | raise ValueError("Predicate does not hold for the value.")
765 |
766 | return Lazy(filter_computation)
767 |
768 | def unwrap(self) -> T:
769 | """Forces evaluation of the computation and returns its result.
770 |
771 | Returns:
772 | The result of the computation.
773 | """
774 | return self.evaluate()
775 |
776 | def unwrap_or(self, default: T) -> T:
777 | """Tries to evaluate the computation, returning a default value if evaluation raises an exception.
778 |
779 | Args:
780 | default: The value to return if the computation raises an exception.
781 |
782 | Returns:
783 | The result of the computation or the default value if an exception is raised.
784 | """
785 | try:
786 | return self.evaluate()
787 | except Exception:
788 | return default
789 |
790 | @classmethod
791 | def lift(cls, func: t.Callable[[U], T]) -> t.Callable[["U | Lazy[U]"], "Lazy[T]"]:
792 | """Transforms a function to work with arguments and output wrapped in Lazy monads.
793 |
794 | Args:
795 | func: A function that takes any number of arguments and returns a value of type U.
796 |
797 | Returns:
798 | A function that takes the same number of Lazy-wrapped arguments and returns a Lazy-wrapped result.
799 | """
800 |
801 | @functools.wraps(func)
802 | def wrapper(value: "U | Lazy[U]") -> "Lazy[T]":
803 | if isinstance(value, Lazy):
804 | return value.map(func)
805 | value = t.cast(U, value)
806 |
807 | def computation() -> T:
808 | return func(value)
809 |
810 | return cls(computation)
811 |
812 | return wrapper
813 |
814 |
815 | Defer = Lazy # Defer is an alias for Lazy
816 |
817 | S = t.TypeVar("S") # State type
818 | A = t.TypeVar("A") # Return type
819 | B = t.TypeVar("B") # Transformed type
820 |
821 |
822 | class State(t.Generic[S, A], Monad[A], abc.ABC):
823 | def __init__(self, run_state: t.Callable[[S], t.Tuple[A, S]]) -> None:
824 | self.run_state = run_state
825 |
826 | def bind(self, func: t.Callable[[A], "State[S, B]"]) -> "State[S, B]":
827 | def new_run_state(s: S) -> t.Tuple[B, S]:
828 | a, state_prime = self.run_state(s)
829 | return func(a).run_state(state_prime)
830 |
831 | return State(new_run_state)
832 |
833 | def map(self, func: t.Callable[[A], B]) -> "State[S, B]":
834 | def new_run_state(s: S) -> t.Tuple[B, S]:
835 | a, state_prime = self.run_state(s)
836 | return func(a), state_prime
837 |
838 | return State(new_run_state)
839 |
840 | def filter(self, predicate: t.Callable[[A], bool]) -> "State[S, A]":
841 | def new_run_state(s: S) -> t.Tuple[A, S]:
842 | a, state_prime = self.run_state(s)
843 | if predicate(a):
844 | return a, state_prime
845 | else:
846 | raise ValueError("Value does not satisfy predicate")
847 |
848 | return State(new_run_state)
849 |
850 | def unwrap(self) -> A:
851 | raise NotImplementedError(
852 | "State cannot be directly unwrapped without providing an initial state."
853 | )
854 |
855 | def unwrap_or(self, default: B) -> t.Union[A, B]:
856 | raise NotImplementedError(
857 | "State cannot directly return a value without an initial state."
858 | )
859 |
860 | def __hash__(self) -> int:
861 | return id(self.run_state)
862 |
863 | @staticmethod
864 | def pure(value: A) -> "State[S, A]":
865 | return State(lambda s: (value, s))
866 |
867 | def __call__(self, state: S) -> t.Tuple[A, S]:
868 | return self.run_state(state)
869 |
870 | def __repr__(self) -> str:
871 | return f"State({self.run_state})"
872 |
873 | @classmethod
874 | def lift(
875 | cls, func: t.Callable[[U], A]
876 | ) -> t.Callable[["U | State[S, U]"], "State[S, A]"]:
877 | """Lifts a function to work within the State monad.
878 | Args:
879 | func: A function to lift.
880 | Returns:
881 | A new function that returns a State value.
882 | """
883 |
884 | @functools.wraps(func)
885 | def wrapper(value: "U | State[S, U]") -> "State[S, A]":
886 | if isinstance(value, State):
887 | return value.map(func)
888 | value = t.cast(U, value)
889 |
890 | def run_state(s: S) -> t.Tuple[A, S]:
891 | return func(value), s
892 |
893 | return cls(run_state)
894 |
895 | return wrapper
896 |
897 |
898 | # Aliases for monadic converters
899 | # to_ is the pure function
900 | # is the lift function
901 |
902 | to_maybe = just = Maybe.pure
903 | nothing = Nothing[t.Any]()
904 | maybe = Maybe.lift
905 |
906 | to_result = ok = Result.pure
907 | error = lambda e: Err(e) # noqa: E731
908 | result = Result.lift
909 |
910 | to_promise = Promise.pure
911 | promise = Promise.lift
912 |
913 | to_lazy = Lazy.pure
914 | lazy = Lazy.lift
915 |
916 | to_deferred = Defer.pure
917 | deferred = Defer.lift
918 |
919 | to_state = State.pure
920 | state = State.lift
921 |
922 | # to_io = IO.pure
923 | # io = IO.lift
924 |
--------------------------------------------------------------------------------