├── ccflow ├── py.typed ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── enums │ │ ├── __init__.py │ │ └── test_pydantic.py │ ├── examples │ │ ├── __init__.py │ │ ├── test_tpch.py │ │ └── test_etl.py │ ├── exttypes │ │ ├── __init__.py │ │ ├── test_jinja.py │ │ ├── test_exprtk.py │ │ ├── test_pyobjectpath.py │ │ ├── test_frequency.py │ │ ├── test_polars.py │ │ └── test_pydantic_numpy.py │ ├── models │ │ ├── __init__.py │ │ └── test_publisher.py │ ├── plugins │ │ └── __init__.py │ ├── result │ │ ├── __init__.py │ │ ├── test_dict.py │ │ ├── test_pandas.py │ │ ├── test_numpy.py │ │ ├── test_xarray.py │ │ ├── test_pyarrow.py │ │ ├── test_generic.py │ │ └── test_narwhals.py │ ├── utils │ │ ├── __init__.py │ │ ├── test_logging.py │ │ ├── test_formatter.py │ │ ├── test_compose_hydra.py │ │ └── test_arrow.py │ ├── evaluators │ │ ├── __init__.py │ │ └── util.py │ ├── data │ │ ├── __init__.py │ │ └── python_object_samples.py │ ├── config_user │ │ ├── sample.yaml │ │ └── sample2.yml │ ├── config │ │ ├── conf_out_of_order.yaml │ │ ├── conf_from_python.yaml │ │ ├── conf.yaml │ │ └── conf_sub.yaml │ ├── test_import.py │ ├── test_evaluator.py │ ├── test_lazy_result.py │ ├── test_global_state.py │ ├── test_base_load_config.py │ ├── publishers │ │ └── test_print.py │ ├── test_validators.py │ └── test_object_config.py ├── examples │ ├── etl │ │ ├── config │ │ │ ├── __init__.py │ │ │ ├── load │ │ │ │ └── db.yaml │ │ │ ├── extract │ │ │ │ └── rest.yaml │ │ │ ├── transform │ │ │ │ └── links.yaml │ │ │ └── base.yaml │ │ ├── explain.py │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── models.py │ ├── tpch │ │ ├── config │ │ │ └── conf.yaml │ │ ├── queries │ │ │ ├── __init__.py │ │ │ ├── README.md │ │ │ ├── q6.py │ │ │ ├── q13.py │ │ │ ├── q4.py │ │ │ ├── q14.py │ │ │ ├── q17.py │ │ │ ├── q11.py │ │ │ ├── LICENSE.md │ │ │ ├── q15.py │ │ │ ├── q16.py │ │ │ ├── q18.py │ │ │ ├── q1.py │ │ │ ├── q22.py │ │ │ ├── q12.py │ │ │ ├── q3.py │ │ │ ├── q5.py │ │ │ ├── q9.py │ │ │ ├── q21.py │ │ │ ├── q10.py │ │ │ ├── q20.py │ │ │ ├── q19.py │ │ │ ├── q2.py │ │ │ ├── q8.py │ │ │ └── q7.py │ │ ├── __init__.py │ │ ├── base.py │ │ └── query.py │ ├── __init__.py │ └── example.parquet ├── evaluators │ └── __init__.py ├── models │ ├── __init__.py │ └── publisher.py ├── plugins │ └── __init__.py ├── exttypes │ ├── pydantic_numpy │ │ ├── __init__.py │ │ └── ndarray.py │ ├── __init__.py │ ├── jinja.py │ ├── exprtk.py │ ├── frequency.py │ ├── polars.py │ └── pyobjectpath.py ├── result │ ├── __init__.py │ ├── list.py │ ├── dict.py │ ├── numpy.py │ ├── pandas.py │ ├── xarray.py │ ├── generic.py │ ├── narwhals.py │ └── pyarrow.py ├── utils │ ├── tokenize.py │ ├── __init__.py │ ├── logging.py │ ├── formatter.py │ ├── core.py │ ├── chunker.py │ └── arrow.py ├── publishers │ ├── __init__.py │ └── print.py ├── __init__.py ├── global_state.py ├── publisher.py ├── serialization.py ├── validators.py ├── object_config.py └── compose.py ├── docs ├── img │ ├── dark.png │ ├── light.png │ └── wiki │ │ └── etl │ │ ├── explain1.png │ │ └── explain2.png └── wiki │ ├── _Footer.md │ ├── Installation.md │ ├── _Sidebar.md │ ├── Contribute.md │ ├── contribute │ ├── Contribute.md │ ├── Build-from-Source.md │ └── Local-Development-Setup.md │ ├── First-Steps.md │ ├── Build-from-Source.md │ └── Local-Development-Setup.md ├── .github ├── CODEOWNERS ├── dependabot.yaml ├── workflows │ ├── copier.yaml │ ├── wiki.yaml │ └── build.yaml ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── CODE_OF_CONDUCT.md ├── .gitattributes ├── .copier-answers.yaml ├── .gitignore ├── Makefile └── README.md /ccflow/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/conftest.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/enums/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/exttypes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/result/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/examples/etl/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/config/conf.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/tests/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ccflow/examples/__init__.py: -------------------------------------------------------------------------------- 1 | from .tpch import * 2 | -------------------------------------------------------------------------------- /ccflow/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | -------------------------------------------------------------------------------- /ccflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .publisher import * 2 | -------------------------------------------------------------------------------- /ccflow/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from .omegaconf_resolvers import * 2 | -------------------------------------------------------------------------------- /docs/img/dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Point72/ccflow/HEAD/docs/img/dark.png -------------------------------------------------------------------------------- /docs/img/light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Point72/ccflow/HEAD/docs/img/light.png -------------------------------------------------------------------------------- /ccflow/exttypes/pydantic_numpy/__init__.py: -------------------------------------------------------------------------------- 1 | from .ndarray import NDArray 2 | from .ndtypes import * 3 | -------------------------------------------------------------------------------- /ccflow/result/__init__.py: -------------------------------------------------------------------------------- 1 | from .dict import * 2 | from .generic import * 3 | from .list import * 4 | -------------------------------------------------------------------------------- /ccflow/utils/tokenize.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | from dask.base import normalize_token, tokenize 3 | -------------------------------------------------------------------------------- /ccflow/publishers/__init__.py: -------------------------------------------------------------------------------- 1 | from .composite import * 2 | from .file import * 3 | from .print import * 4 | -------------------------------------------------------------------------------- /ccflow/examples/example.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Point72/ccflow/HEAD/ccflow/examples/example.parquet -------------------------------------------------------------------------------- /docs/img/wiki/etl/explain1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Point72/ccflow/HEAD/docs/img/wiki/etl/explain1.png -------------------------------------------------------------------------------- /docs/img/wiki/etl/explain2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Point72/ccflow/HEAD/docs/img/wiki/etl/explain2.png -------------------------------------------------------------------------------- /ccflow/examples/tpch/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .data_generators import * 3 | from .query import * 4 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Team Review 2 | * @feussy @hintse @timkpaine @ptomecek 3 | 4 | # Administrative 5 | LICENSE @ptomecek @timkpaine 6 | -------------------------------------------------------------------------------- /ccflow/tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Shared test data modules. 2 | 3 | Import sample configs using module-level objects in `python_object_samples`. 4 | """ 5 | -------------------------------------------------------------------------------- /ccflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunker import * 2 | from .core import * 3 | from .logging import * 4 | from .tokenize import normalize_token, tokenize 5 | -------------------------------------------------------------------------------- /docs/wiki/_Footer.md: -------------------------------------------------------------------------------- 1 | _This wiki is autogenerated. To made updates, open a PR against the original source file in [`docs/wiki`](https://github.com/Point72/ccflow/tree/main/docs/wiki)._ 2 | -------------------------------------------------------------------------------- /ccflow/examples/etl/config/load/db.yaml: -------------------------------------------------------------------------------- 1 | _target_: ccflow.examples.etl.models.DBModel 2 | file: ${transform.publisher.name}${transform.publisher.suffix} 3 | db_file: etl.db 4 | table: links 5 | -------------------------------------------------------------------------------- /ccflow/tests/config_user/sample.yaml: -------------------------------------------------------------------------------- 1 | user_foo: 2 | _target_: ccflow.tests.test_base_registry.MyTestModel 3 | a: test 4 | b: 0.0 5 | c: 6 | - i 7 | - j 8 | d: 9 | k: 2.0 -------------------------------------------------------------------------------- /ccflow/tests/config_user/sample2.yml: -------------------------------------------------------------------------------- 1 | user_bar: 2 | _target_: ccflow.tests.test_base_registry.MyNestedModel 3 | x: foo 4 | y: # Note that when type is defined on parent model, no need to specify _target_ 5 | a: test2 6 | b: 2.0 -------------------------------------------------------------------------------- /ccflow/examples/etl/config/extract/rest.yaml: -------------------------------------------------------------------------------- 1 | _target_: ccflow.PublisherModel 2 | model: 3 | _target_: ccflow.examples.etl.models.RestModel 4 | publisher: 5 | _target_: ccflow.publishers.GenericFilePublisher 6 | name: raw 7 | suffix: .html 8 | field: value 9 | -------------------------------------------------------------------------------- /ccflow/result/list.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, List, TypeVar 2 | 3 | from ..base import ResultBase 4 | 5 | __all__ = ("ListResult",) 6 | 7 | 8 | V = TypeVar("V") 9 | 10 | 11 | class ListResult(ResultBase, Generic[V]): 12 | value: List[V] 13 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | examples/* linguist-documentation 2 | docs/* linguist-documentation 3 | *.ipynb linguist-documentation 4 | Makefile linguist-documentation 5 | 6 | *.md text=auto eol=lf 7 | *.py text=auto eol=lf 8 | *.toml text=auto eol=lf 9 | *.yaml text=auto eol=lf 10 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/README.md: -------------------------------------------------------------------------------- 1 | # Narwhals TPC-H Queries 2 | 3 | The queries in this folder are taken from the [Narwhals Repo](https://github.com/narwhals-dev/narwhals/tree/main/tpch/queries). 4 | See the `LICENSE.md` file in this folder for the license of the queries. 5 | -------------------------------------------------------------------------------- /ccflow/result/dict.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Generic, TypeVar 2 | 3 | from ..base import ResultBase 4 | 5 | __all__ = ("DictResult",) 6 | 7 | 8 | K = TypeVar("K") 9 | V = TypeVar("V") 10 | 11 | 12 | class DictResult(ResultBase, Generic[K, V]): 13 | value: Dict[K, V] 14 | -------------------------------------------------------------------------------- /ccflow/exttypes/__init__.py: -------------------------------------------------------------------------------- 1 | from .arrow import * 2 | from .exprtk import * 3 | from .frequency import * 4 | from .jinja import * 5 | 6 | # Do NOT import .polars. We don't want ccflow (without flow) to have a dependency on polars! 7 | from .pydantic_numpy import * 8 | from .pyobjectpath import * 9 | -------------------------------------------------------------------------------- /ccflow/examples/etl/explain.py: -------------------------------------------------------------------------------- 1 | from ccflow.utils.hydra import cfg_explain_cli 2 | 3 | from .__main__ import main 4 | 5 | __all__ = ("explain",) 6 | 7 | 8 | def explain(): 9 | cfg_explain_cli(config_path="config", config_name="base", hydra_main=main) 10 | 11 | 12 | if __name__ == "__main__": 13 | explain() 14 | -------------------------------------------------------------------------------- /ccflow/examples/etl/config/transform/links.yaml: -------------------------------------------------------------------------------- 1 | _target_: ccflow.PublisherModel 2 | model: 3 | _target_: ccflow.examples.etl.models.LinksModel 4 | file: ${extract.publisher.name}${extract.publisher.suffix} 5 | publisher: 6 | _target_: ccflow.publishers.GenericFilePublisher 7 | name: extracted 8 | suffix: .csv 9 | field: value 10 | -------------------------------------------------------------------------------- /ccflow/tests/result/test_dict.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from ccflow.result import DictResult 4 | 5 | 6 | class TestResult(TestCase): 7 | def test_dict(self): 8 | context = DictResult[str, float].model_validate({"value": {"a": 0, "b": 1.1}}) 9 | self.assertEqual(context.value, {"a": 0.0, "b": 1.1}) 10 | -------------------------------------------------------------------------------- /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | labels: 8 | - "part: github_actions" 9 | 10 | - package-ecosystem: "pip" 11 | directory: "/" 12 | schedule: 13 | interval: "monthly" 14 | labels: 15 | - "lang: python" 16 | - "part: dependencies" 17 | -------------------------------------------------------------------------------- /ccflow/utils/logging.py: -------------------------------------------------------------------------------- 1 | from logging import FileHandler as BaseFileHandler, StreamHandler 2 | from pathlib import Path 3 | 4 | __all__ = ("StreamHandler", "FileHandler") 5 | 6 | 7 | class FileHandler(BaseFileHandler): 8 | def __init__(self, filename, *args, **kwargs): 9 | Path(filename).parent.mkdir(parents=True, exist_ok=True) 10 | super().__init__(filename, *args, **kwargs) 11 | -------------------------------------------------------------------------------- /.github/workflows/copier.yaml: -------------------------------------------------------------------------------- 1 | name: Copier Updates 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: "0 5 * * 0" 7 | 8 | jobs: 9 | update: 10 | permissions: 11 | contents: write 12 | pull-requests: write 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions-ext/copier-update@main 16 | with: 17 | token: ${{ secrets.WORKFLOW_TOKEN }} 18 | -------------------------------------------------------------------------------- /ccflow/result/numpy.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar 2 | 3 | from ..base import ResultBase 4 | from ..exttypes import NDArray 5 | 6 | __all__ = ("NumpyResult",) 7 | 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class NumpyResult(ResultBase, Generic[T]): 13 | array: NDArray[T] 14 | 15 | def __eq__(self, other): 16 | return type(self) is type(other) and (self.array == other.array).all() 17 | -------------------------------------------------------------------------------- /docs/wiki/Installation.md: -------------------------------------------------------------------------------- 1 | ## Pre-requisites 2 | 3 | You need Python >=3.10 on your machine to install `ccflow`. 4 | 5 | ## Install with `pip` 6 | 7 | ```bash 8 | pip install ccflow 9 | ``` 10 | 11 | ## Install with `conda` 12 | 13 | ```bash 14 | conda install ccflow --channel conda-forge 15 | ``` 16 | 17 | ## Source installation 18 | 19 | For other platforms and for development installations, [build `ccflow` from source](Build-from-Source). 20 | -------------------------------------------------------------------------------- /.copier-answers.yaml: -------------------------------------------------------------------------------- 1 | # Changes here will be overwritten by Copier 2 | _commit: b74d698 3 | _src_path: https://github.com/python-project-templates/base.git 4 | add_docs: false 5 | add_extension: python 6 | add_wiki: true 7 | email: OpenSource@point72.com 8 | github: Point72 9 | project_description: ccflow is a collection of tools for workflow configuration, orchestration, 10 | and dependency injection 11 | project_name: ccflow 12 | python_version_primary: '3.11' 13 | team: Point72, L.P. 14 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/base.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from pydantic import conint 4 | 5 | from ccflow import ContextBase 6 | 7 | __all__ = ( 8 | "TPCHTable", 9 | "TPCHTableContext", 10 | "TPCHQueryContext", 11 | ) 12 | 13 | 14 | TPCHTable = Literal["customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier"] 15 | 16 | 17 | class TPCHTableContext(ContextBase): 18 | table: TPCHTable 19 | 20 | 21 | class TPCHQueryContext(ContextBase): 22 | query_id: conint(ge=1, le=22) 23 | -------------------------------------------------------------------------------- /.github/workflows/wiki.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Wiki 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "docs/**" 9 | - "README.md" 10 | workflow_dispatch: 11 | 12 | concurrency: 13 | group: docs 14 | cancel-in-progress: true 15 | 16 | permissions: 17 | contents: write 18 | 19 | jobs: 20 | deploy: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v6 24 | - run: cp README.md docs/wiki/Home.md 25 | - uses: Andrew-Chen-Wang/github-wiki-action@v5 26 | with: 27 | path: docs/wiki 28 | -------------------------------------------------------------------------------- /ccflow/tests/config/conf_out_of_order.yaml: -------------------------------------------------------------------------------- 1 | subregistry2: 2 | baz: 3 | _target_: ccflow.tests.test_base_registry.MyNestedModel 4 | x: /subregistry1/foo 5 | y: qux 6 | 7 | qux: 8 | _target_: ccflow.tests.test_base_registry.MyTestModel 9 | a: test 10 | b: 0.0 11 | 12 | subregistry1: 13 | foo: 14 | _target_: ccflow.tests.test_base_registry.MyTestModel 15 | a: test 16 | b: 0.0 17 | 18 | bar: 19 | _target_: ccflow.tests.test_base_registry.MyNestedModel 20 | x: foo 21 | y: /subregistry2/qux 22 | 23 | 24 | -------------------------------------------------------------------------------- /ccflow/result/pandas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pyarrow as pa 3 | from pydantic import field_validator 4 | 5 | from ..base import ResultBase 6 | 7 | __all__ = ("PandasResult",) 8 | 9 | 10 | class PandasResult(ResultBase): 11 | df: pd.DataFrame 12 | 13 | @field_validator("df", mode="before") 14 | def _from_arrow(cls, v): 15 | if isinstance(v, pa.Table): 16 | return v.to_pandas() 17 | return v 18 | 19 | @field_validator("df", mode="before") 20 | def _from_series(cls, v): 21 | if isinstance(v, pd.Series): 22 | return pd.DataFrame(v) 23 | return v 24 | -------------------------------------------------------------------------------- /ccflow/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.7.0" 2 | 3 | # Import exttypes early so modules that import `from ccflow import PyObjectPath` during 4 | # initialization find it (avoids circular import issues with functions that import utilities 5 | # which, in turn, import `ccflow`). 6 | from .exttypes import * # noqa: I001 7 | 8 | from .arrow import * 9 | from .base import * 10 | from .compose import * 11 | from .callable import * 12 | from .context import * 13 | from .enums import Enum 14 | from .global_state import * 15 | from .models import * 16 | from .object_config import * 17 | from .publisher import * 18 | from .result import * 19 | from .utils import FileHandler, StreamHandler 20 | -------------------------------------------------------------------------------- /ccflow/result/xarray.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pyarrow as pa 3 | import xarray as xr 4 | from pydantic import field_validator 5 | 6 | from ..base import ResultBase 7 | 8 | __all__ = ("XArrayResult",) 9 | 10 | 11 | class XArrayResult(ResultBase): 12 | array: xr.DataArray 13 | 14 | @field_validator("array", mode="before") 15 | def _from_pandas(cls, v): 16 | if isinstance(v, pd.DataFrame): 17 | return xr.DataArray(v) 18 | return v 19 | 20 | @field_validator("array", mode="before") 21 | def _from_arrow(cls, v): 22 | if isinstance(v, pa.Table): 23 | return xr.DataArray(v.to_pandas()) 24 | return v 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /ccflow/tests/result/test_pandas.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import pandas as pd 4 | import pyarrow as pa 5 | 6 | from ccflow.result.pandas import PandasResult 7 | 8 | 9 | class TestResult(TestCase): 10 | def test_pandas(self): 11 | df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) 12 | t = pa.Table.from_pandas(df) 13 | 14 | r = PandasResult(df=t) 15 | self.assertIsInstance(r.df, pd.DataFrame) 16 | 17 | r = PandasResult.model_validate({"df": t}) 18 | self.assertIsInstance(r.df, pd.DataFrame) 19 | 20 | r = PandasResult(df=df["A"]) 21 | self.assertIsInstance(r.df, pd.DataFrame) 22 | self.assertEqual(r.df.columns, ["A"]) 23 | -------------------------------------------------------------------------------- /ccflow/examples/etl/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Optional 3 | 4 | from ccflow import RootModelRegistry, load_config as load_config_base 5 | 6 | __all__ = ("load_config",) 7 | 8 | 9 | def load_config( 10 | config_dir: str = "", 11 | config_name: str = "", 12 | overrides: Optional[List[str]] = None, 13 | *, 14 | overwrite: bool = True, 15 | basepath: str = "", 16 | ) -> RootModelRegistry: 17 | return load_config_base( 18 | root_config_dir=str(Path(__file__).resolve().parent / "config"), 19 | root_config_name="base", 20 | config_dir=config_dir, 21 | config_name=config_name, 22 | overrides=overrides, 23 | overwrite=overwrite, 24 | basepath=basepath, 25 | ) 26 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q6.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query(line_item_ds: FrameT) -> FrameT: 13 | var_1 = datetime(1994, 1, 1) 14 | var_2 = datetime(1995, 1, 1) 15 | var_3 = 24 16 | 17 | return ( 18 | line_item_ds.filter( 19 | nw.col("l_shipdate").is_between(var_1, var_2, closed="left"), 20 | nw.col("l_discount").is_between(0.05, 0.07), 21 | nw.col("l_quantity") < var_3, 22 | ) 23 | .with_columns((nw.col("l_extendedprice") * nw.col("l_discount")).alias("revenue")) 24 | .select(nw.sum("revenue")) 25 | ) 26 | -------------------------------------------------------------------------------- /ccflow/tests/config/conf_from_python.yaml: -------------------------------------------------------------------------------- 1 | shared_model: 2 | _target_: ccflow.compose.from_python 3 | py_object_path: ccflow.tests.data.python_object_samples.SHARED_MODEL 4 | 5 | consumer: 6 | _target_: ccflow.tests.data.python_object_samples.Consumer 7 | shared: shared_model 8 | tag: consumer1 9 | 10 | # Demonstrate from_python returning a dict (non-BaseModel) 11 | holder: 12 | _target_: ccflow.tests.data.python_object_samples.SharedHolder 13 | name: holder1 14 | cfg: 15 | _target_: ccflow.compose.from_python 16 | py_object_path: ccflow.tests.data.python_object_samples.SHARED_CFG 17 | 18 | # Use update_from_template to update a field while preserving shared identity 19 | consumer_updated: 20 | _target_: ccflow.compose.update_from_template 21 | base: consumer 22 | update: 23 | tag: consumer2 24 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q13.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: 12 | var1 = "special" 13 | var2 = "requests" 14 | 15 | orders = orders_ds.filter(~nw.col("o_comment").str.contains(f"{var1}.*{var2}")) 16 | return ( 17 | customer_ds.join(orders, left_on="c_custkey", right_on="o_custkey", how="left") 18 | .group_by("c_custkey") 19 | .agg(nw.col("o_orderkey").count().alias("c_count")) 20 | .group_by("c_count") 21 | .agg(nw.len()) 22 | .select(nw.col("c_count"), nw.col("len").alias("custdist")) 23 | .sort(by=["custdist", "c_count"], descending=[True, True]) 24 | ) 25 | -------------------------------------------------------------------------------- /ccflow/tests/config/conf.yaml: -------------------------------------------------------------------------------- 1 | foo: 2 | _target_: ccflow.tests.test_base_registry.MyTestModel 3 | a: test 4 | b: 0.0 5 | c: 6 | - i 7 | - j 8 | d: 9 | k: 2.0 10 | bar: 11 | _target_: ccflow.tests.test_base_registry.MyNestedModel 12 | x: foo 13 | y: # Note that when type is defined on parent model, no need to specify _target_ 14 | a: test2 15 | b: 2.0 16 | baz: 17 | _target_: ccflow.tests.test_base_registry.MyNestedModel 18 | x: foo 19 | y: 20 | # We can still use _target_ to override the definition from the parent (i.e. subclass) 21 | _target_: ccflow.tests.test_base_registry.MyTestModelSubclass 22 | a: test3 23 | b: 3.0 24 | z: 25 | _target_: ccflow.tests.test_base_registry.MyClass 26 | p: pp 27 | q: 100. 28 | 29 | -------------------------------------------------------------------------------- /docs/wiki/_Sidebar.md: -------------------------------------------------------------------------------- 1 | 8 | 9 | **[Home](Home)** 10 | 11 | **Get Started** 12 | 13 | - [Installation](Installation) 14 | - [Design Goals](Design-Goals) 15 | - [Key Features](Key-Features) 16 | - [First Steps](First-Steps) 17 | 18 | **Tutorials** 19 | 20 | - [Configuration](Configuration) 21 | - [Workflows](Workflows) 22 | - [ETL](ETL) 23 | 24 | **Developer Guide** 25 | 26 | - [Contributing](Contribute) 27 | - [Development Setup](Local-Development-Setup) 28 | - [Build from Source](Build-from-Source) 29 | -------------------------------------------------------------------------------- /ccflow/tests/result/test_numpy.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | 5 | from ccflow.result.numpy import NumpyResult 6 | 7 | 8 | class TestResult(TestCase): 9 | def test_numpy(self): 10 | x = np.array([1.0, 3.0]) 11 | r = NumpyResult[np.float64](array=x) 12 | np.testing.assert_equal(r.array, x) 13 | 14 | # Check you can also construct from list 15 | r = NumpyResult[np.float64](array=x.tolist()) 16 | np.testing.assert_equal(r.array, x) 17 | 18 | self.assertRaises(TypeError, NumpyResult[np.float64], np.array(["foo"])) 19 | 20 | # Test generic 21 | r = NumpyResult[object](array=x) 22 | np.testing.assert_equal(r.array, x) 23 | r = NumpyResult[object](array=[None, "foo", 4.0]) 24 | np.testing.assert_equal(r.array, np.array([None, "foo", 4.0])) 25 | -------------------------------------------------------------------------------- /ccflow/tests/data/python_object_samples.py: -------------------------------------------------------------------------------- 1 | """Sample python objects for testing from_python and identity preservation.""" 2 | 3 | from typing import Dict 4 | 5 | from ccflow import BaseModel 6 | 7 | # Module-level objects 8 | SHARED_CFG: Dict[str, int] = {"x": 1, "y": 2} 9 | OTHER_CFG: Dict[str, int] = {"x": 10, "y": 20} 10 | """Dict samples; identity for dicts is not guaranteed by Pydantic.""" 11 | 12 | NESTED_CFG = { 13 | "db": {"host": "seed.local", "port": 7000, "name": "seed"}, 14 | "meta": {"env": "dev"}, 15 | } 16 | 17 | 18 | class SharedHolder(BaseModel): 19 | name: str 20 | cfg: Dict[str, int] 21 | 22 | 23 | class SharedModel(BaseModel): 24 | val: int = 0 25 | 26 | 27 | # Module-level instance to be resolved via from_python 28 | SHARED_MODEL = SharedModel(val=42) 29 | 30 | 31 | class Consumer(BaseModel): 32 | shared: SharedModel 33 | tag: str = "" 34 | -------------------------------------------------------------------------------- /ccflow/result/generic.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar 2 | 3 | from pydantic import model_validator 4 | 5 | from ..base import ResultBase 6 | 7 | __all__ = ("GenericResult",) 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class GenericResult(ResultBase, Generic[T]): 13 | """Holds anything.""" 14 | 15 | value: T 16 | 17 | @model_validator(mode="wrap") 18 | def _validate_generic_result(cls, v, handler, info): 19 | if isinstance(v, GenericResult) and not isinstance(v, cls): 20 | v = {"value": v.value} 21 | elif not isinstance(v, GenericResult) and not (isinstance(v, dict) and "value" in v): 22 | v = {"value": v} 23 | if isinstance(v, dict) and "value" in v: 24 | from ..context import GenericContext 25 | 26 | if isinstance(v["value"], GenericContext): 27 | v["value"] = v["value"].value 28 | return handler(v) 29 | -------------------------------------------------------------------------------- /ccflow/examples/etl/config/base.yaml: -------------------------------------------------------------------------------- 1 | extract: 2 | _target_: ccflow.PublisherModel 3 | model: 4 | _target_: ccflow.examples.etl.models.RestModel 5 | publisher: 6 | _target_: ccflow.publishers.GenericFilePublisher 7 | name: raw 8 | suffix: .html 9 | field: value 10 | 11 | transform: 12 | _target_: ccflow.PublisherModel 13 | model: 14 | _target_: ccflow.examples.etl.models.LinksModel 15 | file: ${extract.publisher.name}${extract.publisher.suffix} 16 | publisher: 17 | _target_: ccflow.publishers.GenericFilePublisher 18 | name: extracted 19 | suffix: .csv 20 | field: value 21 | 22 | load: 23 | _target_: ccflow.examples.etl.models.DBModel 24 | file: ${transform.publisher.name}${transform.publisher.suffix} 25 | db_file: etl.db 26 | table: links 27 | 28 | # Alternative multi-file approach 29 | # defaults: 30 | # - extract: rest 31 | # - transform: links 32 | # - load: db 33 | -------------------------------------------------------------------------------- /ccflow/tests/utils/test_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from tempfile import TemporaryDirectory 4 | 5 | from ccflow import FileHandler 6 | 7 | 8 | def test_file_handler(): 9 | with TemporaryDirectory() as tempdir: 10 | # Make an arbitrary path in a temporary directory 11 | output_file = Path(tempdir) / "a" / "random" / "sub" / "path" / "file.log" 12 | assert not output_file.exists(), "Output file should not exist before the test" 13 | 14 | # Attach handler to loggers 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.DEBUG) 17 | handler = FileHandler(str(output_file)) 18 | logger.addHandler(handler) 19 | 20 | # Print some stuff 21 | logger.info("Test log message") 22 | 23 | # Assert everything is ok 24 | assert output_file.exists() 25 | assert "Test log message" in output_file.read_text() 26 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q4.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query(line_item_ds: FrameT, orders_ds: FrameT) -> FrameT: 13 | var_1 = datetime(1993, 7, 1) 14 | var_2 = datetime(1993, 10, 1) 15 | 16 | return ( 17 | line_item_ds.join(orders_ds, left_on="l_orderkey", right_on="o_orderkey") 18 | .filter( 19 | nw.col("o_orderdate").is_between(var_1, var_2, closed="left"), 20 | nw.col("l_commitdate") < nw.col("l_receiptdate"), 21 | ) 22 | .unique(subset=["o_orderpriority", "l_orderkey"]) 23 | .group_by("o_orderpriority") 24 | .agg(nw.len().alias("order_count")) 25 | .sort(by="o_orderpriority") 26 | .with_columns(nw.col("order_count").cast(nw.Int64)) 27 | ) 28 | -------------------------------------------------------------------------------- /ccflow/tests/test_import.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | from unittest import TestCase 4 | 5 | 6 | class TestImports(TestCase): 7 | def test(self): 8 | # Assert that importing ccflow doesn't cause expensive modules to be imported. 9 | res = subprocess.run([sys.executable, __file__], capture_output=True, text=True) 10 | if res.returncode != 0: 11 | raise AssertionError(res.stderr) 12 | 13 | 14 | if __name__ == "__main__": 15 | import ccflow 16 | 17 | _ = ccflow 18 | expensive_imports = [ 19 | "ray", 20 | "deltalake", 21 | "emails", 22 | "matplotlib", 23 | "mlflow", 24 | "plotly", 25 | "pyarrow.dataset", 26 | # These aren't necessarily expensive, just things we don't want to import. 27 | "cexprtk", 28 | ] 29 | for m in expensive_imports: 30 | if m in sys.modules: 31 | raise AssertionError(f"{m} was imported!") 32 | -------------------------------------------------------------------------------- /ccflow/tests/config/conf_sub.yaml: -------------------------------------------------------------------------------- 1 | subregistry1: 2 | foo: 3 | _target_: ccflow.tests.test_base_registry.MyTestModel 4 | a: test 5 | b: 0.0 6 | c: 7 | - i 8 | - j 9 | d: 10 | k: 2.0 11 | bar: 12 | _target_: ccflow.tests.test_base_registry.MyNestedModel 13 | x: foo 14 | y: # Note that when type is defined on parent model, no need to specify _target_ 15 | a: test2 16 | b: 2.0 17 | 18 | subregistry2: 19 | baz: 20 | _target_: ccflow.tests.test_base_registry.MyNestedModel 21 | x: /subregistry1/foo 22 | y: 23 | # We can still use _target_ to override the definition from the parent (i.e. subclass) 24 | _target_: ccflow.tests.test_base_registry.MyTestModelSubclass 25 | a: test3 26 | b: 3.0 27 | z: 28 | _target_: ccflow.tests.test_base_registry.MyClass 29 | p: pp 30 | q: 100. 31 | -------------------------------------------------------------------------------- /ccflow/tests/exttypes/test_jinja.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from ccflow.exttypes.jinja import JinjaTemplate 4 | 5 | 6 | class TestJinjaTemplate(TestCase): 7 | def test_template(self): 8 | v = "My {{foo|lower}}" 9 | t = JinjaTemplate(v) 10 | self.assertEqual(t.template.render(foo="FOO"), "My foo") 11 | self.assertEqual(JinjaTemplate.validate(v), t) 12 | 13 | def test_bad(self): 14 | v = "My {{" 15 | self.assertRaises(ValueError, JinjaTemplate.validate, v) 16 | 17 | def test_deepcopy(self): 18 | # Pydantic models sometimes require deep copy, and this can pose problems 19 | # for Jinja templates if they are stored on the object, i.e. see https://github.com/pallets/jinja/issues/758 20 | from copy import deepcopy 21 | 22 | v = "My {{foo|lower}}" 23 | t = JinjaTemplate(v) 24 | 25 | # First access the template 26 | t.template 27 | # Then attempt the copy 28 | self.assertEqual(deepcopy(t), t) 29 | -------------------------------------------------------------------------------- /ccflow/tests/result/test_xarray.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import pandas as pd 4 | import pyarrow as pa 5 | import xarray as xr 6 | 7 | from ccflow.result.xarray import XArrayResult 8 | 9 | 10 | class TestResult(TestCase): 11 | def test_xarray(self): 12 | df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) 13 | da = xr.DataArray(df) 14 | 15 | r = XArrayResult(array=df) 16 | self.assertIsInstance(r.array, xr.DataArray) 17 | self.assertTrue(r.array.equals(da)) 18 | 19 | t = pa.Table.from_pandas(df) 20 | r = XArrayResult(array=t) 21 | self.assertIsInstance(r.array, xr.DataArray) 22 | self.assertTrue(r.array.equals(da)) 23 | 24 | r = XArrayResult.model_validate({"array": df}) 25 | self.assertIsInstance(r.array, xr.DataArray) 26 | self.assertTrue(r.array.equals(da)) 27 | 28 | r = XArrayResult.model_validate({"array": t}) 29 | self.assertIsInstance(r.array, xr.DataArray) 30 | self.assertTrue(r.array.equals(da)) 31 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q14.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query(line_item_ds: FrameT, part_ds: FrameT) -> FrameT: 13 | var1 = datetime(1995, 9, 1) 14 | var2 = datetime(1995, 10, 1) 15 | 16 | return ( 17 | line_item_ds.join(part_ds, left_on="l_partkey", right_on="p_partkey") 18 | .filter(nw.col("l_shipdate").is_between(var1, var2, closed="left")) 19 | .select( 20 | ( 21 | 100.00 22 | * nw.when(nw.col("p_type").str.contains("PROMO*")) 23 | .then(nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) 24 | .otherwise(0) 25 | .sum() 26 | / (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).sum() 27 | ) 28 | # .round(2) 29 | .alias("promo_revenue") 30 | ) 31 | ) 32 | -------------------------------------------------------------------------------- /ccflow/tests/models/test_publisher.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | from ccflow import CallableModel, DictResult, Flow, GenericResult, NullContext 4 | from ccflow.models import PublisherModel 5 | from ccflow.publishers import PrintPublisher 6 | 7 | 8 | class ModelTest(CallableModel): 9 | @Flow.call 10 | def __call__(self, context: NullContext) -> DictResult[str, str]: 11 | return DictResult[str, str](value={"message": "Hello, World!"}) 12 | 13 | 14 | class TestPublisherModel: 15 | def test_run(self): 16 | with patch("ccflow.publishers.print.print") as mock_print: 17 | model = PublisherModel(model=ModelTest(), publisher=PrintPublisher()) 18 | res = model(None) 19 | assert isinstance(res, GenericResult) # from PrintPublisher 20 | assert isinstance(res.value, DictResult[str, str]) 21 | assert res.value.value == {"message": "Hello, World!"} 22 | assert mock_print.call_count == 1 23 | assert mock_print.call_args[0][0].value == {"message": "Hello, World!"} 24 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q17.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query(lineitem_ds: FrameT, part_ds: FrameT) -> FrameT: 12 | var1 = "Brand#23" 13 | var2 = "MED BOX" 14 | 15 | query1 = ( 16 | part_ds.filter(nw.col("p_brand") == var1) 17 | .filter(nw.col("p_container") == var2) 18 | .join(lineitem_ds, how="left", left_on="p_partkey", right_on="l_partkey") 19 | ) 20 | 21 | return ( 22 | query1.with_columns(l_quantity_times_point_2=nw.col("l_quantity") * 0.2) 23 | .group_by("p_partkey") 24 | .agg(nw.col("l_quantity_times_point_2").mean().alias("avg_quantity")) 25 | .select(nw.col("p_partkey").alias("key"), nw.col("avg_quantity")) 26 | .join(query1, left_on="key", right_on="p_partkey") 27 | .filter(nw.col("l_quantity") < nw.col("avg_quantity")) 28 | .select((nw.col("l_extendedprice").sum() / 7.0).round(2).alias("avg_yearly")) 29 | ) 30 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q11.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query( 12 | nation_ds: FrameT, 13 | partsupp_ds: FrameT, 14 | supplier_ds: FrameT, 15 | ) -> FrameT: 16 | var1 = "GERMANY" 17 | var2 = 0.0001 18 | 19 | q1 = ( 20 | partsupp_ds.join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey") 21 | .join(nation_ds, left_on="s_nationkey", right_on="n_nationkey") 22 | .filter(nw.col("n_name") == var1) 23 | ) 24 | q2 = q1.select( 25 | (nw.col("ps_supplycost") * nw.col("ps_availqty")).sum().round(2).alias("tmp") 26 | * var2 27 | ) 28 | 29 | return ( 30 | q1.with_columns((nw.col("ps_supplycost") * nw.col("ps_availqty")).alias("value")) 31 | .group_by("ps_partkey") 32 | .agg(nw.sum("value")) 33 | .join(q2, how="cross") 34 | .filter(nw.col("value") > nw.col("tmp")) 35 | .select("ps_partkey", "value") 36 | .sort("value", descending=True) 37 | ) 38 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024, Marco Gorelli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q15.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query( 13 | lineitem_ds: FrameT, 14 | supplier_ds: FrameT, 15 | ) -> FrameT: 16 | var1 = datetime(1996, 1, 1) 17 | var2 = datetime(1996, 4, 1) 18 | 19 | revenue = ( 20 | lineitem_ds.filter(nw.col("l_shipdate").is_between(var1, var2, closed="left")) 21 | .with_columns( 22 | (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias( 23 | "total_revenue" 24 | ) 25 | ) 26 | .group_by("l_suppkey") 27 | .agg(nw.sum("total_revenue")) 28 | .select(nw.col("l_suppkey").alias("supplier_no"), nw.col("total_revenue")) 29 | ) 30 | 31 | return ( 32 | supplier_ds.join(revenue, left_on="s_suppkey", right_on="supplier_no") 33 | .filter(nw.col("total_revenue") == nw.col("total_revenue").max()) 34 | .with_columns(nw.col("total_revenue").round(2)) 35 | .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") 36 | .sort("s_suppkey") 37 | ) 38 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q16.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query(part_ds: FrameT, partsupp_ds: FrameT, supplier_ds: FrameT) -> FrameT: 12 | var1 = "Brand#45" 13 | 14 | supplier = supplier_ds.filter( 15 | nw.col("s_comment").str.contains(".*Customer.*Complaints.*") 16 | ).select(nw.col("s_suppkey"), nw.col("s_suppkey").alias("ps_suppkey")) 17 | 18 | return ( 19 | part_ds.join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") 20 | .filter(nw.col("p_brand") != var1) 21 | .filter(~nw.col("p_type").str.contains("MEDIUM POLISHED*")) 22 | .filter(nw.col("p_size").is_in([49, 14, 23, 45, 19, 3, 36, 9])) 23 | .join(supplier, left_on="ps_suppkey", right_on="s_suppkey", how="left") 24 | .filter(nw.col("ps_suppkey_right").is_null()) 25 | .group_by("p_brand", "p_type", "p_size") 26 | .agg(nw.col("ps_suppkey").n_unique().alias("supplier_cnt")) 27 | .sort( 28 | by=["supplier_cnt", "p_brand", "p_type", "p_size"], 29 | descending=[True, False, False, False], 30 | ) 31 | ) 32 | -------------------------------------------------------------------------------- /docs/wiki/Contribute.md: -------------------------------------------------------------------------------- 1 | Contributions are welcome on this project. We distribute under the terms of the [Apache 2.0 license](https://github.com/Point72/ccflow/blob/main/LICENSE). 2 | 3 | > [!NOTE] 4 | > 5 | > `ccflow` requires [Developer Certificate of Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) for all contributions. 6 | > This is enforced by a [Probot GitHub App](https://probot.github.io/apps/dco/), which checks that commits are "signed". 7 | > Read [instructions to configure commit signing](Local-Development-Setup#configure-commit-signing). 8 | 9 | For **bug reports** or **small feature requests**, please open an issue on our [issues page](https://github.com/Point72/ccflow/issues). 10 | 11 | For **questions** or to discuss **larger changes or features**, please use our [discussions page](https://github.com/Point72/ccflow/discussions). 12 | 13 | For **contributions**, please see our [developer documentation](Local-Development-Setup). We have `help wanted` and `good first issue` tags on our issues page, so these are a great place to start. 14 | 15 | For **documentation updates**, make PRs that update the pages in `/docs/wiki`. The documentation is pushed to the GitHub wiki automatically through a GitHub workflow. Note that direct updates to this wiki will be overwritten. 16 | -------------------------------------------------------------------------------- /docs/wiki/contribute/Contribute.md: -------------------------------------------------------------------------------- 1 | Contributions are welcome on this project. We distribute under the terms of the [Apache 2.0 license](https://github.com/Point72/ccflow/blob/main/LICENSE). 2 | 3 | > [!NOTE] 4 | > 5 | > `ccflow` requires [Developer Certificate of Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) for all contributions. 6 | > This is enforced by a [Probot GitHub App](https://probot.github.io/apps/dco/), which checks that commits are "signed". 7 | > Read [instructions to configure commit signing](Local-Development-Setup#configure-commit-signing). 8 | 9 | For **bug reports** or **small feature requests**, please open an issue on our [issues page](https://github.com/Point72/ccflow/issues). 10 | 11 | For **questions** or to discuss **larger changes or features**, please use our [discussions page](https://github.com/Point72/ccflow/discussions). 12 | 13 | For **contributions**, please see our [developer documentation](Local-Development-Setup). We have `help wanted` and `good first issue` tags on our issues page, so these are a great place to start. 14 | 15 | For **documentation updates**, make PRs that update the pages in `/docs/wiki`. The documentation is pushed to the GitHub wiki automatically through a GitHub workflow. Note that direct updates to this wiki will be overwritten. 16 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q18.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query(customer_ds: FrameT, lineitem_ds: FrameT, orders_ds: FrameT) -> FrameT: 12 | var1 = 300 13 | 14 | query1 = ( 15 | lineitem_ds.group_by("l_orderkey") 16 | .agg(nw.col("l_quantity").sum().alias("sum_quantity")) 17 | .filter(nw.col("sum_quantity") > var1) 18 | ) 19 | 20 | return ( 21 | orders_ds.join(query1, left_on="o_orderkey", right_on="l_orderkey", how="semi") 22 | .join(lineitem_ds, left_on="o_orderkey", right_on="l_orderkey") 23 | .join(customer_ds, left_on="o_custkey", right_on="c_custkey") 24 | .group_by("c_name", "o_custkey", "o_orderkey", "o_orderdate", "o_totalprice") 25 | .agg(nw.col("l_quantity").sum().alias("sum")) 26 | .select( 27 | nw.col("c_name"), 28 | nw.col("o_custkey").alias("c_custkey"), 29 | nw.col("o_orderkey"), 30 | nw.col("o_orderdate"), 31 | nw.col("o_totalprice"), 32 | nw.col("sum"), 33 | ) 34 | .sort(by=["o_totalprice", "o_orderdate"], descending=[True, False]) 35 | .head(100) 36 | ) 37 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q1.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query(lineitem: FrameT) -> FrameT: 13 | var_1 = datetime(1998, 9, 2) 14 | return ( 15 | lineitem.filter(nw.col("l_shipdate") <= var_1) 16 | .with_columns( 17 | disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")), 18 | charge=( 19 | nw.col("l_extendedprice") 20 | * (1.0 - nw.col("l_discount")) 21 | * (1.0 + nw.col("l_tax")) 22 | ), 23 | ) 24 | .group_by("l_returnflag", "l_linestatus") 25 | .agg( 26 | nw.sum("l_quantity").alias("sum_qty"), 27 | nw.sum("l_extendedprice").alias("sum_base_price"), 28 | nw.sum("disc_price").alias("sum_disc_price"), 29 | nw.sum("charge").alias("sum_charge"), 30 | nw.mean("l_quantity").alias("avg_qty"), 31 | nw.mean("l_extendedprice").alias("avg_price"), 32 | nw.mean("l_discount").alias("avg_disc"), 33 | nw.len().alias("count_order"), 34 | ) 35 | .sort("l_returnflag", "l_linestatus") 36 | ) 37 | -------------------------------------------------------------------------------- /ccflow/utils/formatter.py: -------------------------------------------------------------------------------- 1 | """Custom log formatters for result types.""" 2 | 3 | import logging 4 | import pprint 5 | 6 | import narwhals as nw 7 | import polars as pl 8 | import pyarrow as pa 9 | from pydantic import BaseModel 10 | 11 | 12 | class PolarsTableFormatter(logging.Formatter): 13 | """Formats Arrow Tables and Narwhals eager Dataframes as polars tables. 14 | Leaves Narwhals LazyFrame and other types as-is. 15 | """ 16 | 17 | def __init__(self, *args, **kwargs): 18 | self.polars_config = kwargs.pop("polars_config", {}) 19 | super().__init__(*args, **kwargs) 20 | 21 | def format(self, record: logging.LogRecord) -> str: 22 | """Formats the log record and converts Arrow Tables and Narwhals eager Dataframes to polars tables.""" 23 | if hasattr(record, "result") and isinstance(record.result, BaseModel): 24 | out = record.result.model_dump(by_alias=True) 25 | for k, v in out.items(): 26 | if isinstance(v, pa.Table): 27 | out[k] = pl.from_arrow(v) 28 | elif isinstance(v, nw.DataFrame): 29 | out[k] = v.to_polars() 30 | with pl.Config(**self.polars_config): 31 | record.msg = f"{record.msg}\n{pprint.pformat(out, width=120)}" 32 | return super().format(record) 33 | -------------------------------------------------------------------------------- /ccflow/examples/etl/__main__.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | from ccflow.utils.hydra import cfg_run 4 | 5 | __all__ = ("main",) 6 | 7 | 8 | @hydra.main(config_path="config", config_name="base", version_base=None) 9 | def main(cfg): 10 | cfg_run(cfg) 11 | 12 | 13 | # Extract step: 14 | # python -m ccflow.examples.etl +callable=extract +context=[] 15 | # Change url, as example of context override: 16 | # python -m ccflow.examples.etl +callable=extract +context=["http://lobste.rs"] 17 | # Change file name, as example of callable override: 18 | # python -m ccflow.examples.etl +callable=extract +context=["http://lobste.rs"] ++extract.publisher.name=lobsters 19 | 20 | # Transform step: 21 | # python -m ccflow.examples.etl +callable=transform +context=[] 22 | # python -m ccflow.examples.etl +callable=transform +context=[] ++transform.model.file=lobsters.html ++transform.publisher.name=lobsters 23 | 24 | # Load step: 25 | # python -m ccflow.examples.etl +callable=load +context=[] 26 | # python -m ccflow.examples.etl +callable=load +context=[] ++load.file=lobsters.csv ++load.db_file=":memory:" 27 | 28 | # View SQLite DB: 29 | # sqlite3 etl.db 30 | # .tables 31 | # select * from links; 32 | # .quit 33 | 34 | # [project.scripts] 35 | # etl = "ccflow.examples.etl:main" 36 | # etl-explain = "ccflow.examples.etl:explain" 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q22.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: 12 | q1 = ( 13 | customer_ds.with_columns(nw.col("c_phone").str.slice(0, 2).alias("cntrycode")) 14 | .filter(nw.col("cntrycode").str.contains("13|31|23|29|30|18|17")) 15 | .select("c_acctbal", "c_custkey", nw.col("cntrycode").cast(nw.Int64())) 16 | ) 17 | 18 | q2 = q1.filter(nw.col("c_acctbal") > 0.0).select( 19 | nw.col("c_acctbal").mean().alias("avg_acctbal") 20 | ) 21 | 22 | q3 = ( 23 | orders_ds.select("o_custkey") 24 | .unique("o_custkey") 25 | .with_columns(nw.col("o_custkey").alias("c_custkey")) 26 | ) 27 | 28 | return ( 29 | q1.join(q3, left_on="c_custkey", right_on="c_custkey", how="left") 30 | .filter(nw.col("o_custkey").is_null()) 31 | .join(q2, how="cross") 32 | .filter(nw.col("c_acctbal") > nw.col("avg_acctbal")) 33 | .group_by("cntrycode") 34 | .agg( 35 | nw.col("c_acctbal").count().alias("numcust"), 36 | nw.col("c_acctbal").sum().alias("totacctbal"), 37 | ) 38 | .sort("cntrycode") 39 | ) 40 | -------------------------------------------------------------------------------- /ccflow/global_state.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from pydantic import Field 4 | 5 | from .base import BaseModel, ModelRegistry 6 | from .callable import FlowOptionsOverride 7 | 8 | __all__ = ("GlobalState",) 9 | 10 | 11 | class GlobalState(BaseModel): 12 | """Representation of the global state of the ccflow library. 13 | 14 | Useful when running ccflow functions in remote processes, e.g. with Ray. 15 | """ 16 | 17 | root_registry: ModelRegistry = Field(default_factory=lambda: ModelRegistry.root().clone(name="_")) 18 | open_overrides: Dict[int, FlowOptionsOverride] = Field(default_factory=lambda: FlowOptionsOverride._OPEN_OVERRIDES.copy()) 19 | _old_state: Optional["GlobalState"] = None 20 | 21 | @classmethod 22 | def set(cls, state: "GlobalState"): 23 | root = ModelRegistry.root() 24 | root.clear() 25 | for name, model in state.root_registry.models.items(): 26 | root.add(name, model) 27 | 28 | FlowOptionsOverride._OPEN_OVERRIDES = state.open_overrides 29 | 30 | def __enter__(self): 31 | self._old_state = GlobalState() 32 | GlobalState.set(self) 33 | return self 34 | 35 | def __exit__(self, exc_type, exc_value, traceback): 36 | if self._old_state is not None: 37 | GlobalState.set(self._old_state) 38 | self._old_state = None 39 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q12.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query(line_item_ds: FrameT, orders_ds: FrameT) -> FrameT: 13 | var1 = "MAIL" 14 | var2 = "SHIP" 15 | var3 = datetime(1994, 1, 1) 16 | var4 = datetime(1995, 1, 1) 17 | 18 | return ( 19 | orders_ds.join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") 20 | .filter(nw.col("l_shipmode").is_in([var1, var2])) 21 | .filter(nw.col("l_commitdate") < nw.col("l_receiptdate")) 22 | .filter(nw.col("l_shipdate") < nw.col("l_commitdate")) 23 | .filter(nw.col("l_receiptdate").is_between(var3, var4, closed="left")) 24 | .with_columns( 25 | nw.when(nw.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"])) 26 | .then(1) 27 | .otherwise(0) 28 | .alias("high_line_count"), 29 | nw.when(~nw.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"])) 30 | .then(1) 31 | .otherwise(0) 32 | .alias("low_line_count"), 33 | ) 34 | .group_by("l_shipmode") 35 | .agg(nw.col("high_line_count").sum(), nw.col("low_line_count").sum()) 36 | .sort("l_shipmode") 37 | ) 38 | -------------------------------------------------------------------------------- /ccflow/utils/core.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Set, TypeVar 2 | 3 | from pydantic import BaseModel as PydanticBaseModel, ConfigDict, create_model 4 | 5 | __all__ = ( 6 | "PydanticModelType", 7 | "PydanticDictOptions", 8 | "dict_to_model", 9 | ) 10 | 11 | PydanticModelType = TypeVar("ModelType", bound=PydanticBaseModel) 12 | 13 | 14 | class PydanticDictOptions(PydanticBaseModel): 15 | """See https://pydantic-docs.helpmanual.io/usage/exporting_models/#modeldict""" 16 | 17 | model_config = ConfigDict( 18 | # Want to validate assignment so that if lists are assigned to include/exclude, they get validated 19 | validate_assignment=True 20 | ) 21 | 22 | include: Set[str] = None 23 | exclude: Set[str] = set() 24 | by_alias: bool = False 25 | exclude_unset: bool = False 26 | exclude_defaults: bool = False 27 | exclude_none: bool = False 28 | 29 | 30 | def dict_to_model(cls, v) -> PydanticBaseModel: 31 | """Validator to coerce dict to a pydantic base model without loss of data when no type specified. 32 | Without it, dict is coerced to PydanticBaseModel, losing all data. 33 | """ 34 | if isinstance(v, dict): 35 | config = ConfigDict(arbitrary_types_allowed=True) 36 | 37 | fields = {f: (Any, None) for f in v} 38 | v = create_model("DynamicDictModel", **fields, __config__=config)(**v) 39 | return v 40 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query( 13 | customer_ds: FrameT, 14 | line_item_ds: FrameT, 15 | orders_ds: FrameT, 16 | ) -> FrameT: 17 | var_1 = var_2 = datetime(1995, 3, 15) 18 | var_3 = "BUILDING" 19 | 20 | return ( 21 | customer_ds.filter(nw.col("c_mktsegment") == var_3) 22 | .join(orders_ds, left_on="c_custkey", right_on="o_custkey") 23 | .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") 24 | .filter( 25 | nw.col("o_orderdate") < var_2, 26 | nw.col("l_shipdate") > var_1, 27 | ) 28 | .with_columns( 29 | (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue") 30 | ) 31 | .group_by(["o_orderkey", "o_orderdate", "o_shippriority"]) 32 | .agg([nw.sum("revenue")]) 33 | .select( 34 | [ 35 | nw.col("o_orderkey").alias("l_orderkey"), 36 | "revenue", 37 | "o_orderdate", 38 | "o_shippriority", 39 | ] 40 | ) 41 | .sort(by=["revenue", "o_orderdate"], descending=[True, False]) 42 | .head(10) 43 | ) 44 | -------------------------------------------------------------------------------- /ccflow/tests/test_evaluator.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from unittest import TestCase 3 | 4 | from ccflow import DateContext, Evaluator, ModelEvaluationContext 5 | 6 | from .evaluators.util import MyDateCallable 7 | 8 | 9 | class TestEvaluator(TestCase): 10 | def test_evaluator(self): 11 | m1 = MyDateCallable(offset=1) 12 | context = DateContext(date=date(2022, 1, 1)) 13 | model_evaluation_context = ModelEvaluationContext(model=m1, context=context) 14 | model_evaluation_context2 = ModelEvaluationContext(model=m1, context=context, fn="__call__") 15 | 16 | out = model_evaluation_context() 17 | self.assertEqual(out, m1(context)) 18 | out2 = model_evaluation_context2() 19 | self.assertEqual(out, out2) 20 | 21 | evaluator = Evaluator() 22 | out2 = evaluator(model_evaluation_context) 23 | self.assertEqual(out2, out) 24 | 25 | def test_evaluator_deps(self): 26 | m1 = MyDateCallable(offset=1) 27 | context = DateContext(date=date(2022, 1, 1)) 28 | model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__deps__") 29 | out = model_evaluation_context() 30 | self.assertEqual(out, m1.__deps__(context)) 31 | 32 | evaluator = Evaluator() 33 | out2 = evaluator.__deps__(model_evaluation_context) 34 | self.assertEqual(out2, out) 35 | -------------------------------------------------------------------------------- /ccflow/result/narwhals.py: -------------------------------------------------------------------------------- 1 | import narwhals.stable.v1 as nw 2 | from pydantic import Field, model_validator 3 | 4 | from ..base import ResultBase 5 | from ..exttypes.narwhals import DataFrameT, FrameT 6 | 7 | __all__ = ( 8 | "NarwhalsFrameResult", 9 | "NarwhalsDataFrameResult", 10 | ) 11 | 12 | 13 | class NarwhalsFrameResult(ResultBase): 14 | """Result that holds a Narwhals DataFrame or LazyFrame.""" 15 | 16 | df: FrameT = Field(description="Narwhals DataFrame or LazyFrame") 17 | 18 | def collect(self) -> "NarwhalsDataFrameResult": 19 | """Collects the result into a NarwhalsDataFrameResult.""" 20 | if isinstance(self.df, nw.LazyFrame): 21 | return NarwhalsDataFrameResult(df=self.df.collect(), **self.model_dump(exclude={"df", "type_"})) 22 | return NarwhalsDataFrameResult(df=self.df, **self.model_dump(exclude={"df", "type_"})) 23 | 24 | @model_validator(mode="wrap") 25 | def _validate(cls, v, handler, info): 26 | if not isinstance(v, NarwhalsFrameResult) and not (isinstance(v, dict) and "df" in v): 27 | v = {"df": v} 28 | return handler(v) 29 | 30 | 31 | class NarwhalsDataFrameResult(NarwhalsFrameResult): 32 | df: DataFrameT = Field(description="Narwhals eager Dataframe") 33 | 34 | def collect(self) -> "NarwhalsDataFrameResult": 35 | """Collects the result into a NarwhalsDataFrameResult.""" 36 | return self 37 | -------------------------------------------------------------------------------- /ccflow/exttypes/jinja.py: -------------------------------------------------------------------------------- 1 | """This module contains extension types for pydantic.""" 2 | 3 | from typing import Any 4 | 5 | import jinja2 6 | from pydantic import TypeAdapter 7 | from pydantic_core import core_schema 8 | from typing_extensions import Self 9 | 10 | 11 | class JinjaTemplate(str): 12 | """String that is validated as a jinja2 template.""" 13 | 14 | @property 15 | def template(self) -> jinja2.Template: 16 | """Return the underlying object that the path corresponds to.""" 17 | return jinja2.Template(str(self)) 18 | 19 | @classmethod 20 | def __get_pydantic_core_schema__(cls, source_type, handler): 21 | return core_schema.no_info_plain_validator_function(cls._validate) 22 | 23 | @classmethod 24 | def _validate(cls, value: Any) -> Self: 25 | if isinstance(value, JinjaTemplate): 26 | return value 27 | 28 | if isinstance(value, str): 29 | value = cls(value) 30 | try: 31 | value.template 32 | except Exception as e: 33 | raise ValueError(f"ensure this value contains a valid Jinja2 template string: {e}") 34 | 35 | return value 36 | 37 | @classmethod 38 | def validate(cls, value: Any) -> Self: 39 | """Try to convert/validate an arbitrary value to a JinjaTemplate.""" 40 | return _TYPE_ADAPTER.validate_python(value) 41 | 42 | 43 | _TYPE_ADAPTER = TypeAdapter(JinjaTemplate) 44 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q5.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query( 13 | region_ds: FrameT, 14 | nation_ds: FrameT, 15 | customer_ds: FrameT, 16 | line_item_ds: FrameT, 17 | orders_ds: FrameT, 18 | supplier_ds: FrameT, 19 | ) -> FrameT: 20 | var_1 = "ASIA" 21 | var_2 = datetime(1994, 1, 1) 22 | var_3 = datetime(1995, 1, 1) 23 | 24 | return ( 25 | region_ds.join(nation_ds, left_on="r_regionkey", right_on="n_regionkey") 26 | .join(customer_ds, left_on="n_nationkey", right_on="c_nationkey") 27 | .join(orders_ds, left_on="c_custkey", right_on="o_custkey") 28 | .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") 29 | .join( 30 | supplier_ds, 31 | left_on=["l_suppkey", "n_nationkey"], 32 | right_on=["s_suppkey", "s_nationkey"], 33 | ) 34 | .filter( 35 | nw.col("r_name") == var_1, 36 | nw.col("o_orderdate").is_between(var_2, var_3, closed="left"), 37 | ) 38 | .with_columns( 39 | (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue") 40 | ) 41 | .group_by("n_name") 42 | .agg([nw.sum("revenue")]) 43 | .sort(by="revenue", descending=True) 44 | ) 45 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | 15 | 16 | **Describe the bug** 17 | 18 | 19 | **To Reproduce** 20 | 30 | 31 | 32 | 33 | **Expected behavior** 34 | 35 | 36 | **Error Message** 37 | 39 | 40 | **Runtime Environment** 41 | 45 | 46 | **Additional context** 47 | 50 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q9.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query( 12 | part_ds: FrameT, 13 | partsupp_ds: FrameT, 14 | nation_ds: FrameT, 15 | lineitem_ds: FrameT, 16 | orders_ds: FrameT, 17 | supplier_ds: FrameT, 18 | ) -> FrameT: 19 | return ( 20 | part_ds.join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") 21 | .join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey") 22 | .join( 23 | lineitem_ds, 24 | left_on=["p_partkey", "ps_suppkey"], 25 | right_on=["l_partkey", "l_suppkey"], 26 | ) 27 | .join(orders_ds, left_on="l_orderkey", right_on="o_orderkey") 28 | .join(nation_ds, left_on="s_nationkey", right_on="n_nationkey") 29 | .filter(nw.col("p_name").str.contains("green")) 30 | .select( 31 | nw.col("n_name").alias("nation"), 32 | nw.col("o_orderdate").dt.year().alias("o_year"), 33 | ( 34 | nw.col("l_extendedprice") * (1 - nw.col("l_discount")) 35 | - nw.col("ps_supplycost") * nw.col("l_quantity") 36 | ).alias("amount"), 37 | ) 38 | .group_by("nation", "o_year") 39 | .agg(nw.sum("amount").alias("sum_profit")) 40 | .sort(by=["nation", "o_year"], descending=[False, True]) 41 | ) 42 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q21.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query( 12 | lineitem: FrameT, 13 | nation: FrameT, 14 | orders: FrameT, 15 | supplier: FrameT, 16 | ) -> FrameT: 17 | var1 = "SAUDI ARABIA" 18 | 19 | q1 = ( 20 | lineitem.group_by("l_orderkey") 21 | .agg(nw.len().alias("n_supp_by_order")) 22 | .filter(nw.col("n_supp_by_order") > 1) 23 | .join( 24 | lineitem.filter(nw.col("l_receiptdate") > nw.col("l_commitdate")), 25 | left_on="l_orderkey", 26 | right_on="l_orderkey", 27 | ) 28 | ) 29 | 30 | return ( 31 | q1.group_by("l_orderkey") 32 | .agg(nw.len().alias("n_supp_by_order")) 33 | .join( 34 | q1, 35 | left_on="l_orderkey", 36 | right_on="l_orderkey", 37 | ) 38 | .join(supplier, left_on="l_suppkey", right_on="s_suppkey") 39 | .join(nation, left_on="s_nationkey", right_on="n_nationkey") 40 | .join(orders, left_on="l_orderkey", right_on="o_orderkey") 41 | .filter(nw.col("n_supp_by_order") == 1) 42 | .filter(nw.col("n_name") == var1) 43 | .filter(nw.col("o_orderstatus") == "F") 44 | .group_by("s_name") 45 | .agg(nw.len().alias("numwait")) 46 | .sort(by=["numwait", "s_name"], descending=[True, False]) 47 | .head(100) 48 | ) 49 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q10.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query( 13 | customer_ds: FrameT, 14 | nation_ds: FrameT, 15 | lineitem_ds: FrameT, 16 | orders_ds: FrameT, 17 | ) -> FrameT: 18 | var1 = datetime(1993, 10, 1) 19 | var2 = datetime(1994, 1, 1) 20 | 21 | return ( 22 | customer_ds.join(orders_ds, left_on="c_custkey", right_on="o_custkey") 23 | .join(lineitem_ds, left_on="o_orderkey", right_on="l_orderkey") 24 | .join(nation_ds, left_on="c_nationkey", right_on="n_nationkey") 25 | .filter(nw.col("o_orderdate").is_between(var1, var2, closed="left")) 26 | .filter(nw.col("l_returnflag") == "R") 27 | .with_columns( 28 | (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue") 29 | ) 30 | .group_by( 31 | "c_custkey", 32 | "c_name", 33 | "c_acctbal", 34 | "c_phone", 35 | "n_name", 36 | "c_address", 37 | "c_comment", 38 | ) 39 | .agg(nw.sum("revenue")) 40 | .select( 41 | "c_custkey", 42 | "c_name", 43 | "revenue", 44 | "c_acctbal", 45 | "n_name", 46 | "c_address", 47 | "c_phone", 48 | "c_comment", 49 | ) 50 | .sort(by="revenue", descending=True) 51 | .head(20) 52 | ) 53 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q20.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query( 13 | part_ds: FrameT, 14 | partsupp_ds: FrameT, 15 | nation_ds: FrameT, 16 | lineitem_ds: FrameT, 17 | supplier_ds: FrameT, 18 | ) -> FrameT: 19 | var1 = datetime(1994, 1, 1) 20 | var2 = datetime(1995, 1, 1) 21 | var3 = "CANADA" 22 | var4 = "forest" 23 | 24 | query1 = ( 25 | lineitem_ds.filter(nw.col("l_shipdate").is_between(var1, var2, closed="left")) 26 | .group_by("l_partkey", "l_suppkey") 27 | .agg((nw.col("l_quantity").sum()).alias("sum_quantity")) 28 | .with_columns(sum_quantity=nw.col("sum_quantity") * 0.5) 29 | ) 30 | query2 = nation_ds.filter(nw.col("n_name") == var3) 31 | query3 = supplier_ds.join(query2, left_on="s_nationkey", right_on="n_nationkey") 32 | 33 | return ( 34 | part_ds.filter(nw.col("p_name").str.starts_with(var4)) 35 | .select("p_partkey") 36 | .unique("p_partkey") 37 | .join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") 38 | .join( 39 | query1, 40 | left_on=["ps_suppkey", "p_partkey"], 41 | right_on=["l_suppkey", "l_partkey"], 42 | ) 43 | .filter(nw.col("ps_availqty") > nw.col("sum_quantity")) 44 | .select("ps_suppkey") 45 | .unique("ps_suppkey") 46 | .join(query3, left_on="ps_suppkey", right_on="s_suppkey") 47 | .select("s_name", "s_address") 48 | .sort("s_name") 49 | ) 50 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q19.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query(lineitem_ds: FrameT, part_ds: FrameT) -> FrameT: 12 | return ( 13 | part_ds.join(lineitem_ds, left_on="p_partkey", right_on="l_partkey") 14 | .filter(nw.col("l_shipmode").is_in(["AIR", "AIR REG"])) 15 | .filter(nw.col("l_shipinstruct") == "DELIVER IN PERSON") 16 | .filter( 17 | ( 18 | (nw.col("p_brand") == "Brand#12") 19 | & nw.col("p_container").is_in(["SM CASE", "SM BOX", "SM PACK", "SM PKG"]) 20 | & (nw.col("l_quantity").is_between(1, 11)) 21 | & (nw.col("p_size").is_between(1, 5)) 22 | ) 23 | | ( 24 | (nw.col("p_brand") == "Brand#23") 25 | & nw.col("p_container").is_in( 26 | ["MED BAG", "MED BOX", "MED PKG", "MED PACK"] 27 | ) 28 | & (nw.col("l_quantity").is_between(10, 20)) 29 | & (nw.col("p_size").is_between(1, 10)) 30 | ) 31 | | ( 32 | (nw.col("p_brand") == "Brand#34") 33 | & nw.col("p_container").is_in(["LG CASE", "LG BOX", "LG PACK", "LG PKG"]) 34 | & (nw.col("l_quantity").is_between(20, 30)) 35 | & (nw.col("p_size").is_between(1, 15)) 36 | ) 37 | ) 38 | .select( 39 | (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) 40 | .sum() 41 | .round(2) 42 | .alias("revenue") 43 | ) 44 | ) 45 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import narwhals as nw 6 | 7 | if TYPE_CHECKING: 8 | from narwhals.typing import FrameT 9 | 10 | 11 | def query( 12 | region_ds: FrameT, 13 | nation_ds: FrameT, 14 | supplier_ds: FrameT, 15 | part_ds: FrameT, 16 | part_supp_ds: FrameT, 17 | ) -> FrameT: 18 | var_1 = 15 19 | var_2 = "BRASS" 20 | var_3 = "EUROPE" 21 | 22 | result_q2 = ( 23 | part_ds.join(part_supp_ds, left_on="p_partkey", right_on="ps_partkey") 24 | .join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey") 25 | .join(nation_ds, left_on="s_nationkey", right_on="n_nationkey") 26 | .join(region_ds, left_on="n_regionkey", right_on="r_regionkey") 27 | .filter( 28 | nw.col("p_size") == var_1, 29 | nw.col("p_type").str.ends_with(var_2), 30 | nw.col("r_name") == var_3, 31 | ) 32 | ) 33 | 34 | final_cols = [ 35 | "s_acctbal", 36 | "s_name", 37 | "n_name", 38 | "p_partkey", 39 | "p_mfgr", 40 | "s_address", 41 | "s_phone", 42 | "s_comment", 43 | ] 44 | 45 | return ( 46 | result_q2.group_by("p_partkey") 47 | .agg(nw.col("ps_supplycost").min().alias("ps_supplycost")) 48 | .join( 49 | result_q2, 50 | left_on=["p_partkey", "ps_supplycost"], 51 | right_on=["p_partkey", "ps_supplycost"], 52 | ) 53 | .select(final_cols) 54 | .sort( 55 | ["s_acctbal", "n_name", "s_name", "p_partkey"], 56 | descending=[True, False, False, False], 57 | ) 58 | .head(100) 59 | ) 60 | -------------------------------------------------------------------------------- /ccflow/tests/examples/test_tpch.py: -------------------------------------------------------------------------------- 1 | from typing import get_args 2 | 3 | import pytest 4 | from polars.testing import assert_frame_equal 5 | 6 | from ccflow.examples.tpch import TPCHAnswerGenerator, TPCHDataGenerator, TPCHQueryContext, TPCHQueryRunner, TPCHTable, TPCHTableContext 7 | 8 | 9 | @pytest.fixture(scope="module") 10 | def scale_factor(): 11 | return 0.1 12 | 13 | 14 | @pytest.fixture(scope="module") 15 | def tpch_answer_generator(scale_factor): 16 | return TPCHAnswerGenerator(scale_factor=scale_factor) 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def tpch_data_generator(scale_factor): 21 | return TPCHDataGenerator(scale_factor=scale_factor) 22 | 23 | 24 | @pytest.mark.parametrize("query_id", range(1, 23)) 25 | def test_tpch_answer_generation(tpch_answer_generator, query_id): 26 | context = TPCHQueryContext(query_id=query_id) 27 | out = tpch_answer_generator(context) 28 | assert out is not None 29 | assert len(out.df) > 0 30 | 31 | 32 | @pytest.mark.parametrize("table", get_args(TPCHTable)) 33 | def test_tpch_data_generation(tpch_data_generator, table): 34 | context = TPCHTableContext(table=table) 35 | out = tpch_data_generator(context) 36 | assert out is not None 37 | assert len(out.df) > 0 38 | 39 | 40 | @pytest.mark.parametrize("query_id", range(1, 23)) 41 | def test_tpch_queries(tpch_answer_generator, tpch_data_generator, query_id): 42 | runner = TPCHQueryRunner(table_provider=tpch_data_generator) 43 | context = TPCHQueryContext(query_id=query_id) 44 | answer = tpch_answer_generator(context) 45 | out = runner(context) 46 | assert out is not None 47 | assert answer is not None 48 | assert_frame_equal(out.df.to_polars(), answer.df.to_polars(), check_dtypes=False) 49 | -------------------------------------------------------------------------------- /ccflow/tests/test_lazy_result.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar 2 | 3 | from pydantic import ConfigDict, PrivateAttr, model_validator 4 | 5 | from ccflow.base import ResultBase, make_lazy_result 6 | 7 | 8 | class MyResult(ResultBase): 9 | total: ClassVar[int] = 0 # To track instantiations 10 | value: bool = False 11 | model_config = ConfigDict(extra="allow") 12 | _private: str = PrivateAttr(default="bar") 13 | 14 | @model_validator(mode="after") 15 | def _validate(self): 16 | # Track construction by incrementing the total each time the validation is called 17 | MyResult.total += 1 18 | return self 19 | 20 | 21 | def test_make_lazy_result(): 22 | assert MyResult.total == 0 23 | result = MyResult() 24 | assert MyResult.total == 1 25 | assert not result.value 26 | 27 | lazy_result = make_lazy_result(MyResult, lambda: MyResult(value=True)) 28 | assert isinstance(lazy_result, MyResult) 29 | assert MyResult.total == 1 # Constructing the lazy result did not increment the total 30 | assert lazy_result.value 31 | assert MyResult.total == 2 # Accessing the value in the line above did increment the total 32 | assert lazy_result == MyResult(value=True) 33 | 34 | result = MyResult(value=True, extra_field="foo") 35 | assert "value" in result.__pydantic_fields_set__ 36 | assert result.__pydantic_extra__["extra_field"] == "foo" 37 | assert result.__pydantic_private__["_private"] == "bar" 38 | 39 | lazy_result = make_lazy_result(MyResult, lambda: MyResult(value=True, extra_field="foo")) 40 | assert "value" in lazy_result.__pydantic_fields_set__ 41 | assert lazy_result.__pydantic_extra__["extra_field"] == "foo" 42 | assert lazy_result.__pydantic_private__["_private"] == "bar" 43 | -------------------------------------------------------------------------------- /ccflow/tests/utils/test_formatter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from io import StringIO 3 | from unittest import TestCase 4 | 5 | import narwhals as nw 6 | import pyarrow as pa 7 | 8 | from ccflow.result.narwhals import NarwhalsDataFrameResult 9 | from ccflow.result.pyarrow import ArrowResult 10 | from ccflow.utils.formatter import PolarsTableFormatter 11 | 12 | 13 | class TestPolarsTableFormatter(TestCase): 14 | """Test the PolarsTableFormatter.""" 15 | 16 | def setUp(self): 17 | # Set up a logger with a StringIO stream to capture output 18 | self.logger = logging.getLogger("test_logger") 19 | self.logger.setLevel(logging.DEBUG) 20 | 21 | self.log_stream = StringIO() 22 | handler = logging.StreamHandler(self.log_stream) 23 | 24 | # Use the custom formatter 25 | formatter = PolarsTableFormatter("%(name)s - %(levelname)s - %(message)s", polars_config={"tbl_rows": 5}) 26 | handler.setFormatter(formatter) 27 | 28 | self.logger.addHandler(handler) 29 | 30 | def tearDown(self): 31 | # Remove handlers after each test 32 | self.logger.handlers.clear() 33 | 34 | def test_arrow(self): 35 | table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) 36 | self.logger.debug("Result:", extra={"result": ArrowResult(table=table)}) 37 | log = self.log_stream.getvalue().strip() 38 | assert log.startswith("test_logger - DEBUG - Result:\n{'_target_': 'ccflow.result.pyarrow.ArrowResult',\n 'table': shape: (3, 2)\n") 39 | 40 | def test_narwhals(self): 41 | df = nw.from_dict({"a": [1, 2, 3], "b": ["x", "y", "z"]}, backend="polars") 42 | self.logger.debug("Result:", extra={"result": NarwhalsDataFrameResult(df=df)}) 43 | log = self.log_stream.getvalue().strip() 44 | assert log.startswith("test_logger - DEBUG - Result:\n{'_target_': 'ccflow.result.narwhals.NarwhalsDataFrameResult',\n 'df': shape: (3, 2)\n") 45 | -------------------------------------------------------------------------------- /ccflow/publisher.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Dict, TypeVar 3 | 4 | from pydantic import ConfigDict, Field 5 | from typing_extensions import override 6 | 7 | from .base import BaseModel 8 | from .exttypes import JinjaTemplate 9 | 10 | __all__ = ("BasePublisher", "NullPublisher", "PublisherType") 11 | 12 | 13 | class BasePublisher(BaseModel, abc.ABC): 14 | """A publisher is a configurable object (flow base model) that knows how to "publish" typed python objects. 15 | 16 | We use pydantic's type declarations to define the "type" of data that the publisher knows how to publish. 17 | The naming convention for publishers is WhatWherePublisher or just WherePublisher if Any type is supported. 18 | """ 19 | 20 | model_config = ConfigDict( 21 | # Want to validate assignment so that when new data is set on a publisher, validation gets applied 22 | validate_assignment=True, 23 | # Many publishers will require arbitrary types set on data 24 | arbitrary_types_allowed=True, 25 | ) 26 | name: JinjaTemplate = Field(None, description="The 'name' by which to publish that data element") 27 | name_params: Dict[str, Any] = Field(default_factory=dict, description="The parameters for the name template") 28 | data: Any = Field(None, description="The data we are going to publish") 29 | 30 | def get_name(self): 31 | """Get the name with the template parameters filled in.""" 32 | if not self.name: 33 | raise ValueError("Name must be set") 34 | return self.name.template.render(**self.name_params) 35 | 36 | @abc.abstractmethod 37 | def __call__(self) -> Any: 38 | """Publish the data.""" 39 | 40 | 41 | PublisherType = TypeVar("CallableModelType", bound=BasePublisher) 42 | 43 | 44 | class NullPublisher(BasePublisher): 45 | """A publisher which does nothing!""" 46 | 47 | @override 48 | def __call__(self) -> Any: 49 | pass 50 | -------------------------------------------------------------------------------- /ccflow/exttypes/exprtk.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pydantic import TypeAdapter 4 | from pydantic_core import core_schema 5 | from typing_extensions import Self 6 | 7 | __all__ = ("ExprTkExpression",) 8 | 9 | 10 | def _import_cexprtk(): 11 | try: 12 | import cexprtk 13 | 14 | return cexprtk 15 | except ImportError: 16 | raise ValueError("Unable to import cexprtk. Please make sure you have it installed.") 17 | 18 | 19 | class ExprTkExpression(str): 20 | """Wrapper around a string that represents an ExprTk expression.""" 21 | 22 | @classmethod 23 | def __get_pydantic_core_schema__(cls, source_type, handler): 24 | return core_schema.no_info_plain_validator_function(cls._validate) 25 | 26 | @classmethod 27 | def _validate(cls, value: Any) -> Self: 28 | if isinstance(value, ExprTkExpression): 29 | return value 30 | 31 | if isinstance(value, str): 32 | cexprtk = _import_cexprtk() 33 | try: 34 | cexprtk.check_expression(value) 35 | except cexprtk.ParseException as e: 36 | raise ValueError(f"Error parsing expression {value}. {e}") 37 | 38 | return cls(value) 39 | 40 | raise ValueError(f"{value} cannot be converted into an ExprTkExpression.") 41 | 42 | @classmethod 43 | def validate(cls, value: Any) -> Self: 44 | """Try to convert/validate an arbitrary value to a Frequency.""" 45 | return _TYPE_ADAPTER.validate_python(value) 46 | 47 | def expression(self, symbol_table: Any) -> Any: 48 | """Make a cexprtk.Expression from a symbol table. 49 | 50 | Args: 51 | symbol_table: cexprtk.Symbol_Table 52 | 53 | Returns: 54 | An cexprtk.Expression. 55 | """ 56 | cexprtk = _import_cexprtk() 57 | return cexprtk.Expression(str(self), symbol_table) 58 | 59 | 60 | _TYPE_ADAPTER = TypeAdapter(ExprTkExpression) 61 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build Status 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - v* 9 | paths-ignore: 10 | - LICENSE 11 | - README.md 12 | pull_request: 13 | workflow_dispatch: 14 | 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 17 | cancel-in-progress: true 18 | 19 | permissions: 20 | checks: write 21 | contents: read 22 | pull-requests: write 23 | 24 | jobs: 25 | build: 26 | runs-on: ${{ matrix.os }} 27 | 28 | strategy: 29 | matrix: 30 | os: [ubuntu-latest, macos-latest] 31 | python-version: ['3.10'] 32 | 33 | steps: 34 | - uses: actions/checkout@v6 35 | 36 | - uses: actions-ext/python/setup@main 37 | with: 38 | version: ${{ matrix.python-version }} 39 | 40 | - name: Install dependencies 41 | run: make develop 42 | 43 | - name: Lint 44 | run: make lint 45 | 46 | - name: Checks 47 | run: make checks 48 | 49 | - name: Build 50 | run: make build 51 | 52 | - name: Test 53 | run: make coverage 54 | 55 | - name: Upload test results (Python) 56 | uses: actions/upload-artifact@v6 57 | with: 58 | name: py-test-results-${{ matrix.os }}-${{ matrix.python-version }}- 59 | path: junit.xml 60 | if: ${{ matrix.os == 'ubuntu-latest' }} 61 | 62 | - name: Publish Unit Test Results 63 | uses: EnricoMi/publish-unit-test-result-action@v2 64 | with: 65 | files: '**/junit.xml' 66 | if: ${{ matrix.os == 'ubuntu-latest' }} 67 | 68 | - name: Upload coverage 69 | uses: codecov/codecov-action@v5 70 | with: 71 | token: ${{ secrets.CODECOV_TOKEN }} 72 | if: matrix.os == 'ubuntu-latest' 73 | 74 | - name: Twine check 75 | run: make dist 76 | 77 | - uses: actions/upload-artifact@v6 78 | with: 79 | name: dist-${{matrix.os}} 80 | path: dist 81 | if: matrix.os == 'ubuntu-latest' 82 | -------------------------------------------------------------------------------- /docs/wiki/First-Steps.md: -------------------------------------------------------------------------------- 1 | # First Steps 2 | 3 | This short example shows some of the key features of the configuration framework in `ccflow`: 4 | 5 | ```python 6 | from ccflow import BaseModel, ModelRegistry 7 | 8 | # Define config objects 9 | class MyFileConfig(BaseModel): 10 | file: str 11 | description: str = "" 12 | 13 | class MyTransformConfig(BaseModel): 14 | x: MyFileConfig 15 | y: MyFileConfig = None 16 | param: float = 0. 17 | 18 | 19 | # Define json configs 20 | configs = { 21 | "data": { 22 | "source1": { 23 | "_target_": "__main__.MyFileConfig", 24 | "file": "source1.csv", 25 | "description": "First", 26 | }, 27 | "source2": { 28 | "_target_": "__main__.MyFileConfig", 29 | "file": "source2.csv", 30 | "description": "Second", 31 | }, 32 | "source3": { 33 | "_target_": "__main__.MyFileConfig", 34 | "file": "source3.csv", 35 | "description": "Third", 36 | }, 37 | }, 38 | "transform": { 39 | "_target_": "__main__.MyTransformConfig", 40 | "x": "data/source1", 41 | "y": "data/source2", 42 | }, 43 | } 44 | 45 | # Register configs 46 | root = ModelRegistry.root().clear() 47 | root.load_config(configs) 48 | 49 | # List the keys in the registry 50 | print(list(root)) 51 | #> ['data', 'data/source1', 'data/source2', 'data/source3', 'transform'] 52 | 53 | # Access configs from the registry 54 | print(root["transform"]) 55 | #> MyTransformConfig( 56 | # x=MyFileConfig(file='source1.csv', description='First'), 57 | # y=MyFileConfig(file='source2.csv', description='Second'), 58 | # param=0) 59 | 60 | # Assign config objects by name 61 | root["transform"].x = "data/source3" 62 | print(root["transform"].x) 63 | #> MyFileConfig(file='source3.csv', description='Third') 64 | 65 | # Propagate low-level changes to the top 66 | root["data/source3"].file = "source3_amended.csv" 67 | # See that it changes in the "transform" definition 68 | print(root["transform"].x.file) 69 | #> source3_amended.csv 70 | ``` 71 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q8.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import date 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query( 13 | part_ds: FrameT, 14 | supplier_ds: FrameT, 15 | line_item_ds: FrameT, 16 | orders_ds: FrameT, 17 | customer_ds: FrameT, 18 | nation_ds: FrameT, 19 | region_ds: FrameT, 20 | ) -> FrameT: 21 | nation = "BRAZIL" 22 | region = "AMERICA" 23 | type = "ECONOMY ANODIZED STEEL" 24 | date1 = date(1995, 1, 1) 25 | date2 = date(1996, 12, 31) 26 | 27 | n1 = nation_ds.select("n_nationkey", "n_regionkey") 28 | n2 = nation_ds.select("n_nationkey", "n_name") 29 | 30 | return ( 31 | part_ds.join(line_item_ds, left_on="p_partkey", right_on="l_partkey") 32 | .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") 33 | .join(orders_ds, left_on="l_orderkey", right_on="o_orderkey") 34 | .join(customer_ds, left_on="o_custkey", right_on="c_custkey") 35 | .join(n1, left_on="c_nationkey", right_on="n_nationkey") 36 | .join(region_ds, left_on="n_regionkey", right_on="r_regionkey") 37 | .filter(nw.col("r_name") == region) 38 | .join(n2, left_on="s_nationkey", right_on="n_nationkey") 39 | .filter(nw.col("o_orderdate").is_between(date1, date2)) 40 | .filter(nw.col("p_type") == type) 41 | .select( 42 | nw.col("o_orderdate").dt.year().alias("o_year"), 43 | (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("volume"), 44 | nw.col("n_name").alias("nation"), 45 | ) 46 | .with_columns( 47 | nw.when(nw.col("nation") == nation) 48 | .then(nw.col("volume")) 49 | .otherwise(0) 50 | .alias("_tmp") 51 | ) 52 | .group_by("o_year") 53 | .agg(_tmp_sum=nw.sum("_tmp"), volume_sum=nw.sum("volume")) 54 | .select("o_year", mkt_share=nw.col("_tmp_sum") / nw.col("volume_sum")) 55 | .sort("o_year") 56 | ) 57 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/queries/q7.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING 5 | 6 | import narwhals as nw 7 | 8 | if TYPE_CHECKING: 9 | from narwhals.typing import FrameT 10 | 11 | 12 | def query( 13 | nation_ds: FrameT, 14 | customer_ds: FrameT, 15 | line_item_ds: FrameT, 16 | orders_ds: FrameT, 17 | supplier_ds: FrameT, 18 | ) -> FrameT: 19 | n1 = nation_ds.filter(nw.col("n_name") == "FRANCE") 20 | n2 = nation_ds.filter(nw.col("n_name") == "GERMANY") 21 | 22 | var_1 = datetime(1995, 1, 1) 23 | var_2 = datetime(1996, 12, 31) 24 | 25 | df1 = ( 26 | customer_ds.join(n1, left_on="c_nationkey", right_on="n_nationkey") 27 | .join(orders_ds, left_on="c_custkey", right_on="o_custkey") 28 | .rename({"n_name": "cust_nation"}) 29 | .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") 30 | .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") 31 | .join(n2, left_on="s_nationkey", right_on="n_nationkey") 32 | .rename({"n_name": "supp_nation"}) 33 | ) 34 | 35 | df2 = ( 36 | customer_ds.join(n2, left_on="c_nationkey", right_on="n_nationkey") 37 | .join(orders_ds, left_on="c_custkey", right_on="o_custkey") 38 | .rename({"n_name": "cust_nation"}) 39 | .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") 40 | .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey") 41 | .join(n1, left_on="s_nationkey", right_on="n_nationkey") 42 | .rename({"n_name": "supp_nation"}) 43 | ) 44 | 45 | return ( 46 | nw.concat([df1, df2]) 47 | .filter(nw.col("l_shipdate").is_between(var_1, var_2)) 48 | .with_columns( 49 | (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("volume") 50 | ) 51 | .with_columns(nw.col("l_shipdate").dt.year().alias("l_year")) 52 | .group_by("supp_nation", "cust_nation", "l_year") 53 | .agg(nw.sum("volume").alias("revenue")) 54 | .sort(by=["supp_nation", "cust_nation", "l_year"]) 55 | ) 56 | -------------------------------------------------------------------------------- /ccflow/tests/examples/test_etl.py: -------------------------------------------------------------------------------- 1 | from tempfile import NamedTemporaryFile 2 | from unittest.mock import patch 3 | 4 | from ccflow.examples.etl.__main__ import main 5 | from ccflow.examples.etl.explain import explain 6 | from ccflow.examples.etl.models import DBModel, LinksModel, RestModel, SiteContext 7 | 8 | 9 | class TestEtl: 10 | def test_rest_model(self): 11 | rest = RestModel() 12 | context = SiteContext(site="https://news.ycombinator.com") 13 | result = rest(context) 14 | assert result.value is not None 15 | assert "Hacker News" in result.value 16 | 17 | def test_links_model(self): 18 | with NamedTemporaryFile(suffix=".html") as file: 19 | file.write(b""" 20 | 21 |
22 | Page 1 23 | Page 2 24 | 25 | 26 | """) 27 | file.flush() 28 | links = LinksModel(file=file.name) 29 | result = links() 30 | assert result.value is not None 31 | assert "name,url" in result.value # Check for CSV header 32 | 33 | def test_db_model(self): 34 | with NamedTemporaryFile(suffix=".csv", mode="w+", delete=False) as file: 35 | file.write("name,url\nPage 1,https://example.com/page1\nPage 2,https://example.com/page2\n") 36 | file.flush() 37 | db = DBModel(file=file.name, db_file=":memory:", table="links") 38 | result = db() 39 | assert result.value == "Data loaded into database" 40 | 41 | def test_cli(self): 42 | with patch("ccflow.examples.etl.__main__.cfg_run") as mock_cfg_run: 43 | with patch("sys.argv", ["etl", "+callable=extract", "+context=[]"]): 44 | main() 45 | mock_cfg_run.assert_called_once() 46 | 47 | def test_explain(self): 48 | with patch("ccflow.examples.etl.explain.cfg_explain_cli") as mock_cfg_explain: 49 | explain() 50 | mock_cfg_explain.assert_called_once() 51 | -------------------------------------------------------------------------------- /ccflow/tests/exttypes/test_exprtk.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | from ccflow import BaseModel, ExprTkExpression 5 | 6 | 7 | def _has_cexprtk() -> bool: 8 | try: 9 | import cexprtk # noqa F401 10 | 11 | return True 12 | except ImportError: 13 | return False 14 | 15 | 16 | class MyModel(BaseModel): 17 | expression: ExprTkExpression 18 | 19 | 20 | class TestExprTkExpression(TestCase): 21 | @unittest.skipIf(_has_cexprtk(), "Requires cexprtk to not be installed") 22 | def test_no_cexprtk(self): 23 | self.assertRaisesRegex(ValueError, "Unable to import cexprtk. Please make sure you have it installed.", ExprTkExpression.validate, "1.0") 24 | 25 | @unittest.skipIf(not _has_cexprtk(), "Requires cexprtk to be installed") 26 | def test(self): 27 | import cexprtk 28 | 29 | symbol_table = cexprtk.Symbol_Table({"a": 1.0, "b": 2.0}) 30 | 31 | # Constant 32 | e = ExprTkExpression.validate("1.0") 33 | self.assertAlmostEqual(1.0, e.expression(symbol_table)()) 34 | 35 | # Valid 36 | e = ExprTkExpression.validate("1.0 + a * b") 37 | self.assertAlmostEqual(3.0, e.expression(symbol_table)()) 38 | 39 | # Valid 40 | e = ExprTkExpression.validate("-a * b") 41 | self.assertAlmostEqual(-2.0, e.expression(symbol_table)()) 42 | 43 | # Invalid 44 | self.assertRaisesRegex(ValueError, "Error parsing expression.*", ExprTkExpression.validate, "1a++") 45 | 46 | # Wrong types 47 | self.assertRaisesRegex(ValueError, ".*cannot be converted.*", ExprTkExpression.validate, None) 48 | 49 | def test_model(self): 50 | expression = "1.0 + a" 51 | if _has_cexprtk(): 52 | import cexprtk 53 | 54 | m = MyModel(expression=expression) 55 | symbol_table = cexprtk.Symbol_Table({"a": 1.0}) 56 | self.assertAlmostEqual(2.0, m.expression.expression(symbol_table)()) 57 | else: 58 | self.assertRaisesRegex(ValueError, "Unable to import cexprtk. Please make sure you have it installed.", MyModel, expression=expression) 59 | -------------------------------------------------------------------------------- /ccflow/exttypes/frequency.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from datetime import timedelta 3 | from functools import cached_property 4 | from typing import Type 5 | 6 | import pandas as pd 7 | from pandas.tseries.frequencies import to_offset 8 | from pydantic import TypeAdapter 9 | from pydantic_core import core_schema 10 | 11 | 12 | class Frequency(str): 13 | """Represents a frequency string that can be converted to a pandas offset.""" 14 | 15 | validate_always = True 16 | 17 | @cached_property 18 | def offset(self) -> Type: 19 | """Return the underlying pandas DateOffset object.""" 20 | return to_offset(str(self)) 21 | 22 | @cached_property 23 | def timedelta(self) -> timedelta: 24 | return pd.to_timedelta(self.offset).to_pytimedelta() 25 | 26 | @classmethod 27 | def __get_pydantic_core_schema__(cls, source_type, handler): 28 | return core_schema.no_info_plain_validator_function(cls._validate) 29 | 30 | @classmethod 31 | def _validate(cls, value) -> "Frequency": 32 | if isinstance(value, cls): 33 | return cls._validate(str(value)) 34 | 35 | if isinstance(value, (timedelta, str)): 36 | try: 37 | with warnings.catch_warnings(): 38 | # Because pandas 2.2 deprecated many frequency strings (i.e. "Y", "M", "T" still in common use) 39 | # We should consider switching away from pandas on this and supporting ISO 40 | warnings.simplefilter("ignore", category=FutureWarning) 41 | value = to_offset(value) 42 | except ValueError as e: 43 | raise ValueError(f"ensure this value can be converted to a pandas offset: {e}") 44 | 45 | if isinstance(value, pd.offsets.DateOffset): 46 | return cls(f"{value.n}{value.base.freqstr}") 47 | 48 | raise ValueError(f"ensure this value can be converted to a pandas offset: {value}") 49 | 50 | @classmethod 51 | def validate(cls, value) -> "Frequency": 52 | """Try to convert/validate an arbitrary value to a Frequency.""" 53 | return _TYPE_ADAPTER.validate_python(value) 54 | 55 | 56 | _TYPE_ADAPTER = TypeAdapter(Frequency) 57 | -------------------------------------------------------------------------------- /ccflow/tests/test_global_state.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import pytest 4 | 5 | from ccflow import BaseModel, EvaluatorBase, FlowOptionsOverride, GenericResult, GlobalState, ModelEvaluationContext, ModelRegistry 6 | 7 | 8 | @pytest.fixture 9 | def root_registry(): 10 | r = ModelRegistry.root() 11 | r.clear() 12 | yield r 13 | r.clear() 14 | 15 | 16 | class DummyModel(BaseModel): 17 | name: str 18 | 19 | 20 | class DummyEvaluator(EvaluatorBase): 21 | def __call__(self, context: ModelEvaluationContext): 22 | return GenericResult(value="test") 23 | 24 | 25 | def test_global_state(root_registry): 26 | root_registry.add("foo", DummyModel(name="foo")) 27 | evaluator = DummyEvaluator() 28 | with FlowOptionsOverride(options=dict(evaluator=evaluator)) as override: 29 | state = GlobalState() 30 | 31 | # Now clear the registry, and add a different model 32 | root_registry.clear() 33 | root_registry.add("bar", DummyModel(name="bar")) 34 | assert "foo" in state.root_registry.models 35 | assert "bar" not in state.root_registry.models 36 | assert state.open_overrides == {id(override): override} 37 | 38 | with state: 39 | state2 = GlobalState() 40 | assert "foo" in state2.root_registry.models 41 | assert "bar" not in state2.root_registry.models 42 | assert state2.open_overrides == {id(override): override} 43 | 44 | # Test that global state doesn't persist outside of the context manager 45 | state3 = GlobalState() 46 | assert "foo" not in state3.root_registry.models 47 | assert "bar" in state3.root_registry.models 48 | assert state3.open_overrides == {} 49 | 50 | 51 | def test_global_state_pickle(): 52 | r = ModelRegistry.root() 53 | r.add("foo", DummyModel(name="foo")) 54 | evaluator = DummyEvaluator() 55 | with FlowOptionsOverride(options=dict(evaluator=evaluator)) as override: 56 | state = GlobalState() 57 | 58 | # Now pickle and unpickle the state 59 | state_pickled = pickle.dumps(state) 60 | state_unpickled = pickle.loads(state_pickled) 61 | 62 | assert "foo" in state_unpickled.root_registry.models 63 | assert state_unpickled.open_overrides == {id(override): override} 64 | -------------------------------------------------------------------------------- /ccflow/tests/result/test_pyarrow.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from unittest import TestCase 3 | 4 | import pandas as pd 5 | import polars as pl 6 | import pyarrow as pa 7 | 8 | from ccflow.context import DateRangeContext 9 | from ccflow.result.pyarrow import ArrowDateRangeResult, ArrowResult 10 | 11 | 12 | class TestResult(TestCase): 13 | def test_arrow_from_pandas(self): 14 | df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) 15 | r = ArrowResult.model_validate({"table": df}) 16 | self.assertIsInstance(r.table, pa.Table) 17 | 18 | r = ArrowResult(table=df) 19 | self.assertIsInstance(r.table, pa.Table) 20 | 21 | r = ArrowResult.model_validate(df) 22 | self.assertIsInstance(r.table, pa.Table) 23 | self.assertIsInstance(r.df.to_native(), pa.Table) 24 | 25 | def test_arrow_from_polars(self): 26 | df = pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) 27 | r = ArrowResult.model_validate({"table": df}) 28 | self.assertIsInstance(r.table, pa.Table) 29 | 30 | r = ArrowResult(table=df) 31 | self.assertIsInstance(r.table, pa.Table) 32 | 33 | r = ArrowResult.model_validate(df) 34 | self.assertIsInstance(r.table, pa.Table) 35 | self.assertIsInstance(r.df.to_native(), pa.Table) 36 | 37 | def test_arrow_date_range(self): 38 | df = pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "D": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]}) 39 | context = DateRangeContext(start_date=date(2020, 1, 1), end_date=date(2020, 1, 3)) 40 | r = ArrowDateRangeResult.model_validate({"table": df, "date_col": "D", "context": context}) 41 | self.assertIsInstance(r.table, pa.Table) 42 | self.assertIsInstance(r.df.to_native(), pa.Table) 43 | 44 | self.assertRaises(ValueError, ArrowDateRangeResult.model_validate, {"table": df, "date_col": "B", "context": context}) 45 | self.assertRaises(ValueError, ArrowDateRangeResult.model_validate, {"table": df, "date_col": "E", "context": context}) 46 | context = DateRangeContext(start_date=date(2020, 1, 1), end_date=date(2020, 1, 2)) 47 | self.assertRaises(ValueError, ArrowDateRangeResult.model_validate, {"table": df, "date_col": "E", "context": context}) 48 | -------------------------------------------------------------------------------- /ccflow/examples/tpch/query.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from typing import Dict, Tuple 3 | 4 | from pydantic import Field 5 | 6 | from ccflow import CallableModel, CallableModelGenericType, Flow 7 | from ccflow.result.narwhals import NarwhalsFrameResult 8 | 9 | from .base import TPCHQueryContext, TPCHTable, TPCHTableContext 10 | 11 | __all__ = ("TPCHQueryRunner",) 12 | 13 | 14 | _QUERY_TABLE_MAP: Dict[int, Tuple[TPCHTable, ...]] = { 15 | 1: ("lineitem",), 16 | 2: ("region", "nation", "supplier", "part", "partsupp"), 17 | 3: ("customer", "lineitem", "orders"), 18 | 4: ("lineitem", "orders"), 19 | 5: ("region", "nation", "customer", "lineitem", "orders", "supplier"), 20 | 6: ("lineitem",), 21 | 7: ("nation", "customer", "lineitem", "orders", "supplier"), 22 | 8: ("part", "supplier", "lineitem", "orders", "customer", "nation", "region"), 23 | 9: ("part", "partsupp", "nation", "lineitem", "orders", "supplier"), 24 | 10: ("customer", "nation", "lineitem", "orders"), 25 | 11: ("nation", "partsupp", "supplier"), 26 | 12: ("lineitem", "orders"), 27 | 13: ("customer", "orders"), 28 | 14: ("lineitem", "part"), 29 | 15: ("lineitem", "supplier"), 30 | 16: ("part", "partsupp", "supplier"), 31 | 17: ("lineitem", "part"), 32 | 18: ("customer", "lineitem", "orders"), 33 | 19: ("lineitem", "part"), 34 | 20: ("part", "partsupp", "nation", "lineitem", "supplier"), 35 | 21: ("lineitem", "nation", "orders", "supplier"), 36 | 22: ("customer", "orders"), 37 | } 38 | 39 | 40 | class TPCHQueryRunner(CallableModel): 41 | """Generically runs TPC-H queries from a pre-packaged repository of queries (courtesy of narwhals).""" 42 | 43 | table_provider: CallableModelGenericType[TPCHTableContext, NarwhalsFrameResult] 44 | query_table_map: Dict[int, Tuple[TPCHTable, ...]] = Field(_QUERY_TABLE_MAP, validate_default=True) 45 | 46 | @Flow.call 47 | def __call__(self, context: TPCHQueryContext) -> NarwhalsFrameResult: 48 | query_module = import_module(f"ccflow.examples.tpch.queries.q{context.query_id}") 49 | inputs = (self.table_provider(TPCHTableContext(table=table)).df for table in self.query_table_map[context.query_id]) 50 | result = query_module.query(*inputs) 51 | return NarwhalsFrameResult(df=result) 52 | -------------------------------------------------------------------------------- /ccflow/models/publisher.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, Type 2 | 3 | from pydantic import Field 4 | from typing_extensions import override 5 | 6 | from ..callable import CallableModelType, ContextType, Flow, ResultType, WrapperModel 7 | from ..publisher import PublisherType 8 | from ..result import GenericResult 9 | 10 | __all__ = ("PublisherModel",) 11 | 12 | 13 | class PublisherModel( 14 | WrapperModel[CallableModelType], 15 | Generic[CallableModelType, PublisherType], 16 | ): 17 | """Model that chains together a callable model and a publisher to publish the results of the callable model.""" 18 | 19 | publisher: PublisherType 20 | field: str = Field(None, description="Specific field on model output to publish") 21 | return_data: bool = Field( 22 | False, 23 | description="Whether to return the underlying model result as the output instead of the publisher output", 24 | ) 25 | 26 | @property 27 | def result_type(self) -> Type[ResultType]: 28 | """Result type that will be returned. Could be over-ridden by child class.""" 29 | if self.return_data: 30 | return self.model.result_type 31 | else: 32 | return GenericResult 33 | 34 | def _get_publisher(self, context): 35 | publisher = self.publisher.model_copy() 36 | # Set the name, if needed 37 | if not publisher.name and self.meta.name: 38 | publisher.name = self.meta.name 39 | # Augment any existing name parameters with the context parameters 40 | name_params = publisher.name_params.copy() 41 | name_params.update(context.model_dump(exclude={"type_"})) 42 | publisher.name_params = name_params 43 | return publisher 44 | 45 | @override 46 | @Flow.call 47 | def __call__(self, context: ContextType) -> ResultType: 48 | """This method gets the result from the underlying model, and publishes it.""" 49 | publisher = self._get_publisher(context) 50 | data = self.model(context) 51 | if self.field: 52 | pub_data = getattr(data, self.field) 53 | else: 54 | pub_data = data 55 | publisher.data = pub_data 56 | out = publisher() 57 | if self.return_data: 58 | return data 59 | else: 60 | return self.result_type(value=out) 61 | -------------------------------------------------------------------------------- /ccflow/tests/result/test_generic.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from ccflow import GenericContext, GenericResult 4 | 5 | 6 | class TestGenericResult(TestCase): 7 | def test_generic(self): 8 | v = {"a": 1, "b": [2, 3]} 9 | result = GenericResult(value=v) 10 | self.assertEqual(GenericResult.model_validate(v), result) 11 | self.assertIs(GenericResult.model_validate(result), result) 12 | 13 | v = {"value": 5} 14 | self.assertEqual(GenericResult.model_validate(v), GenericResult(value=5)) 15 | self.assertEqual(GenericResult[int].model_validate(v), GenericResult[int](value=5)) 16 | self.assertEqual(GenericResult[str].model_validate(v), GenericResult[str](value="5")) 17 | 18 | self.assertEqual(GenericResult.model_validate("foo"), GenericResult(value="foo")) 19 | self.assertEqual(GenericResult[str].model_validate(5), GenericResult[str](value="5")) 20 | 21 | result = GenericResult(value=5) 22 | # Note that this will work, even though GenericResult is not a subclass of GenericResult[str] 23 | self.assertEqual(GenericResult[str].model_validate(result), GenericResult[str](value="5")) 24 | 25 | def test_generics_conversion(self): 26 | v = (1, [2, 3], {4, 5, 6}) 27 | self.assertEqual(GenericResult(value=GenericContext(value=v)), GenericResult(value=v)) 28 | 29 | v = 5 30 | self.assertEqual(GenericResult[str](value=GenericContext(value=v)), GenericResult[str](value=v)) 31 | self.assertEqual(GenericResult[str](value=GenericContext[str](value=v)), GenericResult[str](value=v)) 32 | self.assertEqual(GenericResult[int](value=GenericContext[str](value=v)), GenericResult[int](value=v)) 33 | self.assertEqual(GenericResult[int](value=GenericContext[int](value=v)), GenericResult[int](value=v)) 34 | 35 | v = "5" 36 | self.assertEqual(GenericResult[str](value=GenericContext(value=v)), GenericResult[str](value=v)) 37 | self.assertEqual(GenericResult[str](value=GenericContext[str](value=v)), GenericResult[str](value=v)) 38 | self.assertEqual(GenericResult[int](value=GenericContext[str](value=v)), GenericResult[int](value=v)) 39 | self.assertEqual(GenericResult[int](value=GenericContext[int](value=v)), GenericResult[int](value=v)) 40 | 41 | self.assertEqual(GenericResult[str].model_validate(GenericContext(value=5)), GenericResult[str](value="5")) 42 | -------------------------------------------------------------------------------- /ccflow/serialization.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Union 2 | 3 | import numpy as np 4 | import orjson 5 | 6 | from .enums import Enum 7 | 8 | 9 | def _remove_dict_enums(obj: Any) -> Dict: 10 | if isinstance(obj, Enum): 11 | return obj.name 12 | elif isinstance(obj, dict): 13 | return {_remove_dict_enums(k): _remove_dict_enums(v) for k, v in obj.items()} 14 | return obj 15 | 16 | 17 | def orjson_dumps(v, default=None, *arga, **kwargs) -> str: 18 | """Robust wrapping of orjson dumps to help implement serialization.""" 19 | # orjson.dumps returns bytes, to match standard json.dumps we need to decode 20 | # The default passed to orjson seems to be a partial function 21 | # with the json_encoders as the first argument. We try to perform the 22 | # conversion 23 | options = orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY 24 | try: 25 | return orjson.dumps( 26 | v, 27 | default=default, 28 | option=options, 29 | ).decode() 30 | except orjson.JSONEncodeError: 31 | # if we fail, we try to remove the enums because 32 | # orjson serialization fails when csp enums are 33 | # used as dict keys. See https://github.com/ijl/orjson/issues/445 34 | return orjson.dumps( 35 | _remove_dict_enums(v), 36 | default=default, 37 | option=options, 38 | ).decode() 39 | 40 | 41 | def make_ndarray_orjson_valid(arr: np.ndarray) -> Union[List[Any], np.ndarray]: 42 | """Returns a numpy array or list that is compatible with orjson serialization.""" 43 | if not isinstance(arr, np.ndarray): 44 | raise TypeError(f"Expected np.ndarray instance, got {type(arr)}") 45 | # orjson supports these types: 46 | # https://github.com/ijl/orjson#numpy 47 | # Which types to check: 48 | # https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.number 49 | is_number = np.issubdtype(arr.dtype, np.number) 50 | is_bool = np.issubdtype(arr.dtype, np.bool_) 51 | is_complex = np.issubdtype(arr.dtype, np.complexfloating) 52 | if not (is_bool or is_number) or is_complex: 53 | return arr.tolist() 54 | # Now, we have to make the numpy array c-contiguous. Why: 55 | # https://github.com/ijl/orjson/issues/100 56 | try: 57 | return np.ascontiguousarray(arr) 58 | except MemoryError: 59 | return arr.tolist() 60 | -------------------------------------------------------------------------------- /ccflow/tests/exttypes/test_pyobjectpath.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Generic, TypeVar 3 | from unittest import TestCase 4 | 5 | from ccflow import PyObjectPath 6 | 7 | 8 | class A: 9 | pass 10 | 11 | 12 | T = TypeVar("T") 13 | 14 | 15 | class B(Generic[T]): 16 | t: T 17 | 18 | 19 | class TestPyObjectPath(TestCase): 20 | def test_basic(self): 21 | p = PyObjectPath("ccflow.tests.exttypes.test_pyobjectpath.A") 22 | self.assertIsInstance(p, str) 23 | self.assertEqual(p.object, A) 24 | 25 | p = PyObjectPath("builtins.list") 26 | self.assertIsInstance(p, str) 27 | self.assertEqual(p.object, list) 28 | 29 | def test_validate(self): 30 | self.assertRaises(ValueError, PyObjectPath.validate, None) 31 | self.assertRaises(ValueError, PyObjectPath.validate, "foo") 32 | self.assertRaises(ValueError, PyObjectPath.validate, A()) 33 | 34 | p = PyObjectPath("ccflow.tests.exttypes.test_pyobjectpath.A") 35 | self.assertEqual(PyObjectPath.validate(p), p) 36 | self.assertEqual(PyObjectPath.validate(str(p)), p) 37 | self.assertEqual(PyObjectPath.validate(A), p) 38 | 39 | p = PyObjectPath("builtins.list") 40 | self.assertEqual(PyObjectPath.validate(p), p) 41 | self.assertEqual(PyObjectPath.validate(str(p)), p) 42 | self.assertEqual(PyObjectPath.validate(list), p) 43 | 44 | def test_generics(self): 45 | p = PyObjectPath("ccflow.tests.exttypes.test_pyobjectpath.B") 46 | self.assertEqual(PyObjectPath.validate(p), p) 47 | self.assertEqual(PyObjectPath.validate(str(p)), p) 48 | self.assertEqual(PyObjectPath.validate(B), p) 49 | 50 | p2 = PyObjectPath("ccflow.tests.exttypes.test_pyobjectpath.B[float]") 51 | self.assertRaises(ValueError, PyObjectPath.validate, p2) 52 | # Note that the type information gets stripped from the class, i.e. we compare with p, not p2 53 | self.assertEqual(PyObjectPath.validate(B[float]), p) 54 | # Re-creating the object from the path loses the type information at the moment 55 | self.assertEqual(PyObjectPath.validate(B[float]).object, B) 56 | 57 | def test_pickle(self): 58 | p = PyObjectPath("ccflow.tests.exttypes.test_pyobjectpath.A") 59 | self.assertEqual(p, pickle.loads(pickle.dumps(p))) 60 | p = PyObjectPath.validate("ccflow.tests.exttypes.test_pyobjectpath.A") 61 | self.assertEqual(p, pickle.loads(pickle.dumps(p))) 62 | self.assertIsNotNone(p.object) 63 | self.assertEqual(p, pickle.loads(pickle.dumps(p))) 64 | self.assertEqual(p.object, pickle.loads(pickle.dumps(p.object))) 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.a 8 | *.so 9 | *.obj 10 | *.dll 11 | *.exp 12 | *.lib 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | junit.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | Pipfile.lock 87 | 88 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 89 | __pypackages__/ 90 | 91 | # Celery stuff 92 | celerybeat-schedule 93 | celerybeat.pid 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # Documentation 126 | /site 127 | index.md 128 | docs/_build/ 129 | docs/src/_build/ 130 | docs/api 131 | docs/index.md 132 | docs/html 133 | docs/jupyter_execute 134 | index.md 135 | 136 | # JS 137 | js/coverage 138 | js/dist 139 | js/lib 140 | js/node_modules 141 | js/test-results 142 | js/playwright-report 143 | js/*.tgz 144 | ccflow/extension 145 | 146 | # Jupyter 147 | .ipynb_checkpoints 148 | .autoversion 149 | Untitled*.ipynb 150 | !ccflow/extension/ccflow.json 151 | !ccflow/extension/install.json 152 | ccflow/nbextension 153 | ccflow/labextension 154 | 155 | # Mac 156 | .DS_Store 157 | 158 | # Rust 159 | target 160 | 161 | # Examples 162 | outputs 163 | raw.html 164 | extracted.csv 165 | etl.db 166 | lobsters.html 167 | lobsters.csv 168 | -------------------------------------------------------------------------------- /ccflow/tests/test_base_load_config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from ccflow.base import load_config 6 | 7 | 8 | @pytest.fixture 9 | def basepath(): 10 | # Because os.cwd may change depending on how tests are run 11 | return str(Path(__file__).resolve().parent) 12 | 13 | 14 | def test_root_config(basepath): 15 | root_config_dir = str(Path(__file__).resolve().parent / "config") 16 | r = load_config( 17 | root_config_dir=root_config_dir, 18 | root_config_name="conf", 19 | overwrite=True, 20 | basepath=basepath, 21 | ) 22 | try: 23 | assert len(r.models) 24 | assert "foo" in r.models 25 | assert "bar" in r.models 26 | finally: 27 | r.clear() 28 | 29 | 30 | def test_config_dir(basepath): 31 | root_config_dir = str(Path(__file__).resolve().parent / "config") 32 | config_dir = str(Path(__file__).resolve().parent / "config_user") 33 | r = load_config( 34 | root_config_dir=root_config_dir, 35 | root_config_name="conf", 36 | config_dir=config_dir, 37 | overwrite=True, 38 | basepath=basepath, 39 | ) 40 | try: 41 | assert len(r.models) 42 | assert "foo" in r.models 43 | assert "bar" in r.models 44 | finally: 45 | r.clear() 46 | 47 | 48 | def test_config_name(basepath): 49 | root_config_dir = str(Path(__file__).resolve().parent / "config") 50 | config_dir = str(Path(__file__).resolve().parent / "config_user") 51 | r = load_config( 52 | root_config_dir=root_config_dir, 53 | root_config_name="conf", 54 | config_dir=config_dir, 55 | config_name="sample", 56 | overwrite=True, 57 | basepath=basepath, 58 | ) 59 | try: 60 | assert len(r.models) 61 | assert "foo" in r.models 62 | assert "bar" in r.models 63 | assert "config_user" in r.models 64 | assert "user_foo" in r["config_user"] 65 | finally: 66 | r.clear() 67 | 68 | 69 | def test_config_dir_with_overrides(basepath): 70 | root_config_dir = str(Path(__file__).resolve().parent / "config") 71 | config_dir = str(Path(__file__).resolve().parent) 72 | r = load_config( 73 | root_config_dir=root_config_dir, 74 | root_config_name="conf", 75 | config_dir=config_dir, 76 | overrides=["+config_user=sample"], 77 | overwrite=True, 78 | basepath=basepath, 79 | ) 80 | try: 81 | assert len(r.models) 82 | assert "foo" in r.models 83 | assert "bar" in r.models 84 | assert "config_user" in r.models 85 | assert "user_foo" in r["config_user"] 86 | finally: 87 | r.clear() 88 | -------------------------------------------------------------------------------- /ccflow/utils/chunker.py: -------------------------------------------------------------------------------- 1 | """Functionality to chunk query parameters or simple data structures. 2 | 3 | This is useful as part of creating caching schemes for time series data, where one wants to choose a chunk/page size, 4 | i.e. monthly, daily, etc, and any time any data is needed from the chunk, to load and cache the entire chunk. 5 | Control over the chunk size is important: if it's too big, too much un-necessary data gets loaded, but if it's too 6 | small, performance suffers from too many repeated trips to the underlying data store for long-range queries. 7 | """ 8 | 9 | import warnings 10 | from datetime import date 11 | from typing import List, Tuple 12 | 13 | import pandas as pd 14 | 15 | _MIN_END_DATE = date(1969, 12, 31) 16 | 17 | __all__ = ("dates_to_chunks",) 18 | 19 | 20 | def dates_to_chunks(start: date, end: date, chunk_size: str = "ME", trim: bool = False) -> List[Tuple[date, date]]: 21 | """ 22 | Chunks a date range in a consistent way (i.e. the same middle chunks will always be generated for overlapping 23 | ranges). 24 | 25 | Args: 26 | start: The start date of the time interval to convert to chunks 27 | end: The end date of the time interval to convert to chunks 28 | chunk_size: Any valid Pandas frequency string. i.e. 'D', '2W', 'M'. 29 | trim: Whether to trim the ends to match start and end date exactly (versus standard-sized chunks that cover the interval) 30 | 31 | Returns: 32 | List of tuples of (start date, end date) for each of the chunks 33 | """ 34 | with warnings.catch_warnings(): 35 | # Because pandas 2.2 deprecated many frequency strings (i.e. "Y", "M", "T" still in common use) 36 | # We should consider switching away from pandas on this and supporting ISO 37 | warnings.simplefilter("ignore", category=FutureWarning) 38 | offset = pd.tseries.frequencies.to_offset(chunk_size) 39 | if offset.n == 1: 40 | end_dates = pd.date_range(start - offset, end + offset, freq=chunk_size) 41 | else: 42 | # Need to anchor the timeline at some absolute date, because otherwise chunks might depend on the start date 43 | # and end up overlappig each other, i.e. with 2M, would end up with 44 | # i.e. (Jan-Feb) or (Feb,Mar) depending on whether start date was in Jan or Feb, 45 | # instead of always returning (Jan,Feb) for any start date in either of those two months. 46 | end_dates = pd.date_range(_MIN_END_DATE, end + offset, freq=chunk_size) 47 | start_dates = end_dates + pd.DateOffset(1) 48 | chunks = [(s, e) for s, e in zip(start_dates[:-1].date, end_dates[1:].date) if e >= start and s <= end] 49 | if trim: 50 | if chunks[0][0] < start: 51 | chunks[0] = (start, chunks[0][1]) 52 | if chunks[-1][-1] > end: 53 | chunks[-1] = (chunks[-1][0], end) 54 | return chunks 55 | -------------------------------------------------------------------------------- /ccflow/tests/result/test_narwhals.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | import narwhals.stable.v1 as nw 4 | import polars as pl 5 | import pytest 6 | 7 | from ccflow.exttypes.narwhals import ( 8 | DataFrameT, 9 | SchemaValidator, 10 | ) 11 | from ccflow.result.narwhals import NarwhalsDataFrameResult, NarwhalsFrameResult 12 | 13 | 14 | @pytest.fixture 15 | def data(): 16 | return { 17 | "a": [1.0, 2.0, 3.0], 18 | "b": [4, 5, 6], 19 | "c": ["foo", "bar", "baz"], 20 | "d": [0, 0, 0], 21 | } 22 | 23 | 24 | @pytest.fixture 25 | def schema(): 26 | return { 27 | "a": nw.Float64, 28 | "b": nw.Int64, 29 | "c": nw.String, 30 | "d": nw.Float64, 31 | } 32 | 33 | 34 | def test_narwhals_frame_result(data): 35 | df = pl.DataFrame(data) 36 | result = NarwhalsFrameResult(df=df) 37 | assert isinstance(result.df, nw.DataFrame) 38 | assert result.df.to_native() is df 39 | 40 | df = pl.DataFrame(data).lazy() 41 | result = NarwhalsFrameResult(df=df) 42 | assert isinstance(result.df, nw.LazyFrame) 43 | assert result.df.to_native() is df 44 | 45 | 46 | def test_narwhals_frame_result_validation(data, schema): 47 | # Test that we can automatically validate a dataframe into a result type for convenience 48 | df = pl.DataFrame(data) 49 | result = NarwhalsFrameResult.model_validate(df) 50 | assert isinstance(result.df, nw.DataFrame) 51 | assert result.df.to_native() is df 52 | 53 | result = NarwhalsFrameResult.model_validate(dict(df=df)) 54 | assert isinstance(result.df, nw.DataFrame) 55 | assert result.df.to_native() is df 56 | 57 | df = pl.DataFrame(data).lazy() 58 | result = NarwhalsFrameResult.model_validate(df) 59 | assert isinstance(result.df, nw.LazyFrame) 60 | assert result.df.to_native() is df 61 | 62 | result = NarwhalsFrameResult.model_validate(dict(df=df)) 63 | assert isinstance(result.df, nw.LazyFrame) 64 | assert result.df.to_native() is df 65 | 66 | 67 | def test_narwhals_dataframe_result(data): 68 | df = pl.DataFrame(data) 69 | result = NarwhalsDataFrameResult(df=df) 70 | assert isinstance(result.df, nw.DataFrame) 71 | assert result.df.to_native() is df 72 | 73 | df = pl.DataFrame(data).lazy() 74 | result = NarwhalsDataFrameResult(df=df) 75 | assert isinstance(result.df, nw.DataFrame) 76 | 77 | 78 | def test_collect(data): 79 | df = pl.DataFrame(data) 80 | result = NarwhalsFrameResult(df=df) 81 | result2 = result.collect() 82 | assert isinstance(result2, NarwhalsDataFrameResult) 83 | assert isinstance(result2.df, nw.DataFrame) 84 | 85 | 86 | def test_custom(data, schema): 87 | class MyNarwhalsResult(NarwhalsDataFrameResult): 88 | df: Annotated[DataFrameT, SchemaValidator(schema, cast=True)] 89 | 90 | df = pl.DataFrame(data) 91 | result = MyNarwhalsResult(df=df) 92 | assert result.df.schema["d"] == nw.Float64() 93 | -------------------------------------------------------------------------------- /ccflow/examples/etl/models.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | from csv import DictReader, DictWriter 3 | from io import StringIO 4 | from typing import Optional 5 | 6 | from bs4 import BeautifulSoup 7 | from httpx import Client 8 | from pydantic import Field 9 | 10 | from ccflow import CallableModel, ContextBase, Flow, GenericResult, NullContext 11 | 12 | __all__ = ("RestModel", "LinksModel", "DBModel", "SiteContext") 13 | 14 | 15 | class SiteContext(ContextBase): 16 | """An example of a context object, passed into and between callable models from the command line.""" 17 | 18 | site: str = Field(default="https://news.ycombinator.com") 19 | 20 | 21 | class RestModel(CallableModel): 22 | """Example callable model that fetches a URL and returns the HTML content.""" 23 | 24 | @Flow.call 25 | def __call__(self, context: Optional[SiteContext] = None) -> GenericResult[str]: 26 | context = context or SiteContext() 27 | resp = Client().get(context.site, headers={"User-Agent": "Safari/537.36"}, follow_redirects=True) 28 | resp.raise_for_status() 29 | 30 | return GenericResult[str](value=resp.text) 31 | 32 | 33 | class LinksModel(CallableModel): 34 | """Example callable model that transforms HTML content into CSV of links.""" 35 | 36 | file: str 37 | 38 | @Flow.call 39 | def __call__(self, context: Optional[NullContext] = None) -> GenericResult[str]: 40 | context = context or NullContext() 41 | 42 | with open(self.file, "r") as f: 43 | html = f.read() 44 | 45 | # Use beautifulsoup to convert links into csv of name, url 46 | soup = BeautifulSoup(html, "html.parser") 47 | links = [{"name": a.text, "url": href} for a in soup.find_all("a", href=True) if (href := a["href"]).startswith("http")] 48 | 49 | io = StringIO() 50 | writer = DictWriter(io, fieldnames=["name", "url"]) 51 | writer.writeheader() 52 | writer.writerows(links) 53 | output = io.getvalue() 54 | return GenericResult[str](value=output) 55 | 56 | 57 | class DBModel(CallableModel): 58 | """Example callable model that loads CSV data into a SQLite database.""" 59 | 60 | file: str 61 | db_file: str = Field(default="etl.db") 62 | table: str = Field(default="links") 63 | 64 | @Flow.call 65 | def __call__(self, context: Optional[NullContext] = None) -> GenericResult[str]: 66 | context = context or NullContext() 67 | 68 | conn = sqlite3.connect(self.db_file) 69 | cursor = conn.cursor() 70 | cursor.execute(f"CREATE TABLE IF NOT EXISTS {self.table} (name TEXT, url TEXT)") 71 | with open(self.file, "r") as f: 72 | reader = DictReader(f) 73 | for row in reader: 74 | cursor.execute(f"INSERT INTO {self.table} (name, url) VALUES (?, ?)", (row["name"], row["url"])) 75 | conn.commit() 76 | return GenericResult[str](value="Data loaded into database") 77 | -------------------------------------------------------------------------------- /ccflow/exttypes/pydantic_numpy/ndarray.py: -------------------------------------------------------------------------------- 1 | """Code adapted from MIT-licensed open source library https://github.com/cheind/pydantic-numpy 2 | 3 | MIT License 4 | 5 | Copyright (c) 2022 Christoph Heindl 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | import sys 27 | from typing import Any, Generic, TypeVar 28 | 29 | import numpy as np 30 | from numpy.lib import NumpyVersion 31 | from typing_extensions import get_args 32 | 33 | T = TypeVar("T", bound=np.generic) 34 | 35 | if sys.version_info < (3, 9) or NumpyVersion(np.__version__) < "1.22.0": 36 | nd_array_type = np.ndarray 37 | else: 38 | nd_array_type = np.ndarray[Any, T] 39 | 40 | 41 | class NDArray(Generic[T], nd_array_type): 42 | @classmethod 43 | def _serialize(cls, v, nxt): 44 | # Not as efficient as using orjson, but we need a list type to pass to pydantic, 45 | # and orjson produces us a string. 46 | if v is not None: 47 | v = v.tolist() 48 | return nxt(v) 49 | 50 | @classmethod 51 | def __get_pydantic_core_schema__(cls, source_type, handler): 52 | from pydantic_core import core_schema 53 | 54 | def _validate(v): 55 | subtypes = get_args(source_type) 56 | dtype = subtypes[0] if subtypes and subtypes[0] != Any else None 57 | try: 58 | if dtype: 59 | return np.asarray(v, dtype=dtype) 60 | return np.asarray(v) 61 | 62 | except TypeError: 63 | raise ValueError(f"Unable to convert {v} to an array.") 64 | 65 | return core_schema.no_info_before_validator_function( 66 | _validate, 67 | core_schema.any_schema(), 68 | serialization=core_schema.wrap_serializer_function_ser_schema( 69 | cls._serialize, 70 | info_arg=False, 71 | return_schema=core_schema.list_schema(), 72 | ), 73 | ) 74 | -------------------------------------------------------------------------------- /ccflow/tests/publishers/test_print.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from unittest import TestCase 3 | from unittest.mock import MagicMock, patch 4 | 5 | from pydantic import BaseModel as PydanticBaseModel 6 | 7 | from ccflow.exttypes import NDArray 8 | from ccflow.publishers import ( 9 | LogPublisher, 10 | PrintJSONPublisher, 11 | PrintPublisher, 12 | PrintPydanticJSONPublisher, 13 | PrintYAMLPublisher, 14 | ) 15 | 16 | 17 | class MyTestModel(PydanticBaseModel): 18 | foo: int 19 | bar: date 20 | baz: NDArray[float] 21 | 22 | 23 | class TestPrintPublishers(TestCase): 24 | def test_print(self): 25 | with patch("ccflow.publishers.print.print") as mock_print: 26 | p = PrintPublisher( 27 | name="test_{{param}}", 28 | name_params={"param": "JSON"}, 29 | ) 30 | p.data = {"foo": 5, "bar": date(2020, 1, 1)} 31 | p() 32 | assert mock_print.call_count == 1 33 | mock_print.assert_called_with(p.data) 34 | 35 | def test_log(self): 36 | with patch("logging.getLogger") as mock_getLogger: 37 | mock_getLogger.return_value = MagicMock() 38 | p = LogPublisher( 39 | name="test_{{param}}", 40 | name_params={"param": "JSON"}, 41 | ) 42 | p.data = {"foo": 5, "bar": date(2020, 1, 1)} 43 | p() 44 | assert mock_getLogger.return_value.log.call_count == 1 45 | mock_getLogger.return_value.log.assert_called_with(level=20, msg=p.data) 46 | 47 | def test_json(self): 48 | with patch("ccflow.publishers.print.print") as mock_print: 49 | p = PrintJSONPublisher( 50 | name="test_{{param}}", 51 | name_params={"param": "JSON"}, 52 | kwargs=dict(default=str), 53 | ) 54 | p.data = {"foo": 5, "bar": date(2020, 1, 1)} 55 | p() 56 | assert mock_print.call_count == 1 57 | mock_print.assert_called_with('{"foo":5,"bar":"2020-01-01"}') 58 | 59 | def test_yaml(self): 60 | with patch("ccflow.publishers.print.print") as mock_print: 61 | p = PrintYAMLPublisher( 62 | name="test_{{param}}", 63 | name_params={"param": "JSON"}, 64 | ) 65 | p.data = {"foo": 5, "bar": date(2020, 1, 1)} 66 | p() 67 | assert mock_print.call_count == 1 68 | mock_print.assert_called_with("bar: 2020-01-01\nfoo: 5\n") 69 | 70 | def test_json_pydantic(self): 71 | with patch("ccflow.publishers.print.print") as mock_print: 72 | p = PrintPydanticJSONPublisher( 73 | name="test_{{param}}", 74 | name_params={"param": "JSON"}, 75 | ) 76 | p.data = {"foo": 5, "bar": date(2020, 1, 1)} 77 | p() 78 | assert mock_print.call_count == 1 79 | mock_print.assert_called_with('{"foo":5,"bar":"2020-01-01"}') 80 | -------------------------------------------------------------------------------- /docs/wiki/Build-from-Source.md: -------------------------------------------------------------------------------- 1 | `ccflow` is written in Python. While prebuilt wheels are provided for end users, it is also straightforward to build `ccflow` from either the Python [source distribution](https://packaging.python.org/en/latest/specifications/source-distribution-format/) or the GitHub repository. 2 | 3 | - [Make commands](#make-commands) 4 | - [Prerequisites](#prerequisites) 5 | - [Clone](#clone) 6 | - [Install Python dependencies](#install-python-dependencies) 7 | - [Build](#build) 8 | - [Lint and Autoformat](#lint-and-autoformat) 9 | - [Testing](#testing) 10 | 11 | ## Make commands 12 | 13 | As a convenience, `ccflow` uses a `Makefile` for commonly used commands. You can print the main available commands by running `make` with no arguments 14 | 15 | ```bash 16 | > make 17 | 18 | build build the library 19 | clean clean the repository 20 | fix run autofixers 21 | install install library 22 | lint run lints 23 | test run the tests 24 | ``` 25 | 26 | ## Prerequisites 27 | 28 | `ccflow` has a few system-level dependencies which you can install from your machine package manager. Other package managers like `conda`, `nix`, etc, should also work fine. 29 | 30 | ## Clone 31 | 32 | Clone the repo with: 33 | 34 | ```bash 35 | git clone https://github.com/Point72/ccflow.git 36 | cd ccflow 37 | ``` 38 | 39 | ## Install Python dependencies 40 | 41 | Python build and develop dependencies are specified in the `pyproject.toml`, but you can manually install them: 42 | 43 | ```bash 44 | make requirements 45 | ``` 46 | 47 | Note that these dependencies would otherwise be installed normally as part of [PEP517](https://peps.python.org/pep-0517/) / [PEP518](https://peps.python.org/pep-0518/). 48 | 49 | ## Build 50 | 51 | Build the python project in the usual manner: 52 | 53 | ```bash 54 | make build 55 | ``` 56 | 57 | ## Lint and Autoformat 58 | 59 | `ccflow` has linting and auto formatting. 60 | 61 | | Language | Linter | Autoformatter | Description | 62 | | :------- | :---------- | :------------ | :---------- | 63 | | Python | `ruff` | `ruff` | Style | 64 | | Markdown | `mdformat` | `mdformat` | Style | 65 | | Markdown | `codespell` | | Spelling | 66 | 67 | **Python Linting** 68 | 69 | ```bash 70 | make lint-py 71 | ``` 72 | 73 | **Python Autoformatting** 74 | 75 | ```bash 76 | make fix-py 77 | ``` 78 | 79 | **Documentation Linting** 80 | 81 | ```bash 82 | make lint-docs 83 | ``` 84 | 85 | **Documentation Autoformatting** 86 | 87 | ```bash 88 | make fix-docs 89 | ``` 90 | 91 | ## Testing 92 | 93 | `ccflow` has extensive Python tests. The tests can be run via `pytest`. First, install the Python development dependencies with 94 | 95 | ```bash 96 | make develop 97 | ``` 98 | 99 | **Python** 100 | 101 | ```bash 102 | make test 103 | ``` 104 | -------------------------------------------------------------------------------- /docs/wiki/contribute/Build-from-Source.md: -------------------------------------------------------------------------------- 1 | `ccflow` is written in Python. While prebuilt wheels are provided for end users, it is also straightforward to build `ccflow` from either the Python [source distribution](https://packaging.python.org/en/latest/specifications/source-distribution-format/) or the GitHub repository. 2 | 3 | - [Make commands](#make-commands) 4 | - [Prerequisites](#prerequisites) 5 | - [Clone](#clone) 6 | - [Install Python dependencies](#install-python-dependencies) 7 | - [Build](#build) 8 | - [Lint and Autoformat](#lint-and-autoformat) 9 | - [Testing](#testing) 10 | 11 | ## Make commands 12 | 13 | As a convenience, `ccflow` uses a `Makefile` for commonly used commands. You can print the main available commands by running `make` with no arguments 14 | 15 | ```bash 16 | > make 17 | 18 | build build the library 19 | clean clean the repository 20 | fix run autofixers 21 | install install library 22 | lint run lints 23 | test run the tests 24 | ``` 25 | 26 | ## Prerequisites 27 | 28 | `ccflow` has a few system-level dependencies which you can install from your machine package manager. Other package managers like `conda`, `nix`, etc, should also work fine. 29 | 30 | ## Clone 31 | 32 | Clone the repo with: 33 | 34 | ```bash 35 | git clone https://github.com/Point72/ccflow.git 36 | cd ccflow 37 | ``` 38 | 39 | ## Install Python dependencies 40 | 41 | Python build and develop dependencies are specified in the `pyproject.toml`, but you can manually install them: 42 | 43 | ```bash 44 | make requirements 45 | ``` 46 | 47 | Note that these dependencies would otherwise be installed normally as part of [PEP517](https://peps.python.org/pep-0517/) / [PEP518](https://peps.python.org/pep-0518/). 48 | 49 | ## Build 50 | 51 | Build the python project in the usual manner: 52 | 53 | ```bash 54 | make build 55 | ``` 56 | 57 | ## Lint and Autoformat 58 | 59 | `ccflow` has linting and auto formatting. 60 | 61 | | Language | Linter | Autoformatter | Description | 62 | | :------- | :---------- | :------------ | :---------- | 63 | | Python | `ruff` | `ruff` | Style | 64 | | Markdown | `mdformat` | `mdformat` | Style | 65 | | Markdown | `codespell` | | Spelling | 66 | 67 | **Python Linting** 68 | 69 | ```bash 70 | make lint-py 71 | ``` 72 | 73 | **Python Autoformatting** 74 | 75 | ```bash 76 | make fix-py 77 | ``` 78 | 79 | **Documentation Linting** 80 | 81 | ```bash 82 | make lint-docs 83 | ``` 84 | 85 | **Documentation Autoformatting** 86 | 87 | ```bash 88 | make fix-docs 89 | ``` 90 | 91 | ## Testing 92 | 93 | `ccflow` has extensive Python tests. The tests can be run via `pytest`. First, install the Python development dependencies with 94 | 95 | ```bash 96 | make develop 97 | ``` 98 | 99 | **Python** 100 | 101 | ```bash 102 | make test 103 | ``` 104 | -------------------------------------------------------------------------------- /ccflow/tests/exttypes/test_frequency.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from unittest import TestCase 3 | 4 | import pandas as pd 5 | from packaging.version import parse 6 | from pandas.tseries.frequencies import to_offset 7 | 8 | from ccflow.exttypes.frequency import Frequency 9 | 10 | IS_PD_22 = parse(pd.__version__) >= parse("2.2") 11 | 12 | 13 | class TestFrequency(TestCase): 14 | def test_basic(self): 15 | f = Frequency("5min") 16 | self.assertIsInstance(f, str) 17 | self.assertEqual(f.offset, to_offset("5min")) 18 | self.assertEqual(f.timedelta, timedelta(minutes=5)) 19 | 20 | def test_validate_bad(self): 21 | self.assertRaises(ValueError, Frequency.validate, None) 22 | self.assertRaises(ValueError, Frequency.validate, "foo") 23 | 24 | def test_validate_1D(self): 25 | f = Frequency("1D") 26 | self.assertEqual(Frequency.validate(f), f) 27 | self.assertEqual(Frequency.validate(str(f)), f) 28 | self.assertEqual(Frequency.validate(f.offset), f) 29 | self.assertEqual(Frequency.validate("1d"), f) 30 | self.assertEqual(Frequency.validate(Frequency("1d")), f) 31 | self.assertEqual(Frequency.validate(timedelta(days=1)), f) 32 | 33 | def test_validate_5T(self): 34 | if IS_PD_22: 35 | f = Frequency("5min") 36 | else: 37 | f = Frequency("5T") 38 | self.assertEqual(Frequency.validate(f), f) 39 | self.assertEqual(Frequency.validate(str(f)), f) 40 | self.assertEqual(Frequency.validate(f.offset), f) 41 | self.assertEqual(Frequency.validate("5T"), f) 42 | self.assertEqual(Frequency.validate("5min"), f) 43 | self.assertEqual(Frequency.validate(Frequency("5T")), f) 44 | self.assertEqual(Frequency.validate(Frequency("5min")), f) 45 | self.assertEqual(Frequency.validate(timedelta(minutes=5)), f) 46 | 47 | def test_validate_1M(self): 48 | if IS_PD_22: 49 | f = Frequency("1ME") 50 | else: 51 | f = Frequency("1M") 52 | self.assertEqual(Frequency.validate(f), f) 53 | self.assertEqual(Frequency.validate(str(f)), f) 54 | self.assertEqual(Frequency.validate(f.offset), f) 55 | self.assertEqual(Frequency.validate("1m"), f) 56 | self.assertEqual(Frequency.validate("1M"), f) 57 | self.assertEqual(Frequency.validate(Frequency("1m")), f) 58 | self.assertEqual(Frequency.validate(Frequency("1M")), f) 59 | 60 | def test_validate_1Y(self): 61 | if IS_PD_22: 62 | f = Frequency("1YE-DEC") 63 | else: 64 | f = Frequency("1A-DEC") 65 | self.assertEqual(Frequency.validate(f), f) 66 | self.assertEqual(Frequency.validate(str(f)), f) 67 | self.assertEqual(Frequency.validate(f.offset), f) 68 | self.assertEqual(Frequency.validate("1A-DEC"), f) 69 | self.assertEqual(Frequency.validate("1y"), f) 70 | self.assertEqual(Frequency.validate(Frequency("1A-DEC")), f) 71 | self.assertEqual(Frequency.validate(Frequency("1y")), f) 72 | -------------------------------------------------------------------------------- /ccflow/exttypes/polars.py: -------------------------------------------------------------------------------- 1 | import math 2 | from io import StringIO 3 | from typing import Annotated, Any 4 | 5 | import numpy as np 6 | import orjson 7 | import polars as pl 8 | from packaging import version 9 | from typing_extensions import Self 10 | 11 | __all__ = ("PolarsExpression",) 12 | 13 | 14 | class _PolarsExprPydanticAnnotation: 15 | """Provides a polars expressions from a string""" 16 | 17 | @classmethod 18 | def __get_pydantic_core_schema__(cls, source_type, handler): 19 | from pydantic_core import core_schema 20 | 21 | return core_schema.json_or_python_schema( 22 | json_schema=core_schema.no_info_plain_validator_function(function=cls._decode), 23 | python_schema=core_schema.no_info_plain_validator_function(function=cls._validate), 24 | serialization=core_schema.plain_serializer_function_ser_schema(cls._encode, return_schema=core_schema.dict_schema()), 25 | ) 26 | 27 | @staticmethod 28 | def _decode(obj): 29 | # We embed polars expressions as a dict, so we need to convert to a full json string first 30 | json_str = orjson.dumps(obj).decode("utf-8", "ignore") 31 | if version.parse(pl.__version__) < version.parse("1.0.0"): 32 | return pl.Expr.deserialize(StringIO(json_str)) 33 | else: 34 | # polars deserializes from a binary format by default. 35 | return pl.Expr.deserialize(StringIO(json_str), format="json") 36 | 37 | @staticmethod 38 | def _encode(obj, info=None): 39 | # obj.meta.serialize produces a string containing a dict, but we just want to return the dict. 40 | if version.parse(pl.__version__) < version.parse("1.0.0"): 41 | return orjson.loads(obj.meta.serialize()) 42 | else: 43 | # polars serializes into a binary format by default. 44 | return orjson.loads(obj.meta.serialize(format="json")) 45 | 46 | @classmethod 47 | def _validate(cls, value: Any) -> Self: 48 | if isinstance(value, pl.Expr): 49 | return value 50 | 51 | if isinstance(value, str): 52 | try: 53 | local_vars = {"col": pl.col, "c": pl.col, "np": np, "numpy": np, "pl": pl, "polars": pl, "math": math} 54 | try: 55 | import scipy as sp # Optional dependency. 56 | 57 | local_vars.update({"scipy": sp, "sp": sp, "sc": sp}) 58 | except ImportError: 59 | pass 60 | expression = eval(value, local_vars, {}) 61 | except Exception as ex: 62 | raise ValueError(f"Error encountered constructing expression - {str(ex)}") 63 | 64 | if not issubclass(type(expression), pl.Expr): 65 | raise ValueError(f"Supplied value '{value}' does not evaluate to a Polars expression") 66 | return expression 67 | 68 | raise ValueError(f"Supplied value '{value}' cannot be converted to a Polars expression") 69 | 70 | 71 | # Public annotated type for Polars expressions 72 | PolarsExpression = Annotated[pl.Expr, _PolarsExprPydanticAnnotation] 73 | -------------------------------------------------------------------------------- /ccflow/tests/test_validators.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import date, datetime, timedelta 3 | from unittest import TestCase 4 | from zoneinfo import ZoneInfo 5 | 6 | from ccflow.validators import eval_or_load_object, load_object, normalize_date, normalize_datetime, str_to_log_level 7 | 8 | 9 | class A: 10 | pass 11 | 12 | 13 | class TestValidators(TestCase): 14 | def test_normalize_date(self): 15 | c = date.today() 16 | self.assertEqual(normalize_date(c), c) 17 | self.assertEqual(normalize_date("0d"), c) 18 | c1 = date.today() - timedelta(1) 19 | self.assertEqual(normalize_date("-1d"), c1) 20 | 21 | self.assertEqual(normalize_date(datetime.now()), c) 22 | self.assertEqual(normalize_date(datetime.now().isoformat()), c) 23 | 24 | self.assertEqual(normalize_date("foo"), "foo") 25 | self.assertEqual(normalize_date(None), None) 26 | 27 | def test_normalize_datetime(self): 28 | today = datetime.today() 29 | now = datetime.now() 30 | c = datetime(today.year, today.month, today.day) 31 | 32 | self.assertEqual(normalize_datetime(c), c) 33 | self.assertEqual(normalize_datetime("0d"), c) 34 | 35 | c1 = c - timedelta(1) 36 | self.assertEqual(normalize_datetime("-1d"), c1) 37 | 38 | self.assertEqual(normalize_datetime(now), now) 39 | 40 | # check passthrough validation error 41 | self.assertEqual(normalize_datetime("foo"), "foo") 42 | self.assertEqual(normalize_datetime(None), None) 43 | 44 | # check dict 45 | self.assertEqual( 46 | normalize_datetime({"dt": now.isoformat(), "tz": "US/Hawaii"}), 47 | now.astimezone(tz=ZoneInfo("US/Hawaii")), 48 | ) 49 | # check list 50 | self.assertEqual( 51 | normalize_datetime([now.isoformat(), "US/Hawaii"]), 52 | now.astimezone(tz=ZoneInfo("US/Hawaii")), 53 | ) 54 | 55 | def test_load_object(self): 56 | self.assertEqual(load_object("ccflow.tests.test_validators.A"), A) 57 | self.assertIsNone(load_object(None)) 58 | self.assertEqual(load_object(5), 5) 59 | 60 | # Special case, if the object to load is string, you might want to load it from an object path 61 | # or you might want to provide it explicitly. Thus, if no object path found, return the value 62 | self.assertEqual(load_object("foo"), "foo") 63 | 64 | def test_eval_or_load_object(self): 65 | f1 = eval_or_load_object("lambda x: x+1") 66 | self.assertEqual(f1(2), 3) 67 | self.assertEqual(eval_or_load_object("A", {"locals": {"A": A}}), A) 68 | 69 | self.assertEqual(eval_or_load_object("ccflow.tests.test_validators.A"), A) 70 | self.assertIsNone(eval_or_load_object(None)) 71 | self.assertEqual(eval_or_load_object(5), 5) 72 | 73 | def test_str_to_log_level(self): 74 | self.assertEqual(str_to_log_level("INFO"), logging.INFO) 75 | self.assertEqual(str_to_log_level("debug"), logging.DEBUG) 76 | self.assertEqual(str_to_log_level(logging.WARNING), logging.WARNING) 77 | -------------------------------------------------------------------------------- /ccflow/tests/utils/test_compose_hydra.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from ccflow import BaseModel, ModelRegistry 6 | 7 | 8 | def _config_path(name: str) -> str: 9 | return os.path.join(os.path.dirname(__file__), "..", "config", name) 10 | 11 | 12 | @pytest.fixture 13 | def registry(): 14 | ModelRegistry.root().clear() 15 | yield ModelRegistry.root() 16 | ModelRegistry.root().clear() 17 | 18 | 19 | def test_hydra_conf_registry_reference_identity(): 20 | # Config supplies registry names for nested BaseModel arguments; identity should be preserved 21 | from ccflow.tests.data.python_object_samples import Consumer, SharedHolder, SharedModel 22 | 23 | path = _config_path("conf_from_python.yaml") 24 | cfg = ModelRegistry.root().create_config_from_path(path=path) 25 | ModelRegistry.root().load_config(cfg, overwrite=True) 26 | 27 | shared = ModelRegistry.root()["shared_model"] 28 | consumer = ModelRegistry.root()["consumer"] 29 | consumer_updated = ModelRegistry.root()["consumer_updated"] 30 | 31 | assert isinstance(shared, SharedModel) 32 | assert isinstance(consumer, Consumer) 33 | # Identity: consumer.shared should be the same instance as registry shared_model 34 | assert consumer.shared is shared 35 | # update_from_template preserves shared identity and applies field updates 36 | assert consumer_updated.shared is shared 37 | assert consumer_updated.tag == "consumer2" 38 | 39 | # Also check dict-returning from_python works and holder is constructed 40 | holder = ModelRegistry.root()["holder"] 41 | assert isinstance(holder, SharedHolder) 42 | assert isinstance(holder.cfg, dict) 43 | 44 | 45 | def test_update_from_template_shared_identity(): 46 | # Ensure shared sub-fields remain identical objects when alias-update is used 47 | from hydra.utils import instantiate 48 | 49 | ModelRegistry.root().clear() 50 | 51 | class Shared(BaseModel): 52 | val: int = 1 53 | 54 | class A(BaseModel): 55 | s: Shared 56 | x: int = 0 57 | 58 | # Register a base object and a shared object by name 59 | shared = Shared(val=5) 60 | base = A(s=shared, x=10) 61 | ModelRegistry.root().add("shared", shared, overwrite=True) 62 | ModelRegistry.root().add("base", base, overwrite=True) 63 | 64 | # Compose a config that uses update_from_template to update only a primitive field 65 | cfg = { 66 | "updated": { 67 | "_target_": "ccflow.compose.update_from_template", 68 | "base": {"_target_": "ccflow.compose.model_alias", "model_name": "base"}, 69 | "update": {"x": 99}, 70 | } 71 | } 72 | 73 | # Hydra instantiate calls the function, which uses model_copy(update=...) 74 | obj = instantiate(cfg["updated"], _convert_="all") 75 | assert isinstance(obj, A) 76 | assert obj.x == 99 77 | # Ensure the shared sub-field refers to the same object as in the registry 78 | assert obj.s is shared 79 | 80 | # Additional: Using update_from_template without changing shared should preserve identity 81 | obj2 = instantiate(cfg["updated"], _convert_="all") 82 | assert obj2.s is shared 83 | -------------------------------------------------------------------------------- /ccflow/publishers/print.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, Generic 3 | 4 | import yaml 5 | from pydantic import Field, field_validator 6 | from typing_extensions import Literal, override 7 | 8 | from ..publisher import BasePublisher 9 | from ..serialization import orjson_dumps 10 | from ..utils import PydanticDictOptions, PydanticModelType, dict_to_model 11 | 12 | __all__ = ( 13 | "PrintPublisher", 14 | "LogPublisher", 15 | "PrintJSONPublisher", 16 | "PrintYAMLPublisher", 17 | "PrintPydanticJSONPublisher", 18 | ) 19 | 20 | 21 | def _orjson_dumps(data: Any, **kwargs): 22 | default = kwargs.pop("default", None) 23 | return orjson_dumps(data, default=default) 24 | 25 | 26 | class PrintPublisher(BasePublisher): 27 | """Print data using python standard print.""" 28 | 29 | @override 30 | def __call__(self) -> Any: 31 | if self.data is None: 32 | raise ValueError("'data' field must be set before publishing") 33 | print(self.data) 34 | return self.data 35 | 36 | 37 | class LogPublisher(BasePublisher): 38 | """Print data using python standard logging.""" 39 | 40 | level: Literal["debug", "info", "warning", "error", "critical"] = Field( 41 | "info", 42 | description="The log level to use for logging the data", 43 | ) 44 | logger_name: str = Field( 45 | "ccflow", 46 | description="The name of the logger to use for logging the data", 47 | ) 48 | 49 | @override 50 | def __call__(self) -> Any: 51 | if self.data is None: 52 | raise ValueError("'data' field must be set before publishing") 53 | 54 | logging.getLogger(self.logger_name).log( 55 | level=getattr(logging, self.level.upper()), 56 | msg=self.data, 57 | ) 58 | 59 | return self.data 60 | 61 | 62 | class PrintJSONPublisher(BasePublisher): 63 | """Print data in JSON format.""" 64 | 65 | kwargs: Dict[str, Any] = Field(default_factory=dict) 66 | 67 | @override 68 | def __call__(self) -> Any: 69 | print(_orjson_dumps(self.data, **self.kwargs)) 70 | return self.data 71 | 72 | 73 | class PrintYAMLPublisher(BasePublisher): 74 | """Print data in YAML format.""" 75 | 76 | kwargs: Dict[str, Any] = Field(default_factory=dict) 77 | 78 | @override 79 | def __call__(self) -> Any: 80 | print(yaml.dump(self.data, **self.kwargs)) 81 | return self.data 82 | 83 | 84 | class PrintPydanticJSONPublisher(BasePublisher, Generic[PydanticModelType]): 85 | """Print pydantic model as json. 86 | 87 | See https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump_json 88 | """ 89 | 90 | data: PydanticModelType = None 91 | options: PydanticDictOptions = Field(default_factory=PydanticDictOptions) 92 | kwargs: Dict[str, Any] = Field(default_factory=dict) 93 | 94 | _normalize_data = field_validator("data", mode="before")(dict_to_model) 95 | 96 | @override 97 | def __call__(self) -> Any: 98 | kwargs = self.options.model_dump(mode="python") 99 | kwargs.update(self.kwargs) 100 | print(self.data.model_dump_json(**kwargs)) 101 | return self.data 102 | -------------------------------------------------------------------------------- /ccflow/result/pyarrow.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | import pyarrow as pa 4 | from pydantic import Field, model_validator 5 | 6 | from ..base import ResultBase 7 | from ..context import DateRangeContext 8 | from ..exttypes import ArrowTable 9 | 10 | __all__ = ( 11 | "ArrowResult", 12 | "ArrowDateRangeResult", 13 | ) 14 | 15 | if TYPE_CHECKING: 16 | from narwhals.stable.v1.typing import DataFrameT 17 | 18 | 19 | class ArrowResult(ResultBase): 20 | """Result that holds an Arrow Table.""" 21 | 22 | table: ArrowTable 23 | 24 | @model_validator(mode="wrap") 25 | def _validate(cls, v, handler, info): 26 | if not isinstance(v, ArrowResult) and not (isinstance(v, dict) and "table" in v): 27 | v = {"table": v} 28 | return handler(v) 29 | 30 | @property 31 | def df(self) -> "DataFrameT": 32 | """Return the Arrow table as a narwhals DataFrame.""" 33 | # For duck-type compatibility with NarwhalsDataFrameResult (but not for serialization) 34 | import narwhals.stable.v1 as nw 35 | 36 | return nw.from_native(self.table) 37 | 38 | 39 | class ArrowDateRangeResult(ArrowResult): 40 | """Extension of ArrowResult for representing a table over a date range that can be divided by date, 41 | such that generation of any sub-range of dates gives the same results as the original table filtered for those dates. 42 | 43 | Use of this ResultType assumes the data satisfies the condition above! 44 | 45 | This is useful for representing the results of queries of daily data. Furthermore, because the identity of the column 46 | containing the underlying date is known, it can be used to partition the data for future queries and caching. 47 | With the generic ArrowResult there is no way to know which column might correspond to the dates in the date range. 48 | """ 49 | 50 | date_col: str = Field(description="The column corresponding to the date of the record. It must align with the context dates.") 51 | context: DateRangeContext = Field( 52 | description="The context that generated the result. Validation will check that all the dates in the date_col are within the context range." 53 | ) 54 | 55 | @model_validator(mode="after") 56 | def _validate_date_col(self): 57 | import pyarrow.compute 58 | 59 | if self.date_col not in self.table.column_names: 60 | raise ValueError("date_col must be a column in table") 61 | col_type = self.table.schema.field(self.date_col).type 62 | if not pa.types.is_date(col_type): 63 | raise ValueError(f"date_col must be of date type, not {col_type}") 64 | dates = self.table[self.date_col] 65 | if len(dates): 66 | min_date = pyarrow.compute.min(dates).as_py() 67 | max_date = pyarrow.compute.max(dates).as_py() 68 | start_date = self.context.start_date 69 | end_date = self.context.end_date 70 | if min_date < start_date: 71 | raise ValueError(f"The min date value ({min_date}) is smaller than the start date of the context ({start_date})") 72 | if max_date > end_date: 73 | raise ValueError(f"The max date value ({max_date}) is smaller than the end date of the context ({end_date})") 74 | return self 75 | -------------------------------------------------------------------------------- /ccflow/tests/exttypes/test_polars.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import polars as pl 5 | import pytest 6 | import scipy 7 | from packaging import version 8 | from pydantic import TypeAdapter, ValidationError 9 | 10 | from ccflow import BaseModel 11 | from ccflow.exttypes.polars import PolarsExpression 12 | 13 | 14 | def test_expression_passthrough(): 15 | adapter = TypeAdapter(PolarsExpression) 16 | expression = pl.col("Col1") + pl.col("Col2") 17 | result = adapter.validate_python(expression) 18 | assert result.meta.serialize() == expression.meta.serialize() 19 | 20 | 21 | def test_expression_from_string(): 22 | adapter = TypeAdapter(PolarsExpression) 23 | expected_result = pl.col("Col1") + pl.col("Col2") 24 | expression = adapter.validate_python("pl.col('Col1') + pl.col('Col2')") 25 | assert expression.meta.serialize() == expected_result.meta.serialize() 26 | 27 | 28 | def test_expression_complex(): 29 | adapter = TypeAdapter(PolarsExpression) 30 | expected_result = pl.col("Col1") + (scipy.linalg.det(np.eye(2, dtype=int)) - 1) * math.pi * pl.col("Col2") + pl.col("Col2") 31 | expression = adapter.validate_python("col('Col1') + (sp.linalg.det(numpy.eye(2, dtype=int)) - 1 ) * math.pi * c('Col2') + polars.col('Col2')") 32 | assert expression.meta.serialize() == expected_result.meta.serialize() 33 | 34 | 35 | def test_validation_failure(): 36 | adapter = TypeAdapter(PolarsExpression) 37 | with pytest.raises(ValidationError): 38 | adapter.validate_python(None) 39 | with pytest.raises(ValidationError): 40 | adapter.validate_python("pl.DataFrame()") 41 | 42 | 43 | def test_validation_eval_failure(): 44 | adapter = TypeAdapter(PolarsExpression) 45 | with pytest.raises(ValidationError): 46 | adapter.validate_python("invalid_statement") 47 | 48 | 49 | def test_json_serialization_roundtrip(): 50 | adapter = TypeAdapter(PolarsExpression) 51 | expression = pl.col("Col1") + pl.col("Col2") 52 | json_result = adapter.dump_json(expression) 53 | if version.parse(pl.__version__) < version.parse("1.0.0"): 54 | assert json_result.decode("utf-8") == expression.meta.serialize() 55 | else: 56 | assert json_result.decode("utf-8") == expression.meta.serialize(format="json") 57 | 58 | expected_result = adapter.validate_json(json_result) 59 | assert expected_result.meta.serialize() == expression.meta.serialize() 60 | 61 | 62 | def test_model_field_and_dataframe_filter(): 63 | class DummyExprModel(BaseModel): 64 | expr: PolarsExpression 65 | 66 | m = DummyExprModel(expr="pl.col('x') > 10") 67 | assert isinstance(m.expr, pl.Expr) 68 | 69 | df = pl.DataFrame({"x": [5, 10, 11, 20], "y": [1, 2, 3, 4]}) 70 | filtered = df.filter(m.expr) 71 | assert filtered.select("x").to_series().to_list() == [11, 20] 72 | 73 | 74 | def test_model_field_and_dataframe_with_columns(): 75 | class DummyExprModel(BaseModel): 76 | expr: PolarsExpression 77 | 78 | raw_expr = 'pl.col("x").rolling_max(window_size=2)' 79 | m = DummyExprModel(expr=raw_expr) 80 | assert isinstance(m.expr, pl.Expr) 81 | 82 | df = pl.DataFrame({"x": [5, 19, 17, 13, 8, 20], "y": [1, 2, 3, 4, 5, 6]}) 83 | transformed = df.select(m.expr) 84 | assert transformed.to_series().to_list() == [None, 19, 19, 17, 13, 20] 85 | -------------------------------------------------------------------------------- /ccflow/utils/arrow.py: -------------------------------------------------------------------------------- 1 | """Various arrow tools""" 2 | 3 | from typing import Any, Dict, Optional 4 | 5 | import orjson 6 | import pyarrow as pa 7 | 8 | from ccflow.serialization import orjson_dumps 9 | 10 | __all__ = ( 11 | "convert_decimal_types_to_float", 12 | "convert_large_types", 13 | "add_field_metadata", 14 | "get_field_metadata", 15 | ) 16 | 17 | 18 | def convert_decimal_types_to_float(table: pa.table, target_type: Optional[pa.DataType] = None) -> pa.Table: 19 | """Converts decimal types to float or other user-provided type 20 | 21 | Args: 22 | table: The table to convert schema for 23 | target_type: The target type to convert decimal types to. If not supplied, will default to Float64 24 | 25 | Returns: 26 | A pyarrow table whose decimal types have been converted to the specified target type 27 | """ 28 | 29 | if target_type is None: 30 | target_type = pa.float64() 31 | 32 | fields = [] 33 | for field in table.schema: 34 | if pa.types.is_decimal(field.type): 35 | new_field = pa.field(field.name, target_type, field.nullable) 36 | else: 37 | new_field = field 38 | fields.append(new_field) 39 | 40 | schema = pa.schema(fields) 41 | return table.cast(schema) 42 | 43 | 44 | def convert_large_types(table: pa.Table) -> pa.Table: 45 | """Converts the large types to their regular counterparts in pyarrow. 46 | 47 | This is necessary because polars always uses large list, but pyarrow 48 | recommends using the regular one, as it is more accepted (e.g. by csp) 49 | https://arrow.apache.org/docs/python/generated/pyarrow.large_list.html 50 | """ 51 | fields = [] 52 | for field in table.schema: 53 | if pa.types.is_large_list(field.type): 54 | new_field = pa.field(field.name, pa.list_(field.type.value_type), field.nullable) 55 | elif pa.types.is_large_binary(field.type): 56 | new_field = pa.field(field.name, pa.binary(), field.nullable) 57 | elif pa.types.is_large_string(field.type): 58 | new_field = pa.field(field.name, pa.string(), field.nullable) 59 | else: 60 | new_field = field 61 | fields.append(new_field) 62 | schema = pa.schema(fields) 63 | return table.cast(schema) 64 | 65 | 66 | def add_field_metadata(table: pa.Table, metadata: Dict[str, Any]): 67 | """Helper function to add column-level meta data to an arrow table for multiple columns at once.""" 68 | # There does not seem to be a pyarrow function to do this easily 69 | new_schema = [] 70 | for field in table.schema: 71 | if field.name in metadata: 72 | field_metadata = {k: orjson_dumps(v) for k, v in metadata[field.name].items()} 73 | new_field = field.with_metadata(field_metadata) 74 | else: 75 | new_field = field 76 | new_schema.append(new_field) 77 | return table.cast(pa.schema(new_schema)) 78 | 79 | 80 | def get_field_metadata(table: pa.Table) -> Dict[str, Any]: 81 | """Helper function to retrieve all the field level metadata in an arrow table.""" 82 | metadata = {} 83 | for field in table.schema: 84 | raw_metadata = field.metadata 85 | if raw_metadata: 86 | metadata[field.name] = {k.decode("UTF-8"): orjson.loads(v) for k, v in raw_metadata.items()} 87 | return metadata 88 | -------------------------------------------------------------------------------- /docs/wiki/Local-Development-Setup.md: -------------------------------------------------------------------------------- 1 | - [Step 1: Build from Source](#step-1-build-from-source) 2 | - [Step 2: Configuring Git and GitHub for Development](#step-2-configuring-git-and-github-for-development) 3 | - [Create your fork](#create-your-fork) 4 | - [Configure remotes](#configure-remotes) 5 | - [Authenticating with GitHub](#authenticating-with-github) 6 | - [Guidelines](#guidelines) 7 | 8 | ## Step 1: Build from Source 9 | 10 | To work on `ccflow`, you are going to need to build it from source. See 11 | [Build from Source](Build-from-Source) for 12 | detailed build instructions. 13 | 14 | Once you've built `ccflow` from a `git` clone, you will also need to 15 | configure `git` and your GitHub account for `ccflow` development. 16 | 17 | ## Step 2: Configuring Git and GitHub for Development 18 | 19 | ### Create your fork 20 | 21 | The first step is to create a personal fork of `ccflow`. To do so, click 22 | the "fork" button at https://github.com/Point72/ccflow, or just navigate 23 | [here](https://github.com/Point72/ccflow/fork) in your browser. Set the 24 | owner of the repository to your personal GitHub account if it is not 25 | already set that way and click "Create fork". 26 | 27 | ### Configure remotes 28 | 29 | Next, you should set some names for the `git` remotes corresponding to 30 | main Point72 repository and your fork. See the [GitHub Docs](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/configuring-a-remote-repository-for-a-fork) for more information. 31 | 32 | ### Authenticating with GitHub 33 | 34 | If you have not already configured `ssh` access to GitHub, you can find 35 | instructions to do so 36 | [here](https://docs.github.com/en/authentication/connecting-to-github-with-ssh), 37 | including instructions to create an SSH key if you have not done 38 | so. Authenticating with SSH is usually the easiest route. If you are working in 39 | an environment that does not allow SSH connections to GitHub, you can look into 40 | [configuring a hardware 41 | passkey](https://docs.github.com/en/authentication/authenticating-with-a-passkey/about-passkeys) 42 | or adding a [personal access 43 | token](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) 44 | to avoid the need to type in your password every time you push to your fork. 45 | 46 | ## Guidelines 47 | 48 | After developing a change locally, ensure that both [lints](Build-from-Source#lint-and-autoformat) and [tests](Build-from-Source#testing) pass. Commits should be squashed into logical units, and all commits must be signed (e.g. with the `-s` git flag). CSP requires [Developer Certificate of Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) for all contributions. 49 | 50 | If your work is still in-progress, open a [draft pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests#draft-pull-requests). Otherwise, open a normal pull request. It might take a few days for a maintainer to review and provide feedback, so please be patient. If a maintainer asks for changes, please make said changes and squash your commits if necessary. If everything looks good to go, a maintainer will approve and merge your changes for inclusion in the next release. 51 | 52 | Please note that non substantive changes, large changes without prior discussion, etc, are not accepted and pull requests may be closed. 53 | -------------------------------------------------------------------------------- /ccflow/validators.py: -------------------------------------------------------------------------------- 1 | """This module contains common validators.""" 2 | 3 | import logging 4 | from datetime import date, datetime 5 | from typing import Any, Dict, Optional 6 | from zoneinfo import ZoneInfo 7 | 8 | import pandas as pd 9 | from pydantic import TypeAdapter, ValidationError 10 | 11 | from .exttypes import PyObjectPath 12 | 13 | _DatetimeAdapter = TypeAdapter(datetime) 14 | 15 | __all__ = ( 16 | "normalize_date", 17 | "normalize_datetime", 18 | "load_object", 19 | "eval_or_load_object", 20 | "str_to_log_level", 21 | ) 22 | 23 | 24 | def normalize_date(v: Any) -> Any: 25 | """Validator that will convert string offsets to date based on today, and convert datetime to date.""" 26 | if isinstance(v, str): # Check case where it's an offset 27 | try: 28 | timestamp = pd.tseries.frequencies.to_offset(v) + date.today() 29 | return timestamp.date() 30 | except ValueError: 31 | pass 32 | # Convert from anything that can be converted to a datetime to a date via datetime 33 | # This is not normally allowed by pydantic. 34 | try: 35 | v = _DatetimeAdapter.validate_python(v) 36 | if isinstance(v, datetime): 37 | return v.date() 38 | except ValidationError: 39 | pass 40 | return v 41 | 42 | 43 | def normalize_datetime(v: Any) -> Any: 44 | """Validator that will convert string offsets to datetime based on today, and convert datetime to date.""" 45 | if isinstance(v, str): # Check case where it's an offset 46 | try: 47 | return (pd.tseries.frequencies.to_offset(v) + date.today()).to_pydatetime() 48 | except ValueError: 49 | pass 50 | if isinstance(v, dict): 51 | # e.g. DatetimeContext object, {"dt": datetime(...)} 52 | dt = list(v.values())[0] 53 | tz = list(v.values())[1] if len(v) > 1 else None 54 | elif isinstance(v, list): 55 | dt = v[0] 56 | tz = v[1] if len(v) > 1 else None 57 | else: 58 | dt = v 59 | tz = None 60 | try: 61 | dt = TypeAdapter(datetime).validate_python(dt) 62 | if tz and isinstance(tz, str): 63 | tz = ZoneInfo(tz) 64 | if tz: 65 | dt = dt.astimezone(tz) 66 | return dt 67 | except ValidationError: 68 | return v 69 | 70 | 71 | def load_object(v: Any) -> Any: 72 | """Validator that loads an object from path if a string is provided""" 73 | if isinstance(v, str): 74 | try: 75 | return PyObjectPath(v).object 76 | except (ImportError, ValidationError): 77 | pass 78 | return v 79 | 80 | 81 | def eval_or_load_object(v: Any, values: Optional[Dict[str, Any]] = None) -> Any: 82 | """Validator that evaluates or loads an object from path if a string is provided. 83 | 84 | Useful for fields that could be either lambda functions or callables. 85 | """ 86 | if isinstance(v, str): 87 | try: 88 | return eval(v, (values or {}).get("locals", {})) 89 | except NameError: 90 | if isinstance(v, str): 91 | return PyObjectPath(v).object 92 | return v 93 | 94 | 95 | def str_to_log_level(v: Any) -> Any: 96 | """Validator to convert string to a log level.""" 97 | if isinstance(v, str): 98 | return getattr(logging, v.upper()) 99 | return v 100 | -------------------------------------------------------------------------------- /ccflow/object_config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from pydantic import ConfigDict, PrivateAttr, model_validator 4 | from pydantic.fields import Field 5 | 6 | from .base import BaseModel 7 | from .exttypes.pyobjectpath import PyObjectPath 8 | 9 | __all__ = ( 10 | "ObjectConfig", 11 | "LazyObjectConfig", 12 | ) 13 | 14 | 15 | class ObjectConfig(BaseModel): # TODO: Generic model version for type checking 16 | """Small class to help wrap an arbitrary python object as a BaseModel. 17 | 18 | This allows such objects to be registered by name in the registry, 19 | without having to define a custom pydantic wrapper for them. 20 | """ 21 | 22 | model_config = ConfigDict( 23 | ignored_types=(property,), 24 | extra="allow", 25 | frozen=True, # Because we cache _object 26 | ) 27 | object_type: PyObjectPath = Field( 28 | None, 29 | description="The type of the object this model wraps.", 30 | ) 31 | object_kwargs: Dict[str, Any] = {} 32 | _object: Any = PrivateAttr(None) 33 | 34 | @model_validator(mode="wrap") 35 | def _kwarg_validator(cls, values, handler, info): 36 | if isinstance(values, dict): 37 | # Uplift extra fields into object_kwargs 38 | obj_kwargs = values.get("object_kwargs", {}) 39 | for field in list(values): 40 | if field not in cls.model_fields: 41 | obj_kwargs[field] = values.pop(field) 42 | values["object_kwargs"] = obj_kwargs 43 | return handler(values) 44 | 45 | def __getstate__(self): 46 | """Override pickling to ignore _object, so that configs can be pickled even if the underlying object cannot.""" 47 | state_dict = self.__dict__.copy() 48 | state_dict.pop("_object", None) 49 | return { 50 | "__dict__": state_dict, 51 | "__pydantic_fields_set__": self.__pydantic_fields_set__, 52 | "__pydantic_extra__": self.__pydantic_extra__, 53 | "__pydantic_private__": self.__pydantic_private__, 54 | } 55 | 56 | def __setstate__(self, state): 57 | super().__setstate__(state) 58 | self._object = self.object_type.object(**self.object_kwargs) 59 | 60 | def __init__(self, *args, **kwargs): 61 | super().__init__(*args, **kwargs) 62 | # Eagerly construct object. This way, if it fails to construct, it will be known immediately. 63 | self._object = self.object_type.object(**self.object_kwargs) 64 | 65 | @property 66 | def object(self): 67 | """Returns the pre-constructed object corresponding to the config.""" 68 | return self._object 69 | 70 | 71 | class LazyObjectConfig(ObjectConfig): 72 | """Like ObjectConfig, but the object is constructed lazily (on first access). 73 | 74 | One loses upfront validation that it's a valid config, but potentially gains performance benefits 75 | of not constructing unneeded objects. 76 | """ 77 | 78 | def __init__(self, *args, **kwargs): 79 | super(ObjectConfig, self).__init__(*args, **kwargs) 80 | 81 | @property 82 | def object(self): 83 | """Returns the lazily-constructed object corresponding to the config.""" 84 | if not self._object: 85 | self._object = self.object_type.object(**self.object_kwargs) 86 | return self._object 87 | 88 | def __setstate__(self, state): 89 | super().__setstate__(state) 90 | self._object = None 91 | -------------------------------------------------------------------------------- /ccflow/tests/test_object_config.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from unittest import TestCase 3 | 4 | import pytest 5 | from pydantic import ValidationError 6 | 7 | from ccflow import BaseModel, LazyObjectConfig, ModelRegistry, ObjectConfig 8 | 9 | 10 | class MyClass: 11 | def __init__(self, p="p", q=10.0): 12 | self.p = p 13 | self.q = q 14 | 15 | 16 | class ContainerClass(BaseModel): 17 | config: ObjectConfig 18 | 19 | 20 | class TestObjectConfig(TestCase): 21 | def tearDown(self): 22 | ModelRegistry.root().clear() 23 | 24 | def test_construction(self): 25 | config = ObjectConfig( 26 | object_type="ccflow.tests.test_object_config.MyClass", 27 | object_kwargs=dict(p="foo"), 28 | ) 29 | self.assertIsInstance(config.object, MyClass) 30 | self.assertEqual(config.object.p, "foo") 31 | 32 | with pytest.raises((TypeError, ValidationError)): 33 | config.p = "bar" 34 | 35 | with pytest.raises(TypeError): 36 | config = ObjectConfig( 37 | object_type="ccflow.tests.test_object_config.MyClass", 38 | object_kwargs=dict(garbage="foo"), 39 | ) 40 | 41 | def test_lazy_construction(self): 42 | config = LazyObjectConfig( 43 | object_type="ccflow.tests.test_object_config.MyClass", 44 | object_kwargs=dict(p="foo"), 45 | ) 46 | self.assertIsInstance(config.object, MyClass) 47 | self.assertEqual(config.object.p, "foo") 48 | 49 | with pytest.raises((TypeError, ValidationError)): 50 | config.p = "bar" 51 | 52 | config = LazyObjectConfig( 53 | object_type="ccflow.tests.test_object_config.MyClass", 54 | object_kwargs=dict(garbage="foo"), 55 | ) 56 | with pytest.raises(TypeError): 57 | config.object 58 | 59 | def test_validation(self): 60 | for Config in [ObjectConfig, LazyObjectConfig]: 61 | config = Config( 62 | object_type="ccflow.tests.test_object_config.MyClass", 63 | p="foo", 64 | q=5, 65 | ) 66 | self.assertIsInstance(config.object, MyClass) 67 | 68 | # Check the result is as expected 69 | self.assertEqual(config.object.p, "foo") 70 | self.assertEqual(config.object.q, 5) 71 | 72 | def test_pickling(self): 73 | for Config in [ObjectConfig, LazyObjectConfig]: 74 | config = Config( 75 | object_type="ccflow.tests.test_object_config.MyClass", 76 | object_kwargs=dict(p="foo", q=5), 77 | ) 78 | # Insert pickling step 79 | config = pickle.loads(pickle.dumps(config)) 80 | 81 | self.assertIsInstance(config.object, MyClass) 82 | self.assertEqual(config.object.p, "foo") 83 | self.assertEqual(config.object.q, 5) 84 | 85 | def test_registration(self): 86 | """Test that validators on config objects don't interfere with BaseModel validators""" 87 | for Config in [ObjectConfig, LazyObjectConfig]: 88 | config = Config( 89 | object_type="ccflow.tests.test_object_config.MyClass", 90 | object_kwargs=dict(p="foo", q=5), 91 | ) 92 | r = ModelRegistry.root() 93 | r.add("foo", config, overwrite=True) 94 | 95 | cc = ContainerClass(config="foo") 96 | self.assertEqual(cc.config, config) 97 | -------------------------------------------------------------------------------- /docs/wiki/contribute/Local-Development-Setup.md: -------------------------------------------------------------------------------- 1 | ## Table of Contents 2 | 3 | - [Table of Contents](#table-of-contents) 4 | - [Step 1: Build from Source](#step-1-build-from-source) 5 | - [Step 2: Configuring Git and GitHub for Development](#step-2-configuring-git-and-github-for-development) 6 | - [Create your fork](#create-your-fork) 7 | - [Configure remotes](#configure-remotes) 8 | - [Authenticating with GitHub](#authenticating-with-github) 9 | - [Guidelines](#guidelines) 10 | 11 | ## Step 1: Build from Source 12 | 13 | To work on `ccflow`, you are going to need to build it from source. See 14 | [Build from Source](Build-from-Source) for 15 | detailed build instructions. 16 | 17 | Once you've built `ccflow` from a `git` clone, you will also need to 18 | configure `git` and your GitHub account for `ccflow` development. 19 | 20 | ## Step 2: Configuring Git and GitHub for Development 21 | 22 | ### Create your fork 23 | 24 | The first step is to create a personal fork of `ccflow`. To do so, click 25 | the "fork" button at https://github.com/Point72/ccflow, or just navigate 26 | [here](https://github.com/Point72/ccflow/fork) in your browser. Set the 27 | owner of the repository to your personal GitHub account if it is not 28 | already set that way and click "Create fork". 29 | 30 | ### Configure remotes 31 | 32 | Next, you should set some names for the `git` remotes corresponding to 33 | main Point72 repository and your fork. See the [GitHub Docs](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/configuring-a-remote-repository-for-a-fork) for more information. 34 | 35 | ### Authenticating with GitHub 36 | 37 | If you have not already configured `ssh` access to GitHub, you can find 38 | instructions to do so 39 | [here](https://docs.github.com/en/authentication/connecting-to-github-with-ssh), 40 | including instructions to create an SSH key if you have not done 41 | so. Authenticating with SSH is usually the easiest route. If you are working in 42 | an environment that does not allow SSH connections to GitHub, you can look into 43 | [configuring a hardware 44 | passkey](https://docs.github.com/en/authentication/authenticating-with-a-passkey/about-passkeys) 45 | or adding a [personal access 46 | token](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) 47 | to avoid the need to type in your password every time you push to your fork. 48 | 49 | ## Guidelines 50 | 51 | After developing a change locally, ensure that both [lints](Build-from-Source#lint-and-autoformat) and [tests](Build-from-Source#testing) pass. Commits should be squashed into logical units, and all commits must be signed (e.g. with the `-s` git flag). We require [Developer Certificate of Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) for all contributions. 52 | 53 | If your work is still in-progress, open a [draft pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests#draft-pull-requests). Otherwise, open a normal pull request. It might take a few days for a maintainer to review and provide feedback, so please be patient. If a maintainer asks for changes, please make said changes and squash your commits if necessary. If everything looks good to go, a maintainer will approve and merge your changes for inclusion in the next release. 54 | 55 | Please note that non substantive changes, large changes without prior discussion, etc, are not accepted and pull requests may be closed. 56 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at OpenSource@point72.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ######### 2 | # BUILD # 3 | ######### 4 | .PHONY: develop build install 5 | 6 | develop: ## install dependencies and build library 7 | uv pip install -e .[develop] 8 | 9 | requirements: ## install prerequisite python build requirements 10 | python -m pip install --upgrade pip toml 11 | python -m pip install `python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["build-system"]["requires"]))'` 12 | python -m pip install `python -c 'import toml; c = toml.load("pyproject.toml"); print(" ".join(c["project"]["optional-dependencies"]["develop"]))'` 13 | 14 | build: ## build the python library 15 | python -m build -n 16 | 17 | install: ## install library 18 | uv pip install . 19 | 20 | ######### 21 | # LINTS # 22 | ######### 23 | .PHONY: lint-py lint-docs fix-py fix-docs lint lints fix format 24 | 25 | lint-py: ## lint python with ruff 26 | python -m ruff check ccflow 27 | python -m ruff format --check ccflow 28 | 29 | lint-docs: ## lint docs with mdformat and codespell 30 | python -m mdformat --check README.md docs/wiki/ 31 | python -m codespell_lib README.md docs/wiki/ 32 | 33 | fix-py: ## autoformat python code with ruff 34 | python -m ruff check --fix ccflow 35 | python -m ruff format ccflow 36 | 37 | fix-docs: ## autoformat docs with mdformat and codespell 38 | python -m mdformat README.md docs/wiki/ 39 | python -m codespell_lib --write README.md docs/wiki/ 40 | 41 | lint: lint-py lint-docs ## run all linters 42 | lints: lint 43 | fix: fix-py fix-docs ## run all autoformatters 44 | format: fix 45 | 46 | ################ 47 | # Other Checks # 48 | ################ 49 | .PHONY: check-manifest checks check 50 | 51 | check-manifest: ## check python sdist manifest with check-manifest 52 | check-manifest -v 53 | 54 | checks: check-manifest 55 | 56 | # Alias 57 | check: checks 58 | 59 | ######### 60 | # TESTS # 61 | ######### 62 | .PHONY: test coverage tests 63 | 64 | test: ## run python tests 65 | python -m pytest -v ccflow/tests 66 | 67 | coverage: ## run tests and collect test coverage 68 | python -m pytest -v ccflow/tests --cov=ccflow --cov-report term-missing --cov-report xml 69 | 70 | # Alias 71 | tests: test 72 | 73 | ########### 74 | # VERSION # 75 | ########### 76 | .PHONY: show-version patch minor major 77 | 78 | show-version: ## show current library version 79 | @bump-my-version show current_version 80 | 81 | patch: ## bump a patch version 82 | @bump-my-version bump patch 83 | 84 | minor: ## bump a minor version 85 | @bump-my-version bump minor 86 | 87 | major: ## bump a major version 88 | @bump-my-version bump major 89 | 90 | ######## 91 | # DIST # 92 | ######## 93 | .PHONY: dist dist-build dist-sdist dist-local-wheel publish 94 | 95 | dist-build: # build python dists 96 | python -m build -w -s 97 | 98 | dist-check: ## run python dist checker with twine 99 | python -m twine check dist/* 100 | 101 | dist: clean dist-build dist-check ## build all dists 102 | 103 | publish: dist ## publish python assets 104 | 105 | ######### 106 | # CLEAN # 107 | ######### 108 | .PHONY: deep-clean clean 109 | 110 | deep-clean: ## clean everything from the repository 111 | git clean -fdx 112 | 113 | clean: ## clean the repository 114 | rm -rf .coverage coverage cover htmlcov logs build dist *.egg-info 115 | 116 | ############################################################################################ 117 | 118 | .PHONY: help 119 | 120 | # Thanks to Francoise at marmelab.com for this 121 | .DEFAULT_GOAL := help 122 | help: 123 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 124 | 125 | print-%: 126 | @echo '$*=$($*)' 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
5 |