├── trufflehog-ignore.txt ├── tests ├── test_rich_progress_hooks.py ├── conftest.py └── utilities │ ├── test_formatting_utils.py │ └── test_catalog_utils.py ├── requirements.txt ├── kedro_rich ├── __init__.py ├── constants.py ├── rich_init.py ├── rich_cli.py ├── utilities │ ├── formatting_utils.py │ ├── catalog_utils.py │ └── kedro_override_utils.py └── rich_progress_hooks.py ├── static └── list-datasets.png ├── setup.cfg ├── .coveragerc ├── pyproject.toml ├── test_requirements.txt ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── ISSUE_TEMPLATE │ ├── feature-request.md │ └── bug-report.md └── workflows │ └── ci.yml ├── Makefile ├── setup.py ├── .gitignore ├── .pre-commit-config.yaml └── README.md /trufflehog-ignore.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_rich_progress_hooks.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | kedro~=0.17.3 2 | rich-click~=1.2.1 3 | rich~=11.2.0 4 | -------------------------------------------------------------------------------- /kedro_rich/__init__.py: -------------------------------------------------------------------------------- 1 | """Rich plugin to make your Kedro snazzy""" 2 | 3 | __version__ = "0.0.14" 4 | -------------------------------------------------------------------------------- /static/list-datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datajoely/kedro-rich/HEAD/static/list-datasets.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E126,E203,E231,E266,E501,W503 3 | max-line-length = 88 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | fail_under=100 3 | show_missing=True 4 | omit = *tests* 5 | exclude_lines = 6 | pragma: no cover 7 | raise NotImplementedError 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | multi_line_output = 3 3 | include_trailing_comma = true 4 | force_grid_wrap = 0 5 | use_parentheses = true 6 | line_length = 88 7 | known_third_party = "kedro" 8 | 9 | [tool.black] 10 | -------------------------------------------------------------------------------- /test_requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | bandit>=1.6.2, <2.0 3 | black==v19.10.b0 4 | flake8 5 | isort>=4.3.21, <5.0 6 | pre-commit>=1.17.0, <2.0 7 | psutil==5.6.6 8 | pylint>=2.5.2, <3.0 9 | pytest 10 | pytest-cov 11 | pytest-mock 12 | trufflehog>=2.1.0, <3.0 13 | wheel 14 | -------------------------------------------------------------------------------- /kedro_rich/constants.py: -------------------------------------------------------------------------------- 1 | """Simple module for maintaining configuration""" 2 | 3 | KEDRO_RICH_PROGRESS_ENV_VAR_KEY = "KEDRO_RICH_SHOW_PROGRESS" 4 | KEDRO_RICH_SHOW_DATASET_PROGRESS = True 5 | 6 | KEDRO_RICH_LOGGING_HANDLER = { 7 | "class": "rich.logging.RichHandler", 8 | "level": "INFO", 9 | "markup": True, 10 | "log_time_format": "[%X]", 11 | } 12 | 13 | KEDRO_RICH_CATALOG_LIST_THRESHOLD = 10 14 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## Development notes 5 | 6 | 7 | ## Checklist 8 | 9 | - [ ] Read the [contributing](https://github.com/kedro-org/kedro/blob/main/CONTRIBUTING.md) guidelines 10 | - [ ] Opened this PR as a 'Draft Pull Request' if it is work-in-progress 11 | - [ ] Updated the documentation to reflect the code changes 12 | - [ ] Added a description of this change in the [`RELEASE.md`](https://github.com/kedro-org/kedro/blob/main/RELEASE.md) file 13 | - [ ] Added tests to cover my changes 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Let us know if you have a feature request or enhancement 4 | title: '' 5 | labels: 'Issue: Feature Request' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | Is your feature request related to a problem? A clear and concise description of what the problem is: "I'm always frustrated when ..." 12 | 13 | ## Context 14 | Why is this change important to you? How would you use it? How can it benefit other users? 15 | 16 | ## Possible Implementation 17 | (Optional) Suggest an idea for implementing the addition or change. 18 | 19 | ## Possible Alternatives 20 | (Optional) Describe any alternative solutions or features you've considered. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: If something isn't working 4 | title: '<Title>' 5 | labels: 'Issue: Bug Report' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | Short description of the problem here. 12 | 13 | ## Context 14 | How has this bug affected you? What were you trying to accomplish? 15 | 16 | ## Steps to Reproduce 17 | 1. [First Step] 18 | 2. [Second Step] 19 | 3. [And so on...] 20 | 21 | ## Expected Result 22 | Tell us what should happen. 23 | 24 | ## Actual Result 25 | Tell us what happens instead. 26 | 27 | ``` 28 | -- If you received an error, place it here. 29 | ``` 30 | 31 | ``` 32 | -- Separate them if you have more than one. 33 | ``` 34 | 35 | ## Your Environment 36 | Include as many relevant details about the environment in which you experienced the bug: 37 | 38 | * Kedro-rich version used (`pip show kedro-rich`): 39 | * Kedro version used (`pip show kedro` or `kedro -V`): 40 | * Python version used (`python -V`): 41 | * Operating system and version: 42 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from kedro.io.data_catalog import DataCatalog 2 | from kedro.io import MemoryDataSet 3 | from kedro.pipeline import Pipeline, node 4 | from kedro.extras.datasets.pickle import PickleDataSet 5 | import pytest 6 | 7 | 8 | @pytest.fixture() 9 | def data_catalog_fixture() -> DataCatalog: 10 | 11 | return DataCatalog( 12 | data_sets={ 13 | "dataset_1": PickleDataSet(filepath="test"), 14 | "dataset_2": MemoryDataSet(), 15 | "dataset_3": PickleDataSet(filepath="test"), 16 | }, 17 | feed_dict={"params.modelling_params": {"test_size": 0.3, "split_ratio": 0.7}}, 18 | ) 19 | 20 | 21 | @pytest.fixture() 22 | def pipeline_fixture() -> Pipeline: 23 | return Pipeline( 24 | nodes=[ 25 | node(func=lambda x: x, inputs="dataset_1", outputs="dataset_2"), 26 | node( 27 | func=lambda x: x, 28 | inputs="dataset_2", 29 | outputs="dataset_2.5", 30 | namespace="test", 31 | ), 32 | node(func=lambda x: x, inputs="dataset_2.5", outputs="dataset_3"), 33 | ] 34 | ) 35 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: kedro_rich_run 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | jobs: 8 | kedro_rich: 9 | name: kedro_rich_run_and_test 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Check out 14 | uses: actions/checkout@main 15 | with: 16 | ref: main 17 | 18 | - uses: actions/setup-python@v1 19 | with: 20 | python-version: "3.8.x" 21 | 22 | - name: Install dependencies 23 | run: | 24 | pip install -r requirements.txt 25 | 26 | - name: Pull spaceflights and install kedro-rich in editable mode 27 | run: make test-project 28 | 29 | - name: Kedro run (sequential) 30 | run: cd test_project; kedro run 31 | 32 | - name: Kedro run (parallel) 33 | run: cd test_project; kedro run --parallel 34 | 35 | - name: Kedro catalog list (json) 36 | run: cd test_project; kedro catalog list --format=json 37 | 38 | - name: Kedro catalog list (yaml) 39 | run: cd test_project; kedro catalog list --format=yaml 40 | 41 | - name: Kedro catalog list (table) 42 | run: cd test_project; kedro catalog list --format=table 43 | 44 | - name: Run PyTest 45 | run: make test 46 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | package: 2 | rm -Rf dist 3 | python setup.py sdist bdist_wheel 4 | 5 | install: package 6 | pip uninstall kedro-rich -y 7 | pip install -U dist/*.whl 8 | 9 | dev-install: 10 | pip install -e . 11 | 12 | install-pip-setuptools: 13 | python -m pip install -U pip setuptools wheel 14 | 15 | lint: 16 | pre-commit run -a --hook-stage manual 17 | 18 | test: 19 | pytest -vv tests --cov 20 | 21 | clean: 22 | rm -rf build dist pip-wheel-metadata .pytest_cache 23 | find . -regex ".*/__pycache__" -exec rm -rf {} + 24 | find . -regex ".*\.egg-info" -exec rm -rf {} + 25 | rm -rf test_project/ 26 | 27 | install-test-requirements: 28 | pip install -r test_requirements.txt 29 | 30 | install-pre-commit: install-test-requirements 31 | pre-commit install --install-hooks 32 | 33 | uninstall-pre-commit: 34 | pre-commit uninstall 35 | pre-commit uninstall --hook-type pre-push 36 | 37 | test-project: 38 | pip install -e . 39 | rm -rf test_project/ 40 | yes test_project | kedro new --starter=spaceflights 41 | pip install -r test_project/src/requirements.txt 42 | touch .telemetry 43 | echo "consent: false" >> .telemetry 44 | mv .telemetry test_project/ 45 | 46 | test-run: 47 | cd test_project; kedro run 48 | 49 | clear-test-run: 50 | rm -rf test_project/ 51 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | from codecs import open 3 | from os import path 4 | 5 | from setuptools import setup 6 | 7 | name = "kedro-rich" 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | # get package version 11 | package_name = name.replace("-", "_") 12 | with open(path.join(here, package_name, "__init__.py"), encoding="utf-8") as f: 13 | version = re.search(r'__version__ = ["\']([^"\']+)', f.read()).group(1) 14 | 15 | # get the dependencies and installs 16 | with open("requirements.txt", "r", encoding="utf-8") as f: 17 | requires = [x.strip() for x in f if x.strip()] 18 | 19 | # Get the long description from the README file 20 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 21 | readme = f.read() 22 | 23 | setup( 24 | name=name, 25 | version=version, 26 | description="Kedro-Rich", 27 | long_description=readme, 28 | long_description_content_type="text/markdown", 29 | author="datajoely", 30 | python_requires=">=3.6, <3.9", 31 | install_requires=requires, 32 | license="Apache Software License (Apache 2.0)", 33 | packages=["kedro_rich"], 34 | include_package_data=True, 35 | zip_safe=False, 36 | entry_points={ 37 | "kedro.project_commands": ["kedro_rich_project = kedro_rich.rich_cli:commands"], 38 | "kedro.hooks" : ["kedro_rich_progress = kedro_rich.rich_progress_hooks:rich_hooks"], 39 | "kedro.init" : ["kedro_rich_init = kedro_rich.rich_init:start_up"] 40 | }, 41 | ) 42 | -------------------------------------------------------------------------------- /kedro_rich/rich_init.py: -------------------------------------------------------------------------------- 1 | """This module ensures that the rich logging and exceptions handlers are used""" 2 | import click 3 | import rich 4 | import rich.traceback 5 | from kedro.io.data_catalog import DataCatalog 6 | from kedro.pipeline.node import Node 7 | 8 | from kedro_rich.utilities.kedro_override_utils import ( 9 | override_catalog_load, 10 | override_catalog_save, 11 | override_kedro_cli_get_command, 12 | override_kedro_proj_logging_handler, 13 | override_node_str, 14 | ) 15 | 16 | 17 | def override_kedro_lib_logging(): 18 | """This method overrides default Kedro methods to prettify the logging 19 | output, longer term this could just involve changes to Kedro core. 20 | """ 21 | 22 | Node.__str__ = override_node_str 23 | DataCatalog.load = override_catalog_load 24 | DataCatalog.save = override_catalog_save 25 | 26 | 27 | def apply_rich_tracebacks(): 28 | """ 29 | This method ensures that tracebacks raised by the Kedro project 30 | go through the rich traceback method 31 | 32 | The `suppress=[click]` argument means that exceptions will not 33 | show the frames related to the CLI framework and only the actual 34 | logic the user defines. 35 | """ 36 | rich.traceback.install(show_locals=False, suppress=[click]) 37 | 38 | 39 | def start_up(): 40 | """This method runs the setup methods needed to override 41 | certain defaults at start up 42 | """ 43 | override_kedro_proj_logging_handler() 44 | override_kedro_lib_logging() 45 | override_kedro_cli_get_command() 46 | apply_rich_tracebacks() 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # CMake 2 | cmake-build-debug/ 3 | 4 | ## File-based project format: 5 | *.iws 6 | 7 | ## Plugin-specific files: 8 | 9 | # IntelliJ 10 | .idea/ 11 | *.iml 12 | out/ 13 | 14 | ### macOS 15 | *.DS_Store 16 | .AppleDouble 17 | .LSOverride 18 | .Trashes 19 | 20 | # mpeltonen/sbt-idea plugin 21 | .idea_modules/ 22 | 23 | # JIRA plugin 24 | atlassian-ide-plugin.xml 25 | 26 | # Crashlytics plugin (for Android Studio and IntelliJ) 27 | com_crashlytics_export_strings.xml 28 | crashlytics.properties 29 | crashlytics-build.properties 30 | fabric.properties 31 | 32 | ### Python template 33 | # Byte-compiled / optimized / DLL files 34 | __pycache__/ 35 | *.py[cod] 36 | *$py.class 37 | 38 | # C extensions 39 | *.so 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib/ 50 | lib64/ 51 | parts/ 52 | sdist/ 53 | var/ 54 | wheels/ 55 | *.egg-info/ 56 | .installed.cfg 57 | *.egg 58 | MANIFEST 59 | 60 | # PyInstaller 61 | # Usually these files are written by a python script from a template 62 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 63 | *.manifest 64 | *.spec 65 | 66 | # Installer logs 67 | pip-log.txt 68 | pip-delete-this-directory.txt 69 | 70 | # Unit test / coverage reports 71 | htmlcov/ 72 | .tox/ 73 | .coverage 74 | .coverage.* 75 | .cache 76 | nosetests.xml 77 | coverage.xml 78 | *.cover 79 | .hypothesis/ 80 | 81 | # Translations 82 | *.mo 83 | *.pot 84 | 85 | # Django 86 | *.log 87 | .static_storage/ 88 | .media/ 89 | local_settings.py 90 | 91 | # Flask 92 | instance/ 93 | .webassets-cache 94 | 95 | # Scrapy 96 | .scrapy 97 | 98 | # PyBuilder 99 | target/ 100 | 101 | # Jupyter Notebook 102 | .ipynb_checkpoints 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # Celery beat schedule file 108 | celerybeat-schedule 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # MkDocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | 135 | # Visual Studio Code 136 | .vscode/ 137 | # end to end tests assets 138 | 139 | # Vim 140 | *~ 141 | .*.swo 142 | .*.swp 143 | 144 | .pytest_cache/ 145 | docs/tmp-build-artifacts 146 | docs/build 147 | 148 | test_project/ 149 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | 4 | default_stages: [commit, manual] 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v2.2.3 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: end-of-file-fixer 12 | - id: check-yaml # Checks yaml files for parseable syntax. 13 | - id: check-case-conflict # Check for files that would conflict in case-insensitive filesystems 14 | - id: check-merge-conflict # Check for files that contain merge conflict strings. 15 | - id: debug-statements # Check for debugger imports and py37+ `breakpoint()` calls in python source. 16 | - id: requirements-txt-fixer # Sorts entries in requirements.txt 17 | - id: flake8 18 | files: ^kedro_rich/ 19 | args: 20 | - "--max-line-length=88" 21 | - "--max-complexity=18" 22 | - "--select=B,C,E,F,W,T4,B9" 23 | - "--ignore=E126,E203,E231,E266,E501,W503" 24 | 25 | 26 | - repo: local 27 | hooks: 28 | # It's impossible to specify per-directory configuration, so we just run it many times. 29 | # https://github.com/PyCQA/pylint/issues/618 30 | # The first set of pylint checks if for local pre-commit, it only runs on the files changed. 31 | - id: pylint-quick-kedro-rich 32 | name: "Quick PyLint on kedro_rich/*" 33 | language: system 34 | types: [file, python] 35 | files: ^kedro_rich/ 36 | entry: pylint --disable=unnecessary-pass,too-many-locals 37 | stages: [commit] 38 | - id: pylint-quick-tests 39 | name: "Quick PyLint on tests/*" 40 | language: system 41 | types: [file, python] 42 | files: ^kedro_rich/tests/ 43 | entry: pylint --disable=missing-docstring,redefined-outer-name,no-self-use,invalid-name,protected-access,too-many-arguments 44 | stages: [commit] 45 | # The same pylint checks, but running on all files. It's for manual run with `make lint` 46 | - id: pylint-kedro_rich 47 | name: "PyLint on kedro_rich/*" 48 | language: system 49 | pass_filenames: false 50 | stages: [manual] 51 | entry: pylint --disable=unnecessary-pass kedro_rich 52 | - id: pylint-tests 53 | name: "PyLint on tests/*" 54 | language: system 55 | pass_filenames: false 56 | stages: [manual] 57 | entry: pylint --disable=missing-docstring,redefined-outer-name,no-self-use,invalid-name,protected-access,too-many-arguments tests 58 | - id: isort 59 | name: "Sort imports" 60 | language: system 61 | types: [ file, python ] 62 | files: ^kedro_rich/ 63 | entry: isort 64 | - id: black 65 | name: "Black" 66 | language: system 67 | pass_filenames: false 68 | entry: black kedro_rich tests 69 | -------------------------------------------------------------------------------- /tests/utilities/test_formatting_utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from kedro.io import MemoryDataSet 3 | from kedro_rich.utilities.catalog_utils import ( 4 | get_catalog_datasets, 5 | get_datasets_by_pipeline, 6 | report_datasets_as_list, 7 | ) 8 | from kedro_rich.utilities.formatting_utils import ( 9 | prepare_rich_table, 10 | get_kedro_logo, 11 | print_kedro_pipeline_init_screen, 12 | ) 13 | from kedro_rich.constants import KEDRO_RICH_CATALOG_LIST_THRESHOLD 14 | from kedro.pipeline import Pipeline 15 | import kedro 16 | 17 | 18 | def test_get_kedro_logo(): 19 | logo = get_kedro_logo() 20 | assert len(logo) == 9 21 | assert max(len(x) for x in logo) == 43 22 | 23 | 24 | def test_prepare_rich_table_no_namespace(data_catalog_fixture, pipeline_fixture): 25 | registry = {"__default__": pipeline_fixture} 26 | catalog_datasets = get_catalog_datasets(data_catalog_fixture, drop_params=True) 27 | pipeline_datasets = get_datasets_by_pipeline(data_catalog_fixture, registry) 28 | mapped_datasets = report_datasets_as_list(pipeline_datasets, catalog_datasets) 29 | table = prepare_rich_table(records=mapped_datasets, registry=registry) 30 | assert len(table.columns) == 2 + len(registry) 31 | assert len(table.rows) == len(mapped_datasets) 32 | 33 | 34 | def test_prepare_rich_table_namespace(data_catalog_fixture, pipeline_fixture): 35 | registry = {"__default__": pipeline_fixture} 36 | data_catalog = deepcopy(data_catalog_fixture) 37 | data_catalog.add("namespace.dataset_name", MemoryDataSet()) 38 | catalog_datasets = get_catalog_datasets(data_catalog, drop_params=True) 39 | pipeline_datasets = get_datasets_by_pipeline(data_catalog, registry) 40 | mapped_datasets = report_datasets_as_list(pipeline_datasets, catalog_datasets) 41 | table = prepare_rich_table(records=mapped_datasets, registry=registry) 42 | assert len(table.columns) == 3 + len(registry) 43 | assert len(table.rows) == len(mapped_datasets) 44 | 45 | 46 | def test_prepare_rich_table_threshold(data_catalog_fixture, pipeline_fixture): 47 | registry = {"__default__": pipeline_fixture} 48 | custom_registry = deepcopy(registry) 49 | for i in range(KEDRO_RICH_CATALOG_LIST_THRESHOLD): 50 | custom_registry[f"custom_{i}"] = Pipeline([]) 51 | catalog_datasets = get_catalog_datasets(data_catalog_fixture, drop_params=True) 52 | pipeline_datasets = get_datasets_by_pipeline(data_catalog_fixture, custom_registry) 53 | mapped_datasets = report_datasets_as_list(pipeline_datasets, catalog_datasets) 54 | table = prepare_rich_table(records=mapped_datasets, registry=custom_registry) 55 | assert len(table.columns) == 3 56 | assert len(table.rows) == len(mapped_datasets) 57 | 58 | 59 | def test_print_kedro_pipeline_init_screen(capsys): 60 | 61 | print_kedro_pipeline_init_screen() 62 | captured = capsys.readouterr() 63 | 64 | assert kedro.__version__ in captured.out 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kedro-rich 2 | 3 | ## Make your Kedro snazzy 4 | 5 | This is a very early work in progress Kedro plugin that utilises the awesome rich library. 6 | 7 | The intention with this piece of work is to battle test the idea, iron out the creases potentially to integrate this as a 1st class plugin hosted at kedro-org/plugins or if we're lucky, native functionality within Kedro itself. 8 | 9 | I'm very much looking for help developing/testing this project so if you want to get involved please get in touch. 10 | 11 | ## Current Functionality 12 | 13 | ### Overridden `kedro run` command 14 | 15 | - Does exactly the same as a regular Kedro run but kicks the progress bars into account. 16 | - The load/save progress tasks focus purely on persisted data and ignore ephemeral `MemoryDataSets`. 17 | 18 | ![kedro-rich-run](https://user-images.githubusercontent.com/35801847/159065139-1f98e136-7725-480a-8a26-974bac1687bf.gif) 19 | 20 | - The progress bars are currently disabuled when using `ParallelRunner` since `MultiProcessing` is causing issues between `kedro` and `rich`. Further investigation if some sort of `Lock()` mechanism will allow for this to work. 21 | 22 | ### Logging via `rich.logging.RichHandler` 23 | 24 | - This plugin changes the default `stdout` console logging handler in place of the class provided by `rich` . 25 | - This is actually required to make the progress bars work without being broken onto new lines every time a new log message appears. 26 | - At this point we also enable the [rich traceback handler](https://rich.readthedocs.io/en/stable/traceback.html). 27 | - In order enable this purely plug-in side (i.e. not making the user change `logging.yml`) I've had to do an ugly bit of monkey patching. Keen to come up with a better solution here. 28 | 29 | ### Overridden `kedro catalog list` command 30 | 31 | Accepts following options: 32 | 33 | - `--format=yaml` provides YAML representation to stdout that can be piped into other utilities 34 | - `--format=json` provides JSON representation to stdout that can be piped into other utilities 35 | - `--format=table` provides pretty representation to console for human consumption 36 | 37 | ![list of datasets](static/list-datasets.png) 38 | 39 | ## Install the plug-in 40 | 41 | ### (Option 1) Cloning the repository 42 | 43 | The plug-in is in very early days so it will be a while before (if) this makes it to pypi 44 | 45 | 1. Clone the repository 46 | 2. Run `make dev-install` to install this to your environment 47 | 3. Go to any Kedro 0.17.x project and see if it works! (Please let me know if it doesn't). 48 | 49 | ### (Option 2) Direct from GitHub 50 | 51 | 1. Run `pip install git+https://github.com/datajoely/kedro-rich` to install this to your environment. 52 | 2. Go to any Kedro 0.17.x project and see if it works! (Please let me know if it doesn't). 53 | 54 | ## Run end to end example 55 | 56 | Running `make test-project` then `make test-run` will... 57 | 58 | - Install the `kedro-rich` package into the environment 59 | - Pull the 'spaceflights' `kedro-starter` 60 | - Install requirements 61 | - Execute `kedro run` 62 | -------------------------------------------------------------------------------- /tests/utilities/test_catalog_utils.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import pytest 3 | 4 | from kedro_rich.utilities.catalog_utils import ( 5 | filter_datasets_by_pipeline, 6 | get_datasets_by_pipeline, 7 | resolve_pipeline_namespace, 8 | resolve_catalog_namespace, 9 | get_catalog_datasets, 10 | split_catalog_namespace_key, 11 | report_datasets_as_list, 12 | ) 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "given_key,expected_key", 17 | [ 18 | ("ds.something.something_else.key", "ds__something__something_else__key"), 19 | ("ds.key", "ds__key"), 20 | ("key", "key"), 21 | ], 22 | ) 23 | def test_resolve_pipeline_namespace(given_key: str, expected_key: str): 24 | assert resolve_pipeline_namespace(given_key) == expected_key 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "given_key,expected_key", 29 | [ 30 | ("ds__something__something_else__key", "ds.something.something_else.key"), 31 | ("ds__key", "ds.key"), 32 | ("key", "key"), 33 | ], 34 | ) 35 | def test_resolve_catalog_namespace(given_key: str, expected_key: str): 36 | assert resolve_catalog_namespace(given_key) == expected_key 37 | 38 | 39 | def test_get_catalog_datasets(data_catalog_fixture): 40 | datasets_remaining = get_catalog_datasets( 41 | catalog=data_catalog_fixture, exclude_types=("MemoryDataSet",), drop_params=True 42 | ) 43 | assert len(datasets_remaining) == 2 44 | assert all(x == "PickleDataSet" for x in datasets_remaining.values()) 45 | 46 | datasets_remaining_keep_everything = get_catalog_datasets( 47 | catalog=data_catalog_fixture, drop_params=False 48 | ) 49 | 50 | assert len(datasets_remaining_keep_everything) == 4 51 | assert "MemoryDataSet" in datasets_remaining_keep_everything.values() 52 | 53 | 54 | def test_filter_datasets_by_pipeline(data_catalog_fixture, pipeline_fixture): 55 | pipe_datasets = filter_datasets_by_pipeline( 56 | datasets=get_catalog_datasets( 57 | catalog=data_catalog_fixture, 58 | exclude_types=("MemoryDataSet",), 59 | drop_params=True, 60 | ), 61 | pipeline=pipeline_fixture, 62 | ) 63 | all_dataset_keys = reduce( 64 | lambda x, y: x | y, [set(x.keys()) for x in pipe_datasets] 65 | ) 66 | all_dataset_types = reduce( 67 | lambda x, y: x | y, [set(x.values()) for x in pipe_datasets] 68 | ) 69 | assert all_dataset_keys == {"dataset_1", "dataset_3"} 70 | assert all_dataset_types == {"PickleDataSet"} 71 | 72 | 73 | def test_get_datasets_by_pipeline(data_catalog_fixture, pipeline_fixture): 74 | dataset_pipes = get_datasets_by_pipeline( 75 | catalog=data_catalog_fixture, registry={"__default__": pipeline_fixture} 76 | ) 77 | assert set(reduce(lambda x, y: x + y, dataset_pipes.values())) == {"__default__"} 78 | assert set(dataset_pipes.keys()) == {"dataset_1", "dataset_2", "dataset_3"} 79 | 80 | 81 | @pytest.mark.parametrize( 82 | "given_key,expected_namespace,expected_key", 83 | [ 84 | ("ds.something.something_else.key", "ds.something.something_else", "key"), 85 | ("ds.key", "ds", "key"), 86 | ("key", None, "key"), 87 | ], 88 | ) 89 | def test_split_catalog_namespace_key(given_key, expected_namespace, expected_key): 90 | namespace, key = split_catalog_namespace_key(given_key) 91 | assert (namespace, key) == (expected_namespace, expected_key) 92 | 93 | 94 | def test_report_datasets_as_list(data_catalog_fixture, pipeline_fixture): 95 | reported_list = report_datasets_as_list( 96 | catalog_datasets=get_catalog_datasets( 97 | catalog=data_catalog_fixture, 98 | exclude_types=("MemoryDataSet",), 99 | drop_params=True, 100 | ), 101 | pipeline_datasets=get_datasets_by_pipeline( 102 | catalog=data_catalog_fixture, registry={"__default__": pipeline_fixture} 103 | ), 104 | ) 105 | assert [ 106 | { 107 | "dataset_type": "PickleDataSet", 108 | "pipelines": ["__default__"], 109 | "namespace": None, 110 | "key": "dataset_1", 111 | }, 112 | { 113 | "dataset_type": "PickleDataSet", 114 | "pipelines": ["__default__"], 115 | "namespace": None, 116 | "key": "dataset_3", 117 | }, 118 | ] == reported_list 119 | -------------------------------------------------------------------------------- /kedro_rich/rich_cli.py: -------------------------------------------------------------------------------- 1 | """Command line tools for manipulating a Kedro project. 2 | Intended to be invoked via `kedro`.""" 3 | import importlib 4 | import json 5 | import os 6 | from typing import Callable, Dict 7 | 8 | import click 9 | import rich_click 10 | import yaml 11 | from kedro.framework.cli.catalog import _create_session, create_catalog 12 | from kedro.framework.cli.project import run 13 | from kedro.framework.cli.utils import CONTEXT_SETTINGS, env_option 14 | from kedro.framework.startup import ProjectMetadata 15 | from kedro.pipeline import Pipeline 16 | from rich import box 17 | from rich.panel import Panel 18 | 19 | from kedro_rich.constants import KEDRO_RICH_PROGRESS_ENV_VAR_KEY 20 | from kedro_rich.utilities.catalog_utils import ( 21 | get_catalog_datasets, 22 | get_datasets_by_pipeline, 23 | report_datasets_as_list, 24 | ) 25 | from kedro_rich.utilities.formatting_utils import prepare_rich_table 26 | 27 | 28 | @click.group(context_settings=CONTEXT_SETTINGS, name="kedro-rich") 29 | def commands(): 30 | """Command line tools for manipulating a Kedro project.""" 31 | 32 | 33 | def handle_parallel_run(func: Callable) -> Callable: 34 | """ 35 | This method raps the run command callback so that the 36 | parallel runner is disabled 37 | """ 38 | 39 | def wrapped(*args, **kwargs): 40 | """ 41 | This method will add an environment variable if the user 42 | selects a parallel run, kedro-rich will disable the progress bar in 43 | that situation 44 | """ 45 | 46 | # Only run progress bars if (1) In complex mode (2) NOT ParallelRunner 47 | if not kwargs["simple"]: 48 | if not kwargs["parallel"]: 49 | os.environ[KEDRO_RICH_PROGRESS_ENV_VAR_KEY] = "1" 50 | 51 | # drop 'simple' kwarg as that doesn't exist in the original func 52 | original_kwargs = {k: v for k, v in kwargs.items() if "simple" != k} 53 | result = func(*args, **original_kwargs) 54 | 55 | if os.environ.get(KEDRO_RICH_PROGRESS_ENV_VAR_KEY): 56 | del os.environ[KEDRO_RICH_PROGRESS_ENV_VAR_KEY] 57 | return result 58 | 59 | return wrapped 60 | 61 | 62 | run.callback = handle_parallel_run(run.callback) 63 | run.__class__ = rich_click.RichCommand 64 | commands.add_command( 65 | click.option( 66 | "--simple", 67 | "-s", 68 | default=False, 69 | is_flag=True, 70 | help="Disable rich progress bars (simple mode)", 71 | )(run) 72 | ) 73 | 74 | 75 | @commands.group(cls=rich_click.RichGroup) 76 | def catalog(): 77 | """Commands for working with catalog.""" 78 | pass 79 | 80 | 81 | catalog.add_command(create_catalog) 82 | 83 | 84 | @catalog.command(cls=rich_click.RichCommand, name="list") 85 | @env_option 86 | @click.option( 87 | "--format", 88 | "-f", 89 | "fmt", 90 | default="yaml", 91 | type=click.Choice(["yaml", "json", "table"], case_sensitive=False), 92 | help="Output the 'yaml' (default) / 'json' results to stdout or pretty" 93 | " print 'table' to console", 94 | ) 95 | @click.pass_obj 96 | def list_datasets(metadata: ProjectMetadata, fmt: str, env: str): 97 | """Detail datasets by type.""" 98 | 99 | # Needed to avoid circular reference 100 | from rich.console import Console # pylint: disable=import-outside-toplevel 101 | 102 | pipelines = _get_pipeline_registry(metadata) 103 | session = _create_session(metadata.package_name, env=env) 104 | context = session.load_context() 105 | catalog_datasets = get_catalog_datasets(context.catalog, drop_params=True) 106 | pipeline_datasets = get_datasets_by_pipeline(context.catalog, pipelines) 107 | mapped_datasets = report_datasets_as_list(pipeline_datasets, catalog_datasets) 108 | console = Console() 109 | 110 | if fmt == "yaml": 111 | struct = { 112 | f"{x['namespace']}.{x['key']}" if x["namespace"] else x["key"]: x 113 | for x in mapped_datasets 114 | } 115 | console.out(yaml.safe_dump(struct)) 116 | if fmt == "json": 117 | console.out(json.dumps(mapped_datasets, indent=2)) 118 | elif fmt == "table": 119 | table = prepare_rich_table(records=mapped_datasets, registry=pipelines) 120 | console.print( 121 | "\n", 122 | Panel( 123 | table, 124 | expand=False, 125 | title=f"Catalog contains [b][cyan]{len(mapped_datasets)}[/][/] persisted datasets", 126 | padding=(1, 1), 127 | box=box.MINIMAL, 128 | ), 129 | ) 130 | 131 | 132 | def _get_pipeline_registry(proj_metadata: ProjectMetadata) -> Dict[str, Pipeline]: 133 | """ 134 | This method retrieves the pipelines registered in the project where 135 | the plugin in installed 136 | """ 137 | # is this the right 0.18.x version of doing this? 138 | # The object is no longer in the context 139 | registry = importlib.import_module( 140 | f"{proj_metadata.package_name}.pipeline_registry" 141 | ) 142 | return registry.register_pipelines() 143 | -------------------------------------------------------------------------------- /kedro_rich/utilities/formatting_utils.py: -------------------------------------------------------------------------------- 1 | """This module provides utilities that are helpful 2 | for printing to the rich console""" 3 | 4 | from typing import Any, Dict, List, Optional, Tuple 5 | 6 | import kedro 7 | from kedro.pipeline import Pipeline 8 | from rich import box 9 | from rich.console import Console 10 | from rich.style import Style 11 | from rich.table import Table 12 | 13 | from kedro_rich.constants import KEDRO_RICH_CATALOG_LIST_THRESHOLD 14 | 15 | 16 | def prepare_rich_table( 17 | records: List[Dict[str, Any]], registry: Dict[str, Pipeline] 18 | ) -> Table: 19 | """This method will build a rich.Table object based on the 20 | a given list of records and a dictionary of registered pipelines 21 | 22 | Args: 23 | records (List[Dict[str, Any]]): The catalog records 24 | pipes (Dict[str, Pipeline]): The pipelines to map to linked datasets 25 | 26 | Returns: 27 | Table: The table to render 28 | """ 29 | 30 | table = Table(show_header=True, header_style=Style(color="white"), box=box.ROUNDED) 31 | # only include namespace if at least one present in catalog 32 | includes_namespaces = any(x["namespace"] for x in records) 33 | collapse_pipes = len(registry.keys()) > KEDRO_RICH_CATALOG_LIST_THRESHOLD 34 | 35 | # define table headers 36 | namespace_columns = ["namespace"] if includes_namespaces else [] 37 | pipe_columns = ["pipeline_count"] if collapse_pipes else list(registry.keys()) 38 | columns_to_add = namespace_columns + ["dataset_name", "dataset_type"] + pipe_columns 39 | 40 | # add table headers 41 | for column in columns_to_add: 42 | table.add_column(column, justify="center") 43 | 44 | # add table rows 45 | for index, row in enumerate(records): 46 | 47 | def _describe_boundary( 48 | index: int, records: List[Dict[str, Any]], key: str, current_value: str 49 | ) -> Tuple[bool, bool]: 50 | """ 51 | Give a list of dictionaries, key and current value this method will 52 | return two booleans detailing if the sequence has changed or not 53 | """ 54 | same_section = ( 55 | index + 1 < len(records) and records[index + 1][key] == current_value 56 | ) 57 | new_section = index == 0 or records[index - 1][key] != current_value 58 | 59 | return same_section, new_section 60 | 61 | # work out if the dataset_type is the same / different to next row 62 | same_section, new_section = _describe_boundary( 63 | index=index, 64 | records=records, 65 | key="dataset_type", 66 | current_value=row["dataset_type"], 67 | ) 68 | 69 | # add namespace if present 70 | if includes_namespaces: 71 | table_namespace = ( 72 | [row["namespace"]] if row["namespace"] else ["[bright_black]n/a[/]"] 73 | ) 74 | else: 75 | table_namespace = [] 76 | 77 | # get catalog key 78 | table_dataset_name = [row["key"]] 79 | 80 | # get dataset_type, only show if different from the last record 81 | table_dataset_type = ( 82 | [f"[magenta][b]{row['dataset_type']}[/][/]"] if new_section else [""] 83 | ) 84 | 85 | # get pipelines attached to this dataset 86 | dataset_pipes = row["pipelines"] 87 | # get pipelines registered in this project 88 | proj_pipes = sorted(registry.keys()) 89 | 90 | # if too many pipelines registered, simply show the count 91 | if collapse_pipes: 92 | table_pipes = [str(len(dataset_pipes))] 93 | else: 94 | 95 | # show ✓ and ✘ if present 96 | table_pipes = [ 97 | "[bold green]✓[/]" 98 | if (pipe in (set(proj_pipes) & set(dataset_pipes))) 99 | else "[bold red]✘[/]" 100 | for pipe in proj_pipes 101 | ] 102 | 103 | # build full row 104 | renderables = ( 105 | table_namespace + table_dataset_name + table_dataset_type + table_pipes 106 | ) 107 | 108 | # add row to table 109 | table.add_row(*renderables, end_section=not same_section) 110 | return table 111 | 112 | 113 | def get_kedro_logo(color: str = "orange1") -> Optional[List[str]]: 114 | """This method constructs an ascii Kedro logo""" 115 | diamond = """ 116 | - 117 | ·===· 118 | ·==: :==· 119 | ·==: :==· 120 | ·==: :==· 121 | ·===· 122 | - 123 | """.split( 124 | "\n" 125 | ) 126 | color_rows = [ 127 | f"[{color}][b]{x}[/b][/{color}]" if x.strip() else "" for x in diamond 128 | ] 129 | 130 | return color_rows 131 | 132 | 133 | def print_kedro_pipeline_init_screen( 134 | title_color: str = "orange1", tagline_color: str = "gray" 135 | ): 136 | """This method prints the Kedro logo and package metadata""" 137 | 138 | tagline_text = "Reproducible, maintainable and modular data science code" 139 | lib_info = dict( 140 | title=f"[{title_color}][b]KEDRO[/][/{title_color}] ({kedro.__version__})", 141 | tagline=f"[{tagline_color}][i]{tagline_text}[/][/]", 142 | github="https://github.com/kedro-org/kedro", 143 | rtd="https://kedro.readthedocs.io", 144 | ) 145 | 146 | logo_rows = get_kedro_logo() 147 | mapping = ((2, "title"), (3, "tagline"), (-3, "github"), (-2, "rtd")) 148 | for index, key in mapping: 149 | spacing = (51 - len(logo_rows[index])) * " " 150 | logo_rows[index] = logo_rows[index] + spacing + lib_info[key] 151 | 152 | str_rows = "\n".join(logo_rows) 153 | Console().print(str_rows, no_wrap=True) 154 | -------------------------------------------------------------------------------- /kedro_rich/utilities/catalog_utils.py: -------------------------------------------------------------------------------- 1 | """This module includes helper functions for managing the data catalog""" 2 | import operator 3 | from functools import reduce 4 | from itertools import groupby 5 | from typing import Any, Dict, List, Optional, Set, Tuple 6 | 7 | from kedro.io import DataCatalog 8 | from kedro.pipeline import Pipeline 9 | 10 | 11 | def get_catalog_datasets( 12 | catalog: DataCatalog, exclude_types: Tuple[str] = (), drop_params: bool = False 13 | ) -> Dict[str, str]: 14 | """Filter to only persisted datasets""" 15 | datasets_filtered = { 16 | k: type(v).__name__ 17 | for k, v in catalog.datasets.__dict__.items() 18 | if type(v).__name__ not in exclude_types 19 | } 20 | if drop_params: 21 | datasets_w_param_filter = { 22 | k: v 23 | for k, v in datasets_filtered.items() 24 | if not k.startswith("params") and not k == "parameters" 25 | } 26 | return datasets_w_param_filter 27 | return datasets_filtered 28 | 29 | 30 | def filter_datasets_by_pipeline( 31 | datasets: Dict[str, str], pipeline: Pipeline 32 | ) -> Tuple[Dict[str, str], Dict[str, str]]: 33 | """ 34 | Retrieve datasets (inputs and outputs) which intersect against 35 | a given pipeline object 36 | 37 | This function also ensures namespaces are correctly rationalised. 38 | """ 39 | 40 | def _clean_names(datasets: List[str], namespace: Optional[str]) -> Set[str]: 41 | if namespace: 42 | return {resolve_pipeline_namespace(x) for x in datasets} 43 | return set(datasets) 44 | 45 | inputs = reduce( 46 | lambda a, x: a | _clean_names(x.inputs, x.namespace), pipeline.nodes, set() 47 | ) 48 | outputs = reduce( 49 | lambda a, x: a | _clean_names(x.outputs, x.namespace), pipeline.nodes, set() 50 | ) 51 | 52 | pipeline_inputs = { 53 | k: v for k, v in datasets.items() if any(x.endswith(k) for x in inputs) 54 | } 55 | pipeline_outputs = { 56 | k: v for k, v in datasets.items() if any(x.endswith(k) for x in outputs) 57 | } 58 | return pipeline_inputs, pipeline_outputs 59 | 60 | 61 | def get_datasets_by_pipeline( 62 | catalog: DataCatalog, registry: Dict[str, Pipeline] 63 | ) -> Dict[str, List[str]]: 64 | """This method will return a dictionary of datasets mapped to the 65 | list of pipelines they are used within 66 | 67 | Args: 68 | catalog (DataCatalog): The data catalog object 69 | pipelines (Dict[str, Pipeline]): The pipelines in this project 70 | 71 | Returns: 72 | Dict[str, List[str]]: The dataset to pipeline groups 73 | """ 74 | # get non parameter dataset 75 | catalog_datasets = get_catalog_datasets(catalog=catalog, drop_params=True) 76 | 77 | # get node input and outputs 78 | pipeline_input_output_datasets = { 79 | pipeline_name: filter_datasets_by_pipeline(catalog_datasets, pipeline) 80 | for pipeline_name, pipeline in registry.items() 81 | } 82 | 83 | # get those that overlap with pipelines 84 | pipeline_datasets = { 85 | pipeline_name: reduce( 86 | lambda input, output: input.keys() | output.keys(), input_outputs 87 | ) 88 | for pipeline_name, input_outputs in pipeline_input_output_datasets.items() 89 | } 90 | 91 | # get dataset to pipeline pairs 92 | dataset_pipeline_pairs = reduce( 93 | lambda x, y: x + y, 94 | ( 95 | [(dataset, pipeline) for dataset in datasets] 96 | for pipeline, datasets in pipeline_datasets.items() 97 | ), 98 | ) 99 | 100 | # get dataset to pipeline groups 101 | sorter = sorted(dataset_pipeline_pairs, key=operator.itemgetter(0)) 102 | 103 | grouper = groupby(sorter, key=operator.itemgetter(0)) 104 | 105 | dataset_pipeline_groups = { 106 | k: list(map(operator.itemgetter(1), v)) for k, v in grouper 107 | } 108 | return dataset_pipeline_groups 109 | 110 | 111 | def resolve_pipeline_namespace(dataset_name: str) -> str: 112 | """Resolves the dot to double underscore namespace 113 | discrepancy between pipeline inputs/outputs and catalog keys 114 | """ 115 | return dataset_name.replace(".", "__") 116 | 117 | 118 | def resolve_catalog_namespace(dataset_name: str) -> str: 119 | """Resolves the double underscore to dot namespace 120 | discrepancy between catalog keys and pipeline inputs/outputs 121 | """ 122 | return dataset_name.replace("__", ".") 123 | 124 | 125 | def split_catalog_namespace_key(dataset_name: str) -> Tuple[Optional[str], str]: 126 | """This method splits out a catalog name from it's namespace""" 127 | dataset_split = dataset_name.split(".") 128 | namespace = ".".join(dataset_split[:-1]) 129 | if namespace: 130 | dataset_name = dataset_split[-1] 131 | return namespace, dataset_name 132 | return None, dataset_name 133 | 134 | 135 | def report_datasets_as_list( 136 | pipeline_datasets: Dict[str, List[str]], catalog_datasets: Dict[str, str] 137 | ) -> List[Dict[str, Any]]: 138 | """This method accepts the datasets present in the pipeline registry 139 | as well as the full data catalog and produces a list of records 140 | which include key metadata such as the type, namespace, linked pipelines 141 | and dataset name (ordered by type) 142 | """ 143 | return sorted( 144 | ( 145 | { 146 | **{ 147 | "dataset_type": v, 148 | "pipelines": pipeline_datasets.get(k, []), 149 | }, 150 | **dict( 151 | zip( 152 | ("namespace", "key"), 153 | split_catalog_namespace_key( 154 | dataset_name=resolve_catalog_namespace(k) 155 | ), 156 | ) 157 | ), 158 | } 159 | for k, v in catalog_datasets.items() 160 | ), 161 | key=lambda x: x["dataset_type"], 162 | ) 163 | -------------------------------------------------------------------------------- /kedro_rich/utilities/kedro_override_utils.py: -------------------------------------------------------------------------------- 1 | """This module provides methods which we use to override default Kedro methods""" 2 | # pylint: disable=protected-access 3 | import logging 4 | from typing import Any, Callable, Optional, Set 5 | 6 | import rich 7 | from click.core import _check_multicommand 8 | from kedro.framework.cli.cli import KedroCLI 9 | from kedro.framework.session import KedroSession 10 | from kedro.io.core import AbstractVersionedDataSet, Version 11 | from rich.panel import Panel 12 | 13 | from kedro_rich.constants import KEDRO_RICH_LOGGING_HANDLER 14 | 15 | 16 | def override_node_str(self) -> str: 17 | """This method rich-ifies the node.__str__ method""" 18 | 19 | def _drop_namespaces(xset: Set[str]) -> Optional[Set]: 20 | """This method cleans up the namesapces""" 21 | split = {x.split(".")[-1] for x in xset} 22 | if split: 23 | return split 24 | return None 25 | 26 | func_name = f"[magenta]𝑓𝑥 {self._func_name}([/]" 27 | inputs = _drop_namespaces(self.inputs) 28 | bridge = "[magenta])[/] [cyan]➡[/] " 29 | outputs = _drop_namespaces(self.outputs) 30 | return f"{func_name}{inputs}{bridge}{outputs}" 31 | 32 | 33 | def override_catalog_load(self, name: str, version: str = None) -> Any: 34 | """Loads a registered data set (Rich-ified output). 35 | 36 | Args: 37 | name: A data set to be loaded. 38 | version: Optional argument for concrete data version to be loaded. 39 | Works only with versioned datasets. 40 | 41 | Returns: 42 | The loaded data as configured. 43 | 44 | Raises: 45 | DataSetNotFoundError: When a data set with the given name 46 | has not yet been registered. 47 | 48 | Example: 49 | :: 50 | 51 | >>> from kedro.io import DataCatalog 52 | >>> from kedro.extras.datasets.pandas import CSVDataSet 53 | >>> 54 | >>> cars = CSVDataSet(filepath="cars.csv", 55 | >>> load_args=None, 56 | >>> save_args={"index": False}) 57 | >>> io = DataCatalog(data_sets={'cars': cars}) 58 | >>> 59 | >>> df = io.load("cars") 60 | """ 61 | load_version = Version(version, None) if version else None 62 | dataset = self._get_dataset(name, version=load_version) 63 | 64 | self._logger.info( 65 | "Loading data from [bright_blue]%s[/] ([bright_blue][b]%s[/][/])...", 66 | name, 67 | type(dataset).__name__, 68 | ) 69 | 70 | func = self._get_transformed_dataset_function(name, "load", dataset) 71 | result = func() 72 | 73 | version = ( 74 | dataset.resolve_load_version() 75 | if isinstance(dataset, AbstractVersionedDataSet) 76 | else None 77 | ) 78 | 79 | # Log only if versioning is enabled for the data set 80 | if self._journal and version: 81 | self._journal.log_catalog(name, "load", version) 82 | return result 83 | 84 | 85 | def override_catalog_save(self, name: str, data: Any) -> None: 86 | """Save data to a registered data set. 87 | 88 | Args: 89 | name: A data set to be saved to. 90 | data: A data object to be saved as configured in the registered 91 | data set. 92 | 93 | Raises: 94 | DataSetNotFoundError: When a data set with the given name 95 | has not yet been registered. 96 | 97 | Example: 98 | :: 99 | 100 | >>> import pandas as pd 101 | >>> 102 | >>> from kedro.extras.datasets.pandas import CSVDataSet 103 | >>> 104 | >>> cars = CSVDataSet(filepath="cars.csv", 105 | >>> load_args=None, 106 | >>> save_args={"index": False}) 107 | >>> io = DataCatalog(data_sets={'cars': cars}) 108 | >>> 109 | >>> df = pd.DataFrame({'col1': [1, 2], 110 | >>> 'col2': [4, 5], 111 | >>> 'col3': [5, 6]}) 112 | >>> io.save("cars", df) 113 | """ 114 | dataset = self._get_dataset(name) 115 | 116 | self._logger.info( 117 | "Saving data to [bright_blue]%s[/] ([bright_blue][b]%s[/][/])...", 118 | name, 119 | type(dataset).__name__, 120 | ) 121 | 122 | func = self._get_transformed_dataset_function(name, "save", dataset) 123 | func(data) 124 | 125 | version = ( 126 | dataset.resolve_save_version() 127 | if isinstance(dataset, AbstractVersionedDataSet) 128 | else None 129 | ) 130 | 131 | # Log only if versioning is enabled for the data set 132 | if self._journal and version: 133 | self._journal.log_catalog(name, "save", version) 134 | 135 | 136 | def override_kedro_proj_logging_handler(): 137 | """ 138 | This function does two things: 139 | 140 | (1) It mutates the dictionary provided by `logging.yml` to 141 | use the `rich.logging.RichHandler` instead of the standard output one 142 | (2) It enables the rich.Traceback handler so that exceptions are prettier 143 | """ 144 | 145 | # ensure warnings are caught by logger not stout 146 | logging.captureWarnings(True) 147 | 148 | def _replace_console_handler(func: Callable) -> Callable: 149 | """This function mutates the dictionary returned by reading logging.yml""" 150 | 151 | def wrapped(*args, **kwargs): 152 | logging_config = func(*args, **kwargs) 153 | logging_config["handlers"]["console"] = KEDRO_RICH_LOGGING_HANDLER 154 | return logging_config 155 | 156 | return wrapped 157 | 158 | # pylint: disable=protected-access 159 | KedroSession._get_logging_config = _replace_console_handler( 160 | KedroSession._get_logging_config 161 | ) 162 | 163 | 164 | def override_kedro_cli_get_command(): 165 | """This method overrides the Click get_command() method 166 | so that we can give the user a useful message if they try to do a Kedro 167 | project command outside of a project directory 168 | """ 169 | 170 | # pylint: disable=invalid-name 171 | # pylint: disable=inconsistent-return-statements 172 | def _get_command(self, ctx, cmd_name): 173 | for source in self.sources: 174 | rv = source.get_command(ctx, cmd_name) 175 | if rv is not None: 176 | if self.chain: 177 | _check_multicommand(self, cmd_name, rv) 178 | return rv 179 | if not self._metadata: 180 | 181 | warn = "[orange1][b]You are not in a Kedro project[/]![/]" 182 | result = "Project specific commands such as '[bright_cyan]run[/]' or \ 183 | '[bright_cyan]jupyter[/]' are only available within a project directory." 184 | solution = "[bright_black][b]Hint:[/] [i]Kedro is looking for a file called \ 185 | '[magenta]pyproject.toml[/]', is one present in your current working directory?[/][/]" 186 | msg = f"{warn} {result}\n\n{solution}" 187 | console = rich.console.Console() 188 | panel = Panel( 189 | msg, 190 | title=f"Command '{cmd_name}' not found", 191 | expand=False, 192 | border_style="dim", 193 | title_align="left", 194 | ) 195 | console.print("\n", panel, "\n") 196 | 197 | KedroCLI.get_command = _get_command 198 | -------------------------------------------------------------------------------- /kedro_rich/rich_progress_hooks.py: -------------------------------------------------------------------------------- 1 | """This module provides lifecycle hooks to track progress""" 2 | import logging 3 | import os 4 | import time 5 | from datetime import timedelta 6 | from typing import Any, Dict 7 | 8 | from kedro.framework.hooks import hook_impl 9 | from kedro.io import DataCatalog 10 | from kedro.pipeline import Pipeline 11 | from kedro.pipeline.node import Node 12 | from rich.progress import ( 13 | BarColumn, 14 | Progress, 15 | ProgressColumn, 16 | SpinnerColumn, 17 | Task, 18 | TaskID, 19 | ) 20 | from rich.text import Text 21 | 22 | from kedro_rich.constants import ( 23 | KEDRO_RICH_PROGRESS_ENV_VAR_KEY, 24 | KEDRO_RICH_SHOW_DATASET_PROGRESS, 25 | ) 26 | from kedro_rich.utilities.catalog_utils import ( 27 | filter_datasets_by_pipeline, 28 | get_catalog_datasets, 29 | resolve_pipeline_namespace, 30 | split_catalog_namespace_key, 31 | ) 32 | from kedro_rich.utilities.formatting_utils import print_kedro_pipeline_init_screen 33 | 34 | 35 | class RichProgressHooks: 36 | """These set of hooks add progress information to the output of a Kedro run""" 37 | 38 | def __init__(self): 39 | """This constructor initialises the variables used to manage state""" 40 | self.progress = None 41 | self.task_count = 0 42 | self.io_datasets_in_catalog = {} 43 | self.pipeline_inputs = {} 44 | self.pipeline_outputs = {} 45 | self.tasks = {} 46 | 47 | @hook_impl 48 | def before_pipeline_run( 49 | self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog 50 | ): 51 | """ 52 | This method initialises the variables needed to track pipeline process. This 53 | will be disabled under parallel runner 54 | """ 55 | if self._check_if_progress_bar_enabled(): 56 | progress_desc_format = "[progress.description]{task.description}" 57 | progress_percentage_format = "[progress.percentage]{task.percentage:>3.0f}%" 58 | progress_activity_format = "{task.fields[activity]}" 59 | self.progress = Progress( 60 | _KedroElapsedColumn(), 61 | progress_desc_format, 62 | SpinnerColumn(), 63 | BarColumn(), 64 | progress_percentage_format, 65 | progress_activity_format, 66 | ) 67 | 68 | # Get pipeline goals 69 | self._init_progress_tasks(pipeline, catalog) 70 | 71 | # Init tasks 72 | pipe_name = run_params.get("pipeline_name") or "__default__" 73 | input_cnt = len(self.pipeline_inputs) 74 | output_cnt = len(self.pipeline_outputs) 75 | 76 | dataset_tasks = ( 77 | { 78 | "loads": self._add_task(desc="Loading datasets", count=input_cnt), 79 | "saves": self._add_task(desc="Saving datasets", count=output_cnt), 80 | } 81 | if KEDRO_RICH_SHOW_DATASET_PROGRESS 82 | else {} 83 | ) 84 | 85 | overall_task = { 86 | "overall": self._add_task( 87 | desc=f"Running [bright_magenta]'{pipe_name}'[/] pipeline", 88 | count=self.task_count, 89 | ) 90 | } 91 | 92 | self.tasks = {**dataset_tasks, **overall_task} 93 | 94 | print_kedro_pipeline_init_screen() 95 | 96 | # Start process 97 | self.progress.start() 98 | else: 99 | logger = logging.getLogger(__name__) 100 | logger.warning( 101 | "[orange1 bold]Progress bars are incompatible with ParallelRunner[/]", 102 | ) 103 | 104 | @hook_impl 105 | def before_dataset_loaded(self, dataset_name: str): 106 | """ 107 | Add the last dataset loaded (from persistent storage) 108 | to progress display 109 | """ 110 | if KEDRO_RICH_SHOW_DATASET_PROGRESS: 111 | if self.progress: 112 | dataset_name_namespaced = resolve_pipeline_namespace(dataset_name) 113 | if dataset_name in self.pipeline_inputs: 114 | dataset_type = self.io_datasets_in_catalog[dataset_name_namespaced] 115 | dataset_desc = ( 116 | f"📂{' ':<5}[i]{dataset_name}[/] ([bold cyan]{dataset_type}[/])" 117 | ) 118 | self.progress.update( 119 | self.tasks["loads"], advance=1, activity=dataset_desc 120 | ) 121 | 122 | @hook_impl 123 | def after_dataset_saved(self, dataset_name: str): 124 | """Add the last dataset persisted to progress display""" 125 | if KEDRO_RICH_SHOW_DATASET_PROGRESS: 126 | if self.progress: 127 | dataset_name_namespaced = resolve_pipeline_namespace(dataset_name) 128 | 129 | if dataset_name_namespaced in self.pipeline_outputs: 130 | namespace, key = split_catalog_namespace_key(dataset_name) 131 | 132 | data_string = ( 133 | f"[blue]{namespace}[/].{key}" if namespace else f"{key}" 134 | ) 135 | 136 | dataset_type = self.io_datasets_in_catalog[dataset_name_namespaced] 137 | dataset_desc = ( 138 | f"💾{' ':<5}[i]{data_string}[/] ([bold cyan]{dataset_type}[/])" 139 | ) 140 | self.progress.update( 141 | self.tasks["saves"], advance=1, activity=dataset_desc 142 | ) 143 | 144 | @hook_impl 145 | def before_node_run(self, node: Node): 146 | """Add the current function name to progress display""" 147 | if self.progress: 148 | self.progress.update( 149 | self.tasks["overall"], 150 | activity=f"[violet]𝑓𝑥[/]{' ':<5}[orange1]{node.func.__name__}[/]()", 151 | ) 152 | 153 | @hook_impl 154 | def after_node_run(self): 155 | """Increment the task count on node completion""" 156 | if self.progress: 157 | self.progress.update(self.tasks["overall"], advance=1) 158 | 159 | @hook_impl 160 | def after_pipeline_run(self): 161 | """Hook to complete and clean up progress information on pipeline completion""" 162 | if self.progress: 163 | self.progress.update( 164 | self.tasks["overall"], 165 | visible=True, 166 | activity="[bold green]✓ Pipeline complete[/] ", 167 | ) 168 | if KEDRO_RICH_SHOW_DATASET_PROGRESS: 169 | self.progress.update(self.tasks["saves"], completed=100, visible=False) 170 | self.progress.update(self.tasks["loads"], completed=100, visible=False) 171 | time.sleep(0.1) # allows the UI to clean up after the process ends 172 | 173 | def _init_progress_tasks(self, pipeline: Pipeline, catalog: DataCatalog): 174 | """This method initialises the key Hook constructor attributes""" 175 | self.task_count = len(pipeline.nodes) 176 | self.io_datasets_in_catalog = get_catalog_datasets( 177 | catalog=catalog, exclude_types=("MemoryDataSet",) 178 | ) 179 | (self.pipeline_inputs, self.pipeline_outputs,) = filter_datasets_by_pipeline( 180 | datasets=self.io_datasets_in_catalog, pipeline=pipeline 181 | ) 182 | 183 | def _add_task(self, desc: str, count: int) -> TaskID: 184 | """This method adds a task to the progress bar""" 185 | return self.progress.add_task(desc, total=count, activity="") 186 | 187 | @staticmethod 188 | def _check_if_progress_bar_enabled() -> bool: 189 | """Convert env variable into boolean""" 190 | return bool(int(os.environ.get(KEDRO_RICH_PROGRESS_ENV_VAR_KEY, "0"))) 191 | 192 | 193 | class _KedroElapsedColumn(ProgressColumn): 194 | """Renders time elapsed for top task only""" 195 | 196 | def render(self, task: Task) -> Text: 197 | """Show time remaining.""" 198 | if task.id == 0: 199 | elapsed = task.finished_time if task.finished else task.elapsed 200 | if elapsed is None: 201 | return Text("-:--:--", style="cyan") 202 | delta = timedelta(seconds=int(elapsed)) 203 | return Text(str(delta), style="green") 204 | return None 205 | 206 | 207 | rich_hooks = RichProgressHooks() 208 | --------------------------------------------------------------------------------