├── .github
├── CODEOWNERS
└── workflows
│ ├── format.yml
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── docs
├── Makefile
├── conf.py
├── efficient_run_on_multi_workers.rst
├── for_pandas.rst
├── gokart.rst
├── gokart_logo_side_isolation.svg
├── index.rst
├── intro_to_gokart.rst
├── logging.rst
├── make.bat
├── mypy_plugin.rst
├── requirements.txt
├── setting_task_parameters.rst
├── slack_notification.rst
├── task_information.rst
├── task_on_kart.rst
├── task_parameters.rst
├── task_settings.rst
├── tutorial.rst
└── using_task_task_conflict_prevention_lock.rst
├── examples
├── gokart_notebook_example.ipynb
├── logging.ini
└── param.ini
├── gokart
├── __init__.py
├── build.py
├── build_process_task_info.py
├── config_params.py
├── conflict_prevention_lock
│ ├── task_lock.py
│ └── task_lock_wrappers.py
├── errors
│ └── __init__.py
├── file_processor.py
├── gcs_config.py
├── gcs_obj_metadata_client.py
├── gcs_zip_client.py
├── in_memory
│ ├── __init__.py
│ ├── data.py
│ ├── repository.py
│ └── target.py
├── info.py
├── mypy.py
├── object_storage.py
├── pandas_type_config.py
├── parameter.py
├── py.typed
├── required_task_output.py
├── run.py
├── s3_config.py
├── s3_zip_client.py
├── slack
│ ├── __init__.py
│ ├── event_aggregator.py
│ ├── slack_api.py
│ └── slack_config.py
├── target.py
├── task.py
├── task_complete_check.py
├── testing
│ ├── __init__.py
│ ├── check_if_run_with_empty_data_frame.py
│ └── pandas_assert.py
├── tree
│ ├── task_info.py
│ └── task_info_formatter.py
├── utils.py
├── worker.py
├── workspace_management.py
├── zip_client.py
└── zip_client_util.py
├── luigi.cfg
├── pyproject.toml
├── test
├── __init__.py
├── config
│ ├── __init__.py
│ ├── pyproject.toml
│ ├── pyproject_disallow_missing_parameters.toml
│ └── test_config.ini
├── conflict_prevention_lock
│ ├── __init__.py
│ ├── test_task_lock.py
│ └── test_task_lock_wrappers.py
├── in_memory
│ ├── test_in_memory_target.py
│ └── test_repository.py
├── slack
│ ├── __init__.py
│ └── test_slack_api.py
├── test_build.py
├── test_cache_unique_id.py
├── test_config_params.py
├── test_explicit_bool_parameter.py
├── test_file_processor.py
├── test_gcs_config.py
├── test_gcs_obj_metadata_client.py
├── test_info.py
├── test_large_data_fram_processor.py
├── test_list_task_instance_parameter.py
├── test_mypy.py
├── test_pandas_type_check_framework.py
├── test_pandas_type_config.py
├── test_restore_task_by_id.py
├── test_run.py
├── test_s3_config.py
├── test_s3_zip_client.py
├── test_serializable_parameter.py
├── test_target.py
├── test_task_instance_parameter.py
├── test_task_on_kart.py
├── test_utils.py
├── test_worker.py
├── test_zoned_date_second_parameter.py
├── testing
│ ├── __init__.py
│ └── test_pandas_assert.py
├── tree
│ ├── __init__.py
│ ├── test_task_info.py
│ └── test_task_info_formatter.py
└── util.py
├── tox.ini
└── uv.lock
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @Hi-king
2 | * @yokomotod
3 | * @hirosassa
4 | * @mski-iksm
5 | * @kitagry
6 | * @ujiuji1259
7 | * @mamo3gr
8 | * @hiro-o918
9 |
--------------------------------------------------------------------------------
/.github/workflows/format.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 |
8 |
9 | jobs:
10 | formatting-check:
11 |
12 | name: Lint
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Set up the latest version of uv
18 | uses: astral-sh/setup-uv@v5
19 | with:
20 | enable-cache: true
21 | - name: Install dependencies
22 | run: |
23 | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv
24 | - name: Run ruff and mypy
25 | run: |
26 | uvx --with tox-uv tox run -e ruff,mypy
27 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | push:
5 | tags: '*'
6 |
7 | jobs:
8 | deploy:
9 |
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 | - name: Set up the latest version of uv
15 | uses: astral-sh/setup-uv@v5
16 | with:
17 | enable-cache: true
18 | - name: Build and publish
19 | env:
20 | UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
21 | run: |
22 | uv build
23 | uv publish
24 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 |
8 | jobs:
9 | tests:
10 | runs-on: ${{ matrix.platform }}
11 | strategy:
12 | max-parallel: 7
13 | matrix:
14 | platform: ["ubuntu-latest"]
15 | tox-env: ["py39", "py310", "py311", "py312", "py313"]
16 | include:
17 | - platform: macos-13
18 | tox-env: "py313"
19 | - platform: macos-latest
20 | tox-env: "py313"
21 | steps:
22 | - uses: actions/checkout@v4
23 | - name: Set up the latest version of uv
24 | uses: astral-sh/setup-uv@v5
25 | with:
26 | enable-cache: true
27 | - name: Install dependencies
28 | run: |
29 | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv
30 | - name: Test with tox
31 | run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }}
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # pycharm
107 | .idea
108 |
109 | # gokart
110 | resources
111 | examples/resources
112 |
113 | # poetry
114 | dist
115 |
116 | # temporary data
117 | temporary.zip
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file for Sphinx projects
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the OS, Python version and other tools you might need
8 | build:
9 | os: ubuntu-24.04
10 | tools:
11 | python: "3.12"
12 |
13 | # Build from the docs/ directory with Sphinx
14 | sphinx:
15 | configuration: docs/conf.py
16 |
17 | # Optional but recommended, declare the Python requirements required
18 | # to build your documentation
19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
20 | python:
21 | install:
22 | - requirements: docs/requirements.txt
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 M3, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # gokart
2 |
3 |
4 |
5 |
6 |
7 | [](https://github.com/m3dev/gokart/actions?query=workflow%3ATest)
8 | [](https://gokart.readthedocs.io/en/latest/)
9 | [](https://pypi.org/project/gokart/)
10 | [](https://pypi.org/project/gokart/)
11 | 
12 |
13 | Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline.
14 |
15 |
16 | [Documentation](https://gokart.readthedocs.io/en/latest/) for the latest release is hosted on readthedocs.
17 |
18 |
19 | # About gokart
20 |
21 | Here are some good things about gokart.
22 |
23 | - The following meta data for each Task is stored separately in a `pkl` file with hash value
24 | - task output data
25 | - imported all module versions
26 | - task processing time
27 | - random seed in task
28 | - displayed log
29 | - all parameters set as class variables in the task
30 | - Automatically rerun the pipeline if parameters of Tasks are changed.
31 | - Support GCS and S3 as a data store for intermediate results of Tasks in the pipeline.
32 | - The above output is exchanged between tasks as an intermediate file, which is memory-friendly
33 | - `pandas.DataFrame` type and column checking during I/O
34 | - Directory structure of saved files is automatically determined from structure of script
35 | - Seeds for numpy and random are automatically fixed
36 | - Can code while adhering to [SOLID](https://en.wikipedia.org/wiki/SOLID) principles as much as possible
37 | - Tasks are locked via redis even if they run in parallel
38 |
39 | **All the functions above are created for constructing Machine Learning batches. Provides an excellent environment for reproducibility and team development.**
40 |
41 |
42 | Here are some non-goal / downside of the gokart.
43 | - Batch execution in parallel is supported, but parallel and concurrent execution of task in memory.
44 | - Gokart is focused on reproducibility. So, I/O and capacity of data storage can become a bottleneck.
45 | - No support for task visualize.
46 | - Gokart is not an experiment management tool. The management of the execution result is cut out as [Thunderbolt](https://github.com/m3dev/thunderbolt).
47 | - Gokart does not recommend writing pipelines in toml, yaml, json, and more. Gokart is preferring to write them in Python.
48 |
49 | # Getting Started
50 |
51 | Within the activated Python environment, use the following command to install gokart.
52 |
53 | ```
54 | pip install gokart
55 | ```
56 |
57 |
58 | # Quickstart
59 |
60 | ## Minimal Example
61 |
62 | A minimal gokart tasks looks something like this:
63 |
64 |
65 | ```python
66 | import gokart
67 |
68 | class Example(gokart.TaskOnKart):
69 | def run(self):
70 | self.dump('Hello, world!')
71 |
72 | task = Example()
73 | output = gokart.build(task)
74 | print(output)
75 | ```
76 |
77 | `gokart.build` return the result of dump by `gokart.TaskOnKart`. The example will output the following.
78 |
79 |
80 | ```
81 | Hello, world!
82 | ```
83 |
84 | ## Type-Safe Pipeline Example
85 |
86 | We introduce type-annotations to make a gokart pipeline robust.
87 | Please check the following example to see how to use type-annotations on gokart.
88 | Before using this feature, ensure to enable [mypy plugin](https://gokart.readthedocs.io/en/latest/mypy_plugin.html) feature in your project.
89 |
90 | ```python
91 | import gokart
92 |
93 | # `gokart.TaskOnKart[str]` means that the task dumps `str`
94 | class StrDumpTask(gokart.TaskOnKart[str]):
95 | def run(self):
96 | self.dump('Hello, world!')
97 |
98 | # `gokart.TaskOnKart[int]` means that the task dumps `int`
99 | class OneDumpTask(gokart.TaskOnKart[int]):
100 | def run(self):
101 | self.dump(1)
102 |
103 | # `gokart.TaskOnKart[int]` means that the task dumps `int`
104 | class TwoDumpTask(gokart.TaskOnKart[int]):
105 | def run(self):
106 | self.dump(2)
107 |
108 | class AddTask(gokart.TaskOnKart[int]):
109 | # `a` requires a task to dump `int`
110 | a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
111 | # `b` requires a task to dump `int`
112 | b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
113 |
114 | def requires(self):
115 | return dict(a=self.a, b=self.b)
116 |
117 | def run(self):
118 | # loading by instance parameter,
119 | # `a` and `b` are treated as `int`
120 | # because they are declared as `gokart.TaskOnKart[int]`
121 | a = self.load(self.a)
122 | b = self.load(self.b)
123 | self.dump(a + b)
124 |
125 |
126 | valid_task = AddTask(a=OneDumpTask(), b=TwoDumpTask())
127 | # the next line will show type error by mypy
128 | # because `StrDumpTask` dumps `str` and `AddTask` requires `int`
129 | invalid_task = AddTask(a=OneDumpTask(), b=StrDumpTask())
130 | ```
131 |
132 | This is an introduction to some of the gokart.
133 | There are still more useful features.
134 |
135 | Please See [Documentation](https://gokart.readthedocs.io/en/latest/) .
136 |
137 | Have a good gokart life.
138 |
139 | # Achievements
140 |
141 | Gokart is a proven product.
142 |
143 | - It's actually been used by [m3.inc](https://corporate.m3.com/en) for over 3 years
144 | - Natural Language Processing Competition by [Nishika.inc](https://nishika.com) 2nd prize : [Solution Repository](https://github.com/vaaaaanquish/nishika_akutagawa_2nd_prize)
145 |
146 |
147 | # Thanks
148 |
149 | gokart is a wrapper for luigi. Thanks to luigi and dependent projects!
150 |
151 | - [luigi](https://github.com/spotify/luigi)
152 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # https://github.com/sphinx-doc/sphinx/issues/6211
2 | import luigi
3 |
4 | import gokart
5 |
6 | luigi.task.Task.requires.__doc__ = gokart.task.TaskOnKart.requires.__doc__
7 | luigi.task.Task.output.__doc__ = gokart.task.TaskOnKart.output.__doc__
8 |
9 | #
10 | # Configuration file for the Sphinx documentation builder.
11 | #
12 | # This file does only contain a selection of the most common options. For a
13 | # full list see the documentation:
14 | # http://www.sphinx-doc.org/en/master/config
15 |
16 | # -- Path setup --------------------------------------------------------------
17 |
18 | # If extensions (or modules to document with autodoc) are in another directory,
19 | # add these directories to sys.path here. If the directory is relative to the
20 | # documentation root, use os.path.abspath to make it absolute, like shown here.
21 |
22 | # import os
23 | # import sys
24 | # sys.path.insert(0, os.path.abspath('../gokart/'))
25 |
26 | # -- Project information -----------------------------------------------------
27 |
28 | project = 'gokart'
29 | copyright = '2019, Masahiro Nishiba'
30 | author = 'Masahiro Nishiba'
31 |
32 | # The short X.Y version
33 | version = ''
34 | # The full version, including alpha/beta/rc tags
35 | release = ''
36 |
37 | # -- General configuration ---------------------------------------------------
38 |
39 | # If your documentation needs a minimal Sphinx version, state it here.
40 | #
41 | # needs_sphinx = '1.0'
42 |
43 | # Add any Sphinx extension module names here, as strings. They can be
44 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
45 | # ones.
46 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode']
47 |
48 | # Add any paths that contain templates here, relative to this directory.
49 | templates_path = ['_templates']
50 |
51 | # The suffix(es) of source filenames.
52 | # You can specify multiple suffix as a list of string:
53 | #
54 | # source_suffix = ['.rst', '.md']
55 | source_suffix = '.rst'
56 |
57 | # The master toctree document.
58 | master_doc = 'index'
59 |
60 | # The language for content autogenerated by Sphinx. Refer to documentation
61 | # for a list of supported languages.
62 | #
63 | # This is also used if you do content translation via gettext catalogs.
64 | # Usually you set "language" from the command line for these cases.
65 | language = None
66 |
67 | # List of patterns, relative to source directory, that match files and
68 | # directories to ignore when looking for source files.
69 | # This pattern also affects html_static_path and html_extra_path.
70 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
71 |
72 | # The name of the Pygments (syntax highlighting) style to use.
73 | pygments_style = None
74 |
75 | # -- Options for HTML output -------------------------------------------------
76 |
77 | # The theme to use for HTML and HTML Help pages. See the documentation for
78 | # a list of builtin themes.
79 | #
80 | html_theme = 'sphinx_rtd_theme'
81 |
82 | # Theme options are theme-specific and customize the look and feel of a theme
83 | # further. For a list of options available for each theme, see the
84 | # documentation.
85 | #
86 | # html_theme_options = {}
87 |
88 | # Add any paths that contain custom static files (such as style sheets) here,
89 | # relative to this directory. They are copied after the builtin static files,
90 | # so a file named "default.css" will overwrite the builtin "default.css".
91 | html_static_path = []
92 |
93 | # Custom sidebar templates, must be a dictionary that maps document names
94 | # to template names.
95 | #
96 | # The default sidebars (for documents that don't match any pattern) are
97 | # defined by theme itself. Builtin themes are using these templates by
98 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
99 | # 'searchbox.html']``.
100 |
101 | # html_sidebars = {}
102 |
103 | # -- Options for HTMLHelp output ---------------------------------------------
104 |
105 | # Output file base name for HTML help builder.
106 | htmlhelp_basename = 'gokartdoc'
107 |
108 | # -- Options for LaTeX output ------------------------------------------------
109 |
110 | latex_elements = {
111 | # The paper size ('letterpaper' or 'a4paper').
112 | #
113 | # 'papersize': 'letterpaper',
114 | # The font size ('10pt', '11pt' or '12pt').
115 | #
116 | # 'pointsize': '10pt',
117 | # Additional stuff for the LaTeX preamble.
118 | #
119 | # 'preamble': '',
120 | # Latex figure (float) alignment
121 | #
122 | # 'figure_align': 'htbp',
123 | }
124 |
125 | # Grouping the document tree into LaTeX files. List of tuples
126 | # (source start file, target name, title,
127 | # author, documentclass [howto, manual, or own class]).
128 | latex_documents = [
129 | (master_doc, 'gokart.tex', 'gokart Documentation', 'Masahiro Nishiba', 'manual'),
130 | ]
131 |
132 | # -- Options for manual page output ------------------------------------------
133 |
134 | # One entry per manual page. List of tuples
135 | # (source start file, name, description, authors, manual section).
136 | man_pages = [(master_doc, 'gokart', 'gokart Documentation', [author], 1)]
137 |
138 | # -- Options for Texinfo output ----------------------------------------------
139 |
140 | # Grouping the document tree into Texinfo files. List of tuples
141 | # (source start file, target name, title, author,
142 | # dir menu entry, description, category)
143 | texinfo_documents = [
144 | (master_doc, 'gokart', 'gokart Documentation', author, 'gokart', 'One line description of project.', 'Miscellaneous'),
145 | ]
146 |
147 | # -- Options for Epub output -------------------------------------------------
148 |
149 | # Bibliographic Dublin Core info.
150 | epub_title = project
151 |
152 | # The unique identifier of the text. This can be a ISBN number
153 | # or the project homepage.
154 | #
155 | # epub_identifier = ''
156 |
157 | # A unique identification for the text.
158 | #
159 | # epub_uid = ''
160 |
161 | # A list of files that should not be packed into the epub file.
162 | epub_exclude_files = ['search.html']
163 |
--------------------------------------------------------------------------------
/docs/efficient_run_on_multi_workers.rst:
--------------------------------------------------------------------------------
1 | How to improve efficiency when running on multiple workers
2 | ===========================================================
3 |
4 | If multiple worker nodes are running similar gokart pipelines in parallel, it is possible that the exact same task may be executed by multiple workers.
5 | (For example, when training multiple machine learning models with different parameters, the feature creation task in the first stage is expected to be exactly the same.)
6 |
7 | It is inefficient to execute the same task on each of multiple worker nodes, so we want to avoid this.
8 | Here we introduce `should_lock_run` feature to improve this inefficiency.
9 |
10 |
11 |
12 | Suppress run() of the same task with `should_lock_run`
13 | ------------------------------------------------------
14 | When `gokart.TaskOnKart.should_lock_run` is set to True, the task will fail if the same task is run()-ing by another worker.
15 | By failing the task, other tasks that can be executed at that time are given priority.
16 | After that, the failed task is automatically re-executed.
17 |
18 | .. code:: python
19 |
20 | class SampleTask2(gokart.TaskOnKart):
21 | should_lock_run = True
22 |
23 |
24 | Additional Option
25 | ------------------
26 |
27 | Skip completed tasks with `complete_check_at_run`
28 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
29 | By setting `gokart.TaskOnKart.complete_check_at_run` to True, the existence of the cache can be rechecked at run() time.
30 |
31 | Default is True, but if the check takes too much time, you can set to False to inactivate the check.
32 |
33 | .. code:: python
34 |
35 | class SampleTask1(gokart.TaskOnKart):
36 | complete_check_at_run = False
37 |
38 |
--------------------------------------------------------------------------------
/docs/for_pandas.rst:
--------------------------------------------------------------------------------
1 | For Pandas
2 | ==========
3 |
4 | Gokart has several features for Pandas.
5 |
6 |
7 | Pandas Type Config
8 | ------------------
9 |
10 | Pandas has a feature that converts the type of column(s) automatically. This feature sometimes cause wrong result. To avoid unintentional type conversion of pandas, we can specify a column name to check the type of Task input and output in gokart.
11 |
12 |
13 | .. code:: python
14 |
15 | from typing import Any, Dict
16 | import pandas as pd
17 | import gokart
18 |
19 |
20 | # Please define a class which inherits `gokart.PandasTypeConfig`.
21 | class SamplePandasTypeConfig(gokart.PandasTypeConfig):
22 |
23 | @classmethod
24 | def type_dict(cls) -> Dict[str, Any]:
25 | return {'int_column': int}
26 |
27 |
28 | class SampleTask(gokart.TaskOnKart[pd.DataFrame]):
29 |
30 | def run(self):
31 | # [PandasTypeError] because expected type is `int`, but `str` is passed.
32 | df = pd.DataFrame(dict(int_column=['a']))
33 | self.dump(df)
34 |
35 | This is useful when dataframe has nullable columns because pandas auto-conversion often fails in such case.
36 |
37 | Easy to Load DataFrame
38 | ----------------------
39 |
40 | The :func:`~gokart.task.TaskOnKart.load` method is used to load input ``pandas.DataFrame``.
41 |
42 | .. code:: python
43 |
44 | def requires(self):
45 | return MakeDataFrameTask()
46 |
47 | def run(self):
48 | df = self.load()
49 |
50 | Please refer to :func:`~gokart.task.TaskOnKart.load`.
51 |
52 |
53 | Fail on empty DataFrame
54 | -----------------------
55 |
56 | When the :attr:`~gokart.task.TaskOnKart.fail_on_empty_dump` parameter is true, the :func:`~gokart.task.TaskOnKart.dump()` method raises :class:`~gokart.errors.EmptyDumpError` on trying to dump empty ``pandas.DataFrame``.
57 |
58 |
59 | .. code:: python
60 |
61 | import gokart
62 |
63 |
64 | class EmptyTask(gokart.TaskOnKart):
65 | def run(self):
66 | df = pd.DataFrame()
67 | self.dump(df)
68 |
69 |
70 | ::
71 |
72 | $ python main.py EmptyTask --fail-on-empty-dump true
73 | # EmptyDumpError
74 | $ python main.py EmptyTask
75 | # Task will be ran and outputs an empty dataframe
76 |
77 |
78 | Empty caches sometimes hide bugs and let us spend much time debugging. This feature notifies us some bugs (including wrong datasources) in the early stage.
79 |
80 | Please refer to :attr:`~gokart.task.TaskOnKart.fail_on_empty_dump`.
81 |
--------------------------------------------------------------------------------
/docs/gokart.rst:
--------------------------------------------------------------------------------
1 | gokart package
2 | ==============
3 |
4 | Submodules
5 | ----------
6 |
7 | gokart.file\_processor module
8 | -----------------------------
9 |
10 | .. automodule:: gokart.file_processor
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | gokart.info module
16 | ------------------
17 |
18 | .. automodule:: gokart.info
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | gokart.parameter module
24 | -----------------------
25 |
26 | .. automodule:: gokart.parameter
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | gokart.run module
32 | -----------------
33 |
34 | .. automodule:: gokart.run
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | gokart.s3\_config module
40 | ------------------------
41 |
42 | .. automodule:: gokart.s3_config
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | gokart.target module
48 | --------------------
49 |
50 | .. automodule:: gokart.target
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | gokart.task module
56 | ------------------
57 |
58 | .. automodule:: gokart.task
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | gokart.workspace\_management module
64 | -----------------------------------
65 |
66 | .. automodule:: gokart.workspace_management
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | gokart.zip\_client module
72 | -------------------------
73 |
74 | .. automodule:: gokart.zip_client
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
79 |
80 | Module contents
81 | ---------------
82 |
83 | .. automodule:: gokart
84 | :members:
85 | :undoc-members:
86 | :show-inheritance:
87 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. gokart documentation master file, created by
2 | sphinx-quickstart on Fri Jan 11 07:59:25 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to gokart's documentation!
7 | ==================================
8 |
9 | Useful links: `GitHub `_ | `cookiecutter gokart `_
10 |
11 | `Gokart `_ is a wrapper of the data pipeline library `luigi `_. Gokart solves "**reproducibility**", "**task dependencies**", "**constraints of good code**", and "**ease of use**" for Machine Learning Pipeline.
12 |
13 |
14 | Good thing about gokart
15 | -----------------------
16 |
17 | Here are some good things about gokart.
18 |
19 | - The following data for each Task is stored separately in a pkl file with hash value
20 | - task output data
21 | - imported all module versions
22 | - task processing time
23 | - random seed in task
24 | - displayed log
25 | - all parameters set as class variables in the task
26 | - If change parameter of Task, rerun spontaneously.
27 | - The above file will be generated with a different hash value
28 | - The hash value of dependent task will also change and both will be rerun
29 | - Support GCS or S3
30 | - The above output is exchanged between tasks as an intermediate file, which is memory-friendly
31 | - pandas.DataFrame type and column checking during I/O
32 | - Directory structure of saved files is automatically determined from structure of script
33 | - Seeds for numpy and random are automatically fixed
34 | - Can code while adhering to SOLID principles as much as possible
35 | - Tasks are locked via redis even if they run in parallel
36 |
37 | **These are all functions baptized for creating Machine Learning batches. Provides an excellent environment for reproducibility and team development.**
38 |
39 |
40 |
41 | Getting started
42 | -----------------
43 |
44 | .. toctree::
45 | :maxdepth: 2
46 |
47 | intro_to_gokart
48 | tutorial
49 |
50 | User Guide
51 | -----------------
52 |
53 | .. toctree::
54 | :maxdepth: 2
55 |
56 | task_on_kart
57 | task_parameters
58 | setting_task_parameters
59 | task_settings
60 | task_information
61 | logging
62 | slack_notification
63 | using_task_task_conflict_prevention_lock
64 | efficient_run_on_multi_workers
65 | for_pandas
66 | mypy_plugin
67 |
68 | API References
69 | --------------
70 | .. toctree::
71 | :maxdepth: 2
72 |
73 | gokart
74 |
75 |
76 | Indices and tables
77 | -------------------
78 |
79 | * :ref:`genindex`
80 | * :ref:`modindex`
81 | * :ref:`search`
82 |
--------------------------------------------------------------------------------
/docs/intro_to_gokart.rst:
--------------------------------------------------------------------------------
1 | Intro To Gokart
2 | ===============
3 |
4 |
5 | Installation
6 | ------------
7 |
8 | Within the activated Python environment, use the following command to install gokart.
9 |
10 | .. code:: sh
11 |
12 | pip install gokart
13 |
14 |
15 |
16 | Quickstart
17 | ----------
18 |
19 | A minimal gokart tasks looks something like this:
20 |
21 |
22 | .. code:: python
23 |
24 | import gokart
25 |
26 | class Example(gokart.TaskOnKart[str]):
27 | def run(self):
28 | self.dump('Hello, world!')
29 |
30 | task = Example()
31 | output = gokart.build(task)
32 | print(output)
33 |
34 |
35 | ``gokart.build`` return the result of dump by ``gokart.TaskOnKart``. The example will output the following.
36 |
37 |
38 | .. code:: sh
39 |
40 | Hello, world!
41 |
42 |
43 | ``gokart`` records all the information needed for Machine Learning. By default, ``resources`` will be generated in the same directory as the script.
44 |
45 | .. code:: sh
46 |
47 | $ tree resources/
48 | resources/
49 | ├── __main__
50 | │ └── Example_8441c59b5ce0113396d53509f19371fb.pkl
51 | └── log
52 | ├── module_versions
53 | │ └── Example_8441c59b5ce0113396d53509f19371fb.txt
54 | ├── processing_time
55 | │ └── Example_8441c59b5ce0113396d53509f19371fb.pkl
56 | ├── random_seed
57 | │ └── Example_8441c59b5ce0113396d53509f19371fb.pkl
58 | ├── task_log
59 | │ └── Example_8441c59b5ce0113396d53509f19371fb.pkl
60 | └── task_params
61 | └── Example_8441c59b5ce0113396d53509f19371fb.pkl
62 |
63 |
64 | The result of dumping the task will be saved in the ``__name__`` directory.
65 |
66 |
67 | .. code:: python
68 |
69 | import pickle
70 |
71 | with open('resources/__main__/Example_8441c59b5ce0113396d53509f19371fb.pkl', 'rb') as f:
72 | print(pickle.load(f)) # Hello, world!
73 |
74 |
75 | That will be given hash value depending on the parameter of the task. This means that if you change the parameter of the task, the hash value will change, and change output file. This is very useful when changing parameters and experimenting. Please refer to :doc:`task_parameters` section for task parameters. Also see :doc:`task_on_kart` section for information on how to return this output destination.
76 |
77 |
78 | In addition, the following files are automatically saved as ``log``.
79 |
80 | - ``module_versions``: The versions of all modules that were imported when the script was executed. For reproducibility.
81 | - ``processing_time``: The execution time of the task.
82 | - ``random_seed``: This is random seed of python and numpy. For reproducibility in Machine Learning. Please refer to :doc:`task_settings` section.
83 | - ``task_log``: This is the output of the task logger.
84 | - ``task_params``: This is task's parameters. Please refer to :doc:`task_parameters` section.
85 |
86 |
87 | How to running task
88 | -------------------
89 |
90 | Gokart has ``run`` and ``build`` methods for running task. Each has a different purpose.
91 |
92 | - ``gokart.run``: uses arguments on the shell. return retcode.
93 | - ``gokart.build``: uses inline code on jupyter notebook, IPython, and more. return task output.
94 |
95 |
96 | .. note::
97 |
98 | It is not recommended to use ``gokart.run`` and ``gokart.build`` together in the same script. Because ``gokart.build`` will clear the contents of ``luigi.register``. It's the only way to handle duplicate tasks.
99 |
100 |
101 | gokart.run
102 | ~~~~~~~~~~
103 |
104 | The :func:`~gokart.run` is running on shell.
105 |
106 | .. code:: python
107 |
108 | import gokart
109 | import luigi
110 |
111 | class SampleTask(gokart.TaskOnKart[str]):
112 | param = luigi.Parameter()
113 |
114 | def run(self):
115 | self.dump(self.param)
116 |
117 | gokart.run()
118 |
119 |
120 | .. code:: sh
121 |
122 | python sample.py SampleTask --local-scheduler --param=hello
123 |
124 |
125 | If you were to write it in Python, it would be the same as the following behavior.
126 |
127 |
128 | .. code:: python
129 |
130 | gokart.run(['SampleTask', '--local-scheduler', '--param=hello'])
131 |
132 |
133 | gokart.build
134 | ~~~~~~~~~~~~
135 |
136 | The :func:`~gokart.build` is inline code.
137 |
138 | .. code:: python
139 |
140 | import gokart
141 | import luigi
142 |
143 | class SampleTask(gokart.TaskOnKart[str]):
144 | param = luigi.Parameter()
145 |
146 | def run(self):
147 | self.dump(self.param)
148 |
149 | gokart.build(SampleTask(param='hello'), return_value=False)
150 |
151 |
152 | To output logs of each tasks, you can pass `~log_level` parameter to `~gokart.build` as follows:
153 |
154 | .. code:: python
155 |
156 | gokart.build(SampleTask(param='hello'), return_value=False, log_level=logging.DEBUG)
157 |
158 |
159 | This feature is very useful for running `~gokart` on jupyter notebook.
160 | When some tasks are failed, gokart.build raises GokartBuildError. If you have to get tracebacks, you should set `log_level` as `logging.DEBUG`.
161 |
--------------------------------------------------------------------------------
/docs/logging.rst:
--------------------------------------------------------------------------------
1 | Logging
2 | =======
3 |
4 | How to set up a common logger for gokart.
5 |
6 |
7 | Core settings
8 | -------------
9 |
10 | Please write a configuration file similar to the following:
11 |
12 | ::
13 |
14 | # base.ini
15 | [core]
16 | logging_conf_file=./conf/logging.ini
17 |
18 | .. code:: python
19 |
20 | import gokart
21 | gokart.add_config('base.ini')
22 |
23 |
24 | Logger ini file
25 | ---------------
26 |
27 | It is the same as a general logging.ini file.
28 |
29 | ::
30 |
31 | [loggers]
32 | keys=root,luigi,luigi-interface,gokart,gokart.file_processor
33 |
34 | [handlers]
35 | keys=stderrHandler
36 |
37 | [formatters]
38 | keys=simpleFormatter
39 |
40 | [logger_root]
41 | level=INFO
42 | handlers=stderrHandler
43 |
44 | [logger_gokart]
45 | level=INFO
46 | handlers=stderrHandler
47 | qualname=gokart
48 | propagate=0
49 |
50 | [logger_luigi]
51 | level=INFO
52 | handlers=stderrHandler
53 | qualname=luigi
54 | propagate=0
55 |
56 | [logger_luigi-interface]
57 | level=INFO
58 | handlers=stderrHandler
59 | qualname=luigi-interface
60 | propagate=0
61 |
62 | [logger_gokart.file_processor]
63 | level=CRITICAL
64 | handlers=stderrHandler
65 | qualname=gokart.file_processor
66 |
67 | [handler_stderrHandler]
68 | class=StreamHandler
69 | formatter=simpleFormatter
70 | args=(sys.stdout,)
71 |
72 | [formatter_simpleFormatter]
73 | format=[%(asctime)s][%(name)s][%(levelname)s](%(filename)s:%(lineno)s) %(message)s
74 | datefmt=%Y/%m/%d %H:%M:%S
75 |
76 | Please refer to `Python logging documentation `_
77 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/mypy_plugin.rst:
--------------------------------------------------------------------------------
1 | [Experimental] Mypy plugin
2 | ===========================
3 |
4 | Mypy plugin provides type checking for gokart tasks using Mypy.
5 | This feature is experimental.
6 |
7 | How to use
8 | --------------
9 |
10 | Configure Mypy to use this plugin by adding the following to your ``mypy.ini`` file:
11 |
12 | .. code:: ini
13 |
14 | [mypy]
15 | plugins = gokart.mypy:plugin
16 |
17 | or by adding the following to your ``pyproject.toml`` file:
18 |
19 | .. code:: toml
20 |
21 | [tool.mypy]
22 | plugins = ["gokart.mypy"]
23 |
24 | Then, run Mypy as usual.
25 |
26 | Examples
27 | --------
28 |
29 | For example the following code linted by Mypy:
30 |
31 | .. code:: python
32 |
33 | import gokart
34 | import luigi
35 |
36 |
37 | class Foo(gokart.TaskOnKart):
38 | # NOTE: must all the parameters be annotated
39 | foo: int = luigi.IntParameter(default=1)
40 | bar: str = luigi.Parameter()
41 |
42 |
43 |
44 | Foo(foo=1, bar='2') # OK
45 | Foo(foo='1') # NG because foo is not int and bar is missing
46 |
47 |
48 | Mypy plugin checks TaskOnKart generic types.
49 |
50 | .. code:: python
51 |
52 | class SampleTask(gokart.TaskOnKart):
53 | str_task: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter()
54 | int_task: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
55 |
56 | def requires(self):
57 | return dict(str=self.str_task, int=self.int_task)
58 |
59 | def run(self):
60 | s = self.load(self.str_task) # This type is inferred with "str"
61 | i = self.load(self.int_task) # This type is inferred with "int"
62 |
63 | SampleTask(
64 | str_task=StrTask(), # mypy ok
65 | int_task=StrTask(), # mypy error: Argument "int_task" to "StrTask" has incompatible type "StrTask"; expected "TaskOnKart[int]
66 | )
67 |
68 | Configurations (only pyproject.toml)
69 | -----------------------------------
70 |
71 | You can configure the Mypy plugin using the ``pyproject.toml`` file.
72 | The following options are available:
73 |
74 | .. code:: toml
75 |
76 | [tool.gokart-mypy]
77 | # If true, Mypy will raise an error if a task is missing required parameters.
78 | # This configuration causes an error when the parameters set by `luigi.Config()`
79 | # Default: false
80 | disallow_missing_parameters = true
81 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | Sphinx
2 | gokart
3 | sphinx-rtd-theme
4 |
--------------------------------------------------------------------------------
/docs/setting_task_parameters.rst:
--------------------------------------------------------------------------------
1 | ============================
2 | Setting Task Parameters
3 | ============================
4 |
5 | There are several ways to set task parameters.
6 |
7 | - Set parameter from command line
8 | - Set parameter at config file
9 | - Set parameter at upstream task
10 | - Inherit parameter from other task
11 |
12 |
13 | Set parameter from command line
14 | ==================================
15 | .. code:: sh
16 |
17 | python main.py sample.SomeTask --SomeTask-param=Hello
18 |
19 | Parameter of each task can be set as a command line parameter in ``--[task name]-[parameter name]=[value]`` format.
20 |
21 |
22 | Set parameter at config file
23 | ==================================
24 | ::
25 |
26 | [sample.SomeTask]
27 | param = Hello
28 |
29 | Above config file (``config.ini``) must be read before ``gokart.run()`` as the following code:
30 |
31 | .. code:: python
32 |
33 | if __name__ == '__main__':
34 | gokart.add_config('./conf/config.ini')
35 | gokart.run()
36 |
37 |
38 | It can also be loaded from environment variable as the following code:
39 |
40 | ::
41 |
42 | [sample.SomeTask]
43 | param=${PARAMS}
44 |
45 | [TaskOnKart]
46 | workspace_directory=${WORKSPACE_DIRECTORY}
47 |
48 | The advantages of using environment variables are 1) important information will not be logged 2) common settings can be used.
49 |
50 |
51 | Set parameter at upstream task
52 | ==================================
53 |
54 | Parameters can be set at the upstream task, as in a typical pipeline.
55 |
56 | .. code:: python
57 |
58 | class UpstreamTask(gokart.TaskOnKart):
59 | def requires(self):
60 | return dict(sometask=SomeTask(param='Hello'))
61 |
62 |
63 | Inherit parameter from other task
64 | ==================================
65 |
66 | Parameter values can be inherited from other task using ``@inherits_config_params`` decorator.
67 |
68 | .. code:: python
69 |
70 | class MasterConfig(luigi.Config):
71 | param: str = luigi.Parameter()
72 | param2: str = luigi.Parameter()
73 |
74 | @inherits_config_params(MasterConfig)
75 | class SomeTask(gokart.TaskOnKart):
76 | param: str = luigi.Parameter()
77 |
78 |
79 | This is useful when multiple tasks has the same parameter. In the above example, parameter settings of ``MasterConfig`` will be inherited to all tasks decorated with ``@inherits_config_params(MasterConfig)`` as ``SomeTask``.
80 |
81 | Note that only parameters which exist in both ``MasterConfig`` and ``SomeTask`` will be inherited.
82 | In the above example, ``param2`` will not be available in ``SomeTask``, since ``SomeTask`` does not have ``param2`` parameter.
83 |
84 | .. code:: python
85 |
86 | class MasterConfig(luigi.Config):
87 | param: str = luigi.Parameter()
88 | param2: str = luigi.Parameter()
89 |
90 | @inherits_config_params(MasterConfig, parameter_alias={'param2': 'param3'})
91 | class SomeTask(gokart.TaskOnKart):
92 | param3: str = luigi.Parameter()
93 |
94 |
95 | You may also set a parameter name alias by setting ``parameter_alias``.
96 | ``parameter_alias`` must be a dictionary of key: inheriting task's parameter name, value: decorating task's parameter name.
97 |
98 | In the above example, ``SomeTask.param3`` will be set to same value as ``MasterConfig.param2``.
99 |
--------------------------------------------------------------------------------
/docs/slack_notification.rst:
--------------------------------------------------------------------------------
1 | Slack notification
2 | =========================
3 |
4 | Prerequisites
5 | -------------
6 |
7 | Prepare following environmental variables:
8 |
9 | .. code:: sh
10 |
11 | export SLACK_TOKEN=xoxb-your-token // should use token starts with "xoxb-" (bot token is preferable)
12 | export SLACK_CHANNEL=channel-name // not "#channel-name", just "channel-name"
13 |
14 |
15 | A Slack bot token can obtain from `slack app document `_.
16 |
17 | A bot token needs following scopes:
18 |
19 | - `channels:read`
20 | - `chat:write`
21 | - `files:write`
22 |
23 | More about scopes are `slack scopes document `_.
24 |
25 | Implement Slack notification
26 | ----------------------------
27 |
28 | Write following codes pass arguments to your gokart workflow.
29 |
30 | .. code:: python
31 |
32 | cmdline_args = sys.argv[1:]
33 | if 'SLACK_CHANNEL' in os.environ:
34 | cmdline_args.append(f'--SlackConfig-channel={os.environ["SLACK_CHANNEL"]}')
35 | if 'SLACK_TO_USER' in os.environ:
36 | cmdline_args.append(f'--SlackConfig-to-user={os.environ["SLACK_TO_USER"]}')
37 | gokart.run(cmdline_args)
38 |
39 |
--------------------------------------------------------------------------------
/docs/task_parameters.rst:
--------------------------------------------------------------------------------
1 | =================
2 | Task Parameters
3 | =================
4 |
5 | Luigi Parameter
6 | ================
7 |
8 | We can set parameters for tasks.
9 | Also please refer to :doc:`task_settings` section.
10 |
11 | .. code:: python
12 |
13 | class Task(gokart.TaskOnKart):
14 | param_a = luigi.Parameter()
15 | param_c = luigi.ListParameter()
16 | param_d = luigi.IntParameter(default=1)
17 |
18 | Please refer to `luigi document `_ for a list of parameter types.
19 |
20 |
21 | Gokart Parameter
22 | ================
23 |
24 | There are also parameters provided by gokart.
25 |
26 | - gokart.TaskInstanceParameter
27 | - gokart.ListTaskInstanceParameter
28 | - gokart.ExplicitBoolParameter
29 |
30 |
31 | gokart.TaskInstanceParameter
32 | --------------------------------
33 |
34 | The :func:`~gokart.parameter.TaskInstanceParameter` executes a task using the results of a task as dynamic parameters.
35 |
36 |
37 | .. code:: python
38 |
39 | class TaskA(gokart.TaskOnKart[str]):
40 | def run(self):
41 | self.dump('Hello')
42 |
43 |
44 | class TaskB(gokart.TaskOnKart[str]):
45 | require_task = gokart.TaskInstanceParameter()
46 |
47 | def requires(self):
48 | return self.require_task
49 |
50 | def run(self):
51 | task_a = self.load()
52 | self.dump(','.join([task_a, 'world']))
53 |
54 | task = TaskB(require_task=TaskA())
55 | print(gokart.build(task)) # Hello,world
56 |
57 |
58 | Helps to create a pipeline.
59 |
60 |
61 | gokart.ListTaskInstanceParameter
62 | -------------------------------------
63 |
64 | The :func:`~gokart.parameter.ListTaskInstanceParameter` is list of TaskInstanceParameter.
65 |
66 |
67 | gokart.ExplicitBoolParameter
68 | -----------------------------------
69 |
70 | The :func:`~gokart.parameter.ExplicitBoolParameter` is parameter for explicitly specified value.
71 |
72 | ``luigi.BoolParameter`` already has "explicit parsing" feature, but also still has implicit behavior like follows.
73 |
74 | ::
75 |
76 | $ python main.py Task --param
77 | # param will be set as True
78 | $ python main.py Task
79 | # param will be set as False
80 |
81 | ``ExplicitBoolParameter`` solves these problems on parameters from command line.
82 |
83 |
84 | gokart.SerializableParameter
85 | ----------------------------
86 |
87 | The :func:`~gokart.parameter.SerializableParameter` is a parameter for any object that can be serialized and deserialized.
88 | This parameter is particularly useful when you want to pass a complex object or a set of parameters to a task.
89 |
90 | The object must implement the following methods:
91 |
92 | - ``gokart_serialize``: Serialize the object to a string. This serialized string must uniquely identify the object to enable task caching.
93 | Note that it is not required for deserialization.
94 | - ``gokart_deserialize``: Deserialize the object from a string, typically used for CLI arguments.
95 |
96 | Example
97 | ^^^^^^^
98 |
99 | .. code-block:: python
100 |
101 | import json
102 | from dataclasses import dataclass
103 |
104 | import gokart
105 |
106 | @dataclass(frozen=True)
107 | class Config:
108 | foo: int
109 | # The `bar` field does not affect the result of the task.
110 | # Similar to `luigi.Parameter(significant=False)`.
111 | bar: str
112 |
113 | def gokart_serialize(self) -> str:
114 | # Serialize only the `foo` field since `bar` is irrelevant for caching.
115 | return json.dumps({'foo': self.foo})
116 |
117 | @classmethod
118 | def gokart_deserialize(cls, s: str) -> 'Config':
119 | # Deserialize the object from the provided string.
120 | return cls(**json.loads(s))
121 |
122 | class DummyTask(gokart.TaskOnKart):
123 | config: Config = gokart.SerializableParameter(object_type=Config)
124 |
125 | def run(self):
126 | # Save the `config` object as part of the task result.
127 | self.dump(self.config)
128 |
--------------------------------------------------------------------------------
/docs/task_settings.rst:
--------------------------------------------------------------------------------
1 | Task Settings
2 | =============
3 |
4 | Task settings. Also please refer to :doc:`task_parameters` section.
5 |
6 |
7 | Directory to Save Outputs
8 | -------------------------
9 |
10 | We can use both a local directory and the S3 to save outputs.
11 | If you would like to use local directory, please set a local directory path to :attr:`~gokart.task.TaskOnKart.workspace_directory`. Please refer to :doc:`task_parameters` for how to set it up.
12 |
13 | It is recommended to use the config file since it does not change much.
14 |
15 | ::
16 |
17 | # base.ini
18 | [TaskOnKart]
19 | workspace_directory=${TASK_WORKSPACE_DIRECTORY}
20 |
21 | .. code:: python
22 |
23 | # main.py
24 | import gokart
25 | gokart.add_config('base.ini')
26 |
27 |
28 | To use the S3 or GCS repository, please set the bucket path as ``s3://{YOUR_REPOSITORY_NAME}`` or ``gs://{YOUR_REPOSITORY_NAME}`` respectively.
29 |
30 | If use S3 or GCS, please set credential information to Environment Variables.
31 |
32 | .. code:: sh
33 |
34 | # S3
35 | export AWS_ACCESS_KEY_ID='~~~' # AWS access key
36 | export AWS_SECRET_ACCESS_KEY='~~~' # AWS secret access key
37 |
38 | # GCS
39 | export GCS_CREDENTIAL='~~~' # GCS credential
40 | export DISCOVER_CACHE_LOCAL_PATH='~~~' # The local file path of discover api cache.
41 |
42 |
43 | Rerun task
44 | ----------
45 |
46 | There are times when we want to rerun a task, such as when change script or on batch. Please use the ``rerun`` parameter or add an arbitrary parameter.
47 |
48 |
49 | When set rerun as follows:
50 |
51 | .. code:: python
52 |
53 | # rerun TaskA
54 | gokart.build(Task(rerun=True))
55 |
56 |
57 | When used from an argument as follows:
58 |
59 | .. code:: python
60 |
61 | # main.py
62 | class Task(gokart.TaskOnKart[str]):
63 | def run(self):
64 | self.dump('hello')
65 |
66 | .. code:: sh
67 |
68 | python main.py Task --local-scheduler --rerun
69 |
70 |
71 | ``rerun`` parameter will look at the dependent tasks up to one level.
72 |
73 | Example: Suppose we have a straight line pipeline composed of TaskA, TaskB and TaskC, and TaskC is an endpoint of this pipeline. We also suppose that all the tasks have already been executed.
74 |
75 | - TaskA(rerun=True) -> TaskB -> TaskC # not rerunning
76 | - TaskA -> TaskB(rerun=True) -> TaskC # rerunning TaskB and TaskC
77 |
78 | This is due to the way intermediate files are handled. ``rerun`` parameter is ``significant=False``, it does not affect the hash value. It is very important to understand this difference.
79 |
80 |
81 | If you want to change the parameter of TaskA and rerun TaskB and TaskC, recommend adding an arbitrary parameter.
82 |
83 | .. code:: python
84 |
85 | class TaskA(gokart.TaskOnKart):
86 | __version = luigi.IntParameter(default=1)
87 |
88 | If the hash value of TaskA will change, the dependent tasks (in this case, TaskB and TaskC) will rerun.
89 |
90 |
91 | Fix random seed
92 | ---------------
93 |
94 | Every task has a parameter named :attr:`~gokart.task.TaskOnKart.fix_random_seed_methods` and :attr:`~gokart.task.TaskOnKart.fix_random_seed_value`. This can be used to fix the random seed.
95 |
96 |
97 | .. code:: python
98 |
99 | import gokart
100 | import random
101 | import numpy
102 | import torch
103 |
104 | class Task(gokart.TaskOnKart[dict[str, Any]]):
105 | def run(self):
106 | x = [random.randint(0, 100) for _ in range(0, 10)]
107 | y = [np.random.randint(0, 100) for _ in range(0, 10)]
108 | z = [torch.randn(1).tolist()[0] for _ in range(0, 5)]
109 | self.dump({'random': x, 'numpy': y, 'torch': z})
110 |
111 | gokart.build(
112 | Task(
113 | fix_random_seed_methods=[
114 | "random.seed",
115 | "numpy.random.seed",
116 | "torch.random.manual_seed"],
117 | fix_random_seed_value=57))
118 |
119 | ::
120 |
121 | # //--- The output is as follows every time. ---
122 | # {'random': [65, 41, 61, 37, 55, 81, 48, 2, 94, 21],
123 | # 'numpy': [79, 86, 5, 22, 79, 98, 56, 40, 81, 37], 'torch': []}
124 | # 'torch': [0.14460121095180511, -0.11649507284164429,
125 | # 0.6928958296775818, -0.916053831577301, 0.7317505478858948]}
126 |
127 | This will be useful for using Machine Learning Libraries.
128 |
--------------------------------------------------------------------------------
/docs/using_task_task_conflict_prevention_lock.rst:
--------------------------------------------------------------------------------
1 | Task conflict prevention lock
2 | =========================
3 |
4 | If there is a possibility of multiple worker nodes executing the same task, task cache conflict may happen.
5 | Specifically, while node A is loading the cache of a task, node B may be writing to it.
6 | This can lead to reading an inappropriate data and other unwanted behaviors.
7 |
8 | The redis lock introduced in this page is a feature to prevent such cache collisions.
9 |
10 | Requires
11 | --------
12 |
13 | You need to install `redis `_ for using this advanced feature.
14 |
15 |
16 | How to use
17 | -----------
18 |
19 |
20 | 1. Set up a redis server at somewhere accessible from gokart/luigi jobs.
21 |
22 | e.g. Following script will run redis at your localhost.
23 |
24 | .. code:: bash
25 |
26 | $ redis-server
27 |
28 | 2. Set redis server hostname and port number as parameters of gokart.TaskOnKart().
29 |
30 | You can set it by adding ``--redis-host=[your-redis-localhost] --redis-port=[redis-port-number]`` options to gokart python script.
31 |
32 | e.g.
33 |
34 | .. code:: bash
35 |
36 | python main.py sample.SomeTask --local-scheduler --redis-host=localhost --redis-port=6379
37 |
38 |
39 | Alternatively, you may set parameters at config file.
40 |
41 | e.g.
42 |
43 | .. code::
44 |
45 | [TaskOnKart]
46 | redis_host=localhost
47 | redis_port=6379
48 |
49 | 3. Done
50 |
51 | With the above configuration, all tasks that inherits gokart.TaskOnKart will ask the redis server if any other node is not trying to access the same cache file at the same time whenever they access the file with dump or load.
52 |
--------------------------------------------------------------------------------
/examples/gokart_notebook_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "gokart 1.0.2\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "!pip list | grep gokart"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 3,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import gokart\n",
27 | "import luigi"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "# Examples of using gokart at jupyter notebook"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "## Basic Usage\n",
42 | "This is a very basic usage, just to dump a run result of ExampleTaskA."
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 4,
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "name": "stdout",
52 | "output_type": "stream",
53 | "text": [
54 | "example_2\n"
55 | ]
56 | }
57 | ],
58 | "source": [
59 | "class ExampleTaskA(gokart.TaskOnKart):\n",
60 | " param = luigi.Parameter()\n",
61 | " int_param = luigi.IntParameter(default=2)\n",
62 | "\n",
63 | " def run(self):\n",
64 | " self.dump(f'DONE {self.param}_{self.int_param}')\n",
65 | "\n",
66 | " \n",
67 | "task_a = ExampleTaskA(param='example')\n",
68 | "output = gokart.build(task=task_a)\n",
69 | "print(output)"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "## Make tasks dependencies with `requires()`\n",
77 | "ExampleTaskB is dependent on ExampleTaskC and ExampleTaskD. They are defined in `ExampleTaskB.requires()`."
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 5,
83 | "metadata": {},
84 | "outputs": [
85 | {
86 | "name": "stdout",
87 | "output_type": "stream",
88 | "text": [
89 | "DONE example_TASKC_TASKD\n"
90 | ]
91 | }
92 | ],
93 | "source": [
94 | "class ExampleTaskC(gokart.TaskOnKart):\n",
95 | " def run(self):\n",
96 | " self.dump('TASKC')\n",
97 | " \n",
98 | "class ExampleTaskD(gokart.TaskOnKart):\n",
99 | " def run(self):\n",
100 | " self.dump('TASKD')\n",
101 | "\n",
102 | "class ExampleTaskB(gokart.TaskOnKart):\n",
103 | " param = luigi.Parameter()\n",
104 | "\n",
105 | " def requires(self):\n",
106 | " return dict(task_c=ExampleTaskC(), task_d=ExampleTaskD())\n",
107 | "\n",
108 | " def run(self):\n",
109 | " task_c = self.load('task_c')\n",
110 | " task_d = self.load('task_d')\n",
111 | " self.dump(f'DONE {self.param}_{task_c}_{task_d}')\n",
112 | " \n",
113 | "task_b = ExampleTaskB(param='example')\n",
114 | "output = gokart.build(task=task_b)\n",
115 | "print(output)"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {},
121 | "source": [
122 | "## Make tasks dependencies with TaskInstanceParameter\n",
123 | "The dependencies are same as previous example, however they are defined at the outside of the task instead of defied at `ExampleTaskB.requires()`."
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 6,
129 | "metadata": {},
130 | "outputs": [
131 | {
132 | "name": "stdout",
133 | "output_type": "stream",
134 | "text": [
135 | "DONE example_TASKC_TASKD\n"
136 | ]
137 | }
138 | ],
139 | "source": [
140 | "class ExampleTaskC(gokart.TaskOnKart):\n",
141 | " def run(self):\n",
142 | " self.dump('TASKC')\n",
143 | " \n",
144 | "class ExampleTaskD(gokart.TaskOnKart):\n",
145 | " def run(self):\n",
146 | " self.dump('TASKD')\n",
147 | "\n",
148 | "class ExampleTaskB(gokart.TaskOnKart):\n",
149 | " param = luigi.Parameter()\n",
150 | " task_1 = gokart.TaskInstanceParameter()\n",
151 | " task_2 = gokart.TaskInstanceParameter()\n",
152 | "\n",
153 | " def requires(self):\n",
154 | " return dict(task_1=self.task_1, task_2=self.task_2) # required tasks are decided from the task parameters `task_1` and `task_2`\n",
155 | "\n",
156 | " def run(self):\n",
157 | " task_1 = self.load('task_1')\n",
158 | " task_2 = self.load('task_2')\n",
159 | " self.dump(f'DONE {self.param}_{task_1}_{task_2}')\n",
160 | " \n",
161 | "task_b = ExampleTaskB(param='example', task_1=ExampleTaskC(), task_2=ExampleTaskD()) # Dependent tasks are defined here\n",
162 | "output = gokart.build(task=task_b)\n",
163 | "print(output)"
164 | ]
165 | }
166 | ],
167 | "metadata": {
168 | "kernelspec": {
169 | "display_name": "Python 3.8.8 64-bit ('3.8.8': pyenv)",
170 | "name": "python388jvsc74a57bd026997db2bf0f03e18da4e606f276befe0d6bf7cab2a6bb74742969d5bbde02ca"
171 | },
172 | "language_info": {
173 | "codemirror_mode": {
174 | "name": "ipython",
175 | "version": 3
176 | },
177 | "file_extension": ".py",
178 | "mimetype": "text/x-python",
179 | "name": "python",
180 | "nbconvert_exporter": "python",
181 | "pygments_lexer": "ipython3",
182 | "version": "3.8.8"
183 | },
184 | "metadata": {
185 | "interpreter": {
186 | "hash": "26997db2bf0f03e18da4e606f276befe0d6bf7cab2a6bb74742969d5bbde02ca"
187 | }
188 | },
189 | "orig_nbformat": 3
190 | },
191 | "nbformat": 4,
192 | "nbformat_minor": 2
193 | }
--------------------------------------------------------------------------------
/examples/logging.ini:
--------------------------------------------------------------------------------
1 | [loggers]
2 | keys=root,luigi,luigi-interface,gokart
3 |
4 | [handlers]
5 | keys=stderrHandler
6 |
7 | [formatters]
8 | keys=simpleFormatter
9 |
10 | [logger_root]
11 | level=INFO
12 | handlers=stderrHandler
13 |
14 | [logger_gokart]
15 | level=INFO
16 | handlers=stderrHandler
17 | qualname=gokart
18 | propagate=0
19 |
20 | [logger_luigi]
21 | level=INFO
22 | handlers=stderrHandler
23 | qualname=luigi
24 | propagate=0
25 |
26 | [logger_luigi-interface]
27 | level=INFO
28 | handlers=stderrHandler
29 | qualname=luigi-interface
30 | propagate=0
31 |
32 | [handler_stderrHandler]
33 | class=StreamHandler
34 | formatter=simpleFormatter
35 | args=(sys.stdout,)
36 |
37 | [formatter_simpleFormatter]
38 | format=level=%(levelname)s time=%(asctime)s name=%(name)s file=%(filename)s line=%(lineno)d message=%(message)s
39 | datefmt=%Y/%m/%d %H:%M:%S
40 | class=logging.Formatter
41 |
--------------------------------------------------------------------------------
/examples/param.ini:
--------------------------------------------------------------------------------
1 | [TaskOnKart]
2 | workspace_directory=./resource
3 | local_temporary_directory=./resource/tmp
4 |
5 | [core]
6 | logging_conf_file=logging.ini
7 |
8 |
--------------------------------------------------------------------------------
/gokart/__init__.py:
--------------------------------------------------------------------------------
1 | from gokart.build import WorkerSchedulerFactory, build # noqa:F401
2 | from gokart.info import make_tree_info, tree_info # noqa:F401
3 | from gokart.pandas_type_config import PandasTypeConfig # noqa:F401
4 | from gokart.parameter import ( # noqa:F401
5 | ExplicitBoolParameter,
6 | ListTaskInstanceParameter,
7 | SerializableParameter,
8 | TaskInstanceParameter,
9 | ZonedDateSecondParameter,
10 | )
11 | from gokart.run import run # noqa:F401
12 | from gokart.task import TaskOnKart # noqa:F401
13 | from gokart.testing import test_run # noqa:F401
14 | from gokart.tree.task_info import make_task_info_as_tree_str # noqa:F401
15 | from gokart.utils import add_config # noqa:F401
16 | from gokart.workspace_management import delete_local_unnecessary_outputs # noqa:F401
17 |
--------------------------------------------------------------------------------
/gokart/build.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import enum
4 | import logging
5 | import sys
6 | from dataclasses import dataclass
7 | from functools import partial
8 | from logging import getLogger
9 | from typing import Literal, Protocol, TypeVar, cast, overload
10 |
11 | import backoff
12 | import luigi
13 | from luigi import LuigiStatusCode, rpc, scheduler
14 |
15 | import gokart
16 | import gokart.tree.task_info
17 | from gokart import worker
18 | from gokart.conflict_prevention_lock.task_lock import TaskLockException
19 | from gokart.target import TargetOnKart
20 | from gokart.task import TaskOnKart
21 |
22 | T = TypeVar('T')
23 |
24 | logger: logging.Logger = logging.getLogger(__name__)
25 |
26 |
27 | class LoggerConfig:
28 | def __init__(self, level: int):
29 | self.logger = getLogger(__name__)
30 | self.default_level = self.logger.level
31 | self.level = level
32 |
33 | def __enter__(self):
34 | logging.disable(self.level - 10) # subtract 10 to disable below self.level
35 | self.logger.setLevel(self.level)
36 | return self
37 |
38 | def __exit__(self, exception_type, exception_value, traceback):
39 | logging.disable(self.default_level - 10) # subtract 10 to disable below self.level
40 | self.logger.setLevel(self.default_level)
41 |
42 |
43 | class GokartBuildError(Exception):
44 | """Raised when ``gokart.build`` failed. This exception contains raised exceptions in the task execution."""
45 |
46 | def __init__(self, message, raised_exceptions: dict[str, list[Exception]]):
47 | super().__init__(message)
48 | self.raised_exceptions = raised_exceptions
49 |
50 |
51 | class HasLockedTaskException(Exception):
52 | """Raised when the task failed to acquire the lock in the task execution."""
53 |
54 |
55 | class TaskLockExceptionRaisedFlag:
56 | def __init__(self):
57 | self.flag: bool = False
58 |
59 |
60 | class WorkerProtocol(Protocol):
61 | """Protocol for Worker.
62 | This protocol is determined by luigi.worker.Worker.
63 | """
64 |
65 | def add(self, task: TaskOnKart) -> bool: ...
66 |
67 | def run(self) -> bool: ...
68 |
69 | def __enter__(self) -> WorkerProtocol: ...
70 |
71 | def __exit__(self, type, value, traceback) -> Literal[False]: ...
72 |
73 |
74 | class WorkerSchedulerFactory:
75 | def create_local_scheduler(self) -> scheduler.Scheduler:
76 | return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False)
77 |
78 | def create_remote_scheduler(self, url) -> rpc.RemoteScheduler:
79 | return rpc.RemoteScheduler(url)
80 |
81 | def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant=False) -> WorkerProtocol:
82 | return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant)
83 |
84 |
85 | def _get_output(task: TaskOnKart[T]) -> T:
86 | output = task.output()
87 | # FIXME: currently, nested output is not supported
88 | if isinstance(output, list) or isinstance(output, tuple):
89 | return cast(T, [t.load() for t in output if isinstance(t, TargetOnKart)])
90 | if isinstance(output, dict):
91 | return cast(T, {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)})
92 | if isinstance(output, TargetOnKart):
93 | return output.load()
94 | raise ValueError(f'output type is not supported: {type(output)}')
95 |
96 |
97 | def _reset_register(keep={'gokart', 'luigi'}):
98 | """reset luigi.task_register.Register._reg everytime gokart.build called to avoid TaskClassAmbigiousException"""
99 | luigi.task_register.Register._reg = [
100 | x
101 | for x in luigi.task_register.Register._reg
102 | if (
103 | (x.__module__.split('.')[0] in keep) # keep luigi and gokart
104 | or (issubclass(x, gokart.PandasTypeConfig))
105 | ) # PandasTypeConfig should be kept
106 | ]
107 |
108 |
109 | class TaskDumpMode(enum.Enum):
110 | TREE = 'tree'
111 | TABLE = 'table'
112 | NONE = 'none'
113 |
114 |
115 | class TaskDumpOutputType(enum.Enum):
116 | PRINT = 'print'
117 | DUMP = 'dump'
118 | NONE = 'none'
119 |
120 |
121 | @dataclass
122 | class TaskDumpConfig:
123 | mode: TaskDumpMode = TaskDumpMode.NONE
124 | output_type: TaskDumpOutputType = TaskDumpOutputType.NONE
125 |
126 |
127 | if sys.version_info < (3, 10):
128 |
129 | def process_task_info(task: TaskOnKart, task_dump_config: TaskDumpConfig = TaskDumpConfig()):
130 | logger.warning('process_task_info is not supported in Python 3.9 or lower.')
131 | else:
132 | # FIXME: after dropping python3.9 support, change this import to direct implementation
133 | from .build_process_task_info import process_task_info
134 |
135 |
136 | @overload
137 | def build(
138 | task: TaskOnKart[T],
139 | return_value: Literal[True] = True,
140 | reset_register: bool = True,
141 | log_level: int = logging.ERROR,
142 | task_lock_exception_max_tries: int = 10,
143 | task_lock_exception_max_wait_seconds: int = 600,
144 | **env_params,
145 | ) -> T: ...
146 |
147 |
148 | @overload
149 | def build(
150 | task: TaskOnKart[T],
151 | return_value: Literal[False],
152 | reset_register: bool = True,
153 | log_level: int = logging.ERROR,
154 | task_lock_exception_max_tries: int = 10,
155 | task_lock_exception_max_wait_seconds: int = 600,
156 | **env_params,
157 | ) -> None: ...
158 |
159 |
160 | def build(
161 | task: TaskOnKart[T],
162 | return_value: bool = True,
163 | reset_register: bool = True,
164 | log_level: int = logging.ERROR,
165 | task_lock_exception_max_tries: int = 10,
166 | task_lock_exception_max_wait_seconds: int = 600,
167 | task_dump_config: TaskDumpConfig = TaskDumpConfig(),
168 | **env_params,
169 | ) -> T | None:
170 | """
171 | Run gokart task for local interpreter.
172 | Sharing the most of its parameters with luigi.build (see https://luigi.readthedocs.io/en/stable/api/luigi.html?highlight=build#luigi.build)
173 | """
174 | if reset_register:
175 | _reset_register()
176 | with LoggerConfig(level=log_level):
177 | log_handler_before_run = logging.StreamHandler()
178 | logger.addHandler(log_handler_before_run)
179 | process_task_info(task, task_dump_config)
180 | logger.removeHandler(log_handler_before_run)
181 | log_handler_before_run.close()
182 |
183 | task_lock_exception_raised = TaskLockExceptionRaisedFlag()
184 | raised_exceptions: dict[str, list[Exception]] = dict()
185 |
186 | @TaskOnKart.event_handler(luigi.Event.FAILURE)
187 | def when_failure(task, exception):
188 | if isinstance(exception, TaskLockException):
189 | task_lock_exception_raised.flag = True
190 | else:
191 | raised_exceptions.setdefault(task.make_unique_id(), []).append(exception)
192 |
193 | @backoff.on_exception(
194 | partial(backoff.expo, max_value=task_lock_exception_max_wait_seconds), HasLockedTaskException, max_tries=task_lock_exception_max_tries
195 | )
196 | def _build_task():
197 | task_lock_exception_raised.flag = False
198 | result = luigi.build(
199 | [task],
200 | local_scheduler=True,
201 | detailed_summary=True,
202 | log_level=logging.getLevelName(log_level),
203 | **env_params,
204 | )
205 | if task_lock_exception_raised.flag:
206 | raise HasLockedTaskException()
207 | if result.status in (LuigiStatusCode.FAILED, LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED, LuigiStatusCode.SCHEDULING_FAILED):
208 | raise GokartBuildError(result.summary_text, raised_exceptions=raised_exceptions)
209 | return _get_output(task) if return_value else None
210 |
211 | return _build_task()
212 |
--------------------------------------------------------------------------------
/gokart/build_process_task_info.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import io
4 |
5 | import gokart
6 | import gokart.tree.task_info
7 | from gokart.build import TaskDumpConfig, TaskDumpMode, TaskDumpOutputType
8 | from gokart.task import TaskOnKart
9 |
10 | from .build import logger
11 |
12 |
13 | def process_task_info(task: TaskOnKart, task_dump_config: TaskDumpConfig = TaskDumpConfig()):
14 | match task_dump_config:
15 | case TaskDumpConfig(mode=TaskDumpMode.NONE, output_type=TaskDumpOutputType.NONE):
16 | pass
17 | case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.PRINT):
18 | tree = gokart.make_tree_info(task)
19 | logger.info(tree)
20 | case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.PRINT):
21 | table = gokart.tree.task_info.make_task_info_as_table(task)
22 | output = io.StringIO()
23 | table.to_csv(output, index=False, sep='\t')
24 | output.seek(0)
25 | logger.info(output.read())
26 | case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.DUMP):
27 | tree = gokart.make_tree_info(task)
28 | gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.txt').dump(tree)
29 | case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.DUMP):
30 | table = gokart.tree.task_info.make_task_info_as_table(task)
31 | gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.pkl').dump(table)
32 | case _:
33 | raise ValueError(f'Unsupported TaskDumpConfig: {task_dump_config}')
34 |
--------------------------------------------------------------------------------
/gokart/config_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import luigi
4 |
5 | import gokart
6 |
7 |
8 | class inherits_config_params:
9 | def __init__(self, config_class: type[luigi.Config], parameter_alias: dict[str, str] | None = None):
10 | """
11 | Decorates task to inherit parameter value of `config_class`.
12 |
13 | * config_class: Inherit parameter value of this task to decorated task. Only parameter values exist in both tasks are inherited.
14 | * parameter_alias: Dictionary to map paramter names between config_class task and decorated task.
15 | key: config_class's parameter name. value: decorated task's parameter name.
16 | """
17 |
18 | self._config_class: type[luigi.Config] = config_class
19 | self._parameter_alias: dict[str, str] = parameter_alias if parameter_alias is not None else {}
20 |
21 | def __call__(self, task_class: type[gokart.TaskOnKart]):
22 | # wrap task to prevent task name from being changed
23 | @luigi.task._task_wraps(task_class)
24 | class Wrapped(task_class): # type: ignore
25 | @classmethod
26 | def get_param_values(cls, params, args, kwargs):
27 | for param_key, param_value in self._config_class().param_kwargs.items():
28 | task_param_key = self._parameter_alias.get(param_key, param_key)
29 |
30 | if hasattr(cls, task_param_key) and task_param_key not in kwargs:
31 | kwargs[task_param_key] = param_value
32 | return super().get_param_values(params, args, kwargs)
33 |
34 | return Wrapped
35 |
--------------------------------------------------------------------------------
/gokart/conflict_prevention_lock/task_lock.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | import os
5 | from logging import getLogger
6 | from typing import NamedTuple
7 |
8 | import redis
9 | from apscheduler.schedulers.background import BackgroundScheduler
10 |
11 | logger = getLogger(__name__)
12 |
13 |
14 | class TaskLockParams(NamedTuple):
15 | redis_host: str | None
16 | redis_port: int | None
17 | redis_timeout: int | None
18 | redis_key: str
19 | should_task_lock: bool
20 | raise_task_lock_exception_on_collision: bool
21 | lock_extend_seconds: int
22 |
23 |
24 | class TaskLockException(Exception):
25 | pass
26 | """Raised when the task failed to acquire the lock in the task execution. Only used internally."""
27 |
28 |
29 | class RedisClient:
30 | _instances: dict = {}
31 |
32 | def __new__(cls, *args, **kwargs):
33 | key = (args, tuple(sorted(kwargs.items())))
34 | if cls not in cls._instances:
35 | cls._instances[cls] = {}
36 | if key not in cls._instances[cls]:
37 | cls._instances[cls][key] = super().__new__(cls)
38 | return cls._instances[cls][key]
39 |
40 | def __init__(self, host: str | None, port: int | None) -> None:
41 | if not hasattr(self, '_redis_client'):
42 | host = host or 'localhost'
43 | port = port or 6379
44 | self._redis_client = redis.Redis(host=host, port=port)
45 |
46 | def get_redis_client(self):
47 | return self._redis_client
48 |
49 |
50 | def _extend_lock(task_lock: redis.lock.Lock, redis_timeout: int):
51 | task_lock.extend(additional_time=redis_timeout, replace_ttl=True)
52 |
53 |
54 | def set_task_lock(task_lock_params: TaskLockParams) -> redis.lock.Lock:
55 | redis_client = RedisClient(host=task_lock_params.redis_host, port=task_lock_params.redis_port).get_redis_client()
56 | blocking = not task_lock_params.raise_task_lock_exception_on_collision
57 | task_lock = redis.lock.Lock(redis=redis_client, name=task_lock_params.redis_key, timeout=task_lock_params.redis_timeout, thread_local=False)
58 | if not task_lock.acquire(blocking=blocking):
59 | raise TaskLockException('Lock already taken by other task.')
60 | return task_lock
61 |
62 |
63 | def set_lock_scheduler(task_lock: redis.lock.Lock, task_lock_params: TaskLockParams) -> BackgroundScheduler:
64 | scheduler = BackgroundScheduler()
65 | extend_lock = functools.partial(_extend_lock, task_lock=task_lock, redis_timeout=task_lock_params.redis_timeout or 0)
66 | scheduler.add_job(
67 | extend_lock,
68 | 'interval',
69 | seconds=task_lock_params.lock_extend_seconds,
70 | max_instances=999999999,
71 | misfire_grace_time=task_lock_params.redis_timeout,
72 | coalesce=False,
73 | )
74 | scheduler.start()
75 | return scheduler
76 |
77 |
78 | def make_task_lock_key(file_path: str, unique_id: str | None):
79 | basename_without_ext = os.path.splitext(os.path.basename(file_path))[0]
80 | return f'{basename_without_ext}_{unique_id}'
81 |
82 |
83 | def make_task_lock_params(
84 | file_path: str,
85 | unique_id: str | None,
86 | redis_host: str | None = None,
87 | redis_port: int | None = None,
88 | redis_timeout: int | None = None,
89 | raise_task_lock_exception_on_collision: bool = False,
90 | lock_extend_seconds: int = 10,
91 | ) -> TaskLockParams:
92 | redis_key = make_task_lock_key(file_path, unique_id)
93 | should_task_lock = redis_host is not None and redis_port is not None
94 | if redis_timeout is not None:
95 | assert redis_timeout > lock_extend_seconds, f'`redis_timeout` must be set greater than lock_extend_seconds:{lock_extend_seconds}, not {redis_timeout}.'
96 | task_lock_params = TaskLockParams(
97 | redis_host=redis_host,
98 | redis_port=redis_port,
99 | redis_key=redis_key,
100 | should_task_lock=should_task_lock,
101 | redis_timeout=redis_timeout,
102 | raise_task_lock_exception_on_collision=raise_task_lock_exception_on_collision,
103 | lock_extend_seconds=lock_extend_seconds,
104 | )
105 | return task_lock_params
106 |
107 |
108 | def make_task_lock_params_for_run(task_self, lock_extend_seconds: int = 10) -> TaskLockParams:
109 | task_path_name = os.path.join(task_self.__module__.replace('.', '/'), f'{type(task_self).__name__}')
110 | unique_id = task_self.make_unique_id() + '-run'
111 | task_lock_key = make_task_lock_key(file_path=task_path_name, unique_id=unique_id)
112 |
113 | should_task_lock = task_self.redis_host is not None and task_self.redis_port is not None
114 | return TaskLockParams(
115 | redis_host=task_self.redis_host,
116 | redis_port=task_self.redis_port,
117 | redis_key=task_lock_key,
118 | should_task_lock=should_task_lock,
119 | redis_timeout=task_self.redis_timeout,
120 | raise_task_lock_exception_on_collision=True,
121 | lock_extend_seconds=lock_extend_seconds,
122 | )
123 |
--------------------------------------------------------------------------------
/gokart/conflict_prevention_lock/task_lock_wrappers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | from logging import getLogger
5 | from typing import Any, Callable
6 |
7 | from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock
8 |
9 | logger = getLogger(__name__)
10 |
11 |
12 | def wrap_dump_with_lock(func: Callable, task_lock_params: TaskLockParams, exist_check: Callable):
13 | """Redis lock wrapper function for TargetOnKart.dump().
14 | When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check.
15 | https://github.com/m3dev/gokart/issues/265
16 | """
17 |
18 | if not task_lock_params.should_task_lock:
19 | return func
20 |
21 | def wrapper(*args, **kwargs):
22 | task_lock = set_task_lock(task_lock_params=task_lock_params)
23 | scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
24 |
25 | try:
26 | logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} locked.')
27 | if not exist_check():
28 | func(*args, **kwargs)
29 | finally:
30 | logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} released.')
31 | task_lock.release()
32 | scheduler.shutdown()
33 |
34 | return wrapper
35 |
36 |
37 | def wrap_load_with_lock(func, task_lock_params: TaskLockParams):
38 | """Redis lock wrapper function for TargetOnKart.load().
39 | When TargetOnKart.load() is called, redis lock will be locked and released before load().
40 | https://github.com/m3dev/gokart/issues/265
41 | """
42 |
43 | if not task_lock_params.should_task_lock:
44 | return func
45 |
46 | def wrapper(*args, **kwargs):
47 | task_lock = set_task_lock(task_lock_params=task_lock_params)
48 | scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
49 |
50 | logger.debug(f'Task LOAD lock of {task_lock_params.redis_key} locked.')
51 | task_lock.release()
52 | logger.debug(f'Task LOAD lock of {task_lock_params.redis_key} released.')
53 | scheduler.shutdown()
54 | result = func(*args, **kwargs)
55 | return result
56 |
57 | return wrapper
58 |
59 |
60 | def wrap_remove_with_lock(func, task_lock_params: TaskLockParams):
61 | """Redis lock wrapper function for TargetOnKart.remove().
62 | When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock.
63 | https://github.com/m3dev/gokart/issues/265
64 | """
65 | if not task_lock_params.should_task_lock:
66 | return func
67 |
68 | def wrapper(*args, **kwargs):
69 | task_lock = set_task_lock(task_lock_params=task_lock_params)
70 | scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
71 |
72 | try:
73 | logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} locked.')
74 | result = func(*args, **kwargs)
75 | task_lock.release()
76 | logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} released.')
77 | scheduler.shutdown()
78 | return result
79 | except BaseException as e:
80 | logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} released with BaseException.')
81 | task_lock.release()
82 | scheduler.shutdown()
83 | raise e
84 |
85 | return wrapper
86 |
87 |
88 | def wrap_run_with_lock(run_func: Callable[[], Any], task_lock_params: TaskLockParams):
89 | @functools.wraps(run_func)
90 | def wrapped():
91 | task_lock = set_task_lock(task_lock_params=task_lock_params)
92 | scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
93 |
94 | try:
95 | logger.debug(f'Task RUN lock of {task_lock_params.redis_key} locked.')
96 | result = run_func()
97 | task_lock.release()
98 | logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released.')
99 | scheduler.shutdown()
100 | return result
101 | except BaseException as e:
102 | logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released with BaseException.')
103 | task_lock.release()
104 | scheduler.shutdown()
105 | raise e
106 |
107 | return wrapped
108 |
--------------------------------------------------------------------------------
/gokart/errors/__init__.py:
--------------------------------------------------------------------------------
1 | from gokart.build import GokartBuildError, HasLockedTaskException
2 | from gokart.pandas_type_config import PandasTypeError
3 | from gokart.task import EmptyDumpError
4 |
5 | __all__ = [
6 | 'GokartBuildError',
7 | 'HasLockedTaskException',
8 | 'PandasTypeError',
9 | 'EmptyDumpError',
10 | ]
11 |
--------------------------------------------------------------------------------
/gokart/gcs_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import os
5 |
6 | import luigi
7 | import luigi.contrib.gcs
8 | from google.oauth2.service_account import Credentials
9 |
10 |
11 | class GCSConfig(luigi.Config):
12 | gcs_credential_name: str = luigi.Parameter(default='GCS_CREDENTIAL', description='GCS credential environment variable.')
13 | _client = None
14 |
15 | def get_gcs_client(self) -> luigi.contrib.gcs.GCSClient:
16 | if self._client is None: # use cache as like singleton object
17 | self._client = self._get_gcs_client()
18 | return self._client
19 |
20 | def _get_gcs_client(self) -> luigi.contrib.gcs.GCSClient:
21 | return luigi.contrib.gcs.GCSClient(oauth_credentials=self._load_oauth_credentials())
22 |
23 | def _load_oauth_credentials(self) -> Credentials | None:
24 | json_str = os.environ.get(self.gcs_credential_name)
25 | if not json_str:
26 | return None
27 |
28 | if os.path.isfile(json_str):
29 | return Credentials.from_service_account_file(json_str)
30 |
31 | return Credentials.from_service_account_info(json.loads(json_str))
32 |
--------------------------------------------------------------------------------
/gokart/gcs_zip_client.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import shutil
5 |
6 | from gokart.gcs_config import GCSConfig
7 | from gokart.zip_client import ZipClient, _unzip_file
8 |
9 |
10 | class GCSZipClient(ZipClient):
11 | def __init__(self, file_path: str, temporary_directory: str) -> None:
12 | self._file_path = file_path
13 | self._temporary_directory = temporary_directory
14 | self._client = GCSConfig().get_gcs_client()
15 |
16 | def exists(self) -> bool:
17 | return self._client.exists(self._file_path)
18 |
19 | def make_archive(self) -> None:
20 | extension = os.path.splitext(self._file_path)[1]
21 | shutil.make_archive(base_name=self._temporary_directory, format=extension[1:], root_dir=self._temporary_directory)
22 | self._client.put(self._temporary_file_path(), self._file_path)
23 |
24 | def unpack_archive(self) -> None:
25 | os.makedirs(self._temporary_directory, exist_ok=True)
26 | file_pointer = self._client.download(self._file_path)
27 | _unzip_file(fp=file_pointer, extract_dir=self._temporary_directory)
28 |
29 | def remove(self) -> None:
30 | self._client.remove(self._file_path)
31 |
32 | @property
33 | def path(self) -> str:
34 | return self._file_path
35 |
36 | def _temporary_file_path(self):
37 | extension = os.path.splitext(self._file_path)[1]
38 | base_name = self._temporary_directory
39 | if base_name.endswith('/'):
40 | base_name = base_name[:-1]
41 | return base_name + extension
42 |
--------------------------------------------------------------------------------
/gokart/in_memory/__init__.py:
--------------------------------------------------------------------------------
1 | from .repository import InMemoryCacheRepository # noqa:F401
2 | from .target import InMemoryTarget, make_in_memory_target # noqa:F401
3 |
--------------------------------------------------------------------------------
/gokart/in_memory/data.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from datetime import datetime
5 | from typing import Any
6 |
7 |
8 | @dataclass
9 | class InMemoryData:
10 | value: Any
11 | last_modification_time: datetime
12 |
13 | @classmethod
14 | def create_data(self, value: Any) -> InMemoryData:
15 | return InMemoryData(value=value, last_modification_time=datetime.now())
16 |
--------------------------------------------------------------------------------
/gokart/in_memory/repository.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Iterator
4 | from typing import Any
5 |
6 | from .data import InMemoryData
7 |
8 |
9 | class InMemoryCacheRepository:
10 | _cache: dict[str, InMemoryData] = {}
11 |
12 | def __init__(self):
13 | pass
14 |
15 | def get_value(self, key: str) -> Any:
16 | return self._get_data(key).value
17 |
18 | def get_last_modification_time(self, key: str):
19 | return self._get_data(key).last_modification_time
20 |
21 | def _get_data(self, key: str) -> InMemoryData:
22 | return self._cache[key]
23 |
24 | def set_value(self, key: str, obj: Any) -> None:
25 | data = InMemoryData.create_data(obj)
26 | self._cache[key] = data
27 |
28 | def has(self, key: str) -> bool:
29 | return key in self._cache
30 |
31 | def remove(self, key: str) -> None:
32 | assert self.has(key), f'{key} does not exist.'
33 | del self._cache[key]
34 |
35 | def empty(self) -> bool:
36 | return not self._cache
37 |
38 | def clear(self) -> None:
39 | self._cache.clear()
40 |
41 | def get_gen(self) -> Iterator[tuple[str, Any]]:
42 | for key, data in self._cache.items():
43 | yield key, data.value
44 |
45 | @property
46 | def size(self) -> int:
47 | return len(self._cache)
48 |
--------------------------------------------------------------------------------
/gokart/in_memory/target.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from datetime import datetime
4 | from typing import Any
5 |
6 | from gokart.in_memory.repository import InMemoryCacheRepository
7 | from gokart.required_task_output import RequiredTaskOutput
8 | from gokart.target import TargetOnKart, TaskLockParams
9 | from gokart.utils import FlattenableItems
10 |
11 | _repository = InMemoryCacheRepository()
12 |
13 |
14 | class InMemoryTarget(TargetOnKart):
15 | def __init__(self, data_key: str, task_lock_param: TaskLockParams):
16 | if task_lock_param.should_task_lock:
17 | raise ValueError('Redis with `InMemoryTarget` is not currently supported.')
18 |
19 | self._data_key = data_key
20 | self._task_lock_params = task_lock_param
21 |
22 | def _exists(self) -> bool:
23 | return _repository.has(self._data_key)
24 |
25 | def _get_task_lock_params(self) -> TaskLockParams:
26 | return self._task_lock_params
27 |
28 | def _load(self) -> Any:
29 | return _repository.get_value(self._data_key)
30 |
31 | def _dump(
32 | self,
33 | obj: Any,
34 | task_params: dict[str, str] | None = None,
35 | custom_labels: dict[str, str] | None = None,
36 | required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
37 | ) -> None:
38 | return _repository.set_value(self._data_key, obj)
39 |
40 | def _remove(self) -> None:
41 | _repository.remove(self._data_key)
42 |
43 | def _last_modification_time(self) -> datetime:
44 | if not _repository.has(self._data_key):
45 | raise ValueError(f'No object(s) which id is {self._data_key} are stored before.')
46 | time = _repository.get_last_modification_time(self._data_key)
47 | return time
48 |
49 | def _path(self) -> str:
50 | # TODO: this module name `_path` migit not be appropriate
51 | return self._data_key
52 |
53 |
54 | def make_in_memory_target(target_key: str, task_lock_params: TaskLockParams) -> InMemoryTarget:
55 | return InMemoryTarget(target_key, task_lock_params)
56 |
--------------------------------------------------------------------------------
/gokart/info.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from logging import getLogger
4 |
5 | import luigi
6 |
7 | from gokart.task import TaskOnKart
8 | from gokart.tree.task_info import make_task_info_as_tree_str
9 |
10 | logger = getLogger(__name__)
11 |
12 |
13 | def make_tree_info(
14 | task: TaskOnKart,
15 | indent: str = '',
16 | last: bool = True,
17 | details: bool = False,
18 | abbr: bool = True,
19 | visited_tasks: set[str] | None = None,
20 | ignore_task_names: list[str] | None = None,
21 | ) -> str:
22 | """
23 | Return a string representation of the tasks, their statuses/parameters in a dependency tree format
24 |
25 | This function has moved to `gokart.tree.task_info.make_task_info_as_tree_str`.
26 | This code is remained for backward compatibility.
27 |
28 | Parameters
29 | ----------
30 | - task: TaskOnKart
31 | Root task.
32 | - details: bool
33 | Whether or not to output details.
34 | - abbr: bool
35 | Whether or not to simplify tasks information that has already appeared.
36 | - ignore_task_names: list[str] | None
37 | List of task names to ignore.
38 | Returns
39 | -------
40 | - tree_info : str
41 | Formatted task dependency tree.
42 | """
43 | return make_task_info_as_tree_str(task=task, details=details, abbr=abbr, ignore_task_names=ignore_task_names)
44 |
45 |
46 | class tree_info(TaskOnKart):
47 | mode: str = luigi.Parameter(default='', description='This must be in ["simple", "all"].')
48 | output_path: str = luigi.Parameter(default='tree.txt', description='Output file path.')
49 |
50 | def output(self):
51 | return self.make_target(self.output_path, use_unique_id=False)
52 |
--------------------------------------------------------------------------------
/gokart/object_storage.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from datetime import datetime
4 |
5 | import luigi
6 | import luigi.contrib.gcs
7 | import luigi.contrib.s3
8 | from luigi.format import Format
9 |
10 | from gokart.gcs_config import GCSConfig
11 | from gokart.gcs_zip_client import GCSZipClient
12 | from gokart.s3_config import S3Config
13 | from gokart.s3_zip_client import S3ZipClient
14 | from gokart.zip_client import ZipClient
15 |
16 | object_storage_path_prefix = ['s3://', 'gs://']
17 |
18 |
19 | class ObjectStorage:
20 | @staticmethod
21 | def if_object_storage_path(path: str) -> bool:
22 | for prefix in object_storage_path_prefix:
23 | if path.startswith(prefix):
24 | return True
25 | return False
26 |
27 | @staticmethod
28 | def get_object_storage_target(path: str, format: Format) -> luigi.target.FileSystemTarget:
29 | if path.startswith('s3://'):
30 | return luigi.contrib.s3.S3Target(path, client=S3Config().get_s3_client(), format=format)
31 | elif path.startswith('gs://'):
32 | return luigi.contrib.gcs.GCSTarget(path, client=GCSConfig().get_gcs_client(), format=format)
33 | else:
34 | raise
35 |
36 | @staticmethod
37 | def exists(path: str) -> bool:
38 | if path.startswith('s3://'):
39 | return S3Config().get_s3_client().exists(path)
40 | elif path.startswith('gs://'):
41 | return GCSConfig().get_gcs_client().exists(path)
42 | else:
43 | raise
44 |
45 | @staticmethod
46 | def get_timestamp(path: str) -> datetime:
47 | if path.startswith('s3://'):
48 | return S3Config().get_s3_client().get_key(path).last_modified
49 | elif path.startswith('gs://'):
50 | # for gcs object
51 | # should PR to luigi
52 | bucket, obj = GCSConfig().get_gcs_client()._path_to_bucket_and_key(path)
53 | result = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute()
54 | return result['updated']
55 | else:
56 | raise
57 |
58 | @staticmethod
59 | def get_zip_client(file_path: str, temporary_directory: str) -> ZipClient:
60 | if file_path.startswith('s3://'):
61 | return S3ZipClient(file_path=file_path, temporary_directory=temporary_directory)
62 | elif file_path.startswith('gs://'):
63 | return GCSZipClient(file_path=file_path, temporary_directory=temporary_directory)
64 | else:
65 | raise
66 |
67 | @staticmethod
68 | def is_buffered_reader(file: object):
69 | return not isinstance(file, luigi.contrib.s3.ReadableS3File)
70 |
--------------------------------------------------------------------------------
/gokart/pandas_type_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import abstractmethod
4 | from logging import getLogger
5 | from typing import Any
6 |
7 | import luigi
8 | import numpy as np
9 | import pandas as pd
10 | from luigi.task_register import Register
11 |
12 | logger = getLogger(__name__)
13 |
14 |
15 | class PandasTypeError(Exception):
16 | """Raised when the type of the pandas DataFrame column is not as expected."""
17 |
18 |
19 | class PandasTypeConfig(luigi.Config):
20 | @classmethod
21 | @abstractmethod
22 | def type_dict(cls) -> dict[str, Any]:
23 | pass
24 |
25 | @classmethod
26 | def check(cls, df: pd.DataFrame):
27 | for column_name, column_type in cls.type_dict().items():
28 | cls._check_column(df, column_name, column_type)
29 |
30 | @classmethod
31 | def _check_column(cls, df, column_name, column_type):
32 | if column_name not in df.columns:
33 | return
34 |
35 | if not np.all(list(map(lambda x: isinstance(x, column_type), df[column_name]))):
36 | not_match = next(filter(lambda x: not isinstance(x, column_type), df[column_name]))
37 | raise PandasTypeError(f'expected type is "{column_type}", but "{type(not_match)}" is passed in column "{column_name}".')
38 |
39 |
40 | class PandasTypeConfigMap(luigi.Config):
41 | """To initialize this class only once, this inherits luigi.Config."""
42 |
43 | def __init__(self, *args, **kwargs) -> None:
44 | super().__init__(*args, **kwargs)
45 | task_names = Register.task_names()
46 | task_classes = [Register.get_task_cls(task_name) for task_name in task_names]
47 | self._map = {
48 | task_class.task_namespace: task_class for task_class in task_classes if issubclass(task_class, PandasTypeConfig) and task_class != PandasTypeConfig
49 | }
50 |
51 | def check(self, obj, task_namespace: str):
52 | if isinstance(obj, pd.DataFrame) and task_namespace in self._map:
53 | self._map[task_namespace].check(obj)
54 |
--------------------------------------------------------------------------------
/gokart/parameter.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import bz2
4 | import datetime
5 | import json
6 | from logging import getLogger
7 | from typing import Generic, Protocol, TypeVar
8 | from warnings import warn
9 |
10 | import luigi
11 | from luigi import task_register
12 |
13 | import gokart
14 |
15 | logger = getLogger(__name__)
16 |
17 |
18 | class TaskInstanceParameter(luigi.Parameter):
19 | def __init__(self, expected_type=None, *args, **kwargs):
20 | if expected_type is None:
21 | self.expected_type: type = gokart.TaskOnKart
22 | elif isinstance(expected_type, type):
23 | self.expected_type = expected_type
24 | else:
25 | raise TypeError(f'expected_type must be a type, not {type(expected_type)}')
26 | super().__init__(*args, **kwargs)
27 |
28 | @staticmethod
29 | def _recursive(param_dict):
30 | params = param_dict['params']
31 | task_cls = task_register.Register.get_task_cls(param_dict['type'])
32 | for key, value in task_cls.get_params():
33 | if key in params:
34 | params[key] = value.parse(params[key])
35 | return task_cls(**params)
36 |
37 | @staticmethod
38 | def _recursive_decompress(s):
39 | s = dict(luigi.DictParameter().parse(s))
40 | if 'params' in s:
41 | s['params'] = TaskInstanceParameter._recursive_decompress(bz2.decompress(bytes.fromhex(s['params'])).decode())
42 | return s
43 |
44 | def parse(self, s):
45 | if isinstance(s, str):
46 | s = self._recursive_decompress(s)
47 | return self._recursive(s)
48 |
49 | def serialize(self, x):
50 | params = bz2.compress(json.dumps(x.to_str_params(only_significant=True)).encode()).hex()
51 | values = dict(type=x.get_task_family(), params=params)
52 | return luigi.DictParameter().serialize(values)
53 |
54 | def _warn_on_wrong_param_type(self, param_name, param_value):
55 | if not isinstance(param_value, self.expected_type):
56 | raise TypeError(f'{param_value} is not an instance of {self.expected_type}')
57 |
58 |
59 | class _TaskInstanceEncoder(json.JSONEncoder):
60 | def default(self, obj):
61 | if isinstance(obj, luigi.Task):
62 | return TaskInstanceParameter().serialize(obj)
63 | # Let the base class default method raise the TypeError
64 | return json.JSONEncoder.default(self, obj)
65 |
66 |
67 | class ListTaskInstanceParameter(luigi.Parameter):
68 | def __init__(self, expected_elements_type=None, *args, **kwargs):
69 | if expected_elements_type is None:
70 | self.expected_elements_type: type = gokart.TaskOnKart
71 | elif isinstance(expected_elements_type, type):
72 | self.expected_elements_type = expected_elements_type
73 | else:
74 | raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}')
75 | super().__init__(*args, **kwargs)
76 |
77 | def parse(self, s):
78 | return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))]
79 |
80 | def serialize(self, x):
81 | return json.dumps(x, cls=_TaskInstanceEncoder)
82 |
83 | def _warn_on_wrong_param_type(self, param_name, param_value):
84 | for v in param_value:
85 | if not isinstance(v, self.expected_elements_type):
86 | raise TypeError(f'{v} is not an instance of {self.expected_elements_type}')
87 |
88 |
89 | class ExplicitBoolParameter(luigi.BoolParameter):
90 | def __init__(self, *args, **kwargs):
91 | luigi.Parameter.__init__(self, *args, **kwargs)
92 |
93 | def _parser_kwargs(self, *args, **kwargs): # type: ignore
94 | return luigi.Parameter._parser_kwargs(*args, *kwargs)
95 |
96 |
97 | T = TypeVar('T')
98 |
99 |
100 | class Serializable(Protocol):
101 | def gokart_serialize(self) -> str:
102 | """Implement this method to serialize the object as an parameter
103 | You can omit some fields from results of serialization if you want to ignore changes of them
104 | """
105 | ...
106 |
107 | @classmethod
108 | def gokart_deserialize(cls: type[T], s: str) -> T:
109 | """Implement this method to deserialize the object from a string"""
110 | ...
111 |
112 |
113 | S = TypeVar('S', bound=Serializable)
114 |
115 |
116 | class SerializableParameter(luigi.Parameter, Generic[S]):
117 | def __init__(self, object_type: type[S], *args, **kwargs):
118 | self._object_type = object_type
119 | super().__init__(*args, **kwargs)
120 |
121 | def parse(self, s: str) -> S:
122 | return self._object_type.gokart_deserialize(s)
123 |
124 | def serialize(self, x: S) -> str:
125 | return x.gokart_serialize()
126 |
127 |
128 | class ZonedDateSecondParameter(luigi.Parameter):
129 | """
130 | ZonedDateSecondParameter supports a datetime.datetime object with timezone information.
131 |
132 | A ZonedDateSecondParameter is a `ISO 8601 `_ formatted
133 | date, time specified to the second and timezone. For example, ``2013-07-10T19:07:38+09:00`` specifies July 10, 2013 at
134 | 19:07:38 +09:00. The separator `:` can be omitted for Python3.11 and later.
135 | """
136 |
137 | def __init__(self, **kwargs):
138 | super().__init__(**kwargs)
139 |
140 | def parse(self, s):
141 | # special character 'Z' is replaced with '+00:00'
142 | # because Python 3.11 and later support fromisoformat with Z at the end of the string.
143 | if s.endswith('Z'):
144 | s = s[:-1] + '+00:00'
145 | dt = datetime.datetime.fromisoformat(s)
146 | if dt.tzinfo is None:
147 | warn('The input does not have timezone information. Please consider using luigi.DateSecondParameter instead.', stacklevel=1)
148 | return dt
149 |
150 | def serialize(self, dt):
151 | return dt.isoformat()
152 |
153 | def normalize(self, dt):
154 | # override _DatetimeParameterBase.normalize to avoid do nothing to normalize except removing microsecond.
155 | # microsecond is removed because the number of digits of microsecond is not fixed.
156 | # See also luigi's implementation https://github.com/spotify/luigi/blob/v3.6.0/luigi/parameter.py#L612
157 | return dt.replace(microsecond=0)
158 |
--------------------------------------------------------------------------------
/gokart/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/gokart/py.typed
--------------------------------------------------------------------------------
/gokart/required_task_output.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class RequiredTaskOutput:
6 | task_name: str
7 | output_path: str
8 |
9 | def serialize(self) -> dict[str, str]:
10 | return {'__gokart_task_name': self.task_name, '__gokart_output_path': self.output_path}
11 |
--------------------------------------------------------------------------------
/gokart/run.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import sys
5 | from logging import getLogger
6 |
7 | import luigi
8 | import luigi.cmdline
9 | import luigi.retcodes
10 | from luigi.cmdline_parser import CmdlineParser
11 |
12 | import gokart
13 | import gokart.slack
14 | from gokart.object_storage import ObjectStorage
15 |
16 | logger = getLogger(__name__)
17 |
18 |
19 | def _run_tree_info(cmdline_args, details):
20 | with CmdlineParser.global_instance(cmdline_args) as cp:
21 | gokart.tree_info().output().dump(gokart.make_tree_info(cp.get_task_obj(), details=details))
22 |
23 |
24 | def _try_tree_info(cmdline_args):
25 | with CmdlineParser.global_instance(cmdline_args):
26 | mode = gokart.tree_info().mode
27 | output_path = gokart.tree_info().output().path()
28 |
29 | # do nothing if `mode` is empty.
30 | if mode == '':
31 | return
32 |
33 | # output tree info and exit.
34 | if mode == 'simple':
35 | _run_tree_info(cmdline_args, details=False)
36 | elif mode == 'all':
37 | _run_tree_info(cmdline_args, details=True)
38 | else:
39 | raise ValueError(f'--tree-info-mode must be "simple" or "all", but "{mode}" is passed.')
40 | logger.info(f'output tree info: {output_path}')
41 | sys.exit()
42 |
43 |
44 | def _try_to_delete_unnecessary_output_file(cmdline_args: list[str]):
45 | with CmdlineParser.global_instance(cmdline_args) as cp:
46 | task = cp.get_task_obj() # type: gokart.TaskOnKart
47 | if task.delete_unnecessary_output_files:
48 | if ObjectStorage.if_object_storage_path(task.workspace_directory):
49 | logger.info('delete-unnecessary-output-files is not support s3/gcs.')
50 | else:
51 | gokart.delete_local_unnecessary_outputs(task)
52 | sys.exit()
53 |
54 |
55 | def _try_get_slack_api(cmdline_args: list[str]) -> gokart.slack.SlackAPI | None:
56 | with CmdlineParser.global_instance(cmdline_args):
57 | config = gokart.slack.SlackConfig()
58 | token = os.getenv(config.token_name, '')
59 | channel = config.channel
60 | to_user = config.to_user
61 | if token and channel:
62 | logger.info('Slack notification is activated.')
63 | return gokart.slack.SlackAPI(token=token, channel=channel, to_user=to_user)
64 | logger.info('Slack notification is not activated.')
65 | return None
66 |
67 |
68 | def _try_to_send_event_summary_to_slack(slack_api: gokart.slack.SlackAPI | None, event_aggregator: gokart.slack.EventAggregator, cmdline_args: list[str]):
69 | if slack_api is None:
70 | # do nothing
71 | return
72 | options = gokart.slack.SlackConfig()
73 | with CmdlineParser.global_instance(cmdline_args) as cp:
74 | task = cp.get_task_obj()
75 | tree_info = gokart.make_tree_info(task, details=True) if options.send_tree_info else 'Please add SlackConfig.send_tree_info to include tree-info'
76 | task_name = type(task).__name__
77 |
78 | comment = f'Report of {task_name}' + os.linesep + event_aggregator.get_summary()
79 | content = os.linesep.join(['===== Event List ====', event_aggregator.get_event_list(), os.linesep, '==== Tree Info ====', tree_info])
80 | slack_api.send_snippet(comment=comment, title='event.txt', content=content)
81 |
82 |
83 | def run(cmdline_args=None, set_retcode=True):
84 | cmdline_args = cmdline_args or sys.argv[1:]
85 |
86 | if set_retcode:
87 | luigi.retcodes.retcode.already_running = 10
88 | luigi.retcodes.retcode.missing_data = 20
89 | luigi.retcodes.retcode.not_run = 30
90 | luigi.retcodes.retcode.task_failed = 40
91 | luigi.retcodes.retcode.scheduling_error = 50
92 |
93 | _try_tree_info(cmdline_args)
94 | _try_to_delete_unnecessary_output_file(cmdline_args)
95 | gokart.testing.try_to_run_test_for_empty_data_frame(cmdline_args)
96 |
97 | slack_api = _try_get_slack_api(cmdline_args)
98 | event_aggregator = gokart.slack.EventAggregator()
99 | try:
100 | event_aggregator.set_handlers()
101 | luigi.cmdline.luigi_run(cmdline_args)
102 | except SystemExit as e:
103 | _try_to_send_event_summary_to_slack(slack_api, event_aggregator, cmdline_args)
104 | sys.exit(e.code)
105 |
--------------------------------------------------------------------------------
/gokart/s3_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 |
5 | import luigi
6 | import luigi.contrib.s3
7 |
8 |
9 | class S3Config(luigi.Config):
10 | aws_access_key_id_name = luigi.Parameter(default='AWS_ACCESS_KEY_ID', description='AWS access key id environment variable.')
11 | aws_secret_access_key_name = luigi.Parameter(default='AWS_SECRET_ACCESS_KEY', description='AWS secret access key environment variable.')
12 |
13 | _client = None
14 |
15 | def get_s3_client(self) -> luigi.contrib.s3.S3Client:
16 | if self._client is None: # use cache as like singleton object
17 | self._client = self._get_s3_client()
18 | return self._client
19 |
20 | def _get_s3_client(self) -> luigi.contrib.s3.S3Client:
21 | return luigi.contrib.s3.S3Client(
22 | aws_access_key_id=os.environ.get(self.aws_access_key_id_name), aws_secret_access_key=os.environ.get(self.aws_secret_access_key_name)
23 | )
24 |
--------------------------------------------------------------------------------
/gokart/s3_zip_client.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import shutil
5 |
6 | from gokart.s3_config import S3Config
7 | from gokart.zip_client import ZipClient, _unzip_file
8 |
9 |
10 | class S3ZipClient(ZipClient):
11 | def __init__(self, file_path: str, temporary_directory: str) -> None:
12 | self._file_path = file_path
13 | self._temporary_directory = temporary_directory
14 | self._client = S3Config().get_s3_client()
15 |
16 | def exists(self) -> bool:
17 | return self._client.exists(self._file_path)
18 |
19 | def make_archive(self) -> None:
20 | extension = os.path.splitext(self._file_path)[1]
21 | if not os.path.exists(self._temporary_directory):
22 | # Check path existence since shutil.make_archive() of python 3.10+ does not check it.
23 | raise FileNotFoundError(f'Temporary directory {self._temporary_directory} is not found.')
24 | shutil.make_archive(base_name=self._temporary_directory, format=extension[1:], root_dir=self._temporary_directory)
25 | self._client.put(self._temporary_file_path(), self._file_path)
26 |
27 | def unpack_archive(self) -> None:
28 | os.makedirs(self._temporary_directory, exist_ok=True)
29 | self._client.get(self._file_path, self._temporary_file_path())
30 | _unzip_file(fp=self._temporary_file_path(), extract_dir=self._temporary_directory)
31 |
32 | def remove(self) -> None:
33 | self._client.remove(self._file_path)
34 |
35 | @property
36 | def path(self) -> str:
37 | return self._file_path
38 |
39 | def _temporary_file_path(self):
40 | extension = os.path.splitext(self._file_path)[1]
41 | base_name = self._temporary_directory
42 | if base_name.endswith('/'):
43 | base_name = base_name[:-1]
44 | return base_name + extension
45 |
--------------------------------------------------------------------------------
/gokart/slack/__init__.py:
--------------------------------------------------------------------------------
1 | from gokart.slack.event_aggregator import EventAggregator
2 | from gokart.slack.slack_api import SlackAPI
3 | from gokart.slack.slack_config import SlackConfig
4 |
5 | from .slack_api import ChannelListNotLoadedError, ChannelNotFoundError, FileNotUploadedError
6 |
7 | __all__ = [
8 | 'ChannelListNotLoadedError',
9 | 'ChannelNotFoundError',
10 | 'FileNotUploadedError',
11 | 'EventAggregator',
12 | 'SlackAPI',
13 | 'SlackConfig',
14 | ]
15 |
--------------------------------------------------------------------------------
/gokart/slack/event_aggregator.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from logging import getLogger
5 | from typing import TypedDict
6 |
7 | import luigi
8 |
9 | logger = getLogger(__name__)
10 |
11 |
12 | class FailureEvent(TypedDict):
13 | task: str
14 | exception: str
15 |
16 |
17 | class EventAggregator:
18 | def __init__(self) -> None:
19 | self._success_events: list[str] = []
20 | self._failure_events: list[FailureEvent] = []
21 |
22 | def set_handlers(self):
23 | handlers = [(luigi.Event.SUCCESS, self._success), (luigi.Event.FAILURE, self._failure)]
24 | for event, handler in handlers:
25 | luigi.Task.event_handler(event)(handler)
26 |
27 | def get_summary(self) -> str:
28 | return f'Success: {len(self._success_events)}; Failure: {len(self._failure_events)}'
29 |
30 | def get_event_list(self) -> str:
31 | message = ''
32 | if len(self._failure_events) != 0:
33 | failure_message = os.linesep.join([f'Task: {failure["task"]}; Exception: {failure["exception"]}' for failure in self._failure_events])
34 | message += '---- Failure Tasks ----' + os.linesep + failure_message
35 | if len(self._success_events) != 0:
36 | success_message = os.linesep.join(self._success_events)
37 | message += '---- Success Tasks ----' + os.linesep + success_message
38 | if message == '':
39 | message = 'Tasks were not executed.'
40 | return message
41 |
42 | def _success(self, task):
43 | self._success_events.append(self._task_to_str(task))
44 |
45 | def _failure(self, task, exception):
46 | failure: FailureEvent = {'task': self._task_to_str(task), 'exception': str(exception)}
47 | self._failure_events.append(failure)
48 |
49 | @staticmethod
50 | def _task_to_str(task) -> str:
51 | return f'{type(task).__name__}:[{task.make_unique_id()}]'
52 |
--------------------------------------------------------------------------------
/gokart/slack/slack_api.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from logging import getLogger
4 |
5 | import slack_sdk
6 |
7 | logger = getLogger(__name__)
8 |
9 |
10 | class ChannelListNotLoadedError(RuntimeError):
11 | pass
12 |
13 |
14 | class ChannelNotFoundError(RuntimeError):
15 | pass
16 |
17 |
18 | class FileNotUploadedError(RuntimeError):
19 | pass
20 |
21 |
22 | class SlackAPI:
23 | def __init__(self, token, channel: str, to_user: str) -> None:
24 | self._client = slack_sdk.WebClient(token=token)
25 | self._channel_id = self._get_channel_id(channel)
26 | self._to_user = to_user if to_user == '' or to_user.startswith('@') else '@' + to_user
27 |
28 | def _get_channel_id(self, channel_name):
29 | params = {'exclude_archived': True, 'limit': 100}
30 | try:
31 | for channels in self._client.conversations_list(params=params):
32 | if not channels:
33 | raise ChannelListNotLoadedError('Channel list is empty.')
34 | for channel in channels.get('channels', []):
35 | if channel['name'] == channel_name:
36 | return channel['id']
37 | raise ChannelNotFoundError(f'Channel {channel_name} is not found in public channels.')
38 | except Exception as e:
39 | logger.warning(f'The job will start without slack notification: {e}')
40 |
41 | def send_snippet(self, comment, title, content):
42 | try:
43 | request_body = dict(
44 | channels=self._channel_id, initial_comment=f'<{self._to_user}> {comment}' if self._to_user else comment, content=content, title=title
45 | )
46 | response = self._client.api_call('files.upload', data=request_body)
47 | if not response['ok']:
48 | raise FileNotUploadedError(f'Error while uploading file. The error reason is "{response["error"]}".')
49 | except Exception as e:
50 | logger.warning(f'Failed to send slack notification: {e}')
51 |
--------------------------------------------------------------------------------
/gokart/slack/slack_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import luigi
4 |
5 |
6 | class SlackConfig(luigi.Config):
7 | token_name = luigi.Parameter(default='SLACK_TOKEN', description='slack token environment variable.')
8 | channel = luigi.Parameter(default='', significant=False, description='channel name for notification.')
9 | to_user = luigi.Parameter(default='', significant=False, description='Optional; user name who is supposed to be mentioned.')
10 | send_tree_info = luigi.BoolParameter(
11 | default=False,
12 | significant=False,
13 | description='When this option is true, the dependency tree of tasks is included in send message.'
14 | 'It is recommended to set false to this option when notification takes long time.',
15 | )
16 |
--------------------------------------------------------------------------------
/gokart/task_complete_check.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | from logging import getLogger
5 | from typing import Callable
6 |
7 | logger = getLogger(__name__)
8 |
9 |
10 | def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable):
11 | @functools.wraps(run_func)
12 | def wrapper(*args, **kwargs):
13 | if complete_check_func():
14 | logger.warning(f'{run_func.__name__} is skipped because the task is already completed.')
15 | return
16 | return run_func(*args, **kwargs)
17 |
18 | return wrapper
19 |
--------------------------------------------------------------------------------
/gokart/testing/__init__.py:
--------------------------------------------------------------------------------
1 | from gokart.testing.check_if_run_with_empty_data_frame import test_run, try_to_run_test_for_empty_data_frame # noqa:F401
2 | from gokart.testing.pandas_assert import assert_frame_contents_equal # noqa:F401
3 |
--------------------------------------------------------------------------------
/gokart/testing/check_if_run_with_empty_data_frame.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import sys
5 |
6 | import luigi
7 | from luigi.cmdline_parser import CmdlineParser
8 |
9 | import gokart
10 | from gokart.utils import flatten
11 |
12 | test_logger = logging.getLogger(__name__)
13 | test_logger.addHandler(logging.StreamHandler())
14 | test_logger.setLevel(logging.INFO)
15 |
16 |
17 | class test_run(gokart.TaskOnKart):
18 | pandas: bool = luigi.BoolParameter()
19 | namespace: str | None = luigi.OptionalParameter(
20 | default=None, description='When task namespace is not defined explicitly, please use "__not_user_specified".'
21 | )
22 |
23 |
24 | class _TestStatus:
25 | def __init__(self, task: gokart.TaskOnKart) -> None:
26 | self.namespace = task.task_namespace
27 | self.name = type(task).__name__
28 | self.task_id = task.make_unique_id()
29 | self.status = 'OK'
30 | self.message: Exception | None = None
31 |
32 | def format(self) -> str:
33 | s = f'status={self.status}; namespace={self.namespace}; name={self.name}; id={self.task_id};'
34 | if self.message:
35 | s += f' message={type(self.message)}: {", ".join(map(str, self.message.args))}'
36 | return s
37 |
38 | def fail(self) -> bool:
39 | return self.status != 'OK'
40 |
41 |
42 | def _get_all_tasks(task: gokart.TaskOnKart) -> list[gokart.TaskOnKart]:
43 | result = [task]
44 | for o in flatten(task.requires()):
45 | result.extend(_get_all_tasks(o))
46 | return result
47 |
48 |
49 | def _run_with_test_status(task: gokart.TaskOnKart):
50 | test_message = _TestStatus(task)
51 | try:
52 | task.run() # type: ignore
53 | except Exception as e:
54 | test_message.status = 'NG'
55 | test_message.message = e
56 | return test_message
57 |
58 |
59 | def _test_run_with_empty_data_frame(cmdline_args: list[str], test_run_params: test_run):
60 | from unittest.mock import patch
61 |
62 | try:
63 | gokart.run(cmdline_args=cmdline_args)
64 | except SystemExit as e:
65 | assert e.code == 0, f'original workflow does not run properly. It exited with error code {e}.'
66 |
67 | with CmdlineParser.global_instance(cmdline_args) as cp:
68 | all_tasks = _get_all_tasks(cp.get_task_obj())
69 |
70 | if test_run_params.namespace is not None:
71 | all_tasks = [t for t in all_tasks if t.task_namespace == test_run_params.namespace]
72 |
73 | with patch('gokart.TaskOnKart.dump', new=lambda *args, **kwargs: None):
74 | test_status_list = [_run_with_test_status(t) for t in all_tasks]
75 |
76 | test_logger.info('gokart test results:\n' + '\n'.join(s.format() for s in test_status_list))
77 | if any(s.fail() for s in test_status_list):
78 | sys.exit(1)
79 |
80 |
81 | def try_to_run_test_for_empty_data_frame(cmdline_args: list[str]):
82 | with CmdlineParser.global_instance(cmdline_args):
83 | test_run_params = test_run()
84 |
85 | if test_run_params.pandas:
86 | cmdline_args = [a for a in cmdline_args if not a.startswith('--test-run-')]
87 | _test_run_with_empty_data_frame(cmdline_args=cmdline_args, test_run_params=test_run_params)
88 | sys.exit(0)
89 |
--------------------------------------------------------------------------------
/gokart/testing/pandas_assert.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pandas as pd
4 |
5 |
6 | def assert_frame_contents_equal(actual: pd.DataFrame, expected: pd.DataFrame, **kwargs):
7 | """
8 | Assert that two DataFrames are equal.
9 | This function is mostly same as pandas.testing.assert_frame_equal, however
10 | - this fuction ignores the order of index and columns.
11 | - this function fails when duplicated index or columns are found.
12 |
13 | Parameters
14 | ----------
15 | - actual, expected: pd.DataFrame
16 | DataFrames to be compared.
17 | - kwargs: Any
18 | Parameters passed to pandas.testing.assert_frame_equal.
19 | """
20 | assert isinstance(actual, pd.DataFrame), 'actual is not a DataFrame'
21 | assert isinstance(expected, pd.DataFrame), 'expected is not a DataFrame'
22 |
23 | assert actual.index.is_unique, 'actual index is not unique'
24 | assert expected.index.is_unique, 'expected index is not unique'
25 | assert actual.columns.is_unique, 'actual columns is not unique'
26 | assert expected.columns.is_unique, 'expected columns is not unique'
27 |
28 | assert set(actual.columns) == set(expected.columns), 'columns are not equal'
29 | assert set(actual.index) == set(expected.index), 'indexes are not equal'
30 |
31 | expected_reindexed = expected.reindex(actual.index)[actual.columns]
32 | pd.testing.assert_frame_equal(actual, expected_reindexed, **kwargs)
33 |
--------------------------------------------------------------------------------
/gokart/tree/task_info.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 |
5 | import pandas as pd
6 |
7 | from gokart.target import make_target
8 | from gokart.task import TaskOnKart
9 | from gokart.tree.task_info_formatter import make_task_info_tree, make_tree_info, make_tree_info_table_list
10 |
11 |
12 | def make_task_info_as_tree_str(task: TaskOnKart, details: bool = False, abbr: bool = True, ignore_task_names: list[str] | None = None):
13 | """
14 | Return a string representation of the tasks, their statuses/parameters in a dependency tree format
15 |
16 | Parameters
17 | ----------
18 | - task: TaskOnKart
19 | Root task.
20 | - details: bool
21 | Whether or not to output details.
22 | - abbr: bool
23 | Whether or not to simplify tasks information that has already appeared.
24 | - ignore_task_names: list[str] | None
25 | List of task names to ignore.
26 | Returns
27 | -------
28 | - tree_info : str
29 | Formatted task dependency tree.
30 | """
31 | task_info = make_task_info_tree(task, ignore_task_names=ignore_task_names)
32 | result = make_tree_info(task_info=task_info, indent='', last=True, details=details, abbr=abbr, visited_tasks=set())
33 | return result
34 |
35 |
36 | def make_task_info_as_table(task: TaskOnKart, ignore_task_names: list[str] | None = None):
37 | """Return a table containing information about dependent tasks.
38 |
39 | Parameters
40 | ----------
41 | - task: TaskOnKart
42 | Root task.
43 | - ignore_task_names: list[str] | None
44 | List of task names to ignore.
45 | Returns
46 | -------
47 | - task_info_table : pandas.DataFrame
48 | Formatted task dependency table.
49 | """
50 |
51 | task_info = make_task_info_tree(task, ignore_task_names=ignore_task_names)
52 | task_info_table = pd.DataFrame(make_tree_info_table_list(task_info=task_info, visited_tasks=set()))
53 |
54 | return task_info_table
55 |
56 |
57 | def dump_task_info_table(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: list[str] | None = None):
58 | """Dump a table containing information about dependent tasks.
59 |
60 | Parameters
61 | ----------
62 | - task: TaskOnKart
63 | Root task.
64 | - task_info_dump_path: str
65 | Output target file path. Path destination can be `local`, `S3`, or `GCS`.
66 | File extension can be any type that gokart file processor accepts, including `csv`, `pickle`, or `txt`.
67 | See `TaskOnKart.make_target module ` for details.
68 | - ignore_task_names: list[str] | None
69 | List of task names to ignore.
70 | Returns
71 | -------
72 | None
73 | """
74 | task_info_table = make_task_info_as_table(task=task, ignore_task_names=ignore_task_names)
75 |
76 | unique_id = task.make_unique_id()
77 |
78 | task_info_target = make_target(file_path=task_info_dump_path, unique_id=unique_id)
79 | task_info_target.dump(obj=task_info_table, lock_at_dump=False)
80 |
81 |
82 | def dump_task_info_tree(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: list[str] | None = None, use_unique_id: bool = True):
83 | """Dump the task info tree object (TaskInfo) to a pickle file.
84 |
85 | Parameters
86 | ----------
87 | - task: TaskOnKart
88 | Root task.
89 | - task_info_dump_path: str
90 | Output target file path. Path destination can be `local`, `S3`, or `GCS`.
91 | File extension must be '.pkl'.
92 | - ignore_task_names: list[str] | None
93 | List of task names to ignore.
94 | - use_unique_id: bool = True
95 | Whether to use unique id to dump target file. Default is True.
96 | Returns
97 | -------
98 | None
99 | """
100 | extension = os.path.splitext(task_info_dump_path)[1]
101 | assert extension == '.pkl', f'File extention must be `.pkl`, not `{extension}`.'
102 |
103 | task_info_tree = make_task_info_tree(task, ignore_task_names=ignore_task_names)
104 |
105 | unique_id = task.make_unique_id() if use_unique_id else None
106 |
107 | task_info_target = make_target(file_path=task_info_dump_path, unique_id=unique_id)
108 | task_info_target.dump(obj=task_info_tree, lock_at_dump=False)
109 |
--------------------------------------------------------------------------------
/gokart/tree/task_info_formatter.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import typing
4 | import warnings
5 | from dataclasses import dataclass
6 | from typing import NamedTuple
7 |
8 | from gokart.task import TaskOnKart
9 | from gokart.utils import FlattenableItems, flatten
10 |
11 |
12 | @dataclass
13 | class TaskInfo:
14 | name: str
15 | unique_id: str
16 | output_paths: list[str]
17 | params: dict
18 | processing_time: str
19 | is_complete: str
20 | task_log: dict
21 | requires: FlattenableItems[RequiredTask]
22 | children_task_infos: list[TaskInfo]
23 |
24 | def get_task_id(self):
25 | return f'{self.name}_{self.unique_id}'
26 |
27 | def get_task_title(self):
28 | return f'({self.is_complete}) {self.name}[{self.unique_id}]'
29 |
30 | def get_task_detail(self):
31 | return f'(parameter={self.params}, output={self.output_paths}, time={self.processing_time}, task_log={self.task_log})'
32 |
33 | def task_info_dict(self):
34 | return dict(
35 | name=self.name,
36 | unique_id=self.unique_id,
37 | output_paths=self.output_paths,
38 | params=self.params,
39 | processing_time=self.processing_time,
40 | is_complete=self.is_complete,
41 | task_log=self.task_log,
42 | requires=self.requires,
43 | )
44 |
45 |
46 | class RequiredTask(NamedTuple):
47 | name: str
48 | unique_id: str
49 |
50 |
51 | def _make_requires_info(requires):
52 | if isinstance(requires, TaskOnKart):
53 | return RequiredTask(name=requires.__class__.__name__, unique_id=requires.make_unique_id())
54 | elif isinstance(requires, dict):
55 | return {key: _make_requires_info(requires=item) for key, item in requires.items()}
56 | elif isinstance(requires, typing.Iterable):
57 | return [_make_requires_info(requires=item) for item in requires]
58 |
59 | raise TypeError(f'`requires` has unexpected type {type(requires)}. Must be `TaskOnKart`, `Iterarble[TaskOnKart]`, or `Dict[str, TaskOnKart]`')
60 |
61 |
62 | def make_task_info_tree(task: TaskOnKart, ignore_task_names: list[str] | None = None, cache: dict[str, TaskInfo] | None = None) -> TaskInfo:
63 | with warnings.catch_warnings():
64 | warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete() method')
65 | is_task_complete = task.complete()
66 |
67 | name = task.__class__.__name__
68 | unique_id = task.make_unique_id()
69 | output_paths: list[str] = [t.path() for t in flatten(task.output())]
70 |
71 | cache = {} if cache is None else cache
72 | cache_id = f'{name}_{unique_id}_{is_task_complete}'
73 | if cache_id in cache:
74 | return cache[cache_id]
75 |
76 | params = task.get_info(only_significant=True)
77 | processing_time = task.get_processing_time()
78 | if isinstance(processing_time, float):
79 | processing_time = str(processing_time) + 's'
80 | is_complete = 'COMPLETE' if is_task_complete else 'PENDING'
81 | task_log = dict(task.get_task_log())
82 | requires = _make_requires_info(task.requires())
83 |
84 | children = flatten(task.requires())
85 | children_task_infos: list[TaskInfo] = []
86 | for child in children:
87 | if ignore_task_names is None or child.__class__.__name__ not in ignore_task_names:
88 | children_task_infos.append(make_task_info_tree(child, ignore_task_names=ignore_task_names, cache=cache))
89 | task_info = TaskInfo(
90 | name=name,
91 | unique_id=unique_id,
92 | output_paths=output_paths,
93 | params=params,
94 | processing_time=processing_time,
95 | is_complete=is_complete,
96 | task_log=task_log,
97 | requires=requires,
98 | children_task_infos=children_task_infos,
99 | )
100 | cache[cache_id] = task_info
101 | return task_info
102 |
103 |
104 | def make_tree_info(task_info: TaskInfo, indent: str, last: bool, details: bool, abbr: bool, visited_tasks: set[str]):
105 | result = '\n' + indent
106 | if last:
107 | result += '└─-'
108 | indent += ' '
109 | else:
110 | result += '|--'
111 | indent += '| '
112 | result += task_info.get_task_title()
113 |
114 | if abbr:
115 | task_id = task_info.get_task_id()
116 | if task_id not in visited_tasks:
117 | visited_tasks.add(task_id)
118 | else:
119 | result += f'\n{indent}└─- ...'
120 | return result
121 |
122 | if details:
123 | result += task_info.get_task_detail()
124 |
125 | children = task_info.children_task_infos
126 | for index, child in enumerate(children):
127 | result += make_tree_info(child, indent, (index + 1) == len(children), details=details, abbr=abbr, visited_tasks=visited_tasks)
128 | return result
129 |
130 |
131 | def make_tree_info_table_list(task_info: TaskInfo, visited_tasks: set[str]):
132 | task_id = task_info.get_task_id()
133 | if task_id in visited_tasks:
134 | return []
135 | visited_tasks.add(task_id)
136 |
137 | result = [task_info.task_info_dict()]
138 |
139 | children = task_info.children_task_infos
140 | for child in children:
141 | result += make_tree_info_table_list(task_info=child, visited_tasks=visited_tasks)
142 | return result
143 |
--------------------------------------------------------------------------------
/gokart/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import sys
5 | from collections.abc import Iterable
6 | from io import BytesIO
7 | from typing import Any, Callable, Protocol, TypeVar, Union
8 |
9 | import dill
10 | import luigi
11 | import pandas as pd
12 |
13 |
14 | class FileLike(Protocol):
15 | def read(self, n: int) -> bytes: ...
16 |
17 | def readline(self) -> bytes: ...
18 |
19 | def seek(self, offset: int) -> None: ...
20 |
21 | def seekable(self) -> bool: ...
22 |
23 |
24 | def add_config(file_path: str):
25 | _, ext = os.path.splitext(file_path)
26 | luigi.configuration.core.parser = ext # type: ignore
27 | assert luigi.configuration.add_config_path(file_path)
28 |
29 |
30 | T = TypeVar('T')
31 | if sys.version_info >= (3, 10):
32 | from typing import TypeAlias
33 |
34 | FlattenableItems: TypeAlias = T | Iterable['FlattenableItems[T]'] | dict[str, 'FlattenableItems[T]']
35 | else:
36 | FlattenableItems = Union[T, Iterable['FlattenableItems[T]'], dict[str, 'FlattenableItems[T]']]
37 |
38 |
39 | def flatten(targets: FlattenableItems[T]) -> list[T]:
40 | """
41 | Creates a flat list of all items in structured output (dicts, lists, items):
42 |
43 | .. code-block:: python
44 |
45 | >>> sorted(flatten({'a': 'foo', 'b': 'bar'}))
46 | ['bar', 'foo']
47 | >>> sorted(flatten(['foo', ['bar', 'troll']]))
48 | ['bar', 'foo', 'troll']
49 | >>> flatten('foo')
50 | ['foo']
51 | >>> flatten(42)
52 | [42]
53 |
54 | This method is copied and modified from [luigi.task.flatten](https://github.com/spotify/luigi/blob/367edc2e3a099b8a0c2d15b1676269e33ad06117/luigi/task.py#L958) in accordance with [Apache License 2.0](https://github.com/spotify/luigi/blob/367edc2e3a099b8a0c2d15b1676269e33ad06117/LICENSE).
55 | """
56 | if targets is None:
57 | return []
58 | flat = []
59 | if isinstance(targets, dict):
60 | for _, result in targets.items():
61 | flat += flatten(result)
62 | return flat
63 |
64 | if isinstance(targets, str):
65 | return [targets] # type: ignore
66 |
67 | if not isinstance(targets, Iterable):
68 | return [targets]
69 |
70 | for result in targets:
71 | flat += flatten(result)
72 | return flat
73 |
74 |
75 | K = TypeVar('K')
76 |
77 |
78 | def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]:
79 | if isinstance(items, dict):
80 | return {k: map_flattenable_items(func, v) for k, v in items.items()}
81 | if isinstance(items, tuple):
82 | return tuple(map_flattenable_items(func, i) for i in items)
83 | if isinstance(items, str):
84 | return func(items) # type: ignore
85 | if isinstance(items, Iterable):
86 | return list(map(lambda item: map_flattenable_items(func, item), items))
87 | return func(items)
88 |
89 |
90 | def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> Any:
91 | """Load binary dumped by dill with pandas backward compatibility.
92 | pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle.
93 | It is unclear whether all objects dumped by dill can be loaded by pd.read_pickle, we use dill.load as a fallback.
94 | """
95 | try:
96 | return dill.load(file)
97 | except Exception:
98 | assert file.seekable(), f'{file} is not seekable.'
99 | file.seek(0)
100 | return pd.read_pickle(file)
101 |
--------------------------------------------------------------------------------
/gokart/workspace_management.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import itertools
4 | import os
5 | import pathlib
6 | from logging import getLogger
7 |
8 | import gokart
9 | from gokart.utils import flatten
10 |
11 | logger = getLogger(__name__)
12 |
13 |
14 | def _get_all_output_file_paths(task: gokart.TaskOnKart):
15 | output_paths = [t.path() for t in flatten(task.output())]
16 | children = flatten(task.requires())
17 | output_paths.extend(itertools.chain.from_iterable([_get_all_output_file_paths(child) for child in children]))
18 | return output_paths
19 |
20 |
21 | def delete_local_unnecessary_outputs(task: gokart.TaskOnKart):
22 | task.make_unique_id() # this is required to make unique ids.
23 | all_files = {str(path) for path in pathlib.Path(task.workspace_directory).rglob('*.*')}
24 | log_files = {str(path) for path in pathlib.Path(os.path.join(task.workspace_directory, 'log')).rglob('*.*')}
25 | necessary_files = set(_get_all_output_file_paths(task))
26 | unnecessary_files = all_files - necessary_files - log_files
27 | if len(unnecessary_files) == 0:
28 | logger.info('all files are necessary for this task.')
29 | else:
30 | logger.info(f'remove following files: {os.linesep} {os.linesep.join(unnecessary_files)}')
31 | for file in unnecessary_files:
32 | os.remove(file)
33 |
--------------------------------------------------------------------------------
/gokart/zip_client.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import shutil
5 | import zipfile
6 | from abc import abstractmethod
7 | from typing import IO
8 |
9 |
10 | def _unzip_file(fp: str | IO | os.PathLike, extract_dir: str) -> None:
11 | zip_file = zipfile.ZipFile(fp)
12 | zip_file.extractall(extract_dir)
13 | zip_file.close()
14 |
15 |
16 | class ZipClient:
17 | @abstractmethod
18 | def exists(self) -> bool:
19 | pass
20 |
21 | @abstractmethod
22 | def make_archive(self) -> None:
23 | pass
24 |
25 | @abstractmethod
26 | def unpack_archive(self) -> None:
27 | pass
28 |
29 | @abstractmethod
30 | def remove(self) -> None:
31 | pass
32 |
33 | @property
34 | @abstractmethod
35 | def path(self) -> str:
36 | pass
37 |
38 |
39 | class LocalZipClient(ZipClient):
40 | def __init__(self, file_path: str, temporary_directory: str) -> None:
41 | self._file_path = file_path
42 | self._temporary_directory = temporary_directory
43 |
44 | def exists(self) -> bool:
45 | return os.path.exists(self._file_path)
46 |
47 | def make_archive(self) -> None:
48 | [base, extension] = os.path.splitext(self._file_path)
49 | shutil.make_archive(base_name=base, format=extension[1:], root_dir=self._temporary_directory)
50 |
51 | def unpack_archive(self) -> None:
52 | _unzip_file(fp=self._file_path, extract_dir=self._temporary_directory)
53 |
54 | def remove(self) -> None:
55 | shutil.rmtree(self._file_path, ignore_errors=True)
56 |
57 | @property
58 | def path(self) -> str:
59 | return self._file_path
60 |
--------------------------------------------------------------------------------
/gokart/zip_client_util.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from gokart.object_storage import ObjectStorage
4 | from gokart.zip_client import LocalZipClient, ZipClient
5 |
6 |
7 | def make_zip_client(file_path: str, temporary_directory: str) -> ZipClient:
8 | if ObjectStorage.if_object_storage_path(file_path):
9 | return ObjectStorage.get_zip_client(file_path=file_path, temporary_directory=temporary_directory)
10 | return LocalZipClient(file_path=file_path, temporary_directory=temporary_directory)
11 |
--------------------------------------------------------------------------------
/luigi.cfg:
--------------------------------------------------------------------------------
1 | [core]
2 | autoload_range: false
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling", "uv-dynamic-versioning"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "gokart"
7 | description="Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline. [Documentation](https://gokart.readthedocs.io/en/latest/)"
8 | authors = [
9 | {name = "M3, inc."}
10 | ]
11 | license = "MIT"
12 | readme = "README.md"
13 | requires-python = ">=3.9, <4"
14 | dependencies = [
15 | "luigi",
16 | "boto3",
17 | "slack-sdk",
18 | "pandas",
19 | "numpy",
20 | "google-auth",
21 | "pyarrow",
22 | "uritemplate",
23 | "google-api-python-client",
24 | "APScheduler",
25 | "redis",
26 | "dill",
27 | "backoff",
28 | "typing-extensions>=4.11.0; python_version<'3.13'",
29 | ]
30 | classifiers = [
31 | "Development Status :: 5 - Production/Stable",
32 | "License :: OSI Approved :: MIT License",
33 | "Programming Language :: Python :: 3",
34 | "Programming Language :: Python :: 3.9",
35 | "Programming Language :: Python :: 3.10",
36 | "Programming Language :: Python :: 3.11",
37 | "Programming Language :: Python :: 3.12",
38 | "Programming Language :: Python :: 3.13",
39 |
40 | ]
41 | dynamic = ["version"]
42 |
43 | [project.urls]
44 | Homepage = "https://github.com/m3dev/gokart"
45 | Repository = "https://github.com/m3dev/gokart"
46 | Documentation = "https://gokart.readthedocs.io/en/latest/"
47 |
48 | [dependency-groups]
49 | test = [
50 | "fakeredis",
51 | "lupa",
52 | "matplotlib",
53 | "moto",
54 | "mypy",
55 | "pytest",
56 | "pytest-cov",
57 | "pytest-xdist",
58 | "testfixtures",
59 | "toml",
60 | "types-redis",
61 | "typing-extensions>=4.11.0",
62 | ]
63 |
64 | lint = [
65 | "ruff",
66 | "mypy",
67 | ]
68 |
69 | [tool.uv]
70 | default-groups = ['test', 'lint']
71 | cache-keys = [ { file = "pyproject.toml" }, { git = true } ]
72 |
73 | [tool.hatch.version]
74 | source = "uv-dynamic-versioning"
75 |
76 | [tool.uv-dynamic-versioning]
77 | enable = true
78 |
79 | [tool.hatch.build.targets.sdist]
80 | include = [
81 | "/LICENSE",
82 | "/README.md",
83 | "/examples",
84 | "/gokart",
85 | "/test",
86 | ]
87 |
88 | [tool.ruff]
89 | line-length = 160
90 | exclude = ["venv/*", "tox/*", "examples/*"]
91 |
92 | [tool.ruff.lint]
93 | # All the rules are listed on https://docs.astral.sh/ruff/rules/
94 | extend-select = [
95 | "B", # bugbear
96 | "I", # isort
97 | "UP", # pyupgrade, upgrade syntax for newer versions of the language.
98 | ]
99 |
100 | # B006: Do not use mutable data structures for argument defaults. They are created during function definition time. All calls to the function reuse this one instance of that data structure, persisting changes between them.
101 | # B008 Do not perform function calls in argument defaults. The call is performed only once at function definition time. All calls to your function will reuse the result of that definition-time function call. If this is intended, assign the function call to a module-level variable and use that variable as a default value.
102 | ignore = ["B006", "B008"]
103 |
104 | [tool.ruff.format]
105 | quote-style = "single"
106 |
107 | [tool.mypy]
108 | ignore_missing_imports = true
109 | plugins = ["gokart.mypy:plugin"]
110 |
111 | check_untyped_defs = true
112 |
113 | [tool.pytest.ini_options]
114 | testpaths = ["test"]
115 | addopts = "-n auto -s -v --durations=0"
116 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/__init__.py
--------------------------------------------------------------------------------
/test/config/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Final
3 |
4 | CONFIG_DIR: Final[Path] = Path(__file__).parent.resolve()
5 | PYPROJECT_TOML: Final[Path] = CONFIG_DIR / 'pyproject.toml'
6 | PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS: Final[Path] = CONFIG_DIR / 'pyproject_disallow_missing_parameters.toml'
7 | TEST_CONFIG_INI: Final[Path] = CONFIG_DIR / 'test_config.ini'
8 |
--------------------------------------------------------------------------------
/test/config/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.mypy]
2 | plugins = ["gokart.mypy"]
3 |
4 | [[tool.mypy.overrides]]
5 | ignore_missing_imports = true
6 | module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"]
7 |
--------------------------------------------------------------------------------
/test/config/pyproject_disallow_missing_parameters.toml:
--------------------------------------------------------------------------------
1 | [tool.mypy]
2 | plugins = ["gokart.mypy"]
3 |
4 | [[tool.mypy.overrides]]
5 | ignore_missing_imports = true
6 | module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"]
7 |
8 | [tool.gokart-mypy]
9 | disallow_missing_parameters = true
10 |
--------------------------------------------------------------------------------
/test/config/test_config.ini:
--------------------------------------------------------------------------------
1 | [test_read_config._DummyTask]
2 | param = ${test_param}
3 |
4 | [test_build._DummyTask]
5 | param = ${test_param}
--------------------------------------------------------------------------------
/test/conflict_prevention_lock/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/conflict_prevention_lock/__init__.py
--------------------------------------------------------------------------------
/test/conflict_prevention_lock/test_task_lock.py:
--------------------------------------------------------------------------------
1 | import random
2 | import unittest
3 | from unittest.mock import patch
4 |
5 | import gokart
6 | from gokart.conflict_prevention_lock.task_lock import RedisClient, TaskLockParams, make_task_lock_key, make_task_lock_params, make_task_lock_params_for_run
7 |
8 |
9 | class TestRedisClient(unittest.TestCase):
10 | @staticmethod
11 | def _get_randint(host, port):
12 | return random.randint(0, 100000)
13 |
14 | def test_redis_client_is_singleton(self):
15 | with patch('redis.Redis') as mock:
16 | mock.side_effect = self._get_randint
17 |
18 | redis_client_0_0 = RedisClient(host='host_0', port=123)
19 | redis_client_1 = RedisClient(host='host_1', port=123)
20 | redis_client_0_1 = RedisClient(host='host_0', port=123)
21 |
22 | self.assertNotEqual(redis_client_0_0, redis_client_1)
23 | self.assertEqual(redis_client_0_0, redis_client_0_1)
24 |
25 | self.assertEqual(redis_client_0_0.get_redis_client(), redis_client_0_1.get_redis_client())
26 |
27 |
28 | class TestMakeRedisKey(unittest.TestCase):
29 | def test_make_redis_key(self):
30 | result = make_task_lock_key(file_path='gs://test_ll/dir/fname.pkl', unique_id='12345')
31 | self.assertEqual(result, 'fname_12345')
32 |
33 |
34 | class TestMakeRedisParams(unittest.TestCase):
35 | def test_make_task_lock_params_with_valid_host(self):
36 | result = make_task_lock_params(
37 | file_path='gs://aaa.pkl', unique_id='123', redis_host='0.0.0.0', redis_port=12345, redis_timeout=180, raise_task_lock_exception_on_collision=False
38 | )
39 | expected = TaskLockParams(
40 | redis_host='0.0.0.0',
41 | redis_port=12345,
42 | redis_key='aaa_123',
43 | should_task_lock=True,
44 | redis_timeout=180,
45 | raise_task_lock_exception_on_collision=False,
46 | lock_extend_seconds=10,
47 | )
48 | self.assertEqual(result, expected)
49 |
50 | def test_make_task_lock_params_with_no_host(self):
51 | result = make_task_lock_params(
52 | file_path='gs://aaa.pkl', unique_id='123', redis_host=None, redis_port=12345, redis_timeout=180, raise_task_lock_exception_on_collision=False
53 | )
54 | expected = TaskLockParams(
55 | redis_host=None,
56 | redis_port=12345,
57 | redis_key='aaa_123',
58 | should_task_lock=False,
59 | redis_timeout=180,
60 | raise_task_lock_exception_on_collision=False,
61 | lock_extend_seconds=10,
62 | )
63 | self.assertEqual(result, expected)
64 |
65 | def test_assert_when_redis_timeout_is_too_short(self):
66 | with self.assertRaises(AssertionError):
67 | make_task_lock_params(
68 | file_path='test_dir/test_file.pkl',
69 | unique_id='123abc',
70 | redis_host='0.0.0.0',
71 | redis_port=12345,
72 | redis_timeout=2,
73 | )
74 |
75 |
76 | class TestMakeTaskLockParamsForRun(unittest.TestCase):
77 | def test_make_task_lock_params_for_run(self):
78 | class _SampleDummyTask(gokart.TaskOnKart):
79 | pass
80 |
81 | task_self = _SampleDummyTask(
82 | redis_host='0.0.0.0',
83 | redis_port=12345,
84 | redis_timeout=180,
85 | )
86 |
87 | result = make_task_lock_params_for_run(task_self=task_self, lock_extend_seconds=10)
88 | expected = TaskLockParams(
89 | redis_host='0.0.0.0',
90 | redis_port=12345,
91 | redis_timeout=180,
92 | redis_key='_SampleDummyTask_7e857f231830ca0fd6cf829d99f43961-run',
93 | should_task_lock=True,
94 | raise_task_lock_exception_on_collision=True,
95 | lock_extend_seconds=10,
96 | )
97 |
98 | self.assertEqual(result, expected)
99 |
--------------------------------------------------------------------------------
/test/in_memory/test_in_memory_target.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from time import sleep
3 |
4 | import pytest
5 |
6 | from gokart.conflict_prevention_lock.task_lock import TaskLockParams
7 | from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget, make_in_memory_target
8 |
9 |
10 | class TestInMemoryTarget:
11 | @pytest.fixture
12 | def task_lock_params(self) -> TaskLockParams:
13 | return TaskLockParams(
14 | redis_host=None,
15 | redis_port=None,
16 | redis_timeout=None,
17 | redis_key='dummy',
18 | should_task_lock=False,
19 | raise_task_lock_exception_on_collision=False,
20 | lock_extend_seconds=0,
21 | )
22 |
23 | @pytest.fixture
24 | def target(self, task_lock_params: TaskLockParams) -> InMemoryTarget:
25 | return make_in_memory_target(target_key='dummy_key', task_lock_params=task_lock_params)
26 |
27 | @pytest.fixture(autouse=True)
28 | def clear_repo(self) -> None:
29 | InMemoryCacheRepository().clear()
30 |
31 | def test_dump_and_load_data(self, target: InMemoryTarget):
32 | dumped = 'dummy_data'
33 | target.dump(dumped)
34 | loaded = target.load()
35 | assert loaded == dumped
36 |
37 | def test_exist(self, target: InMemoryTarget):
38 | assert not target.exists()
39 | target.dump('dummy_data')
40 | assert target.exists()
41 |
42 | def test_last_modified_time(self, target: InMemoryTarget):
43 | input = 'dummy_data'
44 | target.dump(input)
45 | time = target.last_modification_time()
46 | assert isinstance(time, datetime)
47 |
48 | sleep(0.1)
49 | another_input = 'another_data'
50 | target.dump(another_input)
51 | another_time = target.last_modification_time()
52 | assert time < another_time
53 |
54 | target.remove()
55 | with pytest.raises(ValueError):
56 | assert target.last_modification_time()
57 |
--------------------------------------------------------------------------------
/test/in_memory/test_repository.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import pytest
4 |
5 | from gokart.in_memory import InMemoryCacheRepository as Repo
6 |
7 | dummy_num = 100
8 |
9 |
10 | class TestInMemoryCacheRepository:
11 | @pytest.fixture
12 | def repo(self) -> Repo:
13 | repo = Repo()
14 | repo.clear()
15 | return repo
16 |
17 | def test_set(self, repo: Repo):
18 | repo.set_value('dummy_key', dummy_num)
19 | assert repo.size == 1
20 | for key, value in repo.get_gen():
21 | assert (key, value) == ('dummy_key', dummy_num)
22 |
23 | repo.set_value('another_key', 'another_value')
24 | assert repo.size == 2
25 |
26 | def test_get(self, repo: Repo):
27 | repo.set_value('dummy_key', dummy_num)
28 | repo.set_value('another_key', 'another_value')
29 |
30 | """Raise Error when key doesn't exist."""
31 | with pytest.raises(KeyError):
32 | repo.get_value('not_exist_key')
33 |
34 | assert repo.get_value('dummy_key') == dummy_num
35 | assert repo.get_value('another_key') == 'another_value'
36 |
37 | def test_empty(self, repo: Repo):
38 | assert repo.empty()
39 | repo.set_value('dummmy_key', dummy_num)
40 | assert not repo.empty()
41 |
42 | def test_has(self, repo: Repo):
43 | assert not repo.has('dummy_key')
44 | repo.set_value('dummy_key', dummy_num)
45 | assert repo.has('dummy_key')
46 | assert not repo.has('not_exist_key')
47 |
48 | def test_remove(self, repo: Repo):
49 | repo.set_value('dummy_key', dummy_num)
50 |
51 | with pytest.raises(AssertionError):
52 | repo.remove('not_exist_key')
53 |
54 | repo.remove('dummy_key')
55 | assert not repo.has('dummy_key')
56 |
57 | def test_last_modification_time(self, repo: Repo):
58 | repo.set_value('dummy_key', dummy_num)
59 | date1 = repo.get_last_modification_time('dummy_key')
60 | time.sleep(0.1)
61 | repo.set_value('dummy_key', dummy_num)
62 | date2 = repo.get_last_modification_time('dummy_key')
63 | assert date1 < date2
64 |
--------------------------------------------------------------------------------
/test/slack/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/slack/__init__.py
--------------------------------------------------------------------------------
/test/slack/test_slack_api.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from logging import getLogger
3 | from unittest import mock
4 | from unittest.mock import MagicMock
5 |
6 | from slack_sdk import WebClient
7 | from slack_sdk.web.slack_response import SlackResponse
8 | from testfixtures import LogCapture
9 |
10 | import gokart.slack
11 |
12 | logger = getLogger(__name__)
13 |
14 |
15 | def _slack_response(token, data):
16 | return SlackResponse(
17 | client=WebClient(token=token), http_verb='POST', api_url='http://localhost:3000/api.test', req_args={}, data=data, headers={}, status_code=200
18 | )
19 |
20 |
21 | class TestSlackAPI(unittest.TestCase):
22 | @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient')
23 | def test_initialization_with_invalid_token(self, patch):
24 | def _conversations_list(params={}):
25 | return _slack_response(token='invalid', data={'ok': False, 'error': 'error_reason'})
26 |
27 | mock_client = MagicMock()
28 | mock_client.conversations_list = MagicMock(side_effect=_conversations_list)
29 | patch.return_value = mock_client
30 |
31 | with LogCapture() as log:
32 | gokart.slack.SlackAPI(token='invalid', channel='test', to_user='test user')
33 | log.check(('gokart.slack.slack_api', 'WARNING', 'The job will start without slack notification: Channel test is not found in public channels.'))
34 |
35 | @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient')
36 | def test_invalid_channel(self, patch):
37 | def _conversations_list(params={}):
38 | return _slack_response(
39 | token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}}
40 | )
41 |
42 | mock_client = MagicMock()
43 | mock_client.conversations_list = MagicMock(side_effect=_conversations_list)
44 | patch.return_value = mock_client
45 |
46 | with LogCapture() as log:
47 | gokart.slack.SlackAPI(token='valid', channel='invalid_channel', to_user='test user')
48 | log.check(
49 | ('gokart.slack.slack_api', 'WARNING', 'The job will start without slack notification: Channel invalid_channel is not found in public channels.')
50 | )
51 |
52 | @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient')
53 | def test_send_snippet_with_invalid_token(self, patch):
54 | def _conversations_list(params={}):
55 | return _slack_response(
56 | token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}}
57 | )
58 |
59 | def _api_call(method, data={}):
60 | assert method == 'files.upload'
61 | return {'ok': False, 'error': 'error_reason'}
62 |
63 | mock_client = MagicMock()
64 | mock_client.conversations_list = MagicMock(side_effect=_conversations_list)
65 | mock_client.api_call = MagicMock(side_effect=_api_call)
66 | patch.return_value = mock_client
67 |
68 | with LogCapture() as log:
69 | api = gokart.slack.SlackAPI(token='valid', channel='valid', to_user='test user')
70 | api.send_snippet(comment='test', title='title', content='content')
71 | log.check(
72 | ('gokart.slack.slack_api', 'WARNING', 'Failed to send slack notification: Error while uploading file. The error reason is "error_reason".')
73 | )
74 |
75 | @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient')
76 | def test_send(self, patch):
77 | def _conversations_list(params={}):
78 | return _slack_response(
79 | token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}}
80 | )
81 |
82 | def _api_call(method, data={}):
83 | assert method == 'files.upload'
84 | return {'ok': False, 'error': 'error_reason'}
85 |
86 | mock_client = MagicMock()
87 | mock_client.conversations_list = MagicMock(side_effect=_conversations_list)
88 | mock_client.api_call = MagicMock(side_effect=_api_call)
89 | patch.return_value = mock_client
90 |
91 | api = gokart.slack.SlackAPI(token='valid', channel='valid', to_user='test user')
92 | api.send_snippet(comment='test', title='title', content='content')
93 |
94 |
95 | if __name__ == '__main__':
96 | unittest.main()
97 |
--------------------------------------------------------------------------------
/test/test_cache_unique_id.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 |
4 | import luigi
5 | import luigi.mock
6 |
7 | import gokart
8 |
9 |
10 | class _DummyTask(gokart.TaskOnKart):
11 | def requires(self):
12 | return _DummyTaskDep()
13 |
14 | def run(self):
15 | self.dump(self.load())
16 |
17 |
18 | class _DummyTaskDep(gokart.TaskOnKart):
19 | param = luigi.Parameter()
20 |
21 | def run(self):
22 | self.dump(self.param)
23 |
24 |
25 | class CacheUniqueIDTest(unittest.TestCase):
26 | def setUp(self):
27 | luigi.configuration.LuigiConfigParser._instance = None
28 | luigi.mock.MockFileSystem().clear()
29 | os.environ.clear()
30 |
31 | def test_cache_unique_id_true(self):
32 | _DummyTaskDep.param = luigi.Parameter(default='original_param')
33 |
34 | output1 = gokart.build(_DummyTask(cache_unique_id=True), reset_register=False)
35 |
36 | _DummyTaskDep.param = luigi.Parameter(default='updated_param')
37 | output2 = gokart.build(_DummyTask(cache_unique_id=True), reset_register=False)
38 | self.assertEqual(output1, output2)
39 |
40 | def test_cache_unique_id_false(self):
41 | _DummyTaskDep.param = luigi.Parameter(default='original_param')
42 |
43 | output1 = gokart.build(_DummyTask(cache_unique_id=False), reset_register=False)
44 |
45 | _DummyTaskDep.param = luigi.Parameter(default='updated_param')
46 | output2 = gokart.build(_DummyTask(cache_unique_id=False), reset_register=False)
47 | self.assertNotEqual(output1, output2)
48 |
49 |
50 | if __name__ == '__main__':
51 | unittest.main()
52 |
--------------------------------------------------------------------------------
/test/test_config_params.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import luigi
4 | from luigi.cmdline_parser import CmdlineParser
5 |
6 | import gokart
7 | from gokart.config_params import inherits_config_params
8 |
9 |
10 | def in_parse(cmds, deferred_computation):
11 | """function copied from luigi: https://github.com/spotify/luigi/blob/e2228418eec60b68ca09a30c878ab26413846847/test/helpers.py"""
12 | with CmdlineParser.global_instance(cmds) as cp:
13 | deferred_computation(cp.get_task_obj())
14 |
15 |
16 | class ConfigClass(luigi.Config):
17 | param_a = luigi.Parameter(default='config a')
18 | param_b = luigi.Parameter(default='config b')
19 | param_c = luigi.Parameter(default='config c')
20 |
21 |
22 | @inherits_config_params(ConfigClass)
23 | class Inherited(gokart.TaskOnKart):
24 | param_a = luigi.Parameter()
25 | param_b = luigi.Parameter(default='overrided')
26 |
27 |
28 | @inherits_config_params(ConfigClass, parameter_alias={'param_a': 'param_d'})
29 | class Inherited2(gokart.TaskOnKart):
30 | param_c = luigi.Parameter()
31 | param_d = luigi.Parameter()
32 |
33 |
34 | class ChildTask(Inherited):
35 | pass
36 |
37 |
38 | class ChildTaskWithNewParam(Inherited):
39 | param_new = luigi.Parameter()
40 |
41 |
42 | class ConfigClass2(luigi.Config):
43 | param_a = luigi.Parameter(default='config a from config class 2')
44 |
45 |
46 | @inherits_config_params(ConfigClass2)
47 | class ChildTaskWithNewConfig(Inherited):
48 | pass
49 |
50 |
51 | class TestInheritsConfigParam(unittest.TestCase):
52 | def test_inherited_params(self):
53 | # test fill values
54 | in_parse(['Inherited'], lambda task: self.assertEqual(task.param_a, 'config a'))
55 |
56 | # test overrided
57 | in_parse(['Inherited'], lambda task: self.assertEqual(task.param_b, 'config b'))
58 |
59 | # Command line argument takes precedence over config param
60 | in_parse(['Inherited', '--param-a', 'command line arg'], lambda task: self.assertEqual(task.param_a, 'command line arg'))
61 |
62 | # Parameters which is not a member of the task will not be set
63 | with self.assertRaises(AttributeError):
64 | in_parse(['Inherited'], lambda task: task.param_c)
65 |
66 | # test parameter name alias
67 | in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_c, 'config c'))
68 | in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_d, 'config a'))
69 |
70 | def test_child_task(self):
71 | in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_a, 'config a'))
72 | in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_b, 'config b'))
73 | in_parse(['ChildTask', '--param-a', 'command line arg'], lambda task: self.assertEqual(task.param_a, 'command line arg'))
74 | with self.assertRaises(AttributeError):
75 | in_parse(['ChildTask'], lambda task: task.param_c)
76 |
77 | def test_child_override(self):
78 | in_parse(['ChildTaskWithNewConfig'], lambda task: self.assertEqual(task.param_a, 'config a from config class 2'))
79 | in_parse(['ChildTaskWithNewConfig'], lambda task: self.assertEqual(task.param_b, 'config b'))
80 |
--------------------------------------------------------------------------------
/test/test_explicit_bool_parameter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import luigi
4 | import luigi.mock
5 | from luigi.cmdline_parser import CmdlineParser
6 |
7 | import gokart
8 |
9 |
10 | def in_parse(cmds, deferred_computation):
11 | with CmdlineParser.global_instance(cmds) as cp:
12 | deferred_computation(cp.get_task_obj())
13 |
14 |
15 | class WithDefaultTrue(gokart.TaskOnKart):
16 | param = gokart.ExplicitBoolParameter(default=True)
17 |
18 |
19 | class WithDefaultFalse(gokart.TaskOnKart):
20 | param = gokart.ExplicitBoolParameter(default=False)
21 |
22 |
23 | class ExplicitParsing(gokart.TaskOnKart):
24 | param = gokart.ExplicitBoolParameter()
25 |
26 | def run(self):
27 | ExplicitParsing._param = self.param # type: ignore
28 |
29 |
30 | class TestExplicitBoolParameter(unittest.TestCase):
31 | def test_bool_default(self):
32 | self.assertTrue(WithDefaultTrue().param)
33 | self.assertFalse(WithDefaultFalse().param)
34 |
35 | def test_parse_param(self):
36 | in_parse(['ExplicitParsing', '--param', 'true'], lambda task: self.assertTrue(task.param))
37 | in_parse(['ExplicitParsing', '--param', 'false'], lambda task: self.assertFalse(task.param))
38 | in_parse(['ExplicitParsing', '--param', 'True'], lambda task: self.assertTrue(task.param))
39 | in_parse(['ExplicitParsing', '--param', 'False'], lambda task: self.assertFalse(task.param))
40 |
41 | def test_missing_parameter(self):
42 | with self.assertRaises(luigi.parameter.MissingParameterException):
43 | in_parse(['ExplicitParsing'], lambda: True)
44 |
45 | def test_value_error(self):
46 | with self.assertRaises(ValueError):
47 | in_parse(['ExplicitParsing', '--param', 'Foo'], lambda: True)
48 |
49 | def test_expected_one_argment_error(self):
50 | # argparse throw "expected one argument" error
51 | with self.assertRaises(SystemExit):
52 | in_parse(['ExplicitParsing', '--param'], lambda: True)
53 |
--------------------------------------------------------------------------------
/test/test_gcs_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | from unittest.mock import MagicMock, patch
4 |
5 | from gokart.gcs_config import GCSConfig
6 |
7 |
8 | class TestGCSConfig(unittest.TestCase):
9 | def test_get_gcs_client_without_gcs_credential_name(self):
10 | mock = MagicMock()
11 | os.environ['env_name'] = ''
12 | with patch('luigi.contrib.gcs.GCSClient', mock):
13 | GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
14 | self.assertEqual(dict(oauth_credentials=None), mock.call_args[1])
15 |
16 | def test_get_gcs_client_with_file_path(self):
17 | mock = MagicMock()
18 | file_path = 'test.json'
19 | os.environ['env_name'] = file_path
20 | with patch('luigi.contrib.gcs.GCSClient'):
21 | with patch('google.oauth2.service_account.Credentials.from_service_account_file', mock):
22 | with patch('os.path.isfile', return_value=True):
23 | GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
24 | self.assertEqual(file_path, mock.call_args[0][0])
25 |
26 | def test_get_gcs_client_with_json(self):
27 | mock = MagicMock()
28 | json_str = '{"test": 1}'
29 | os.environ['env_name'] = json_str
30 | with patch('luigi.contrib.gcs.GCSClient'):
31 | with patch('google.oauth2.service_account.Credentials.from_service_account_info', mock):
32 | GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
33 | self.assertEqual(dict(test=1), mock.call_args[0][0])
34 |
--------------------------------------------------------------------------------
/test/test_gcs_obj_metadata_client.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import datetime
4 | import unittest
5 | from typing import Any
6 | from unittest.mock import MagicMock, patch
7 |
8 | import gokart
9 | from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
10 | from gokart.required_task_output import RequiredTaskOutput
11 | from gokart.target import TargetOnKart
12 |
13 |
14 | class _DummyTaskOnKart(gokart.TaskOnKart):
15 | task_namespace = __name__
16 |
17 | def run(self):
18 | self.dump('Dummy TaskOnKart')
19 |
20 |
21 | class TestGCSObjectMetadataClient(unittest.TestCase):
22 | def setUp(self):
23 | self.task_params: dict[str, str] = {
24 | 'param1': 'a' * 1000,
25 | 'param2': str(1000),
26 | 'param3': str({'key1': 'value1', 'key2': True, 'key3': 2}),
27 | 'param4': str([1, 2, 3, 4, 5]),
28 | 'param5': str(datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5)),
29 | 'param6': '',
30 | }
31 | self.custom_labels: dict[str, Any] = {
32 | 'created_at': datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5),
33 | 'created_by': 'hoge fuga',
34 | 'empty': True,
35 | 'try_num': 3,
36 | }
37 |
38 | self.task_params_with_conflicts = {
39 | 'empty': 'False',
40 | 'created_by': 'fuga hoge',
41 | 'param1': 'a' * 10,
42 | }
43 |
44 | def test_normalize_labels_not_empty(self):
45 | got = GCSObjectMetadataClient._normalize_labels(None)
46 | self.assertEqual(got, {})
47 |
48 | def test_normalize_labels_has_value(self):
49 | got = GCSObjectMetadataClient._normalize_labels(self.task_params)
50 |
51 | self.assertIsInstance(got, dict)
52 | self.assertIsInstance(got, dict)
53 | self.assertIn('param1', got)
54 | self.assertIn('param2', got)
55 | self.assertIn('param3', got)
56 | self.assertIn('param4', got)
57 | self.assertIn('param5', got)
58 | self.assertIn('param6', got)
59 |
60 | def test_get_patched_obj_metadata_only_task_params(self):
61 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=None)
62 |
63 | self.assertIsInstance(got, dict)
64 | self.assertIn('param1', got)
65 | self.assertIn('param2', got)
66 | self.assertIn('param3', got)
67 | self.assertIn('param4', got)
68 | self.assertIn('param5', got)
69 | self.assertNotIn('param6', got)
70 |
71 | def test_get_patched_obj_metadata_only_custom_labels(self):
72 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=None, custom_labels=self.custom_labels)
73 |
74 | self.assertIsInstance(got, dict)
75 | self.assertIn('created_at', got)
76 | self.assertIn('created_by', got)
77 | self.assertIn('empty', got)
78 | self.assertIn('try_num', got)
79 |
80 | def test_get_patched_obj_metadata_with_both_task_params_and_custom_labels(self):
81 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=self.custom_labels)
82 |
83 | self.assertIsInstance(got, dict)
84 | self.assertIn('param1', got)
85 | self.assertIn('param2', got)
86 | self.assertIn('param3', got)
87 | self.assertIn('param4', got)
88 | self.assertIn('param5', got)
89 | self.assertNotIn('param6', got)
90 | self.assertIn('created_at', got)
91 | self.assertIn('created_by', got)
92 | self.assertIn('empty', got)
93 | self.assertIn('try_num', got)
94 |
95 | def test_get_patched_obj_metadata_with_exceeded_size_metadata(self):
96 | size_exceeded_task_params = {
97 | 'param1': 'a' * 5000,
98 | 'param2': 'b' * 5000,
99 | }
100 | want = {
101 | 'param1': 'a' * 5000,
102 | }
103 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=size_exceeded_task_params)
104 | self.assertEqual(got, want)
105 |
106 | def test_get_patched_obj_metadata_with_conflicts(self):
107 | got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params_with_conflicts, custom_labels=self.custom_labels)
108 | self.assertIsInstance(got, dict)
109 | self.assertIn('created_at', got)
110 | self.assertIn('created_by', got)
111 | self.assertIn('empty', got)
112 | self.assertIn('try_num', got)
113 | self.assertEqual(got['empty'], 'True')
114 | self.assertEqual(got['created_by'], 'hoge fuga')
115 | self.assertEqual(got['param1'], 'a' * 10)
116 |
117 | def test_get_patched_obj_metadata_with_required_task_outputs(self):
118 | got = GCSObjectMetadataClient._get_patched_obj_metadata(
119 | {},
120 | required_task_outputs=[
121 | RequiredTaskOutput(task_name='task1', output_path='path/to/output1'),
122 | ],
123 | )
124 |
125 | self.assertIsInstance(got, dict)
126 | self.assertIn('__required_task_outputs', got)
127 | self.assertEqual(got['__required_task_outputs'], '[{"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}]')
128 |
129 | def test_get_patched_obj_metadata_with_nested_required_task_outputs(self):
130 | got = GCSObjectMetadataClient._get_patched_obj_metadata(
131 | {},
132 | required_task_outputs={
133 | 'nested_task': {'nest': RequiredTaskOutput(task_name='task1', output_path='path/to/output1')},
134 | },
135 | )
136 |
137 | self.assertIsInstance(got, dict)
138 | self.assertIn('__required_task_outputs', got)
139 | self.assertEqual(
140 | got['__required_task_outputs'], '{"nested_task": {"nest": {"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}}}'
141 | )
142 |
143 |
144 | class TestGokartTask(unittest.TestCase):
145 | @patch.object(_DummyTaskOnKart, '_get_output_target')
146 | def test_mock_target_on_kart(self, mock_get_output_target):
147 | mock_target = MagicMock(spec=TargetOnKart)
148 | mock_get_output_target.return_value = mock_target
149 |
150 | task = _DummyTaskOnKart()
151 | task.dump({'key': 'value'}, mock_target)
152 |
153 | mock_target.dump.assert_called_once_with(
154 | {'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None, required_task_outputs=[]
155 | )
156 |
157 |
158 | if __name__ == '__main__':
159 | unittest.main()
160 |
--------------------------------------------------------------------------------
/test/test_info.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import patch
3 |
4 | import luigi
5 | import luigi.mock
6 | from luigi.mock import MockFileSystem, MockTarget
7 |
8 | import gokart
9 | import gokart.info
10 | from test.tree.test_task_info import _DoubleLoadSubTask, _SubTask, _Task
11 |
12 |
13 | class TestInfo(unittest.TestCase):
14 | def setUp(self) -> None:
15 | MockFileSystem().clear()
16 | luigi.setup_logging.DaemonLogging._configured = False
17 | luigi.setup_logging.InterfaceLogging._configured = False
18 |
19 | def tearDown(self) -> None:
20 | luigi.setup_logging.DaemonLogging._configured = False
21 | luigi.setup_logging.InterfaceLogging._configured = False
22 |
23 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
24 | def test_make_tree_info_pending(self):
25 | task = _Task(param=1, sub=_SubTask(param=2))
26 |
27 | # check before running
28 | tree = gokart.info.make_tree_info(task)
29 | expected = r"""
30 | └─-\(PENDING\) _Task\[[a-z0-9]*\]
31 | └─-\(PENDING\) _SubTask\[[a-z0-9]*\]$"""
32 | self.assertRegex(tree, expected)
33 |
34 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
35 | def test_make_tree_info_complete(self):
36 | task = _Task(param=1, sub=_SubTask(param=2))
37 |
38 | # check after sub task runs
39 | gokart.build(task, reset_register=False)
40 | tree = gokart.info.make_tree_info(task)
41 | expected = r"""
42 | └─-\(COMPLETE\) _Task\[[a-z0-9]*\]
43 | └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$"""
44 | self.assertRegex(tree, expected)
45 |
46 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
47 | def test_make_tree_info_abbreviation(self):
48 | task = _DoubleLoadSubTask(
49 | sub1=_Task(param=1, sub=_SubTask(param=2)),
50 | sub2=_Task(param=1, sub=_SubTask(param=2)),
51 | )
52 |
53 | # check after sub task runs
54 | gokart.build(task, reset_register=False)
55 | tree = gokart.info.make_tree_info(task)
56 | expected = r"""
57 | └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]
58 | \|--\(COMPLETE\) _Task\[[a-z0-9]*\]
59 | \| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]
60 | └─-\(COMPLETE\) _Task\[[a-z0-9]*\]
61 | └─- \.\.\.$"""
62 | self.assertRegex(tree, expected)
63 |
64 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
65 | def test_make_tree_info_not_compress(self):
66 | task = _DoubleLoadSubTask(
67 | sub1=_Task(param=1, sub=_SubTask(param=2)),
68 | sub2=_Task(param=1, sub=_SubTask(param=2)),
69 | )
70 |
71 | # check after sub task runs
72 | gokart.build(task, reset_register=False)
73 | tree = gokart.info.make_tree_info(task, abbr=False)
74 | expected = r"""
75 | └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]
76 | \|--\(COMPLETE\) _Task\[[a-z0-9]*\]
77 | \| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]
78 | └─-\(COMPLETE\) _Task\[[a-z0-9]*\]
79 | └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$"""
80 | self.assertRegex(tree, expected)
81 |
82 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
83 | def test_make_tree_info_not_compress_ignore_task(self):
84 | task = _DoubleLoadSubTask(
85 | sub1=_Task(param=1, sub=_SubTask(param=2)),
86 | sub2=_Task(param=1, sub=_SubTask(param=2)),
87 | )
88 |
89 | # check after sub task runs
90 | gokart.build(task, reset_register=False)
91 | tree = gokart.info.make_tree_info(task, abbr=False, ignore_task_names=['_Task'])
92 | expected = r"""
93 | └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]$"""
94 | self.assertRegex(tree, expected)
95 |
96 |
97 | if __name__ == '__main__':
98 | unittest.main()
99 |
--------------------------------------------------------------------------------
/test/test_large_data_fram_processor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import unittest
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from gokart.target import LargeDataFrameProcessor
9 | from test.util import _get_temporary_directory
10 |
11 |
12 | class LargeDataFrameProcessorTest(unittest.TestCase):
13 | def setUp(self):
14 | self.temporary_directory = _get_temporary_directory()
15 |
16 | def tearDown(self):
17 | shutil.rmtree(self.temporary_directory, ignore_errors=True)
18 |
19 | def test_save_and_load(self):
20 | file_path = os.path.join(self.temporary_directory, 'test.zip')
21 | df = pd.DataFrame(dict(data=np.random.uniform(0, 1, size=int(1e6))))
22 | processor = LargeDataFrameProcessor(max_byte=int(1e6))
23 | processor.save(df, file_path)
24 | loaded = processor.load(file_path)
25 |
26 | pd.testing.assert_frame_equal(loaded, df, check_like=True)
27 |
28 | def test_save_and_load_empty(self):
29 | file_path = os.path.join(self.temporary_directory, 'test_with_empty.zip')
30 | df = pd.DataFrame()
31 | processor = LargeDataFrameProcessor(max_byte=int(1e6))
32 | processor.save(df, file_path)
33 | loaded = processor.load(file_path)
34 |
35 | pd.testing.assert_frame_equal(loaded, df, check_like=True)
36 |
37 |
38 | if __name__ == '__main__':
39 | unittest.main()
40 |
--------------------------------------------------------------------------------
/test/test_list_task_instance_parameter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import luigi
4 |
5 | import gokart
6 | from gokart import TaskOnKart
7 |
8 |
9 | class _DummySubTask(TaskOnKart):
10 | task_namespace = __name__
11 | pass
12 |
13 |
14 | class _DummyTask(TaskOnKart):
15 | task_namespace = __name__
16 | param = luigi.IntParameter()
17 | task = gokart.TaskInstanceParameter(default=_DummySubTask())
18 |
19 |
20 | class ListTaskInstanceParameterTest(unittest.TestCase):
21 | def setUp(self):
22 | _DummyTask.clear_instance_cache()
23 |
24 | def test_serialize_and_parse(self):
25 | original = [_DummyTask(param=3), _DummyTask(param=3)]
26 | s = gokart.ListTaskInstanceParameter().serialize(original)
27 | parsed = gokart.ListTaskInstanceParameter().parse(s)
28 | self.assertEqual(parsed[0].task_id, original[0].task_id)
29 | self.assertEqual(parsed[1].task_id, original[1].task_id)
30 |
31 |
32 | if __name__ == '__main__':
33 | unittest.main()
34 |
--------------------------------------------------------------------------------
/test/test_mypy.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import unittest
3 |
4 | from mypy import api
5 |
6 | from test.config import PYPROJECT_TOML, PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS
7 |
8 |
9 | class TestMyMypyPlugin(unittest.TestCase):
10 | def test_plugin_no_issue(self):
11 | test_code = """
12 | import luigi
13 | from luigi import Parameter
14 | import gokart
15 | import datetime
16 |
17 |
18 | class MyTask(gokart.TaskOnKart):
19 | foo: int = luigi.IntParameter() # type: ignore
20 | bar: str = luigi.Parameter() # type: ignore
21 | baz: bool = gokart.ExplicitBoolParameter()
22 | qux: str = Parameter()
23 | # https://github.com/m3dev/gokart/issues/395
24 | datetime: datetime.datetime = luigi.DateMinuteParameter(interval=10, default=datetime.datetime(2021, 1, 1))
25 |
26 |
27 |
28 | # TaskOnKart parameters:
29 | # - `complete_check_at_run`
30 | MyTask(foo=1, bar='bar', baz=False, qux='qux', complete_check_at_run=False)
31 | """
32 |
33 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
34 | test_file.write(test_code.encode('utf-8'))
35 | test_file.flush()
36 | result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
37 | self.assertIn('Success: no issues found', result[0])
38 |
39 | def test_plugin_invalid_arg(self):
40 | test_code = """
41 | import luigi
42 | import gokart
43 |
44 |
45 | class MyTask(gokart.TaskOnKart):
46 | foo: int = luigi.IntParameter() # type: ignore
47 | bar: str = luigi.Parameter() # type: ignore
48 | baz: bool = gokart.ExplicitBoolParameter()
49 |
50 | # issue: foo is int
51 | # not issue: bar is missing, because it can be set by config file.
52 | # TaskOnKart parameters:
53 | # - `complete_check_at_run`
54 | MyTask(foo='1', baz='not bool', complete_check_at_run='not bool')
55 | """
56 |
57 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
58 | test_file.write(test_code.encode('utf-8'))
59 | test_file.flush()
60 | result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
61 | self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
62 | self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
63 | self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
64 | self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0])
65 |
66 | def test_parameter_has_default_type_invalid_pattern(self):
67 | """
68 | If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
69 | """
70 | test_code = """
71 | import enum
72 | import luigi
73 | import gokart
74 |
75 |
76 | class MyEnum(enum.Enum):
77 | FOO = enum.auto()
78 |
79 | class MyTask(gokart.TaskOnKart):
80 | foo = luigi.IntParameter()
81 | bar = luigi.DateParameter()
82 | baz = gokart.TaskInstanceParameter()
83 | qux = luigi.NumericalParameter(var_type=int)
84 | quux = luigi.ChoiceParameter(choices=[1, 2, 3], var_type=int)
85 | corge = luigi.EnumParameter(enum=MyEnum)
86 |
87 | MyTask(foo="1", bar=1, baz=1, qux='1', quux='1', corge=1)
88 | """
89 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
90 | test_file.write(test_code.encode('utf-8'))
91 | test_file.flush()
92 | result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
93 | self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
94 | self.assertIn('error: Argument "bar" to "MyTask" has incompatible type "int"; expected "date" [arg-type]', result[0])
95 | self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "int"; expected "TaskOnKart[Any]" [arg-type]', result[0])
96 | self.assertIn('error: Argument "qux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
97 | self.assertIn('error: Argument "quux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
98 | self.assertIn('error: Argument "corge" to "MyTask" has incompatible type "int"; expected "MyEnum" [arg-type]', result[0])
99 |
100 | def test_parameter_has_default_type_no_issue_pattern(self):
101 | """
102 | If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
103 | """
104 | test_code = """
105 | from datetime import date
106 | import luigi
107 | import gokart
108 |
109 | class MyTask(gokart.TaskOnKart):
110 | foo = luigi.IntParameter()
111 | bar = luigi.DateParameter()
112 | baz = gokart.TaskInstanceParameter()
113 |
114 | MyTask(foo=1, bar=date.today(), baz=gokart.TaskOnKart())
115 | """
116 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
117 | test_file.write(test_code.encode('utf-8'))
118 | test_file.flush()
119 | result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
120 | self.assertIn('Success: no issues found', result[0])
121 |
122 | def test_no_issue_found_when_missing_parameter_when_default_option(self):
123 | """
124 | If `disallow_missing_parameters` is False (or default), mypy doesn't show any error when missing parameters.
125 | """
126 | test_code = """
127 | import luigi
128 | import gokart
129 |
130 | class MyTask(gokart.TaskOnKart):
131 | foo = luigi.IntParameter()
132 | bar = luigi.Parameter(default="bar")
133 |
134 | MyTask()
135 | """
136 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
137 | test_file.write(test_code.encode('utf-8'))
138 | test_file.flush()
139 | result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
140 | self.assertIn('Success: no issues found', result[0])
141 |
142 | def test_issue_found_when_missing_parameter_when_disallow_missing_parameters_set_true(self):
143 | """
144 | If `disallow_missing_parameters` is True, mypy shows an error when missing parameters.
145 | """
146 | test_code = """
147 | import luigi
148 | import gokart
149 |
150 | class MyTask(gokart.TaskOnKart):
151 | # issue: foo is missing
152 | foo = luigi.IntParameter()
153 | # bar has default value, so it is not required to set it.
154 | bar = luigi.Parameter(default="bar")
155 |
156 | MyTask()
157 | """
158 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
159 | test_file.write(test_code.encode('utf-8'))
160 | test_file.flush()
161 | result = api.run(
162 | [
163 | '--show-traceback',
164 | '--no-incremental',
165 | '--cache-dir=/dev/null',
166 | '--config-file',
167 | str(PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS),
168 | test_file.name,
169 | ]
170 | )
171 | self.assertIn('error: Missing named argument "foo" for "MyTask" [call-arg]', result[0])
172 | self.assertIn('Found 1 error in 1 file (checked 1 source file)', result[0])
173 |
--------------------------------------------------------------------------------
/test/test_pandas_type_check_framework.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import unittest
5 | from logging import getLogger
6 | from typing import Any
7 | from unittest.mock import patch
8 |
9 | import luigi
10 | import pandas as pd
11 | from luigi.mock import MockFileSystem, MockTarget
12 |
13 | import gokart
14 | from gokart.build import GokartBuildError
15 | from gokart.pandas_type_config import PandasTypeConfig
16 |
17 | logger = getLogger(__name__)
18 |
19 |
20 | class TestPandasTypeConfig(PandasTypeConfig):
21 | task_namespace = 'test_pandas_type_check_framework'
22 |
23 | @classmethod
24 | def type_dict(cls) -> dict[str, Any]:
25 | return {'system_cd': int}
26 |
27 |
28 | class _DummyFailTask(gokart.TaskOnKart):
29 | task_namespace = 'test_pandas_type_check_framework'
30 | rerun = True
31 |
32 | def output(self):
33 | return self.make_target('dummy.pkl')
34 |
35 | def run(self):
36 | df = pd.DataFrame(dict(system_cd=['1']))
37 | self.dump(df)
38 |
39 |
40 | class _DummyFailWithNoneTask(gokart.TaskOnKart):
41 | task_namespace = 'test_pandas_type_check_framework'
42 | rerun = True
43 |
44 | def output(self):
45 | return self.make_target('dummy.pkl')
46 |
47 | def run(self):
48 | df = pd.DataFrame(dict(system_cd=[1, None]))
49 | self.dump(df)
50 |
51 |
52 | class _DummySuccessTask(gokart.TaskOnKart):
53 | task_namespace = 'test_pandas_type_check_framework'
54 | rerun = True
55 |
56 | def output(self):
57 | return self.make_target('dummy.pkl')
58 |
59 | def run(self):
60 | df = pd.DataFrame(dict(system_cd=[1]))
61 | self.dump(df)
62 |
63 |
64 | class TestPandasTypeCheckFramework(unittest.TestCase):
65 | def setUp(self) -> None:
66 | luigi.setup_logging.DaemonLogging._configured = False
67 | luigi.setup_logging.InterfaceLogging._configured = False
68 | MockFileSystem().clear()
69 | # same way as luigi https://github.com/spotify/luigi/blob/fe7ecf4acf7cf4c084bd0f32162c8e0721567630/test/helpers.py#L175
70 | self._stashed_reg = luigi.task_register.Register._get_reg()
71 |
72 | def tearDown(self) -> None:
73 | luigi.setup_logging.DaemonLogging._configured = False
74 | luigi.setup_logging.InterfaceLogging._configured = False
75 | luigi.task_register.Register._set_reg(self._stashed_reg)
76 |
77 | @patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummyFailTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock'])
78 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
79 | def test_fail_with_gokart_run(self):
80 | with self.assertRaises(SystemExit) as exit_code:
81 | gokart.run()
82 | self.assertNotEqual(exit_code.exception.code, 0) # raise Error
83 |
84 | def test_fail(self):
85 | with self.assertRaises(GokartBuildError):
86 | gokart.build(_DummyFailTask(), log_level=logging.CRITICAL)
87 |
88 | def test_fail_with_None(self):
89 | with self.assertRaises(GokartBuildError):
90 | gokart.build(_DummyFailWithNoneTask(), log_level=logging.CRITICAL)
91 |
92 | def test_success(self):
93 | gokart.build(_DummySuccessTask())
94 | # no error
95 |
--------------------------------------------------------------------------------
/test/test_pandas_type_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from datetime import date, datetime
4 | from typing import Any
5 | from unittest import TestCase
6 |
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from gokart import PandasTypeConfig
11 | from gokart.pandas_type_config import PandasTypeError
12 |
13 |
14 | class _DummyPandasTypeConfig(PandasTypeConfig):
15 | @classmethod
16 | def type_dict(cls) -> dict[str, Any]:
17 | return {'int_column': int, 'datetime_column': datetime, 'array_column': np.ndarray}
18 |
19 |
20 | class TestPandasTypeConfig(TestCase):
21 | def test_int_fail(self):
22 | df = pd.DataFrame(dict(int_column=['1']))
23 | with self.assertRaises(PandasTypeError):
24 | _DummyPandasTypeConfig().check(df)
25 |
26 | def test_int_success(self):
27 | df = pd.DataFrame(dict(int_column=[1]))
28 | _DummyPandasTypeConfig().check(df)
29 |
30 | def test_datetime_fail(self):
31 | df = pd.DataFrame(dict(datetime_column=[date(2019, 1, 12)]))
32 | with self.assertRaises(PandasTypeError):
33 | _DummyPandasTypeConfig().check(df)
34 |
35 | def test_datetime_success(self):
36 | df = pd.DataFrame(dict(datetime_column=[datetime(2019, 1, 12, 0, 0, 0)]))
37 | _DummyPandasTypeConfig().check(df)
38 |
39 | def test_array_fail(self):
40 | df = pd.DataFrame(dict(array_column=[[1, 2]]))
41 | with self.assertRaises(PandasTypeError):
42 | _DummyPandasTypeConfig().check(df)
43 |
44 | def test_array_success(self):
45 | df = pd.DataFrame(dict(array_column=[np.array([1, 2])]))
46 | _DummyPandasTypeConfig().check(df)
47 |
--------------------------------------------------------------------------------
/test/test_restore_task_by_id.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import patch
3 |
4 | import luigi
5 | import luigi.mock
6 |
7 | import gokart
8 |
9 |
10 | class _SubDummyTask(gokart.TaskOnKart):
11 | task_namespace = __name__
12 | param = luigi.IntParameter()
13 |
14 | def run(self):
15 | self.dump('test')
16 |
17 |
18 | class _DummyTask(gokart.TaskOnKart):
19 | task_namespace = __name__
20 | sub_task = gokart.TaskInstanceParameter()
21 |
22 | def output(self):
23 | return self.make_target('test.txt')
24 |
25 | def run(self):
26 | self.dump('test')
27 |
28 |
29 | class RestoreTaskByIDTest(unittest.TestCase):
30 | def setUp(self) -> None:
31 | luigi.mock.MockFileSystem().clear()
32 |
33 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs))
34 | def test(self):
35 | task = _DummyTask(sub_task=_SubDummyTask(param=10))
36 | luigi.build([task], local_scheduler=True, log_level='CRITICAL')
37 |
38 | unique_id = task.make_unique_id()
39 | restored = _DummyTask.restore(unique_id)
40 | self.assertTrue(task.make_unique_id(), restored.make_unique_id())
41 |
42 |
43 | if __name__ == '__main__':
44 | unittest.main()
45 |
--------------------------------------------------------------------------------
/test/test_run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | from unittest.mock import MagicMock, patch
4 |
5 | import luigi
6 | import luigi.mock
7 |
8 | import gokart
9 | from gokart.run import _try_to_send_event_summary_to_slack
10 |
11 |
12 | class _DummyTask(gokart.TaskOnKart):
13 | task_namespace = __name__
14 | param = luigi.Parameter()
15 |
16 |
17 | class RunTest(unittest.TestCase):
18 | def setUp(self):
19 | luigi.configuration.LuigiConfigParser._instance = None
20 | luigi.mock.MockFileSystem().clear()
21 | os.environ.clear()
22 |
23 | @patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--param', 'test', '--log-level=CRITICAL', '--local-scheduler'])
24 | def test_run(self):
25 | config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
26 | luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
27 | os.environ.setdefault('test_param', 'test')
28 | with self.assertRaises(SystemExit) as exit_code:
29 | gokart.run()
30 | self.assertEqual(exit_code.exception.code, 0)
31 |
32 | @patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--log-level=CRITICAL', '--local-scheduler'])
33 | def test_run_with_undefined_environ(self):
34 | config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
35 | luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
36 | with self.assertRaises(luigi.parameter.MissingParameterException):
37 | gokart.run()
38 |
39 | @patch(
40 | 'sys.argv',
41 | new=[
42 | 'main',
43 | '--tree-info-mode=simple',
44 | '--tree-info-output-path=tree.txt',
45 | f'{__name__}._DummyTask',
46 | '--param',
47 | 'test',
48 | '--log-level=CRITICAL',
49 | '--local-scheduler',
50 | ],
51 | )
52 | @patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs))
53 | def test_run_tree_info(self):
54 | config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
55 | luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
56 | os.environ.setdefault('test_param', 'test')
57 | tree_info = gokart.tree_info(mode='simple', output_path='tree.txt')
58 | with self.assertRaises(SystemExit):
59 | gokart.run()
60 | self.assertTrue(gokart.make_tree_info(_DummyTask(param='test')), tree_info.output().load())
61 |
62 | @patch('gokart.make_tree_info')
63 | def test_try_to_send_event_summary_to_slack(self, make_tree_info_mock: MagicMock):
64 | event_aggregator_mock = MagicMock()
65 | event_aggregator_mock.get_summury.return_value = f'{__name__}._DummyTask'
66 | event_aggregator_mock.get_event_list.return_value = f'{__name__}._DummyTask:[]'
67 | make_tree_info_mock.return_value = 'tree'
68 |
69 | def get_content(content: str, **kwargs):
70 | self.output = content
71 |
72 | slack_api_mock = MagicMock()
73 | slack_api_mock.send_snippet.side_effect = get_content
74 |
75 | cmdline_args = [f'{__name__}._DummyTask', '--param', 'test']
76 | with patch('gokart.slack.SlackConfig.send_tree_info', True):
77 | _try_to_send_event_summary_to_slack(slack_api_mock, event_aggregator_mock, cmdline_args)
78 | expects = os.linesep.join(['===== Event List ====', event_aggregator_mock.get_event_list(), os.linesep, '==== Tree Info ====', 'tree'])
79 |
80 | results = self.output
81 | self.assertEqual(expects, results)
82 |
83 | cmdline_args = [f'{__name__}._DummyTask', '--param', 'test']
84 | with patch('gokart.slack.SlackConfig.send_tree_info', False):
85 | _try_to_send_event_summary_to_slack(slack_api_mock, event_aggregator_mock, cmdline_args)
86 | expects = os.linesep.join(
87 | [
88 | '===== Event List ====',
89 | event_aggregator_mock.get_event_list(),
90 | os.linesep,
91 | '==== Tree Info ====',
92 | 'Please add SlackConfig.send_tree_info to include tree-info',
93 | ]
94 | )
95 |
96 | results = self.output
97 | self.assertEqual(expects, results)
98 |
99 |
100 | if __name__ == '__main__':
101 | unittest.main()
102 |
--------------------------------------------------------------------------------
/test/test_s3_config.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from gokart.s3_config import S3Config
4 |
5 |
6 | class TestS3Config(unittest.TestCase):
7 | def test_get_same_s3_client(self):
8 | client_a = S3Config().get_s3_client()
9 | client_b = S3Config().get_s3_client()
10 |
11 | self.assertEqual(client_a, client_b)
12 |
--------------------------------------------------------------------------------
/test/test_s3_zip_client.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import unittest
4 |
5 | import boto3
6 | from moto import mock_aws
7 |
8 | from gokart.s3_zip_client import S3ZipClient
9 | from test.util import _get_temporary_directory
10 |
11 |
12 | class TestS3ZipClient(unittest.TestCase):
13 | def setUp(self):
14 | self.temporary_directory = _get_temporary_directory()
15 |
16 | def tearDown(self):
17 | shutil.rmtree(self.temporary_directory, ignore_errors=True)
18 |
19 | # remove temporary zip archive if exists.
20 | if os.path.exists(f'{self.temporary_directory}.zip'):
21 | os.remove(f'{self.temporary_directory}.zip')
22 |
23 | @mock_aws
24 | def test_make_archive(self):
25 | conn = boto3.resource('s3', region_name='us-east-1')
26 | conn.create_bucket(Bucket='test')
27 |
28 | file_path = os.path.join('s3://test/', 'test.zip')
29 | temporary_directory = self.temporary_directory
30 |
31 | zip_client = S3ZipClient(file_path=file_path, temporary_directory=temporary_directory)
32 | # raise error if temporary directory does not exist.
33 | with self.assertRaises(FileNotFoundError):
34 | zip_client.make_archive()
35 |
36 | # run without error because temporary directory exists.
37 | os.makedirs(temporary_directory, exist_ok=True)
38 | zip_client.make_archive()
39 |
40 | @mock_aws
41 | def test_unpack_archive(self):
42 | conn = boto3.resource('s3', region_name='us-east-1')
43 | conn.create_bucket(Bucket='test')
44 |
45 | file_path = os.path.join('s3://test/', 'test.zip')
46 | in_temporary_directory = os.path.join(self.temporary_directory, 'in', 'dummy')
47 | out_temporary_directory = os.path.join(self.temporary_directory, 'out', 'dummy')
48 |
49 | # make dummy zip file.
50 | os.makedirs(in_temporary_directory, exist_ok=True)
51 | in_zip_client = S3ZipClient(file_path=file_path, temporary_directory=in_temporary_directory)
52 | in_zip_client.make_archive()
53 |
54 | # load dummy zip file.
55 | out_zip_client = S3ZipClient(file_path=file_path, temporary_directory=out_temporary_directory)
56 | self.assertFalse(os.path.exists(out_temporary_directory))
57 | out_zip_client.unpack_archive()
58 |
--------------------------------------------------------------------------------
/test/test_serializable_parameter.py:
--------------------------------------------------------------------------------
1 | import json
2 | import tempfile
3 | from dataclasses import asdict, dataclass
4 |
5 | import luigi
6 | import pytest
7 | from luigi.cmdline_parser import CmdlineParser
8 | from mypy import api
9 |
10 | from gokart import SerializableParameter, TaskOnKart
11 | from test.config import PYPROJECT_TOML
12 |
13 |
14 | @dataclass(frozen=True)
15 | class Config:
16 | foo: int
17 | bar: str
18 |
19 | def gokart_serialize(self) -> str:
20 | # dict is ordered in Python 3.7+
21 | return json.dumps(asdict(self))
22 |
23 | @classmethod
24 | def gokart_deserialize(cls, s: str) -> 'Config':
25 | return cls(**json.loads(s))
26 |
27 |
28 | class SerializableParameterWithOutDefault(TaskOnKart):
29 | task_namespace = __name__
30 | config: Config = SerializableParameter(object_type=Config)
31 |
32 | def run(self):
33 | self.dump(self.config)
34 |
35 |
36 | class SerializableParameterWithDefault(TaskOnKart):
37 | task_namespace = __name__
38 | config: Config = SerializableParameter(object_type=Config, default=Config(foo=1, bar='bar'))
39 |
40 | def run(self):
41 | self.dump(self.config)
42 |
43 |
44 | class TestSerializableParameter:
45 | def test_default(self):
46 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithDefault']) as cp:
47 | assert cp.get_task_obj().config == Config(foo=1, bar='bar')
48 |
49 | def test_parse_param(self):
50 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp:
51 | assert cp.get_task_obj().config == Config(foo=100, bar='val')
52 |
53 | def test_missing_parameter(self):
54 | with pytest.raises(luigi.parameter.MissingParameterException):
55 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault']) as cp:
56 | cp.get_task_obj()
57 |
58 | def test_value_error(self):
59 | with pytest.raises(ValueError):
60 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', 'Foo']) as cp:
61 | cp.get_task_obj()
62 |
63 | def test_expected_one_argument_error(self):
64 | with pytest.raises(SystemExit):
65 | with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config']) as cp:
66 | cp.get_task_obj()
67 |
68 | def test_mypy(self):
69 | """check invalid object cannot used for SerializableParameter"""
70 |
71 | test_code = """
72 | import gokart
73 |
74 | class InvalidClass:
75 | ...
76 |
77 | gokart.SerializableParameter(object_type=InvalidClass)
78 | """
79 | with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
80 | test_file.write(test_code.encode('utf-8'))
81 | test_file.flush()
82 | result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
83 | assert 'Value of type variable "S" of "SerializableParameter" cannot be "InvalidClass" [type-var]' in result[0]
84 |
--------------------------------------------------------------------------------
/test/test_task_instance_parameter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import luigi
4 |
5 | import gokart
6 | from gokart import ListTaskInstanceParameter, TaskInstanceParameter, TaskOnKart
7 |
8 |
9 | class _DummySubTask(TaskOnKart):
10 | task_namespace = __name__
11 | pass
12 |
13 |
14 | class _DummyCorrectSubClassTask(_DummySubTask):
15 | task_namespace = __name__
16 | pass
17 |
18 |
19 | class _DummyInvalidSubClassTask(TaskOnKart):
20 | task_namespace = __name__
21 | pass
22 |
23 |
24 | class _DummyTask(TaskOnKart):
25 | task_namespace = __name__
26 | param = luigi.IntParameter()
27 | task = TaskInstanceParameter(default=_DummySubTask())
28 |
29 |
30 | class _DummyListTask(TaskOnKart):
31 | task_namespace = __name__
32 | param = luigi.IntParameter()
33 | task = ListTaskInstanceParameter(default=[_DummySubTask(), _DummySubTask()])
34 |
35 |
36 | class TaskInstanceParameterTest(unittest.TestCase):
37 | def setUp(self):
38 | _DummyTask.clear_instance_cache()
39 |
40 | def test_serialize_and_parse(self):
41 | original = _DummyTask(param=2)
42 | s = gokart.TaskInstanceParameter().serialize(original)
43 | parsed = gokart.TaskInstanceParameter().parse(s)
44 | self.assertEqual(parsed.task_id, original.task_id)
45 |
46 | def test_serialize_and_parse_list_params(self):
47 | original = _DummyListTask(param=2)
48 | s = gokart.TaskInstanceParameter().serialize(original)
49 | parsed = gokart.TaskInstanceParameter().parse(s)
50 | self.assertEqual(parsed.task_id, original.task_id)
51 |
52 | def test_invalid_class(self):
53 | self.assertRaises(TypeError, lambda: gokart.TaskInstanceParameter(expected_type=1)) # not type instance
54 |
55 | def test_params_with_correct_param_type(self):
56 | class _DummyPipelineA(TaskOnKart):
57 | task_namespace = __name__
58 | subtask = gokart.TaskInstanceParameter(expected_type=_DummySubTask)
59 |
60 | task = _DummyPipelineA(subtask=_DummyCorrectSubClassTask())
61 | self.assertEqual(task.requires()['subtask'], _DummyCorrectSubClassTask()) # type: ignore
62 |
63 | def test_params_with_invalid_param_type(self):
64 | class _DummyPipelineB(TaskOnKart):
65 | task_namespace = __name__
66 | subtask = gokart.TaskInstanceParameter(expected_type=_DummySubTask)
67 |
68 | with self.assertRaises(TypeError):
69 | _DummyPipelineB(subtask=_DummyInvalidSubClassTask())
70 |
71 |
72 | class ListTaskInstanceParameterTest(unittest.TestCase):
73 | def setUp(self):
74 | _DummyTask.clear_instance_cache()
75 |
76 | def test_invalid_class(self):
77 | self.assertRaises(TypeError, lambda: gokart.ListTaskInstanceParameter(expected_elements_type=1)) # not type instance
78 |
79 | def test_list_params_with_correct_param_types(self):
80 | class _DummyPipelineC(TaskOnKart):
81 | task_namespace = __name__
82 | subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)
83 |
84 | task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()])
85 | self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(),)) # type: ignore
86 |
87 | def test_list_params_with_invalid_param_types(self):
88 | class _DummyPipelineD(TaskOnKart):
89 | task_namespace = __name__
90 | subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)
91 |
92 | with self.assertRaises(TypeError):
93 | _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()])
94 |
95 |
96 | if __name__ == '__main__':
97 | unittest.main()
98 |
--------------------------------------------------------------------------------
/test/test_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from gokart.utils import flatten, map_flattenable_items
4 |
5 |
6 | class TestFlatten(unittest.TestCase):
7 | def test_flatten_dict(self):
8 | self.assertEqual(flatten({'a': 'foo', 'b': 'bar'}), ['foo', 'bar'])
9 |
10 | def test_flatten_list(self):
11 | self.assertEqual(flatten(['foo', ['bar', 'troll']]), ['foo', 'bar', 'troll'])
12 |
13 | def test_flatten_str(self):
14 | self.assertEqual(flatten('foo'), ['foo'])
15 |
16 | def test_flatten_int(self):
17 | self.assertEqual(flatten(42), [42])
18 |
19 | def test_flatten_none(self):
20 | self.assertEqual(flatten(None), [])
21 |
22 |
23 | class TestMapFlatten(unittest.TestCase):
24 | def test_map_flattenable_items(self):
25 | self.assertEqual(map_flattenable_items(lambda x: str(x), {'a': 1, 'b': 2}), {'a': '1', 'b': '2'})
26 | self.assertEqual(
27 | map_flattenable_items(lambda x: str(x), (1, 2, 3, (4, 5, (6, 7, {'a': (8, 9, 0)})))),
28 | ('1', '2', '3', ('4', '5', ('6', '7', {'a': ('8', '9', '0')}))),
29 | )
30 | self.assertEqual(
31 | map_flattenable_items(
32 | lambda x: str(x),
33 | {'a': [1, 2, 3, '4'], 'b': {'c': True, 'd': {'e': 5}}},
34 | ),
35 | {'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}},
36 | )
37 |
--------------------------------------------------------------------------------
/test/test_worker.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from unittest.mock import Mock
3 |
4 | import luigi
5 | import luigi.worker
6 | import pytest
7 | from luigi import scheduler
8 |
9 | import gokart
10 | from gokart.worker import Worker, gokart_worker
11 |
12 |
13 | class _DummyTask(gokart.TaskOnKart):
14 | task_namespace = __name__
15 | random_id = luigi.Parameter()
16 |
17 | def _run(self): ...
18 |
19 | def run(self):
20 | self._run()
21 | self.dump('test')
22 |
23 |
24 | class TestWorkerRun:
25 | def test_run(self, monkeypatch: pytest.MonkeyPatch):
26 | """Check run is called when the task is not completed"""
27 | sch = scheduler.Scheduler()
28 | worker = Worker(scheduler=sch)
29 |
30 | task = _DummyTask(random_id=uuid.uuid4().hex)
31 | mock_run = Mock()
32 | monkeypatch.setattr(task, '_run', mock_run)
33 | with worker:
34 | assert worker.add(task)
35 | assert worker.run()
36 | mock_run.assert_called_once()
37 |
38 |
39 | class _DummyTaskToCheckSkip(gokart.TaskOnKart[None]):
40 | task_namespace = __name__
41 |
42 | def _run(self): ...
43 |
44 | def run(self):
45 | self._run()
46 | self.dump(None)
47 |
48 | def complete(self) -> bool:
49 | return False
50 |
51 |
52 | class TestWorkerSkipIfCompletedPreRun:
53 | @pytest.mark.parametrize(
54 | 'task_completion_check_at_run,is_completed,expect_skipped',
55 | [
56 | pytest.param(True, True, True, id='skipped when completed and task_completion_check_at_run is True'),
57 | pytest.param(True, False, False, id='not skipped when not completed and task_completion_check_at_run is True'),
58 | pytest.param(False, True, False, id='not skipped when completed and task_completion_check_at_run is False'),
59 | pytest.param(False, False, False, id='not skipped when not completed and task_completion_check_at_run is False'),
60 | ],
61 | )
62 | def test_skip_task(self, monkeypatch: pytest.MonkeyPatch, task_completion_check_at_run: bool, is_completed: bool, expect_skipped: bool):
63 | sch = scheduler.Scheduler()
64 | worker = Worker(scheduler=sch, config=gokart_worker(task_completion_check_at_run=task_completion_check_at_run))
65 |
66 | mock_complete = Mock(return_value=is_completed)
67 | # NOTE: set `complete_check_at_run=False` to avoid using deprecated skip logic.
68 | task = _DummyTaskToCheckSkip(complete_check_at_run=False)
69 | mock_run = Mock()
70 | monkeypatch.setattr(task, '_run', mock_run)
71 |
72 | with worker:
73 | assert worker.add(task)
74 | # NOTE: mock `complete` after `add` because `add` calls `complete`
75 | # to check if the task is already completed.
76 | monkeypatch.setattr(task, 'complete', mock_complete)
77 | assert worker.run()
78 |
79 | if expect_skipped:
80 | mock_run.assert_not_called()
81 | else:
82 | mock_run.assert_called_once()
83 |
--------------------------------------------------------------------------------
/test/test_zoned_date_second_parameter.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import unittest
3 |
4 | from luigi.cmdline_parser import CmdlineParser
5 |
6 | from gokart import TaskOnKart, ZonedDateSecondParameter
7 |
8 |
9 | class ZonedDateSecondParameterTaskWithoutDefault(TaskOnKart):
10 | task_namespace = __name__
11 | dt: datetime.datetime = ZonedDateSecondParameter()
12 |
13 | def run(self):
14 | self.dump(self.dt)
15 |
16 |
17 | class ZonedDateSecondParameterTaskWithDefault(TaskOnKart):
18 | task_namespace = __name__
19 | dt: datetime.datetime = ZonedDateSecondParameter(default=datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9))))
20 |
21 | def run(self):
22 | self.dump(self.dt)
23 |
24 |
25 | class ZonedDateSecondParameterTest(unittest.TestCase):
26 | def setUp(self):
27 | self.default_datetime = datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9)))
28 | self.default_datetime_str = '2025-02-21T12:00:00+09:00'
29 |
30 | def test_default(self):
31 | with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault']) as cp:
32 | assert cp.get_task_obj().dt == self.default_datetime
33 |
34 | def test_parse_param_with_tz_suffix(self):
35 | with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault', '--dt', '2024-01-20T11:00:00+09:00']) as cp:
36 | assert cp.get_task_obj().dt == datetime.datetime(2024, 1, 20, 11, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9)))
37 |
38 | def test_parse_param_with_Z_suffix(self):
39 | with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault', '--dt', '2024-01-20T11:00:00Z']) as cp:
40 | assert cp.get_task_obj().dt == datetime.datetime(2024, 1, 20, 11, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=0)))
41 |
42 | def test_parse_param_without_timezone_input(self):
43 | with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithoutDefault', '--dt', '2025-02-21T12:00:00']) as cp:
44 | assert cp.get_task_obj().dt == datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=None)
45 |
46 | def test_parse_method(self):
47 | actual = ZonedDateSecondParameter().parse(self.default_datetime_str)
48 | expected = self.default_datetime
49 | self.assertEqual(actual, expected)
50 |
51 | def test_serialize_task(self):
52 | task = ZonedDateSecondParameterTaskWithoutDefault(dt=self.default_datetime)
53 | actual = str(task)
54 | expected = f'(dt={self.default_datetime_str})'
55 | self.assertTrue(actual.endswith(expected))
56 |
57 |
58 | if __name__ == '__main__':
59 | unittest.main()
60 |
--------------------------------------------------------------------------------
/test/testing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/testing/__init__.py
--------------------------------------------------------------------------------
/test/testing/test_pandas_assert.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import pandas as pd
4 |
5 | import gokart
6 |
7 |
8 | class TestPandasAssert(unittest.TestCase):
9 | def test_assert_frame_contents_equal(self):
10 | expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2])
11 | resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2])
12 |
13 | gokart.testing.assert_frame_contents_equal(resulted, expected)
14 |
15 | def test_assert_frame_contents_equal_with_small_error(self):
16 | expected = pd.DataFrame(data=dict(f1=[1.0001, 2.0001, 3.0001], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2])
17 | resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2.0002, 1.0002, 3.0002], f3=[222, 111, 333]), index=[1, 0, 2])
18 |
19 | gokart.testing.assert_frame_contents_equal(resulted, expected, atol=1e-1)
20 |
21 | def test_assert_frame_contents_equal_with_duplicated_columns(self):
22 | expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2])
23 | expected.columns = ['f1', 'f1', 'f2']
24 | resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2])
25 | resulted.columns = ['f2', 'f1', 'f1']
26 |
27 | with self.assertRaises(AssertionError):
28 | gokart.testing.assert_frame_contents_equal(resulted, expected)
29 |
30 | def test_assert_frame_contents_equal_with_duplicated_indexes(self):
31 | expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2])
32 | expected.index = [0, 1, 1]
33 | resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2])
34 | expected.index = [1, 0, 1]
35 |
36 | with self.assertRaises(AssertionError):
37 | gokart.testing.assert_frame_contents_equal(resulted, expected)
38 |
--------------------------------------------------------------------------------
/test/tree/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m3dev/gokart/855ef94e1ee0936828667c10edebb9305f758c9b/test/tree/__init__.py
--------------------------------------------------------------------------------
/test/tree/test_task_info_formatter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import gokart
4 | from gokart.tree.task_info_formatter import RequiredTask, _make_requires_info
5 |
6 |
7 | class _RequiredTaskExampleTaskA(gokart.TaskOnKart):
8 | task_namespace = __name__
9 |
10 |
11 | class TestMakeRequiresInfo(unittest.TestCase):
12 | def test_make_requires_info_with_task_on_kart(self):
13 | requires = _RequiredTaskExampleTaskA()
14 | resulted = _make_requires_info(requires=requires)
15 | expected = RequiredTask(name=requires.__class__.__name__, unique_id=requires.make_unique_id())
16 | self.assertEqual(resulted, expected)
17 |
18 | def test_make_requires_info_with_list(self):
19 | requires = [_RequiredTaskExampleTaskA()]
20 | resulted = _make_requires_info(requires=requires)
21 | expected = [RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for require in requires]
22 | self.assertEqual(resulted, expected)
23 |
24 | def test_make_requires_info_with_generator(self):
25 | def _requires_gen():
26 | return (_RequiredTaskExampleTaskA() for _ in range(2))
27 |
28 | resulted = _make_requires_info(requires=_requires_gen())
29 | expected = [RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for require in _requires_gen()]
30 | self.assertEqual(resulted, expected)
31 |
32 | def test_make_requires_info_with_dict(self):
33 | requires = dict(taskA=_RequiredTaskExampleTaskA())
34 | resulted = _make_requires_info(requires=requires)
35 | expected = {key: RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for key, require in requires.items()}
36 | self.assertEqual(resulted, expected)
37 |
38 | def test_make_requires_info_with_invalid(self):
39 | requires = [1, 2]
40 | with self.assertRaises(TypeError):
41 | _make_requires_info(requires=requires)
42 |
--------------------------------------------------------------------------------
/test/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 |
4 |
5 | # TODO: use pytest.fixture to share this functionality with other tests
6 | def _get_temporary_directory():
7 | _uuid = str(uuid.uuid4())
8 | return os.path.abspath(os.path.join(os.path.dirname(__name__), f'temporary-{_uuid}'))
9 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = py{39,310,311,312,313},ruff,mypy
3 | skipsdist = True
4 |
5 | [testenv]
6 | runner = uv-venv-lock-runner
7 | dependency_groups = test
8 | commands =
9 | {envpython} -m pytest --cov=gokart --cov-report=xml -vv {posargs:}
10 |
11 | [testenv:ruff]
12 | dependency_groups = lint
13 | commands =
14 | ruff check {posargs:}
15 | ruff format --check {posargs:}
16 |
17 | [testenv:mypy]
18 | dependency_groups = lint
19 | commands = mypy gokart test {posargs:}
20 |
--------------------------------------------------------------------------------