├── .github ├── CODEOWNERS └── workflows │ ├── format.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── conf.py ├── efficient_run_on_multi_workers.rst ├── for_pandas.rst ├── gokart.rst ├── gokart_logo_side_isolation.svg ├── index.rst ├── intro_to_gokart.rst ├── logging.rst ├── make.bat ├── mypy_plugin.rst ├── requirements.txt ├── setting_task_parameters.rst ├── slack_notification.rst ├── task_information.rst ├── task_on_kart.rst ├── task_parameters.rst ├── task_settings.rst ├── tutorial.rst └── using_task_task_conflict_prevention_lock.rst ├── examples ├── gokart_notebook_example.ipynb ├── logging.ini └── param.ini ├── gokart ├── __init__.py ├── build.py ├── build_process_task_info.py ├── config_params.py ├── conflict_prevention_lock │ ├── task_lock.py │ └── task_lock_wrappers.py ├── errors │ └── __init__.py ├── file_processor.py ├── gcs_config.py ├── gcs_obj_metadata_client.py ├── gcs_zip_client.py ├── in_memory │ ├── __init__.py │ ├── data.py │ ├── repository.py │ └── target.py ├── info.py ├── mypy.py ├── object_storage.py ├── pandas_type_config.py ├── parameter.py ├── py.typed ├── required_task_output.py ├── run.py ├── s3_config.py ├── s3_zip_client.py ├── slack │ ├── __init__.py │ ├── event_aggregator.py │ ├── slack_api.py │ └── slack_config.py ├── target.py ├── task.py ├── task_complete_check.py ├── testing │ ├── __init__.py │ ├── check_if_run_with_empty_data_frame.py │ └── pandas_assert.py ├── tree │ ├── task_info.py │ └── task_info_formatter.py ├── utils.py ├── worker.py ├── workspace_management.py ├── zip_client.py └── zip_client_util.py ├── luigi.cfg ├── pyproject.toml ├── test ├── __init__.py ├── config │ ├── __init__.py │ ├── pyproject.toml │ ├── pyproject_disallow_missing_parameters.toml │ └── test_config.ini ├── conflict_prevention_lock │ ├── __init__.py │ ├── test_task_lock.py │ └── test_task_lock_wrappers.py ├── in_memory │ ├── test_in_memory_target.py │ └── test_repository.py ├── slack │ ├── __init__.py │ └── test_slack_api.py ├── test_build.py ├── test_cache_unique_id.py ├── test_config_params.py ├── test_explicit_bool_parameter.py ├── test_file_processor.py ├── test_gcs_config.py ├── test_gcs_obj_metadata_client.py ├── test_info.py ├── test_large_data_fram_processor.py ├── test_list_task_instance_parameter.py ├── test_mypy.py ├── test_pandas_type_check_framework.py ├── test_pandas_type_config.py ├── test_restore_task_by_id.py ├── test_run.py ├── test_s3_config.py ├── test_s3_zip_client.py ├── test_serializable_parameter.py ├── test_target.py ├── test_task_instance_parameter.py ├── test_task_on_kart.py ├── test_utils.py ├── test_worker.py ├── test_zoned_date_second_parameter.py ├── testing │ ├── __init__.py │ └── test_pandas_assert.py ├── tree │ ├── __init__.py │ ├── test_task_info.py │ └── test_task_info_formatter.py └── util.py ├── tox.ini └── uv.lock /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @Hi-king 2 | * @yokomotod 3 | * @hirosassa 4 | * @mski-iksm 5 | * @kitagry 6 | * @ujiuji1259 7 | * @mamo3gr 8 | * @hiro-o918 9 | -------------------------------------------------------------------------------- /.github/workflows/format.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | 8 | 9 | jobs: 10 | formatting-check: 11 | 12 | name: Lint 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up the latest version of uv 18 | uses: astral-sh/setup-uv@v5 19 | with: 20 | enable-cache: true 21 | - name: Install dependencies 22 | run: | 23 | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv 24 | - name: Run ruff and mypy 25 | run: | 26 | uvx --with tox-uv tox run -e ruff,mypy 27 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | push: 5 | tags: '*' 6 | 7 | jobs: 8 | deploy: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up the latest version of uv 15 | uses: astral-sh/setup-uv@v5 16 | with: 17 | enable-cache: true 18 | - name: Build and publish 19 | env: 20 | UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 21 | run: | 22 | uv build 23 | uv publish 24 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | 8 | jobs: 9 | tests: 10 | runs-on: ${{ matrix.platform }} 11 | strategy: 12 | max-parallel: 7 13 | matrix: 14 | platform: ["ubuntu-latest"] 15 | tox-env: ["py39", "py310", "py311", "py312", "py313"] 16 | include: 17 | - platform: macos-13 18 | tox-env: "py313" 19 | - platform: macos-latest 20 | tox-env: "py313" 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up the latest version of uv 24 | uses: astral-sh/setup-uv@v5 25 | with: 26 | enable-cache: true 27 | - name: Install dependencies 28 | run: | 29 | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv 30 | - name: Test with tox 31 | run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }} 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # pycharm 107 | .idea 108 | 109 | # gokart 110 | resources 111 | examples/resources 112 | 113 | # poetry 114 | dist 115 | 116 | # temporary data 117 | temporary.zip -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for Sphinx projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version and other tools you might need 8 | build: 9 | os: ubuntu-24.04 10 | tools: 11 | python: "3.12" 12 | 13 | # Build from the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | 17 | # Optional but recommended, declare the Python requirements required 18 | # to build your documentation 19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: docs/requirements.txt 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 M3, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gokart 2 | 3 |

4 | 5 |

