├── trufflehog-ignore.txt
├── tests
├── test_rich_progress_hooks.py
├── conftest.py
└── utilities
│ ├── test_formatting_utils.py
│ └── test_catalog_utils.py
├── requirements.txt
├── kedro_rich
├── __init__.py
├── constants.py
├── rich_init.py
├── rich_cli.py
├── utilities
│ ├── formatting_utils.py
│ ├── catalog_utils.py
│ └── kedro_override_utils.py
└── rich_progress_hooks.py
├── static
└── list-datasets.png
├── setup.cfg
├── .coveragerc
├── pyproject.toml
├── test_requirements.txt
├── .github
├── PULL_REQUEST_TEMPLATE.md
├── ISSUE_TEMPLATE
│ ├── feature-request.md
│ └── bug-report.md
└── workflows
│ └── ci.yml
├── Makefile
├── setup.py
├── .gitignore
├── .pre-commit-config.yaml
└── README.md
/trufflehog-ignore.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/test_rich_progress_hooks.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | kedro~=0.17.3
2 | rich-click~=1.2.1
3 | rich~=11.2.0
4 |
--------------------------------------------------------------------------------
/kedro_rich/__init__.py:
--------------------------------------------------------------------------------
1 | """Rich plugin to make your Kedro snazzy"""
2 |
3 | __version__ = "0.0.14"
4 |
--------------------------------------------------------------------------------
/static/list-datasets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datajoely/kedro-rich/HEAD/static/list-datasets.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E126,E203,E231,E266,E501,W503
3 | max-line-length = 88
4 | max-complexity = 18
5 | select = B,C,E,F,W,T4,B9
6 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [report]
2 | fail_under=100
3 | show_missing=True
4 | omit = *tests*
5 | exclude_lines =
6 | pragma: no cover
7 | raise NotImplementedError
8 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | multi_line_output = 3
3 | include_trailing_comma = true
4 | force_grid_wrap = 0
5 | use_parentheses = true
6 | line_length = 88
7 | known_third_party = "kedro"
8 |
9 | [tool.black]
10 |
--------------------------------------------------------------------------------
/test_requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 | bandit>=1.6.2, <2.0
3 | black==v19.10.b0
4 | flake8
5 | isort>=4.3.21, <5.0
6 | pre-commit>=1.17.0, <2.0
7 | psutil==5.6.6
8 | pylint>=2.5.2, <3.0
9 | pytest
10 | pytest-cov
11 | pytest-mock
12 | trufflehog>=2.1.0, <3.0
13 | wheel
14 |
--------------------------------------------------------------------------------
/kedro_rich/constants.py:
--------------------------------------------------------------------------------
1 | """Simple module for maintaining configuration"""
2 |
3 | KEDRO_RICH_PROGRESS_ENV_VAR_KEY = "KEDRO_RICH_SHOW_PROGRESS"
4 | KEDRO_RICH_SHOW_DATASET_PROGRESS = True
5 |
6 | KEDRO_RICH_LOGGING_HANDLER = {
7 | "class": "rich.logging.RichHandler",
8 | "level": "INFO",
9 | "markup": True,
10 | "log_time_format": "[%X]",
11 | }
12 |
13 | KEDRO_RICH_CATALOG_LIST_THRESHOLD = 10
14 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Description
2 |
3 |
4 | ## Development notes
5 |
6 |
7 | ## Checklist
8 |
9 | - [ ] Read the [contributing](https://github.com/kedro-org/kedro/blob/main/CONTRIBUTING.md) guidelines
10 | - [ ] Opened this PR as a 'Draft Pull Request' if it is work-in-progress
11 | - [ ] Updated the documentation to reflect the code changes
12 | - [ ] Added a description of this change in the [`RELEASE.md`](https://github.com/kedro-org/kedro/blob/main/RELEASE.md) file
13 | - [ ] Added tests to cover my changes
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Let us know if you have a feature request or enhancement
4 | title: '
'
5 | labels: 'Issue: Feature Request'
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## Description
11 | Is your feature request related to a problem? A clear and concise description of what the problem is: "I'm always frustrated when ..."
12 |
13 | ## Context
14 | Why is this change important to you? How would you use it? How can it benefit other users?
15 |
16 | ## Possible Implementation
17 | (Optional) Suggest an idea for implementing the addition or change.
18 |
19 | ## Possible Alternatives
20 | (Optional) Describe any alternative solutions or features you've considered.
21 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: If something isn't working
4 | title: ''
5 | labels: 'Issue: Bug Report'
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## Description
11 | Short description of the problem here.
12 |
13 | ## Context
14 | How has this bug affected you? What were you trying to accomplish?
15 |
16 | ## Steps to Reproduce
17 | 1. [First Step]
18 | 2. [Second Step]
19 | 3. [And so on...]
20 |
21 | ## Expected Result
22 | Tell us what should happen.
23 |
24 | ## Actual Result
25 | Tell us what happens instead.
26 |
27 | ```
28 | -- If you received an error, place it here.
29 | ```
30 |
31 | ```
32 | -- Separate them if you have more than one.
33 | ```
34 |
35 | ## Your Environment
36 | Include as many relevant details about the environment in which you experienced the bug:
37 |
38 | * Kedro-rich version used (`pip show kedro-rich`):
39 | * Kedro version used (`pip show kedro` or `kedro -V`):
40 | * Python version used (`python -V`):
41 | * Operating system and version:
42 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from kedro.io.data_catalog import DataCatalog
2 | from kedro.io import MemoryDataSet
3 | from kedro.pipeline import Pipeline, node
4 | from kedro.extras.datasets.pickle import PickleDataSet
5 | import pytest
6 |
7 |
8 | @pytest.fixture()
9 | def data_catalog_fixture() -> DataCatalog:
10 |
11 | return DataCatalog(
12 | data_sets={
13 | "dataset_1": PickleDataSet(filepath="test"),
14 | "dataset_2": MemoryDataSet(),
15 | "dataset_3": PickleDataSet(filepath="test"),
16 | },
17 | feed_dict={"params.modelling_params": {"test_size": 0.3, "split_ratio": 0.7}},
18 | )
19 |
20 |
21 | @pytest.fixture()
22 | def pipeline_fixture() -> Pipeline:
23 | return Pipeline(
24 | nodes=[
25 | node(func=lambda x: x, inputs="dataset_1", outputs="dataset_2"),
26 | node(
27 | func=lambda x: x,
28 | inputs="dataset_2",
29 | outputs="dataset_2.5",
30 | namespace="test",
31 | ),
32 | node(func=lambda x: x, inputs="dataset_2.5", outputs="dataset_3"),
33 | ]
34 | )
35 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: kedro_rich_run
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 |
7 | jobs:
8 | kedro_rich:
9 | name: kedro_rich_run_and_test
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Check out
14 | uses: actions/checkout@main
15 | with:
16 | ref: main
17 |
18 | - uses: actions/setup-python@v1
19 | with:
20 | python-version: "3.8.x"
21 |
22 | - name: Install dependencies
23 | run: |
24 | pip install -r requirements.txt
25 |
26 | - name: Pull spaceflights and install kedro-rich in editable mode
27 | run: make test-project
28 |
29 | - name: Kedro run (sequential)
30 | run: cd test_project; kedro run
31 |
32 | - name: Kedro run (parallel)
33 | run: cd test_project; kedro run --parallel
34 |
35 | - name: Kedro catalog list (json)
36 | run: cd test_project; kedro catalog list --format=json
37 |
38 | - name: Kedro catalog list (yaml)
39 | run: cd test_project; kedro catalog list --format=yaml
40 |
41 | - name: Kedro catalog list (table)
42 | run: cd test_project; kedro catalog list --format=table
43 |
44 | - name: Run PyTest
45 | run: make test
46 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | package:
2 | rm -Rf dist
3 | python setup.py sdist bdist_wheel
4 |
5 | install: package
6 | pip uninstall kedro-rich -y
7 | pip install -U dist/*.whl
8 |
9 | dev-install:
10 | pip install -e .
11 |
12 | install-pip-setuptools:
13 | python -m pip install -U pip setuptools wheel
14 |
15 | lint:
16 | pre-commit run -a --hook-stage manual
17 |
18 | test:
19 | pytest -vv tests --cov
20 |
21 | clean:
22 | rm -rf build dist pip-wheel-metadata .pytest_cache
23 | find . -regex ".*/__pycache__" -exec rm -rf {} +
24 | find . -regex ".*\.egg-info" -exec rm -rf {} +
25 | rm -rf test_project/
26 |
27 | install-test-requirements:
28 | pip install -r test_requirements.txt
29 |
30 | install-pre-commit: install-test-requirements
31 | pre-commit install --install-hooks
32 |
33 | uninstall-pre-commit:
34 | pre-commit uninstall
35 | pre-commit uninstall --hook-type pre-push
36 |
37 | test-project:
38 | pip install -e .
39 | rm -rf test_project/
40 | yes test_project | kedro new --starter=spaceflights
41 | pip install -r test_project/src/requirements.txt
42 | touch .telemetry
43 | echo "consent: false" >> .telemetry
44 | mv .telemetry test_project/
45 |
46 | test-run:
47 | cd test_project; kedro run
48 |
49 | clear-test-run:
50 | rm -rf test_project/
51 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import re
2 | from codecs import open
3 | from os import path
4 |
5 | from setuptools import setup
6 |
7 | name = "kedro-rich"
8 | here = path.abspath(path.dirname(__file__))
9 |
10 | # get package version
11 | package_name = name.replace("-", "_")
12 | with open(path.join(here, package_name, "__init__.py"), encoding="utf-8") as f:
13 | version = re.search(r'__version__ = ["\']([^"\']+)', f.read()).group(1)
14 |
15 | # get the dependencies and installs
16 | with open("requirements.txt", "r", encoding="utf-8") as f:
17 | requires = [x.strip() for x in f if x.strip()]
18 |
19 | # Get the long description from the README file
20 | with open(path.join(here, "README.md"), encoding="utf-8") as f:
21 | readme = f.read()
22 |
23 | setup(
24 | name=name,
25 | version=version,
26 | description="Kedro-Rich",
27 | long_description=readme,
28 | long_description_content_type="text/markdown",
29 | author="datajoely",
30 | python_requires=">=3.6, <3.9",
31 | install_requires=requires,
32 | license="Apache Software License (Apache 2.0)",
33 | packages=["kedro_rich"],
34 | include_package_data=True,
35 | zip_safe=False,
36 | entry_points={
37 | "kedro.project_commands": ["kedro_rich_project = kedro_rich.rich_cli:commands"],
38 | "kedro.hooks" : ["kedro_rich_progress = kedro_rich.rich_progress_hooks:rich_hooks"],
39 | "kedro.init" : ["kedro_rich_init = kedro_rich.rich_init:start_up"]
40 | },
41 | )
42 |
--------------------------------------------------------------------------------
/kedro_rich/rich_init.py:
--------------------------------------------------------------------------------
1 | """This module ensures that the rich logging and exceptions handlers are used"""
2 | import click
3 | import rich
4 | import rich.traceback
5 | from kedro.io.data_catalog import DataCatalog
6 | from kedro.pipeline.node import Node
7 |
8 | from kedro_rich.utilities.kedro_override_utils import (
9 | override_catalog_load,
10 | override_catalog_save,
11 | override_kedro_cli_get_command,
12 | override_kedro_proj_logging_handler,
13 | override_node_str,
14 | )
15 |
16 |
17 | def override_kedro_lib_logging():
18 | """This method overrides default Kedro methods to prettify the logging
19 | output, longer term this could just involve changes to Kedro core.
20 | """
21 |
22 | Node.__str__ = override_node_str
23 | DataCatalog.load = override_catalog_load
24 | DataCatalog.save = override_catalog_save
25 |
26 |
27 | def apply_rich_tracebacks():
28 | """
29 | This method ensures that tracebacks raised by the Kedro project
30 | go through the rich traceback method
31 |
32 | The `suppress=[click]` argument means that exceptions will not
33 | show the frames related to the CLI framework and only the actual
34 | logic the user defines.
35 | """
36 | rich.traceback.install(show_locals=False, suppress=[click])
37 |
38 |
39 | def start_up():
40 | """This method runs the setup methods needed to override
41 | certain defaults at start up
42 | """
43 | override_kedro_proj_logging_handler()
44 | override_kedro_lib_logging()
45 | override_kedro_cli_get_command()
46 | apply_rich_tracebacks()
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # CMake
2 | cmake-build-debug/
3 |
4 | ## File-based project format:
5 | *.iws
6 |
7 | ## Plugin-specific files:
8 |
9 | # IntelliJ
10 | .idea/
11 | *.iml
12 | out/
13 |
14 | ### macOS
15 | *.DS_Store
16 | .AppleDouble
17 | .LSOverride
18 | .Trashes
19 |
20 | # mpeltonen/sbt-idea plugin
21 | .idea_modules/
22 |
23 | # JIRA plugin
24 | atlassian-ide-plugin.xml
25 |
26 | # Crashlytics plugin (for Android Studio and IntelliJ)
27 | com_crashlytics_export_strings.xml
28 | crashlytics.properties
29 | crashlytics-build.properties
30 | fabric.properties
31 |
32 | ### Python template
33 | # Byte-compiled / optimized / DLL files
34 | __pycache__/
35 | *.py[cod]
36 | *$py.class
37 |
38 | # C extensions
39 | *.so
40 |
41 | # Distribution / packaging
42 | .Python
43 | build/
44 | develop-eggs/
45 | dist/
46 | downloads/
47 | eggs/
48 | .eggs/
49 | lib/
50 | lib64/
51 | parts/
52 | sdist/
53 | var/
54 | wheels/
55 | *.egg-info/
56 | .installed.cfg
57 | *.egg
58 | MANIFEST
59 |
60 | # PyInstaller
61 | # Usually these files are written by a python script from a template
62 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
63 | *.manifest
64 | *.spec
65 |
66 | # Installer logs
67 | pip-log.txt
68 | pip-delete-this-directory.txt
69 |
70 | # Unit test / coverage reports
71 | htmlcov/
72 | .tox/
73 | .coverage
74 | .coverage.*
75 | .cache
76 | nosetests.xml
77 | coverage.xml
78 | *.cover
79 | .hypothesis/
80 |
81 | # Translations
82 | *.mo
83 | *.pot
84 |
85 | # Django
86 | *.log
87 | .static_storage/
88 | .media/
89 | local_settings.py
90 |
91 | # Flask
92 | instance/
93 | .webassets-cache
94 |
95 | # Scrapy
96 | .scrapy
97 |
98 | # PyBuilder
99 | target/
100 |
101 | # Jupyter Notebook
102 | .ipynb_checkpoints
103 |
104 | # pyenv
105 | .python-version
106 |
107 | # Celery beat schedule file
108 | celerybeat-schedule
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # MkDocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 |
135 | # Visual Studio Code
136 | .vscode/
137 | # end to end tests assets
138 |
139 | # Vim
140 | *~
141 | .*.swo
142 | .*.swp
143 |
144 | .pytest_cache/
145 | docs/tmp-build-artifacts
146 | docs/build
147 |
148 | test_project/
149 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 |
4 | default_stages: [commit, manual]
5 |
6 | repos:
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v2.2.3
9 | hooks:
10 | - id: trailing-whitespace
11 | - id: end-of-file-fixer
12 | - id: check-yaml # Checks yaml files for parseable syntax.
13 | - id: check-case-conflict # Check for files that would conflict in case-insensitive filesystems
14 | - id: check-merge-conflict # Check for files that contain merge conflict strings.
15 | - id: debug-statements # Check for debugger imports and py37+ `breakpoint()` calls in python source.
16 | - id: requirements-txt-fixer # Sorts entries in requirements.txt
17 | - id: flake8
18 | files: ^kedro_rich/
19 | args:
20 | - "--max-line-length=88"
21 | - "--max-complexity=18"
22 | - "--select=B,C,E,F,W,T4,B9"
23 | - "--ignore=E126,E203,E231,E266,E501,W503"
24 |
25 |
26 | - repo: local
27 | hooks:
28 | # It's impossible to specify per-directory configuration, so we just run it many times.
29 | # https://github.com/PyCQA/pylint/issues/618
30 | # The first set of pylint checks if for local pre-commit, it only runs on the files changed.
31 | - id: pylint-quick-kedro-rich
32 | name: "Quick PyLint on kedro_rich/*"
33 | language: system
34 | types: [file, python]
35 | files: ^kedro_rich/
36 | entry: pylint --disable=unnecessary-pass,too-many-locals
37 | stages: [commit]
38 | - id: pylint-quick-tests
39 | name: "Quick PyLint on tests/*"
40 | language: system
41 | types: [file, python]
42 | files: ^kedro_rich/tests/
43 | entry: pylint --disable=missing-docstring,redefined-outer-name,no-self-use,invalid-name,protected-access,too-many-arguments
44 | stages: [commit]
45 | # The same pylint checks, but running on all files. It's for manual run with `make lint`
46 | - id: pylint-kedro_rich
47 | name: "PyLint on kedro_rich/*"
48 | language: system
49 | pass_filenames: false
50 | stages: [manual]
51 | entry: pylint --disable=unnecessary-pass kedro_rich
52 | - id: pylint-tests
53 | name: "PyLint on tests/*"
54 | language: system
55 | pass_filenames: false
56 | stages: [manual]
57 | entry: pylint --disable=missing-docstring,redefined-outer-name,no-self-use,invalid-name,protected-access,too-many-arguments tests
58 | - id: isort
59 | name: "Sort imports"
60 | language: system
61 | types: [ file, python ]
62 | files: ^kedro_rich/
63 | entry: isort
64 | - id: black
65 | name: "Black"
66 | language: system
67 | pass_filenames: false
68 | entry: black kedro_rich tests
69 |
--------------------------------------------------------------------------------
/tests/utilities/test_formatting_utils.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from kedro.io import MemoryDataSet
3 | from kedro_rich.utilities.catalog_utils import (
4 | get_catalog_datasets,
5 | get_datasets_by_pipeline,
6 | report_datasets_as_list,
7 | )
8 | from kedro_rich.utilities.formatting_utils import (
9 | prepare_rich_table,
10 | get_kedro_logo,
11 | print_kedro_pipeline_init_screen,
12 | )
13 | from kedro_rich.constants import KEDRO_RICH_CATALOG_LIST_THRESHOLD
14 | from kedro.pipeline import Pipeline
15 | import kedro
16 |
17 |
18 | def test_get_kedro_logo():
19 | logo = get_kedro_logo()
20 | assert len(logo) == 9
21 | assert max(len(x) for x in logo) == 43
22 |
23 |
24 | def test_prepare_rich_table_no_namespace(data_catalog_fixture, pipeline_fixture):
25 | registry = {"__default__": pipeline_fixture}
26 | catalog_datasets = get_catalog_datasets(data_catalog_fixture, drop_params=True)
27 | pipeline_datasets = get_datasets_by_pipeline(data_catalog_fixture, registry)
28 | mapped_datasets = report_datasets_as_list(pipeline_datasets, catalog_datasets)
29 | table = prepare_rich_table(records=mapped_datasets, registry=registry)
30 | assert len(table.columns) == 2 + len(registry)
31 | assert len(table.rows) == len(mapped_datasets)
32 |
33 |
34 | def test_prepare_rich_table_namespace(data_catalog_fixture, pipeline_fixture):
35 | registry = {"__default__": pipeline_fixture}
36 | data_catalog = deepcopy(data_catalog_fixture)
37 | data_catalog.add("namespace.dataset_name", MemoryDataSet())
38 | catalog_datasets = get_catalog_datasets(data_catalog, drop_params=True)
39 | pipeline_datasets = get_datasets_by_pipeline(data_catalog, registry)
40 | mapped_datasets = report_datasets_as_list(pipeline_datasets, catalog_datasets)
41 | table = prepare_rich_table(records=mapped_datasets, registry=registry)
42 | assert len(table.columns) == 3 + len(registry)
43 | assert len(table.rows) == len(mapped_datasets)
44 |
45 |
46 | def test_prepare_rich_table_threshold(data_catalog_fixture, pipeline_fixture):
47 | registry = {"__default__": pipeline_fixture}
48 | custom_registry = deepcopy(registry)
49 | for i in range(KEDRO_RICH_CATALOG_LIST_THRESHOLD):
50 | custom_registry[f"custom_{i}"] = Pipeline([])
51 | catalog_datasets = get_catalog_datasets(data_catalog_fixture, drop_params=True)
52 | pipeline_datasets = get_datasets_by_pipeline(data_catalog_fixture, custom_registry)
53 | mapped_datasets = report_datasets_as_list(pipeline_datasets, catalog_datasets)
54 | table = prepare_rich_table(records=mapped_datasets, registry=custom_registry)
55 | assert len(table.columns) == 3
56 | assert len(table.rows) == len(mapped_datasets)
57 |
58 |
59 | def test_print_kedro_pipeline_init_screen(capsys):
60 |
61 | print_kedro_pipeline_init_screen()
62 | captured = capsys.readouterr()
63 |
64 | assert kedro.__version__ in captured.out
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # kedro-rich
2 |
3 | ## Make your Kedro snazzy
4 |
5 | This is a very early work in progress Kedro plugin that utilises the awesome rich library.
6 |
7 | The intention with this piece of work is to battle test the idea, iron out the creases potentially to integrate this as a 1st class plugin hosted at kedro-org/plugins or if we're lucky, native functionality within Kedro itself.
8 |
9 | I'm very much looking for help developing/testing this project so if you want to get involved please get in touch.
10 |
11 | ## Current Functionality
12 |
13 | ### Overridden `kedro run` command
14 |
15 | - Does exactly the same as a regular Kedro run but kicks the progress bars into account.
16 | - The load/save progress tasks focus purely on persisted data and ignore ephemeral `MemoryDataSets`.
17 |
18 | 
19 |
20 | - The progress bars are currently disabuled when using `ParallelRunner` since `MultiProcessing` is causing issues between `kedro` and `rich`. Further investigation if some sort of `Lock()` mechanism will allow for this to work.
21 |
22 | ### Logging via `rich.logging.RichHandler`
23 |
24 | - This plugin changes the default `stdout` console logging handler in place of the class provided by `rich` .
25 | - This is actually required to make the progress bars work without being broken onto new lines every time a new log message appears.
26 | - At this point we also enable the [rich traceback handler](https://rich.readthedocs.io/en/stable/traceback.html).
27 | - In order enable this purely plug-in side (i.e. not making the user change `logging.yml`) I've had to do an ugly bit of monkey patching. Keen to come up with a better solution here.
28 |
29 | ### Overridden `kedro catalog list` command
30 |
31 | Accepts following options:
32 |
33 | - `--format=yaml` provides YAML representation to stdout that can be piped into other utilities
34 | - `--format=json` provides JSON representation to stdout that can be piped into other utilities
35 | - `--format=table` provides pretty representation to console for human consumption
36 |
37 | 
38 |
39 | ## Install the plug-in
40 |
41 | ### (Option 1) Cloning the repository
42 |
43 | The plug-in is in very early days so it will be a while before (if) this makes it to pypi
44 |
45 | 1. Clone the repository
46 | 2. Run `make dev-install` to install this to your environment
47 | 3. Go to any Kedro 0.17.x project and see if it works! (Please let me know if it doesn't).
48 |
49 | ### (Option 2) Direct from GitHub
50 |
51 | 1. Run `pip install git+https://github.com/datajoely/kedro-rich` to install this to your environment.
52 | 2. Go to any Kedro 0.17.x project and see if it works! (Please let me know if it doesn't).
53 |
54 | ## Run end to end example
55 |
56 | Running `make test-project` then `make test-run` will...
57 |
58 | - Install the `kedro-rich` package into the environment
59 | - Pull the 'spaceflights' `kedro-starter`
60 | - Install requirements
61 | - Execute `kedro run`
62 |
--------------------------------------------------------------------------------
/tests/utilities/test_catalog_utils.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import pytest
3 |
4 | from kedro_rich.utilities.catalog_utils import (
5 | filter_datasets_by_pipeline,
6 | get_datasets_by_pipeline,
7 | resolve_pipeline_namespace,
8 | resolve_catalog_namespace,
9 | get_catalog_datasets,
10 | split_catalog_namespace_key,
11 | report_datasets_as_list,
12 | )
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "given_key,expected_key",
17 | [
18 | ("ds.something.something_else.key", "ds__something__something_else__key"),
19 | ("ds.key", "ds__key"),
20 | ("key", "key"),
21 | ],
22 | )
23 | def test_resolve_pipeline_namespace(given_key: str, expected_key: str):
24 | assert resolve_pipeline_namespace(given_key) == expected_key
25 |
26 |
27 | @pytest.mark.parametrize(
28 | "given_key,expected_key",
29 | [
30 | ("ds__something__something_else__key", "ds.something.something_else.key"),
31 | ("ds__key", "ds.key"),
32 | ("key", "key"),
33 | ],
34 | )
35 | def test_resolve_catalog_namespace(given_key: str, expected_key: str):
36 | assert resolve_catalog_namespace(given_key) == expected_key
37 |
38 |
39 | def test_get_catalog_datasets(data_catalog_fixture):
40 | datasets_remaining = get_catalog_datasets(
41 | catalog=data_catalog_fixture, exclude_types=("MemoryDataSet",), drop_params=True
42 | )
43 | assert len(datasets_remaining) == 2
44 | assert all(x == "PickleDataSet" for x in datasets_remaining.values())
45 |
46 | datasets_remaining_keep_everything = get_catalog_datasets(
47 | catalog=data_catalog_fixture, drop_params=False
48 | )
49 |
50 | assert len(datasets_remaining_keep_everything) == 4
51 | assert "MemoryDataSet" in datasets_remaining_keep_everything.values()
52 |
53 |
54 | def test_filter_datasets_by_pipeline(data_catalog_fixture, pipeline_fixture):
55 | pipe_datasets = filter_datasets_by_pipeline(
56 | datasets=get_catalog_datasets(
57 | catalog=data_catalog_fixture,
58 | exclude_types=("MemoryDataSet",),
59 | drop_params=True,
60 | ),
61 | pipeline=pipeline_fixture,
62 | )
63 | all_dataset_keys = reduce(
64 | lambda x, y: x | y, [set(x.keys()) for x in pipe_datasets]
65 | )
66 | all_dataset_types = reduce(
67 | lambda x, y: x | y, [set(x.values()) for x in pipe_datasets]
68 | )
69 | assert all_dataset_keys == {"dataset_1", "dataset_3"}
70 | assert all_dataset_types == {"PickleDataSet"}
71 |
72 |
73 | def test_get_datasets_by_pipeline(data_catalog_fixture, pipeline_fixture):
74 | dataset_pipes = get_datasets_by_pipeline(
75 | catalog=data_catalog_fixture, registry={"__default__": pipeline_fixture}
76 | )
77 | assert set(reduce(lambda x, y: x + y, dataset_pipes.values())) == {"__default__"}
78 | assert set(dataset_pipes.keys()) == {"dataset_1", "dataset_2", "dataset_3"}
79 |
80 |
81 | @pytest.mark.parametrize(
82 | "given_key,expected_namespace,expected_key",
83 | [
84 | ("ds.something.something_else.key", "ds.something.something_else", "key"),
85 | ("ds.key", "ds", "key"),
86 | ("key", None, "key"),
87 | ],
88 | )
89 | def test_split_catalog_namespace_key(given_key, expected_namespace, expected_key):
90 | namespace, key = split_catalog_namespace_key(given_key)
91 | assert (namespace, key) == (expected_namespace, expected_key)
92 |
93 |
94 | def test_report_datasets_as_list(data_catalog_fixture, pipeline_fixture):
95 | reported_list = report_datasets_as_list(
96 | catalog_datasets=get_catalog_datasets(
97 | catalog=data_catalog_fixture,
98 | exclude_types=("MemoryDataSet",),
99 | drop_params=True,
100 | ),
101 | pipeline_datasets=get_datasets_by_pipeline(
102 | catalog=data_catalog_fixture, registry={"__default__": pipeline_fixture}
103 | ),
104 | )
105 | assert [
106 | {
107 | "dataset_type": "PickleDataSet",
108 | "pipelines": ["__default__"],
109 | "namespace": None,
110 | "key": "dataset_1",
111 | },
112 | {
113 | "dataset_type": "PickleDataSet",
114 | "pipelines": ["__default__"],
115 | "namespace": None,
116 | "key": "dataset_3",
117 | },
118 | ] == reported_list
119 |
--------------------------------------------------------------------------------
/kedro_rich/rich_cli.py:
--------------------------------------------------------------------------------
1 | """Command line tools for manipulating a Kedro project.
2 | Intended to be invoked via `kedro`."""
3 | import importlib
4 | import json
5 | import os
6 | from typing import Callable, Dict
7 |
8 | import click
9 | import rich_click
10 | import yaml
11 | from kedro.framework.cli.catalog import _create_session, create_catalog
12 | from kedro.framework.cli.project import run
13 | from kedro.framework.cli.utils import CONTEXT_SETTINGS, env_option
14 | from kedro.framework.startup import ProjectMetadata
15 | from kedro.pipeline import Pipeline
16 | from rich import box
17 | from rich.panel import Panel
18 |
19 | from kedro_rich.constants import KEDRO_RICH_PROGRESS_ENV_VAR_KEY
20 | from kedro_rich.utilities.catalog_utils import (
21 | get_catalog_datasets,
22 | get_datasets_by_pipeline,
23 | report_datasets_as_list,
24 | )
25 | from kedro_rich.utilities.formatting_utils import prepare_rich_table
26 |
27 |
28 | @click.group(context_settings=CONTEXT_SETTINGS, name="kedro-rich")
29 | def commands():
30 | """Command line tools for manipulating a Kedro project."""
31 |
32 |
33 | def handle_parallel_run(func: Callable) -> Callable:
34 | """
35 | This method raps the run command callback so that the
36 | parallel runner is disabled
37 | """
38 |
39 | def wrapped(*args, **kwargs):
40 | """
41 | This method will add an environment variable if the user
42 | selects a parallel run, kedro-rich will disable the progress bar in
43 | that situation
44 | """
45 |
46 | # Only run progress bars if (1) In complex mode (2) NOT ParallelRunner
47 | if not kwargs["simple"]:
48 | if not kwargs["parallel"]:
49 | os.environ[KEDRO_RICH_PROGRESS_ENV_VAR_KEY] = "1"
50 |
51 | # drop 'simple' kwarg as that doesn't exist in the original func
52 | original_kwargs = {k: v for k, v in kwargs.items() if "simple" != k}
53 | result = func(*args, **original_kwargs)
54 |
55 | if os.environ.get(KEDRO_RICH_PROGRESS_ENV_VAR_KEY):
56 | del os.environ[KEDRO_RICH_PROGRESS_ENV_VAR_KEY]
57 | return result
58 |
59 | return wrapped
60 |
61 |
62 | run.callback = handle_parallel_run(run.callback)
63 | run.__class__ = rich_click.RichCommand
64 | commands.add_command(
65 | click.option(
66 | "--simple",
67 | "-s",
68 | default=False,
69 | is_flag=True,
70 | help="Disable rich progress bars (simple mode)",
71 | )(run)
72 | )
73 |
74 |
75 | @commands.group(cls=rich_click.RichGroup)
76 | def catalog():
77 | """Commands for working with catalog."""
78 | pass
79 |
80 |
81 | catalog.add_command(create_catalog)
82 |
83 |
84 | @catalog.command(cls=rich_click.RichCommand, name="list")
85 | @env_option
86 | @click.option(
87 | "--format",
88 | "-f",
89 | "fmt",
90 | default="yaml",
91 | type=click.Choice(["yaml", "json", "table"], case_sensitive=False),
92 | help="Output the 'yaml' (default) / 'json' results to stdout or pretty"
93 | " print 'table' to console",
94 | )
95 | @click.pass_obj
96 | def list_datasets(metadata: ProjectMetadata, fmt: str, env: str):
97 | """Detail datasets by type."""
98 |
99 | # Needed to avoid circular reference
100 | from rich.console import Console # pylint: disable=import-outside-toplevel
101 |
102 | pipelines = _get_pipeline_registry(metadata)
103 | session = _create_session(metadata.package_name, env=env)
104 | context = session.load_context()
105 | catalog_datasets = get_catalog_datasets(context.catalog, drop_params=True)
106 | pipeline_datasets = get_datasets_by_pipeline(context.catalog, pipelines)
107 | mapped_datasets = report_datasets_as_list(pipeline_datasets, catalog_datasets)
108 | console = Console()
109 |
110 | if fmt == "yaml":
111 | struct = {
112 | f"{x['namespace']}.{x['key']}" if x["namespace"] else x["key"]: x
113 | for x in mapped_datasets
114 | }
115 | console.out(yaml.safe_dump(struct))
116 | if fmt == "json":
117 | console.out(json.dumps(mapped_datasets, indent=2))
118 | elif fmt == "table":
119 | table = prepare_rich_table(records=mapped_datasets, registry=pipelines)
120 | console.print(
121 | "\n",
122 | Panel(
123 | table,
124 | expand=False,
125 | title=f"Catalog contains [b][cyan]{len(mapped_datasets)}[/][/] persisted datasets",
126 | padding=(1, 1),
127 | box=box.MINIMAL,
128 | ),
129 | )
130 |
131 |
132 | def _get_pipeline_registry(proj_metadata: ProjectMetadata) -> Dict[str, Pipeline]:
133 | """
134 | This method retrieves the pipelines registered in the project where
135 | the plugin in installed
136 | """
137 | # is this the right 0.18.x version of doing this?
138 | # The object is no longer in the context
139 | registry = importlib.import_module(
140 | f"{proj_metadata.package_name}.pipeline_registry"
141 | )
142 | return registry.register_pipelines()
143 |
--------------------------------------------------------------------------------
/kedro_rich/utilities/formatting_utils.py:
--------------------------------------------------------------------------------
1 | """This module provides utilities that are helpful
2 | for printing to the rich console"""
3 |
4 | from typing import Any, Dict, List, Optional, Tuple
5 |
6 | import kedro
7 | from kedro.pipeline import Pipeline
8 | from rich import box
9 | from rich.console import Console
10 | from rich.style import Style
11 | from rich.table import Table
12 |
13 | from kedro_rich.constants import KEDRO_RICH_CATALOG_LIST_THRESHOLD
14 |
15 |
16 | def prepare_rich_table(
17 | records: List[Dict[str, Any]], registry: Dict[str, Pipeline]
18 | ) -> Table:
19 | """This method will build a rich.Table object based on the
20 | a given list of records and a dictionary of registered pipelines
21 |
22 | Args:
23 | records (List[Dict[str, Any]]): The catalog records
24 | pipes (Dict[str, Pipeline]): The pipelines to map to linked datasets
25 |
26 | Returns:
27 | Table: The table to render
28 | """
29 |
30 | table = Table(show_header=True, header_style=Style(color="white"), box=box.ROUNDED)
31 | # only include namespace if at least one present in catalog
32 | includes_namespaces = any(x["namespace"] for x in records)
33 | collapse_pipes = len(registry.keys()) > KEDRO_RICH_CATALOG_LIST_THRESHOLD
34 |
35 | # define table headers
36 | namespace_columns = ["namespace"] if includes_namespaces else []
37 | pipe_columns = ["pipeline_count"] if collapse_pipes else list(registry.keys())
38 | columns_to_add = namespace_columns + ["dataset_name", "dataset_type"] + pipe_columns
39 |
40 | # add table headers
41 | for column in columns_to_add:
42 | table.add_column(column, justify="center")
43 |
44 | # add table rows
45 | for index, row in enumerate(records):
46 |
47 | def _describe_boundary(
48 | index: int, records: List[Dict[str, Any]], key: str, current_value: str
49 | ) -> Tuple[bool, bool]:
50 | """
51 | Give a list of dictionaries, key and current value this method will
52 | return two booleans detailing if the sequence has changed or not
53 | """
54 | same_section = (
55 | index + 1 < len(records) and records[index + 1][key] == current_value
56 | )
57 | new_section = index == 0 or records[index - 1][key] != current_value
58 |
59 | return same_section, new_section
60 |
61 | # work out if the dataset_type is the same / different to next row
62 | same_section, new_section = _describe_boundary(
63 | index=index,
64 | records=records,
65 | key="dataset_type",
66 | current_value=row["dataset_type"],
67 | )
68 |
69 | # add namespace if present
70 | if includes_namespaces:
71 | table_namespace = (
72 | [row["namespace"]] if row["namespace"] else ["[bright_black]n/a[/]"]
73 | )
74 | else:
75 | table_namespace = []
76 |
77 | # get catalog key
78 | table_dataset_name = [row["key"]]
79 |
80 | # get dataset_type, only show if different from the last record
81 | table_dataset_type = (
82 | [f"[magenta][b]{row['dataset_type']}[/][/]"] if new_section else [""]
83 | )
84 |
85 | # get pipelines attached to this dataset
86 | dataset_pipes = row["pipelines"]
87 | # get pipelines registered in this project
88 | proj_pipes = sorted(registry.keys())
89 |
90 | # if too many pipelines registered, simply show the count
91 | if collapse_pipes:
92 | table_pipes = [str(len(dataset_pipes))]
93 | else:
94 |
95 | # show ✓ and ✘ if present
96 | table_pipes = [
97 | "[bold green]✓[/]"
98 | if (pipe in (set(proj_pipes) & set(dataset_pipes)))
99 | else "[bold red]✘[/]"
100 | for pipe in proj_pipes
101 | ]
102 |
103 | # build full row
104 | renderables = (
105 | table_namespace + table_dataset_name + table_dataset_type + table_pipes
106 | )
107 |
108 | # add row to table
109 | table.add_row(*renderables, end_section=not same_section)
110 | return table
111 |
112 |
113 | def get_kedro_logo(color: str = "orange1") -> Optional[List[str]]:
114 | """This method constructs an ascii Kedro logo"""
115 | diamond = """
116 | -
117 | ·===·
118 | ·==: :==·
119 | ·==: :==·
120 | ·==: :==·
121 | ·===·
122 | -
123 | """.split(
124 | "\n"
125 | )
126 | color_rows = [
127 | f"[{color}][b]{x}[/b][/{color}]" if x.strip() else "" for x in diamond
128 | ]
129 |
130 | return color_rows
131 |
132 |
133 | def print_kedro_pipeline_init_screen(
134 | title_color: str = "orange1", tagline_color: str = "gray"
135 | ):
136 | """This method prints the Kedro logo and package metadata"""
137 |
138 | tagline_text = "Reproducible, maintainable and modular data science code"
139 | lib_info = dict(
140 | title=f"[{title_color}][b]KEDRO[/][/{title_color}] ({kedro.__version__})",
141 | tagline=f"[{tagline_color}][i]{tagline_text}[/][/]",
142 | github="https://github.com/kedro-org/kedro",
143 | rtd="https://kedro.readthedocs.io",
144 | )
145 |
146 | logo_rows = get_kedro_logo()
147 | mapping = ((2, "title"), (3, "tagline"), (-3, "github"), (-2, "rtd"))
148 | for index, key in mapping:
149 | spacing = (51 - len(logo_rows[index])) * " "
150 | logo_rows[index] = logo_rows[index] + spacing + lib_info[key]
151 |
152 | str_rows = "\n".join(logo_rows)
153 | Console().print(str_rows, no_wrap=True)
154 |
--------------------------------------------------------------------------------
/kedro_rich/utilities/catalog_utils.py:
--------------------------------------------------------------------------------
1 | """This module includes helper functions for managing the data catalog"""
2 | import operator
3 | from functools import reduce
4 | from itertools import groupby
5 | from typing import Any, Dict, List, Optional, Set, Tuple
6 |
7 | from kedro.io import DataCatalog
8 | from kedro.pipeline import Pipeline
9 |
10 |
11 | def get_catalog_datasets(
12 | catalog: DataCatalog, exclude_types: Tuple[str] = (), drop_params: bool = False
13 | ) -> Dict[str, str]:
14 | """Filter to only persisted datasets"""
15 | datasets_filtered = {
16 | k: type(v).__name__
17 | for k, v in catalog.datasets.__dict__.items()
18 | if type(v).__name__ not in exclude_types
19 | }
20 | if drop_params:
21 | datasets_w_param_filter = {
22 | k: v
23 | for k, v in datasets_filtered.items()
24 | if not k.startswith("params") and not k == "parameters"
25 | }
26 | return datasets_w_param_filter
27 | return datasets_filtered
28 |
29 |
30 | def filter_datasets_by_pipeline(
31 | datasets: Dict[str, str], pipeline: Pipeline
32 | ) -> Tuple[Dict[str, str], Dict[str, str]]:
33 | """
34 | Retrieve datasets (inputs and outputs) which intersect against
35 | a given pipeline object
36 |
37 | This function also ensures namespaces are correctly rationalised.
38 | """
39 |
40 | def _clean_names(datasets: List[str], namespace: Optional[str]) -> Set[str]:
41 | if namespace:
42 | return {resolve_pipeline_namespace(x) for x in datasets}
43 | return set(datasets)
44 |
45 | inputs = reduce(
46 | lambda a, x: a | _clean_names(x.inputs, x.namespace), pipeline.nodes, set()
47 | )
48 | outputs = reduce(
49 | lambda a, x: a | _clean_names(x.outputs, x.namespace), pipeline.nodes, set()
50 | )
51 |
52 | pipeline_inputs = {
53 | k: v for k, v in datasets.items() if any(x.endswith(k) for x in inputs)
54 | }
55 | pipeline_outputs = {
56 | k: v for k, v in datasets.items() if any(x.endswith(k) for x in outputs)
57 | }
58 | return pipeline_inputs, pipeline_outputs
59 |
60 |
61 | def get_datasets_by_pipeline(
62 | catalog: DataCatalog, registry: Dict[str, Pipeline]
63 | ) -> Dict[str, List[str]]:
64 | """This method will return a dictionary of datasets mapped to the
65 | list of pipelines they are used within
66 |
67 | Args:
68 | catalog (DataCatalog): The data catalog object
69 | pipelines (Dict[str, Pipeline]): The pipelines in this project
70 |
71 | Returns:
72 | Dict[str, List[str]]: The dataset to pipeline groups
73 | """
74 | # get non parameter dataset
75 | catalog_datasets = get_catalog_datasets(catalog=catalog, drop_params=True)
76 |
77 | # get node input and outputs
78 | pipeline_input_output_datasets = {
79 | pipeline_name: filter_datasets_by_pipeline(catalog_datasets, pipeline)
80 | for pipeline_name, pipeline in registry.items()
81 | }
82 |
83 | # get those that overlap with pipelines
84 | pipeline_datasets = {
85 | pipeline_name: reduce(
86 | lambda input, output: input.keys() | output.keys(), input_outputs
87 | )
88 | for pipeline_name, input_outputs in pipeline_input_output_datasets.items()
89 | }
90 |
91 | # get dataset to pipeline pairs
92 | dataset_pipeline_pairs = reduce(
93 | lambda x, y: x + y,
94 | (
95 | [(dataset, pipeline) for dataset in datasets]
96 | for pipeline, datasets in pipeline_datasets.items()
97 | ),
98 | )
99 |
100 | # get dataset to pipeline groups
101 | sorter = sorted(dataset_pipeline_pairs, key=operator.itemgetter(0))
102 |
103 | grouper = groupby(sorter, key=operator.itemgetter(0))
104 |
105 | dataset_pipeline_groups = {
106 | k: list(map(operator.itemgetter(1), v)) for k, v in grouper
107 | }
108 | return dataset_pipeline_groups
109 |
110 |
111 | def resolve_pipeline_namespace(dataset_name: str) -> str:
112 | """Resolves the dot to double underscore namespace
113 | discrepancy between pipeline inputs/outputs and catalog keys
114 | """
115 | return dataset_name.replace(".", "__")
116 |
117 |
118 | def resolve_catalog_namespace(dataset_name: str) -> str:
119 | """Resolves the double underscore to dot namespace
120 | discrepancy between catalog keys and pipeline inputs/outputs
121 | """
122 | return dataset_name.replace("__", ".")
123 |
124 |
125 | def split_catalog_namespace_key(dataset_name: str) -> Tuple[Optional[str], str]:
126 | """This method splits out a catalog name from it's namespace"""
127 | dataset_split = dataset_name.split(".")
128 | namespace = ".".join(dataset_split[:-1])
129 | if namespace:
130 | dataset_name = dataset_split[-1]
131 | return namespace, dataset_name
132 | return None, dataset_name
133 |
134 |
135 | def report_datasets_as_list(
136 | pipeline_datasets: Dict[str, List[str]], catalog_datasets: Dict[str, str]
137 | ) -> List[Dict[str, Any]]:
138 | """This method accepts the datasets present in the pipeline registry
139 | as well as the full data catalog and produces a list of records
140 | which include key metadata such as the type, namespace, linked pipelines
141 | and dataset name (ordered by type)
142 | """
143 | return sorted(
144 | (
145 | {
146 | **{
147 | "dataset_type": v,
148 | "pipelines": pipeline_datasets.get(k, []),
149 | },
150 | **dict(
151 | zip(
152 | ("namespace", "key"),
153 | split_catalog_namespace_key(
154 | dataset_name=resolve_catalog_namespace(k)
155 | ),
156 | )
157 | ),
158 | }
159 | for k, v in catalog_datasets.items()
160 | ),
161 | key=lambda x: x["dataset_type"],
162 | )
163 |
--------------------------------------------------------------------------------
/kedro_rich/utilities/kedro_override_utils.py:
--------------------------------------------------------------------------------
1 | """This module provides methods which we use to override default Kedro methods"""
2 | # pylint: disable=protected-access
3 | import logging
4 | from typing import Any, Callable, Optional, Set
5 |
6 | import rich
7 | from click.core import _check_multicommand
8 | from kedro.framework.cli.cli import KedroCLI
9 | from kedro.framework.session import KedroSession
10 | from kedro.io.core import AbstractVersionedDataSet, Version
11 | from rich.panel import Panel
12 |
13 | from kedro_rich.constants import KEDRO_RICH_LOGGING_HANDLER
14 |
15 |
16 | def override_node_str(self) -> str:
17 | """This method rich-ifies the node.__str__ method"""
18 |
19 | def _drop_namespaces(xset: Set[str]) -> Optional[Set]:
20 | """This method cleans up the namesapces"""
21 | split = {x.split(".")[-1] for x in xset}
22 | if split:
23 | return split
24 | return None
25 |
26 | func_name = f"[magenta]𝑓𝑥 {self._func_name}([/]"
27 | inputs = _drop_namespaces(self.inputs)
28 | bridge = "[magenta])[/] [cyan]➡[/] "
29 | outputs = _drop_namespaces(self.outputs)
30 | return f"{func_name}{inputs}{bridge}{outputs}"
31 |
32 |
33 | def override_catalog_load(self, name: str, version: str = None) -> Any:
34 | """Loads a registered data set (Rich-ified output).
35 |
36 | Args:
37 | name: A data set to be loaded.
38 | version: Optional argument for concrete data version to be loaded.
39 | Works only with versioned datasets.
40 |
41 | Returns:
42 | The loaded data as configured.
43 |
44 | Raises:
45 | DataSetNotFoundError: When a data set with the given name
46 | has not yet been registered.
47 |
48 | Example:
49 | ::
50 |
51 | >>> from kedro.io import DataCatalog
52 | >>> from kedro.extras.datasets.pandas import CSVDataSet
53 | >>>
54 | >>> cars = CSVDataSet(filepath="cars.csv",
55 | >>> load_args=None,
56 | >>> save_args={"index": False})
57 | >>> io = DataCatalog(data_sets={'cars': cars})
58 | >>>
59 | >>> df = io.load("cars")
60 | """
61 | load_version = Version(version, None) if version else None
62 | dataset = self._get_dataset(name, version=load_version)
63 |
64 | self._logger.info(
65 | "Loading data from [bright_blue]%s[/] ([bright_blue][b]%s[/][/])...",
66 | name,
67 | type(dataset).__name__,
68 | )
69 |
70 | func = self._get_transformed_dataset_function(name, "load", dataset)
71 | result = func()
72 |
73 | version = (
74 | dataset.resolve_load_version()
75 | if isinstance(dataset, AbstractVersionedDataSet)
76 | else None
77 | )
78 |
79 | # Log only if versioning is enabled for the data set
80 | if self._journal and version:
81 | self._journal.log_catalog(name, "load", version)
82 | return result
83 |
84 |
85 | def override_catalog_save(self, name: str, data: Any) -> None:
86 | """Save data to a registered data set.
87 |
88 | Args:
89 | name: A data set to be saved to.
90 | data: A data object to be saved as configured in the registered
91 | data set.
92 |
93 | Raises:
94 | DataSetNotFoundError: When a data set with the given name
95 | has not yet been registered.
96 |
97 | Example:
98 | ::
99 |
100 | >>> import pandas as pd
101 | >>>
102 | >>> from kedro.extras.datasets.pandas import CSVDataSet
103 | >>>
104 | >>> cars = CSVDataSet(filepath="cars.csv",
105 | >>> load_args=None,
106 | >>> save_args={"index": False})
107 | >>> io = DataCatalog(data_sets={'cars': cars})
108 | >>>
109 | >>> df = pd.DataFrame({'col1': [1, 2],
110 | >>> 'col2': [4, 5],
111 | >>> 'col3': [5, 6]})
112 | >>> io.save("cars", df)
113 | """
114 | dataset = self._get_dataset(name)
115 |
116 | self._logger.info(
117 | "Saving data to [bright_blue]%s[/] ([bright_blue][b]%s[/][/])...",
118 | name,
119 | type(dataset).__name__,
120 | )
121 |
122 | func = self._get_transformed_dataset_function(name, "save", dataset)
123 | func(data)
124 |
125 | version = (
126 | dataset.resolve_save_version()
127 | if isinstance(dataset, AbstractVersionedDataSet)
128 | else None
129 | )
130 |
131 | # Log only if versioning is enabled for the data set
132 | if self._journal and version:
133 | self._journal.log_catalog(name, "save", version)
134 |
135 |
136 | def override_kedro_proj_logging_handler():
137 | """
138 | This function does two things:
139 |
140 | (1) It mutates the dictionary provided by `logging.yml` to
141 | use the `rich.logging.RichHandler` instead of the standard output one
142 | (2) It enables the rich.Traceback handler so that exceptions are prettier
143 | """
144 |
145 | # ensure warnings are caught by logger not stout
146 | logging.captureWarnings(True)
147 |
148 | def _replace_console_handler(func: Callable) -> Callable:
149 | """This function mutates the dictionary returned by reading logging.yml"""
150 |
151 | def wrapped(*args, **kwargs):
152 | logging_config = func(*args, **kwargs)
153 | logging_config["handlers"]["console"] = KEDRO_RICH_LOGGING_HANDLER
154 | return logging_config
155 |
156 | return wrapped
157 |
158 | # pylint: disable=protected-access
159 | KedroSession._get_logging_config = _replace_console_handler(
160 | KedroSession._get_logging_config
161 | )
162 |
163 |
164 | def override_kedro_cli_get_command():
165 | """This method overrides the Click get_command() method
166 | so that we can give the user a useful message if they try to do a Kedro
167 | project command outside of a project directory
168 | """
169 |
170 | # pylint: disable=invalid-name
171 | # pylint: disable=inconsistent-return-statements
172 | def _get_command(self, ctx, cmd_name):
173 | for source in self.sources:
174 | rv = source.get_command(ctx, cmd_name)
175 | if rv is not None:
176 | if self.chain:
177 | _check_multicommand(self, cmd_name, rv)
178 | return rv
179 | if not self._metadata:
180 |
181 | warn = "[orange1][b]You are not in a Kedro project[/]![/]"
182 | result = "Project specific commands such as '[bright_cyan]run[/]' or \
183 | '[bright_cyan]jupyter[/]' are only available within a project directory."
184 | solution = "[bright_black][b]Hint:[/] [i]Kedro is looking for a file called \
185 | '[magenta]pyproject.toml[/]', is one present in your current working directory?[/][/]"
186 | msg = f"{warn} {result}\n\n{solution}"
187 | console = rich.console.Console()
188 | panel = Panel(
189 | msg,
190 | title=f"Command '{cmd_name}' not found",
191 | expand=False,
192 | border_style="dim",
193 | title_align="left",
194 | )
195 | console.print("\n", panel, "\n")
196 |
197 | KedroCLI.get_command = _get_command
198 |
--------------------------------------------------------------------------------
/kedro_rich/rich_progress_hooks.py:
--------------------------------------------------------------------------------
1 | """This module provides lifecycle hooks to track progress"""
2 | import logging
3 | import os
4 | import time
5 | from datetime import timedelta
6 | from typing import Any, Dict
7 |
8 | from kedro.framework.hooks import hook_impl
9 | from kedro.io import DataCatalog
10 | from kedro.pipeline import Pipeline
11 | from kedro.pipeline.node import Node
12 | from rich.progress import (
13 | BarColumn,
14 | Progress,
15 | ProgressColumn,
16 | SpinnerColumn,
17 | Task,
18 | TaskID,
19 | )
20 | from rich.text import Text
21 |
22 | from kedro_rich.constants import (
23 | KEDRO_RICH_PROGRESS_ENV_VAR_KEY,
24 | KEDRO_RICH_SHOW_DATASET_PROGRESS,
25 | )
26 | from kedro_rich.utilities.catalog_utils import (
27 | filter_datasets_by_pipeline,
28 | get_catalog_datasets,
29 | resolve_pipeline_namespace,
30 | split_catalog_namespace_key,
31 | )
32 | from kedro_rich.utilities.formatting_utils import print_kedro_pipeline_init_screen
33 |
34 |
35 | class RichProgressHooks:
36 | """These set of hooks add progress information to the output of a Kedro run"""
37 |
38 | def __init__(self):
39 | """This constructor initialises the variables used to manage state"""
40 | self.progress = None
41 | self.task_count = 0
42 | self.io_datasets_in_catalog = {}
43 | self.pipeline_inputs = {}
44 | self.pipeline_outputs = {}
45 | self.tasks = {}
46 |
47 | @hook_impl
48 | def before_pipeline_run(
49 | self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog
50 | ):
51 | """
52 | This method initialises the variables needed to track pipeline process. This
53 | will be disabled under parallel runner
54 | """
55 | if self._check_if_progress_bar_enabled():
56 | progress_desc_format = "[progress.description]{task.description}"
57 | progress_percentage_format = "[progress.percentage]{task.percentage:>3.0f}%"
58 | progress_activity_format = "{task.fields[activity]}"
59 | self.progress = Progress(
60 | _KedroElapsedColumn(),
61 | progress_desc_format,
62 | SpinnerColumn(),
63 | BarColumn(),
64 | progress_percentage_format,
65 | progress_activity_format,
66 | )
67 |
68 | # Get pipeline goals
69 | self._init_progress_tasks(pipeline, catalog)
70 |
71 | # Init tasks
72 | pipe_name = run_params.get("pipeline_name") or "__default__"
73 | input_cnt = len(self.pipeline_inputs)
74 | output_cnt = len(self.pipeline_outputs)
75 |
76 | dataset_tasks = (
77 | {
78 | "loads": self._add_task(desc="Loading datasets", count=input_cnt),
79 | "saves": self._add_task(desc="Saving datasets", count=output_cnt),
80 | }
81 | if KEDRO_RICH_SHOW_DATASET_PROGRESS
82 | else {}
83 | )
84 |
85 | overall_task = {
86 | "overall": self._add_task(
87 | desc=f"Running [bright_magenta]'{pipe_name}'[/] pipeline",
88 | count=self.task_count,
89 | )
90 | }
91 |
92 | self.tasks = {**dataset_tasks, **overall_task}
93 |
94 | print_kedro_pipeline_init_screen()
95 |
96 | # Start process
97 | self.progress.start()
98 | else:
99 | logger = logging.getLogger(__name__)
100 | logger.warning(
101 | "[orange1 bold]Progress bars are incompatible with ParallelRunner[/]",
102 | )
103 |
104 | @hook_impl
105 | def before_dataset_loaded(self, dataset_name: str):
106 | """
107 | Add the last dataset loaded (from persistent storage)
108 | to progress display
109 | """
110 | if KEDRO_RICH_SHOW_DATASET_PROGRESS:
111 | if self.progress:
112 | dataset_name_namespaced = resolve_pipeline_namespace(dataset_name)
113 | if dataset_name in self.pipeline_inputs:
114 | dataset_type = self.io_datasets_in_catalog[dataset_name_namespaced]
115 | dataset_desc = (
116 | f"📂{' ':<5}[i]{dataset_name}[/] ([bold cyan]{dataset_type}[/])"
117 | )
118 | self.progress.update(
119 | self.tasks["loads"], advance=1, activity=dataset_desc
120 | )
121 |
122 | @hook_impl
123 | def after_dataset_saved(self, dataset_name: str):
124 | """Add the last dataset persisted to progress display"""
125 | if KEDRO_RICH_SHOW_DATASET_PROGRESS:
126 | if self.progress:
127 | dataset_name_namespaced = resolve_pipeline_namespace(dataset_name)
128 |
129 | if dataset_name_namespaced in self.pipeline_outputs:
130 | namespace, key = split_catalog_namespace_key(dataset_name)
131 |
132 | data_string = (
133 | f"[blue]{namespace}[/].{key}" if namespace else f"{key}"
134 | )
135 |
136 | dataset_type = self.io_datasets_in_catalog[dataset_name_namespaced]
137 | dataset_desc = (
138 | f"💾{' ':<5}[i]{data_string}[/] ([bold cyan]{dataset_type}[/])"
139 | )
140 | self.progress.update(
141 | self.tasks["saves"], advance=1, activity=dataset_desc
142 | )
143 |
144 | @hook_impl
145 | def before_node_run(self, node: Node):
146 | """Add the current function name to progress display"""
147 | if self.progress:
148 | self.progress.update(
149 | self.tasks["overall"],
150 | activity=f"[violet]𝑓𝑥[/]{' ':<5}[orange1]{node.func.__name__}[/]()",
151 | )
152 |
153 | @hook_impl
154 | def after_node_run(self):
155 | """Increment the task count on node completion"""
156 | if self.progress:
157 | self.progress.update(self.tasks["overall"], advance=1)
158 |
159 | @hook_impl
160 | def after_pipeline_run(self):
161 | """Hook to complete and clean up progress information on pipeline completion"""
162 | if self.progress:
163 | self.progress.update(
164 | self.tasks["overall"],
165 | visible=True,
166 | activity="[bold green]✓ Pipeline complete[/] ",
167 | )
168 | if KEDRO_RICH_SHOW_DATASET_PROGRESS:
169 | self.progress.update(self.tasks["saves"], completed=100, visible=False)
170 | self.progress.update(self.tasks["loads"], completed=100, visible=False)
171 | time.sleep(0.1) # allows the UI to clean up after the process ends
172 |
173 | def _init_progress_tasks(self, pipeline: Pipeline, catalog: DataCatalog):
174 | """This method initialises the key Hook constructor attributes"""
175 | self.task_count = len(pipeline.nodes)
176 | self.io_datasets_in_catalog = get_catalog_datasets(
177 | catalog=catalog, exclude_types=("MemoryDataSet",)
178 | )
179 | (self.pipeline_inputs, self.pipeline_outputs,) = filter_datasets_by_pipeline(
180 | datasets=self.io_datasets_in_catalog, pipeline=pipeline
181 | )
182 |
183 | def _add_task(self, desc: str, count: int) -> TaskID:
184 | """This method adds a task to the progress bar"""
185 | return self.progress.add_task(desc, total=count, activity="")
186 |
187 | @staticmethod
188 | def _check_if_progress_bar_enabled() -> bool:
189 | """Convert env variable into boolean"""
190 | return bool(int(os.environ.get(KEDRO_RICH_PROGRESS_ENV_VAR_KEY, "0")))
191 |
192 |
193 | class _KedroElapsedColumn(ProgressColumn):
194 | """Renders time elapsed for top task only"""
195 |
196 | def render(self, task: Task) -> Text:
197 | """Show time remaining."""
198 | if task.id == 0:
199 | elapsed = task.finished_time if task.finished else task.elapsed
200 | if elapsed is None:
201 | return Text("-:--:--", style="cyan")
202 | delta = timedelta(seconds=int(elapsed))
203 | return Text(str(delta), style="green")
204 | return None
205 |
206 |
207 | rich_hooks = RichProgressHooks()
208 |
--------------------------------------------------------------------------------