6 | 7 | [![Test](https://github.com/m3dev/gokart/workflows/Test/badge.svg)](https://github.com/m3dev/gokart/actions?query=workflow%3ATest) 8 | [![](https://readthedocs.org/projects/gokart/badge/?version=latest)](https://gokart.readthedocs.io/en/latest/) 9 | [![Python Versions](https://img.shields.io/pypi/pyversions/gokart.svg)](https://pypi.org/project/gokart/) 10 | [![](https://img.shields.io/pypi/v/gokart)](https://pypi.org/project/gokart/) 11 | ![](https://img.shields.io/pypi/l/gokart) 12 | 13 | Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline. 14 | 15 | 16 | [Documentation](https://gokart.readthedocs.io/en/latest/) for the latest release is hosted on readthedocs. 17 | 18 | 19 | # About gokart 20 | 21 | Here are some good things about gokart. 22 | 23 | - The following meta data for each Task is stored separately in a `pkl` file with hash value 24 | - task output data 25 | - imported all module versions 26 | - task processing time 27 | - random seed in task 28 | - displayed log 29 | - all parameters set as class variables in the task 30 | - Automatically rerun the pipeline if parameters of Tasks are changed. 31 | - Support GCS and S3 as a data store for intermediate results of Tasks in the pipeline. 32 | - The above output is exchanged between tasks as an intermediate file, which is memory-friendly 33 | - `pandas.DataFrame` type and column checking during I/O 34 | - Directory structure of saved files is automatically determined from structure of script 35 | - Seeds for numpy and random are automatically fixed 36 | - Can code while adhering to [SOLID](https://en.wikipedia.org/wiki/SOLID) principles as much as possible 37 | - Tasks are locked via redis even if they run in parallel 38 | 39 | **All the functions above are created for constructing Machine Learning batches. Provides an excellent environment for reproducibility and team development.** 40 | 41 | 42 | Here are some non-goal / downside of the gokart. 43 | - Batch execution in parallel is supported, but parallel and concurrent execution of task in memory. 44 | - Gokart is focused on reproducibility. So, I/O and capacity of data storage can become a bottleneck. 45 | - No support for task visualize. 46 | - Gokart is not an experiment management tool. The management of the execution result is cut out as [Thunderbolt](https://github.com/m3dev/thunderbolt). 47 | - Gokart does not recommend writing pipelines in toml, yaml, json, and more. Gokart is preferring to write them in Python. 48 | 49 | # Getting Started 50 | 51 | Within the activated Python environment, use the following command to install gokart. 52 | 53 | ``` 54 | pip install gokart 55 | ``` 56 | 57 | 58 | # Quickstart 59 | 60 | ## Minimal Example 61 | 62 | A minimal gokart tasks looks something like this: 63 | 64 | 65 | ```python 66 | import gokart 67 | 68 | class Example(gokart.TaskOnKart): 69 | def run(self): 70 | self.dump('Hello, world!') 71 | 72 | task = Example() 73 | output = gokart.build(task) 74 | print(output) 75 | ``` 76 | 77 | `gokart.build` return the result of dump by `gokart.TaskOnKart`. The example will output the following. 78 | 79 | 80 | ``` 81 | Hello, world! 82 | ``` 83 | 84 | ## Type-Safe Pipeline Example 85 | 86 | We introduce type-annotations to make a gokart pipeline robust. 87 | Please check the following example to see how to use type-annotations on gokart. 88 | Before using this feature, ensure to enable [mypy plugin](https://gokart.readthedocs.io/en/latest/mypy_plugin.html) feature in your project. 89 | 90 | ```python 91 | import gokart 92 | 93 | # `gokart.TaskOnKart[str]` means that the task dumps `str` 94 | class StrDumpTask(gokart.TaskOnKart[str]): 95 | def run(self): 96 | self.dump('Hello, world!') 97 | 98 | # `gokart.TaskOnKart[int]` means that the task dumps `int` 99 | class OneDumpTask(gokart.TaskOnKart[int]): 100 | def run(self): 101 | self.dump(1) 102 | 103 | # `gokart.TaskOnKart[int]` means that the task dumps `int` 104 | class TwoDumpTask(gokart.TaskOnKart[int]): 105 | def run(self): 106 | self.dump(2) 107 | 108 | class AddTask(gokart.TaskOnKart[int]): 109 | # `a` requires a task to dump `int` 110 | a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() 111 | # `b` requires a task to dump `int` 112 | b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() 113 | 114 | def requires(self): 115 | return dict(a=self.a, b=self.b) 116 | 117 | def run(self): 118 | # loading by instance parameter, 119 | # `a` and `b` are treated as `int` 120 | # because they are declared as `gokart.TaskOnKart[int]` 121 | a = self.load(self.a) 122 | b = self.load(self.b) 123 | self.dump(a + b) 124 | 125 | 126 | valid_task = AddTask(a=OneDumpTask(), b=TwoDumpTask()) 127 | # the next line will show type error by mypy 128 | # because `StrDumpTask` dumps `str` and `AddTask` requires `int` 129 | invalid_task = AddTask(a=OneDumpTask(), b=StrDumpTask()) 130 | ``` 131 | 132 | This is an introduction to some of the gokart. 133 | There are still more useful features. 134 | 135 | Please See [Documentation](https://gokart.readthedocs.io/en/latest/) . 136 | 137 | Have a good gokart life. 138 | 139 | # Achievements 140 | 141 | Gokart is a proven product. 142 | 143 | - It's actually been used by [m3.inc](https://corporate.m3.com/en) for over 3 years 144 | - Natural Language Processing Competition by [Nishika.inc](https://nishika.com) 2nd prize : [Solution Repository](https://github.com/vaaaaanquish/nishika_akutagawa_2nd_prize) 145 | 146 | 147 | # Thanks 148 | 149 | gokart is a wrapper for luigi. Thanks to luigi and dependent projects! 150 | 151 | - [luigi](https://github.com/spotify/luigi) 152 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # https://github.com/sphinx-doc/sphinx/issues/6211 2 | import luigi 3 | 4 | import gokart 5 | 6 | luigi.task.Task.requires.__doc__ = gokart.task.TaskOnKart.requires.__doc__ 7 | luigi.task.Task.output.__doc__ = gokart.task.TaskOnKart.output.__doc__ 8 | 9 | # 10 | # Configuration file for the Sphinx documentation builder. 11 | # 12 | # This file does only contain a selection of the most common options. For a 13 | # full list see the documentation: 14 | # http://www.sphinx-doc.org/en/master/config 15 | 16 | # -- Path setup -------------------------------------------------------------- 17 | 18 | # If extensions (or modules to document with autodoc) are in another directory, 19 | # add these directories to sys.path here. If the directory is relative to the 20 | # documentation root, use os.path.abspath to make it absolute, like shown here. 21 | 22 | # import os 23 | # import sys 24 | # sys.path.insert(0, os.path.abspath('../gokart/')) 25 | 26 | # -- Project information ----------------------------------------------------- 27 | 28 | project = 'gokart' 29 | copyright = '2019, Masahiro Nishiba' 30 | author = 'Masahiro Nishiba' 31 | 32 | # The short X.Y version 33 | version = '' 34 | # The full version, including alpha/beta/rc tags 35 | release = '' 36 | 37 | # -- General configuration --------------------------------------------------- 38 | 39 | # If your documentation needs a minimal Sphinx version, state it here. 40 | # 41 | # needs_sphinx = '1.0' 42 | 43 | # Add any Sphinx extension module names here, as strings. They can be 44 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 45 | # ones. 46 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['_templates'] 50 | 51 | # The suffix(es) of source filenames. 52 | # You can specify multiple suffix as a list of string: 53 | # 54 | # source_suffix = ['.rst', '.md'] 55 | source_suffix = '.rst' 56 | 57 | # The master toctree document. 58 | master_doc = 'index' 59 | 60 | # The language for content autogenerated by Sphinx. Refer to documentation 61 | # for a list of supported languages. 62 | # 63 | # This is also used if you do content translation via gettext catalogs. 64 | # Usually you set "language" from the command line for these cases. 65 | language = None 66 | 67 | # List of patterns, relative to source directory, that match files and 68 | # directories to ignore when looking for source files. 69 | # This pattern also affects html_static_path and html_extra_path. 70 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 71 | 72 | # The name of the Pygments (syntax highlighting) style to use. 73 | pygments_style = None 74 | 75 | # -- Options for HTML output ------------------------------------------------- 76 | 77 | # The theme to use for HTML and HTML Help pages. See the documentation for 78 | # a list of builtin themes. 79 | # 80 | html_theme = 'sphinx_rtd_theme' 81 | 82 | # Theme options are theme-specific and customize the look and feel of a theme 83 | # further. For a list of options available for each theme, see the 84 | # documentation. 85 | # 86 | # html_theme_options = {} 87 | 88 | # Add any paths that contain custom static files (such as style sheets) here, 89 | # relative to this directory. They are copied after the builtin static files, 90 | # so a file named "default.css" will overwrite the builtin "default.css". 91 | html_static_path = [] 92 | 93 | # Custom sidebar templates, must be a dictionary that maps document names 94 | # to template names. 95 | # 96 | # The default sidebars (for documents that don't match any pattern) are 97 | # defined by theme itself. Builtin themes are using these templates by 98 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 99 | # 'searchbox.html']``. 100 | 101 | # html_sidebars = {} 102 | 103 | # -- Options for HTMLHelp output --------------------------------------------- 104 | 105 | # Output file base name for HTML help builder. 106 | htmlhelp_basename = 'gokartdoc' 107 | 108 | # -- Options for LaTeX output ------------------------------------------------ 109 | 110 | latex_elements = { 111 | # The paper size ('letterpaper' or 'a4paper'). 112 | # 113 | # 'papersize': 'letterpaper', 114 | # The font size ('10pt', '11pt' or '12pt'). 115 | # 116 | # 'pointsize': '10pt', 117 | # Additional stuff for the LaTeX preamble. 118 | # 119 | # 'preamble': '', 120 | # Latex figure (float) alignment 121 | # 122 | # 'figure_align': 'htbp', 123 | } 124 | 125 | # Grouping the document tree into LaTeX files. List of tuples 126 | # (source start file, target name, title, 127 | # author, documentclass [howto, manual, or own class]). 128 | latex_documents = [ 129 | (master_doc, 'gokart.tex', 'gokart Documentation', 'Masahiro Nishiba', 'manual'), 130 | ] 131 | 132 | # -- Options for manual page output ------------------------------------------ 133 | 134 | # One entry per manual page. List of tuples 135 | # (source start file, name, description, authors, manual section). 136 | man_pages = [(master_doc, 'gokart', 'gokart Documentation', [author], 1)] 137 | 138 | # -- Options for Texinfo output ---------------------------------------------- 139 | 140 | # Grouping the document tree into Texinfo files. List of tuples 141 | # (source start file, target name, title, author, 142 | # dir menu entry, description, category) 143 | texinfo_documents = [ 144 | (master_doc, 'gokart', 'gokart Documentation', author, 'gokart', 'One line description of project.', 'Miscellaneous'), 145 | ] 146 | 147 | # -- Options for Epub output ------------------------------------------------- 148 | 149 | # Bibliographic Dublin Core info. 150 | epub_title = project 151 | 152 | # The unique identifier of the text. This can be a ISBN number 153 | # or the project homepage. 154 | # 155 | # epub_identifier = '' 156 | 157 | # A unique identification for the text. 158 | # 159 | # epub_uid = '' 160 | 161 | # A list of files that should not be packed into the epub file. 162 | epub_exclude_files = ['search.html'] 163 | -------------------------------------------------------------------------------- /docs/efficient_run_on_multi_workers.rst: -------------------------------------------------------------------------------- 1 | How to improve efficiency when running on multiple workers 2 | =========================================================== 3 | 4 | If multiple worker nodes are running similar gokart pipelines in parallel, it is possible that the exact same task may be executed by multiple workers. 5 | (For example, when training multiple machine learning models with different parameters, the feature creation task in the first stage is expected to be exactly the same.) 6 | 7 | It is inefficient to execute the same task on each of multiple worker nodes, so we want to avoid this. 8 | Here we introduce `should_lock_run` feature to improve this inefficiency. 9 | 10 | 11 | 12 | Suppress run() of the same task with `should_lock_run` 13 | ------------------------------------------------------ 14 | When `gokart.TaskOnKart.should_lock_run` is set to True, the task will fail if the same task is run()-ing by another worker. 15 | By failing the task, other tasks that can be executed at that time are given priority. 16 | After that, the failed task is automatically re-executed. 17 | 18 | .. code:: python 19 | 20 | class SampleTask2(gokart.TaskOnKart): 21 | should_lock_run = True 22 | 23 | 24 | Additional Option 25 | ------------------ 26 | 27 | Skip completed tasks with `complete_check_at_run` 28 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 29 | By setting `gokart.TaskOnKart.complete_check_at_run` to True, the existence of the cache can be rechecked at run() time. 30 | 31 | Default is True, but if the check takes too much time, you can set to False to inactivate the check. 32 | 33 | .. code:: python 34 | 35 | class SampleTask1(gokart.TaskOnKart): 36 | complete_check_at_run = False 37 | 38 | -------------------------------------------------------------------------------- /docs/for_pandas.rst: -------------------------------------------------------------------------------- 1 | For Pandas 2 | ========== 3 | 4 | Gokart has several features for Pandas. 5 | 6 | 7 | Pandas Type Config 8 | ------------------ 9 | 10 | Pandas has a feature that converts the type of column(s) automatically. This feature sometimes cause wrong result. To avoid unintentional type conversion of pandas, we can specify a column name to check the type of Task input and output in gokart. 11 | 12 | 13 | .. code:: python 14 | 15 | from typing import Any, Dict 16 | import pandas as pd 17 | import gokart 18 | 19 | 20 | # Please define a class which inherits `gokart.PandasTypeConfig`. 21 | class SamplePandasTypeConfig(gokart.PandasTypeConfig): 22 | 23 | @classmethod 24 | def type_dict(cls) -> Dict[str, Any]: 25 | return {'int_column': int} 26 | 27 | 28 | class SampleTask(gokart.TaskOnKart[pd.DataFrame]): 29 | 30 | def run(self): 31 | # [PandasTypeError] because expected type is `int`, but `str` is passed. 32 | df = pd.DataFrame(dict(int_column=['a'])) 33 | self.dump(df) 34 | 35 | This is useful when dataframe has nullable columns because pandas auto-conversion often fails in such case. 36 | 37 | Easy to Load DataFrame 38 | ---------------------- 39 | 40 | The :func:`~gokart.task.TaskOnKart.load` method is used to load input ``pandas.DataFrame``. 41 | 42 | .. code:: python 43 | 44 | def requires(self): 45 | return MakeDataFrameTask() 46 | 47 | def run(self): 48 | df = self.load() 49 | 50 | Please refer to :func:`~gokart.task.TaskOnKart.load`. 51 | 52 | 53 | Fail on empty DataFrame 54 | ----------------------- 55 | 56 | When the :attr:`~gokart.task.TaskOnKart.fail_on_empty_dump` parameter is true, the :func:`~gokart.task.TaskOnKart.dump()` method raises :class:`~gokart.errors.EmptyDumpError` on trying to dump empty ``pandas.DataFrame``. 57 | 58 | 59 | .. code:: python 60 | 61 | import gokart 62 | 63 | 64 | class EmptyTask(gokart.TaskOnKart): 65 | def run(self): 66 | df = pd.DataFrame() 67 | self.dump(df) 68 | 69 | 70 | :: 71 | 72 | $ python main.py EmptyTask --fail-on-empty-dump true 73 | # EmptyDumpError 74 | $ python main.py EmptyTask 75 | # Task will be ran and outputs an empty dataframe 76 | 77 | 78 | Empty caches sometimes hide bugs and let us spend much time debugging. This feature notifies us some bugs (including wrong datasources) in the early stage. 79 | 80 | Please refer to :attr:`~gokart.task.TaskOnKart.fail_on_empty_dump`. 81 | -------------------------------------------------------------------------------- /docs/gokart.rst: -------------------------------------------------------------------------------- 1 | gokart package 2 | ============== 3 | 4 | Submodules 5 | ---------- 6 | 7 | gokart.file\_processor module 8 | ----------------------------- 9 | 10 | .. automodule:: gokart.file_processor 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | gokart.info module 16 | ------------------ 17 | 18 | .. automodule:: gokart.info 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | gokart.parameter module 24 | ----------------------- 25 | 26 | .. automodule:: gokart.parameter 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | gokart.run module 32 | ----------------- 33 | 34 | .. automodule:: gokart.run 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | gokart.s3\_config module 40 | ------------------------ 41 | 42 | .. automodule:: gokart.s3_config 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | gokart.target module 48 | -------------------- 49 | 50 | .. automodule:: gokart.target 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | gokart.task module 56 | ------------------ 57 | 58 | .. automodule:: gokart.task 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | gokart.workspace\_management module 64 | ----------------------------------- 65 | 66 | .. automodule:: gokart.workspace_management 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | gokart.zip\_client module 72 | ------------------------- 73 | 74 | .. automodule:: gokart.zip_client 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | 80 | Module contents 81 | --------------- 82 | 83 | .. automodule:: gokart 84 | :members: 85 | :undoc-members: 86 | :show-inheritance: 87 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. gokart documentation master file, created by 2 | sphinx-quickstart on Fri Jan 11 07:59:25 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to gokart's documentation! 7 | ================================== 8 | 9 | Useful links: `GitHub `_ | `cookiecutter gokart `_ 10 | 11 | `Gokart `_ is a wrapper of the data pipeline library `luigi `_. Gokart solves "**reproducibility**", "**task dependencies**", "**constraints of good code**", and "**ease of use**" for Machine Learning Pipeline. 12 | 13 | 14 | Good thing about gokart 15 | ----------------------- 16 | 17 | Here are some good things about gokart. 18 | 19 | - The following data for each Task is stored separately in a pkl file with hash value 20 | - task output data 21 | - imported all module versions 22 | - task processing time 23 | - random seed in task 24 | - displayed log 25 | - all parameters set as class variables in the task 26 | - If change parameter of Task, rerun spontaneously. 27 | - The above file will be generated with a different hash value 28 | - The hash value of dependent task will also change and both will be rerun 29 | - Support GCS or S3 30 | - The above output is exchanged between tasks as an intermediate file, which is memory-friendly 31 | - pandas.DataFrame type and column checking during I/O 32 | - Directory structure of saved files is automatically determined from structure of script 33 | - Seeds for numpy and random are automatically fixed 34 | - Can code while adhering to SOLID principles as much as possible 35 | - Tasks are locked via redis even if they run in parallel 36 | 37 | **These are all functions baptized for creating Machine Learning batches. Provides an excellent environment for reproducibility and team development.** 38 | 39 | 40 | 41 | Getting started 42 | ----------------- 43 | 44 | .. toctree:: 45 | :maxdepth: 2 46 | 47 | intro_to_gokart 48 | tutorial 49 | 50 | User Guide 51 | ----------------- 52 | 53 | .. toctree:: 54 | :maxdepth: 2 55 | 56 | task_on_kart 57 | task_parameters 58 | setting_task_parameters 59 | task_settings 60 | task_information 61 | logging 62 | slack_notification 63 | using_task_task_conflict_prevention_lock 64 | efficient_run_on_multi_workers 65 | for_pandas 66 | mypy_plugin 67 | 68 | API References 69 | -------------- 70 | .. toctree:: 71 | :maxdepth: 2 72 | 73 | gokart 74 | 75 | 76 | Indices and tables 77 | ------------------- 78 | 79 | * :ref:`genindex` 80 | * :ref:`modindex` 81 | * :ref:`search` 82 | -------------------------------------------------------------------------------- /docs/intro_to_gokart.rst: -------------------------------------------------------------------------------- 1 | Intro To Gokart 2 | =============== 3 | 4 | 5 | Installation 6 | ------------ 7 | 8 | Within the activated Python environment, use the following command to install gokart. 9 | 10 | .. code:: sh 11 | 12 | pip install gokart 13 | 14 | 15 | 16 | Quickstart 17 | ---------- 18 | 19 | A minimal gokart tasks looks something like this: 20 | 21 | 22 | .. code:: python 23 | 24 | import gokart 25 | 26 | class Example(gokart.TaskOnKart[str]): 27 | def run(self): 28 | self.dump('Hello, world!') 29 | 30 | task = Example() 31 | output = gokart.build(task) 32 | print(output) 33 | 34 | 35 | ``gokart.build`` return the result of dump by ``gokart.TaskOnKart``. The example will output the following. 36 | 37 | 38 | .. code:: sh 39 | 40 | Hello, world! 41 | 42 | 43 | ``gokart`` records all the information needed for Machine Learning. By default, ``resources`` will be generated in the same directory as the script. 44 | 45 | .. code:: sh 46 | 47 | $ tree resources/ 48 | resources/ 49 | ├── __main__ 50 | │   └── Example_8441c59b5ce0113396d53509f19371fb.pkl 51 | └── log 52 | ├── module_versions 53 | │   └── Example_8441c59b5ce0113396d53509f19371fb.txt 54 | ├── processing_time 55 | │   └── Example_8441c59b5ce0113396d53509f19371fb.pkl 56 | ├── random_seed 57 | │   └── Example_8441c59b5ce0113396d53509f19371fb.pkl 58 | ├── task_log 59 | │   └── Example_8441c59b5ce0113396d53509f19371fb.pkl 60 | └── task_params 61 | └── Example_8441c59b5ce0113396d53509f19371fb.pkl 62 | 63 | 64 | The result of dumping the task will be saved in the ``__name__`` directory. 65 | 66 | 67 | .. code:: python 68 | 69 | import pickle 70 | 71 | with open('resources/__main__/Example_8441c59b5ce0113396d53509f19371fb.pkl', 'rb') as f: 72 | print(pickle.load(f)) # Hello, world! 73 | 74 | 75 | That will be given hash value depending on the parameter of the task. This means that if you change the parameter of the task, the hash value will change, and change output file. This is very useful when changing parameters and experimenting. Please refer to :doc:`task_parameters` section for task parameters. Also see :doc:`task_on_kart` section for information on how to return this output destination. 76 | 77 | 78 | In addition, the following files are automatically saved as ``log``. 79 | 80 | - ``module_versions``: The versions of all modules that were imported when the script was executed. For reproducibility. 81 | - ``processing_time``: The execution time of the task. 82 | - ``random_seed``: This is random seed of python and numpy. For reproducibility in Machine Learning. Please refer to :doc:`task_settings` section. 83 | - ``task_log``: This is the output of the task logger. 84 | - ``task_params``: This is task's parameters. Please refer to :doc:`task_parameters` section. 85 | 86 | 87 | How to running task 88 | ------------------- 89 | 90 | Gokart has ``run`` and ``build`` methods for running task. Each has a different purpose. 91 | 92 | - ``gokart.run``: uses arguments on the shell. return retcode. 93 | - ``gokart.build``: uses inline code on jupyter notebook, IPython, and more. return task output. 94 | 95 | 96 | .. note:: 97 | 98 | It is not recommended to use ``gokart.run`` and ``gokart.build`` together in the same script. Because ``gokart.build`` will clear the contents of ``luigi.register``. It's the only way to handle duplicate tasks. 99 | 100 | 101 | gokart.run 102 | ~~~~~~~~~~ 103 | 104 | The :func:`~gokart.run` is running on shell. 105 | 106 | .. code:: python 107 | 108 | import gokart 109 | import luigi 110 | 111 | class SampleTask(gokart.TaskOnKart[str]): 112 | param = luigi.Parameter() 113 | 114 | def run(self): 115 | self.dump(self.param) 116 | 117 | gokart.run() 118 | 119 | 120 | .. code:: sh 121 | 122 | python sample.py SampleTask --local-scheduler --param=hello 123 | 124 | 125 | If you were to write it in Python, it would be the same as the following behavior. 126 | 127 | 128 | .. code:: python 129 | 130 | gokart.run(['SampleTask', '--local-scheduler', '--param=hello']) 131 | 132 | 133 | gokart.build 134 | ~~~~~~~~~~~~ 135 | 136 | The :func:`~gokart.build` is inline code. 137 | 138 | .. code:: python 139 | 140 | import gokart 141 | import luigi 142 | 143 | class SampleTask(gokart.TaskOnKart[str]): 144 | param = luigi.Parameter() 145 | 146 | def run(self): 147 | self.dump(self.param) 148 | 149 | gokart.build(SampleTask(param='hello'), return_value=False) 150 | 151 | 152 | To output logs of each tasks, you can pass `~log_level` parameter to `~gokart.build` as follows: 153 | 154 | .. code:: python 155 | 156 | gokart.build(SampleTask(param='hello'), return_value=False, log_level=logging.DEBUG) 157 | 158 | 159 | This feature is very useful for running `~gokart` on jupyter notebook. 160 | When some tasks are failed, gokart.build raises GokartBuildError. If you have to get tracebacks, you should set `log_level` as `logging.DEBUG`. 161 | -------------------------------------------------------------------------------- /docs/logging.rst: -------------------------------------------------------------------------------- 1 | Logging 2 | ======= 3 | 4 | How to set up a common logger for gokart. 5 | 6 | 7 | Core settings 8 | ------------- 9 | 10 | Please write a configuration file similar to the following: 11 | 12 | :: 13 | 14 | # base.ini 15 | [core] 16 | logging_conf_file=./conf/logging.ini 17 | 18 | .. code:: python 19 | 20 | import gokart 21 | gokart.add_config('base.ini') 22 | 23 | 24 | Logger ini file 25 | --------------- 26 | 27 | It is the same as a general logging.ini file. 28 | 29 | :: 30 | 31 | [loggers] 32 | keys=root,luigi,luigi-interface,gokart,gokart.file_processor 33 | 34 | [handlers] 35 | keys=stderrHandler 36 | 37 | [formatters] 38 | keys=simpleFormatter 39 | 40 | [logger_root] 41 | level=INFO 42 | handlers=stderrHandler 43 | 44 | [logger_gokart] 45 | level=INFO 46 | handlers=stderrHandler 47 | qualname=gokart 48 | propagate=0 49 | 50 | [logger_luigi] 51 | level=INFO 52 | handlers=stderrHandler 53 | qualname=luigi 54 | propagate=0 55 | 56 | [logger_luigi-interface] 57 | level=INFO 58 | handlers=stderrHandler 59 | qualname=luigi-interface 60 | propagate=0 61 | 62 | [logger_gokart.file_processor] 63 | level=CRITICAL 64 | handlers=stderrHandler 65 | qualname=gokart.file_processor 66 | 67 | [handler_stderrHandler] 68 | class=StreamHandler 69 | formatter=simpleFormatter 70 | args=(sys.stdout,) 71 | 72 | [formatter_simpleFormatter] 73 | format=[%(asctime)s][%(name)s][%(levelname)s](%(filename)s:%(lineno)s) %(message)s 74 | datefmt=%Y/%m/%d %H:%M:%S 75 | 76 | Please refer to `Python logging documentation `_ 77 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/mypy_plugin.rst: -------------------------------------------------------------------------------- 1 | [Experimental] Mypy plugin 2 | =========================== 3 | 4 | Mypy plugin provides type checking for gokart tasks using Mypy. 5 | This feature is experimental. 6 | 7 | How to use 8 | -------------- 9 | 10 | Configure Mypy to use this plugin by adding the following to your ``mypy.ini`` file: 11 | 12 | .. code:: ini 13 | 14 | [mypy] 15 | plugins = gokart.mypy:plugin 16 | 17 | or by adding the following to your ``pyproject.toml`` file: 18 | 19 | .. code:: toml 20 | 21 | [tool.mypy] 22 | plugins = ["gokart.mypy"] 23 | 24 | Then, run Mypy as usual. 25 | 26 | Examples 27 | -------- 28 | 29 | For example the following code linted by Mypy: 30 | 31 | .. code:: python 32 | 33 | import gokart 34 | import luigi 35 | 36 | 37 | class Foo(gokart.TaskOnKart): 38 | # NOTE: must all the parameters be annotated 39 | foo: int = luigi.IntParameter(default=1) 40 | bar: str = luigi.Parameter() 41 | 42 | 43 | 44 | Foo(foo=1, bar='2') # OK 45 | Foo(foo='1') # NG because foo is not int and bar is missing 46 | 47 | 48 | Mypy plugin checks TaskOnKart generic types. 49 | 50 | .. code:: python 51 | 52 | class SampleTask(gokart.TaskOnKart): 53 | str_task: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter() 54 | int_task: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() 55 | 56 | def requires(self): 57 | return dict(str=self.str_task, int=self.int_task) 58 | 59 | def run(self): 60 | s = self.load(self.str_task) # This type is inferred with "str" 61 | i = self.load(self.int_task) # This type is inferred with "int" 62 | 63 | SampleTask( 64 | str_task=StrTask(), # mypy ok 65 | int_task=StrTask(), # mypy error: Argument "int_task" to "StrTask" has incompatible type "StrTask"; expected "TaskOnKart[int] 66 | ) 67 | 68 | Configurations (only pyproject.toml) 69 | ----------------------------------- 70 | 71 | You can configure the Mypy plugin using the ``pyproject.toml`` file. 72 | The following options are available: 73 | 74 | .. code:: toml 75 | 76 | [tool.gokart-mypy] 77 | # If true, Mypy will raise an error if a task is missing required parameters. 78 | # This configuration causes an error when the parameters set by `luigi.Config()` 79 | # Default: false 80 | disallow_missing_parameters = true 81 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx 2 | gokart 3 | sphinx-rtd-theme 4 | -------------------------------------------------------------------------------- /docs/setting_task_parameters.rst: -------------------------------------------------------------------------------- 1 | ============================ 2 | Setting Task Parameters 3 | ============================ 4 | 5 | There are several ways to set task parameters. 6 | 7 | - Set parameter from command line 8 | - Set parameter at config file 9 | - Set parameter at upstream task 10 | - Inherit parameter from other task 11 | 12 | 13 | Set parameter from command line 14 | ================================== 15 | .. code:: sh 16 | 17 | python main.py sample.SomeTask --SomeTask-param=Hello 18 | 19 | Parameter of each task can be set as a command line parameter in ``--[task name]-[parameter name]=[value]`` format. 20 | 21 | 22 | Set parameter at config file 23 | ================================== 24 | :: 25 | 26 | [sample.SomeTask] 27 | param = Hello 28 | 29 | Above config file (``config.ini``) must be read before ``gokart.run()`` as the following code: 30 | 31 | .. code:: python 32 | 33 | if __name__ == '__main__': 34 | gokart.add_config('./conf/config.ini') 35 | gokart.run() 36 | 37 | 38 | It can also be loaded from environment variable as the following code: 39 | 40 | :: 41 | 42 | [sample.SomeTask] 43 | param=${PARAMS} 44 | 45 | [TaskOnKart] 46 | workspace_directory=${WORKSPACE_DIRECTORY} 47 | 48 | The advantages of using environment variables are 1) important information will not be logged 2) common settings can be used. 49 | 50 | 51 | Set parameter at upstream task 52 | ================================== 53 | 54 | Parameters can be set at the upstream task, as in a typical pipeline. 55 | 56 | .. code:: python 57 | 58 | class UpstreamTask(gokart.TaskOnKart): 59 | def requires(self): 60 | return dict(sometask=SomeTask(param='Hello')) 61 | 62 | 63 | Inherit parameter from other task 64 | ================================== 65 | 66 | Parameter values can be inherited from other task using ``@inherits_config_params`` decorator. 67 | 68 | .. code:: python 69 | 70 | class MasterConfig(luigi.Config): 71 | param: str = luigi.Parameter() 72 | param2: str = luigi.Parameter() 73 | 74 | @inherits_config_params(MasterConfig) 75 | class SomeTask(gokart.TaskOnKart): 76 | param: str = luigi.Parameter() 77 | 78 | 79 | This is useful when multiple tasks has the same parameter. In the above example, parameter settings of ``MasterConfig`` will be inherited to all tasks decorated with ``@inherits_config_params(MasterConfig)`` as ``SomeTask``. 80 | 81 | Note that only parameters which exist in both ``MasterConfig`` and ``SomeTask`` will be inherited. 82 | In the above example, ``param2`` will not be available in ``SomeTask``, since ``SomeTask`` does not have ``param2`` parameter. 83 | 84 | .. code:: python 85 | 86 | class MasterConfig(luigi.Config): 87 | param: str = luigi.Parameter() 88 | param2: str = luigi.Parameter() 89 | 90 | @inherits_config_params(MasterConfig, parameter_alias={'param2': 'param3'}) 91 | class SomeTask(gokart.TaskOnKart): 92 | param3: str = luigi.Parameter() 93 | 94 | 95 | You may also set a parameter name alias by setting ``parameter_alias``. 96 | ``parameter_alias`` must be a dictionary of key: inheriting task's parameter name, value: decorating task's parameter name. 97 | 98 | In the above example, ``SomeTask.param3`` will be set to same value as ``MasterConfig.param2``. 99 | -------------------------------------------------------------------------------- /docs/slack_notification.rst: -------------------------------------------------------------------------------- 1 | Slack notification 2 | ========================= 3 | 4 | Prerequisites 5 | ------------- 6 | 7 | Prepare following environmental variables: 8 | 9 | .. code:: sh 10 | 11 | export SLACK_TOKEN=xoxb-your-token // should use token starts with "xoxb-" (bot token is preferable) 12 | export SLACK_CHANNEL=channel-name // not "#channel-name", just "channel-name" 13 | 14 | 15 | A Slack bot token can obtain from `slack app document `_. 16 | 17 | A bot token needs following scopes: 18 | 19 | - `channels:read` 20 | - `chat:write` 21 | - `files:write` 22 | 23 | More about scopes are `slack scopes document `_. 24 | 25 | Implement Slack notification 26 | ---------------------------- 27 | 28 | Write following codes pass arguments to your gokart workflow. 29 | 30 | .. code:: python 31 | 32 | cmdline_args = sys.argv[1:] 33 | if 'SLACK_CHANNEL' in os.environ: 34 | cmdline_args.append(f'--SlackConfig-channel={os.environ["SLACK_CHANNEL"]}') 35 | if 'SLACK_TO_USER' in os.environ: 36 | cmdline_args.append(f'--SlackConfig-to-user={os.environ["SLACK_TO_USER"]}') 37 | gokart.run(cmdline_args) 38 | 39 | -------------------------------------------------------------------------------- /docs/task_parameters.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | Task Parameters 3 | ================= 4 | 5 | Luigi Parameter 6 | ================ 7 | 8 | We can set parameters for tasks. 9 | Also please refer to :doc:`task_settings` section. 10 | 11 | .. code:: python 12 | 13 | class Task(gokart.TaskOnKart): 14 | param_a = luigi.Parameter() 15 | param_c = luigi.ListParameter() 16 | param_d = luigi.IntParameter(default=1) 17 | 18 | Please refer to `luigi document `_ for a list of parameter types. 19 | 20 | 21 | Gokart Parameter 22 | ================ 23 | 24 | There are also parameters provided by gokart. 25 | 26 | - gokart.TaskInstanceParameter 27 | - gokart.ListTaskInstanceParameter 28 | - gokart.ExplicitBoolParameter 29 | 30 | 31 | gokart.TaskInstanceParameter 32 | -------------------------------- 33 | 34 | The :func:`~gokart.parameter.TaskInstanceParameter` executes a task using the results of a task as dynamic parameters. 35 | 36 | 37 | .. code:: python 38 | 39 | class TaskA(gokart.TaskOnKart[str]): 40 | def run(self): 41 | self.dump('Hello') 42 | 43 | 44 | class TaskB(gokart.TaskOnKart[str]): 45 | require_task = gokart.TaskInstanceParameter() 46 | 47 | def requires(self): 48 | return self.require_task 49 | 50 | def run(self): 51 | task_a = self.load() 52 | self.dump(','.join([task_a, 'world'])) 53 | 54 | task = TaskB(require_task=TaskA()) 55 | print(gokart.build(task)) # Hello,world 56 | 57 | 58 | Helps to create a pipeline. 59 | 60 | 61 | gokart.ListTaskInstanceParameter 62 | ------------------------------------- 63 | 64 | The :func:`~gokart.parameter.ListTaskInstanceParameter` is list of TaskInstanceParameter. 65 | 66 | 67 | gokart.ExplicitBoolParameter 68 | ----------------------------------- 69 | 70 | The :func:`~gokart.parameter.ExplicitBoolParameter` is parameter for explicitly specified value. 71 | 72 | ``luigi.BoolParameter`` already has "explicit parsing" feature, but also still has implicit behavior like follows. 73 | 74 | :: 75 | 76 | $ python main.py Task --param 77 | # param will be set as True 78 | $ python main.py Task 79 | # param will be set as False 80 | 81 | ``ExplicitBoolParameter`` solves these problems on parameters from command line. 82 | 83 | 84 | gokart.SerializableParameter 85 | ---------------------------- 86 | 87 | The :func:`~gokart.parameter.SerializableParameter` is a parameter for any object that can be serialized and deserialized. 88 | This parameter is particularly useful when you want to pass a complex object or a set of parameters to a task. 89 | 90 | The object must implement the following methods: 91 | 92 | - ``gokart_serialize``: Serialize the object to a string. This serialized string must uniquely identify the object to enable task caching. 93 | Note that it is not required for deserialization. 94 | - ``gokart_deserialize``: Deserialize the object from a string, typically used for CLI arguments. 95 | 96 | Example 97 | ^^^^^^^ 98 | 99 | .. code-block:: python 100 | 101 | import json 102 | from dataclasses import dataclass 103 | 104 | import gokart 105 | 106 | @dataclass(frozen=True) 107 | class Config: 108 | foo: int 109 | # The `bar` field does not affect the result of the task. 110 | # Similar to `luigi.Parameter(significant=False)`. 111 | bar: str 112 | 113 | def gokart_serialize(self) -> str: 114 | # Serialize only the `foo` field since `bar` is irrelevant for caching. 115 | return json.dumps({'foo': self.foo}) 116 | 117 | @classmethod 118 | def gokart_deserialize(cls, s: str) -> 'Config': 119 | # Deserialize the object from the provided string. 120 | return cls(**json.loads(s)) 121 | 122 | class DummyTask(gokart.TaskOnKart): 123 | config: Config = gokart.SerializableParameter(object_type=Config) 124 | 125 | def run(self): 126 | # Save the `config` object as part of the task result. 127 | self.dump(self.config) 128 | -------------------------------------------------------------------------------- /docs/task_settings.rst: -------------------------------------------------------------------------------- 1 | Task Settings 2 | ============= 3 | 4 | Task settings. Also please refer to :doc:`task_parameters` section. 5 | 6 | 7 | Directory to Save Outputs 8 | ------------------------- 9 | 10 | We can use both a local directory and the S3 to save outputs. 11 | If you would like to use local directory, please set a local directory path to :attr:`~gokart.task.TaskOnKart.workspace_directory`. Please refer to :doc:`task_parameters` for how to set it up. 12 | 13 | It is recommended to use the config file since it does not change much. 14 | 15 | :: 16 | 17 | # base.ini 18 | [TaskOnKart] 19 | workspace_directory=${TASK_WORKSPACE_DIRECTORY} 20 | 21 | .. code:: python 22 | 23 | # main.py 24 | import gokart 25 | gokart.add_config('base.ini') 26 | 27 | 28 | To use the S3 or GCS repository, please set the bucket path as ``s3://{YOUR_REPOSITORY_NAME}`` or ``gs://{YOUR_REPOSITORY_NAME}`` respectively. 29 | 30 | If use S3 or GCS, please set credential information to Environment Variables. 31 | 32 | .. code:: sh 33 | 34 | # S3 35 | export AWS_ACCESS_KEY_ID='~~~' # AWS access key 36 | export AWS_SECRET_ACCESS_KEY='~~~' # AWS secret access key 37 | 38 | # GCS 39 | export GCS_CREDENTIAL='~~~' # GCS credential 40 | export DISCOVER_CACHE_LOCAL_PATH='~~~' # The local file path of discover api cache. 41 | 42 | 43 | Rerun task 44 | ---------- 45 | 46 | There are times when we want to rerun a task, such as when change script or on batch. Please use the ``rerun`` parameter or add an arbitrary parameter. 47 | 48 | 49 | When set rerun as follows: 50 | 51 | .. code:: python 52 | 53 | # rerun TaskA 54 | gokart.build(Task(rerun=True)) 55 | 56 | 57 | When used from an argument as follows: 58 | 59 | .. code:: python 60 | 61 | # main.py 62 | class Task(gokart.TaskOnKart[str]): 63 | def run(self): 64 | self.dump('hello') 65 | 66 | .. code:: sh 67 | 68 | python main.py Task --local-scheduler --rerun 69 | 70 | 71 | ``rerun`` parameter will look at the dependent tasks up to one level. 72 | 73 | Example: Suppose we have a straight line pipeline composed of TaskA, TaskB and TaskC, and TaskC is an endpoint of this pipeline. We also suppose that all the tasks have already been executed. 74 | 75 | - TaskA(rerun=True) -> TaskB -> TaskC # not rerunning 76 | - TaskA -> TaskB(rerun=True) -> TaskC # rerunning TaskB and TaskC 77 | 78 | This is due to the way intermediate files are handled. ``rerun`` parameter is ``significant=False``, it does not affect the hash value. It is very important to understand this difference. 79 | 80 | 81 | If you want to change the parameter of TaskA and rerun TaskB and TaskC, recommend adding an arbitrary parameter. 82 | 83 | .. code:: python 84 | 85 | class TaskA(gokart.TaskOnKart): 86 | __version = luigi.IntParameter(default=1) 87 | 88 | If the hash value of TaskA will change, the dependent tasks (in this case, TaskB and TaskC) will rerun. 89 | 90 | 91 | Fix random seed 92 | --------------- 93 | 94 | Every task has a parameter named :attr:`~gokart.task.TaskOnKart.fix_random_seed_methods` and :attr:`~gokart.task.TaskOnKart.fix_random_seed_value`. This can be used to fix the random seed. 95 | 96 | 97 | .. code:: python 98 | 99 | import gokart 100 | import random 101 | import numpy 102 | import torch 103 | 104 | class Task(gokart.TaskOnKart[dict[str, Any]]): 105 | def run(self): 106 | x = [random.randint(0, 100) for _ in range(0, 10)] 107 | y = [np.random.randint(0, 100) for _ in range(0, 10)] 108 | z = [torch.randn(1).tolist()[0] for _ in range(0, 5)] 109 | self.dump({'random': x, 'numpy': y, 'torch': z}) 110 | 111 | gokart.build( 112 | Task( 113 | fix_random_seed_methods=[ 114 | "random.seed", 115 | "numpy.random.seed", 116 | "torch.random.manual_seed"], 117 | fix_random_seed_value=57)) 118 | 119 | :: 120 | 121 | # //--- The output is as follows every time. --- 122 | # {'random': [65, 41, 61, 37, 55, 81, 48, 2, 94, 21], 123 | # 'numpy': [79, 86, 5, 22, 79, 98, 56, 40, 81, 37], 'torch': []} 124 | # 'torch': [0.14460121095180511, -0.11649507284164429, 125 | # 0.6928958296775818, -0.916053831577301, 0.7317505478858948]} 126 | 127 | This will be useful for using Machine Learning Libraries. 128 | -------------------------------------------------------------------------------- /docs/using_task_task_conflict_prevention_lock.rst: -------------------------------------------------------------------------------- 1 | Task conflict prevention lock 2 | ========================= 3 | 4 | If there is a possibility of multiple worker nodes executing the same task, task cache conflict may happen. 5 | Specifically, while node A is loading the cache of a task, node B may be writing to it. 6 | This can lead to reading an inappropriate data and other unwanted behaviors. 7 | 8 | The redis lock introduced in this page is a feature to prevent such cache collisions. 9 | 10 | Requires 11 | -------- 12 | 13 | You need to install `redis `_ for using this advanced feature. 14 | 15 | 16 | How to use 17 | ----------- 18 | 19 | 20 | 1. Set up a redis server at somewhere accessible from gokart/luigi jobs. 21 | 22 | e.g. Following script will run redis at your localhost. 23 | 24 | .. code:: bash 25 | 26 | $ redis-server 27 | 28 | 2. Set redis server hostname and port number as parameters of gokart.TaskOnKart(). 29 | 30 | You can set it by adding ``--redis-host=[your-redis-localhost] --redis-port=[redis-port-number]`` options to gokart python script. 31 | 32 | e.g. 33 | 34 | .. code:: bash 35 | 36 | python main.py sample.SomeTask --local-scheduler --redis-host=localhost --redis-port=6379 37 | 38 | 39 | Alternatively, you may set parameters at config file. 40 | 41 | e.g. 42 | 43 | .. code:: 44 | 45 | [TaskOnKart] 46 | redis_host=localhost 47 | redis_port=6379 48 | 49 | 3. Done 50 | 51 | With the above configuration, all tasks that inherits gokart.TaskOnKart will ask the redis server if any other node is not trying to access the same cache file at the same time whenever they access the file with dump or load. 52 | -------------------------------------------------------------------------------- /examples/gokart_notebook_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "gokart 1.0.2\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "!pip list | grep gokart" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 3, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import gokart\n", 27 | "import luigi" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "# Examples of using gokart at jupyter notebook" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Basic Usage\n", 42 | "This is a very basic usage, just to dump a run result of ExampleTaskA." 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "example_2\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "class ExampleTaskA(gokart.TaskOnKart):\n", 60 | " param = luigi.Parameter()\n", 61 | " int_param = luigi.IntParameter(default=2)\n", 62 | "\n", 63 | " def run(self):\n", 64 | " self.dump(f'DONE {self.param}_{self.int_param}')\n", 65 | "\n", 66 | " \n", 67 | "task_a = ExampleTaskA(param='example')\n", 68 | "output = gokart.build(task=task_a)\n", 69 | "print(output)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "## Make tasks dependencies with `requires()`\n", 77 | "ExampleTaskB is dependent on ExampleTaskC and ExampleTaskD. They are defined in `ExampleTaskB.requires()`." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "DONE example_TASKC_TASKD\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "class ExampleTaskC(gokart.TaskOnKart):\n", 95 | " def run(self):\n", 96 | " self.dump('TASKC')\n", 97 | " \n", 98 | "class ExampleTaskD(gokart.TaskOnKart):\n", 99 | " def run(self):\n", 100 | " self.dump('TASKD')\n", 101 | "\n", 102 | "class ExampleTaskB(gokart.TaskOnKart):\n", 103 | " param = luigi.Parameter()\n", 104 | "\n", 105 | " def requires(self):\n", 106 | " return dict(task_c=ExampleTaskC(), task_d=ExampleTaskD())\n", 107 | "\n", 108 | " def run(self):\n", 109 | " task_c = self.load('task_c')\n", 110 | " task_d = self.load('task_d')\n", 111 | " self.dump(f'DONE {self.param}_{task_c}_{task_d}')\n", 112 | " \n", 113 | "task_b = ExampleTaskB(param='example')\n", 114 | "output = gokart.build(task=task_b)\n", 115 | "print(output)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "## Make tasks dependencies with TaskInstanceParameter\n", 123 | "The dependencies are same as previous example, however they are defined at the outside of the task instead of defied at `ExampleTaskB.requires()`." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "DONE example_TASKC_TASKD\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "class ExampleTaskC(gokart.TaskOnKart):\n", 141 | " def run(self):\n", 142 | " self.dump('TASKC')\n", 143 | " \n", 144 | "class ExampleTaskD(gokart.TaskOnKart):\n", 145 | " def run(self):\n", 146 | " self.dump('TASKD')\n", 147 | "\n", 148 | "class ExampleTaskB(gokart.TaskOnKart):\n", 149 | " param = luigi.Parameter()\n", 150 | " task_1 = gokart.TaskInstanceParameter()\n", 151 | " task_2 = gokart.TaskInstanceParameter()\n", 152 | "\n", 153 | " def requires(self):\n", 154 | " return dict(task_1=self.task_1, task_2=self.task_2) # required tasks are decided from the task parameters `task_1` and `task_2`\n", 155 | "\n", 156 | " def run(self):\n", 157 | " task_1 = self.load('task_1')\n", 158 | " task_2 = self.load('task_2')\n", 159 | " self.dump(f'DONE {self.param}_{task_1}_{task_2}')\n", 160 | " \n", 161 | "task_b = ExampleTaskB(param='example', task_1=ExampleTaskC(), task_2=ExampleTaskD()) # Dependent tasks are defined here\n", 162 | "output = gokart.build(task=task_b)\n", 163 | "print(output)" 164 | ] 165 | } 166 | ], 167 | "metadata": { 168 | "kernelspec": { 169 | "display_name": "Python 3.8.8 64-bit ('3.8.8': pyenv)", 170 | "name": "python388jvsc74a57bd026997db2bf0f03e18da4e606f276befe0d6bf7cab2a6bb74742969d5bbde02ca" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3", 182 | "version": "3.8.8" 183 | }, 184 | "metadata": { 185 | "interpreter": { 186 | "hash": "26997db2bf0f03e18da4e606f276befe0d6bf7cab2a6bb74742969d5bbde02ca" 187 | } 188 | }, 189 | "orig_nbformat": 3 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 2 193 | } -------------------------------------------------------------------------------- /examples/logging.ini: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root,luigi,luigi-interface,gokart 3 | 4 | [handlers] 5 | keys=stderrHandler 6 | 7 | [formatters] 8 | keys=simpleFormatter 9 | 10 | [logger_root] 11 | level=INFO 12 | handlers=stderrHandler 13 | 14 | [logger_gokart] 15 | level=INFO 16 | handlers=stderrHandler 17 | qualname=gokart 18 | propagate=0 19 | 20 | [logger_luigi] 21 | level=INFO 22 | handlers=stderrHandler 23 | qualname=luigi 24 | propagate=0 25 | 26 | [logger_luigi-interface] 27 | level=INFO 28 | handlers=stderrHandler 29 | qualname=luigi-interface 30 | propagate=0 31 | 32 | [handler_stderrHandler] 33 | class=StreamHandler 34 | formatter=simpleFormatter 35 | args=(sys.stdout,) 36 | 37 | [formatter_simpleFormatter] 38 | format=level=%(levelname)s time=%(asctime)s name=%(name)s file=%(filename)s line=%(lineno)d message=%(message)s 39 | datefmt=%Y/%m/%d %H:%M:%S 40 | class=logging.Formatter 41 | -------------------------------------------------------------------------------- /examples/param.ini: -------------------------------------------------------------------------------- 1 | [TaskOnKart] 2 | workspace_directory=./resource 3 | local_temporary_directory=./resource/tmp 4 | 5 | [core] 6 | logging_conf_file=logging.ini 7 | 8 | -------------------------------------------------------------------------------- /gokart/__init__.py: -------------------------------------------------------------------------------- 1 | from gokart.build import WorkerSchedulerFactory, build # noqa:F401 2 | from gokart.info import make_tree_info, tree_info # noqa:F401 3 | from gokart.pandas_type_config import PandasTypeConfig # noqa:F401 4 | from gokart.parameter import ( # noqa:F401 5 | ExplicitBoolParameter, 6 | ListTaskInstanceParameter, 7 | SerializableParameter, 8 | TaskInstanceParameter, 9 | ZonedDateSecondParameter, 10 | ) 11 | from gokart.run import run # noqa:F401 12 | from gokart.task import TaskOnKart # noqa:F401 13 | from gokart.testing import test_run # noqa:F401 14 | from gokart.tree.task_info import make_task_info_as_tree_str # noqa:F401 15 | from gokart.utils import add_config # noqa:F401 16 | from gokart.workspace_management import delete_local_unnecessary_outputs # noqa:F401 17 | -------------------------------------------------------------------------------- /gokart/build.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import enum 4 | import logging 5 | import sys 6 | from dataclasses import dataclass 7 | from functools import partial 8 | from logging import getLogger 9 | from typing import Literal, Protocol, TypeVar, cast, overload 10 | 11 | import backoff 12 | import luigi 13 | from luigi import LuigiStatusCode, rpc, scheduler 14 | 15 | import gokart 16 | import gokart.tree.task_info 17 | from gokart import worker 18 | from gokart.conflict_prevention_lock.task_lock import TaskLockException 19 | from gokart.target import TargetOnKart 20 | from gokart.task import TaskOnKart 21 | 22 | T = TypeVar('T') 23 | 24 | logger: logging.Logger = logging.getLogger(__name__) 25 | 26 | 27 | class LoggerConfig: 28 | def __init__(self, level: int): 29 | self.logger = getLogger(__name__) 30 | self.default_level = self.logger.level 31 | self.level = level 32 | 33 | def __enter__(self): 34 | logging.disable(self.level - 10) # subtract 10 to disable below self.level 35 | self.logger.setLevel(self.level) 36 | return self 37 | 38 | def __exit__(self, exception_type, exception_value, traceback): 39 | logging.disable(self.default_level - 10) # subtract 10 to disable below self.level 40 | self.logger.setLevel(self.default_level) 41 | 42 | 43 | class GokartBuildError(Exception): 44 | """Raised when ``gokart.build`` failed. This exception contains raised exceptions in the task execution.""" 45 | 46 | def __init__(self, message, raised_exceptions: dict[str, list[Exception]]): 47 | super().__init__(message) 48 | self.raised_exceptions = raised_exceptions 49 | 50 | 51 | class HasLockedTaskException(Exception): 52 | """Raised when the task failed to acquire the lock in the task execution.""" 53 | 54 | 55 | class TaskLockExceptionRaisedFlag: 56 | def __init__(self): 57 | self.flag: bool = False 58 | 59 | 60 | class WorkerProtocol(Protocol): 61 | """Protocol for Worker. 62 | This protocol is determined by luigi.worker.Worker. 63 | """ 64 | 65 | def add(self, task: TaskOnKart) -> bool: ... 66 | 67 | def run(self) -> bool: ... 68 | 69 | def __enter__(self) -> WorkerProtocol: ... 70 | 71 | def __exit__(self, type, value, traceback) -> Literal[False]: ... 72 | 73 | 74 | class WorkerSchedulerFactory: 75 | def create_local_scheduler(self) -> scheduler.Scheduler: 76 | return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False) 77 | 78 | def create_remote_scheduler(self, url) -> rpc.RemoteScheduler: 79 | return rpc.RemoteScheduler(url) 80 | 81 | def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant=False) -> WorkerProtocol: 82 | return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant) 83 | 84 | 85 | def _get_output(task: TaskOnKart[T]) -> T: 86 | output = task.output() 87 | # FIXME: currently, nested output is not supported 88 | if isinstance(output, list) or isinstance(output, tuple): 89 | return cast(T, [t.load() for t in output if isinstance(t, TargetOnKart)]) 90 | if isinstance(output, dict): 91 | return cast(T, {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)}) 92 | if isinstance(output, TargetOnKart): 93 | return output.load() 94 | raise ValueError(f'output type is not supported: {type(output)}') 95 | 96 | 97 | def _reset_register(keep={'gokart', 'luigi'}): 98 | """reset luigi.task_register.Register._reg everytime gokart.build called to avoid TaskClassAmbigiousException""" 99 | luigi.task_register.Register._reg = [ 100 | x 101 | for x in luigi.task_register.Register._reg 102 | if ( 103 | (x.__module__.split('.')[0] in keep) # keep luigi and gokart 104 | or (issubclass(x, gokart.PandasTypeConfig)) 105 | ) # PandasTypeConfig should be kept 106 | ] 107 | 108 | 109 | class TaskDumpMode(enum.Enum): 110 | TREE = 'tree' 111 | TABLE = 'table' 112 | NONE = 'none' 113 | 114 | 115 | class TaskDumpOutputType(enum.Enum): 116 | PRINT = 'print' 117 | DUMP = 'dump' 118 | NONE = 'none' 119 | 120 | 121 | @dataclass 122 | class TaskDumpConfig: 123 | mode: TaskDumpMode = TaskDumpMode.NONE 124 | output_type: TaskDumpOutputType = TaskDumpOutputType.NONE 125 | 126 | 127 | if sys.version_info < (3, 10): 128 | 129 | def process_task_info(task: TaskOnKart, task_dump_config: TaskDumpConfig = TaskDumpConfig()): 130 | logger.warning('process_task_info is not supported in Python 3.9 or lower.') 131 | else: 132 | # FIXME: after dropping python3.9 support, change this import to direct implementation 133 | from .build_process_task_info import process_task_info 134 | 135 | 136 | @overload 137 | def build( 138 | task: TaskOnKart[T], 139 | return_value: Literal[True] = True, 140 | reset_register: bool = True, 141 | log_level: int = logging.ERROR, 142 | task_lock_exception_max_tries: int = 10, 143 | task_lock_exception_max_wait_seconds: int = 600, 144 | **env_params, 145 | ) -> T: ... 146 | 147 | 148 | @overload 149 | def build( 150 | task: TaskOnKart[T], 151 | return_value: Literal[False], 152 | reset_register: bool = True, 153 | log_level: int = logging.ERROR, 154 | task_lock_exception_max_tries: int = 10, 155 | task_lock_exception_max_wait_seconds: int = 600, 156 | **env_params, 157 | ) -> None: ... 158 | 159 | 160 | def build( 161 | task: TaskOnKart[T], 162 | return_value: bool = True, 163 | reset_register: bool = True, 164 | log_level: int = logging.ERROR, 165 | task_lock_exception_max_tries: int = 10, 166 | task_lock_exception_max_wait_seconds: int = 600, 167 | task_dump_config: TaskDumpConfig = TaskDumpConfig(), 168 | **env_params, 169 | ) -> T | None: 170 | """ 171 | Run gokart task for local interpreter. 172 | Sharing the most of its parameters with luigi.build (see https://luigi.readthedocs.io/en/stable/api/luigi.html?highlight=build#luigi.build) 173 | """ 174 | if reset_register: 175 | _reset_register() 176 | with LoggerConfig(level=log_level): 177 | log_handler_before_run = logging.StreamHandler() 178 | logger.addHandler(log_handler_before_run) 179 | process_task_info(task, task_dump_config) 180 | logger.removeHandler(log_handler_before_run) 181 | log_handler_before_run.close() 182 | 183 | task_lock_exception_raised = TaskLockExceptionRaisedFlag() 184 | raised_exceptions: dict[str, list[Exception]] = dict() 185 | 186 | @TaskOnKart.event_handler(luigi.Event.FAILURE) 187 | def when_failure(task, exception): 188 | if isinstance(exception, TaskLockException): 189 | task_lock_exception_raised.flag = True 190 | else: 191 | raised_exceptions.setdefault(task.make_unique_id(), []).append(exception) 192 | 193 | @backoff.on_exception( 194 | partial(backoff.expo, max_value=task_lock_exception_max_wait_seconds), HasLockedTaskException, max_tries=task_lock_exception_max_tries 195 | ) 196 | def _build_task(): 197 | task_lock_exception_raised.flag = False 198 | result = luigi.build( 199 | [task], 200 | local_scheduler=True, 201 | detailed_summary=True, 202 | log_level=logging.getLevelName(log_level), 203 | **env_params, 204 | ) 205 | if task_lock_exception_raised.flag: 206 | raise HasLockedTaskException() 207 | if result.status in (LuigiStatusCode.FAILED, LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED, LuigiStatusCode.SCHEDULING_FAILED): 208 | raise GokartBuildError(result.summary_text, raised_exceptions=raised_exceptions) 209 | return _get_output(task) if return_value else None 210 | 211 | return _build_task() 212 | -------------------------------------------------------------------------------- /gokart/build_process_task_info.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | 5 | import gokart 6 | import gokart.tree.task_info 7 | from gokart.build import TaskDumpConfig, TaskDumpMode, TaskDumpOutputType 8 | from gokart.task import TaskOnKart 9 | 10 | from .build import logger 11 | 12 | 13 | def process_task_info(task: TaskOnKart, task_dump_config: TaskDumpConfig = TaskDumpConfig()): 14 | match task_dump_config: 15 | case TaskDumpConfig(mode=TaskDumpMode.NONE, output_type=TaskDumpOutputType.NONE): 16 | pass 17 | case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.PRINT): 18 | tree = gokart.make_tree_info(task) 19 | logger.info(tree) 20 | case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.PRINT): 21 | table = gokart.tree.task_info.make_task_info_as_table(task) 22 | output = io.StringIO() 23 | table.to_csv(output, index=False, sep='\t') 24 | output.seek(0) 25 | logger.info(output.read()) 26 | case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.DUMP): 27 | tree = gokart.make_tree_info(task) 28 | gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.txt').dump(tree) 29 | case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.DUMP): 30 | table = gokart.tree.task_info.make_task_info_as_table(task) 31 | gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.pkl').dump(table) 32 | case _: 33 | raise ValueError(f'Unsupported TaskDumpConfig: {task_dump_config}') 34 | -------------------------------------------------------------------------------- /gokart/config_params.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import luigi 4 | 5 | import gokart 6 | 7 | 8 | class inherits_config_params: 9 | def __init__(self, config_class: type[luigi.Config], parameter_alias: dict[str, str] | None = None): 10 | """ 11 | Decorates task to inherit parameter value of `config_class`. 12 | 13 | * config_class: Inherit parameter value of this task to decorated task. Only parameter values exist in both tasks are inherited. 14 | * parameter_alias: Dictionary to map paramter names between config_class task and decorated task. 15 | key: config_class's parameter name. value: decorated task's parameter name. 16 | """ 17 | 18 | self._config_class: type[luigi.Config] = config_class 19 | self._parameter_alias: dict[str, str] = parameter_alias if parameter_alias is not None else {} 20 | 21 | def __call__(self, task_class: type[gokart.TaskOnKart]): 22 | # wrap task to prevent task name from being changed 23 | @luigi.task._task_wraps(task_class) 24 | class Wrapped(task_class): # type: ignore 25 | @classmethod 26 | def get_param_values(cls, params, args, kwargs): 27 | for param_key, param_value in self._config_class().param_kwargs.items(): 28 | task_param_key = self._parameter_alias.get(param_key, param_key) 29 | 30 | if hasattr(cls, task_param_key) and task_param_key not in kwargs: 31 | kwargs[task_param_key] = param_value 32 | return super().get_param_values(params, args, kwargs) 33 | 34 | return Wrapped 35 | -------------------------------------------------------------------------------- /gokart/conflict_prevention_lock/task_lock.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | import os 5 | from logging import getLogger 6 | from typing import NamedTuple 7 | 8 | import redis 9 | from apscheduler.schedulers.background import BackgroundScheduler 10 | 11 | logger = getLogger(__name__) 12 | 13 | 14 | class TaskLockParams(NamedTuple): 15 | redis_host: str | None 16 | redis_port: int | None 17 | redis_timeout: int | None 18 | redis_key: str 19 | should_task_lock: bool 20 | raise_task_lock_exception_on_collision: bool 21 | lock_extend_seconds: int 22 | 23 | 24 | class TaskLockException(Exception): 25 | pass 26 | """Raised when the task failed to acquire the lock in the task execution. Only used internally.""" 27 | 28 | 29 | class RedisClient: 30 | _instances: dict = {} 31 | 32 | def __new__(cls, *args, **kwargs): 33 | key = (args, tuple(sorted(kwargs.items()))) 34 | if cls not in cls._instances: 35 | cls._instances[cls] = {} 36 | if key not in cls._instances[cls]: 37 | cls._instances[cls][key] = super().__new__(cls) 38 | return cls._instances[cls][key] 39 | 40 | def __init__(self, host: str | None, port: int | None) -> None: 41 | if not hasattr(self, '_redis_client'): 42 | host = host or 'localhost' 43 | port = port or 6379 44 | self._redis_client = redis.Redis(host=host, port=port) 45 | 46 | def get_redis_client(self): 47 | return self._redis_client 48 | 49 | 50 | def _extend_lock(task_lock: redis.lock.Lock, redis_timeout: int): 51 | task_lock.extend(additional_time=redis_timeout, replace_ttl=True) 52 | 53 | 54 | def set_task_lock(task_lock_params: TaskLockParams) -> redis.lock.Lock: 55 | redis_client = RedisClient(host=task_lock_params.redis_host, port=task_lock_params.redis_port).get_redis_client() 56 | blocking = not task_lock_params.raise_task_lock_exception_on_collision 57 | task_lock = redis.lock.Lock(redis=redis_client, name=task_lock_params.redis_key, timeout=task_lock_params.redis_timeout, thread_local=False) 58 | if not task_lock.acquire(blocking=blocking): 59 | raise TaskLockException('Lock already taken by other task.') 60 | return task_lock 61 | 62 | 63 | def set_lock_scheduler(task_lock: redis.lock.Lock, task_lock_params: TaskLockParams) -> BackgroundScheduler: 64 | scheduler = BackgroundScheduler() 65 | extend_lock = functools.partial(_extend_lock, task_lock=task_lock, redis_timeout=task_lock_params.redis_timeout or 0) 66 | scheduler.add_job( 67 | extend_lock, 68 | 'interval', 69 | seconds=task_lock_params.lock_extend_seconds, 70 | max_instances=999999999, 71 | misfire_grace_time=task_lock_params.redis_timeout, 72 | coalesce=False, 73 | ) 74 | scheduler.start() 75 | return scheduler 76 | 77 | 78 | def make_task_lock_key(file_path: str, unique_id: str | None): 79 | basename_without_ext = os.path.splitext(os.path.basename(file_path))[0] 80 | return f'{basename_without_ext}_{unique_id}' 81 | 82 | 83 | def make_task_lock_params( 84 | file_path: str, 85 | unique_id: str | None, 86 | redis_host: str | None = None, 87 | redis_port: int | None = None, 88 | redis_timeout: int | None = None, 89 | raise_task_lock_exception_on_collision: bool = False, 90 | lock_extend_seconds: int = 10, 91 | ) -> TaskLockParams: 92 | redis_key = make_task_lock_key(file_path, unique_id) 93 | should_task_lock = redis_host is not None and redis_port is not None 94 | if redis_timeout is not None: 95 | assert redis_timeout > lock_extend_seconds, f'`redis_timeout` must be set greater than lock_extend_seconds:{lock_extend_seconds}, not {redis_timeout}.' 96 | task_lock_params = TaskLockParams( 97 | redis_host=redis_host, 98 | redis_port=redis_port, 99 | redis_key=redis_key, 100 | should_task_lock=should_task_lock, 101 | redis_timeout=redis_timeout, 102 | raise_task_lock_exception_on_collision=raise_task_lock_exception_on_collision, 103 | lock_extend_seconds=lock_extend_seconds, 104 | ) 105 | return task_lock_params 106 | 107 | 108 | def make_task_lock_params_for_run(task_self, lock_extend_seconds: int = 10) -> TaskLockParams: 109 | task_path_name = os.path.join(task_self.__module__.replace('.', '/'), f'{type(task_self).__name__}') 110 | unique_id = task_self.make_unique_id() + '-run' 111 | task_lock_key = make_task_lock_key(file_path=task_path_name, unique_id=unique_id) 112 | 113 | should_task_lock = task_self.redis_host is not None and task_self.redis_port is not None 114 | return TaskLockParams( 115 | redis_host=task_self.redis_host, 116 | redis_port=task_self.redis_port, 117 | redis_key=task_lock_key, 118 | should_task_lock=should_task_lock, 119 | redis_timeout=task_self.redis_timeout, 120 | raise_task_lock_exception_on_collision=True, 121 | lock_extend_seconds=lock_extend_seconds, 122 | ) 123 | -------------------------------------------------------------------------------- /gokart/conflict_prevention_lock/task_lock_wrappers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from logging import getLogger 5 | from typing import Any, Callable 6 | 7 | from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock 8 | 9 | logger = getLogger(__name__) 10 | 11 | 12 | def wrap_dump_with_lock(func: Callable, task_lock_params: TaskLockParams, exist_check: Callable): 13 | """Redis lock wrapper function for TargetOnKart.dump(). 14 | When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check. 15 | https://github.com/m3dev/gokart/issues/265 16 | """ 17 | 18 | if not task_lock_params.should_task_lock: 19 | return func 20 | 21 | def wrapper(*args, **kwargs): 22 | task_lock = set_task_lock(task_lock_params=task_lock_params) 23 | scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) 24 | 25 | try: 26 | logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} locked.') 27 | if not exist_check(): 28 | func(*args, **kwargs) 29 | finally: 30 | logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} released.') 31 | task_lock.release() 32 | scheduler.shutdown() 33 | 34 | return wrapper 35 | 36 | 37 | def wrap_load_with_lock(func, task_lock_params: TaskLockParams): 38 | """Redis lock wrapper function for TargetOnKart.load(). 39 | When TargetOnKart.load() is called, redis lock will be locked and released before load(). 40 | https://github.com/m3dev/gokart/issues/265 41 | """ 42 | 43 | if not task_lock_params.should_task_lock: 44 | return func 45 | 46 | def wrapper(*args, **kwargs): 47 | task_lock = set_task_lock(task_lock_params=task_lock_params) 48 | scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) 49 | 50 | logger.debug(f'Task LOAD lock of {task_lock_params.redis_key} locked.') 51 | task_lock.release() 52 | logger.debug(f'Task LOAD lock of {task_lock_params.redis_key} released.') 53 | scheduler.shutdown() 54 | result = func(*args, **kwargs) 55 | return result 56 | 57 | return wrapper 58 | 59 | 60 | def wrap_remove_with_lock(func, task_lock_params: TaskLockParams): 61 | """Redis lock wrapper function for TargetOnKart.remove(). 62 | When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock. 63 | https://github.com/m3dev/gokart/issues/265 64 | """ 65 | if not task_lock_params.should_task_lock: 66 | return func 67 | 68 | def wrapper(*args, **kwargs): 69 | task_lock = set_task_lock(task_lock_params=task_lock_params) 70 | scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) 71 | 72 | try: 73 | logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} locked.') 74 | result = func(*args, **kwargs) 75 | task_lock.release() 76 | logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} released.') 77 | scheduler.shutdown() 78 | return result 79 | except BaseException as e: 80 | logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} released with BaseException.') 81 | task_lock.release() 82 | scheduler.shutdown() 83 | raise e 84 | 85 | return wrapper 86 | 87 | 88 | def wrap_run_with_lock(run_func: Callable[[], Any], task_lock_params: TaskLockParams): 89 | @functools.wraps(run_func) 90 | def wrapped(): 91 | task_lock = set_task_lock(task_lock_params=task_lock_params) 92 | scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) 93 | 94 | try: 95 | logger.debug(f'Task RUN lock of {task_lock_params.redis_key} locked.') 96 | result = run_func() 97 | task_lock.release() 98 | logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released.') 99 | scheduler.shutdown() 100 | return result 101 | except BaseException as e: 102 | logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released with BaseException.') 103 | task_lock.release() 104 | scheduler.shutdown() 105 | raise e 106 | 107 | return wrapped 108 | -------------------------------------------------------------------------------- /gokart/errors/__init__.py: -------------------------------------------------------------------------------- 1 | from gokart.build import GokartBuildError, HasLockedTaskException 2 | from gokart.pandas_type_config import PandasTypeError 3 | from gokart.task import EmptyDumpError 4 | 5 | __all__ = [ 6 | 'GokartBuildError', 7 | 'HasLockedTaskException', 8 | 'PandasTypeError', 9 | 'EmptyDumpError', 10 | ] 11 | -------------------------------------------------------------------------------- /gokart/gcs_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | 6 | import luigi 7 | import luigi.contrib.gcs 8 | from google.oauth2.service_account import Credentials 9 | 10 | 11 | class GCSConfig(luigi.Config): 12 | gcs_credential_name: str = luigi.Parameter(default='GCS_CREDENTIAL', description='GCS credential environment variable.') 13 | _client = None 14 | 15 | def get_gcs_client(self) -> luigi.contrib.gcs.GCSClient: 16 | if self._client is None: # use cache as like singleton object 17 | self._client = self._get_gcs_client() 18 | return self._client 19 | 20 | def _get_gcs_client(self) -> luigi.contrib.gcs.GCSClient: 21 | return luigi.contrib.gcs.GCSClient(oauth_credentials=self._load_oauth_credentials()) 22 | 23 | def _load_oauth_credentials(self) -> Credentials | None: 24 | json_str = os.environ.get(self.gcs_credential_name) 25 | if not json_str: 26 | return None 27 | 28 | if os.path.isfile(json_str): 29 | return Credentials.from_service_account_file(json_str) 30 | 31 | return Credentials.from_service_account_info(json.loads(json_str)) 32 | -------------------------------------------------------------------------------- /gokart/gcs_zip_client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import shutil 5 | 6 | from gokart.gcs_config import GCSConfig 7 | from gokart.zip_client import ZipClient, _unzip_file 8 | 9 | 10 | class GCSZipClient(ZipClient): 11 | def __init__(self, file_path: str, temporary_directory: str) -> None: 12 | self._file_path = file_path 13 | self._temporary_directory = temporary_directory 14 | self._client = GCSConfig().get_gcs_client() 15 | 16 | def exists(self) -> bool: 17 | return self._client.exists(self._file_path) 18 | 19 | def make_archive(self) -> None: 20 | extension = os.path.splitext(self._file_path)[1] 21 | shutil.make_archive(base_name=self._temporary_directory, format=extension[1:], root_dir=self._temporary_directory) 22 | self._client.put(self._temporary_file_path(), self._file_path) 23 | 24 | def unpack_archive(self) -> None: 25 | os.makedirs(self._temporary_directory, exist_ok=True) 26 | file_pointer = self._client.download(self._file_path) 27 | _unzip_file(fp=file_pointer, extract_dir=self._temporary_directory) 28 | 29 | def remove(self) -> None: 30 | self._client.remove(self._file_path) 31 | 32 | @property 33 | def path(self) -> str: 34 | return self._file_path 35 | 36 | def _temporary_file_path(self): 37 | extension = os.path.splitext(self._file_path)[1] 38 | base_name = self._temporary_directory 39 | if base_name.endswith('/'): 40 | base_name = base_name[:-1] 41 | return base_name + extension 42 | -------------------------------------------------------------------------------- /gokart/in_memory/__init__.py: -------------------------------------------------------------------------------- 1 | from .repository import InMemoryCacheRepository # noqa:F401 2 | from .target import InMemoryTarget, make_in_memory_target # noqa:F401 3 | -------------------------------------------------------------------------------- /gokart/in_memory/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from datetime import datetime 5 | from typing import Any 6 | 7 | 8 | @dataclass 9 | class InMemoryData: 10 | value: Any 11 | last_modification_time: datetime 12 | 13 | @classmethod 14 | def create_data(self, value: Any) -> InMemoryData: 15 | return InMemoryData(value=value, last_modification_time=datetime.now()) 16 | -------------------------------------------------------------------------------- /gokart/in_memory/repository.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Iterator 4 | from typing import Any 5 | 6 | from .data import InMemoryData 7 | 8 | 9 | class InMemoryCacheRepository: 10 | _cache: dict[str, InMemoryData] = {} 11 | 12 | def __init__(self): 13 | pass 14 | 15 | def get_value(self, key: str) -> Any: 16 | return self._get_data(key).value 17 | 18 | def get_last_modification_time(self, key: str): 19 | return self._get_data(key).last_modification_time 20 | 21 | def _get_data(self, key: str) -> InMemoryData: 22 | return self._cache[key] 23 | 24 | def set_value(self, key: str, obj: Any) -> None: 25 | data = InMemoryData.create_data(obj) 26 | self._cache[key] = data 27 | 28 | def has(self, key: str) -> bool: 29 | return key in self._cache 30 | 31 | def remove(self, key: str) -> None: 32 | assert self.has(key), f'{key} does not exist.' 33 | del self._cache[key] 34 | 35 | def empty(self) -> bool: 36 | return not self._cache 37 | 38 | def clear(self) -> None: 39 | self._cache.clear() 40 | 41 | def get_gen(self) -> Iterator[tuple[str, Any]]: 42 | for key, data in self._cache.items(): 43 | yield key, data.value 44 | 45 | @property 46 | def size(self) -> int: 47 | return len(self._cache) 48 | -------------------------------------------------------------------------------- /gokart/in_memory/target.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import Any 5 | 6 | from gokart.in_memory.repository import InMemoryCacheRepository 7 | from gokart.required_task_output import RequiredTaskOutput 8 | from gokart.target import TargetOnKart, TaskLockParams 9 | from gokart.utils import FlattenableItems 10 | 11 | _repository = InMemoryCacheRepository() 12 | 13 | 14 | class InMemoryTarget(TargetOnKart): 15 | def __init__(self, data_key: str, task_lock_param: TaskLockParams): 16 | if task_lock_param.should_task_lock: 17 | raise ValueError('Redis with `InMemoryTarget` is not currently supported.') 18 | 19 | self._data_key = data_key 20 | self._task_lock_params = task_lock_param 21 | 22 | def _exists(self) -> bool: 23 | return _repository.has(self._data_key) 24 | 25 | def _get_task_lock_params(self) -> TaskLockParams: 26 | return self._task_lock_params 27 | 28 | def _load(self) -> Any: 29 | return _repository.get_value(self._data_key) 30 | 31 | def _dump( 32 | self, 33 | obj: Any, 34 | task_params: dict[str, str] | None = None, 35 | custom_labels: dict[str, str] | None = None, 36 | required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, 37 | ) -> None: 38 | return _repository.set_value(self._data_key, obj) 39 | 40 | def _remove(self) -> None: 41 | _repository.remove(self._data_key) 42 | 43 | def _last_modification_time(self) -> datetime: 44 | if not _repository.has(self._data_key): 45 | raise ValueError(f'No object(s) which id is {self._data_key} are stored before.') 46 | time = _repository.get_last_modification_time(self._data_key) 47 | return time 48 | 49 | def _path(self) -> str: 50 | # TODO: this module name `_path` migit not be appropriate 51 | return self._data_key 52 | 53 | 54 | def make_in_memory_target(target_key: str, task_lock_params: TaskLockParams) -> InMemoryTarget: 55 | return InMemoryTarget(target_key, task_lock_params) 56 | -------------------------------------------------------------------------------- /gokart/info.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from logging import getLogger 4 | 5 | import luigi 6 | 7 | from gokart.task import TaskOnKart 8 | from gokart.tree.task_info import make_task_info_as_tree_str 9 | 10 | logger = getLogger(__name__) 11 | 12 | 13 | def make_tree_info( 14 | task: TaskOnKart, 15 | indent: str = '', 16 | last: bool = True, 17 | details: bool = False, 18 | abbr: bool = True, 19 | visited_tasks: set[str] | None = None, 20 | ignore_task_names: list[str] | None = None, 21 | ) -> str: 22 | """ 23 | Return a string representation of the tasks, their statuses/parameters in a dependency tree format 24 | 25 | This function has moved to `gokart.tree.task_info.make_task_info_as_tree_str`. 26 | This code is remained for backward compatibility. 27 | 28 | Parameters 29 | ---------- 30 | - task: TaskOnKart 31 | Root task. 32 | - details: bool 33 | Whether or not to output details. 34 | - abbr: bool 35 | Whether or not to simplify tasks information that has already appeared. 36 | - ignore_task_names: list[str] | None 37 | List of task names to ignore. 38 | Returns 39 | ------- 40 | - tree_info : str 41 | Formatted task dependency tree. 42 | """ 43 | return make_task_info_as_tree_str(task=task, details=details, abbr=abbr, ignore_task_names=ignore_task_names) 44 | 45 | 46 | class tree_info(TaskOnKart): 47 | mode: str = luigi.Parameter(default='', description='This must be in ["simple", "all"].') 48 | output_path: str = luigi.Parameter(default='tree.txt', description='Output file path.') 49 | 50 | def output(self): 51 | return self.make_target(self.output_path, use_unique_id=False) 52 | -------------------------------------------------------------------------------- /gokart/object_storage.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | 5 | import luigi 6 | import luigi.contrib.gcs 7 | import luigi.contrib.s3 8 | from luigi.format import Format 9 | 10 | from gokart.gcs_config import GCSConfig 11 | from gokart.gcs_zip_client import GCSZipClient 12 | from gokart.s3_config import S3Config 13 | from gokart.s3_zip_client import S3ZipClient 14 | from gokart.zip_client import ZipClient 15 | 16 | object_storage_path_prefix = ['s3://', 'gs://'] 17 | 18 | 19 | class ObjectStorage: 20 | @staticmethod 21 | def if_object_storage_path(path: str) -> bool: 22 | for prefix in object_storage_path_prefix: 23 | if path.startswith(prefix): 24 | return True 25 | return False 26 | 27 | @staticmethod 28 | def get_object_storage_target(path: str, format: Format) -> luigi.target.FileSystemTarget: 29 | if path.startswith('s3://'): 30 | return luigi.contrib.s3.S3Target(path, client=S3Config().get_s3_client(), format=format) 31 | elif path.startswith('gs://'): 32 | return luigi.contrib.gcs.GCSTarget(path, client=GCSConfig().get_gcs_client(), format=format) 33 | else: 34 | raise 35 | 36 | @staticmethod 37 | def exists(path: str) -> bool: 38 | if path.startswith('s3://'): 39 | return S3Config().get_s3_client().exists(path) 40 | elif path.startswith('gs://'): 41 | return GCSConfig().get_gcs_client().exists(path) 42 | else: 43 | raise 44 | 45 | @staticmethod 46 | def get_timestamp(path: str) -> datetime: 47 | if path.startswith('s3://'): 48 | return S3Config().get_s3_client().get_key(path).last_modified 49 | elif path.startswith('gs://'): 50 | # for gcs object 51 | # should PR to luigi 52 | bucket, obj = GCSConfig().get_gcs_client()._path_to_bucket_and_key(path) 53 | result = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute() 54 | return result['updated'] 55 | else: 56 | raise 57 | 58 | @staticmethod 59 | def get_zip_client(file_path: str, temporary_directory: str) -> ZipClient: 60 | if file_path.startswith('s3://'): 61 | return S3ZipClient(file_path=file_path, temporary_directory=temporary_directory) 62 | elif file_path.startswith('gs://'): 63 | return GCSZipClient(file_path=file_path, temporary_directory=temporary_directory) 64 | else: 65 | raise 66 | 67 | @staticmethod 68 | def is_buffered_reader(file: object): 69 | return not isinstance(file, luigi.contrib.s3.ReadableS3File) 70 | -------------------------------------------------------------------------------- /gokart/pandas_type_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import abstractmethod 4 | from logging import getLogger 5 | from typing import Any 6 | 7 | import luigi 8 | import numpy as np 9 | import pandas as pd 10 | from luigi.task_register import Register 11 | 12 | logger = getLogger(__name__) 13 | 14 | 15 | class PandasTypeError(Exception): 16 | """Raised when the type of the pandas DataFrame column is not as expected.""" 17 | 18 | 19 | class PandasTypeConfig(luigi.Config): 20 | @classmethod 21 | @abstractmethod 22 | def type_dict(cls) -> dict[str, Any]: 23 | pass 24 | 25 | @classmethod 26 | def check(cls, df: pd.DataFrame): 27 | for column_name, column_type in cls.type_dict().items(): 28 | cls._check_column(df, column_name, column_type) 29 | 30 | @classmethod 31 | def _check_column(cls, df, column_name, column_type): 32 | if column_name not in df.columns: 33 | return 34 | 35 | if not np.all(list(map(lambda x: isinstance(x, column_type), df[column_name]))): 36 | not_match = next(filter(lambda x: not isinstance(x, column_type), df[column_name])) 37 | raise PandasTypeError(f'expected type is "{column_type}", but "{type(not_match)}" is passed in column "{column_name}".') 38 | 39 | 40 | class PandasTypeConfigMap(luigi.Config): 41 | """To initialize this class only once, this inherits luigi.Config.""" 42 | 43 | def __init__(self, *args, **kwargs) -> None: 44 | super().__init__(*args, **kwargs) 45 | task_names = Register.task_names() 46 | task_classes = [Register.get_task_cls(task_name) for task_name in task_names] 47 | self._map = { 48 | task_class.task_namespace: task_class for task_class in task_classes if issubclass(task_class, PandasTypeConfig) and task_class != PandasTypeConfig 49 | } 50 | 51 | def check(self, obj, task_namespace: str): 52 | if isinstance(obj, pd.DataFrame) and task_namespace in self._map: 53 | self._map[task_namespace].check(obj) 54 | -------------------------------------------------------------------------------- /gokart/parameter.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import bz2 4 | import datetime 5 | import json 6 | from logging import getLogger 7 | from typing import Generic, Protocol, TypeVar 8 | from warnings import warn 9 | 10 | import luigi 11 | from luigi import task_register 12 | 13 | import gokart 14 | 15 | logger = getLogger(__name__) 16 | 17 | 18 | class TaskInstanceParameter(luigi.Parameter): 19 | def __init__(self, expected_type=None, *args, **kwargs): 20 | if expected_type is None: 21 | self.expected_type: type = gokart.TaskOnKart 22 | elif isinstance(expected_type, type): 23 | self.expected_type = expected_type 24 | else: 25 | raise TypeError(f'expected_type must be a type, not {type(expected_type)}') 26 | super().__init__(*args, **kwargs) 27 | 28 | @staticmethod 29 | def _recursive(param_dict): 30 | params = param_dict['params'] 31 | task_cls = task_register.Register.get_task_cls(param_dict['type']) 32 | for key, value in task_cls.get_params(): 33 | if key in params: 34 | params[key] = value.parse(params[key]) 35 | return task_cls(**params) 36 | 37 | @staticmethod 38 | def _recursive_decompress(s): 39 | s = dict(luigi.DictParameter().parse(s)) 40 | if 'params' in s: 41 | s['params'] = TaskInstanceParameter._recursive_decompress(bz2.decompress(bytes.fromhex(s['params'])).decode()) 42 | return s 43 | 44 | def parse(self, s): 45 | if isinstance(s, str): 46 | s = self._recursive_decompress(s) 47 | return self._recursive(s) 48 | 49 | def serialize(self, x): 50 | params = bz2.compress(json.dumps(x.to_str_params(only_significant=True)).encode()).hex() 51 | values = dict(type=x.get_task_family(), params=params) 52 | return luigi.DictParameter().serialize(values) 53 | 54 | def _warn_on_wrong_param_type(self, param_name, param_value): 55 | if not isinstance(param_value, self.expected_type): 56 | raise TypeError(f'{param_value} is not an instance of {self.expected_type}') 57 | 58 | 59 | class _TaskInstanceEncoder(json.JSONEncoder): 60 | def default(self, obj): 61 | if isinstance(obj, luigi.Task): 62 | return TaskInstanceParameter().serialize(obj) 63 | # Let the base class default method raise the TypeError 64 | return json.JSONEncoder.default(self, obj) 65 | 66 | 67 | class ListTaskInstanceParameter(luigi.Parameter): 68 | def __init__(self, expected_elements_type=None, *args, **kwargs): 69 | if expected_elements_type is None: 70 | self.expected_elements_type: type = gokart.TaskOnKart 71 | elif isinstance(expected_elements_type, type): 72 | self.expected_elements_type = expected_elements_type 73 | else: 74 | raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}') 75 | super().__init__(*args, **kwargs) 76 | 77 | def parse(self, s): 78 | return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))] 79 | 80 | def serialize(self, x): 81 | return json.dumps(x, cls=_TaskInstanceEncoder) 82 | 83 | def _warn_on_wrong_param_type(self, param_name, param_value): 84 | for v in param_value: 85 | if not isinstance(v, self.expected_elements_type): 86 | raise TypeError(f'{v} is not an instance of {self.expected_elements_type}') 87 | 88 | 89 | class ExplicitBoolParameter(luigi.BoolParameter): 90 | def __init__(self, *args, **kwargs): 91 | luigi.Parameter.__init__(self, *args, **kwargs) 92 | 93 | def _parser_kwargs(self, *args, **kwargs): # type: ignore 94 | return luigi.Parameter._parser_kwargs(*args, *kwargs) 95 | 96 | 97 | T = TypeVar('T') 98 | 99 | 100 | class Serializable(Protocol): 101 | def gokart_serialize(self) -> str: 102 | """Implement this method to serialize the object as an parameter 103 | You can omit some fields from results of serialization if you want to ignore changes of them 104 | """ 105 | ... 106 | 107 | @classmethod 108 | def gokart_deserialize(cls: type[T], s: str) -> T: 109 | """Implement this method to deserialize the object from a string""" 110 | ... 111 | 112 | 113 | S = TypeVar('S', bound=Serializable) 114 | 115 | 116 | class SerializableParameter(luigi.Parameter, Generic[S]): 117 | def __init__(self, object_type: type[S], *args, **kwargs): 118 | self._object_type = object_type 119 | super().__init__(*args, **kwargs) 120 | 121 | def parse(self, s: str) -> S: 122 | return self._object_type.gokart_deserialize(s) 123 | 124 | def serialize(self, x: S) -> str: 125 | return x.gokart_serialize() 126 | 127 | 128 | class ZonedDateSecondParameter(luigi.Parameter): 129 | """ 130 | ZonedDateSecondParameter supports a datetime.datetime object with timezone information. 131 | 132 | A ZonedDateSecondParameter is a `ISO 8601 `_ formatted 133 | date, time specified to the second and timezone. For example, ``2013-07-10T19:07:38+09:00`` specifies July 10, 2013 at 134 | 19:07:38 +09:00. The separator `:` can be omitted for Python3.11 and later. 135 | """ 136 | 137 | def __init__(self, **kwargs): 138 | super().__init__(**kwargs) 139 | 140 | def parse(self, s): 141 | # special character 'Z' is replaced with '+00:00' 142 | # because Python 3.11 and later support fromisoformat with Z at the end of the string. 143 | if s.endswith('Z'): 144 | s = s[:-1] + '+00:00' 145 | dt = datetime.datetime.fromisoformat(s) 146 | if dt.tzinfo is None: 147 | warn('The input does not have timezone information. Please consider using luigi.DateSecondParameter instead.', stacklevel=1) 148 | return dt 149 | 150 | def serialize(self, dt): 151 | return dt.isoformat() 152 | 153 | def normalize(self, dt): 154 | # override _DatetimeParameterBase.normalize to avoid do nothing to normalize except removing microsecond. 155 | # microsecond is removed because the number of digits of microsecond is not fixed. 156 | # See also luigi's implementation https://github.com/spotify/luigi/blob/v3.6.0/luigi/parameter.py#L612 157 | return dt.replace(microsecond=0) 158 | -------------------------------------------------------------------------------- /gokart/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/gokart/py.typed -------------------------------------------------------------------------------- /gokart/required_task_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class RequiredTaskOutput: 6 | task_name: str 7 | output_path: str 8 | 9 | def serialize(self) -> dict[str, str]: 10 | return {'__gokart_task_name': self.task_name, '__gokart_output_path': self.output_path} 11 | -------------------------------------------------------------------------------- /gokart/run.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | from logging import getLogger 6 | 7 | import luigi 8 | import luigi.cmdline 9 | import luigi.retcodes 10 | from luigi.cmdline_parser import CmdlineParser 11 | 12 | import gokart 13 | import gokart.slack 14 | from gokart.object_storage import ObjectStorage 15 | 16 | logger = getLogger(__name__) 17 | 18 | 19 | def _run_tree_info(cmdline_args, details): 20 | with CmdlineParser.global_instance(cmdline_args) as cp: 21 | gokart.tree_info().output().dump(gokart.make_tree_info(cp.get_task_obj(), details=details)) 22 | 23 | 24 | def _try_tree_info(cmdline_args): 25 | with CmdlineParser.global_instance(cmdline_args): 26 | mode = gokart.tree_info().mode 27 | output_path = gokart.tree_info().output().path() 28 | 29 | # do nothing if `mode` is empty. 30 | if mode == '': 31 | return 32 | 33 | # output tree info and exit. 34 | if mode == 'simple': 35 | _run_tree_info(cmdline_args, details=False) 36 | elif mode == 'all': 37 | _run_tree_info(cmdline_args, details=True) 38 | else: 39 | raise ValueError(f'--tree-info-mode must be "simple" or "all", but "{mode}" is passed.') 40 | logger.info(f'output tree info: {output_path}') 41 | sys.exit() 42 | 43 | 44 | def _try_to_delete_unnecessary_output_file(cmdline_args: list[str]): 45 | with CmdlineParser.global_instance(cmdline_args) as cp: 46 | task = cp.get_task_obj() # type: gokart.TaskOnKart 47 | if task.delete_unnecessary_output_files: 48 | if ObjectStorage.if_object_storage_path(task.workspace_directory): 49 | logger.info('delete-unnecessary-output-files is not support s3/gcs.') 50 | else: 51 | gokart.delete_local_unnecessary_outputs(task) 52 | sys.exit() 53 | 54 | 55 | def _try_get_slack_api(cmdline_args: list[str]) -> gokart.slack.SlackAPI | None: 56 | with CmdlineParser.global_instance(cmdline_args): 57 | config = gokart.slack.SlackConfig() 58 | token = os.getenv(config.token_name, '') 59 | channel = config.channel 60 | to_user = config.to_user 61 | if token and channel: 62 | logger.info('Slack notification is activated.') 63 | return gokart.slack.SlackAPI(token=token, channel=channel, to_user=to_user) 64 | logger.info('Slack notification is not activated.') 65 | return None 66 | 67 | 68 | def _try_to_send_event_summary_to_slack(slack_api: gokart.slack.SlackAPI | None, event_aggregator: gokart.slack.EventAggregator, cmdline_args: list[str]): 69 | if slack_api is None: 70 | # do nothing 71 | return 72 | options = gokart.slack.SlackConfig() 73 | with CmdlineParser.global_instance(cmdline_args) as cp: 74 | task = cp.get_task_obj() 75 | tree_info = gokart.make_tree_info(task, details=True) if options.send_tree_info else 'Please add SlackConfig.send_tree_info to include tree-info' 76 | task_name = type(task).__name__ 77 | 78 | comment = f'Report of {task_name}' + os.linesep + event_aggregator.get_summary() 79 | content = os.linesep.join(['===== Event List ====', event_aggregator.get_event_list(), os.linesep, '==== Tree Info ====', tree_info]) 80 | slack_api.send_snippet(comment=comment, title='event.txt', content=content) 81 | 82 | 83 | def run(cmdline_args=None, set_retcode=True): 84 | cmdline_args = cmdline_args or sys.argv[1:] 85 | 86 | if set_retcode: 87 | luigi.retcodes.retcode.already_running = 10 88 | luigi.retcodes.retcode.missing_data = 20 89 | luigi.retcodes.retcode.not_run = 30 90 | luigi.retcodes.retcode.task_failed = 40 91 | luigi.retcodes.retcode.scheduling_error = 50 92 | 93 | _try_tree_info(cmdline_args) 94 | _try_to_delete_unnecessary_output_file(cmdline_args) 95 | gokart.testing.try_to_run_test_for_empty_data_frame(cmdline_args) 96 | 97 | slack_api = _try_get_slack_api(cmdline_args) 98 | event_aggregator = gokart.slack.EventAggregator() 99 | try: 100 | event_aggregator.set_handlers() 101 | luigi.cmdline.luigi_run(cmdline_args) 102 | except SystemExit as e: 103 | _try_to_send_event_summary_to_slack(slack_api, event_aggregator, cmdline_args) 104 | sys.exit(e.code) 105 | -------------------------------------------------------------------------------- /gokart/s3_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | 5 | import luigi 6 | import luigi.contrib.s3 7 | 8 | 9 | class S3Config(luigi.Config): 10 | aws_access_key_id_name = luigi.Parameter(default='AWS_ACCESS_KEY_ID', description='AWS access key id environment variable.') 11 | aws_secret_access_key_name = luigi.Parameter(default='AWS_SECRET_ACCESS_KEY', description='AWS secret access key environment variable.') 12 | 13 | _client = None 14 | 15 | def get_s3_client(self) -> luigi.contrib.s3.S3Client: 16 | if self._client is None: # use cache as like singleton object 17 | self._client = self._get_s3_client() 18 | return self._client 19 | 20 | def _get_s3_client(self) -> luigi.contrib.s3.S3Client: 21 | return luigi.contrib.s3.S3Client( 22 | aws_access_key_id=os.environ.get(self.aws_access_key_id_name), aws_secret_access_key=os.environ.get(self.aws_secret_access_key_name) 23 | ) 24 | -------------------------------------------------------------------------------- /gokart/s3_zip_client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import shutil 5 | 6 | from gokart.s3_config import S3Config 7 | from gokart.zip_client import ZipClient, _unzip_file 8 | 9 | 10 | class S3ZipClient(ZipClient): 11 | def __init__(self, file_path: str, temporary_directory: str) -> None: 12 | self._file_path = file_path 13 | self._temporary_directory = temporary_directory 14 | self._client = S3Config().get_s3_client() 15 | 16 | def exists(self) -> bool: 17 | return self._client.exists(self._file_path) 18 | 19 | def make_archive(self) -> None: 20 | extension = os.path.splitext(self._file_path)[1] 21 | if not os.path.exists(self._temporary_directory): 22 | # Check path existence since shutil.make_archive() of python 3.10+ does not check it. 23 | raise FileNotFoundError(f'Temporary directory {self._temporary_directory} is not found.') 24 | shutil.make_archive(base_name=self._temporary_directory, format=extension[1:], root_dir=self._temporary_directory) 25 | self._client.put(self._temporary_file_path(), self._file_path) 26 | 27 | def unpack_archive(self) -> None: 28 | os.makedirs(self._temporary_directory, exist_ok=True) 29 | self._client.get(self._file_path, self._temporary_file_path()) 30 | _unzip_file(fp=self._temporary_file_path(), extract_dir=self._temporary_directory) 31 | 32 | def remove(self) -> None: 33 | self._client.remove(self._file_path) 34 | 35 | @property 36 | def path(self) -> str: 37 | return self._file_path 38 | 39 | def _temporary_file_path(self): 40 | extension = os.path.splitext(self._file_path)[1] 41 | base_name = self._temporary_directory 42 | if base_name.endswith('/'): 43 | base_name = base_name[:-1] 44 | return base_name + extension 45 | -------------------------------------------------------------------------------- /gokart/slack/__init__.py: -------------------------------------------------------------------------------- 1 | from gokart.slack.event_aggregator import EventAggregator 2 | from gokart.slack.slack_api import SlackAPI 3 | from gokart.slack.slack_config import SlackConfig 4 | 5 | from .slack_api import ChannelListNotLoadedError, ChannelNotFoundError, FileNotUploadedError 6 | 7 | __all__ = [ 8 | 'ChannelListNotLoadedError', 9 | 'ChannelNotFoundError', 10 | 'FileNotUploadedError', 11 | 'EventAggregator', 12 | 'SlackAPI', 13 | 'SlackConfig', 14 | ] 15 | -------------------------------------------------------------------------------- /gokart/slack/event_aggregator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from logging import getLogger 5 | from typing import TypedDict 6 | 7 | import luigi 8 | 9 | logger = getLogger(__name__) 10 | 11 | 12 | class FailureEvent(TypedDict): 13 | task: str 14 | exception: str 15 | 16 | 17 | class EventAggregator: 18 | def __init__(self) -> None: 19 | self._success_events: list[str] = [] 20 | self._failure_events: list[FailureEvent] = [] 21 | 22 | def set_handlers(self): 23 | handlers = [(luigi.Event.SUCCESS, self._success), (luigi.Event.FAILURE, self._failure)] 24 | for event, handler in handlers: 25 | luigi.Task.event_handler(event)(handler) 26 | 27 | def get_summary(self) -> str: 28 | return f'Success: {len(self._success_events)}; Failure: {len(self._failure_events)}' 29 | 30 | def get_event_list(self) -> str: 31 | message = '' 32 | if len(self._failure_events) != 0: 33 | failure_message = os.linesep.join([f'Task: {failure["task"]}; Exception: {failure["exception"]}' for failure in self._failure_events]) 34 | message += '---- Failure Tasks ----' + os.linesep + failure_message 35 | if len(self._success_events) != 0: 36 | success_message = os.linesep.join(self._success_events) 37 | message += '---- Success Tasks ----' + os.linesep + success_message 38 | if message == '': 39 | message = 'Tasks were not executed.' 40 | return message 41 | 42 | def _success(self, task): 43 | self._success_events.append(self._task_to_str(task)) 44 | 45 | def _failure(self, task, exception): 46 | failure: FailureEvent = {'task': self._task_to_str(task), 'exception': str(exception)} 47 | self._failure_events.append(failure) 48 | 49 | @staticmethod 50 | def _task_to_str(task) -> str: 51 | return f'{type(task).__name__}:[{task.make_unique_id()}]' 52 | -------------------------------------------------------------------------------- /gokart/slack/slack_api.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from logging import getLogger 4 | 5 | import slack_sdk 6 | 7 | logger = getLogger(__name__) 8 | 9 | 10 | class ChannelListNotLoadedError(RuntimeError): 11 | pass 12 | 13 | 14 | class ChannelNotFoundError(RuntimeError): 15 | pass 16 | 17 | 18 | class FileNotUploadedError(RuntimeError): 19 | pass 20 | 21 | 22 | class SlackAPI: 23 | def __init__(self, token, channel: str, to_user: str) -> None: 24 | self._client = slack_sdk.WebClient(token=token) 25 | self._channel_id = self._get_channel_id(channel) 26 | self._to_user = to_user if to_user == '' or to_user.startswith('@') else '@' + to_user 27 | 28 | def _get_channel_id(self, channel_name): 29 | params = {'exclude_archived': True, 'limit': 100} 30 | try: 31 | for channels in self._client.conversations_list(params=params): 32 | if not channels: 33 | raise ChannelListNotLoadedError('Channel list is empty.') 34 | for channel in channels.get('channels', []): 35 | if channel['name'] == channel_name: 36 | return channel['id'] 37 | raise ChannelNotFoundError(f'Channel {channel_name} is not found in public channels.') 38 | except Exception as e: 39 | logger.warning(f'The job will start without slack notification: {e}') 40 | 41 | def send_snippet(self, comment, title, content): 42 | try: 43 | request_body = dict( 44 | channels=self._channel_id, initial_comment=f'<{self._to_user}> {comment}' if self._to_user else comment, content=content, title=title 45 | ) 46 | response = self._client.api_call('files.upload', data=request_body) 47 | if not response['ok']: 48 | raise FileNotUploadedError(f'Error while uploading file. The error reason is "{response["error"]}".') 49 | except Exception as e: 50 | logger.warning(f'Failed to send slack notification: {e}') 51 | -------------------------------------------------------------------------------- /gokart/slack/slack_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import luigi 4 | 5 | 6 | class SlackConfig(luigi.Config): 7 | token_name = luigi.Parameter(default='SLACK_TOKEN', description='slack token environment variable.') 8 | channel = luigi.Parameter(default='', significant=False, description='channel name for notification.') 9 | to_user = luigi.Parameter(default='', significant=False, description='Optional; user name who is supposed to be mentioned.') 10 | send_tree_info = luigi.BoolParameter( 11 | default=False, 12 | significant=False, 13 | description='When this option is true, the dependency tree of tasks is included in send message.' 14 | 'It is recommended to set false to this option when notification takes long time.', 15 | ) 16 | -------------------------------------------------------------------------------- /gokart/task_complete_check.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from logging import getLogger 5 | from typing import Callable 6 | 7 | logger = getLogger(__name__) 8 | 9 | 10 | def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable): 11 | @functools.wraps(run_func) 12 | def wrapper(*args, **kwargs): 13 | if complete_check_func(): 14 | logger.warning(f'{run_func.__name__} is skipped because the task is already completed.') 15 | return 16 | return run_func(*args, **kwargs) 17 | 18 | return wrapper 19 | -------------------------------------------------------------------------------- /gokart/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from gokart.testing.check_if_run_with_empty_data_frame import test_run, try_to_run_test_for_empty_data_frame # noqa:F401 2 | from gokart.testing.pandas_assert import assert_frame_contents_equal # noqa:F401 3 | -------------------------------------------------------------------------------- /gokart/testing/check_if_run_with_empty_data_frame.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import sys 5 | 6 | import luigi 7 | from luigi.cmdline_parser import CmdlineParser 8 | 9 | import gokart 10 | from gokart.utils import flatten 11 | 12 | test_logger = logging.getLogger(__name__) 13 | test_logger.addHandler(logging.StreamHandler()) 14 | test_logger.setLevel(logging.INFO) 15 | 16 | 17 | class test_run(gokart.TaskOnKart): 18 | pandas: bool = luigi.BoolParameter() 19 | namespace: str | None = luigi.OptionalParameter( 20 | default=None, description='When task namespace is not defined explicitly, please use "__not_user_specified".' 21 | ) 22 | 23 | 24 | class _TestStatus: 25 | def __init__(self, task: gokart.TaskOnKart) -> None: 26 | self.namespace = task.task_namespace 27 | self.name = type(task).__name__ 28 | self.task_id = task.make_unique_id() 29 | self.status = 'OK' 30 | self.message: Exception | None = None 31 | 32 | def format(self) -> str: 33 | s = f'status={self.status}; namespace={self.namespace}; name={self.name}; id={self.task_id};' 34 | if self.message: 35 | s += f' message={type(self.message)}: {", ".join(map(str, self.message.args))}' 36 | return s 37 | 38 | def fail(self) -> bool: 39 | return self.status != 'OK' 40 | 41 | 42 | def _get_all_tasks(task: gokart.TaskOnKart) -> list[gokart.TaskOnKart]: 43 | result = [task] 44 | for o in flatten(task.requires()): 45 | result.extend(_get_all_tasks(o)) 46 | return result 47 | 48 | 49 | def _run_with_test_status(task: gokart.TaskOnKart): 50 | test_message = _TestStatus(task) 51 | try: 52 | task.run() # type: ignore 53 | except Exception as e: 54 | test_message.status = 'NG' 55 | test_message.message = e 56 | return test_message 57 | 58 | 59 | def _test_run_with_empty_data_frame(cmdline_args: list[str], test_run_params: test_run): 60 | from unittest.mock import patch 61 | 62 | try: 63 | gokart.run(cmdline_args=cmdline_args) 64 | except SystemExit as e: 65 | assert e.code == 0, f'original workflow does not run properly. It exited with error code {e}.' 66 | 67 | with CmdlineParser.global_instance(cmdline_args) as cp: 68 | all_tasks = _get_all_tasks(cp.get_task_obj()) 69 | 70 | if test_run_params.namespace is not None: 71 | all_tasks = [t for t in all_tasks if t.task_namespace == test_run_params.namespace] 72 | 73 | with patch('gokart.TaskOnKart.dump', new=lambda *args, **kwargs: None): 74 | test_status_list = [_run_with_test_status(t) for t in all_tasks] 75 | 76 | test_logger.info('gokart test results:\n' + '\n'.join(s.format() for s in test_status_list)) 77 | if any(s.fail() for s in test_status_list): 78 | sys.exit(1) 79 | 80 | 81 | def try_to_run_test_for_empty_data_frame(cmdline_args: list[str]): 82 | with CmdlineParser.global_instance(cmdline_args): 83 | test_run_params = test_run() 84 | 85 | if test_run_params.pandas: 86 | cmdline_args = [a for a in cmdline_args if not a.startswith('--test-run-')] 87 | _test_run_with_empty_data_frame(cmdline_args=cmdline_args, test_run_params=test_run_params) 88 | sys.exit(0) 89 | -------------------------------------------------------------------------------- /gokart/testing/pandas_assert.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pandas as pd 4 | 5 | 6 | def assert_frame_contents_equal(actual: pd.DataFrame, expected: pd.DataFrame, **kwargs): 7 | """ 8 | Assert that two DataFrames are equal. 9 | This function is mostly same as pandas.testing.assert_frame_equal, however 10 | - this fuction ignores the order of index and columns. 11 | - this function fails when duplicated index or columns are found. 12 | 13 | Parameters 14 | ---------- 15 | - actual, expected: pd.DataFrame 16 | DataFrames to be compared. 17 | - kwargs: Any 18 | Parameters passed to pandas.testing.assert_frame_equal. 19 | """ 20 | assert isinstance(actual, pd.DataFrame), 'actual is not a DataFrame' 21 | assert isinstance(expected, pd.DataFrame), 'expected is not a DataFrame' 22 | 23 | assert actual.index.is_unique, 'actual index is not unique' 24 | assert expected.index.is_unique, 'expected index is not unique' 25 | assert actual.columns.is_unique, 'actual columns is not unique' 26 | assert expected.columns.is_unique, 'expected columns is not unique' 27 | 28 | assert set(actual.columns) == set(expected.columns), 'columns are not equal' 29 | assert set(actual.index) == set(expected.index), 'indexes are not equal' 30 | 31 | expected_reindexed = expected.reindex(actual.index)[actual.columns] 32 | pd.testing.assert_frame_equal(actual, expected_reindexed, **kwargs) 33 | -------------------------------------------------------------------------------- /gokart/tree/task_info.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | 5 | import pandas as pd 6 | 7 | from gokart.target import make_target 8 | from gokart.task import TaskOnKart 9 | from gokart.tree.task_info_formatter import make_task_info_tree, make_tree_info, make_tree_info_table_list 10 | 11 | 12 | def make_task_info_as_tree_str(task: TaskOnKart, details: bool = False, abbr: bool = True, ignore_task_names: list[str] | None = None): 13 | """ 14 | Return a string representation of the tasks, their statuses/parameters in a dependency tree format 15 | 16 | Parameters 17 | ---------- 18 | - task: TaskOnKart 19 | Root task. 20 | - details: bool 21 | Whether or not to output details. 22 | - abbr: bool 23 | Whether or not to simplify tasks information that has already appeared. 24 | - ignore_task_names: list[str] | None 25 | List of task names to ignore. 26 | Returns 27 | ------- 28 | - tree_info : str 29 | Formatted task dependency tree. 30 | """ 31 | task_info = make_task_info_tree(task, ignore_task_names=ignore_task_names) 32 | result = make_tree_info(task_info=task_info, indent='', last=True, details=details, abbr=abbr, visited_tasks=set()) 33 | return result 34 | 35 | 36 | def make_task_info_as_table(task: TaskOnKart, ignore_task_names: list[str] | None = None): 37 | """Return a table containing information about dependent tasks. 38 | 39 | Parameters 40 | ---------- 41 | - task: TaskOnKart 42 | Root task. 43 | - ignore_task_names: list[str] | None 44 | List of task names to ignore. 45 | Returns 46 | ------- 47 | - task_info_table : pandas.DataFrame 48 | Formatted task dependency table. 49 | """ 50 | 51 | task_info = make_task_info_tree(task, ignore_task_names=ignore_task_names) 52 | task_info_table = pd.DataFrame(make_tree_info_table_list(task_info=task_info, visited_tasks=set())) 53 | 54 | return task_info_table 55 | 56 | 57 | def dump_task_info_table(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: list[str] | None = None): 58 | """Dump a table containing information about dependent tasks. 59 | 60 | Parameters 61 | ---------- 62 | - task: TaskOnKart 63 | Root task. 64 | - task_info_dump_path: str 65 | Output target file path. Path destination can be `local`, `S3`, or `GCS`. 66 | File extension can be any type that gokart file processor accepts, including `csv`, `pickle`, or `txt`. 67 | See `TaskOnKart.make_target module ` for details. 68 | - ignore_task_names: list[str] | None 69 | List of task names to ignore. 70 | Returns 71 | ------- 72 | None 73 | """ 74 | task_info_table = make_task_info_as_table(task=task, ignore_task_names=ignore_task_names) 75 | 76 | unique_id = task.make_unique_id() 77 | 78 | task_info_target = make_target(file_path=task_info_dump_path, unique_id=unique_id) 79 | task_info_target.dump(obj=task_info_table, lock_at_dump=False) 80 | 81 | 82 | def dump_task_info_tree(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: list[str] | None = None, use_unique_id: bool = True): 83 | """Dump the task info tree object (TaskInfo) to a pickle file. 84 | 85 | Parameters 86 | ---------- 87 | - task: TaskOnKart 88 | Root task. 89 | - task_info_dump_path: str 90 | Output target file path. Path destination can be `local`, `S3`, or `GCS`. 91 | File extension must be '.pkl'. 92 | - ignore_task_names: list[str] | None 93 | List of task names to ignore. 94 | - use_unique_id: bool = True 95 | Whether to use unique id to dump target file. Default is True. 96 | Returns 97 | ------- 98 | None 99 | """ 100 | extension = os.path.splitext(task_info_dump_path)[1] 101 | assert extension == '.pkl', f'File extention must be `.pkl`, not `{extension}`.' 102 | 103 | task_info_tree = make_task_info_tree(task, ignore_task_names=ignore_task_names) 104 | 105 | unique_id = task.make_unique_id() if use_unique_id else None 106 | 107 | task_info_target = make_target(file_path=task_info_dump_path, unique_id=unique_id) 108 | task_info_target.dump(obj=task_info_tree, lock_at_dump=False) 109 | -------------------------------------------------------------------------------- /gokart/tree/task_info_formatter.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import NamedTuple 7 | 8 | from gokart.task import TaskOnKart 9 | from gokart.utils import FlattenableItems, flatten 10 | 11 | 12 | @dataclass 13 | class TaskInfo: 14 | name: str 15 | unique_id: str 16 | output_paths: list[str] 17 | params: dict 18 | processing_time: str 19 | is_complete: str 20 | task_log: dict 21 | requires: FlattenableItems[RequiredTask] 22 | children_task_infos: list[TaskInfo] 23 | 24 | def get_task_id(self): 25 | return f'{self.name}_{self.unique_id}' 26 | 27 | def get_task_title(self): 28 | return f'({self.is_complete}) {self.name}[{self.unique_id}]' 29 | 30 | def get_task_detail(self): 31 | return f'(parameter={self.params}, output={self.output_paths}, time={self.processing_time}, task_log={self.task_log})' 32 | 33 | def task_info_dict(self): 34 | return dict( 35 | name=self.name, 36 | unique_id=self.unique_id, 37 | output_paths=self.output_paths, 38 | params=self.params, 39 | processing_time=self.processing_time, 40 | is_complete=self.is_complete, 41 | task_log=self.task_log, 42 | requires=self.requires, 43 | ) 44 | 45 | 46 | class RequiredTask(NamedTuple): 47 | name: str 48 | unique_id: str 49 | 50 | 51 | def _make_requires_info(requires): 52 | if isinstance(requires, TaskOnKart): 53 | return RequiredTask(name=requires.__class__.__name__, unique_id=requires.make_unique_id()) 54 | elif isinstance(requires, dict): 55 | return {key: _make_requires_info(requires=item) for key, item in requires.items()} 56 | elif isinstance(requires, typing.Iterable): 57 | return [_make_requires_info(requires=item) for item in requires] 58 | 59 | raise TypeError(f'`requires` has unexpected type {type(requires)}. Must be `TaskOnKart`, `Iterarble[TaskOnKart]`, or `Dict[str, TaskOnKart]`') 60 | 61 | 62 | def make_task_info_tree(task: TaskOnKart, ignore_task_names: list[str] | None = None, cache: dict[str, TaskInfo] | None = None) -> TaskInfo: 63 | with warnings.catch_warnings(): 64 | warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete() method') 65 | is_task_complete = task.complete() 66 | 67 | name = task.__class__.__name__ 68 | unique_id = task.make_unique_id() 69 | output_paths: list[str] = [t.path() for t in flatten(task.output())] 70 | 71 | cache = {} if cache is None else cache 72 | cache_id = f'{name}_{unique_id}_{is_task_complete}' 73 | if cache_id in cache: 74 | return cache[cache_id] 75 | 76 | params = task.get_info(only_significant=True) 77 | processing_time = task.get_processing_time() 78 | if isinstance(processing_time, float): 79 | processing_time = str(processing_time) + 's' 80 | is_complete = 'COMPLETE' if is_task_complete else 'PENDING' 81 | task_log = dict(task.get_task_log()) 82 | requires = _make_requires_info(task.requires()) 83 | 84 | children = flatten(task.requires()) 85 | children_task_infos: list[TaskInfo] = [] 86 | for child in children: 87 | if ignore_task_names is None or child.__class__.__name__ not in ignore_task_names: 88 | children_task_infos.append(make_task_info_tree(child, ignore_task_names=ignore_task_names, cache=cache)) 89 | task_info = TaskInfo( 90 | name=name, 91 | unique_id=unique_id, 92 | output_paths=output_paths, 93 | params=params, 94 | processing_time=processing_time, 95 | is_complete=is_complete, 96 | task_log=task_log, 97 | requires=requires, 98 | children_task_infos=children_task_infos, 99 | ) 100 | cache[cache_id] = task_info 101 | return task_info 102 | 103 | 104 | def make_tree_info(task_info: TaskInfo, indent: str, last: bool, details: bool, abbr: bool, visited_tasks: set[str]): 105 | result = '\n' + indent 106 | if last: 107 | result += '└─-' 108 | indent += ' ' 109 | else: 110 | result += '|--' 111 | indent += '| ' 112 | result += task_info.get_task_title() 113 | 114 | if abbr: 115 | task_id = task_info.get_task_id() 116 | if task_id not in visited_tasks: 117 | visited_tasks.add(task_id) 118 | else: 119 | result += f'\n{indent}└─- ...' 120 | return result 121 | 122 | if details: 123 | result += task_info.get_task_detail() 124 | 125 | children = task_info.children_task_infos 126 | for index, child in enumerate(children): 127 | result += make_tree_info(child, indent, (index + 1) == len(children), details=details, abbr=abbr, visited_tasks=visited_tasks) 128 | return result 129 | 130 | 131 | def make_tree_info_table_list(task_info: TaskInfo, visited_tasks: set[str]): 132 | task_id = task_info.get_task_id() 133 | if task_id in visited_tasks: 134 | return [] 135 | visited_tasks.add(task_id) 136 | 137 | result = [task_info.task_info_dict()] 138 | 139 | children = task_info.children_task_infos 140 | for child in children: 141 | result += make_tree_info_table_list(task_info=child, visited_tasks=visited_tasks) 142 | return result 143 | -------------------------------------------------------------------------------- /gokart/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | from collections.abc import Iterable 6 | from io import BytesIO 7 | from typing import Any, Callable, Protocol, TypeVar, Union 8 | 9 | import dill 10 | import luigi 11 | import pandas as pd 12 | 13 | 14 | class FileLike(Protocol): 15 | def read(self, n: int) -> bytes: ... 16 | 17 | def readline(self) -> bytes: ... 18 | 19 | def seek(self, offset: int) -> None: ... 20 | 21 | def seekable(self) -> bool: ... 22 | 23 | 24 | def add_config(file_path: str): 25 | _, ext = os.path.splitext(file_path) 26 | luigi.configuration.core.parser = ext # type: ignore 27 | assert luigi.configuration.add_config_path(file_path) 28 | 29 | 30 | T = TypeVar('T') 31 | if sys.version_info >= (3, 10): 32 | from typing import TypeAlias 33 | 34 | FlattenableItems: TypeAlias = T | Iterable['FlattenableItems[T]'] | dict[str, 'FlattenableItems[T]'] 35 | else: 36 | FlattenableItems = Union[T, Iterable['FlattenableItems[T]'], dict[str, 'FlattenableItems[T]']] 37 | 38 | 39 | def flatten(targets: FlattenableItems[T]) -> list[T]: 40 | """ 41 | Creates a flat list of all items in structured output (dicts, lists, items): 42 | 43 | .. code-block:: python 44 | 45 | >>> sorted(flatten({'a': 'foo', 'b': 'bar'})) 46 | ['bar', 'foo'] 47 | >>> sorted(flatten(['foo', ['bar', 'troll']])) 48 | ['bar', 'foo', 'troll'] 49 | >>> flatten('foo') 50 | ['foo'] 51 | >>> flatten(42) 52 | [42] 53 | 54 | This method is copied and modified from [luigi.task.flatten](https://github.com/spotify/luigi/blob/367edc2e3a099b8a0c2d15b1676269e33ad06117/luigi/task.py#L958) in accordance with [Apache License 2.0](https://github.com/spotify/luigi/blob/367edc2e3a099b8a0c2d15b1676269e33ad06117/LICENSE). 55 | """ 56 | if targets is None: 57 | return [] 58 | flat = [] 59 | if isinstance(targets, dict): 60 | for _, result in targets.items(): 61 | flat += flatten(result) 62 | return flat 63 | 64 | if isinstance(targets, str): 65 | return [targets] # type: ignore 66 | 67 | if not isinstance(targets, Iterable): 68 | return [targets] 69 | 70 | for result in targets: 71 | flat += flatten(result) 72 | return flat 73 | 74 | 75 | K = TypeVar('K') 76 | 77 | 78 | def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]: 79 | if isinstance(items, dict): 80 | return {k: map_flattenable_items(func, v) for k, v in items.items()} 81 | if isinstance(items, tuple): 82 | return tuple(map_flattenable_items(func, i) for i in items) 83 | if isinstance(items, str): 84 | return func(items) # type: ignore 85 | if isinstance(items, Iterable): 86 | return list(map(lambda item: map_flattenable_items(func, item), items)) 87 | return func(items) 88 | 89 | 90 | def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> Any: 91 | """Load binary dumped by dill with pandas backward compatibility. 92 | pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle. 93 | It is unclear whether all objects dumped by dill can be loaded by pd.read_pickle, we use dill.load as a fallback. 94 | """ 95 | try: 96 | return dill.load(file) 97 | except Exception: 98 | assert file.seekable(), f'{file} is not seekable.' 99 | file.seek(0) 100 | return pd.read_pickle(file) 101 | -------------------------------------------------------------------------------- /gokart/workspace_management.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import itertools 4 | import os 5 | import pathlib 6 | from logging import getLogger 7 | 8 | import gokart 9 | from gokart.utils import flatten 10 | 11 | logger = getLogger(__name__) 12 | 13 | 14 | def _get_all_output_file_paths(task: gokart.TaskOnKart): 15 | output_paths = [t.path() for t in flatten(task.output())] 16 | children = flatten(task.requires()) 17 | output_paths.extend(itertools.chain.from_iterable([_get_all_output_file_paths(child) for child in children])) 18 | return output_paths 19 | 20 | 21 | def delete_local_unnecessary_outputs(task: gokart.TaskOnKart): 22 | task.make_unique_id() # this is required to make unique ids. 23 | all_files = {str(path) for path in pathlib.Path(task.workspace_directory).rglob('*.*')} 24 | log_files = {str(path) for path in pathlib.Path(os.path.join(task.workspace_directory, 'log')).rglob('*.*')} 25 | necessary_files = set(_get_all_output_file_paths(task)) 26 | unnecessary_files = all_files - necessary_files - log_files 27 | if len(unnecessary_files) == 0: 28 | logger.info('all files are necessary for this task.') 29 | else: 30 | logger.info(f'remove following files: {os.linesep} {os.linesep.join(unnecessary_files)}') 31 | for file in unnecessary_files: 32 | os.remove(file) 33 | -------------------------------------------------------------------------------- /gokart/zip_client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import shutil 5 | import zipfile 6 | from abc import abstractmethod 7 | from typing import IO 8 | 9 | 10 | def _unzip_file(fp: str | IO | os.PathLike, extract_dir: str) -> None: 11 | zip_file = zipfile.ZipFile(fp) 12 | zip_file.extractall(extract_dir) 13 | zip_file.close() 14 | 15 | 16 | class ZipClient: 17 | @abstractmethod 18 | def exists(self) -> bool: 19 | pass 20 | 21 | @abstractmethod 22 | def make_archive(self) -> None: 23 | pass 24 | 25 | @abstractmethod 26 | def unpack_archive(self) -> None: 27 | pass 28 | 29 | @abstractmethod 30 | def remove(self) -> None: 31 | pass 32 | 33 | @property 34 | @abstractmethod 35 | def path(self) -> str: 36 | pass 37 | 38 | 39 | class LocalZipClient(ZipClient): 40 | def __init__(self, file_path: str, temporary_directory: str) -> None: 41 | self._file_path = file_path 42 | self._temporary_directory = temporary_directory 43 | 44 | def exists(self) -> bool: 45 | return os.path.exists(self._file_path) 46 | 47 | def make_archive(self) -> None: 48 | [base, extension] = os.path.splitext(self._file_path) 49 | shutil.make_archive(base_name=base, format=extension[1:], root_dir=self._temporary_directory) 50 | 51 | def unpack_archive(self) -> None: 52 | _unzip_file(fp=self._file_path, extract_dir=self._temporary_directory) 53 | 54 | def remove(self) -> None: 55 | shutil.rmtree(self._file_path, ignore_errors=True) 56 | 57 | @property 58 | def path(self) -> str: 59 | return self._file_path 60 | -------------------------------------------------------------------------------- /gokart/zip_client_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from gokart.object_storage import ObjectStorage 4 | from gokart.zip_client import LocalZipClient, ZipClient 5 | 6 | 7 | def make_zip_client(file_path: str, temporary_directory: str) -> ZipClient: 8 | if ObjectStorage.if_object_storage_path(file_path): 9 | return ObjectStorage.get_zip_client(file_path=file_path, temporary_directory=temporary_directory) 10 | return LocalZipClient(file_path=file_path, temporary_directory=temporary_directory) 11 | -------------------------------------------------------------------------------- /luigi.cfg: -------------------------------------------------------------------------------- 1 | [core] 2 | autoload_range: false 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "uv-dynamic-versioning"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "gokart" 7 | description="Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline. [Documentation](https://gokart.readthedocs.io/en/latest/)" 8 | authors = [ 9 | {name = "M3, inc."} 10 | ] 11 | license = "MIT" 12 | readme = "README.md" 13 | requires-python = ">=3.9, <4" 14 | dependencies = [ 15 | "luigi", 16 | "boto3", 17 | "slack-sdk", 18 | "pandas", 19 | "numpy", 20 | "google-auth", 21 | "pyarrow", 22 | "uritemplate", 23 | "google-api-python-client", 24 | "APScheduler", 25 | "redis", 26 | "dill", 27 | "backoff", 28 | "typing-extensions>=4.11.0; python_version<'3.13'", 29 | ] 30 | classifiers = [ 31 | "Development Status :: 5 - Production/Stable", 32 | "License :: OSI Approved :: MIT License", 33 | "Programming Language :: Python :: 3", 34 | "Programming Language :: Python :: 3.9", 35 | "Programming Language :: Python :: 3.10", 36 | "Programming Language :: Python :: 3.11", 37 | "Programming Language :: Python :: 3.12", 38 | "Programming Language :: Python :: 3.13", 39 | 40 | ] 41 | dynamic = ["version"] 42 | 43 | [project.urls] 44 | Homepage = "https://github.com/m3dev/gokart" 45 | Repository = "https://github.com/m3dev/gokart" 46 | Documentation = "https://gokart.readthedocs.io/en/latest/" 47 | 48 | [dependency-groups] 49 | test = [ 50 | "fakeredis", 51 | "lupa", 52 | "matplotlib", 53 | "moto", 54 | "mypy", 55 | "pytest", 56 | "pytest-cov", 57 | "pytest-xdist", 58 | "testfixtures", 59 | "toml", 60 | "types-redis", 61 | "typing-extensions>=4.11.0", 62 | ] 63 | 64 | lint = [ 65 | "ruff", 66 | "mypy", 67 | ] 68 | 69 | [tool.uv] 70 | default-groups = ['test', 'lint'] 71 | cache-keys = [ { file = "pyproject.toml" }, { git = true } ] 72 | 73 | [tool.hatch.version] 74 | source = "uv-dynamic-versioning" 75 | 76 | [tool.uv-dynamic-versioning] 77 | enable = true 78 | 79 | [tool.hatch.build.targets.sdist] 80 | include = [ 81 | "/LICENSE", 82 | "/README.md", 83 | "/examples", 84 | "/gokart", 85 | "/test", 86 | ] 87 | 88 | [tool.ruff] 89 | line-length = 160 90 | exclude = ["venv/*", "tox/*", "examples/*"] 91 | 92 | [tool.ruff.lint] 93 | # All the rules are listed on https://docs.astral.sh/ruff/rules/ 94 | extend-select = [ 95 | "B", # bugbear 96 | "I", # isort 97 | "UP", # pyupgrade, upgrade syntax for newer versions of the language. 98 | ] 99 | 100 | # B006: Do not use mutable data structures for argument defaults. They are created during function definition time. All calls to the function reuse this one instance of that data structure, persisting changes between them. 101 | # B008 Do not perform function calls in argument defaults. The call is performed only once at function definition time. All calls to your function will reuse the result of that definition-time function call. If this is intended, assign the function call to a module-level variable and use that variable as a default value. 102 | ignore = ["B006", "B008"] 103 | 104 | [tool.ruff.format] 105 | quote-style = "single" 106 | 107 | [tool.mypy] 108 | ignore_missing_imports = true 109 | plugins = ["gokart.mypy:plugin"] 110 | 111 | check_untyped_defs = true 112 | 113 | [tool.pytest.ini_options] 114 | testpaths = ["test"] 115 | addopts = "-n auto -s -v --durations=0" 116 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/__init__.py -------------------------------------------------------------------------------- /test/config/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Final 3 | 4 | CONFIG_DIR: Final[Path] = Path(__file__).parent.resolve() 5 | PYPROJECT_TOML: Final[Path] = CONFIG_DIR / 'pyproject.toml' 6 | PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS: Final[Path] = CONFIG_DIR / 'pyproject_disallow_missing_parameters.toml' 7 | TEST_CONFIG_INI: Final[Path] = CONFIG_DIR / 'test_config.ini' 8 | -------------------------------------------------------------------------------- /test/config/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.mypy] 2 | plugins = ["gokart.mypy"] 3 | 4 | [[tool.mypy.overrides]] 5 | ignore_missing_imports = true 6 | module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"] 7 | -------------------------------------------------------------------------------- /test/config/pyproject_disallow_missing_parameters.toml: -------------------------------------------------------------------------------- 1 | [tool.mypy] 2 | plugins = ["gokart.mypy"] 3 | 4 | [[tool.mypy.overrides]] 5 | ignore_missing_imports = true 6 | module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"] 7 | 8 | [tool.gokart-mypy] 9 | disallow_missing_parameters = true 10 | -------------------------------------------------------------------------------- /test/config/test_config.ini: -------------------------------------------------------------------------------- 1 | [test_read_config._DummyTask] 2 | param = ${test_param} 3 | 4 | [test_build._DummyTask] 5 | param = ${test_param} -------------------------------------------------------------------------------- /test/conflict_prevention_lock/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/conflict_prevention_lock/__init__.py -------------------------------------------------------------------------------- /test/conflict_prevention_lock/test_task_lock.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | from unittest.mock import patch 4 | 5 | import gokart 6 | from gokart.conflict_prevention_lock.task_lock import RedisClient, TaskLockParams, make_task_lock_key, make_task_lock_params, make_task_lock_params_for_run 7 | 8 | 9 | class TestRedisClient(unittest.TestCase): 10 | @staticmethod 11 | def _get_randint(host, port): 12 | return random.randint(0, 100000) 13 | 14 | def test_redis_client_is_singleton(self): 15 | with patch('redis.Redis') as mock: 16 | mock.side_effect = self._get_randint 17 | 18 | redis_client_0_0 = RedisClient(host='host_0', port=123) 19 | redis_client_1 = RedisClient(host='host_1', port=123) 20 | redis_client_0_1 = RedisClient(host='host_0', port=123) 21 | 22 | self.assertNotEqual(redis_client_0_0, redis_client_1) 23 | self.assertEqual(redis_client_0_0, redis_client_0_1) 24 | 25 | self.assertEqual(redis_client_0_0.get_redis_client(), redis_client_0_1.get_redis_client()) 26 | 27 | 28 | class TestMakeRedisKey(unittest.TestCase): 29 | def test_make_redis_key(self): 30 | result = make_task_lock_key(file_path='gs://test_ll/dir/fname.pkl', unique_id='12345') 31 | self.assertEqual(result, 'fname_12345') 32 | 33 | 34 | class TestMakeRedisParams(unittest.TestCase): 35 | def test_make_task_lock_params_with_valid_host(self): 36 | result = make_task_lock_params( 37 | file_path='gs://aaa.pkl', unique_id='123', redis_host='0.0.0.0', redis_port=12345, redis_timeout=180, raise_task_lock_exception_on_collision=False 38 | ) 39 | expected = TaskLockParams( 40 | redis_host='0.0.0.0', 41 | redis_port=12345, 42 | redis_key='aaa_123', 43 | should_task_lock=True, 44 | redis_timeout=180, 45 | raise_task_lock_exception_on_collision=False, 46 | lock_extend_seconds=10, 47 | ) 48 | self.assertEqual(result, expected) 49 | 50 | def test_make_task_lock_params_with_no_host(self): 51 | result = make_task_lock_params( 52 | file_path='gs://aaa.pkl', unique_id='123', redis_host=None, redis_port=12345, redis_timeout=180, raise_task_lock_exception_on_collision=False 53 | ) 54 | expected = TaskLockParams( 55 | redis_host=None, 56 | redis_port=12345, 57 | redis_key='aaa_123', 58 | should_task_lock=False, 59 | redis_timeout=180, 60 | raise_task_lock_exception_on_collision=False, 61 | lock_extend_seconds=10, 62 | ) 63 | self.assertEqual(result, expected) 64 | 65 | def test_assert_when_redis_timeout_is_too_short(self): 66 | with self.assertRaises(AssertionError): 67 | make_task_lock_params( 68 | file_path='test_dir/test_file.pkl', 69 | unique_id='123abc', 70 | redis_host='0.0.0.0', 71 | redis_port=12345, 72 | redis_timeout=2, 73 | ) 74 | 75 | 76 | class TestMakeTaskLockParamsForRun(unittest.TestCase): 77 | def test_make_task_lock_params_for_run(self): 78 | class _SampleDummyTask(gokart.TaskOnKart): 79 | pass 80 | 81 | task_self = _SampleDummyTask( 82 | redis_host='0.0.0.0', 83 | redis_port=12345, 84 | redis_timeout=180, 85 | ) 86 | 87 | result = make_task_lock_params_for_run(task_self=task_self, lock_extend_seconds=10) 88 | expected = TaskLockParams( 89 | redis_host='0.0.0.0', 90 | redis_port=12345, 91 | redis_timeout=180, 92 | redis_key='_SampleDummyTask_7e857f231830ca0fd6cf829d99f43961-run', 93 | should_task_lock=True, 94 | raise_task_lock_exception_on_collision=True, 95 | lock_extend_seconds=10, 96 | ) 97 | 98 | self.assertEqual(result, expected) 99 | -------------------------------------------------------------------------------- /test/in_memory/test_in_memory_target.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from time import sleep 3 | 4 | import pytest 5 | 6 | from gokart.conflict_prevention_lock.task_lock import TaskLockParams 7 | from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget, make_in_memory_target 8 | 9 | 10 | class TestInMemoryTarget: 11 | @pytest.fixture 12 | def task_lock_params(self) -> TaskLockParams: 13 | return TaskLockParams( 14 | redis_host=None, 15 | redis_port=None, 16 | redis_timeout=None, 17 | redis_key='dummy', 18 | should_task_lock=False, 19 | raise_task_lock_exception_on_collision=False, 20 | lock_extend_seconds=0, 21 | ) 22 | 23 | @pytest.fixture 24 | def target(self, task_lock_params: TaskLockParams) -> InMemoryTarget: 25 | return make_in_memory_target(target_key='dummy_key', task_lock_params=task_lock_params) 26 | 27 | @pytest.fixture(autouse=True) 28 | def clear_repo(self) -> None: 29 | InMemoryCacheRepository().clear() 30 | 31 | def test_dump_and_load_data(self, target: InMemoryTarget): 32 | dumped = 'dummy_data' 33 | target.dump(dumped) 34 | loaded = target.load() 35 | assert loaded == dumped 36 | 37 | def test_exist(self, target: InMemoryTarget): 38 | assert not target.exists() 39 | target.dump('dummy_data') 40 | assert target.exists() 41 | 42 | def test_last_modified_time(self, target: InMemoryTarget): 43 | input = 'dummy_data' 44 | target.dump(input) 45 | time = target.last_modification_time() 46 | assert isinstance(time, datetime) 47 | 48 | sleep(0.1) 49 | another_input = 'another_data' 50 | target.dump(another_input) 51 | another_time = target.last_modification_time() 52 | assert time < another_time 53 | 54 | target.remove() 55 | with pytest.raises(ValueError): 56 | assert target.last_modification_time() 57 | -------------------------------------------------------------------------------- /test/in_memory/test_repository.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | 5 | from gokart.in_memory import InMemoryCacheRepository as Repo 6 | 7 | dummy_num = 100 8 | 9 | 10 | class TestInMemoryCacheRepository: 11 | @pytest.fixture 12 | def repo(self) -> Repo: 13 | repo = Repo() 14 | repo.clear() 15 | return repo 16 | 17 | def test_set(self, repo: Repo): 18 | repo.set_value('dummy_key', dummy_num) 19 | assert repo.size == 1 20 | for key, value in repo.get_gen(): 21 | assert (key, value) == ('dummy_key', dummy_num) 22 | 23 | repo.set_value('another_key', 'another_value') 24 | assert repo.size == 2 25 | 26 | def test_get(self, repo: Repo): 27 | repo.set_value('dummy_key', dummy_num) 28 | repo.set_value('another_key', 'another_value') 29 | 30 | """Raise Error when key doesn't exist.""" 31 | with pytest.raises(KeyError): 32 | repo.get_value('not_exist_key') 33 | 34 | assert repo.get_value('dummy_key') == dummy_num 35 | assert repo.get_value('another_key') == 'another_value' 36 | 37 | def test_empty(self, repo: Repo): 38 | assert repo.empty() 39 | repo.set_value('dummmy_key', dummy_num) 40 | assert not repo.empty() 41 | 42 | def test_has(self, repo: Repo): 43 | assert not repo.has('dummy_key') 44 | repo.set_value('dummy_key', dummy_num) 45 | assert repo.has('dummy_key') 46 | assert not repo.has('not_exist_key') 47 | 48 | def test_remove(self, repo: Repo): 49 | repo.set_value('dummy_key', dummy_num) 50 | 51 | with pytest.raises(AssertionError): 52 | repo.remove('not_exist_key') 53 | 54 | repo.remove('dummy_key') 55 | assert not repo.has('dummy_key') 56 | 57 | def test_last_modification_time(self, repo: Repo): 58 | repo.set_value('dummy_key', dummy_num) 59 | date1 = repo.get_last_modification_time('dummy_key') 60 | time.sleep(0.1) 61 | repo.set_value('dummy_key', dummy_num) 62 | date2 = repo.get_last_modification_time('dummy_key') 63 | assert date1 < date2 64 | -------------------------------------------------------------------------------- /test/slack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/slack/__init__.py -------------------------------------------------------------------------------- /test/slack/test_slack_api.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from logging import getLogger 3 | from unittest import mock 4 | from unittest.mock import MagicMock 5 | 6 | from slack_sdk import WebClient 7 | from slack_sdk.web.slack_response import SlackResponse 8 | from testfixtures import LogCapture 9 | 10 | import gokart.slack 11 | 12 | logger = getLogger(__name__) 13 | 14 | 15 | def _slack_response(token, data): 16 | return SlackResponse( 17 | client=WebClient(token=token), http_verb='POST', api_url='http://localhost:3000/api.test', req_args={}, data=data, headers={}, status_code=200 18 | ) 19 | 20 | 21 | class TestSlackAPI(unittest.TestCase): 22 | @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient') 23 | def test_initialization_with_invalid_token(self, patch): 24 | def _conversations_list(params={}): 25 | return _slack_response(token='invalid', data={'ok': False, 'error': 'error_reason'}) 26 | 27 | mock_client = MagicMock() 28 | mock_client.conversations_list = MagicMock(side_effect=_conversations_list) 29 | patch.return_value = mock_client 30 | 31 | with LogCapture() as log: 32 | gokart.slack.SlackAPI(token='invalid', channel='test', to_user='test user') 33 | log.check(('gokart.slack.slack_api', 'WARNING', 'The job will start without slack notification: Channel test is not found in public channels.')) 34 | 35 | @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient') 36 | def test_invalid_channel(self, patch): 37 | def _conversations_list(params={}): 38 | return _slack_response( 39 | token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}} 40 | ) 41 | 42 | mock_client = MagicMock() 43 | mock_client.conversations_list = MagicMock(side_effect=_conversations_list) 44 | patch.return_value = mock_client 45 | 46 | with LogCapture() as log: 47 | gokart.slack.SlackAPI(token='valid', channel='invalid_channel', to_user='test user') 48 | log.check( 49 | ('gokart.slack.slack_api', 'WARNING', 'The job will start without slack notification: Channel invalid_channel is not found in public channels.') 50 | ) 51 | 52 | @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient') 53 | def test_send_snippet_with_invalid_token(self, patch): 54 | def _conversations_list(params={}): 55 | return _slack_response( 56 | token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}} 57 | ) 58 | 59 | def _api_call(method, data={}): 60 | assert method == 'files.upload' 61 | return {'ok': False, 'error': 'error_reason'} 62 | 63 | mock_client = MagicMock() 64 | mock_client.conversations_list = MagicMock(side_effect=_conversations_list) 65 | mock_client.api_call = MagicMock(side_effect=_api_call) 66 | patch.return_value = mock_client 67 | 68 | with LogCapture() as log: 69 | api = gokart.slack.SlackAPI(token='valid', channel='valid', to_user='test user') 70 | api.send_snippet(comment='test', title='title', content='content') 71 | log.check( 72 | ('gokart.slack.slack_api', 'WARNING', 'Failed to send slack notification: Error while uploading file. The error reason is "error_reason".') 73 | ) 74 | 75 | @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient') 76 | def test_send(self, patch): 77 | def _conversations_list(params={}): 78 | return _slack_response( 79 | token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}} 80 | ) 81 | 82 | def _api_call(method, data={}): 83 | assert method == 'files.upload' 84 | return {'ok': False, 'error': 'error_reason'} 85 | 86 | mock_client = MagicMock() 87 | mock_client.conversations_list = MagicMock(side_effect=_conversations_list) 88 | mock_client.api_call = MagicMock(side_effect=_api_call) 89 | patch.return_value = mock_client 90 | 91 | api = gokart.slack.SlackAPI(token='valid', channel='valid', to_user='test user') 92 | api.send_snippet(comment='test', title='title', content='content') 93 | 94 | 95 | if __name__ == '__main__': 96 | unittest.main() 97 | -------------------------------------------------------------------------------- /test/test_cache_unique_id.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import luigi 5 | import luigi.mock 6 | 7 | import gokart 8 | 9 | 10 | class _DummyTask(gokart.TaskOnKart): 11 | def requires(self): 12 | return _DummyTaskDep() 13 | 14 | def run(self): 15 | self.dump(self.load()) 16 | 17 | 18 | class _DummyTaskDep(gokart.TaskOnKart): 19 | param = luigi.Parameter() 20 | 21 | def run(self): 22 | self.dump(self.param) 23 | 24 | 25 | class CacheUniqueIDTest(unittest.TestCase): 26 | def setUp(self): 27 | luigi.configuration.LuigiConfigParser._instance = None 28 | luigi.mock.MockFileSystem().clear() 29 | os.environ.clear() 30 | 31 | def test_cache_unique_id_true(self): 32 | _DummyTaskDep.param = luigi.Parameter(default='original_param') 33 | 34 | output1 = gokart.build(_DummyTask(cache_unique_id=True), reset_register=False) 35 | 36 | _DummyTaskDep.param = luigi.Parameter(default='updated_param') 37 | output2 = gokart.build(_DummyTask(cache_unique_id=True), reset_register=False) 38 | self.assertEqual(output1, output2) 39 | 40 | def test_cache_unique_id_false(self): 41 | _DummyTaskDep.param = luigi.Parameter(default='original_param') 42 | 43 | output1 = gokart.build(_DummyTask(cache_unique_id=False), reset_register=False) 44 | 45 | _DummyTaskDep.param = luigi.Parameter(default='updated_param') 46 | output2 = gokart.build(_DummyTask(cache_unique_id=False), reset_register=False) 47 | self.assertNotEqual(output1, output2) 48 | 49 | 50 | if __name__ == '__main__': 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /test/test_config_params.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import luigi 4 | from luigi.cmdline_parser import CmdlineParser 5 | 6 | import gokart 7 | from gokart.config_params import inherits_config_params 8 | 9 | 10 | def in_parse(cmds, deferred_computation): 11 | """function copied from luigi: https://github.com/spotify/luigi/blob/e2228418eec60b68ca09a30c878ab26413846847/test/helpers.py""" 12 | with CmdlineParser.global_instance(cmds) as cp: 13 | deferred_computation(cp.get_task_obj()) 14 | 15 | 16 | class ConfigClass(luigi.Config): 17 | param_a = luigi.Parameter(default='config a') 18 | param_b = luigi.Parameter(default='config b') 19 | param_c = luigi.Parameter(default='config c') 20 | 21 | 22 | @inherits_config_params(ConfigClass) 23 | class Inherited(gokart.TaskOnKart): 24 | param_a = luigi.Parameter() 25 | param_b = luigi.Parameter(default='overrided') 26 | 27 | 28 | @inherits_config_params(ConfigClass, parameter_alias={'param_a': 'param_d'}) 29 | class Inherited2(gokart.TaskOnKart): 30 | param_c = luigi.Parameter() 31 | param_d = luigi.Parameter() 32 | 33 | 34 | class ChildTask(Inherited): 35 | pass 36 | 37 | 38 | class ChildTaskWithNewParam(Inherited): 39 | param_new = luigi.Parameter() 40 | 41 | 42 | class ConfigClass2(luigi.Config): 43 | param_a = luigi.Parameter(default='config a from config class 2') 44 | 45 | 46 | @inherits_config_params(ConfigClass2) 47 | class ChildTaskWithNewConfig(Inherited): 48 | pass 49 | 50 | 51 | class TestInheritsConfigParam(unittest.TestCase): 52 | def test_inherited_params(self): 53 | # test fill values 54 | in_parse(['Inherited'], lambda task: self.assertEqual(task.param_a, 'config a')) 55 | 56 | # test overrided 57 | in_parse(['Inherited'], lambda task: self.assertEqual(task.param_b, 'config b')) 58 | 59 | # Command line argument takes precedence over config param 60 | in_parse(['Inherited', '--param-a', 'command line arg'], lambda task: self.assertEqual(task.param_a, 'command line arg')) 61 | 62 | # Parameters which is not a member of the task will not be set 63 | with self.assertRaises(AttributeError): 64 | in_parse(['Inherited'], lambda task: task.param_c) 65 | 66 | # test parameter name alias 67 | in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_c, 'config c')) 68 | in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_d, 'config a')) 69 | 70 | def test_child_task(self): 71 | in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_a, 'config a')) 72 | in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_b, 'config b')) 73 | in_parse(['ChildTask', '--param-a', 'command line arg'], lambda task: self.assertEqual(task.param_a, 'command line arg')) 74 | with self.assertRaises(AttributeError): 75 | in_parse(['ChildTask'], lambda task: task.param_c) 76 | 77 | def test_child_override(self): 78 | in_parse(['ChildTaskWithNewConfig'], lambda task: self.assertEqual(task.param_a, 'config a from config class 2')) 79 | in_parse(['ChildTaskWithNewConfig'], lambda task: self.assertEqual(task.param_b, 'config b')) 80 | -------------------------------------------------------------------------------- /test/test_explicit_bool_parameter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import luigi 4 | import luigi.mock 5 | from luigi.cmdline_parser import CmdlineParser 6 | 7 | import gokart 8 | 9 | 10 | def in_parse(cmds, deferred_computation): 11 | with CmdlineParser.global_instance(cmds) as cp: 12 | deferred_computation(cp.get_task_obj()) 13 | 14 | 15 | class WithDefaultTrue(gokart.TaskOnKart): 16 | param = gokart.ExplicitBoolParameter(default=True) 17 | 18 | 19 | class WithDefaultFalse(gokart.TaskOnKart): 20 | param = gokart.ExplicitBoolParameter(default=False) 21 | 22 | 23 | class ExplicitParsing(gokart.TaskOnKart): 24 | param = gokart.ExplicitBoolParameter() 25 | 26 | def run(self): 27 | ExplicitParsing._param = self.param # type: ignore 28 | 29 | 30 | class TestExplicitBoolParameter(unittest.TestCase): 31 | def test_bool_default(self): 32 | self.assertTrue(WithDefaultTrue().param) 33 | self.assertFalse(WithDefaultFalse().param) 34 | 35 | def test_parse_param(self): 36 | in_parse(['ExplicitParsing', '--param', 'true'], lambda task: self.assertTrue(task.param)) 37 | in_parse(['ExplicitParsing', '--param', 'false'], lambda task: self.assertFalse(task.param)) 38 | in_parse(['ExplicitParsing', '--param', 'True'], lambda task: self.assertTrue(task.param)) 39 | in_parse(['ExplicitParsing', '--param', 'False'], lambda task: self.assertFalse(task.param)) 40 | 41 | def test_missing_parameter(self): 42 | with self.assertRaises(luigi.parameter.MissingParameterException): 43 | in_parse(['ExplicitParsing'], lambda: True) 44 | 45 | def test_value_error(self): 46 | with self.assertRaises(ValueError): 47 | in_parse(['ExplicitParsing', '--param', 'Foo'], lambda: True) 48 | 49 | def test_expected_one_argment_error(self): 50 | # argparse throw "expected one argument" error 51 | with self.assertRaises(SystemExit): 52 | in_parse(['ExplicitParsing', '--param'], lambda: True) 53 | -------------------------------------------------------------------------------- /test/test_gcs_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from unittest.mock import MagicMock, patch 4 | 5 | from gokart.gcs_config import GCSConfig 6 | 7 | 8 | class TestGCSConfig(unittest.TestCase): 9 | def test_get_gcs_client_without_gcs_credential_name(self): 10 | mock = MagicMock() 11 | os.environ['env_name'] = '' 12 | with patch('luigi.contrib.gcs.GCSClient', mock): 13 | GCSConfig(gcs_credential_name='env_name')._get_gcs_client() 14 | self.assertEqual(dict(oauth_credentials=None), mock.call_args[1]) 15 | 16 | def test_get_gcs_client_with_file_path(self): 17 | mock = MagicMock() 18 | file_path = 'test.json' 19 | os.environ['env_name'] = file_path 20 | with patch('luigi.contrib.gcs.GCSClient'): 21 | with patch('google.oauth2.service_account.Credentials.from_service_account_file', mock): 22 | with patch('os.path.isfile', return_value=True): 23 | GCSConfig(gcs_credential_name='env_name')._get_gcs_client() 24 | self.assertEqual(file_path, mock.call_args[0][0]) 25 | 26 | def test_get_gcs_client_with_json(self): 27 | mock = MagicMock() 28 | json_str = '{"test": 1}' 29 | os.environ['env_name'] = json_str 30 | with patch('luigi.contrib.gcs.GCSClient'): 31 | with patch('google.oauth2.service_account.Credentials.from_service_account_info', mock): 32 | GCSConfig(gcs_credential_name='env_name')._get_gcs_client() 33 | self.assertEqual(dict(test=1), mock.call_args[0][0]) 34 | -------------------------------------------------------------------------------- /test/test_gcs_obj_metadata_client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import datetime 4 | import unittest 5 | from typing import Any 6 | from unittest.mock import MagicMock, patch 7 | 8 | import gokart 9 | from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient 10 | from gokart.required_task_output import RequiredTaskOutput 11 | from gokart.target import TargetOnKart 12 | 13 | 14 | class _DummyTaskOnKart(gokart.TaskOnKart): 15 | task_namespace = __name__ 16 | 17 | def run(self): 18 | self.dump('Dummy TaskOnKart') 19 | 20 | 21 | class TestGCSObjectMetadataClient(unittest.TestCase): 22 | def setUp(self): 23 | self.task_params: dict[str, str] = { 24 | 'param1': 'a' * 1000, 25 | 'param2': str(1000), 26 | 'param3': str({'key1': 'value1', 'key2': True, 'key3': 2}), 27 | 'param4': str([1, 2, 3, 4, 5]), 28 | 'param5': str(datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5)), 29 | 'param6': '', 30 | } 31 | self.custom_labels: dict[str, Any] = { 32 | 'created_at': datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5), 33 | 'created_by': 'hoge fuga', 34 | 'empty': True, 35 | 'try_num': 3, 36 | } 37 | 38 | self.task_params_with_conflicts = { 39 | 'empty': 'False', 40 | 'created_by': 'fuga hoge', 41 | 'param1': 'a' * 10, 42 | } 43 | 44 | def test_normalize_labels_not_empty(self): 45 | got = GCSObjectMetadataClient._normalize_labels(None) 46 | self.assertEqual(got, {}) 47 | 48 | def test_normalize_labels_has_value(self): 49 | got = GCSObjectMetadataClient._normalize_labels(self.task_params) 50 | 51 | self.assertIsInstance(got, dict) 52 | self.assertIsInstance(got, dict) 53 | self.assertIn('param1', got) 54 | self.assertIn('param2', got) 55 | self.assertIn('param3', got) 56 | self.assertIn('param4', got) 57 | self.assertIn('param5', got) 58 | self.assertIn('param6', got) 59 | 60 | def test_get_patched_obj_metadata_only_task_params(self): 61 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=None) 62 | 63 | self.assertIsInstance(got, dict) 64 | self.assertIn('param1', got) 65 | self.assertIn('param2', got) 66 | self.assertIn('param3', got) 67 | self.assertIn('param4', got) 68 | self.assertIn('param5', got) 69 | self.assertNotIn('param6', got) 70 | 71 | def test_get_patched_obj_metadata_only_custom_labels(self): 72 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=None, custom_labels=self.custom_labels) 73 | 74 | self.assertIsInstance(got, dict) 75 | self.assertIn('created_at', got) 76 | self.assertIn('created_by', got) 77 | self.assertIn('empty', got) 78 | self.assertIn('try_num', got) 79 | 80 | def test_get_patched_obj_metadata_with_both_task_params_and_custom_labels(self): 81 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=self.custom_labels) 82 | 83 | self.assertIsInstance(got, dict) 84 | self.assertIn('param1', got) 85 | self.assertIn('param2', got) 86 | self.assertIn('param3', got) 87 | self.assertIn('param4', got) 88 | self.assertIn('param5', got) 89 | self.assertNotIn('param6', got) 90 | self.assertIn('created_at', got) 91 | self.assertIn('created_by', got) 92 | self.assertIn('empty', got) 93 | self.assertIn('try_num', got) 94 | 95 | def test_get_patched_obj_metadata_with_exceeded_size_metadata(self): 96 | size_exceeded_task_params = { 97 | 'param1': 'a' * 5000, 98 | 'param2': 'b' * 5000, 99 | } 100 | want = { 101 | 'param1': 'a' * 5000, 102 | } 103 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=size_exceeded_task_params) 104 | self.assertEqual(got, want) 105 | 106 | def test_get_patched_obj_metadata_with_conflicts(self): 107 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params_with_conflicts, custom_labels=self.custom_labels) 108 | self.assertIsInstance(got, dict) 109 | self.assertIn('created_at', got) 110 | self.assertIn('created_by', got) 111 | self.assertIn('empty', got) 112 | self.assertIn('try_num', got) 113 | self.assertEqual(got['empty'], 'True') 114 | self.assertEqual(got['created_by'], 'hoge fuga') 115 | self.assertEqual(got['param1'], 'a' * 10) 116 | 117 | def test_get_patched_obj_metadata_with_required_task_outputs(self): 118 | got = GCSObjectMetadataClient._get_patched_obj_metadata( 119 | {}, 120 | required_task_outputs=[ 121 | RequiredTaskOutput(task_name='task1', output_path='path/to/output1'), 122 | ], 123 | ) 124 | 125 | self.assertIsInstance(got, dict) 126 | self.assertIn('__required_task_outputs', got) 127 | self.assertEqual(got['__required_task_outputs'], '[{"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}]') 128 | 129 | def test_get_patched_obj_metadata_with_nested_required_task_outputs(self): 130 | got = GCSObjectMetadataClient._get_patched_obj_metadata( 131 | {}, 132 | required_task_outputs={ 133 | 'nested_task': {'nest': RequiredTaskOutput(task_name='task1', output_path='path/to/output1')}, 134 | }, 135 | ) 136 | 137 | self.assertIsInstance(got, dict) 138 | self.assertIn('__required_task_outputs', got) 139 | self.assertEqual( 140 | got['__required_task_outputs'], '{"nested_task": {"nest": {"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}}}' 141 | ) 142 | 143 | 144 | class TestGokartTask(unittest.TestCase): 145 | @patch.object(_DummyTaskOnKart, '_get_output_target') 146 | def test_mock_target_on_kart(self, mock_get_output_target): 147 | mock_target = MagicMock(spec=TargetOnKart) 148 | mock_get_output_target.return_value = mock_target 149 | 150 | task = _DummyTaskOnKart() 151 | task.dump({'key': 'value'}, mock_target) 152 | 153 | mock_target.dump.assert_called_once_with( 154 | {'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None, required_task_outputs=[] 155 | ) 156 | 157 | 158 | if __name__ == '__main__': 159 | unittest.main() 160 | -------------------------------------------------------------------------------- /test/test_info.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch 3 | 4 | import luigi 5 | import luigi.mock 6 | from luigi.mock import MockFileSystem, MockTarget 7 | 8 | import gokart 9 | import gokart.info 10 | from test.tree.test_task_info import _DoubleLoadSubTask, _SubTask, _Task 11 | 12 | 13 | class TestInfo(unittest.TestCase): 14 | def setUp(self) -> None: 15 | MockFileSystem().clear() 16 | luigi.setup_logging.DaemonLogging._configured = False 17 | luigi.setup_logging.InterfaceLogging._configured = False 18 | 19 | def tearDown(self) -> None: 20 | luigi.setup_logging.DaemonLogging._configured = False 21 | luigi.setup_logging.InterfaceLogging._configured = False 22 | 23 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) 24 | def test_make_tree_info_pending(self): 25 | task = _Task(param=1, sub=_SubTask(param=2)) 26 | 27 | # check before running 28 | tree = gokart.info.make_tree_info(task) 29 | expected = r""" 30 | └─-\(PENDING\) _Task\[[a-z0-9]*\] 31 | └─-\(PENDING\) _SubTask\[[a-z0-9]*\]$""" 32 | self.assertRegex(tree, expected) 33 | 34 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) 35 | def test_make_tree_info_complete(self): 36 | task = _Task(param=1, sub=_SubTask(param=2)) 37 | 38 | # check after sub task runs 39 | gokart.build(task, reset_register=False) 40 | tree = gokart.info.make_tree_info(task) 41 | expected = r""" 42 | └─-\(COMPLETE\) _Task\[[a-z0-9]*\] 43 | └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$""" 44 | self.assertRegex(tree, expected) 45 | 46 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) 47 | def test_make_tree_info_abbreviation(self): 48 | task = _DoubleLoadSubTask( 49 | sub1=_Task(param=1, sub=_SubTask(param=2)), 50 | sub2=_Task(param=1, sub=_SubTask(param=2)), 51 | ) 52 | 53 | # check after sub task runs 54 | gokart.build(task, reset_register=False) 55 | tree = gokart.info.make_tree_info(task) 56 | expected = r""" 57 | └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\] 58 | \|--\(COMPLETE\) _Task\[[a-z0-9]*\] 59 | \| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\] 60 | └─-\(COMPLETE\) _Task\[[a-z0-9]*\] 61 | └─- \.\.\.$""" 62 | self.assertRegex(tree, expected) 63 | 64 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) 65 | def test_make_tree_info_not_compress(self): 66 | task = _DoubleLoadSubTask( 67 | sub1=_Task(param=1, sub=_SubTask(param=2)), 68 | sub2=_Task(param=1, sub=_SubTask(param=2)), 69 | ) 70 | 71 | # check after sub task runs 72 | gokart.build(task, reset_register=False) 73 | tree = gokart.info.make_tree_info(task, abbr=False) 74 | expected = r""" 75 | └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\] 76 | \|--\(COMPLETE\) _Task\[[a-z0-9]*\] 77 | \| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\] 78 | └─-\(COMPLETE\) _Task\[[a-z0-9]*\] 79 | └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$""" 80 | self.assertRegex(tree, expected) 81 | 82 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) 83 | def test_make_tree_info_not_compress_ignore_task(self): 84 | task = _DoubleLoadSubTask( 85 | sub1=_Task(param=1, sub=_SubTask(param=2)), 86 | sub2=_Task(param=1, sub=_SubTask(param=2)), 87 | ) 88 | 89 | # check after sub task runs 90 | gokart.build(task, reset_register=False) 91 | tree = gokart.info.make_tree_info(task, abbr=False, ignore_task_names=['_Task']) 92 | expected = r""" 93 | └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]$""" 94 | self.assertRegex(tree, expected) 95 | 96 | 97 | if __name__ == '__main__': 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /test/test_large_data_fram_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import unittest 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from gokart.target import LargeDataFrameProcessor 9 | from test.util import _get_temporary_directory 10 | 11 | 12 | class LargeDataFrameProcessorTest(unittest.TestCase): 13 | def setUp(self): 14 | self.temporary_directory = _get_temporary_directory() 15 | 16 | def tearDown(self): 17 | shutil.rmtree(self.temporary_directory, ignore_errors=True) 18 | 19 | def test_save_and_load(self): 20 | file_path = os.path.join(self.temporary_directory, 'test.zip') 21 | df = pd.DataFrame(dict(data=np.random.uniform(0, 1, size=int(1e6)))) 22 | processor = LargeDataFrameProcessor(max_byte=int(1e6)) 23 | processor.save(df, file_path) 24 | loaded = processor.load(file_path) 25 | 26 | pd.testing.assert_frame_equal(loaded, df, check_like=True) 27 | 28 | def test_save_and_load_empty(self): 29 | file_path = os.path.join(self.temporary_directory, 'test_with_empty.zip') 30 | df = pd.DataFrame() 31 | processor = LargeDataFrameProcessor(max_byte=int(1e6)) 32 | processor.save(df, file_path) 33 | loaded = processor.load(file_path) 34 | 35 | pd.testing.assert_frame_equal(loaded, df, check_like=True) 36 | 37 | 38 | if __name__ == '__main__': 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /test/test_list_task_instance_parameter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import luigi 4 | 5 | import gokart 6 | from gokart import TaskOnKart 7 | 8 | 9 | class _DummySubTask(TaskOnKart): 10 | task_namespace = __name__ 11 | pass 12 | 13 | 14 | class _DummyTask(TaskOnKart): 15 | task_namespace = __name__ 16 | param = luigi.IntParameter() 17 | task = gokart.TaskInstanceParameter(default=_DummySubTask()) 18 | 19 | 20 | class ListTaskInstanceParameterTest(unittest.TestCase): 21 | def setUp(self): 22 | _DummyTask.clear_instance_cache() 23 | 24 | def test_serialize_and_parse(self): 25 | original = [_DummyTask(param=3), _DummyTask(param=3)] 26 | s = gokart.ListTaskInstanceParameter().serialize(original) 27 | parsed = gokart.ListTaskInstanceParameter().parse(s) 28 | self.assertEqual(parsed[0].task_id, original[0].task_id) 29 | self.assertEqual(parsed[1].task_id, original[1].task_id) 30 | 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /test/test_mypy.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import unittest 3 | 4 | from mypy import api 5 | 6 | from test.config import PYPROJECT_TOML, PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS 7 | 8 | 9 | class TestMyMypyPlugin(unittest.TestCase): 10 | def test_plugin_no_issue(self): 11 | test_code = """ 12 | import luigi 13 | from luigi import Parameter 14 | import gokart 15 | import datetime 16 | 17 | 18 | class MyTask(gokart.TaskOnKart): 19 | foo: int = luigi.IntParameter() # type: ignore 20 | bar: str = luigi.Parameter() # type: ignore 21 | baz: bool = gokart.ExplicitBoolParameter() 22 | qux: str = Parameter() 23 | # https://github.com/m3dev/gokart/issues/395 24 | datetime: datetime.datetime = luigi.DateMinuteParameter(interval=10, default=datetime.datetime(2021, 1, 1)) 25 | 26 | 27 | 28 | # TaskOnKart parameters: 29 | # - `complete_check_at_run` 30 | MyTask(foo=1, bar='bar', baz=False, qux='qux', complete_check_at_run=False) 31 | """ 32 | 33 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file: 34 | test_file.write(test_code.encode('utf-8')) 35 | test_file.flush() 36 | result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) 37 | self.assertIn('Success: no issues found', result[0]) 38 | 39 | def test_plugin_invalid_arg(self): 40 | test_code = """ 41 | import luigi 42 | import gokart 43 | 44 | 45 | class MyTask(gokart.TaskOnKart): 46 | foo: int = luigi.IntParameter() # type: ignore 47 | bar: str = luigi.Parameter() # type: ignore 48 | baz: bool = gokart.ExplicitBoolParameter() 49 | 50 | # issue: foo is int 51 | # not issue: bar is missing, because it can be set by config file. 52 | # TaskOnKart parameters: 53 | # - `complete_check_at_run` 54 | MyTask(foo='1', baz='not bool', complete_check_at_run='not bool') 55 | """ 56 | 57 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file: 58 | test_file.write(test_code.encode('utf-8')) 59 | test_file.flush() 60 | result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) 61 | self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) 62 | self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0]) 63 | self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0]) 64 | self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0]) 65 | 66 | def test_parameter_has_default_type_invalid_pattern(self): 67 | """ 68 | If user doesn't set the type of the parameter, mypy infer the default type from Parameter types. 69 | """ 70 | test_code = """ 71 | import enum 72 | import luigi 73 | import gokart 74 | 75 | 76 | class MyEnum(enum.Enum): 77 | FOO = enum.auto() 78 | 79 | class MyTask(gokart.TaskOnKart): 80 | foo = luigi.IntParameter() 81 | bar = luigi.DateParameter() 82 | baz = gokart.TaskInstanceParameter() 83 | qux = luigi.NumericalParameter(var_type=int) 84 | quux = luigi.ChoiceParameter(choices=[1, 2, 3], var_type=int) 85 | corge = luigi.EnumParameter(enum=MyEnum) 86 | 87 | MyTask(foo="1", bar=1, baz=1, qux='1', quux='1', corge=1) 88 | """ 89 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file: 90 | test_file.write(test_code.encode('utf-8')) 91 | test_file.flush() 92 | result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) 93 | self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) 94 | self.assertIn('error: Argument "bar" to "MyTask" has incompatible type "int"; expected "date" [arg-type]', result[0]) 95 | self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "int"; expected "TaskOnKart[Any]" [arg-type]', result[0]) 96 | self.assertIn('error: Argument "qux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) 97 | self.assertIn('error: Argument "quux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) 98 | self.assertIn('error: Argument "corge" to "MyTask" has incompatible type "int"; expected "MyEnum" [arg-type]', result[0]) 99 | 100 | def test_parameter_has_default_type_no_issue_pattern(self): 101 | """ 102 | If user doesn't set the type of the parameter, mypy infer the default type from Parameter types. 103 | """ 104 | test_code = """ 105 | from datetime import date 106 | import luigi 107 | import gokart 108 | 109 | class MyTask(gokart.TaskOnKart): 110 | foo = luigi.IntParameter() 111 | bar = luigi.DateParameter() 112 | baz = gokart.TaskInstanceParameter() 113 | 114 | MyTask(foo=1, bar=date.today(), baz=gokart.TaskOnKart()) 115 | """ 116 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file: 117 | test_file.write(test_code.encode('utf-8')) 118 | test_file.flush() 119 | result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) 120 | self.assertIn('Success: no issues found', result[0]) 121 | 122 | def test_no_issue_found_when_missing_parameter_when_default_option(self): 123 | """ 124 | If `disallow_missing_parameters` is False (or default), mypy doesn't show any error when missing parameters. 125 | """ 126 | test_code = """ 127 | import luigi 128 | import gokart 129 | 130 | class MyTask(gokart.TaskOnKart): 131 | foo = luigi.IntParameter() 132 | bar = luigi.Parameter(default="bar") 133 | 134 | MyTask() 135 | """ 136 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file: 137 | test_file.write(test_code.encode('utf-8')) 138 | test_file.flush() 139 | result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) 140 | self.assertIn('Success: no issues found', result[0]) 141 | 142 | def test_issue_found_when_missing_parameter_when_disallow_missing_parameters_set_true(self): 143 | """ 144 | If `disallow_missing_parameters` is True, mypy shows an error when missing parameters. 145 | """ 146 | test_code = """ 147 | import luigi 148 | import gokart 149 | 150 | class MyTask(gokart.TaskOnKart): 151 | # issue: foo is missing 152 | foo = luigi.IntParameter() 153 | # bar has default value, so it is not required to set it. 154 | bar = luigi.Parameter(default="bar") 155 | 156 | MyTask() 157 | """ 158 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file: 159 | test_file.write(test_code.encode('utf-8')) 160 | test_file.flush() 161 | result = api.run( 162 | [ 163 | '--show-traceback', 164 | '--no-incremental', 165 | '--cache-dir=/dev/null', 166 | '--config-file', 167 | str(PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS), 168 | test_file.name, 169 | ] 170 | ) 171 | self.assertIn('error: Missing named argument "foo" for "MyTask" [call-arg]', result[0]) 172 | self.assertIn('Found 1 error in 1 file (checked 1 source file)', result[0]) 173 | -------------------------------------------------------------------------------- /test/test_pandas_type_check_framework.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import unittest 5 | from logging import getLogger 6 | from typing import Any 7 | from unittest.mock import patch 8 | 9 | import luigi 10 | import pandas as pd 11 | from luigi.mock import MockFileSystem, MockTarget 12 | 13 | import gokart 14 | from gokart.build import GokartBuildError 15 | from gokart.pandas_type_config import PandasTypeConfig 16 | 17 | logger = getLogger(__name__) 18 | 19 | 20 | class TestPandasTypeConfig(PandasTypeConfig): 21 | task_namespace = 'test_pandas_type_check_framework' 22 | 23 | @classmethod 24 | def type_dict(cls) -> dict[str, Any]: 25 | return {'system_cd': int} 26 | 27 | 28 | class _DummyFailTask(gokart.TaskOnKart): 29 | task_namespace = 'test_pandas_type_check_framework' 30 | rerun = True 31 | 32 | def output(self): 33 | return self.make_target('dummy.pkl') 34 | 35 | def run(self): 36 | df = pd.DataFrame(dict(system_cd=['1'])) 37 | self.dump(df) 38 | 39 | 40 | class _DummyFailWithNoneTask(gokart.TaskOnKart): 41 | task_namespace = 'test_pandas_type_check_framework' 42 | rerun = True 43 | 44 | def output(self): 45 | return self.make_target('dummy.pkl') 46 | 47 | def run(self): 48 | df = pd.DataFrame(dict(system_cd=[1, None])) 49 | self.dump(df) 50 | 51 | 52 | class _DummySuccessTask(gokart.TaskOnKart): 53 | task_namespace = 'test_pandas_type_check_framework' 54 | rerun = True 55 | 56 | def output(self): 57 | return self.make_target('dummy.pkl') 58 | 59 | def run(self): 60 | df = pd.DataFrame(dict(system_cd=[1])) 61 | self.dump(df) 62 | 63 | 64 | class TestPandasTypeCheckFramework(unittest.TestCase): 65 | def setUp(self) -> None: 66 | luigi.setup_logging.DaemonLogging._configured = False 67 | luigi.setup_logging.InterfaceLogging._configured = False 68 | MockFileSystem().clear() 69 | # same way as luigi https://github.com/spotify/luigi/blob/fe7ecf4acf7cf4c084bd0f32162c8e0721567630/test/helpers.py#L175 70 | self._stashed_reg = luigi.task_register.Register._get_reg() 71 | 72 | def tearDown(self) -> None: 73 | luigi.setup_logging.DaemonLogging._configured = False 74 | luigi.setup_logging.InterfaceLogging._configured = False 75 | luigi.task_register.Register._set_reg(self._stashed_reg) 76 | 77 | @patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummyFailTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock']) 78 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) 79 | def test_fail_with_gokart_run(self): 80 | with self.assertRaises(SystemExit) as exit_code: 81 | gokart.run() 82 | self.assertNotEqual(exit_code.exception.code, 0) # raise Error 83 | 84 | def test_fail(self): 85 | with self.assertRaises(GokartBuildError): 86 | gokart.build(_DummyFailTask(), log_level=logging.CRITICAL) 87 | 88 | def test_fail_with_None(self): 89 | with self.assertRaises(GokartBuildError): 90 | gokart.build(_DummyFailWithNoneTask(), log_level=logging.CRITICAL) 91 | 92 | def test_success(self): 93 | gokart.build(_DummySuccessTask()) 94 | # no error 95 | -------------------------------------------------------------------------------- /test/test_pandas_type_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import date, datetime 4 | from typing import Any 5 | from unittest import TestCase 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from gokart import PandasTypeConfig 11 | from gokart.pandas_type_config import PandasTypeError 12 | 13 | 14 | class _DummyPandasTypeConfig(PandasTypeConfig): 15 | @classmethod 16 | def type_dict(cls) -> dict[str, Any]: 17 | return {'int_column': int, 'datetime_column': datetime, 'array_column': np.ndarray} 18 | 19 | 20 | class TestPandasTypeConfig(TestCase): 21 | def test_int_fail(self): 22 | df = pd.DataFrame(dict(int_column=['1'])) 23 | with self.assertRaises(PandasTypeError): 24 | _DummyPandasTypeConfig().check(df) 25 | 26 | def test_int_success(self): 27 | df = pd.DataFrame(dict(int_column=[1])) 28 | _DummyPandasTypeConfig().check(df) 29 | 30 | def test_datetime_fail(self): 31 | df = pd.DataFrame(dict(datetime_column=[date(2019, 1, 12)])) 32 | with self.assertRaises(PandasTypeError): 33 | _DummyPandasTypeConfig().check(df) 34 | 35 | def test_datetime_success(self): 36 | df = pd.DataFrame(dict(datetime_column=[datetime(2019, 1, 12, 0, 0, 0)])) 37 | _DummyPandasTypeConfig().check(df) 38 | 39 | def test_array_fail(self): 40 | df = pd.DataFrame(dict(array_column=[[1, 2]])) 41 | with self.assertRaises(PandasTypeError): 42 | _DummyPandasTypeConfig().check(df) 43 | 44 | def test_array_success(self): 45 | df = pd.DataFrame(dict(array_column=[np.array([1, 2])])) 46 | _DummyPandasTypeConfig().check(df) 47 | -------------------------------------------------------------------------------- /test/test_restore_task_by_id.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch 3 | 4 | import luigi 5 | import luigi.mock 6 | 7 | import gokart 8 | 9 | 10 | class _SubDummyTask(gokart.TaskOnKart): 11 | task_namespace = __name__ 12 | param = luigi.IntParameter() 13 | 14 | def run(self): 15 | self.dump('test') 16 | 17 | 18 | class _DummyTask(gokart.TaskOnKart): 19 | task_namespace = __name__ 20 | sub_task = gokart.TaskInstanceParameter() 21 | 22 | def output(self): 23 | return self.make_target('test.txt') 24 | 25 | def run(self): 26 | self.dump('test') 27 | 28 | 29 | class RestoreTaskByIDTest(unittest.TestCase): 30 | def setUp(self) -> None: 31 | luigi.mock.MockFileSystem().clear() 32 | 33 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs)) 34 | def test(self): 35 | task = _DummyTask(sub_task=_SubDummyTask(param=10)) 36 | luigi.build([task], local_scheduler=True, log_level='CRITICAL') 37 | 38 | unique_id = task.make_unique_id() 39 | restored = _DummyTask.restore(unique_id) 40 | self.assertTrue(task.make_unique_id(), restored.make_unique_id()) 41 | 42 | 43 | if __name__ == '__main__': 44 | unittest.main() 45 | -------------------------------------------------------------------------------- /test/test_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from unittest.mock import MagicMock, patch 4 | 5 | import luigi 6 | import luigi.mock 7 | 8 | import gokart 9 | from gokart.run import _try_to_send_event_summary_to_slack 10 | 11 | 12 | class _DummyTask(gokart.TaskOnKart): 13 | task_namespace = __name__ 14 | param = luigi.Parameter() 15 | 16 | 17 | class RunTest(unittest.TestCase): 18 | def setUp(self): 19 | luigi.configuration.LuigiConfigParser._instance = None 20 | luigi.mock.MockFileSystem().clear() 21 | os.environ.clear() 22 | 23 | @patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--param', 'test', '--log-level=CRITICAL', '--local-scheduler']) 24 | def test_run(self): 25 | config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini') 26 | luigi.configuration.LuigiConfigParser.add_config_path(config_file_path) 27 | os.environ.setdefault('test_param', 'test') 28 | with self.assertRaises(SystemExit) as exit_code: 29 | gokart.run() 30 | self.assertEqual(exit_code.exception.code, 0) 31 | 32 | @patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--log-level=CRITICAL', '--local-scheduler']) 33 | def test_run_with_undefined_environ(self): 34 | config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini') 35 | luigi.configuration.LuigiConfigParser.add_config_path(config_file_path) 36 | with self.assertRaises(luigi.parameter.MissingParameterException): 37 | gokart.run() 38 | 39 | @patch( 40 | 'sys.argv', 41 | new=[ 42 | 'main', 43 | '--tree-info-mode=simple', 44 | '--tree-info-output-path=tree.txt', 45 | f'{__name__}._DummyTask', 46 | '--param', 47 | 'test', 48 | '--log-level=CRITICAL', 49 | '--local-scheduler', 50 | ], 51 | ) 52 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs)) 53 | def test_run_tree_info(self): 54 | config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini') 55 | luigi.configuration.LuigiConfigParser.add_config_path(config_file_path) 56 | os.environ.setdefault('test_param', 'test') 57 | tree_info = gokart.tree_info(mode='simple', output_path='tree.txt') 58 | with self.assertRaises(SystemExit): 59 | gokart.run() 60 | self.assertTrue(gokart.make_tree_info(_DummyTask(param='test')), tree_info.output().load()) 61 | 62 | @patch('gokart.make_tree_info') 63 | def test_try_to_send_event_summary_to_slack(self, make_tree_info_mock: MagicMock): 64 | event_aggregator_mock = MagicMock() 65 | event_aggregator_mock.get_summury.return_value = f'{__name__}._DummyTask' 66 | event_aggregator_mock.get_event_list.return_value = f'{__name__}._DummyTask:[]' 67 | make_tree_info_mock.return_value = 'tree' 68 | 69 | def get_content(content: str, **kwargs): 70 | self.output = content 71 | 72 | slack_api_mock = MagicMock() 73 | slack_api_mock.send_snippet.side_effect = get_content 74 | 75 | cmdline_args = [f'{__name__}._DummyTask', '--param', 'test'] 76 | with patch('gokart.slack.SlackConfig.send_tree_info', True): 77 | _try_to_send_event_summary_to_slack(slack_api_mock, event_aggregator_mock, cmdline_args) 78 | expects = os.linesep.join(['===== Event List ====', event_aggregator_mock.get_event_list(), os.linesep, '==== Tree Info ====', 'tree']) 79 | 80 | results = self.output 81 | self.assertEqual(expects, results) 82 | 83 | cmdline_args = [f'{__name__}._DummyTask', '--param', 'test'] 84 | with patch('gokart.slack.SlackConfig.send_tree_info', False): 85 | _try_to_send_event_summary_to_slack(slack_api_mock, event_aggregator_mock, cmdline_args) 86 | expects = os.linesep.join( 87 | [ 88 | '===== Event List ====', 89 | event_aggregator_mock.get_event_list(), 90 | os.linesep, 91 | '==== Tree Info ====', 92 | 'Please add SlackConfig.send_tree_info to include tree-info', 93 | ] 94 | ) 95 | 96 | results = self.output 97 | self.assertEqual(expects, results) 98 | 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /test/test_s3_config.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from gokart.s3_config import S3Config 4 | 5 | 6 | class TestS3Config(unittest.TestCase): 7 | def test_get_same_s3_client(self): 8 | client_a = S3Config().get_s3_client() 9 | client_b = S3Config().get_s3_client() 10 | 11 | self.assertEqual(client_a, client_b) 12 | -------------------------------------------------------------------------------- /test/test_s3_zip_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import unittest 4 | 5 | import boto3 6 | from moto import mock_aws 7 | 8 | from gokart.s3_zip_client import S3ZipClient 9 | from test.util import _get_temporary_directory 10 | 11 | 12 | class TestS3ZipClient(unittest.TestCase): 13 | def setUp(self): 14 | self.temporary_directory = _get_temporary_directory() 15 | 16 | def tearDown(self): 17 | shutil.rmtree(self.temporary_directory, ignore_errors=True) 18 | 19 | # remove temporary zip archive if exists. 20 | if os.path.exists(f'{self.temporary_directory}.zip'): 21 | os.remove(f'{self.temporary_directory}.zip') 22 | 23 | @mock_aws 24 | def test_make_archive(self): 25 | conn = boto3.resource('s3', region_name='us-east-1') 26 | conn.create_bucket(Bucket='test') 27 | 28 | file_path = os.path.join('s3://test/', 'test.zip') 29 | temporary_directory = self.temporary_directory 30 | 31 | zip_client = S3ZipClient(file_path=file_path, temporary_directory=temporary_directory) 32 | # raise error if temporary directory does not exist. 33 | with self.assertRaises(FileNotFoundError): 34 | zip_client.make_archive() 35 | 36 | # run without error because temporary directory exists. 37 | os.makedirs(temporary_directory, exist_ok=True) 38 | zip_client.make_archive() 39 | 40 | @mock_aws 41 | def test_unpack_archive(self): 42 | conn = boto3.resource('s3', region_name='us-east-1') 43 | conn.create_bucket(Bucket='test') 44 | 45 | file_path = os.path.join('s3://test/', 'test.zip') 46 | in_temporary_directory = os.path.join(self.temporary_directory, 'in', 'dummy') 47 | out_temporary_directory = os.path.join(self.temporary_directory, 'out', 'dummy') 48 | 49 | # make dummy zip file. 50 | os.makedirs(in_temporary_directory, exist_ok=True) 51 | in_zip_client = S3ZipClient(file_path=file_path, temporary_directory=in_temporary_directory) 52 | in_zip_client.make_archive() 53 | 54 | # load dummy zip file. 55 | out_zip_client = S3ZipClient(file_path=file_path, temporary_directory=out_temporary_directory) 56 | self.assertFalse(os.path.exists(out_temporary_directory)) 57 | out_zip_client.unpack_archive() 58 | -------------------------------------------------------------------------------- /test/test_serializable_parameter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | from dataclasses import asdict, dataclass 4 | 5 | import luigi 6 | import pytest 7 | from luigi.cmdline_parser import CmdlineParser 8 | from mypy import api 9 | 10 | from gokart import SerializableParameter, TaskOnKart 11 | from test.config import PYPROJECT_TOML 12 | 13 | 14 | @dataclass(frozen=True) 15 | class Config: 16 | foo: int 17 | bar: str 18 | 19 | def gokart_serialize(self) -> str: 20 | # dict is ordered in Python 3.7+ 21 | return json.dumps(asdict(self)) 22 | 23 | @classmethod 24 | def gokart_deserialize(cls, s: str) -> 'Config': 25 | return cls(**json.loads(s)) 26 | 27 | 28 | class SerializableParameterWithOutDefault(TaskOnKart): 29 | task_namespace = __name__ 30 | config: Config = SerializableParameter(object_type=Config) 31 | 32 | def run(self): 33 | self.dump(self.config) 34 | 35 | 36 | class SerializableParameterWithDefault(TaskOnKart): 37 | task_namespace = __name__ 38 | config: Config = SerializableParameter(object_type=Config, default=Config(foo=1, bar='bar')) 39 | 40 | def run(self): 41 | self.dump(self.config) 42 | 43 | 44 | class TestSerializableParameter: 45 | def test_default(self): 46 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithDefault']) as cp: 47 | assert cp.get_task_obj().config == Config(foo=1, bar='bar') 48 | 49 | def test_parse_param(self): 50 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp: 51 | assert cp.get_task_obj().config == Config(foo=100, bar='val') 52 | 53 | def test_missing_parameter(self): 54 | with pytest.raises(luigi.parameter.MissingParameterException): 55 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault']) as cp: 56 | cp.get_task_obj() 57 | 58 | def test_value_error(self): 59 | with pytest.raises(ValueError): 60 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', 'Foo']) as cp: 61 | cp.get_task_obj() 62 | 63 | def test_expected_one_argument_error(self): 64 | with pytest.raises(SystemExit): 65 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config']) as cp: 66 | cp.get_task_obj() 67 | 68 | def test_mypy(self): 69 | """check invalid object cannot used for SerializableParameter""" 70 | 71 | test_code = """ 72 | import gokart 73 | 74 | class InvalidClass: 75 | ... 76 | 77 | gokart.SerializableParameter(object_type=InvalidClass) 78 | """ 79 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file: 80 | test_file.write(test_code.encode('utf-8')) 81 | test_file.flush() 82 | result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) 83 | assert 'Value of type variable "S" of "SerializableParameter" cannot be "InvalidClass" [type-var]' in result[0] 84 | -------------------------------------------------------------------------------- /test/test_task_instance_parameter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import luigi 4 | 5 | import gokart 6 | from gokart import ListTaskInstanceParameter, TaskInstanceParameter, TaskOnKart 7 | 8 | 9 | class _DummySubTask(TaskOnKart): 10 | task_namespace = __name__ 11 | pass 12 | 13 | 14 | class _DummyCorrectSubClassTask(_DummySubTask): 15 | task_namespace = __name__ 16 | pass 17 | 18 | 19 | class _DummyInvalidSubClassTask(TaskOnKart): 20 | task_namespace = __name__ 21 | pass 22 | 23 | 24 | class _DummyTask(TaskOnKart): 25 | task_namespace = __name__ 26 | param = luigi.IntParameter() 27 | task = TaskInstanceParameter(default=_DummySubTask()) 28 | 29 | 30 | class _DummyListTask(TaskOnKart): 31 | task_namespace = __name__ 32 | param = luigi.IntParameter() 33 | task = ListTaskInstanceParameter(default=[_DummySubTask(), _DummySubTask()]) 34 | 35 | 36 | class TaskInstanceParameterTest(unittest.TestCase): 37 | def setUp(self): 38 | _DummyTask.clear_instance_cache() 39 | 40 | def test_serialize_and_parse(self): 41 | original = _DummyTask(param=2) 42 | s = gokart.TaskInstanceParameter().serialize(original) 43 | parsed = gokart.TaskInstanceParameter().parse(s) 44 | self.assertEqual(parsed.task_id, original.task_id) 45 | 46 | def test_serialize_and_parse_list_params(self): 47 | original = _DummyListTask(param=2) 48 | s = gokart.TaskInstanceParameter().serialize(original) 49 | parsed = gokart.TaskInstanceParameter().parse(s) 50 | self.assertEqual(parsed.task_id, original.task_id) 51 | 52 | def test_invalid_class(self): 53 | self.assertRaises(TypeError, lambda: gokart.TaskInstanceParameter(expected_type=1)) # not type instance 54 | 55 | def test_params_with_correct_param_type(self): 56 | class _DummyPipelineA(TaskOnKart): 57 | task_namespace = __name__ 58 | subtask = gokart.TaskInstanceParameter(expected_type=_DummySubTask) 59 | 60 | task = _DummyPipelineA(subtask=_DummyCorrectSubClassTask()) 61 | self.assertEqual(task.requires()['subtask'], _DummyCorrectSubClassTask()) # type: ignore 62 | 63 | def test_params_with_invalid_param_type(self): 64 | class _DummyPipelineB(TaskOnKart): 65 | task_namespace = __name__ 66 | subtask = gokart.TaskInstanceParameter(expected_type=_DummySubTask) 67 | 68 | with self.assertRaises(TypeError): 69 | _DummyPipelineB(subtask=_DummyInvalidSubClassTask()) 70 | 71 | 72 | class ListTaskInstanceParameterTest(unittest.TestCase): 73 | def setUp(self): 74 | _DummyTask.clear_instance_cache() 75 | 76 | def test_invalid_class(self): 77 | self.assertRaises(TypeError, lambda: gokart.ListTaskInstanceParameter(expected_elements_type=1)) # not type instance 78 | 79 | def test_list_params_with_correct_param_types(self): 80 | class _DummyPipelineC(TaskOnKart): 81 | task_namespace = __name__ 82 | subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask) 83 | 84 | task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()]) 85 | self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(),)) # type: ignore 86 | 87 | def test_list_params_with_invalid_param_types(self): 88 | class _DummyPipelineD(TaskOnKart): 89 | task_namespace = __name__ 90 | subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask) 91 | 92 | with self.assertRaises(TypeError): 93 | _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()]) 94 | 95 | 96 | if __name__ == '__main__': 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from gokart.utils import flatten, map_flattenable_items 4 | 5 | 6 | class TestFlatten(unittest.TestCase): 7 | def test_flatten_dict(self): 8 | self.assertEqual(flatten({'a': 'foo', 'b': 'bar'}), ['foo', 'bar']) 9 | 10 | def test_flatten_list(self): 11 | self.assertEqual(flatten(['foo', ['bar', 'troll']]), ['foo', 'bar', 'troll']) 12 | 13 | def test_flatten_str(self): 14 | self.assertEqual(flatten('foo'), ['foo']) 15 | 16 | def test_flatten_int(self): 17 | self.assertEqual(flatten(42), [42]) 18 | 19 | def test_flatten_none(self): 20 | self.assertEqual(flatten(None), []) 21 | 22 | 23 | class TestMapFlatten(unittest.TestCase): 24 | def test_map_flattenable_items(self): 25 | self.assertEqual(map_flattenable_items(lambda x: str(x), {'a': 1, 'b': 2}), {'a': '1', 'b': '2'}) 26 | self.assertEqual( 27 | map_flattenable_items(lambda x: str(x), (1, 2, 3, (4, 5, (6, 7, {'a': (8, 9, 0)})))), 28 | ('1', '2', '3', ('4', '5', ('6', '7', {'a': ('8', '9', '0')}))), 29 | ) 30 | self.assertEqual( 31 | map_flattenable_items( 32 | lambda x: str(x), 33 | {'a': [1, 2, 3, '4'], 'b': {'c': True, 'd': {'e': 5}}}, 34 | ), 35 | {'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}}, 36 | ) 37 | -------------------------------------------------------------------------------- /test/test_worker.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from unittest.mock import Mock 3 | 4 | import luigi 5 | import luigi.worker 6 | import pytest 7 | from luigi import scheduler 8 | 9 | import gokart 10 | from gokart.worker import Worker, gokart_worker 11 | 12 | 13 | class _DummyTask(gokart.TaskOnKart): 14 | task_namespace = __name__ 15 | random_id = luigi.Parameter() 16 | 17 | def _run(self): ... 18 | 19 | def run(self): 20 | self._run() 21 | self.dump('test') 22 | 23 | 24 | class TestWorkerRun: 25 | def test_run(self, monkeypatch: pytest.MonkeyPatch): 26 | """Check run is called when the task is not completed""" 27 | sch = scheduler.Scheduler() 28 | worker = Worker(scheduler=sch) 29 | 30 | task = _DummyTask(random_id=uuid.uuid4().hex) 31 | mock_run = Mock() 32 | monkeypatch.setattr(task, '_run', mock_run) 33 | with worker: 34 | assert worker.add(task) 35 | assert worker.run() 36 | mock_run.assert_called_once() 37 | 38 | 39 | class _DummyTaskToCheckSkip(gokart.TaskOnKart[None]): 40 | task_namespace = __name__ 41 | 42 | def _run(self): ... 43 | 44 | def run(self): 45 | self._run() 46 | self.dump(None) 47 | 48 | def complete(self) -> bool: 49 | return False 50 | 51 | 52 | class TestWorkerSkipIfCompletedPreRun: 53 | @pytest.mark.parametrize( 54 | 'task_completion_check_at_run,is_completed,expect_skipped', 55 | [ 56 | pytest.param(True, True, True, id='skipped when completed and task_completion_check_at_run is True'), 57 | pytest.param(True, False, False, id='not skipped when not completed and task_completion_check_at_run is True'), 58 | pytest.param(False, True, False, id='not skipped when completed and task_completion_check_at_run is False'), 59 | pytest.param(False, False, False, id='not skipped when not completed and task_completion_check_at_run is False'), 60 | ], 61 | ) 62 | def test_skip_task(self, monkeypatch: pytest.MonkeyPatch, task_completion_check_at_run: bool, is_completed: bool, expect_skipped: bool): 63 | sch = scheduler.Scheduler() 64 | worker = Worker(scheduler=sch, config=gokart_worker(task_completion_check_at_run=task_completion_check_at_run)) 65 | 66 | mock_complete = Mock(return_value=is_completed) 67 | # NOTE: set `complete_check_at_run=False` to avoid using deprecated skip logic. 68 | task = _DummyTaskToCheckSkip(complete_check_at_run=False) 69 | mock_run = Mock() 70 | monkeypatch.setattr(task, '_run', mock_run) 71 | 72 | with worker: 73 | assert worker.add(task) 74 | # NOTE: mock `complete` after `add` because `add` calls `complete` 75 | # to check if the task is already completed. 76 | monkeypatch.setattr(task, 'complete', mock_complete) 77 | assert worker.run() 78 | 79 | if expect_skipped: 80 | mock_run.assert_not_called() 81 | else: 82 | mock_run.assert_called_once() 83 | -------------------------------------------------------------------------------- /test/test_zoned_date_second_parameter.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import unittest 3 | 4 | from luigi.cmdline_parser import CmdlineParser 5 | 6 | from gokart import TaskOnKart, ZonedDateSecondParameter 7 | 8 | 9 | class ZonedDateSecondParameterTaskWithoutDefault(TaskOnKart): 10 | task_namespace = __name__ 11 | dt: datetime.datetime = ZonedDateSecondParameter() 12 | 13 | def run(self): 14 | self.dump(self.dt) 15 | 16 | 17 | class ZonedDateSecondParameterTaskWithDefault(TaskOnKart): 18 | task_namespace = __name__ 19 | dt: datetime.datetime = ZonedDateSecondParameter(default=datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9)))) 20 | 21 | def run(self): 22 | self.dump(self.dt) 23 | 24 | 25 | class ZonedDateSecondParameterTest(unittest.TestCase): 26 | def setUp(self): 27 | self.default_datetime = datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9))) 28 | self.default_datetime_str = '2025-02-21T12:00:00+09:00' 29 | 30 | def test_default(self): 31 | with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault']) as cp: 32 | assert cp.get_task_obj().dt == self.default_datetime 33 | 34 | def test_parse_param_with_tz_suffix(self): 35 | with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault', '--dt', '2024-01-20T11:00:00+09:00']) as cp: 36 | assert cp.get_task_obj().dt == datetime.datetime(2024, 1, 20, 11, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9))) 37 | 38 | def test_parse_param_with_Z_suffix(self): 39 | with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault', '--dt', '2024-01-20T11:00:00Z']) as cp: 40 | assert cp.get_task_obj().dt == datetime.datetime(2024, 1, 20, 11, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=0))) 41 | 42 | def test_parse_param_without_timezone_input(self): 43 | with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithoutDefault', '--dt', '2025-02-21T12:00:00']) as cp: 44 | assert cp.get_task_obj().dt == datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=None) 45 | 46 | def test_parse_method(self): 47 | actual = ZonedDateSecondParameter().parse(self.default_datetime_str) 48 | expected = self.default_datetime 49 | self.assertEqual(actual, expected) 50 | 51 | def test_serialize_task(self): 52 | task = ZonedDateSecondParameterTaskWithoutDefault(dt=self.default_datetime) 53 | actual = str(task) 54 | expected = f'(dt={self.default_datetime_str})' 55 | self.assertTrue(actual.endswith(expected)) 56 | 57 | 58 | if __name__ == '__main__': 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /test/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/testing/__init__.py -------------------------------------------------------------------------------- /test/testing/test_pandas_assert.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pandas as pd 4 | 5 | import gokart 6 | 7 | 8 | class TestPandasAssert(unittest.TestCase): 9 | def test_assert_frame_contents_equal(self): 10 | expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2]) 11 | resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2]) 12 | 13 | gokart.testing.assert_frame_contents_equal(resulted, expected) 14 | 15 | def test_assert_frame_contents_equal_with_small_error(self): 16 | expected = pd.DataFrame(data=dict(f1=[1.0001, 2.0001, 3.0001], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2]) 17 | resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2.0002, 1.0002, 3.0002], f3=[222, 111, 333]), index=[1, 0, 2]) 18 | 19 | gokart.testing.assert_frame_contents_equal(resulted, expected, atol=1e-1) 20 | 21 | def test_assert_frame_contents_equal_with_duplicated_columns(self): 22 | expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2]) 23 | expected.columns = ['f1', 'f1', 'f2'] 24 | resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2]) 25 | resulted.columns = ['f2', 'f1', 'f1'] 26 | 27 | with self.assertRaises(AssertionError): 28 | gokart.testing.assert_frame_contents_equal(resulted, expected) 29 | 30 | def test_assert_frame_contents_equal_with_duplicated_indexes(self): 31 | expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2]) 32 | expected.index = [0, 1, 1] 33 | resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2]) 34 | expected.index = [1, 0, 1] 35 | 36 | with self.assertRaises(AssertionError): 37 | gokart.testing.assert_frame_contents_equal(resulted, expected) 38 | -------------------------------------------------------------------------------- /test/tree/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/tree/__init__.py -------------------------------------------------------------------------------- /test/tree/test_task_info_formatter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import gokart 4 | from gokart.tree.task_info_formatter import RequiredTask, _make_requires_info 5 | 6 | 7 | class _RequiredTaskExampleTaskA(gokart.TaskOnKart): 8 | task_namespace = __name__ 9 | 10 | 11 | class TestMakeRequiresInfo(unittest.TestCase): 12 | def test_make_requires_info_with_task_on_kart(self): 13 | requires = _RequiredTaskExampleTaskA() 14 | resulted = _make_requires_info(requires=requires) 15 | expected = RequiredTask(name=requires.__class__.__name__, unique_id=requires.make_unique_id()) 16 | self.assertEqual(resulted, expected) 17 | 18 | def test_make_requires_info_with_list(self): 19 | requires = [_RequiredTaskExampleTaskA()] 20 | resulted = _make_requires_info(requires=requires) 21 | expected = [RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for require in requires] 22 | self.assertEqual(resulted, expected) 23 | 24 | def test_make_requires_info_with_generator(self): 25 | def _requires_gen(): 26 | return (_RequiredTaskExampleTaskA() for _ in range(2)) 27 | 28 | resulted = _make_requires_info(requires=_requires_gen()) 29 | expected = [RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for require in _requires_gen()] 30 | self.assertEqual(resulted, expected) 31 | 32 | def test_make_requires_info_with_dict(self): 33 | requires = dict(taskA=_RequiredTaskExampleTaskA()) 34 | resulted = _make_requires_info(requires=requires) 35 | expected = {key: RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for key, require in requires.items()} 36 | self.assertEqual(resulted, expected) 37 | 38 | def test_make_requires_info_with_invalid(self): 39 | requires = [1, 2] 40 | with self.assertRaises(TypeError): 41 | _make_requires_info(requires=requires) 42 | -------------------------------------------------------------------------------- /test/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | 4 | 5 | # TODO: use pytest.fixture to share this functionality with other tests 6 | def _get_temporary_directory(): 7 | _uuid = str(uuid.uuid4()) 8 | return os.path.abspath(os.path.join(os.path.dirname(__name__), f'temporary-{_uuid}')) 9 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py{39,310,311,312,313},ruff,mypy 3 | skipsdist = True 4 | 5 | [testenv] 6 | runner = uv-venv-lock-runner 7 | dependency_groups = test 8 | commands = 9 | {envpython} -m pytest --cov=gokart --cov-report=xml -vv {posargs:} 10 | 11 | [testenv:ruff] 12 | dependency_groups = lint 13 | commands = 14 | ruff check {posargs:} 15 | ruff format --check {posargs:} 16 | 17 | [testenv:mypy] 18 | dependency_groups = lint 19 | commands = mypy gokart test {posargs:} 20 | --------------------------------------------------------------------------------