├── .github
└── workflows
│ ├── ci.yml
│ ├── docs.yml
│ └── publish.yml
├── .gitignore
├── CODEOWNERS
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── VERSION
├── dev-requirements.txt
├── docs
├── Makefile
├── make.bat
└── source
│ ├── conf.py
│ ├── getstarted.rst
│ ├── index.rst
│ └── schema.rst
├── mypy.ini
├── notebooks
└── Wicker Hello World.ipynb
├── pyproject.toml
├── setup.cfg
├── setup.py
├── tests
├── .wickerconfig.test.json
├── __init__.py
├── core
│ ├── __init__.py
│ └── test_persistence.py
├── test_avro_schema.py
├── test_column_files.py
├── test_datasets.py
├── test_filelock.py
├── test_numpy_codec.py
├── test_shuffle.py
├── test_spark.py
├── test_storage.py
└── test_wandb.py
├── tox.ini
└── wicker
├── __init__.py
├── core
├── __init__.py
├── abstract.py
├── column_files.py
├── config.py
├── datasets.py
├── definitions.py
├── errors.py
├── filelock.py
├── persistance.py
├── shuffle.py
├── storage.py
├── utils.py
└── writer.py
├── plugins
├── __init__.py
├── dynamodb.py
├── flyte.py
├── spark.py
└── wandb.py
├── schema
├── __init__.py
├── codecs.py
├── dataloading.py
├── dataparsing.py
├── schema.py
├── serialization.py
└── validation.py
└── testing
├── __init__.py
├── codecs.py
└── storage.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 |
3 | # Controls when the workflow will run
4 | on:
5 | # Triggers the workflow on push or pull request events but only for the main branch
6 | push:
7 | branches: [ main ]
8 | pull_request:
9 | branches: [ main ]
10 |
11 | # Allows you to run this workflow manually from the Actions tab
12 | workflow_dispatch:
13 |
14 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
15 | jobs:
16 | # This workflow contains a single job called "run-ci"
17 | run-ci:
18 | # The type of runner that the job will run on
19 | runs-on: ubuntu-latest
20 |
21 | # Set environment variables
22 | # ToDo: remove this and make the config hermetic and tunable so each test can run without this file overhead
23 | env:
24 | WICKER_CONFIG_PATH: ${{ github.workspace }}/tests/.wickerconfig.test.json
25 |
26 | # Steps represent a sequence of tasks that will be executed as part of the job
27 | steps:
28 |
29 | - name: Check out repo
30 | uses: actions/checkout@v2
31 |
32 | - name: Set up Python 3.8
33 | uses: actions/setup-python@v1
34 | with:
35 | python-version: 3.8
36 |
37 | - name: Install dependencies
38 | run: |
39 | pip install --upgrade pip
40 | pip install -r dev-requirements.txt
41 |
42 | - name: Run tests
43 | run: make test
44 |
45 | - name: Run lints
46 | run: make lint
47 |
48 | - name: Run type checks
49 | run: make type-check
50 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 |
3 | # Controls when the workflow will run
4 | on:
5 | # Triggers the workflow on push events but only for the main branch
6 | push:
7 | branches: [ main ]
8 | pull_request:
9 | branches: [ main ]
10 |
11 | # Allows you to run this workflow manually from the Actions tab
12 | workflow_dispatch:
13 |
14 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
15 | jobs:
16 |
17 | deploy-docs:
18 | # The type of runner that the job will run on
19 | runs-on: ubuntu-latest
20 |
21 | # Set environment variables
22 | env:
23 | WICKER_CONFIG_PATH: tests/.wickerconfig.test.json
24 |
25 | # Steps represent a sequence of tasks that will be executed as part of the job
26 | steps:
27 |
28 | - name: Check out repo
29 | uses: actions/checkout@v2
30 |
31 | - name: Set up Python 3.8
32 | uses: actions/setup-python@v1
33 | with:
34 | python-version: 3.8
35 |
36 | - name: Install dependencies
37 | run: |
38 | pip install --upgrade pip
39 | pip install -r dev-requirements.txt
40 |
41 | - name: Build docs
42 | run: |
43 | pushd docs
44 | make html
45 | popd
46 | touch docs/build/html/.nojekyll
47 |
48 | - name: Deploy docs
49 | uses: JamesIves/github-pages-deploy-action@4.1.8
50 | with:
51 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
52 | BASE_BRANCH: master # The branch the action should deploy from.
53 | BRANCH: gh-pages # The branch the action should deploy to.
54 | FOLDER: docs/build/html # The folder the action should deploy (only stuff inside will be copied).
55 | # Reactivate when ready
56 | if: github.event_name == 'push'
57 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: publish
2 |
3 | # Controls when the workflow will run
4 | on:
5 | # Triggers the workflow on push events but only for tags starting with v*
6 | push:
7 | tags: ["v*"]
8 |
9 | # Allows you to run this workflow manually from the Actions tab
10 | workflow_dispatch:
11 |
12 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
13 | jobs:
14 |
15 | publish-pypi:
16 | # The type of runner that the job will run on
17 | runs-on: ubuntu-latest
18 |
19 | steps:
20 | - name: Check out repo
21 | uses: actions/checkout@v2
22 |
23 | - name: Build package
24 | run: |
25 | python3 -m pip install --upgrade build
26 | python3 -m build
27 |
28 | - name: Publish package
29 | uses: pypa/gh-action-pypi-publish@release/v1
30 | with:
31 | user: __token__
32 | password: ${{ secrets.PYPI_API_TOKEN }}
33 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | .idea/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | # Code env specific
133 | .vscode
134 | *.egg
135 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @aalavian @anantsimran @chrisochoatri @convexquad @marccarre @pickles-bread-and-butter
2 |
3 | CODEOWNERS @aalavian @marccarre
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Woven Planet
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include VERSION
2 | include README.md
3 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | docs:
2 | sphinx-apidoc -f -o docs/source/ .
3 | cd docs && make html
4 | @echo "Check out file://docs/html/index.html for the generated docs."
5 |
6 | lint-fix:
7 | @echo "Running automated lint fixes"
8 | python -m isort tests/* wicker/*
9 | python -m black tests/* wicker/*
10 |
11 | lint:
12 | @echo "Running wicker lints"
13 | # Ignoring slice errors E203: https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#slices
14 | python -m flake8 tests/* wicker/* --ignore E203,W503
15 | python -m isort tests/* wicker/* --check --diff
16 | python -m black tests/* wicker/* --check
17 |
18 | type-check:
19 | @echo "Running wicker type checking with mypy"
20 | python -m mypy tests
21 | python -m mypy wicker
22 |
23 | test:
24 | @echo "Running wicker tests"
25 | python -m unittest discover
26 |
27 | .PHONY: black docs lint type-check test
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Wicker
2 |
3 | Wicker is an open source framework for Machine Learning dataset storage and serving developed at Woven Planet L5.
4 |
5 | # Usage
6 |
7 | Refer to the [Wicker documentation's Getting Started page](https://woven-planet.github.io/wicker/getstarted.html) for more information.
8 |
9 | # Development
10 |
11 | To develop on Wicker to contribute, set up your local environment as follows:
12 |
13 | 1. Create a new virtual environment
14 | 2. Do a `pip install -r dev-requirements.txt` to install the development dependencies
15 | 3. Run `make test` to run all unit tests
16 | 4. Run `make lint-fix` to fix all lints and `make lint` to check for any lints that must be fixed manually
17 | 5. Run `make type-check` to check for type errors
18 |
19 | To contribute a new plugin to have Wicker be compatible with other technologies (e.g. Kubernetes, Ray, AWS batch etc):
20 |
21 | 1. Add your plugin into the `wicker.plugins` module as an appropriately named module
22 | 2. If your new plugin requires new external dependencies:
23 | 1. Add a new extra-requires entry to `setup.cfg`
24 | 2. Update `dev-requirements.txt` with any necessary dependencies to run your module in unit tests
25 | 3. Write a unit test in `tests/` to test your module
26 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.0.18
2 |
--------------------------------------------------------------------------------
/dev-requirements.txt:
--------------------------------------------------------------------------------
1 | alabaster==0.7.12
2 | attrs==21.2.0
3 | avro==1.10.2
4 | Babel==2.9.1
5 | black==21.11b1
6 | boto3==1.20.22
7 | botocore==1.23.22
8 | certifi==2021.10.8
9 | charset-normalizer==2.0.9
10 | click==8.0.3
11 | decorator==5.1.0
12 | docutils==0.17.1
13 | flake8==4.0.1
14 | idna==3.3
15 | imagesize==1.3.0
16 | iniconfig==1.1.1
17 | isort==5.9.3
18 | Jinja2==3.0.3
19 | jmespath==0.10.0
20 | MarkupSafe==2.0.1
21 | mccabe==0.6.1
22 | mypy==0.960
23 | mypy-extensions==0.4.3
24 | numpy==1.21.2
25 | packaging==21.0
26 | pathspec==0.9.0
27 | platformdirs==2.4.0
28 | pluggy==1.0.0
29 | py==1.10.0
30 | py4j==0.10.9.2
31 | pyarrow==5.0.0
32 | pycodestyle==2.8.0
33 | pyflakes==2.4.0
34 | Pygments==2.10.0
35 | pyparsing==2.4.7
36 | pyspark==3.2.0
37 | pytest==6.2.5
38 | python-dateutil==2.8.2
39 | pytz==2021.3
40 | regex==2021.11.10
41 | requests==2.26.0
42 | retry==0.9.2
43 | retrying==1.3.3
44 | s3transfer==0.5.0
45 | six==1.16.0
46 | snowballstemmer==2.2.0
47 | Sphinx==4.3.2
48 | sphinxcontrib-applehelp==1.0.2
49 | sphinxcontrib-devhelp==1.0.2
50 | sphinxcontrib-htmlhelp==2.0.0
51 | sphinxcontrib-jsmath==1.0.1
52 | sphinxcontrib-qthelp==1.0.3
53 | sphinxcontrib-serializinghtml==1.1.5
54 | toml==0.10.2
55 | tomli==1.2.2
56 | types-retry==0.9.2
57 | typing-extensions==3.10.0.2
58 | urllib3==1.26.7
59 | wandb==0.12.21
60 | tqdm
61 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('../..'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'Wicker'
21 | copyright = '2021, Woven Planet L5'
22 | author = 'Jay Chia, Alex Bain, Francois Lefevere, Ben Hansen, Jason Zhao'
23 |
24 |
25 | # -- General configuration ---------------------------------------------------
26 |
27 | # Add any Sphinx extension module names here, as strings. They can be
28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
29 | # ones.
30 | extensions = [
31 | "sphinx.ext.autodoc",
32 | ]
33 |
34 | # Add any paths that contain templates here, relative to this directory.
35 | templates_path = ['_templates']
36 |
37 | # List of patterns, relative to source directory, that match files and
38 | # directories to ignore when looking for source files.
39 | # This pattern also affects html_static_path and html_extra_path.
40 | exclude_patterns = [
41 | "tests*",
42 | "setup.py",
43 | ]
44 |
45 |
46 | # -- Options for HTML output -------------------------------------------------
47 |
48 | # The theme to use for HTML and HTML Help pages. See the documentation for
49 | # a list of builtin themes.
50 | #
51 | html_theme = 'alabaster'
52 |
53 | # Add any paths that contain custom static files (such as style sheets) here,
54 | # relative to this directory. They are copied after the builtin static files,
55 | # so a file named "default.css" will overwrite the builtin "default.css".
56 | html_static_path = ['_static']
--------------------------------------------------------------------------------
/docs/source/getstarted.rst:
--------------------------------------------------------------------------------
1 | Getting Started
2 | ===============
3 |
4 | Wicker is an open source framework for Machine Learning dataset storage and serving developed at Woven Planet L5.
5 |
6 | Wicker leverages other open source technologies such as Apache Arrow and Apache Parquet to store and serve data. Operating
7 | Wicker mainly requires users to provide an object store (currently Wicker is only compatible with AWS S3, but integrations with
8 | other cloud object stores are a work-in-progress).
9 |
10 | Out of the box, Wicker provides integrations with several widely used technologies such as Spark, Flyte and DynamoDB to allow users
11 | to write Wicker datasets from these data infrastructures. However, Wicker was built with a high degree of extensibility in mind, and
12 | allows users to build and use their own implementations to easily integrate with their own infrastructure.
13 |
14 | Installation
15 | ------------
16 |
17 | ``pip install wicker``
18 |
19 | Additionally, in order to use some of the provided integrations with other open-source tooling such as Spark, Flyte and Kubernetes,
20 | users may optionally add these options as extra install arguments:
21 |
22 | ``pip install wicker[spark,flyte,kubernetes,...]``
23 |
24 | Configuration
25 | -------------
26 |
27 | By default, Wicker searches for a configurations file at ``~/wickerconfig.json``. Users may also change this path by setting the
28 | ``WICKER_CONFIG_PATH`` variable to point to their configuration JSON file.
29 |
30 | .. code-block:: json
31 |
32 | {
33 | "aws_s3_config": {
34 | "s3_datasets_path": "s3://my-bucket/somepath", // Path to the AWS bucket + prefix to use
35 | "region": "us-west-2", // Region of your bucket
36 | "store_concatenated_bytes_files_in_dataset": true // (Optional) Whether to store concatenated bytes files in the dataset
37 | }
38 | }
39 |
40 | Writing your first Dataset
41 | --------------------------
42 |
43 | Wicker allows users to work both locally and in the cloud, by leveraging different compute and storage backends.
44 |
45 | Note that every dataset must have a defined schema. We define schemas using Wicker's schema library:
46 |
47 | .. code-block:: python3
48 |
49 | from wicker import schema
50 |
51 | MY_SCHEMA = schema.DatasetSchema(
52 | primary_keys=["foo"],
53 | fields=[
54 | schema.StringField("foo", description="This is an optional human-readable description of the field"),
55 | schema.NumpyField("arr", shape=(4, 4), dtype="float64"),
56 | ]
57 | )
58 |
59 | The above schema defines a dataset that consists of data that looks like:
60 |
61 | .. code-block:: python3
62 |
63 | {
64 | "foo": "some_string",
65 | "arr": np.array([
66 | [1., 1., 1., 1.],
67 | [1., 1., 1., 1.],
68 | [1., 1., 1., 1.],
69 | [1., 1., 1., 1.],
70 | ])
71 | }
72 |
73 | We have the guarantee that the dataset will be:
74 |
75 | 1. Sorted by each examples's `"foo"` field as this is the only primary_key of the dataset
76 | 2. Each example's `"arr"` field contains a 4-by-4 numpy array of float64 values
77 |
78 | After defining a schema, we can then start to write data conforming to this schema to a dataset
79 |
80 | Using Spark
81 | ^^^^^^^^^^^
82 |
83 | Spark is a common data engine and Wicker provides integrations to write datasets from Spark.
84 |
85 | .. code-block:: python3
86 |
87 | from wicker.plugins.spark import SparkPersistor
88 |
89 | examples = [
90 | (
91 | "train", # Wicker dataset partition that this row belongs to
92 | {
93 | "foo": f"foo{i}",
94 | "arr": np.ones((4, 4)),
95 | }
96 | ) for i in range(1000)
97 | ]
98 |
99 | rdd = spark_context.parallelize(examples)
100 | persistor = SparkPersistor()
101 | persistor.persist_wicker_dataset(
102 | "my_dataset_name",
103 | "0.0.1",
104 | MY_SCHEMA,
105 | rdd,
106 | )
107 |
108 | And that's it! Wicker will handle all the sorting and persisting of the data for you under the hood.
109 |
110 | Using Non-Data Engine Infrastructures
111 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
112 |
113 | Not all users have access to infrastructure like Spark, or want to fire up something quite as heavyweight for
114 | maybe a smaller dataset or use-case. For these users, Wicker exposes a ``DatasetWriter`` API for adding and committing
115 | examples from any environment.
116 |
117 | To make this work, Wicker needs an intermediate ``MetadataDatabase`` to store and index information about each row before
118 | it commits the dataset. We provide a default integration with DynamoDB, but users can implement their own integrations easily
119 | by implementing the abstract interface ``wicker.core.writer.AbstractDatasetWriterMetadataDatabase``, and use their own
120 | MetadataDatabases as intermediate storage for persisting their data. Integrations with other databases as to use as a Wicker-compatible
121 | MetadataDatabase is a work-in-progress.
122 |
123 | Below, we provide an example of how we can use `Flyte `_ to commit our datasets, using DynamoDB as our
124 | MetadataDatabase. More plugins are being written for other commonly used cloud infrastructure such as AWS Batch, Kubernetes etc.
125 |
126 | .. code-block:: python3
127 |
128 | from wicker.schema import serialization
129 | from wicker.core.definitions import DatasetDefinition, DatasetID
130 | from wicker.core.writer import DatasetWriter
131 | from wicker.plugins import dynamodb, flyte
132 |
133 | # First, add the following to our ~/.wickerconfig.json file to enable Wicker's DynamoDB integrations
134 | #
135 | # "dynamodb_config": { // only if users need to use DynamoDB for writing datasets
136 | # "table_name": "my-table", // name of the table to use in dynamodb
137 | # "region": "us-west-2" // region of your table
138 | # }
139 |
140 | metadata_database = dynamodb.DynamodbMetadataDatabase()
141 | dataset_definition = DatasetDefinition(DatasetID(name="my_dataset", version="0.0.1"), MY_SCHEMA)
142 |
143 | # (1): Add examples to your dataset
144 | #
145 | # Note that this can be called from anywhere asynchronously, e.g. in different Flyte workers, from
146 | # a Jupyter notebook, a local Python script etc - as long as the same metadata_database config is used
147 | with DatasetWriter(dataset_definition, metadata_database) as writer:
148 | writer.add_example(
149 | "train", # Name of your Wicker dataset partition (e.g. train, test, eval, unittest, ...)
150 | {
151 | "foo": "foo1",
152 | "arr": np.eye(4).astype("float64"),
153 | }, # Raw data for a single example that conforms to your schema
154 | )
155 |
156 | # (2): When ready, commit the dataset.
157 | #
158 | # Trigger the Flyte workflow to commit the dataset, either from the Flyte UI, Flyte CLI or from a Python script
159 | flyte.WickerDataShufflingWorkflow(
160 | dataset_id=str(dataset_definition.dataset_id),
161 | schema_json_str=serialization.dumps(MY_SCHEMA),
162 | )
163 |
164 | 1. Start adding examples to your dataset. Note:
165 | a) Here we use a ``DynamodbMetadataDatabase`` as the metadata storage for this dataset, but users can use other Metadata Database implementations here as well if they do not have an accessible DynamoDB instance.
166 | b) The ``.add_example(...)`` call writes a single example to the ``"train"`` partition, and can potentially throw a ``WickerSchemaException`` error if the data provided does not conform to the schema.
167 | 2. Commit your dataset. Note here that we use the committing functionality provided by ``wicker.plugins.flyte``, but more plugins for other data infrastructures are a work-in-progress (e.g. Kubernetes, AWS Batch)
168 |
169 |
170 | Reading from your Dataset
171 | -------------------------
172 |
173 | .. code-block:: python3
174 |
175 | from wicker.core.datasets import S3Dataset
176 |
177 | ds = S3Dataset("my_new_dataset", "0.0.1", "train", columns_to_load=["foo", "arr"])
178 |
179 | # Check the size of your "train" partition
180 | len(ds)
181 |
182 | # Retrieve a single item, initial access is slow (O(seconds))
183 | x0 = ds[0]
184 |
185 | # Subsequent data accesses are fast (O(us)), data is cached in page buffers
186 | x0_ = ds[0]
187 |
188 | # Access to contiguous indices is also fast (O(ms)), data is cached on disk/in page buffers
189 | x1 = ds[1]
190 |
191 | Reading from your dataset is as simple as indexing on an ``S3Dataset`` handle. Note:
192 |
193 | 1. Wicker is built for high-throughput and initial access times are amortized by accessing contiguous chunks of indices. Sampling for distributed ML training should take this into account and provide each worker with a contiguous chunk of indices as its working set for good performance.
194 |
195 | 2. Wicker allows users to select columns that they are interested in using, using the ``columns_to_load`` keyword argument
196 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Wicker documentation master file, created by
2 | sphinx-quickstart on Sun Dec 26 08:41:22 2021.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Wicker's documentation!
7 | ==================================
8 |
9 | Wicker is a Python framework for storing and serving Machine Learning datasets. It provides:
10 |
11 | **Abstraction of underlying storage**
12 |
13 | Read and write data as Python dictionaries - don't worry about where/how it is stored.
14 |
15 |
16 | **Schematized Data**
17 |
18 | Enforce a schema when reading and writing your data, providing readability and useability for users of your datasets.
19 |
20 |
21 | **Metadata Indexing**
22 |
23 | Easily store, query and visualize your metadata in your own stores such as BigQuery or Snowflake for downstream analyses.
24 |
25 |
26 | **Efficient Dataloading**
27 |
28 | Built and optimized for high-throughput distributed training of deep learning models.
29 |
30 |
31 | .. toctree::
32 | getstarted
33 | schema
34 |
35 | Indices and tables
36 | ==================
37 |
38 | * :ref:`genindex`
39 | * :ref:`modindex`
40 | * :ref:`search`
41 |
--------------------------------------------------------------------------------
/docs/source/schema.rst:
--------------------------------------------------------------------------------
1 | Wicker Schemas
2 | ==============
3 |
4 | Every Wicker dataset has an associated schema which is declared at schema write-time.
5 |
6 | Wicker schemas are Python objects which are serialized in storage as Avro-compatible JSON files.
7 | When declaring schemas, we use the ``wicker.schema.DatasetSchema`` object:
8 |
9 | .. code-block:: python3
10 |
11 | from wicker.schema import DatasetSchema
12 |
13 | my_schema = DatasetSchema(
14 | primary_keys=["foo", "bar"],
15 | fields=[...],
16 | )
17 |
18 | Your schema must be defined with a set of primary_keys. Your primary keys must be the names of
19 | string, float, int or bool fields in your schema, and will be used to order your dataset.
20 |
21 | Schema Fields
22 | -------------
23 |
24 | Here is a list of Schema fields that Wicker provides. Most notably, users can implement custom fields
25 | by implementing their own codecs and using the ``ObjectField``.
26 |
27 | .. automodule:: wicker.schema.schema
28 | :members:
29 | :undoc-members:
30 | :show-inheritance:
31 | :exclude-members: DatasetSchema, *._accept_visitor
32 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | python_version = 3.8
3 | warn_unused_configs = True
4 |
5 | [mypy-boto3.*]
6 | ignore_missing_imports = True
7 |
8 | [mypy-pynamodb.*]
9 | ignore_missing_imports = True
10 |
11 | [mypy-retrying.*]
12 | ignore_missing_imports = True
13 |
14 | [mypy-pyarrow.*]
15 | ignore_missing_imports = True
16 |
17 | [mypy-pyspark.*]
18 | ignore_missing_imports = True
19 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "wheel"
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [tool.black]
9 | line-length = 120
10 |
11 | [tool.isort]
12 | profile = "black"
13 | multi_line_output = 3
14 | skip = ["docs"]
15 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = wicker
3 | version = file: VERSION
4 | author = wicker-maintainers
5 | author_email = wicker-maintainers@woven-planet.global
6 | description = An open source framework for Machine Learning dataset storage and serving
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 | url = https://github.com/woven-planet/Wicker
10 | project_urls =
11 | Bug Tracker = https://github.com/woven-planet/Wicker/issues
12 | classifiers =
13 | Programming Language :: Python :: 3
14 | License :: OSI Approved :: Apache Software License
15 | Operating System :: OS Independent
16 |
17 | [options]
18 | include_package_data = True
19 | package_dir =
20 | =.
21 | packages = find:
22 | python_requires = >=3.6
23 | install_requires =
24 | numpy
25 | pyarrow
26 | boto3
27 |
28 | [options.extras_require]
29 | flyte = flytekit
30 | dynamodb = pynamodb
31 | spark = pyspark
32 | wandb = wandb
33 |
34 | [options.packages.find]
35 | where = .
36 | exclude =
37 | tests
38 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | if __name__ == "__main__":
4 | setup()
5 |
--------------------------------------------------------------------------------
/tests/.wickerconfig.test.json:
--------------------------------------------------------------------------------
1 | {
2 | "aws_s3_config": {
3 | "s3_datasets_path": "s3://fake_data/",
4 | "region": "us-west-2",
5 | "boto_config": {
6 | "max_pool_connections":10,
7 | "read_timeout_s": 140,
8 | "connect_timeout_s": 140
9 | }
10 | },
11 | "filesystem_configs": [
12 | {
13 | "config_name": "filesystem_1",
14 | "prefix_replace_path": "s3://fake_data_1/",
15 | "root_datasets_path": "/mnt/fake_data_1/"
16 | },
17 | {
18 | "config_name": "filesystem_2",
19 | "prefix_replace_path": "s3://fake_data_2/",
20 | "root_datasets_path": "/mnt/fake_data_2/"
21 | }
22 | ],
23 | "dynamodb_config": {
24 | "table_name": "fake_db",
25 | "region": "us-west-2"
26 | },
27 | "storage_download_config":{
28 | "retries": 2,
29 | "timeout": 150,
30 | "retry_backoff":5,
31 | "retry_delay_s": 4
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/tests/__init__.py
--------------------------------------------------------------------------------
/tests/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/tests/core/__init__.py
--------------------------------------------------------------------------------
/tests/core/test_persistence.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import random
4 | import uuid
5 | from typing import Any, Dict, List, Tuple
6 |
7 | import pyarrow.fs as pafs
8 | import pyarrow.parquet as papq
9 | import pytest
10 |
11 | from wicker import schema
12 | from wicker.core.datasets import S3Dataset
13 | from wicker.core.persistance import BasicPersistor
14 | from wicker.core.storage import S3PathFactory
15 | from wicker.schema.schema import DatasetSchema
16 | from wicker.testing.storage import LocalDataStorage
17 |
18 | DATASET_NAME = "dataset"
19 | DATASET_VERSION = "0.0.1"
20 | SCHEMA = schema.DatasetSchema(
21 | primary_keys=["bar", "foo"],
22 | fields=[
23 | schema.IntField("foo"),
24 | schema.StringField("bar"),
25 | schema.BytesField("bytescol"),
26 | ],
27 | )
28 | EXAMPLES = [
29 | (
30 | "train" if i % 2 == 0 else "test",
31 | {
32 | "foo": random.randint(0, 10000),
33 | "bar": str(uuid.uuid4()),
34 | "bytescol": b"0",
35 | },
36 | )
37 | for i in range(10000)
38 | ]
39 | # Examples with a duplicated key
40 | EXAMPLES_DUPES = copy.deepcopy(EXAMPLES)
41 |
42 |
43 | @pytest.fixture
44 | def mock_basic_persistor(request, tmpdir) -> Tuple[BasicPersistor, str]:
45 | storage = request.param.get("storage", LocalDataStorage(root_path=tmpdir))
46 | path_factory = request.param.get("path_factory", S3PathFactory(s3_root_path=os.path.join(tmpdir, "datasets")))
47 | return BasicPersistor(storage, path_factory), tmpdir
48 |
49 |
50 | def assert_written_correctness(persistor: BasicPersistor, data: List[Tuple[str, Dict[str, Any]]], tmpdir: str) -> None:
51 | """Asserts that all files are written as expected by the L5MLDatastore"""
52 | # Check that files are correctly written locally by Spark/Parquet with a _SUCCESS marker file
53 | prefix = persistor.s3_path_factory.root_path
54 | assert DATASET_NAME in os.listdir(os.path.join(tmpdir, prefix))
55 | assert DATASET_VERSION in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME))
56 | for partition in ["train", "test"]:
57 | print(os.listdir(os.path.join(tmpdir, prefix)))
58 | columns_path = os.path.join(tmpdir, prefix, "__COLUMN_CONCATENATED_FILES__")
59 | all_read_bytes = b""
60 | for filename in os.listdir(columns_path):
61 | concatenated_bytes_filepath = os.path.join(columns_path, filename)
62 | with open(concatenated_bytes_filepath, "rb") as bytescol_file:
63 | all_read_bytes += bytescol_file.read()
64 | assert all_read_bytes == b"0" * 10000
65 |
66 | # Load parquet file and assert ordering of primary_key
67 | assert f"{partition}.parquet" in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION))
68 | tbl = papq.read_table(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION, f"{partition}.parquet"))
69 | foobar = [(barval.as_py(), fooval.as_py()) for fooval, barval in zip(tbl["foo"], tbl["bar"])]
70 | assert foobar == sorted(foobar)
71 |
72 | # Also load the data from the dataset and check it.
73 | ds = S3Dataset(
74 | DATASET_NAME,
75 | DATASET_VERSION,
76 | partition,
77 | local_cache_path_prefix=tmpdir,
78 | storage=persistor.s3_storage,
79 | s3_path_factory=persistor.s3_path_factory,
80 | pa_filesystem=pafs.LocalFileSystem(),
81 | )
82 |
83 | ds_values = [ds[idx] for idx in range(len(ds))]
84 | # Sort the expected values by the primary keys.
85 | expected_values = sorted([value for p, value in data if p == partition], key=lambda v: (v["bar"], v["foo"]))
86 | assert ds_values == expected_values
87 |
88 |
89 | @pytest.mark.parametrize(
90 | "mock_basic_persistor, dataset_name, dataset_version, dataset_schema, dataset",
91 | [({}, DATASET_NAME, DATASET_VERSION, SCHEMA, copy.deepcopy(EXAMPLES_DUPES))],
92 | indirect=["mock_basic_persistor"],
93 | )
94 | def test_basic_persistor(
95 | mock_basic_persistor: Tuple[BasicPersistor, str],
96 | dataset_name: str,
97 | dataset_version: str,
98 | dataset_schema: DatasetSchema,
99 | dataset: List[Tuple[str, Dict[str, Any]]],
100 | ):
101 | """
102 | Test if the basic persistor can persist data in the format we have established.
103 |
104 | Ensure we read the right file locations, the right amount of bytes,
105 | and the ordering is correct.
106 | """
107 | # create the mock basic persistor
108 | mock_basic_persistor_obj, tempdir = mock_basic_persistor
109 | # persist the dataset
110 | mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset)
111 | # assert the dataset is correctly written
112 | assert_written_correctness(persistor=mock_basic_persistor_obj, data=dataset, tmpdir=tempdir)
113 |
--------------------------------------------------------------------------------
/tests/test_column_files.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | import unittest
4 | import uuid
5 | from unittest.mock import MagicMock
6 |
7 | from wicker.core.column_files import (
8 | ColumnBytesFileCache,
9 | ColumnBytesFileLocationV1,
10 | ColumnBytesFileReader,
11 | ColumnBytesFileWriter,
12 | )
13 | from wicker.core.definitions import DatasetID, DatasetPartition
14 | from wicker.core.storage import S3PathFactory, WickerPathFactory
15 | from wicker.testing.storage import FakeS3DataStorage
16 |
17 | FAKE_SHARD_ID = "fake_shard_id"
18 | FAKE_DATA_PARTITION = DatasetPartition(dataset_id=DatasetID(name="name", version="0.0.1"), partition="partition")
19 | FAKE_BYTES = b"foo"
20 | FAKE_BYTES2 = b"0123456789"
21 | FAKE_COL = "col0"
22 | FAKE_COL2 = "col1"
23 |
24 |
25 | class TestColumnBytesFileWriter(unittest.TestCase):
26 | def test_write_empty(self) -> None:
27 | mock_storage = MagicMock()
28 | path_factory = S3PathFactory()
29 | with ColumnBytesFileWriter(
30 | storage=mock_storage,
31 | s3_path_factory=path_factory,
32 | ):
33 | pass
34 | mock_storage.put_object_s3.assert_not_called()
35 |
36 | def test_write_one_column_one_row(self) -> None:
37 | path_factory = S3PathFactory()
38 | mock_storage = MagicMock()
39 | with ColumnBytesFileWriter(
40 | storage=mock_storage,
41 | s3_path_factory=path_factory,
42 | ) as ccb:
43 | info = ccb.add(FAKE_COL, FAKE_BYTES)
44 | self.assertEqual(info.byte_offset, 0)
45 | self.assertEqual(info.data_size, len(FAKE_BYTES))
46 | mock_storage.put_file_s3.assert_called_once_with(
47 | unittest.mock.ANY,
48 | os.path.join(
49 | path_factory.get_column_concatenated_bytes_files_path(),
50 | str(info.file_id),
51 | ),
52 | )
53 |
54 | def test_write_one_column_multi_row(self) -> None:
55 | path_factory = S3PathFactory()
56 | mock_storage = MagicMock()
57 | with ColumnBytesFileWriter(
58 | storage=mock_storage,
59 | s3_path_factory=path_factory,
60 | ) as ccb:
61 | info = ccb.add(FAKE_COL, FAKE_BYTES)
62 | self.assertEqual(info.byte_offset, 0)
63 | self.assertEqual(info.data_size, len(FAKE_BYTES))
64 |
65 | next_info = ccb.add(FAKE_COL, FAKE_BYTES)
66 | self.assertEqual(next_info.byte_offset, len(FAKE_BYTES))
67 | self.assertEqual(next_info.data_size, len(FAKE_BYTES))
68 | mock_storage.put_file_s3.assert_called_once_with(
69 | unittest.mock.ANY,
70 | os.path.join(
71 | path_factory.get_column_concatenated_bytes_files_path(),
72 | str(info.file_id),
73 | ),
74 | )
75 |
76 | def test_write_multi_column_multi_row(self) -> None:
77 | path_factory = S3PathFactory()
78 | mock_storage = MagicMock()
79 | with ColumnBytesFileWriter(
80 | storage=mock_storage,
81 | s3_path_factory=path_factory,
82 | ) as ccb:
83 | info1 = ccb.add(FAKE_COL, FAKE_BYTES)
84 | self.assertEqual(info1.byte_offset, 0)
85 | self.assertEqual(info1.data_size, len(FAKE_BYTES))
86 |
87 | info1 = ccb.add(FAKE_COL, FAKE_BYTES)
88 | self.assertEqual(info1.byte_offset, len(FAKE_BYTES))
89 | self.assertEqual(info1.data_size, len(FAKE_BYTES))
90 |
91 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES)
92 | self.assertEqual(info2.byte_offset, 0)
93 | self.assertEqual(info2.data_size, len(FAKE_BYTES))
94 |
95 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES)
96 | self.assertEqual(info2.byte_offset, len(FAKE_BYTES))
97 | self.assertEqual(info2.data_size, len(FAKE_BYTES))
98 |
99 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES)
100 | self.assertEqual(info2.byte_offset, len(FAKE_BYTES) * 2)
101 | self.assertEqual(info2.data_size, len(FAKE_BYTES))
102 |
103 | mock_storage.put_file_s3.assert_any_call(
104 | unittest.mock.ANY,
105 | os.path.join(
106 | path_factory.get_column_concatenated_bytes_files_path(),
107 | str(info1.file_id),
108 | ),
109 | )
110 | mock_storage.put_file_s3.assert_any_call(
111 | unittest.mock.ANY,
112 | os.path.join(
113 | path_factory.get_column_concatenated_bytes_files_path(),
114 | str(info2.file_id),
115 | ),
116 | )
117 |
118 | def test_write_large_file(self) -> None:
119 | with tempfile.TemporaryDirectory() as tmpdir:
120 | path_factory = S3PathFactory()
121 | storage = FakeS3DataStorage(tmpdir=tmpdir)
122 | with ColumnBytesFileWriter(storage=storage, s3_path_factory=path_factory, target_file_size=10) as ccb:
123 | info1 = ccb.add(FAKE_COL, FAKE_BYTES)
124 | self.assertEqual(info1.byte_offset, 0)
125 | self.assertEqual(info1.data_size, len(FAKE_BYTES))
126 | info1 = ccb.add(FAKE_COL, FAKE_BYTES)
127 | self.assertEqual(info1.byte_offset, len(FAKE_BYTES))
128 | self.assertEqual(info1.data_size, len(FAKE_BYTES))
129 | info1 = ccb.add(FAKE_COL, FAKE_BYTES)
130 | self.assertEqual(info1.byte_offset, len(FAKE_BYTES) * 2)
131 | self.assertEqual(info1.data_size, len(FAKE_BYTES))
132 | info1 = ccb.add(FAKE_COL, FAKE_BYTES)
133 | self.assertEqual(info1.byte_offset, len(FAKE_BYTES) * 3)
134 | self.assertEqual(info1.data_size, len(FAKE_BYTES))
135 | info2 = ccb.add(FAKE_COL, FAKE_BYTES)
136 | self.assertEqual(info2.byte_offset, 0)
137 | self.assertEqual(info2.data_size, len(FAKE_BYTES))
138 | self.assertNotEqual(info1.file_id, info2.file_id)
139 |
140 | info1 = ccb.add(FAKE_COL2, FAKE_BYTES2)
141 | self.assertEqual(info1.byte_offset, 0)
142 | self.assertEqual(info1.data_size, len(FAKE_BYTES2))
143 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES2)
144 | self.assertEqual(info2.byte_offset, 0)
145 | self.assertEqual(info2.data_size, len(FAKE_BYTES2))
146 | self.assertNotEqual(info1.file_id, info2.file_id)
147 | info3 = ccb.add(FAKE_COL2, FAKE_BYTES2)
148 | self.assertEqual(info3.byte_offset, 0)
149 | self.assertEqual(info3.data_size, len(FAKE_BYTES2))
150 | self.assertNotEqual(info2.file_id, info3.file_id)
151 |
152 | def test_write_manyrows_file(self) -> None:
153 | with tempfile.TemporaryDirectory() as tmpdir:
154 | path_factory = S3PathFactory()
155 | storage = FakeS3DataStorage(tmpdir=tmpdir)
156 | with ColumnBytesFileWriter(
157 | storage=storage, s3_path_factory=path_factory, target_file_rowgroup_size=1
158 | ) as ccb:
159 | info1 = ccb.add(FAKE_COL, FAKE_BYTES)
160 | self.assertEqual(info1.byte_offset, 0)
161 | self.assertEqual(info1.data_size, len(FAKE_BYTES))
162 | info2 = ccb.add(FAKE_COL, FAKE_BYTES)
163 | self.assertEqual(info2.byte_offset, 0)
164 | self.assertEqual(info2.data_size, len(FAKE_BYTES))
165 | self.assertNotEqual(info1.file_id, info2.file_id)
166 |
167 | info1 = ccb.add(FAKE_COL2, FAKE_BYTES2)
168 | self.assertEqual(info1.byte_offset, 0)
169 | self.assertEqual(info1.data_size, len(FAKE_BYTES2))
170 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES2)
171 | self.assertEqual(info2.byte_offset, 0)
172 | self.assertEqual(info2.data_size, len(FAKE_BYTES2))
173 | self.assertNotEqual(info1.file_id, info2.file_id)
174 | info3 = ccb.add(FAKE_COL2, FAKE_BYTES2)
175 | self.assertEqual(info3.byte_offset, 0)
176 | self.assertEqual(info3.data_size, len(FAKE_BYTES2))
177 | self.assertNotEqual(info2.file_id, info3.file_id)
178 |
179 |
180 | class TestColumnBytesFileCacheAndReader(unittest.TestCase):
181 | def test_write_one_column_one_row_and_read(self) -> None:
182 | path_factory = S3PathFactory()
183 | mock_storage = MagicMock()
184 | # First, write a row to column bytes mock storage.
185 | with ColumnBytesFileWriter(
186 | storage=mock_storage,
187 | s3_path_factory=path_factory,
188 | ) as ccb:
189 | info = ccb.add(FAKE_COL, FAKE_BYTES)
190 | self.assertEqual(info.byte_offset, 0)
191 | self.assertEqual(info.data_size, len(FAKE_BYTES))
192 | s3_path = os.path.join(
193 | path_factory.get_column_concatenated_bytes_files_path(),
194 | str(info.file_id),
195 | )
196 | mock_storage.put_file_s3.assert_called_once_with(unittest.mock.ANY, s3_path)
197 |
198 | # Now, verify that we can read it back from mock storage.
199 | local_path = os.path.join("/tmp", s3_path.split("s3://fake_data/")[1])
200 | mock_storage.fetch_file = MagicMock(return_value=local_path)
201 |
202 | cbf_cache = ColumnBytesFileCache(
203 | storage=mock_storage,
204 | path_factory=path_factory,
205 | )
206 |
207 | # Mock the helper function that just opens a file and returns its contents.
208 | read_column_bytes_file_mock = MagicMock(return_value=FAKE_BYTES)
209 | cbf_cache._read_column_bytes_file = read_column_bytes_file_mock # type: ignore
210 |
211 | # Read back the column file.
212 | bytes_read = cbf_cache.read(info)
213 | mock_storage.fetch_file.assert_called_once_with(s3_path, "/tmp", timeout_seconds=-1)
214 | self.assertEqual(len(bytes_read), len(FAKE_BYTES))
215 | self.assertEqual(bytes_read, FAKE_BYTES)
216 |
217 | # Now let's verify that we can use the reader directly to read the local column bytes.
218 | wicker_path_factory = WickerPathFactory("/tmp")
219 | cbf_reader = ColumnBytesFileReader(wicker_path_factory)
220 |
221 | # Reset the reader function mock so we can check that it was called just once.
222 | read_column_bytes_file_mock = MagicMock(return_value=FAKE_BYTES)
223 | cbf_reader._read_column_bytes_file = read_column_bytes_file_mock # type: ignore
224 |
225 | bytes_read = cbf_reader.read(info)
226 | read_column_bytes_file_mock.assert_called_once_with(
227 | column_bytes_file_info=info,
228 | column_bytes_file_path=local_path,
229 | )
230 | self.assertEqual(len(bytes_read), len(FAKE_BYTES))
231 | self.assertEqual(bytes_read, FAKE_BYTES)
232 |
233 |
234 | class TestCCBInfo(unittest.TestCase):
235 | def test_to_string(self) -> None:
236 | ccb_info = ColumnBytesFileLocationV1(uuid.uuid4(), 100, 100)
237 | ccb_info_parsed = ColumnBytesFileLocationV1.from_bytes(ccb_info.to_bytes())
238 | self.assertEqual(ccb_info, ccb_info_parsed)
239 |
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | import unittest
4 | from contextlib import contextmanager
5 | from typing import Any, Iterator, NamedTuple, Tuple
6 | from unittest.mock import patch
7 |
8 | import numpy as np
9 | import pyarrow as pa # type: ignore
10 | import pyarrow.fs as pafs # type: ignore
11 | import pyarrow.parquet as papq # type: ignore
12 |
13 | from wicker.core.column_files import ColumnBytesFileWriter
14 | from wicker.core.config import (
15 | FILESYSTEM_CONFIG,
16 | StorageDownloadConfig,
17 | WickerAwsS3Config,
18 | WickerConfig,
19 | WickerFileSystemConfig,
20 | WickerWandBConfig,
21 | )
22 | from wicker.core.datasets import FileSystemDataset, S3Dataset, build_dataset
23 | from wicker.core.definitions import DatasetID, DatasetPartition
24 | from wicker.core.storage import FileSystemDataStorage, S3PathFactory, WickerPathFactory
25 | from wicker.schema import schema, serialization
26 | from wicker.testing.storage import FakeS3DataStorage
27 |
28 | FAKE_NAME = "dataset_name"
29 | FAKE_VERSION = "0.0.1"
30 | FAKE_PARTITION = "train"
31 | FAKE_DATASET_ID = DatasetID(name=FAKE_NAME, version=FAKE_VERSION)
32 | FAKE_DATASET_PARTITION = DatasetPartition(FAKE_DATASET_ID, partition=FAKE_PARTITION)
33 |
34 | FAKE_SHAPE = (4, 4)
35 | FAKE_DTYPE = "float64"
36 | FAKE_NUMPY_CODEC = schema.WickerNumpyCodec(FAKE_SHAPE, FAKE_DTYPE)
37 | FAKE_SCHEMA = schema.DatasetSchema(
38 | primary_keys=["foo"],
39 | fields=[
40 | schema.StringField("foo"),
41 | schema.NumpyField("np_arr", shape=FAKE_SHAPE, dtype=FAKE_DTYPE),
42 | ],
43 | )
44 | FAKE_DATA = [{"foo": f"bar{i}", "np_arr": np.eye(4)} for i in range(1000)]
45 |
46 |
47 | def build_mock_wicker_config(tmpdir: str) -> WickerConfig:
48 | """Helper function to build WickerConfig objects to use as unit test mocks."""
49 | return WickerConfig(
50 | raw={},
51 | aws_s3_config=WickerAwsS3Config.from_json({}),
52 | filesystem_configs=[
53 | WickerFileSystemConfig.from_json(
54 | {
55 | "config_name": "filesystem_1",
56 | "prefix_replace_path": "",
57 | "root_datasets_path": os.path.join(tmpdir, "fake_data"),
58 | }
59 | ),
60 | ],
61 | storage_download_config=StorageDownloadConfig.from_json({}),
62 | wandb_config=WickerWandBConfig.from_json({}),
63 | )
64 |
65 |
66 | @contextmanager
67 | def cwd(path):
68 | """Changes the current working directory, and returns to the previous directory afterwards"""
69 | oldpwd = os.getcwd()
70 | os.chdir(path)
71 | try:
72 | yield
73 | finally:
74 | os.chdir(oldpwd)
75 |
76 |
77 | class TestFileSystemDataset(unittest.TestCase):
78 | @contextmanager
79 | def _setup_storage(self) -> Iterator[Tuple[FileSystemDataStorage, WickerPathFactory, str]]:
80 | with tempfile.TemporaryDirectory() as tmpdir, cwd(tmpdir):
81 | fake_local_fs_storage = FileSystemDataStorage()
82 | fake_local_path_factory = WickerPathFactory(root_path=os.path.join(tmpdir, "fake_data"))
83 | fake_s3_path_factory = S3PathFactory()
84 | fake_s3_storage = FakeS3DataStorage(tmpdir=tmpdir)
85 | with ColumnBytesFileWriter(
86 | storage=fake_s3_storage,
87 | s3_path_factory=fake_s3_path_factory,
88 | target_file_rowgroup_size=10,
89 | ) as writer:
90 | locs = [
91 | writer.add("np_arr", FAKE_NUMPY_CODEC.validate_and_encode_object(data["np_arr"])) # type: ignore
92 | for data in FAKE_DATA
93 | ]
94 |
95 | arrow_metadata_table = pa.Table.from_pydict(
96 | {"foo": [data["foo"] for data in FAKE_DATA], "np_arr": [loc.to_bytes() for loc in locs]}
97 | )
98 | metadata_table_path = os.path.join(
99 | tmpdir, fake_local_path_factory._get_dataset_partition_path(FAKE_DATASET_PARTITION)
100 | )
101 | os.makedirs(os.path.dirname(metadata_table_path), exist_ok=True)
102 | papq.write_table(arrow_metadata_table, metadata_table_path)
103 |
104 | # The mock storage class here actually writes to local storage, so we can use it in the test.
105 | fake_s3_storage.put_object_s3(
106 | serialization.dumps(FAKE_SCHEMA).encode("utf-8"),
107 | fake_local_path_factory._get_dataset_schema_path(FAKE_DATASET_ID),
108 | )
109 | yield fake_local_fs_storage, fake_local_path_factory, tmpdir
110 |
111 | def test_filesystem_dataset(self):
112 | with self._setup_storage() as (fake_local_storage, fake_local_path_factory, tmpdir):
113 | ds = FileSystemDataset(
114 | FAKE_NAME,
115 | FAKE_VERSION,
116 | FAKE_PARTITION,
117 | fake_local_path_factory,
118 | fake_local_storage,
119 | )
120 | for i in range(len(FAKE_DATA)):
121 | retrieved = ds[i]
122 | reference = FAKE_DATA[i]
123 | self.assertEqual(retrieved["foo"], reference["foo"])
124 | np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"])
125 |
126 | # Also double-check that the builder function is working correctly.
127 | with patch("wicker.core.datasets.get_config") as mock_get_config:
128 | mock_get_config.return_value = build_mock_wicker_config(tmpdir)
129 | ds2 = build_dataset(
130 | FILESYSTEM_CONFIG,
131 | FAKE_NAME,
132 | FAKE_VERSION,
133 | FAKE_PARTITION,
134 | config_name="filesystem_1",
135 | )
136 | for i in range(len(FAKE_DATA)):
137 | retrieved = ds2[i]
138 | reference = FAKE_DATA[i]
139 | self.assertEqual(retrieved["foo"], reference["foo"])
140 | np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"])
141 |
142 |
143 | class TestS3Dataset(unittest.TestCase):
144 | @contextmanager
145 | def _setup_storage(self) -> Iterator[Tuple[FakeS3DataStorage, S3PathFactory, str]]:
146 | """Context manager that sets up a local directory to mimic S3 storage for a committed dataset,
147 | and returns a tuple of (S3DataStorage, S3PathFactory, tmpdir_path) for the caller to use as
148 | fixtures in their tests.
149 | """
150 | with tempfile.TemporaryDirectory() as tmpdir, cwd(tmpdir):
151 | fake_s3_storage = FakeS3DataStorage(tmpdir=tmpdir)
152 | fake_s3_path_factory = S3PathFactory()
153 | with ColumnBytesFileWriter(
154 | storage=fake_s3_storage,
155 | s3_path_factory=fake_s3_path_factory,
156 | target_file_rowgroup_size=10,
157 | ) as writer:
158 | locs = [
159 | writer.add("np_arr", FAKE_NUMPY_CODEC.validate_and_encode_object(data["np_arr"])) # type: ignore
160 | for data in FAKE_DATA
161 | ]
162 | arrow_metadata_table = pa.Table.from_pydict(
163 | {
164 | "foo": [data["foo"] for data in FAKE_DATA],
165 | "np_arr": [loc.to_bytes() for loc in locs],
166 | }
167 | )
168 | metadata_table_path = os.path.join(
169 | tmpdir,
170 | fake_s3_path_factory.get_dataset_partition_path(FAKE_DATASET_PARTITION, s3_prefix=False),
171 | "part-1.parquet",
172 | )
173 | os.makedirs(os.path.dirname(metadata_table_path), exist_ok=True)
174 | papq.write_table(
175 | arrow_metadata_table,
176 | metadata_table_path,
177 | )
178 | fake_s3_storage.put_object_s3(
179 | serialization.dumps(FAKE_SCHEMA).encode("utf-8"),
180 | fake_s3_path_factory.get_dataset_schema_path(FAKE_DATASET_ID),
181 | )
182 | yield fake_s3_storage, fake_s3_path_factory, tmpdir
183 |
184 | def test_dataset(self) -> None:
185 | with self._setup_storage() as (fake_s3_storage, fake_s3_path_factory, tmpdir):
186 | ds = S3Dataset(
187 | FAKE_NAME,
188 | FAKE_VERSION,
189 | FAKE_PARTITION,
190 | local_cache_path_prefix=tmpdir,
191 | columns_to_load=None,
192 | storage=fake_s3_storage,
193 | s3_path_factory=fake_s3_path_factory,
194 | pa_filesystem=pafs.LocalFileSystem(),
195 | )
196 |
197 | for i in range(len(FAKE_DATA)):
198 | retrieved = ds[i]
199 | reference = FAKE_DATA[i]
200 | self.assertEqual(retrieved["foo"], reference["foo"])
201 | np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"])
202 |
203 | def test_filters_dataset(self) -> None:
204 | filtered_value_list = [f"bar{i}" for i in range(100)]
205 | with self._setup_storage() as (fake_s3_storage, fake_s3_path_factory, tmpdir):
206 | ds = S3Dataset(
207 | FAKE_NAME,
208 | FAKE_VERSION,
209 | FAKE_PARTITION,
210 | local_cache_path_prefix=tmpdir,
211 | columns_to_load=None,
212 | storage=fake_s3_storage,
213 | s3_path_factory=fake_s3_path_factory,
214 | pa_filesystem=pafs.LocalFileSystem(),
215 | filters=[("foo", "in", filtered_value_list)],
216 | )
217 | self.assertEqual(len(ds), len(filtered_value_list))
218 | retrieved_values_list = [ds[i]["foo"] for i in range(len(ds))]
219 | retrieved_values_list.sort()
220 | filtered_value_list.sort()
221 | self.assertListEqual(retrieved_values_list, filtered_value_list)
222 |
223 | def test_dataset_size(self) -> None:
224 | with self._setup_storage() as (fake_s3_storage, fake_s3_path_factory, tmpdir):
225 | # overwrite the mocked resource function using the fake storage
226 | class FakeResponse(NamedTuple):
227 | content_length: int
228 |
229 | # we do this to mock out using boto3, we use boto3 on the dataset
230 | # size because we can get just file metadata but we only mock
231 | # out the file storage pull usually to mock out boto3 we sub in
232 | # a replacement function that uses the fake storage
233 | class MockedS3Resource:
234 | def __init__(self) -> None:
235 | pass
236 |
237 | def Object(self, bucket: str, key: str) -> FakeResponse:
238 | full_path = os.path.join("s3://", bucket, key)
239 | data = fake_s3_storage.fetch_obj_s3(full_path)
240 | return FakeResponse(content_length=len(data))
241 |
242 | def mock_resource_returner(_: Any):
243 | return MockedS3Resource()
244 |
245 | with patch("wicker.core.datasets.boto3.resource", mock_resource_returner):
246 | ds = S3Dataset(
247 | FAKE_NAME,
248 | FAKE_VERSION,
249 | FAKE_PARTITION,
250 | local_cache_path_prefix=tmpdir,
251 | columns_to_load=None,
252 | storage=fake_s3_storage,
253 | s3_path_factory=fake_s3_path_factory,
254 | pa_filesystem=pafs.LocalFileSystem(),
255 | )
256 |
257 | # sub this in to get the local size of the parquet dir
258 | def _get_parquet_dir_size_mocked():
259 | def get_parquet_size(path="."):
260 | total = 0
261 | with os.scandir(path) as it:
262 | for entry in it:
263 | if entry.is_file() and ".parquet" in entry.name:
264 | total += entry.stat().st_size
265 | elif entry.is_dir():
266 | total += get_parquet_size(entry.path)
267 | return total
268 |
269 | return get_parquet_size(fake_s3_storage._tmpdir)
270 |
271 | ds._get_parquet_dir_size = _get_parquet_dir_size_mocked # type: ignore
272 |
273 | dataset_size = ds.dataset_size
274 |
275 | # get the expected size, all of the col files plus pyarrow table
276 | def get_dir_size(path="."):
277 | total = 0
278 | with os.scandir(path) as it:
279 | for entry in it:
280 | if entry.is_file() and ".json" not in entry.name:
281 | total += entry.stat().st_size
282 | elif entry.is_dir():
283 | total += get_dir_size(entry.path)
284 | return total
285 |
286 | expected_bytes = get_dir_size(fake_s3_storage._tmpdir)
287 | assert expected_bytes == dataset_size
288 |
--------------------------------------------------------------------------------
/tests/test_filelock.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | import re
4 | import sys
5 | import tempfile
6 | import time
7 | import unittest
8 | from typing import Optional
9 |
10 | from wicker.core.filelock import SimpleUnixFileLock
11 |
12 |
13 | def write_index_to_file(filepath: str, index: int) -> None:
14 | # If we don't protect this writing operation with a lock, we will get interleaved
15 | # writing in the file instead of lines that have "foo-[index]"
16 | with SimpleUnixFileLock(f"{filepath}.lock"):
17 | for c in f"foo-{index}\n":
18 | with open(filepath, "a") as f:
19 | f.write(c)
20 |
21 |
22 | class TestFileLock(unittest.TestCase):
23 | def test_simple_acquire_no_race_conditions(self) -> None:
24 | with tempfile.NamedTemporaryFile() as lockfile:
25 | placeholder: Optional[str] = None
26 | with SimpleUnixFileLock(lockfile.name):
27 | placeholder = "foo"
28 | self.assertEqual(placeholder, "foo")
29 |
30 | def test_synchronized_concurrent_writes(self) -> None:
31 | with tempfile.NamedTemporaryFile("w+") as write_file:
32 | with multiprocessing.Pool(4) as pool:
33 | pool.starmap(write_index_to_file, [(write_file.name, i) for i in range(1000)])
34 | for line in write_file.readlines():
35 | self.assertTrue(re.match(r"foo-[0-9]+", line))
36 |
37 | def test_timeout(self) -> None:
38 | with tempfile.NamedTemporaryFile() as lockfile:
39 |
40 | def run_acquire_and_sleep() -> None:
41 | with SimpleUnixFileLock(lockfile.name):
42 | time.sleep(2)
43 |
44 | proc = multiprocessing.Process(target=run_acquire_and_sleep)
45 | proc.start()
46 |
47 | # Give the other process some time to start up
48 | time.sleep(1)
49 |
50 | with self.assertRaises(TimeoutError):
51 | with SimpleUnixFileLock(lockfile.name, timeout_seconds=1):
52 | assert False, "Lock acquisition should time out"
53 |
54 | def test_process_dies(self) -> None:
55 | with tempfile.NamedTemporaryFile() as lockfile:
56 |
57 | def run_acquire_and_die() -> None:
58 | """Simulate a process dying on some exception while the lock is acquired"""
59 | # Mute itself to prevent too much noise in stdout/stderr
60 | with open(os.devnull, "w") as devnull:
61 | sys.stdout = devnull
62 | sys.stderr = devnull
63 | lock = SimpleUnixFileLock(lockfile.name)
64 | lock.__enter__()
65 | raise ValueError("die")
66 |
67 | # When the child process dies, the fd (and the exclusive lock) should be closed automatically
68 | proc = multiprocessing.Process(target=run_acquire_and_die)
69 | proc.start()
70 |
71 | # Give the other process some time to start up
72 | time.sleep(1)
73 |
74 | placeholder: Optional[str] = None
75 | with SimpleUnixFileLock(lockfile.name, timeout_seconds=1):
76 | placeholder = "foo"
77 | self.assertEqual(placeholder, "foo")
78 |
--------------------------------------------------------------------------------
/tests/test_numpy_codec.py:
--------------------------------------------------------------------------------
1 | import io
2 | import unittest
3 |
4 | import numpy as np
5 |
6 | from wicker.core.errors import WickerSchemaException
7 | from wicker.schema.schema import WickerNumpyCodec
8 |
9 | EYE_ARR = np.eye(4)
10 | eye_arr_bio = io.BytesIO()
11 | np.save(eye_arr_bio, EYE_ARR)
12 | EYE_ARR_BYTES = eye_arr_bio.getvalue()
13 |
14 |
15 | class TestNumpyCodec(unittest.TestCase):
16 | def test_codec_none_shape(self) -> None:
17 | codec = WickerNumpyCodec(shape=None, dtype="float64")
18 | self.assertEqual(codec.get_codec_name(), "wicker_numpy")
19 | self.assertEqual(codec.object_type(), np.ndarray)
20 | self.assertEqual(
21 | codec.save_codec_to_dict(),
22 | {
23 | "dtype": "float64",
24 | "shape": None,
25 | },
26 | )
27 | self.assertEqual(WickerNumpyCodec.load_codec_from_dict(codec.save_codec_to_dict()), codec)
28 | self.assertEqual(codec.validate_and_encode_object(EYE_ARR), EYE_ARR_BYTES)
29 | np.testing.assert_equal(codec.decode_object(EYE_ARR_BYTES), EYE_ARR)
30 |
31 | def test_codec_unbounded_dim_shape(self) -> None:
32 | codec = WickerNumpyCodec(shape=(-1, -1), dtype="float64")
33 | self.assertEqual(codec.get_codec_name(), "wicker_numpy")
34 | self.assertEqual(codec.object_type(), np.ndarray)
35 | self.assertEqual(
36 | codec.save_codec_to_dict(),
37 | {
38 | "dtype": "float64",
39 | "shape": [-1, -1],
40 | },
41 | )
42 | self.assertEqual(WickerNumpyCodec.load_codec_from_dict(codec.save_codec_to_dict()), codec)
43 | self.assertEqual(codec.validate_and_encode_object(EYE_ARR), EYE_ARR_BYTES)
44 | np.testing.assert_equal(codec.decode_object(EYE_ARR_BYTES), EYE_ARR)
45 |
46 | # Should raise when provided with bad shapes with too few/many dimensions
47 | with self.assertRaises(WickerSchemaException):
48 | codec.validate_and_encode_object(np.ones((10,)))
49 | with self.assertRaises(WickerSchemaException):
50 | codec.validate_and_encode_object(np.ones((10, 10, 10)))
51 |
52 | def test_codec_fixed_shape(self) -> None:
53 | codec = WickerNumpyCodec(shape=(4, 4), dtype="float64")
54 | self.assertEqual(codec.get_codec_name(), "wicker_numpy")
55 | self.assertEqual(codec.object_type(), np.ndarray)
56 | self.assertEqual(
57 | codec.save_codec_to_dict(),
58 | {
59 | "dtype": "float64",
60 | "shape": [4, 4],
61 | },
62 | )
63 | self.assertEqual(WickerNumpyCodec.load_codec_from_dict(codec.save_codec_to_dict()), codec)
64 | self.assertEqual(codec.validate_and_encode_object(EYE_ARR), EYE_ARR_BYTES)
65 | np.testing.assert_equal(codec.decode_object(EYE_ARR_BYTES), EYE_ARR)
66 |
67 | # Should raise when provided with bad shapes with too few/many/wrong dimensions
68 | with self.assertRaises(WickerSchemaException):
69 | codec.validate_and_encode_object(np.ones((10,)))
70 | with self.assertRaises(WickerSchemaException):
71 | codec.validate_and_encode_object(np.ones((10, 10, 10)))
72 | with self.assertRaises(WickerSchemaException):
73 | codec.validate_and_encode_object(np.ones((5, 4)))
74 |
75 | def test_codec_bad_dtype(self) -> None:
76 | with self.assertRaises(WickerSchemaException):
77 | WickerNumpyCodec(shape=(4, 4), dtype="SOME_BAD_DTYPE")
78 |
--------------------------------------------------------------------------------
/tests/test_shuffle.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import tempfile
3 | import unittest
4 | import unittest.mock
5 | from typing import IO, Dict, List, Tuple
6 |
7 | from wicker.core.column_files import ColumnBytesFileLocationV1
8 | from wicker.core.config import get_config
9 | from wicker.core.definitions import DatasetDefinition, DatasetID, DatasetPartition
10 | from wicker.core.shuffle import ShuffleJob, ShuffleJobFactory, ShuffleWorker
11 | from wicker.core.writer import MetadataDatabaseScanRow
12 | from wicker.schema import dataparsing, schema, serialization
13 | from wicker.testing.codecs import Vector, VectorCodec
14 | from wicker.testing.storage import FakeS3DataStorage
15 |
16 | DATASET_NAME = "dataset1"
17 | DATASET_VERSION = "0.0.1"
18 | FAKE_DATASET_SCHEMA = schema.DatasetSchema(
19 | fields=[
20 | schema.IntField("timestamp"),
21 | schema.StringField("car_id"),
22 | schema.ObjectField("vector", VectorCodec(0)),
23 | ],
24 | primary_keys=["car_id", "timestamp"],
25 | )
26 | FAKE_EXAMPLE = {
27 | "timestamp": 1,
28 | "car_id": "somecar",
29 | "vector": Vector([0, 0, 0]),
30 | }
31 | FAKE_DATASET_ID = DatasetID(name=DATASET_NAME, version=DATASET_VERSION)
32 | FAKE_DATASET_DEFINITION = DatasetDefinition(
33 | dataset_id=FAKE_DATASET_ID,
34 | schema=FAKE_DATASET_SCHEMA,
35 | )
36 |
37 |
38 | class TestShuffleJobFactory(unittest.TestCase):
39 | def test_shuffle_job_factory_no_entries(self) -> None:
40 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend"""
41 | mock_rows: List[MetadataDatabaseScanRow] = []
42 | mock_dataset_writer_backend = unittest.mock.MagicMock()
43 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows)
44 | job_factory = ShuffleJobFactory(mock_dataset_writer_backend)
45 | self.assertEqual(list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)), [])
46 |
47 | def test_shuffle_job_factory_one_entry(self) -> None:
48 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend"""
49 | mock_rows: List[MetadataDatabaseScanRow] = [
50 | MetadataDatabaseScanRow(partition="train", row_data_path="somepath", row_size=1337)
51 | ]
52 | mock_dataset_writer_backend = unittest.mock.MagicMock()
53 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows)
54 | job_factory = ShuffleJobFactory(mock_dataset_writer_backend)
55 | self.assertEqual(
56 | list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)),
57 | [
58 | ShuffleJob(
59 | dataset_partition=DatasetPartition(
60 | dataset_id=FAKE_DATASET_ID,
61 | partition="train",
62 | ),
63 | files=[("somepath", 1337)],
64 | )
65 | ],
66 | )
67 |
68 | def test_shuffle_job_factory_one_partition(self) -> None:
69 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend"""
70 | partition = "train"
71 | mock_rows: List[MetadataDatabaseScanRow] = [
72 | MetadataDatabaseScanRow(partition=partition, row_data_path=f"somepath{i}", row_size=i) for i in range(10)
73 | ]
74 | mock_dataset_writer_backend = unittest.mock.MagicMock()
75 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows)
76 | job_factory = ShuffleJobFactory(mock_dataset_writer_backend)
77 | self.assertEqual(
78 | list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)),
79 | [
80 | ShuffleJob(
81 | dataset_partition=DatasetPartition(
82 | dataset_id=FAKE_DATASET_ID,
83 | partition=partition,
84 | ),
85 | files=[(row.row_data_path, row.row_size) for row in mock_rows],
86 | )
87 | ],
88 | )
89 |
90 | def test_shuffle_job_factory_one_partition_two_working_sets(self) -> None:
91 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend"""
92 | partition = "train"
93 | worker_max_working_set_size = 5
94 | mock_rows: List[MetadataDatabaseScanRow] = [
95 | MetadataDatabaseScanRow(partition=partition, row_data_path=f"somepath{i}", row_size=i) for i in range(10)
96 | ]
97 | mock_dataset_writer_backend = unittest.mock.MagicMock()
98 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows)
99 | job_factory = ShuffleJobFactory(
100 | mock_dataset_writer_backend, worker_max_working_set_size=worker_max_working_set_size
101 | )
102 | self.assertEqual(
103 | list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)),
104 | [
105 | ShuffleJob(
106 | dataset_partition=DatasetPartition(
107 | dataset_id=FAKE_DATASET_ID,
108 | partition=partition,
109 | ),
110 | files=[(row.row_data_path, row.row_size) for row in mock_rows[:worker_max_working_set_size]],
111 | ),
112 | ShuffleJob(
113 | dataset_partition=DatasetPartition(
114 | dataset_id=FAKE_DATASET_ID,
115 | partition=partition,
116 | ),
117 | files=[(row.row_data_path, row.row_size) for row in mock_rows[worker_max_working_set_size:]],
118 | ),
119 | ],
120 | )
121 |
122 | def test_shuffle_job_factory_multi_partition_multi_working_sets(self) -> None:
123 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend"""
124 | partitions = ("train", "test")
125 | num_rows = 10
126 | num_batches = 2
127 | worker_max_working_set_size = num_rows // num_batches
128 | mock_rows: List[MetadataDatabaseScanRow] = [
129 | MetadataDatabaseScanRow(partition=partition, row_data_path=f"somepath{i}", row_size=i)
130 | for partition in partitions
131 | for i in range(10)
132 | ]
133 | mock_dataset_writer_backend = unittest.mock.MagicMock()
134 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows)
135 | job_factory = ShuffleJobFactory(
136 | mock_dataset_writer_backend, worker_max_working_set_size=worker_max_working_set_size
137 | )
138 | self.assertEqual(
139 | list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)),
140 | [
141 | ShuffleJob(
142 | dataset_partition=DatasetPartition(
143 | dataset_id=FAKE_DATASET_ID,
144 | partition=partition,
145 | ),
146 | files=[
147 | (row.row_data_path, row.row_size)
148 | for row in mock_rows[
149 | worker_max_working_set_size
150 | * (batch + (partition_index * num_batches)) : worker_max_working_set_size
151 | * (batch + (partition_index * num_batches) + 1)
152 | ]
153 | ],
154 | )
155 | for partition_index, partition in enumerate(partitions)
156 | for batch in range(2)
157 | ],
158 | )
159 |
160 |
161 | class TestShuffleWorker(unittest.TestCase):
162 | def setUp(self) -> None:
163 | self.boto3_patcher = unittest.mock.patch("wicker.core.shuffle.boto3")
164 | self.boto3_mock = self.boto3_patcher.start()
165 | self.uploaded_column_bytes_files: Dict[Tuple[str, str], bytes] = {}
166 |
167 | def tearDown(self) -> None:
168 | self.boto3_patcher.stop()
169 | self.uploaded_column_bytes_files.clear()
170 |
171 | @staticmethod
172 | def download_fileobj_mock(bucket: str, key: str, bio: IO) -> None:
173 | bio.write(
174 | pickle.dumps(
175 | dataparsing.parse_example(
176 | FAKE_EXAMPLE,
177 | FAKE_DATASET_SCHEMA,
178 | )
179 | )
180 | )
181 | bio.seek(0)
182 | return None
183 |
184 | def test_process_job(self) -> None:
185 | with tempfile.TemporaryDirectory() as tmpdir:
186 | # The threaded workers each construct their own S3DataStorage from a boto3 client to download
187 | # file data in parallel. We mock those out here by mocking out the boto3 client itself.
188 | self.boto3_mock.session.Session().client().download_fileobj.side_effect = self.download_fileobj_mock
189 |
190 | fake_job = ShuffleJob(
191 | dataset_partition=DatasetPartition(
192 | dataset_id=FAKE_DATASET_ID,
193 | partition="test",
194 | ),
195 | files=[(f"s3://somebucket/path/{i}", i) for i in range(10)],
196 | )
197 |
198 | fake_storage = FakeS3DataStorage(tmpdir=tmpdir)
199 | fake_storage.put_object_s3(
200 | serialization.dumps(FAKE_DATASET_SCHEMA).encode("utf-8"),
201 | f"{get_config().aws_s3_config.s3_datasets_path}/{FAKE_DATASET_ID.name}"
202 | f"/{FAKE_DATASET_ID.version}/avro_schema.json",
203 | )
204 | worker = ShuffleWorker(storage=fake_storage)
205 | shuffle_results = worker.process_job(fake_job)
206 |
207 | self.assertEqual(shuffle_results["timestamp"].to_pylist(), [FAKE_EXAMPLE["timestamp"] for _ in range(10)])
208 | self.assertEqual(shuffle_results["car_id"].to_pylist(), [FAKE_EXAMPLE["car_id"] for _ in range(10)])
209 | for location_bytes in shuffle_results["vector"].to_pylist():
210 | location = ColumnBytesFileLocationV1.from_bytes(location_bytes)
211 | path = worker.s3_path_factory.get_column_concatenated_bytes_s3path_from_uuid(location.file_id.bytes)
212 | self.assertTrue(fake_storage.check_exists_s3(path))
213 | data = fake_storage.fetch_obj_s3(path)[location.byte_offset : location.byte_offset + location.data_size]
214 | self.assertEqual(VectorCodec(0).decode_object(data), FAKE_EXAMPLE["vector"])
215 |
--------------------------------------------------------------------------------
/tests/test_spark.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import random
4 | import tempfile
5 | import unittest
6 | import uuid
7 |
8 | import pyarrow.parquet as papq
9 | from pyspark.sql import SparkSession
10 |
11 | from wicker import schema
12 | from wicker.core.config import get_config
13 | from wicker.core.errors import WickerDatastoreException
14 | from wicker.plugins.spark import persist_wicker_dataset
15 | from wicker.testing.storage import FakeS3DataStorage
16 |
17 | DATASET_NAME = "dataset"
18 | DATASET_VERSION = "0.0.1"
19 | TEST_ROWS_NUM = 10000
20 | SCHEMA = schema.DatasetSchema(
21 | primary_keys=["bar", "foo"],
22 | fields=[
23 | schema.IntField("foo"),
24 | schema.StringField("bar"),
25 | schema.BytesField("bytescol"),
26 | ],
27 | )
28 | EXAMPLES = [
29 | (
30 | "train" if i % 2 == 0 else "test",
31 | {
32 | "foo": random.randint(0, TEST_ROWS_NUM),
33 | "bar": str(uuid.uuid4()),
34 | "bytescol": b"0",
35 | },
36 | )
37 | for i in range(TEST_ROWS_NUM)
38 | ]
39 | # Examples with a duplicated key
40 | EXAMPLES_DUPES = copy.deepcopy(EXAMPLES)
41 | EXAMPLES_DUPES[5000] = EXAMPLES_DUPES[0]
42 |
43 |
44 | class LocalWritingTestCase(unittest.TestCase):
45 | def assert_written_correctness(self, tmpdir: str, row_num: int = TEST_ROWS_NUM) -> None:
46 | """Asserts that all files are written as expected by the L5MLDatastore"""
47 | # Check that files are correctly written locally by Spark/Parquet with a _SUCCESS marker file
48 | prefix = get_config().aws_s3_config.s3_datasets_path.replace("s3://", "")
49 | self.assertIn(DATASET_NAME, os.listdir(os.path.join(tmpdir, prefix)))
50 | self.assertIn(DATASET_VERSION, os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME)))
51 | for partition in ["train", "test"]:
52 | columns_path = os.path.join(tmpdir, prefix, "__COLUMN_CONCATENATED_FILES__")
53 | all_read_bytes = b""
54 | for filename in os.listdir(columns_path):
55 | concatenated_bytes_filepath = os.path.join(columns_path, filename)
56 | with open(concatenated_bytes_filepath, "rb") as bytescol_file:
57 | all_read_bytes += bytescol_file.read()
58 | self.assertEqual(all_read_bytes, b"0" * row_num)
59 |
60 | # Load parquet file and assert ordering of primary_key
61 | self.assertIn(
62 | f"{partition}.parquet", os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION))
63 | )
64 | tbl = papq.read_table(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION, f"{partition}.parquet"))
65 | foobar = [(barval.as_py(), fooval.as_py()) for fooval, barval in zip(tbl["foo"], tbl["bar"])]
66 | self.assertEqual(foobar, sorted(foobar))
67 |
68 | def test_simple_schema_local_writing(self) -> None:
69 | for local_reduction in (True, False):
70 | for sort in (True, False):
71 | with tempfile.TemporaryDirectory() as tmpdir:
72 | fake_storage = FakeS3DataStorage(tmpdir=tmpdir)
73 | spark_session = SparkSession.builder.appName("test").master("local[*]")
74 | spark = spark_session.getOrCreate()
75 | sc = spark.sparkContext
76 | rdd = sc.parallelize(copy.deepcopy(EXAMPLES), 100)
77 | persist_wicker_dataset(
78 | DATASET_NAME,
79 | DATASET_VERSION,
80 | SCHEMA,
81 | rdd,
82 | fake_storage,
83 | local_reduction=local_reduction,
84 | sort=sort,
85 | )
86 | self.assert_written_correctness(tmpdir)
87 |
88 | def test_dupe_primary_keys_raises_exception(self) -> None:
89 | for local_reduction in (True, False):
90 | for sort in (True, False):
91 | with self.assertRaises(WickerDatastoreException) as e:
92 | with tempfile.TemporaryDirectory() as tmpdir:
93 | fake_storage = FakeS3DataStorage(tmpdir=tmpdir)
94 | spark_session = SparkSession.builder.appName("test").master("local[*]")
95 | spark = spark_session.getOrCreate()
96 | sc = spark.sparkContext
97 | rdd = sc.parallelize(copy.deepcopy(EXAMPLES_DUPES), 100)
98 | persist_wicker_dataset(
99 | DATASET_NAME,
100 | DATASET_VERSION,
101 | SCHEMA,
102 | rdd,
103 | fake_storage,
104 | local_reduction=local_reduction,
105 | sort=sort,
106 | )
107 |
108 | self.assertIn(
109 | "Error: dataset examples do not have unique primary key tuples",
110 | str(e.exception),
111 | )
112 |
113 | def test_simple_schema_local_writing_4_row_dataset(self) -> None:
114 | for local_reduction in (True, False):
115 | for sort in (True, False):
116 | with tempfile.TemporaryDirectory() as tmpdir:
117 | small_row_cnt = 4
118 | fake_storage = FakeS3DataStorage(tmpdir=tmpdir)
119 | spark_session = SparkSession.builder.appName("test").master("local[*]")
120 | spark = spark_session.getOrCreate()
121 | sc = spark.sparkContext
122 | rdd = sc.parallelize(copy.deepcopy(EXAMPLES)[:small_row_cnt], 1)
123 | persist_wicker_dataset(
124 | DATASET_NAME,
125 | DATASET_VERSION,
126 | SCHEMA,
127 | rdd,
128 | fake_storage,
129 | local_reduction=local_reduction,
130 | sort=sort,
131 | )
132 | self.assert_written_correctness(tmpdir, small_row_cnt)
133 |
--------------------------------------------------------------------------------
/tests/test_storage.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import string
4 | import tempfile
5 | from typing import Any, Dict
6 | from unittest import TestCase, mock
7 |
8 | from botocore.exceptions import ClientError # type: ignore
9 | from botocore.stub import Stubber # type: ignore
10 |
11 | from wicker.core.config import WickerConfig
12 | from wicker.core.storage import FileSystemDataStorage, S3DataStorage, S3PathFactory
13 |
14 | RANDOM_SEED_VALUE = 1
15 | RANDOM_STRING_CHAR_COUNT = 10
16 |
17 |
18 | class TestFileSystemDataStorage(TestCase):
19 | def test_fetch_file(self) -> None:
20 | """Unit test for fetching file from local/mounted file system to different location"""
21 | # put file in the directory that you're using for test
22 | with tempfile.TemporaryDirectory() as temp_dir:
23 | src_dir = os.path.join(temp_dir, "test", "location", "starting", "mount")
24 | os.makedirs(src_dir, exist_ok=True)
25 | src_path = os.path.join(src_dir, "test.txt")
26 | dst_dir = os.path.join(temp_dir, "desired", "location", "for", "test")
27 | os.makedirs(dst_dir, exist_ok=True)
28 | dst_path = os.path.join(dst_dir, "test.txt")
29 |
30 | random.seed(RANDOM_SEED_VALUE)
31 | expected_string = "".join(
32 | random.choices(string.ascii_uppercase + string.digits, k=RANDOM_STRING_CHAR_COUNT)
33 | )
34 | with open(src_path, "w") as open_src:
35 | open_src.write(expected_string)
36 |
37 | # create local file store
38 | local_datastore = FileSystemDataStorage()
39 | # save file to destination
40 | local_datastore.fetch_file(src_path, dst_dir)
41 |
42 | # verify file exists
43 | assert os.path.exists(dst_path)
44 |
45 | # assert contents are the expected
46 | with open(dst_path, "r") as open_dst_file:
47 | test_string = open_dst_file.readline()
48 | assert test_string == expected_string
49 |
50 |
51 | class TestS3DataStorage(TestCase):
52 | def test_bucket_key_from_s3_path(self) -> None:
53 | """Unit test for the S3DataStorage bucket_key_from_s3_path function"""
54 | data_storage = S3DataStorage()
55 |
56 | s3_url = "s3://hello/world"
57 | bucket, key = data_storage.bucket_key_from_s3_path(s3_url)
58 | self.assertEqual(bucket, "hello")
59 | self.assertEqual(key, "world")
60 |
61 | s3_url = "s3://hello/"
62 | bucket, key = data_storage.bucket_key_from_s3_path(s3_url)
63 | self.assertEqual(bucket, "hello")
64 | self.assertEqual(key, "")
65 |
66 | s3_url = "s3://"
67 | bucket, key = data_storage.bucket_key_from_s3_path(s3_url)
68 | self.assertEqual(bucket, "")
69 | self.assertEqual(key, "")
70 |
71 | s3_url = "s3://hello/world/"
72 | bucket, key = data_storage.bucket_key_from_s3_path(s3_url)
73 | self.assertEqual(bucket, "hello")
74 | self.assertEqual(key, "world/")
75 |
76 | def test_check_exists_s3(self) -> None:
77 | """Unit test for the check_exists_s3 function."""
78 | data_storage = S3DataStorage()
79 | input_path = "s3://foo/bar/baz/dummy"
80 |
81 | with Stubber(data_storage.client) as stubber:
82 | response = {} # type: ignore
83 | expected_params = {"Bucket": "foo", "Key": "bar/baz/dummy"}
84 | stubber.add_response("head_object", response, expected_params)
85 | self.assertTrue(data_storage.check_exists_s3(input_path))
86 |
87 | def test_check_exists_s3_nonexisting(self) -> None:
88 | """Unit test for the check_exists_s3 function."""
89 | data_storage = S3DataStorage()
90 | input_path = "s3://foo/bar/baz/dummy"
91 |
92 | with Stubber(data_storage.client) as stubber:
93 | stubber.add_client_error(
94 | expected_params={"Bucket": "foo", "Key": "bar/baz/dummy"},
95 | method="head_object",
96 | service_error_code="404",
97 | service_message="The specified key does not exist.",
98 | )
99 |
100 | # The check_exists_s3 function catches the exception when the key does not exist
101 | self.assertFalse(data_storage.check_exists_s3(input_path))
102 |
103 | def test_put_object_s3(self) -> None:
104 | """Unit test for the put_object_s3 function."""
105 | data_storage = S3DataStorage()
106 | object_bytes = b"this is my object"
107 | input_path = "s3://foo/bar/baz/dummy"
108 |
109 | with Stubber(data_storage.client) as stubber:
110 | response = {} # type: ignore
111 | expected_params = {
112 | "Body": object_bytes,
113 | "Bucket": "foo",
114 | "Key": "bar/baz/dummy",
115 | }
116 | stubber.add_response("put_object", response, expected_params)
117 | data_storage.put_object_s3(object_bytes, input_path)
118 |
119 | def test_put_file_s3(self) -> None:
120 | """Unit test for the put_file_s3 function"""
121 | data_storage = S3DataStorage()
122 | object_bytes = b"this is my object"
123 | input_path = "s3://foo/bar/baz/dummy"
124 |
125 | with tempfile.NamedTemporaryFile() as tmpfile:
126 | tmpfile.write(object_bytes)
127 | tmpfile.flush()
128 |
129 | with Stubber(data_storage.client) as stubber:
130 | response = {} # type: ignore
131 | stubber.add_response("put_object", response, None)
132 | data_storage.put_file_s3(tmpfile.name, input_path)
133 |
134 | @staticmethod
135 | def download_file_side_effect(*args, **kwargs) -> None: # type: ignore
136 | """Helper function to patch the S3 download_file function with a side-effect that creates an
137 | empty file at the correct path in order to mock the download"""
138 | input_path = str(kwargs["filename"])
139 | with open(input_path, "w"):
140 | pass
141 |
142 | # Stubber does not have a stub function for S3 client download_file function, so patch it
143 | @mock.patch("boto3.s3.transfer.S3Transfer.download_file")
144 | def test_fetch_file(self, download_file: mock.Mock) -> None:
145 | """Unit test for the fetch_file function."""
146 | data_storage = S3DataStorage()
147 | input_path = "s3://foo/bar/baz/dummy"
148 | with tempfile.TemporaryDirectory() as local_prefix:
149 | # Add a side-effect to create the file to download at the correct local path
150 | download_file.side_effect = self.download_file_side_effect
151 |
152 | local_path = data_storage.fetch_file(input_path, local_prefix)
153 | download_file.assert_called_once_with(
154 | bucket="foo",
155 | key="bar/baz/dummy",
156 | filename=f"{local_prefix}/bar/baz/dummy",
157 | extra_args=None,
158 | callback=None,
159 | )
160 | self.assertTrue(os.path.isfile(local_path))
161 |
162 | # Stubber does not have a stub function for S3 client download_file function, so patch it
163 | @mock.patch("boto3.s3.transfer.S3Transfer.download_file")
164 | def test_fetch_file_s3_on_nonexistent_file(self, download_file: mock.Mock) -> None:
165 | """Unit test for the fetch_file function for a non-existent file in S3."""
166 | data_storage = S3DataStorage()
167 | input_path = "s3://foo/bar/barbazz/dummy"
168 | local_prefix = "/tmp"
169 |
170 | response = {"Error": {"Code": "404"}}
171 | side_effect = ClientError(response, "unexpected")
172 | download_file.side_effect = side_effect
173 |
174 | with self.assertRaises(ClientError):
175 | data_storage.fetch_file(input_path, local_prefix)
176 |
177 |
178 | class TestS3PathFactory(TestCase):
179 | @mock.patch("wicker.core.storage.get_config")
180 | def test_get_column_concatenated_bytes_files_path(self, mock_get_config: mock.Mock) -> None:
181 | """Unit test for the S3PathFactory get_column_concatenated_bytes_files_path
182 | function"""
183 | # If store_concatenated_bytes_files_in_dataset is False, return the default path
184 | dummy_config: Dict[str, Any] = {
185 | "aws_s3_config": {
186 | "s3_datasets_path": "s3://dummy_bucket/wicker/",
187 | "region": "us-east-1",
188 | "boto_config": {"max_pool_connections": 10, "read_timeout_s": 140, "connect_timeout_s": 140},
189 | },
190 | "dynamodb_config": {"table_name": "fake-table-name", "region": "us-west-2"},
191 | "storage_download_config": {
192 | "retries": 2,
193 | "timeout": 150,
194 | "retry_backoff": 5,
195 | "retry_delay_s": 4,
196 | },
197 | }
198 | mock_get_config.return_value = WickerConfig.from_json(dummy_config)
199 |
200 | path_factory = S3PathFactory()
201 | self.assertEqual(
202 | path_factory.get_column_concatenated_bytes_files_path(),
203 | "s3://dummy_bucket/wicker/__COLUMN_CONCATENATED_FILES__",
204 | )
205 |
206 | # If store_concatenated_bytes_files_in_dataset is True, return the dataset-specific path
207 | dummy_config["aws_s3_config"]["store_concatenated_bytes_files_in_dataset"] = True
208 | mock_get_config.return_value = WickerConfig.from_json(dummy_config)
209 | dataset_name = "dummy_dataset"
210 | self.assertEqual(
211 | S3PathFactory().get_column_concatenated_bytes_files_path(dataset_name=dataset_name),
212 | f"s3://dummy_bucket/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__",
213 | )
214 |
215 | # If the store_concatenated_bytes_files_in_dataset option is True but no
216 | # dataset_name, raise ValueError
217 | with self.assertRaises(ValueError):
218 | S3PathFactory().get_column_concatenated_bytes_files_path()
219 |
220 | # Test the remove s3 prefix option in the get_column_concatenated_bytes_files_path function
221 | self.assertEqual(
222 | S3PathFactory().get_column_concatenated_bytes_files_path(dataset_name=dataset_name, s3_prefix=False),
223 | f"dummy_bucket/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__",
224 | )
225 |
226 | # Test when the s3 prefix remove bool is not passed the prefix isn't eliminated.
227 | self.assertEqual(
228 | S3PathFactory(prefix_replace_path="/test_mount_path").get_column_concatenated_bytes_files_path(
229 | dataset_name=dataset_name, s3_prefix=True
230 | ),
231 | f"s3://dummy_bucket/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__",
232 | )
233 |
234 | self.assertEqual(
235 | S3PathFactory(prefix_replace_path="/test_mount_path/").get_column_concatenated_bytes_files_path(
236 | dataset_name=dataset_name, s3_prefix=False
237 | ),
238 | f"/test_mount_path/dummy_bucket/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__",
239 | )
240 |
241 | self.assertEqual(
242 | S3PathFactory(prefix_replace_path="/test_mount_path/").get_column_concatenated_bytes_files_path(
243 | dataset_name=dataset_name, s3_prefix=False, cut_prefix_override="s3://dummy_bucket/"
244 | ),
245 | f"/test_mount_path/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__",
246 | )
247 |
248 | self.assertEqual(
249 | S3PathFactory(prefix_replace_path="/test_mount_path/").get_column_concatenated_bytes_files_path(
250 | dataset_name=dataset_name, s3_prefix=True, cut_prefix_override="s3://dummy_bucket/"
251 | ),
252 | f"/test_mount_path/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__",
253 | )
254 |
--------------------------------------------------------------------------------
/tests/test_wandb.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from unittest.mock import call, patch
4 |
5 | import pytest
6 |
7 | from wicker.plugins.wandb import (
8 | _identify_s3_url_for_dataset_version,
9 | _set_wandb_credentials,
10 | version_dataset,
11 | )
12 |
13 |
14 | @pytest.fixture
15 | def temp_config(request, tmpdir):
16 | # parse the config json into a proper config
17 | config_json = request.param.get("config_json")
18 | parsable_config = {
19 | "wandb_config": config_json.get("wandb_config", {}),
20 | "aws_s3_config": {
21 | "s3_datasets_path": config_json.get("aws_s3_config", {}).get("s3_datasets_path", None),
22 | "region": config_json.get("aws_s3_config", {}).get("region", None),
23 | },
24 | }
25 |
26 | # set the config path env and then get the creds from the function
27 | temp_config = tmpdir / "temp_config.json"
28 | with open(temp_config, "w") as file_stream:
29 | json.dump(parsable_config, file_stream)
30 | os.environ["WICKER_CONFIG_PATH"] = str(temp_config)
31 | return temp_config, parsable_config
32 |
33 |
34 | @pytest.mark.parametrize(
35 | "temp_config, dataset_name, dataset_version, dataset_metadata, entity",
36 | [
37 | (
38 | {
39 | "config_json": {
40 | "aws_s3_config": {
41 | "s3_datasets_path": "s3://test_path/to_nowhere/",
42 | "region": "us-west-2",
43 | },
44 | "wandb_config": {
45 | "wandb_api_key": "test_key",
46 | "wandb_base_url": "test_url",
47 | },
48 | }
49 | },
50 | "test_data",
51 | "0.0.1",
52 | {"test_key": "test_value"},
53 | "test_entity",
54 | )
55 | ],
56 | indirect=["temp_config"],
57 | )
58 | def test_version_dataset(temp_config, dataset_name, dataset_version, dataset_metadata, entity):
59 | """
60 | GIVEN: A mocked dataset folder to track, a dataset version to track, metadata to track, and a backend to use
61 | WHEN: This dataset is registered or versioned on W&B
62 | THEN: The dataset shows up as a run on the portal
63 | """
64 | # need to mock out all the wandb calls and test just the inputs to them
65 | _, config = temp_config
66 | with patch("wicker.plugins.wandb.wandb.init") as patched_wandb_init:
67 | with patch("wicker.plugins.wandb.wandb.Artifact") as patched_artifact:
68 | # version the dataset with the patched functions/classes
69 | version_dataset(dataset_name, dataset_version, entity, dataset_metadata)
70 |
71 | # establish the expected calls
72 | expected_artifact_calls = [
73 | call(f"{dataset_name}_{dataset_version}", type="dataset"),
74 | call().add_reference(
75 | f"{config['aws_s3_config']['s3_datasets_path']}{dataset_name}/{dataset_version}/assets",
76 | name="dataset",
77 | ),
78 | call().metadata.__setitem__("version", dataset_version),
79 | call().metadata.__setitem__(
80 | "s3_uri", f"{config['aws_s3_config']['s3_datasets_path']}{dataset_name}/{dataset_version}/assets"
81 | ),
82 | ]
83 | for key, value in dataset_metadata.items():
84 | expected_artifact_calls.append(call().metadata.__setitem__(key, value))
85 |
86 | expected_run_calls = [
87 | call(project="dataset_curation", name=f"{dataset_name}_{dataset_version}", entity=entity),
88 | call().log_artifact(patched_artifact()),
89 | ]
90 |
91 | # assert that these are properly called
92 | patched_artifact.assert_has_calls(expected_artifact_calls, any_order=True)
93 |
94 | patched_wandb_init.assert_has_calls(expected_run_calls, any_order=True)
95 |
96 |
97 | @pytest.mark.parametrize(
98 | "credentials_to_load, temp_config",
99 | [
100 | (
101 | {"WANDB_BASE_URL": "env_base", "WANDB_API_KEY": "env_key", "WANDB_USER_EMAIL": "env_email"},
102 | {
103 | "config_json": {
104 | "wandb_config": {
105 | "wandb_base_url": "config_base",
106 | "wandb_api_key": "config_key",
107 | }
108 | }
109 | },
110 | )
111 | ],
112 | ids=["basic test to override all params"],
113 | indirect=["temp_config"],
114 | )
115 | def test_set_wandb_credentials(credentials_to_load, temp_config, tmpdir):
116 | """
117 | GIVEN: A set of credentials as existing envs and a config json specifying creds
118 | WHEN: The configs are requested for wandb
119 | THEN: The proper env variables should be set in the environment based on rules, default to
120 | wicker config and reject preset env variables
121 | """
122 | temp_config_pth, config_json = temp_config
123 |
124 | # load the creds into the env as the base comparison
125 | for key, value in credentials_to_load.items():
126 | os.environ[key] = value
127 |
128 | # compare the creds for expected results
129 | _set_wandb_credentials()
130 | assert os.environ["WANDB_BASE_URL"] == config_json["wandb_config"]["wandb_base_url"]
131 | assert os.environ["WANDB_API_KEY"] == config_json["wandb_config"]["wandb_api_key"]
132 |
133 |
134 | @pytest.mark.parametrize(
135 | "temp_config, dataset_name, dataset_version, dataset_backend, expected_url",
136 | [
137 | (
138 | {"config_json": {"aws_s3_config": {"s3_datasets_path": "s3://test_path_to_nowhere", "region": "test"}}},
139 | "test_dataset",
140 | "0.0.0",
141 | "s3",
142 | "s3://test_path_to_nowhere/test_dataset/0.0.0/assets",
143 | )
144 | ],
145 | ids=["Basic test with s3"],
146 | indirect=["temp_config"],
147 | )
148 | def test_identify_s3_url_for_dataset_version(temp_config, dataset_name, dataset_version, dataset_backend, expected_url):
149 | """
150 | GIVEN: A temporary config, a dataset name, version, backend, and expected url
151 | WHEN: The assets url is pulled from the path factory with these parameters
152 | THEN: The expected url should match what is returned from the function
153 | """
154 | parsed_dataset_url = _identify_s3_url_for_dataset_version(dataset_name, dataset_version, dataset_backend)
155 | assert parsed_dataset_url == expected_url
156 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | exclude=docs
4 |
--------------------------------------------------------------------------------
/wicker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/__init__.py
--------------------------------------------------------------------------------
/wicker/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/core/__init__.py
--------------------------------------------------------------------------------
/wicker/core/abstract.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/core/abstract.py
--------------------------------------------------------------------------------
/wicker/core/config.py:
--------------------------------------------------------------------------------
1 | """This module defines how to configure Wicker from the user environment
2 | """
3 |
4 | from __future__ import annotations
5 |
6 | import dataclasses
7 | import json
8 | import os
9 | from typing import Any, Dict, List
10 |
11 | AWS_S3_CONFIG = "aws_s3_config"
12 | FILESYSTEM_CONFIG = "filesystem_config"
13 |
14 |
15 | @dataclasses.dataclass(frozen=True)
16 | class WickerWandBConfig:
17 | wandb_base_url: str
18 | wandb_api_key: str
19 |
20 | @classmethod
21 | def from_json(cls, data: Dict[str, Any]) -> WickerWandBConfig:
22 | return cls(
23 | wandb_api_key=data.get("wandb_api_key", None),
24 | wandb_base_url=data.get("wandb_base_url", None),
25 | )
26 |
27 |
28 | @dataclasses.dataclass(frozen=True)
29 | class BotoS3Config:
30 | max_pool_connections: int
31 | read_timeout_s: int
32 | connect_timeout_s: int
33 |
34 | @classmethod
35 | def from_json(cls, data: Dict[str, Any]) -> BotoS3Config:
36 | return cls(
37 | max_pool_connections=data.get("max_pool_connections", 0),
38 | read_timeout_s=data.get("read_timeout_s", 0),
39 | connect_timeout_s=data.get("connect_timeout_s", 0),
40 | )
41 |
42 |
43 | @dataclasses.dataclass(frozen=True)
44 | class WickerAwsS3Config:
45 | s3_datasets_path: str
46 | region: str
47 | boto_config: BotoS3Config
48 | store_concatenated_bytes_files_in_dataset: bool = False
49 |
50 | @classmethod
51 | def from_json(cls, data: Dict[str, Any]) -> WickerAwsS3Config:
52 | return cls(
53 | s3_datasets_path=data.get("s3_datasets_path", ""),
54 | region=data.get("region", ""),
55 | boto_config=BotoS3Config.from_json(data.get("boto_config", {})),
56 | store_concatenated_bytes_files_in_dataset=data.get("store_concatenated_bytes_files_in_dataset", False),
57 | )
58 |
59 |
60 | @dataclasses.dataclass(frozen=True)
61 | class WickerFileSystemConfig:
62 | config_name: str
63 | prefix_replace_path: str
64 | root_datasets_path: str
65 |
66 | @classmethod
67 | def from_json(cls, data: Dict[str, Any]) -> WickerFileSystemConfig:
68 | return cls(
69 | config_name=data.get("config_name", ""),
70 | prefix_replace_path=data.get("prefix_replace_path", ""),
71 | root_datasets_path=data.get("root_datasets_path", ""),
72 | )
73 |
74 | @classmethod
75 | def from_json_list(cls, data: List[Dict[str, Any]]) -> List[WickerFileSystemConfig]:
76 | return [WickerFileSystemConfig.from_json(d) for d in data]
77 |
78 |
79 | @dataclasses.dataclass(frozen=True)
80 | class StorageDownloadConfig:
81 | retries: int
82 | timeout: int
83 | retry_backoff: int
84 | retry_delay_s: int
85 |
86 | @classmethod
87 | def from_json(cls, data: Dict[str, Any]) -> StorageDownloadConfig:
88 | return cls(
89 | retries=data.get("retries", 0),
90 | timeout=data.get("timeout", 0),
91 | retry_backoff=data.get("retry_backoff", 0),
92 | retry_delay_s=data.get("retry_delay_s", 0),
93 | )
94 |
95 |
96 | @dataclasses.dataclass()
97 | class WickerConfig:
98 | raw: Dict[str, Any]
99 | aws_s3_config: WickerAwsS3Config
100 | filesystem_configs: List[WickerFileSystemConfig]
101 | storage_download_config: StorageDownloadConfig
102 | wandb_config: WickerWandBConfig
103 |
104 | @classmethod
105 | def from_json(cls, data: Dict[str, Any]) -> WickerConfig:
106 | return cls(
107 | raw=data,
108 | aws_s3_config=WickerAwsS3Config.from_json(data.get(AWS_S3_CONFIG, {})),
109 | filesystem_configs=WickerFileSystemConfig.from_json_list(data.get("filesystem_configs", [])),
110 | storage_download_config=StorageDownloadConfig.from_json(data.get("storage_download_config", {})),
111 | wandb_config=WickerWandBConfig.from_json(data.get("wandb_config", {})),
112 | )
113 |
114 |
115 | def get_config() -> WickerConfig:
116 | """Retrieves the Wicker config for the current process"""
117 |
118 | wicker_config_path = os.getenv("WICKER_CONFIG_PATH", os.path.expanduser("~/wickerconfig.json"))
119 | with open(wicker_config_path, "r") as f:
120 | config = WickerConfig.from_json(json.load(f))
121 | return config
122 |
--------------------------------------------------------------------------------
/wicker/core/definitions.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import dataclasses
4 | import enum
5 | import re
6 | from typing import Any, Dict, TypeVar
7 |
8 | # flake8 linting struggles with the module name `schema` and DatasetDefinition dataclass field name `schema`
9 | from wicker import schema # noqa: 401
10 |
11 | Example = TypeVar("Example")
12 | ExampleMetadata = Dict[str, Any]
13 |
14 |
15 | DATASET_ID_REGEX = re.compile(r"^(?P[0-9a-zA-Z_]+)/(?P[0-9]+\.[0-9]+\.[0-9]+)$")
16 |
17 |
18 | @dataclasses.dataclass(frozen=True)
19 | class DatasetID:
20 | """Representation of the unique identifier of a dataset
21 |
22 | `name` should be alphanumeric and contain no spaces, only underscores
23 | `version` should be a semantic version (e.g. 1.0.0)
24 | """
25 |
26 | name: str
27 | version: str
28 |
29 | @classmethod
30 | def from_str(cls, s: str) -> DatasetID:
31 | """Parses a DatasetID from a string"""
32 | match = DATASET_ID_REGEX.match(s)
33 | if not match:
34 | raise ValueError(f"{s} is not a valid DatasetID")
35 | return cls(name=match["name"], version=match["version"])
36 |
37 | def __str__(self) -> str:
38 | """Helper function to return the representation of the dataset version as a path-like string"""
39 | return f"{self.name}/{self.version}"
40 |
41 | @staticmethod
42 | def validate_dataset_id(name: str, version: str) -> None:
43 | """Validates the name and version of a dataset"""
44 | if not re.match(r"^[0-9a-zA-Z_]+$", name):
45 | raise ValueError(
46 | f"Provided dataset name {name} must be alphanumeric and contain no spaces, only underscores"
47 | )
48 | if not re.match(r"^[0-9]+\.[0-9]+\.[0-9]+$", version):
49 | pass
50 | """raise ValueError(
51 | f"Provided dataset version {version} should be a semantic version without any prefixes/suffixes"
52 | )"""
53 |
54 | def __post_init__(self) -> None:
55 | DatasetID.validate_dataset_id(self.name, self.version)
56 |
57 |
58 | @dataclasses.dataclass(frozen=True)
59 | class DatasetDefinition:
60 | """Representation of the definition of a dataset (immutable once dataset is added)
61 |
62 | `name` should be alphanumeric and contain no spaces, only underscores
63 | `version` should be a semantic version (e.g. 1.0.0)
64 | """
65 |
66 | dataset_id: DatasetID
67 | schema: schema.DatasetSchema
68 |
69 | @property
70 | def identifier(self) -> DatasetID:
71 | return DatasetID(name=self.dataset_id.name, version=self.dataset_id.version)
72 |
73 | def __post_init__(self) -> None:
74 | DatasetID.validate_dataset_id(self.dataset_id.name, self.dataset_id.version)
75 |
76 |
77 | @dataclasses.dataclass(frozen=True)
78 | class DatasetPartition:
79 | """Representation of the definition of a partition within dataset
80 |
81 | The partition here is meant to represent the common train/val/test splits of a dataset, but
82 | could also represent partitions for other use cases.
83 |
84 | `partition` should be alphanumeric and contain no spaces, only underscores
85 | """
86 |
87 | dataset_id: DatasetID
88 | partition: str
89 |
90 | def __str__(self) -> str:
91 | """Helper function to return the representation of the partition as a path-like string"""
92 | return f"{self.dataset_id.name}/{self.dataset_id.version}/{self.partition}"
93 |
94 |
95 | class DatasetState(enum.Enum):
96 | """Representation of the state of a dataset"""
97 |
98 | STAGED = "PENDING"
99 | COMMITTED = "COMMITTED"
100 |
--------------------------------------------------------------------------------
/wicker/core/errors.py:
--------------------------------------------------------------------------------
1 | class WickerDatastoreException(Exception):
2 | pass
3 |
4 |
5 | class WickerSchemaException(Exception):
6 | pass
7 |
--------------------------------------------------------------------------------
/wicker/core/filelock.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import contextlib
4 | import fcntl
5 | import os
6 | import signal
7 | from typing import Any, Iterator, Optional
8 |
9 |
10 | @contextlib.contextmanager
11 | def _timeout(seconds: int, lockfile: str) -> Iterator[None]:
12 | """Context manager for triggering a timeout via SIGALRM if a blocking call
13 | does not return in `seconds` number of seconds.
14 |
15 | Usage:
16 | >>> with _timeout(5, "somefile"):
17 | >>> some_long_running_function()
18 |
19 | In the above example, if some_long_running_function takes more than 5 seconds, a
20 | TimeoutError will be raised from the `with _timeout(5, "somefile")` context.
21 | """
22 | if seconds == -1.0:
23 | # No timeout specified, we execute without registering a SIGALRM
24 | yield
25 | else:
26 |
27 | def timeout_handler(signum: Any, frame: Any) -> None:
28 | """Handles receiving a SIGALRM by throwing a TimeoutError with a useful
29 | error message
30 | """
31 | raise TimeoutError(f"{lockfile} could not be acquired after {seconds}s")
32 |
33 | # Register `timeout_handler` as the handler for any SIGALRMs that are raised
34 | # We store the original SIGALRM handler if available to restore it afterwards
35 | original_handler = signal.signal(signal.SIGALRM, timeout_handler)
36 |
37 | try:
38 | # Raise a SIGALRM in `seconds` number of seconds, and then yield to code
39 | # in the context manager's code-block.
40 | signal.alarm(seconds)
41 | yield
42 | # Make sure to always restore the original SIGALRM handler
43 | finally:
44 | signal.alarm(0)
45 | signal.signal(signal.SIGALRM, original_handler)
46 |
47 |
48 | class SimpleUnixFileLock:
49 | """Simple blocking lock class that uses files to achieve system-wide locking
50 |
51 | 1. If timeout is not specified, waits forever by default
52 | 2. Does not support recursive locking and WILL break if called recursively
53 | """
54 |
55 | def __init__(self, lock_file: str, timeout_seconds: int = -1):
56 | # The path to the lock file.
57 | self._lock_file = lock_file
58 |
59 | # The file descriptor for the *_lock_file* as it is returned by the
60 | # os.open() function.
61 | # This file lock is only NOT None, if the object currently holds the
62 | # lock.
63 | self._lock_file_fd: Optional[int] = None
64 |
65 | # The default timeout value.
66 | self.timeout_seconds = timeout_seconds
67 |
68 | return None
69 |
70 | def __enter__(self) -> SimpleUnixFileLock:
71 | open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
72 | fd = os.open(self._lock_file, open_mode)
73 | try:
74 | with _timeout(self.timeout_seconds, self._lock_file):
75 | fcntl.lockf(fd, fcntl.LOCK_EX)
76 | except (IOError, OSError, TimeoutError):
77 | os.close(fd)
78 | raise
79 | else:
80 | self._lock_file_fd = fd
81 | return self
82 |
83 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
84 | assert self._lock_file_fd is not None, "self._lock_file_fd not set by __enter__"
85 | fd = self._lock_file_fd
86 | self._lock_file_fd = None
87 | fcntl.flock(fd, fcntl.LOCK_UN)
88 | os.close(fd)
89 | return None
90 |
--------------------------------------------------------------------------------
/wicker/core/persistance.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Any, Dict, Iterable, List, Optional, Tuple
3 |
4 | import pyarrow as pa
5 | import pyarrow.compute as pc
6 |
7 | from wicker import schema as schema_module
8 | from wicker.core.column_files import ColumnBytesFileWriter
9 | from wicker.core.definitions import DatasetID
10 | from wicker.core.shuffle import save_index
11 | from wicker.core.storage import S3DataStorage, S3PathFactory
12 | from wicker.schema import dataparsing, serialization
13 |
14 | MAX_COL_FILE_NUMROW = 50 # TODO(isaak-willett): Magic number, we should derive this based on row size
15 |
16 | UnparsedExample = Dict[str, Any]
17 | ParsedExample = Dict[str, Any]
18 | PointerParsedExample = Dict[str, Any]
19 |
20 |
21 | class AbstractDataPersistor(abc.ABC):
22 | """
23 | Abstract class for persisting data onto a user defined cloud or local instance.
24 |
25 | Only s3 is supported right now but plan to support other data stores
26 | (BigQuery, Azure, Postgres)
27 | """
28 |
29 | def __init__(
30 | self,
31 | s3_storage: S3DataStorage = S3DataStorage(),
32 | s3_path_factory: S3PathFactory = S3PathFactory(),
33 | ) -> None:
34 | """
35 | Init a Persister
36 |
37 | :param s3_storage: The storage abstraction for S3
38 | :type s3_storage: S3DataStore
39 | :param s3_path_factory: The path factory for generating s3 paths
40 | based on dataset name and version
41 | :type s3_path_factory: S3PathFactory
42 | """
43 | super().__init__()
44 | self.s3_storage = s3_storage
45 | self.s3_path_factory = s3_path_factory
46 |
47 | @abc.abstractmethod
48 | def persist_wicker_dataset(
49 | self,
50 | dataset_name: str,
51 | dataset_version: str,
52 | dataset_schema: schema_module.DatasetSchema,
53 | dataset: Any,
54 | ) -> Optional[Dict[str, int]]:
55 | """
56 | Persist a user specified dataset defined by name, version, schema, and data.
57 |
58 | :param dataset_name: Name of the dataset
59 | :type dataset_name: str
60 | :param dataset_version: Version of the dataset
61 | :type: dataset_version: str
62 | :param dataset_schema: Schema of the dataset
63 | :type dataset_schema: wicker.schema.schema.DatasetSchema
64 | :param dataset: Data of the dataset
65 | :type dataset: User defined
66 | """
67 | raise NotImplementedError("Method, persist_wicker_dataset, needs to be implemented in inhertiance class.")
68 |
69 | @staticmethod
70 | def parse_row(data_row: UnparsedExample, schema: schema_module.DatasetSchema) -> ParsedExample:
71 | """
72 | Parse a row to test for validation errors.
73 |
74 | :param data_row: Data row to be parsed
75 | :type data_row: UnparsedExample
76 | :return: parsed row containing the correct types associated with schema
77 | :rtype: ParsedExample
78 | """
79 | return dataparsing.parse_example(data_row, schema)
80 |
81 | # Write data to Column Byte Files
82 |
83 | @staticmethod
84 | def persist_wicker_partition(
85 | dataset_name: str,
86 | spark_partition_iter: Iterable[Tuple[str, ParsedExample]],
87 | schema: schema_module.DatasetSchema,
88 | s3_storage: S3DataStorage,
89 | s3_path_factory: S3PathFactory,
90 | target_max_column_file_numrows: int = 50,
91 | ) -> Iterable[Tuple[str, PointerParsedExample]]:
92 | """Persists a Spark partition of examples with parsed bytes into S3Storage as ColumnBytesFiles,
93 | returning a new Spark partition of examples with heavy-pointers and metadata only.
94 | :param dataset_name: dataset name
95 | :param spark_partition_iter: Spark partition of `(partition_str, example)`, where `example`
96 | is a dictionary of parsed bytes that needs to be uploaded to S3
97 | :param target_max_column_file_numrows: Maximum number of rows in column files. Defaults to 50.
98 | :return: a Generator of `(partition_str, example)`, where `example` is a dictionary with heavy-pointers
99 | that point to ColumnBytesFiles in S3 in place of the parsed bytes
100 | """
101 | column_bytes_file_writers: Dict[str, ColumnBytesFileWriter] = {}
102 | heavy_pointer_columns = schema.get_pointer_columns()
103 | metadata_columns = schema.get_non_pointer_columns()
104 |
105 | for partition, example in spark_partition_iter:
106 | # Create ColumnBytesFileWriter lazily as required, for each partition
107 | if partition not in column_bytes_file_writers:
108 | column_bytes_file_writers[partition] = ColumnBytesFileWriter(
109 | s3_storage,
110 | s3_path_factory,
111 | target_file_rowgroup_size=target_max_column_file_numrows,
112 | dataset_name=dataset_name,
113 | )
114 |
115 | # Write to ColumnBytesFileWriter and return only metadata + heavy-pointers
116 | parquet_metadata: Dict[str, Any] = {col: example[col] for col in metadata_columns}
117 | for col in heavy_pointer_columns:
118 | loc = column_bytes_file_writers[partition].add(col, example[col])
119 | parquet_metadata[col] = loc.to_bytes()
120 | yield partition, parquet_metadata
121 |
122 | # Flush all writers when finished
123 | for partition in column_bytes_file_writers:
124 | column_bytes_file_writers[partition].close()
125 |
126 | @staticmethod
127 | def save_partition_tbl(
128 | partition_table_tuple: Tuple[str, pa.Table],
129 | dataset_name: str,
130 | dataset_version: str,
131 | s3_storage: S3DataStorage,
132 | s3_path_factory: S3PathFactory,
133 | ) -> Tuple[str, int]:
134 | """
135 | Save a partition table to s3 under the dataset name and version.
136 |
137 | :param partition_table_tuple: Tuple of partition id and pyarrow table to save
138 | :type partition_table_tuple: Tuple[str, pyarrow.Table]
139 | :return: A tuple containing the paritiion id and the num of saved rows
140 | :rtype: Tuple[str, int]
141 | """
142 | partition, pa_tbl = partition_table_tuple
143 | save_index(
144 | dataset_name,
145 | dataset_version,
146 | {partition: pa_tbl},
147 | s3_storage=s3_storage,
148 | s3_path_factory=s3_path_factory,
149 | )
150 | return (partition, pa_tbl.num_rows)
151 |
152 |
153 | def persist_wicker_dataset(
154 | dataset_name: str,
155 | dataset_version: str,
156 | dataset_schema: schema_module.DatasetSchema,
157 | dataset: Any,
158 | s3_storage: S3DataStorage = S3DataStorage(),
159 | s3_path_factory: S3PathFactory = S3PathFactory(),
160 | ) -> Optional[Dict[str, int]]:
161 | """
162 | Persist wicker dataset public facing api function, for api consistency.
163 | :param dataset_name: name of dataset persisted
164 | :type dataset_name: str
165 | :param dataset_version: version of dataset persisted
166 | :type dataset_version: str
167 | :param dataset_schema: schema of dataset to be persisted
168 | :type dataset_schema: DatasetSchema
169 | :param rdd: rdd of data to persist
170 | :type rdd: RDD
171 | :param s3_storage: s3 storage abstraction
172 | :type s3_storage: S3DataStorage
173 | :param s3_path_factory: s3 path abstraction
174 | :type s3_path_factory: S3PathFactory
175 | """
176 | return BasicPersistor(s3_storage, s3_path_factory).persist_wicker_dataset(
177 | dataset_name, dataset_version, dataset_schema, dataset
178 | )
179 |
180 |
181 | class BasicPersistor(AbstractDataPersistor):
182 | """
183 | Basic persistor class that persists wicker data on s3 in a non sorted manner.
184 |
185 | We will move to supporting other features like shuffling, other data engines, etc...
186 | """
187 |
188 | def __init__(
189 | self, s3_storage: S3DataStorage = S3DataStorage(), s3_path_factory: S3PathFactory = S3PathFactory()
190 | ) -> None:
191 | super().__init__(s3_storage, s3_path_factory)
192 |
193 | def persist_wicker_dataset(
194 | self, dataset_name: str, dataset_version: str, dataset_schema: schema_module.DatasetSchema, dataset: Any
195 | ) -> Optional[Dict[str, int]]:
196 | """
197 | Persist a user defined dataset, pushing data to s3 in a basic manner
198 |
199 | :param dataset_name: Name of the dataset
200 | :type dataset_name: str
201 | :param dataset_version: Version of the dataset
202 | :type: dataset_version: str
203 | :param dataset_schema: Schema of the dataset
204 | :type dataset_schema: wicker.schema.schema.DatasetSchema
205 | :param dataset: Data of the dataset
206 | :type dataset: User defined
207 | """
208 | # what needs to be done within this function
209 | # 1. Check if the variables are set
210 | # check if variables have been set ie: not None
211 | if (
212 | not isinstance(dataset_name, str)
213 | or not isinstance(dataset_version, str)
214 | or not isinstance(dataset_schema, schema_module.DatasetSchema)
215 | ):
216 | raise ValueError("Current dataset variables not all set, set all to proper not None values")
217 |
218 | # 2. Put the schema up on s3
219 | schema_path = self.s3_path_factory.get_dataset_schema_path(
220 | DatasetID(name=dataset_name, version=dataset_version)
221 | )
222 | self.s3_storage.put_object_s3(serialization.dumps(dataset_schema).encode("utf-8"), schema_path)
223 |
224 | # 3. Validate the rows and ensure data is well formed, sort while doing
225 | dataset_0 = [(row[0], self.parse_row(row[1], dataset_schema)) for row in dataset]
226 |
227 | # 4. Sort the dataset if not sorted
228 | sorted_dataset_0 = sorted(dataset_0, key=lambda tup: tup[0])
229 |
230 | # 6. Persist the partitions to S3
231 | metadata_iterator = self.persist_wicker_partition(
232 | dataset_name,
233 | sorted_dataset_0,
234 | dataset_schema,
235 | self.s3_storage,
236 | self.s3_path_factory,
237 | MAX_COL_FILE_NUMROW,
238 | )
239 |
240 | # 7. Create the parition table, need to combine keys in a way we can form table
241 | # split into k dicts where k is partition number and the data is a list of values
242 | # for each key for all the dicts in the partition
243 | merged_dicts: Dict[str, Dict[str, List[Any]]] = {}
244 | for partition_key, row in metadata_iterator:
245 | current_dict: Dict[str, List[Any]] = merged_dicts.get(partition_key, {})
246 | for col in row.keys():
247 | if col in current_dict:
248 | current_dict[col].append(row[col])
249 | else:
250 | current_dict[col] = [row[col]]
251 | merged_dicts[partition_key] = current_dict
252 | # convert each of the dicts to a pyarrow table in the same way SparkPersistor
253 | # converts, needed to ensure parity between the two
254 | arrow_dict = {}
255 | for partition_key, data_dict in merged_dicts.items():
256 | data_table = pa.Table.from_pydict(data_dict)
257 | arrow_dict[partition_key] = pc.take(
258 | pa.Table.from_pydict(data_dict),
259 | pc.sort_indices(data_table, sort_keys=[(pk, "ascending") for pk in dataset_schema.primary_keys]),
260 | )
261 |
262 | # 8. Persist the partition table to s3
263 | written_dict = {}
264 | for partition_key, pa_table in arrow_dict.items():
265 | self.save_partition_tbl(
266 | (partition_key, pa_table), dataset_name, dataset_version, self.s3_storage, self.s3_path_factory
267 | )
268 | written_dict[partition_key] = pa_table.num_rows
269 |
270 | return written_dict
271 |
--------------------------------------------------------------------------------
/wicker/core/shuffle.py:
--------------------------------------------------------------------------------
1 | """Classes handling the shuffling of data in S3 when committing a dataset.
2 |
3 | When committing a dataset, we sort the data by primary_key before materializing in S3
4 | as Parquet files.
5 |
6 | 1. The ShuffleJob is the unit of work, and it is just an ordered set of Examples that should
7 | be bundled together in one Parquet file
8 |
9 | 2. The ShuffleJobFactory produces ShuffleJobs, using a DatasetWriter object to retrieve the
10 | written examples for divvying up as ShuffleJobs.
11 |
12 | 3. ShuffleWorkers receive the ShuffleJobs and perform the act of retrieving the data and
13 | persisting the data into S3 as Parquet files (one for each ShuffleJob)
14 | """
15 | from __future__ import annotations
16 |
17 | import collections
18 | import concurrent.futures
19 | import dataclasses
20 | import os
21 | import pickle
22 | import tempfile
23 | from typing import Any, Dict, Generator, List, Optional, Tuple
24 |
25 | import boto3
26 | import pyarrow as pa
27 | import pyarrow.fs as pafs
28 | import pyarrow.parquet as papq
29 |
30 | from wicker.core.column_files import ColumnBytesFileWriter
31 | from wicker.core.definitions import DatasetID, DatasetPartition
32 | from wicker.core.storage import S3DataStorage, S3PathFactory
33 | from wicker.core.writer import DatasetWriterBackend
34 | from wicker.schema import serialization
35 |
36 | # Maximum working set for each worker
37 | DEFAULT_WORKER_MAX_WORKING_SET_SIZE = 16384
38 |
39 |
40 | @dataclasses.dataclass
41 | class ShuffleJob:
42 | """Represents all the shuffling operations that will happen for a given partition (train/eval/test) on a given
43 | compute shard."""
44 |
45 | dataset_partition: DatasetPartition
46 | files: List[Tuple[str, int]]
47 |
48 |
49 | class ShuffleJobFactory:
50 | def __init__(
51 | self,
52 | writer_backend: DatasetWriterBackend,
53 | worker_max_working_set_size: int = DEFAULT_WORKER_MAX_WORKING_SET_SIZE,
54 | ):
55 | self.writer_backend = writer_backend
56 |
57 | # Factory configurations
58 | self.worker_max_working_set_size = worker_max_working_set_size
59 |
60 | def build_shuffle_jobs(self, dataset_id: DatasetID) -> Generator[ShuffleJob, None, None]:
61 | # Initialize with first item
62 | example_keys = self.writer_backend._metadata_db.scan_sorted(dataset_id)
63 | try:
64 | initial_key = next(example_keys)
65 | except StopIteration:
66 | return
67 | job = ShuffleJob(
68 | dataset_partition=DatasetPartition(dataset_id=dataset_id, partition=initial_key.partition),
69 | files=[(initial_key.row_data_path, initial_key.row_size)],
70 | )
71 |
72 | # Yield ShuffleJobs as we accumulate ExampleKeys, where each ShuffleJob is upper-bounded in size by
73 | # self.worker_max_working_set_size and has all ExampleKeys from the same partition
74 | for example_key in example_keys:
75 | if (
76 | example_key.partition == job.dataset_partition.partition
77 | and len(job.files) < self.worker_max_working_set_size
78 | ):
79 | job.files.append((example_key.row_data_path, example_key.row_size))
80 | continue
81 |
82 | # Yield job and construct new job to keep iterating
83 | yield job
84 | job = ShuffleJob(
85 | dataset_partition=DatasetPartition(dataset_id=dataset_id, partition=example_key.partition),
86 | files=[(example_key.row_data_path, example_key.row_size)],
87 | )
88 | else:
89 | yield job
90 |
91 |
92 | _download_thread_session: Optional[boto3.session.Session] = None
93 | _download_thread_client: Optional[S3DataStorage] = None
94 |
95 |
96 | def _initialize_download_thread():
97 | global _download_thread_session
98 | global _download_thread_client
99 | if _download_thread_client is None:
100 | _download_thread_session = boto3.session.Session()
101 | _download_thread_client = S3DataStorage(session=_download_thread_session)
102 |
103 |
104 | class ShuffleWorker:
105 | def __init__(
106 | self,
107 | target_rowgroup_bytes_size: int = int(256e6),
108 | max_worker_threads: int = 16,
109 | max_memory_usage_bytes: int = int(2e9),
110 | storage: S3DataStorage = S3DataStorage(),
111 | s3_path_factory: S3PathFactory = S3PathFactory(),
112 | ):
113 | self.target_rowgroup_bytes_size = target_rowgroup_bytes_size
114 | self.max_worker_threads = max_worker_threads
115 | self.max_memory_usage_bytes = max_memory_usage_bytes
116 | self.storage = storage
117 | self.s3_path_factory = s3_path_factory
118 |
119 | def _download_files(self, job: ShuffleJob) -> Generator[Dict[str, Any], None, None]:
120 | """Downloads the files in a ShuffleJob, yielding a generator of Dict[str, Any]
121 |
122 | Internally, this function maintains a buffer of lookahead downloads to execute downloads
123 | in parallel over a ThreadPoolExecutor, up to a maximum of `max_memory_usage_bytes` bytes.
124 |
125 | Args:
126 | job (ShuffleJob): job to download files for
127 |
128 | Yields:
129 | Generator[Dict[str, Any], None, None]: stream of Dict[str, Any] from downloading the files
130 | in order
131 | """
132 | # TODO(jchia): Add retries here
133 | def _download_file(filepath: str) -> Dict[str, Any]:
134 | assert _download_thread_client is not None
135 | return pickle.loads(_download_thread_client.fetch_obj_s3(filepath))
136 |
137 | with concurrent.futures.ThreadPoolExecutor(
138 | self.max_worker_threads, initializer=_initialize_download_thread
139 | ) as executor:
140 | buffer: Dict[int, Tuple[concurrent.futures.Future[Dict[str, Any]], int]] = {}
141 | buffer_size = 0
142 | buffer_index = 0
143 | for current_index in range(len(job.files)):
144 | while buffer_index < len(job.files) and buffer_size < self.max_memory_usage_bytes:
145 | filepath, file_size = job.files[buffer_index]
146 | buffer[buffer_index] = (executor.submit(_download_file, filepath), file_size)
147 | buffer_size += file_size
148 | buffer_index += 1
149 | current_future, current_file_size = buffer[current_index]
150 | yield current_future.result()
151 | buffer_size -= current_file_size
152 | del buffer[current_index]
153 |
154 | def _estimate_target_file_rowgroup_size(
155 | self,
156 | job: ShuffleJob,
157 | target_rowgroup_size_bytes: int = int(256e6),
158 | min_target_rowgroup_size: int = 16,
159 | ) -> int:
160 | """Estimates the number of rows to include in each rowgroup using a target size for the rowgroup
161 |
162 | :param job: job to estimate
163 | :param target_rowgroup_size_bytes: target size in bytes of a rowgroup, defaults to 256MB
164 | :param min_target_rowgroup_size: minimum number of rows in a rowgroup
165 | :return: target number of rows in a rowgroup
166 | """
167 | average_filesize = sum([size for _, size in job.files]) / len(job.files)
168 | return max(min_target_rowgroup_size, int(target_rowgroup_size_bytes / average_filesize))
169 |
170 | def process_job(self, job: ShuffleJob) -> pa.Table:
171 | # Load dataset schema
172 | dataset_schema = serialization.loads(
173 | self.storage.fetch_obj_s3(
174 | self.s3_path_factory.get_dataset_schema_path(job.dataset_partition.dataset_id)
175 | ).decode("utf-8")
176 | )
177 |
178 | # Estimate how many rows to add to each ColumnBytesFile
179 | target_file_rowgroup_size = self._estimate_target_file_rowgroup_size(job)
180 |
181 | # Initialize data containers to dump into parquet
182 | heavy_pointer_columns = dataset_schema.get_pointer_columns()
183 | metadata_columns = dataset_schema.get_non_pointer_columns()
184 | parquet_metadata: Dict[str, List[Any]] = collections.defaultdict(list)
185 |
186 | # Parse each row, uploading heavy_pointer bytes to S3 and storing only pointers
187 | # in parquet_metadata
188 | with ColumnBytesFileWriter(
189 | self.storage,
190 | self.s3_path_factory,
191 | target_file_rowgroup_size=target_file_rowgroup_size,
192 | dataset_name=job.dataset_partition.dataset_id.name,
193 | ) as writer:
194 | for data in self._download_files(job):
195 | for col in metadata_columns:
196 | parquet_metadata[col].append(data[col])
197 | for col in heavy_pointer_columns:
198 | loc = writer.add(col, data[col])
199 | parquet_metadata[col].append(loc.to_bytes())
200 |
201 | # Save parquet_metadata as a PyArrow Table
202 | assert len({len(parquet_metadata[col]) for col in parquet_metadata}) == 1, "All columns must have same length"
203 | return pa.Table.from_pydict(parquet_metadata)
204 |
205 |
206 | def save_index(
207 | dataset_name: str,
208 | dataset_version: str,
209 | final_indices: Dict[str, pa.Table],
210 | s3_path_factory: Optional[S3PathFactory] = None,
211 | s3_storage: Optional[S3DataStorage] = None,
212 | ) -> None:
213 | """Saves a generated final_index into persistent storage
214 |
215 | :param dataset_name: Name of the dataset
216 | :param dataset_version: Version of the dataset
217 | :param final_index: Dictionary of pandas dataframes which is the finalized index
218 | :param pyarrow_filesystem: PyArrow filesystem to use, defaults to None
219 | :param s3_path_factory: S3PathFactory to use
220 | :param s3_storage: S3DataStorage to use
221 | """
222 | s3_storage = s3_storage if s3_storage is not None else S3DataStorage()
223 | s3_path_factory = s3_path_factory if s3_path_factory is not None else S3PathFactory()
224 | for partition_name in final_indices:
225 | dataset_partition = DatasetPartition(
226 | dataset_id=DatasetID(
227 | name=dataset_name,
228 | version=dataset_version,
229 | ),
230 | partition=partition_name,
231 | )
232 |
233 | parquet_folder = s3_path_factory.get_dataset_partition_path(dataset_partition, s3_prefix=True)
234 | parquet_path = os.path.join(parquet_folder, "part-0.parquet")
235 |
236 | # Write the Parquet file as one file locally, then upload to S3
237 | with tempfile.NamedTemporaryFile() as tmpfile:
238 | pa_table = final_indices[partition_name]
239 | papq.write_table(
240 | pa_table,
241 | tmpfile.name,
242 | compression="zstd",
243 | row_group_size=None,
244 | filesystem=pafs.LocalFileSystem(),
245 | # We skip writing statistics since it bloats the file, and we don't actually run any queries
246 | # on the Parquet files that could make use of predicate push-down
247 | write_statistics=False,
248 | )
249 | s3_storage.put_file_s3(
250 | tmpfile.name,
251 | parquet_path,
252 | )
253 |
254 | return None
255 |
--------------------------------------------------------------------------------
/wicker/core/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import signal
3 |
4 |
5 | @contextlib.contextmanager
6 | def time_limit(seconds, error_message: str):
7 | def signal_handler(signum, frame):
8 | raise TimeoutError(f"Timed out!. {error_message}")
9 |
10 | signal.signal(signal.SIGALRM, signal_handler)
11 | signal.alarm(seconds)
12 | try:
13 | yield
14 | finally:
15 | signal.alarm(0)
16 |
--------------------------------------------------------------------------------
/wicker/core/writer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | import contextlib
5 | import dataclasses
6 | import hashlib
7 | import os
8 | import pickle
9 | import threading
10 | import time
11 | from concurrent.futures import Executor, Future, ThreadPoolExecutor
12 | from types import TracebackType
13 | from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
14 |
15 | from wicker.core.definitions import DatasetDefinition, DatasetID
16 | from wicker.core.errors import WickerDatastoreException
17 | from wicker.core.storage import S3DataStorage, S3PathFactory
18 | from wicker.schema import dataparsing, serialization
19 |
20 |
21 | @dataclasses.dataclass
22 | class ExampleKey:
23 | """Unique identifier for one example."""
24 |
25 | # Name of the partition (usually train/eval/test)
26 | partition: str
27 | # Values of the primary keys for the example (in order of key precedence)
28 | primary_key_values: List[Union[str, int]]
29 |
30 | def hash(self) -> str:
31 | return hashlib.sha256(
32 | "/".join([self.partition, *[str(obj) for obj in self.primary_key_values]]).encode("utf-8")
33 | ).hexdigest()
34 |
35 |
36 | @dataclasses.dataclass
37 | class MetadataDatabaseScanRow:
38 | """Container for data obtained by scanning the MetadataDatabase"""
39 |
40 | partition: str
41 | row_data_path: str
42 | row_size: int
43 |
44 |
45 | class AbstractDatasetWriterMetadataDatabase:
46 | """Database for storing metadata, used from inside the DatasetWriterBackend.
47 |
48 | NOTE: Implementors - this is the main implementation integration point for creating a new kind of DatasetWriter.
49 | """
50 |
51 | @abc.abstractmethod
52 | def save_row_metadata(self, dataset_id: DatasetID, key: ExampleKey, location: str, row_size: int) -> None:
53 | """Saves a row in the metadata database, marking it as having been uploaded to S3 and
54 | ready for shuffling.
55 |
56 | :param dataset_id: The ID of the dataset to save to
57 | :param key: The key of the example
58 | :param location: The location of the example in S3
59 | :param row_size: The size of the file in S3
60 | """
61 | pass
62 |
63 | @abc.abstractmethod
64 | def scan_sorted(self, dataset_id: DatasetID) -> Generator[MetadataDatabaseScanRow, None, None]:
65 | """Scans the MetadataDatabase for a **SORTED** stream of MetadataDatabaseScanRows for a given dataset.
66 | The stream is sorted by partition first, and then primary_key_values second.
67 |
68 | Should be fast O(minutes) to perform as this will be called from a single machine to assign chunks to jobs
69 | to run.
70 |
71 | :param dataset_id: The dataset to scan the metadata database for
72 | :return: a Generator of ExampleDBRows in **SORTED** partition + primary_key order
73 | """
74 | pass
75 |
76 |
77 | class DatasetWriterBackend:
78 | """The backend for a DatasetWriter.
79 |
80 | Responsible for saving and retrieving data used during the dataset writing and committing workflow.
81 | """
82 |
83 | def __init__(
84 | self,
85 | s3_path_factory: S3PathFactory,
86 | s3_storage: S3DataStorage,
87 | metadata_database: AbstractDatasetWriterMetadataDatabase,
88 | ):
89 | self._s3_path_factory = s3_path_factory
90 | self._s3_storage = s3_storage
91 | self._metadata_db = metadata_database
92 |
93 | def save_row(self, dataset_id: DatasetID, key: ExampleKey, raw_data: Dict[str, Any]) -> None:
94 | """Adds an example to the backend
95 |
96 | :param dataset_id: ID of the dataset to save the row to
97 | :param key: Key of the example to write
98 | :param raw_data: raw data for the example that conforms to the schema provided at initialization
99 | """
100 | hashed_row_key = key.hash()
101 | pickled_row = pickle.dumps(raw_data) # TODO(jchia): Do we want a more sophisticated storage format here?
102 | row_s3_path = os.path.join(
103 | self._s3_path_factory.get_temporary_row_files_path(dataset_id),
104 | hashed_row_key,
105 | )
106 |
107 | # Persist data in S3 and in MetadataDatabase
108 | self._s3_storage.put_object_s3(pickled_row, row_s3_path)
109 | self._metadata_db.save_row_metadata(dataset_id, key, row_s3_path, len(pickled_row))
110 |
111 | def commit_schema(
112 | self,
113 | dataset_definition: DatasetDefinition,
114 | ) -> None:
115 | """Write the schema to the backend as part of the commit step."""
116 | schema_path = self._s3_path_factory.get_dataset_schema_path(dataset_definition.identifier)
117 | serialized_schema = serialization.dumps(dataset_definition.schema)
118 | self._s3_storage.put_object_s3(serialized_schema.encode(), schema_path)
119 |
120 |
121 | DEFAULT_BUFFER_SIZE_LIMIT = 1000
122 |
123 |
124 | class DatasetWriter:
125 | """DatasetWriter providing async writing functionality. Implementors should override
126 | the ._save_row_impl method to define functionality for saving each individual row from inside the
127 | async thread executors.
128 | """
129 |
130 | def __init__(
131 | self,
132 | dataset_definition: DatasetDefinition,
133 | metadata_database: AbstractDatasetWriterMetadataDatabase,
134 | s3_path_factory: Optional[S3PathFactory] = None,
135 | s3_storage: Optional[S3DataStorage] = None,
136 | buffer_size_limit: int = DEFAULT_BUFFER_SIZE_LIMIT,
137 | executor: Optional[Executor] = None,
138 | wait_flush_timeout_seconds: int = 300,
139 | ):
140 | """Create a new DatasetWriter
141 |
142 | :param dataset_definition: definition of the dataset
143 | :param s3_path_factory: factory for s3 paths
144 | :param s3_storage: S3-compatible storage for storing data
145 | :param buffer_size_limit: size limit to the number of examples buffered in-memory, defaults to 1000
146 | :param executor: concurrent.futures.Executor to use for async writing, defaults to None
147 | :param wait_flush_timeout_seconds: number of seconds to wait before timing out on flushing
148 | all examples, defaults to 10
149 | """
150 | self.dataset_definition = dataset_definition
151 | self.backend = DatasetWriterBackend(
152 | s3_path_factory if s3_path_factory is not None else S3PathFactory(),
153 | s3_storage if s3_storage is not None else S3DataStorage(),
154 | metadata_database,
155 | )
156 |
157 | self.buffer: List[Tuple[ExampleKey, Dict[str, Any]]] = []
158 | self.buffer_size_limit = buffer_size_limit
159 |
160 | self.wait_flush_timeout_seconds = wait_flush_timeout_seconds
161 | self.executor = executor if executor is not None else ThreadPoolExecutor(max_workers=16)
162 | self.writes_in_flight: Dict[str, Dict[str, Any]] = {}
163 | self.flush_condition_variable = threading.Condition()
164 |
165 | def __enter__(self) -> DatasetWriter:
166 | return self
167 |
168 | def __exit__(
169 | self,
170 | exception_type: Optional[Type[BaseException]],
171 | exception_value: Optional[BaseException],
172 | traceback: Optional[TracebackType],
173 | ) -> None:
174 | self.flush(block=True)
175 |
176 | def __del__(self) -> None:
177 | self.flush(block=True)
178 |
179 | def add_example(self, partition_name: str, raw_data: Dict[str, Any]) -> None:
180 | """Adds an example to the writer
181 |
182 | :param partition_name: partition name where the example belongs
183 | :param raw_data: raw data for the example that conforms to the schema provided at initialization
184 | """
185 | # Run sanity checks on the data, fill empty fields.
186 | ex = dataparsing.parse_example(raw_data, self.dataset_definition.schema)
187 | example_key = ExampleKey(
188 | partition=partition_name, primary_key_values=[ex[k] for k in self.dataset_definition.schema.primary_keys]
189 | )
190 | self.buffer.append((example_key, ex))
191 |
192 | # Flush buffer to persistent storage if at size limit
193 | if len(self.buffer) > self.buffer_size_limit:
194 | self.flush(block=False)
195 |
196 | def flush(self, block: bool = True) -> None:
197 | """Flushes the writer
198 |
199 | :param block: whether to block on flushing all currently buffered examples, defaults to True
200 | :raises TimeoutError: timing out on flushing all examples
201 | """
202 | batch_data = self.buffer
203 | self.buffer = []
204 | self._save_batch_data(batch_data)
205 |
206 | if block:
207 | with self._block_on_writes_in_flight(max_in_flight=0, timeout_seconds=self.wait_flush_timeout_seconds):
208 | pass
209 |
210 | @contextlib.contextmanager
211 | def _block_on_writes_in_flight(self, max_in_flight: int = 0, timeout_seconds: int = 60) -> Iterator[None]:
212 | """Blocks until number of writes in flight <= max_in_flight and yields control to the with block
213 |
214 | Usage:
215 | >>> with self._block_on_writes_in_flight(max_in_flight=10):
216 | >>> # Code here holds the exclusive lock on self.flush_condition_variable
217 | >>> ...
218 | >>> # Code here has released the lock, and has no guarantees about the number of writes in flight
219 |
220 | :param max_in_flight: maximum number of writes in flight, defaults to 0
221 | :param timeout_seconds: maximum number of seconds to wait, defaults to 10
222 | :raises TimeoutError: if waiting for more than timeout_seconds
223 | """
224 | start_time = time.time()
225 | with self.flush_condition_variable:
226 | while len(self.writes_in_flight) > max_in_flight:
227 | if time.time() - start_time > timeout_seconds:
228 | raise TimeoutError(
229 | f"Timed out while flushing dataset writes with {len(self.writes_in_flight)} writes in flight"
230 | )
231 | self.flush_condition_variable.wait()
232 | yield
233 |
234 | def _save_row(self, key: ExampleKey, data: Dict[str, Any]) -> str:
235 | self.backend.save_row(self.dataset_definition.identifier, key, data)
236 | return key.hash()
237 |
238 | def _save_batch_data(self, batch_data: List[Tuple[ExampleKey, Dict[str, Any]]]) -> None:
239 | """Save a batch of data to persistent storage
240 |
241 | :param row_keys: Unique identifiers for each row Uses only 0-9A-F, can be used as a file name.
242 | :param batch_data: Batch of data
243 | """
244 |
245 | def done_callback(future: Future[str]) -> None:
246 | # TODO(jchia): We can add retries on failure by reappending to the buffer
247 | # this currently may raise a CancelledError or some Exception thrown by the .save_row function
248 | with self.flush_condition_variable:
249 | del self.writes_in_flight[future.result()]
250 | self.flush_condition_variable.notify()
251 |
252 | for key, data in batch_data:
253 | # Keep the number of writes in flight always smaller than 2 * self.buffer_size_limit
254 | with self._block_on_writes_in_flight(
255 | max_in_flight=2 * self.buffer_size_limit,
256 | timeout_seconds=self.wait_flush_timeout_seconds,
257 | ):
258 | key_hash = key.hash()
259 | if key_hash in self.writes_in_flight:
260 | raise WickerDatastoreException(
261 | f"Error: data example has non unique key {key}, primary keys must be unique"
262 | )
263 | self.writes_in_flight[key_hash] = data
264 | future = self.executor.submit(self._save_row, key, data)
265 | future.add_done_callback(done_callback)
266 |
--------------------------------------------------------------------------------
/wicker/plugins/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/plugins/__init__.py
--------------------------------------------------------------------------------
/wicker/plugins/dynamodb.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import heapq
3 | from typing import Generator, List, Tuple
4 |
5 | try:
6 | import pynamodb # type: ignore
7 | except ImportError:
8 | raise RuntimeError(
9 | "pynamodb is not detected in current environment, install Wicker with extra arguments:"
10 | " `pip install wicker[dynamodb]`"
11 | )
12 | from pynamodb.attributes import NumberAttribute, UnicodeAttribute
13 | from pynamodb.models import Model
14 | from retry import retry
15 |
16 | from wicker.core.config import get_config
17 | from wicker.core.definitions import DatasetID
18 | from wicker.core.writer import (
19 | AbstractDatasetWriterMetadataDatabase,
20 | ExampleKey,
21 | MetadataDatabaseScanRow,
22 | )
23 |
24 | # DANGER: If these constants are ever changed, this is a backward-incompatible change.
25 | # Make sure that all the writers and readers of dynamodb are in sync when changing.
26 | NUM_DYNAMODB_SHARDS = 32
27 | HASH_PREFIX_LENGTH = 4
28 | DYNAMODB_QUERY_PAGINATION_LIMIT = 1000
29 |
30 |
31 | @dataclasses.dataclass(frozen=True)
32 | class DynamoDBConfig:
33 | table_name: str
34 | region: str
35 |
36 |
37 | def get_dynamodb_config() -> DynamoDBConfig:
38 | raw_config = get_config().raw
39 | if "dynamodb_config" not in raw_config:
40 | raise RuntimeError("Could not find 'dynamodb' parameters in config")
41 | if "table_name" not in raw_config["dynamodb_config"]:
42 | raise RuntimeError("Could not find 'table_name' parameter in config.dynamodb_config")
43 | if "region" not in raw_config["dynamodb_config"]:
44 | raise RuntimeError("Could not find 'region' parameter in config.dynamodb_config")
45 | return DynamoDBConfig(
46 | table_name=raw_config["dynamodb_config"]["table_name"],
47 | region=raw_config["dynamodb_config"]["region"],
48 | )
49 |
50 |
51 | class DynamoDBExampleDBRow(Model):
52 | class Meta:
53 | table_name = get_dynamodb_config().table_name
54 | region = get_dynamodb_config().region
55 |
56 | dataset_id = UnicodeAttribute(hash_key=True)
57 | example_id = UnicodeAttribute(range_key=True)
58 | partition = UnicodeAttribute()
59 | row_data_path = UnicodeAttribute()
60 | row_size = NumberAttribute()
61 |
62 |
63 | def _key_to_row_id_and_shard_id(example_key: ExampleKey) -> Tuple[str, int]:
64 | """Deterministically convert an ExampleKey into a row_id and a shard_id which are used as
65 | DynamoDB RangeKeys and HashKeys respectively.
66 |
67 | HashKeys help to increase read/write throughput by allowing us to use different partitions.
68 | RangeKeys are how the rows are sorted within partitions by DynamoDB.
69 |
70 | We completely randomize the hash and range key to optimize for write throughput, but this means
71 | that sorting needs to be done entirely client-side in our application.
72 | """
73 | partition_example_id = f"{example_key.partition}//{'/'.join([str(key) for key in example_key.primary_key_values])}"
74 | hash = example_key.hash()
75 | shard = int(hash, 16) % NUM_DYNAMODB_SHARDS
76 | return partition_example_id, shard
77 |
78 |
79 | def _dataset_shard_name(dataset_id: DatasetID, shard_id: int) -> str:
80 | """Get the name of the DynamoDB partition for a given dataset_definition and shard number"""
81 | return f"{dataset_id}_shard{shard_id:02d}"
82 |
83 |
84 | class DynamodbMetadataDatabase(AbstractDatasetWriterMetadataDatabase):
85 | def save_row_metadata(self, dataset_id: DatasetID, key: ExampleKey, location: str, row_size: int) -> None:
86 | """Saves a row in the metadata database, marking it as having been uploaded to S3 and
87 | ready for shuffling.
88 |
89 | :param dataset_id: The ID of the dataset to save to
90 | :param key: The key of the example
91 | :param location: The location of the example in S3
92 | :param row_size: The size of the file in S3
93 | """
94 | partition_example_id, shard_id = _key_to_row_id_and_shard_id(key)
95 | entry = DynamoDBExampleDBRow(
96 | dataset_id=_dataset_shard_name(dataset_id, shard_id),
97 | example_id=partition_example_id,
98 | partition=key.partition,
99 | row_data_path=location,
100 | row_size=row_size,
101 | )
102 | entry.save()
103 |
104 | def scan_sorted(self, dataset_id: DatasetID) -> Generator[MetadataDatabaseScanRow, None, None]:
105 | """Scans the MetadataDatabase for a **SORTED** list of ExampleKeys for a given dataset. Should be fast O(minutes)
106 | to perform as this will be called from a single machine to assign chunks to jobs to run.
107 |
108 | :param dataset: The dataset to scan the metadata database for
109 | :return: a Generator of MetadataDatabaseScanRow in **SORTED** primary_key order
110 | """
111 |
112 | @retry(pynamodb.exceptions.QueryError, tries=10, backoff=2, delay=4, jitter=(0, 2))
113 | def shard_iterator(shard_id: int) -> Generator[DynamoDBExampleDBRow, None, None]:
114 | """Yields DynamoDBExampleDBRows from a given shard to exhaustion, sorted by example_id in ascending order
115 | DynamoDBExampleDBRows are yielded in sorted order of the Dynamodb RangeKey, which is the example_id
116 | """
117 | hash_key = _dataset_shard_name(dataset_id, shard_id)
118 | last_evaluated_key = None
119 | while True:
120 | query_results = DynamoDBExampleDBRow.query(
121 | hash_key,
122 | consistent_read=True,
123 | last_evaluated_key=last_evaluated_key,
124 | limit=DYNAMODB_QUERY_PAGINATION_LIMIT,
125 | )
126 | yield from query_results
127 | if query_results.last_evaluated_key is None:
128 | break
129 | last_evaluated_key = query_results.last_evaluated_key
130 |
131 | # Individual shards have their rows already in sorted order
132 | # Elements are popped off each shard to exhaustion and put into a minheap
133 | # We yield from the heap to exhaustion to provide a stream of globally sorted example_ids
134 | heap: List[Tuple[str, int, DynamoDBExampleDBRow]] = []
135 | shard_iterators = {shard_id: shard_iterator(shard_id) for shard_id in range(NUM_DYNAMODB_SHARDS)}
136 | for shard_id, iterator in shard_iterators.items():
137 | try:
138 | row = next(iterator)
139 | heapq.heappush(heap, (row.example_id, shard_id, row))
140 | except StopIteration:
141 | pass
142 | while heap:
143 | _, shard_id, row = heapq.heappop(heap)
144 | try:
145 | nextrow = next(shard_iterators[shard_id])
146 | heapq.heappush(heap, (nextrow.example_id, shard_id, nextrow))
147 | except StopIteration:
148 | pass
149 | yield MetadataDatabaseScanRow(
150 | partition=row.partition,
151 | row_data_path=row.row_data_path,
152 | row_size=row.row_size,
153 | )
154 |
--------------------------------------------------------------------------------
/wicker/plugins/flyte.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import dataclasses
3 | import json
4 | import tempfile
5 | from typing import Dict, List, Type, cast
6 |
7 | try:
8 | import flytekit # type: ignore
9 | except ImportError:
10 | raise RuntimeError(
11 | "flytekit is not detected in current environment, install Wicker with extra arguments:"
12 | " `pip install wicker[flyte]`"
13 | )
14 | import pyarrow as pa # type: ignore
15 | from flytekit.extend import TypeEngine, TypeTransformer # type: ignore
16 |
17 | from wicker.core.definitions import DatasetID, DatasetPartition
18 | from wicker.core.shuffle import ShuffleJob, ShuffleJobFactory, ShuffleWorker, save_index
19 | from wicker.core.storage import S3DataStorage, S3PathFactory
20 | from wicker.core.writer import DatasetWriterBackend
21 | from wicker.plugins.dynamodb import DynamodbMetadataDatabase
22 |
23 | ###
24 | # Custom type definitions to clean up passing of data between Flyte tasks
25 | ###
26 |
27 |
28 | class ShuffleJobTransformer(TypeTransformer[ShuffleJob]):
29 | _TYPE_INFO = flytekit.BlobType(
30 | format="binary",
31 | dimensionality=flytekit.BlobType.BlobDimensionality.SINGLE,
32 | )
33 |
34 | def __init__(self):
35 | super(ShuffleJobTransformer, self).__init__(name="shufflejob-transform", t=ShuffleJob)
36 |
37 | def get_literal_type(self, t: Type[ShuffleJob]) -> flytekit.LiteralType:
38 | """
39 | This is useful to tell the Flytekit type system that ``ShuffleJob`` actually refers to what corresponding type
40 | In this example, we say its of format binary (do not try to introspect) and there are more than one files in it
41 | """
42 | return flytekit.LiteralType(blob=self._TYPE_INFO)
43 |
44 | @staticmethod
45 | def _shuffle_jobs_to_bytes(job: ShuffleJob) -> bytes:
46 | return json.dumps(
47 | {
48 | "dataset_partition": {
49 | "dataset_id": {
50 | "name": job.dataset_partition.dataset_id.name,
51 | "version": job.dataset_partition.dataset_id.version,
52 | },
53 | "partition": job.dataset_partition.partition,
54 | },
55 | "files": job.files,
56 | }
57 | ).encode("utf-8")
58 |
59 | @staticmethod
60 | def _shuffle_jobs_from_bytes(b: bytes) -> ShuffleJob:
61 | data = json.loads(b.decode("utf-8"))
62 | return ShuffleJob(
63 | dataset_partition=DatasetPartition(
64 | dataset_id=DatasetID(
65 | name=data["dataset_partition"]["dataset_id"]["name"],
66 | version=data["dataset_partition"]["dataset_id"]["version"],
67 | ),
68 | partition=data["dataset_partition"]["partition"],
69 | ),
70 | files=[(path, size) for path, size in data["files"]],
71 | )
72 |
73 | def to_literal(
74 | self,
75 | ctx: flytekit.FlyteContext,
76 | python_val: ShuffleJob,
77 | python_type: Type[ShuffleJob],
78 | expected: flytekit.LiteralType,
79 | ) -> flytekit.Literal:
80 | """
81 | This method is used to convert from given python type object ``MyDataset`` to the Literal representation
82 | """
83 | # Step 1: lets upload all the data into a remote place recommended by Flyte
84 | remote_path = ctx.file_access.get_random_remote_path()
85 | with tempfile.NamedTemporaryFile() as tmpfile:
86 | tmpfile.write(self._shuffle_jobs_to_bytes(python_val))
87 | tmpfile.flush()
88 | tmpfile.seek(0)
89 | ctx.file_access.upload(tmpfile.name, remote_path)
90 | # Step 2: lets return a pointer to this remote_path in the form of a literal
91 | return flytekit.Literal(
92 | scalar=flytekit.Scalar(
93 | blob=flytekit.Blob(uri=remote_path, metadata=flytekit.BlobMetadata(type=self._TYPE_INFO))
94 | )
95 | )
96 |
97 | def to_python_value(
98 | self, ctx: flytekit.FlyteContext, lv: flytekit.Literal, expected_python_type: Type[ShuffleJob]
99 | ) -> ShuffleJob:
100 | """
101 | In this function we want to be able to re-hydrate the custom object from Flyte Literal value
102 | """
103 | # Step 1: lets download remote data locally
104 | local_path = ctx.file_access.get_random_local_path()
105 | ctx.file_access.download(lv.scalar.blob.uri, local_path)
106 | # Step 2: create the ShuffleJob object
107 | with open(local_path, "rb") as f:
108 | return self._shuffle_jobs_from_bytes(f.read())
109 |
110 |
111 | @dataclasses.dataclass
112 | class ShuffleWorkerResults:
113 | partition: str
114 | pa_table: pa.Table
115 |
116 |
117 | class ShuffleWorkerResultsTransformer(TypeTransformer[ShuffleWorkerResults]):
118 | _TYPE_INFO = flytekit.BlobType(
119 | format="binary",
120 | dimensionality=flytekit.BlobType.BlobDimensionality.SINGLE,
121 | )
122 |
123 | def __init__(self):
124 | super(ShuffleWorkerResultsTransformer, self).__init__(
125 | name="shuffleworkerresults-transform",
126 | t=ShuffleWorkerResults,
127 | )
128 |
129 | def get_literal_type(self, t: Type[ShuffleWorkerResults]) -> flytekit.LiteralType:
130 | """
131 | This is useful to tell the Flytekit type system that ``ShuffleWorkerResults`` actually refers to
132 | what corresponding type In this example, we say its of format binary (do not try to introspect) and
133 | there are more than one files in it
134 | """
135 | return flytekit.LiteralType(blob=self._TYPE_INFO)
136 |
137 | def to_literal(
138 | self,
139 | ctx: flytekit.FlyteContext,
140 | python_val: ShuffleWorkerResults,
141 | python_type: Type[ShuffleWorkerResults],
142 | expected: flytekit.LiteralType,
143 | ) -> flytekit.Literal:
144 | """
145 | This method is used to convert from given python type object ``ShuffleWorkerResults``
146 | to the Literal representation
147 | """
148 | # Step 1: lets upload all the data into a remote place recommended by Flyte
149 | remote_path = ctx.file_access.get_random_remote_path()
150 | local_path = ctx.file_access.get_random_local_path()
151 | with pa.ipc.new_stream(
152 | local_path, python_val.pa_table.schema.with_metadata({"partition": python_val.partition})
153 | ) as stream:
154 | stream.write(python_val.pa_table)
155 | ctx.file_access.upload(local_path, remote_path)
156 | # Step 2: lets return a pointer to this remote_path in the form of a literal
157 | return flytekit.Literal(
158 | scalar=flytekit.Scalar(
159 | blob=flytekit.Blob(uri=remote_path, metadata=flytekit.BlobMetadata(type=self._TYPE_INFO))
160 | )
161 | )
162 |
163 | def to_python_value(
164 | self, ctx: flytekit.FlyteContext, lv: flytekit.Literal, expected_python_type: Type[ShuffleWorkerResults]
165 | ) -> ShuffleWorkerResults:
166 | """
167 | In this function we want to be able to re-hydrate the custom object from Flyte Literal value
168 | """
169 | # Step 1: lets download remote data locally
170 | local_path = ctx.file_access.get_random_local_path()
171 | ctx.file_access.download(lv.scalar.blob.uri, local_path)
172 | # Step 2: create the ShuffleWorkerResults object
173 | with pa.ipc.open_stream(local_path) as reader:
174 | return ShuffleWorkerResults(
175 | pa_table=pa.Table.from_batches([b for b in reader]),
176 | partition=reader.schema.metadata[b"partition"].decode("utf-8"),
177 | )
178 |
179 |
180 | TypeEngine.register(ShuffleJobTransformer())
181 | TypeEngine.register(ShuffleWorkerResultsTransformer())
182 |
183 |
184 | ###
185 | # Task and Workflow definitions
186 | ###
187 |
188 |
189 | @flytekit.task(requests=flytekit.Resources(mem="2Gi", cpu="1"), retries=2)
190 | def initialize_dataset(
191 | schema_json_str: str,
192 | dataset_id: str,
193 | ) -> str:
194 | """Write the schema to the storage."""
195 | s3_path_factory = S3PathFactory()
196 | s3_storage = S3DataStorage()
197 | schema_path = s3_path_factory.get_dataset_schema_path(DatasetID.from_str(dataset_id))
198 | s3_storage.put_object_s3(schema_json_str.encode("utf-8"), schema_path)
199 | return schema_json_str
200 |
201 |
202 | @flytekit.task(requests=flytekit.Resources(mem="8Gi", cpu="2"), retries=2)
203 | def create_shuffling_jobs(
204 | schema_json_str: str,
205 | dataset_id: str,
206 | worker_max_working_set_size: int = 16384,
207 | ) -> List[ShuffleJob]:
208 | """Read the DynamoDB and return a list of shuffling jobs to distribute.
209 |
210 | The job descriptions are stored into files managed by Flyte.
211 | :param dataset_id: string representation of the DatasetID we need to process (dataset name + version).
212 | :param schema_json_str: string representation of the dataset schema
213 | :param max_rows_per_worker: Maximum number of rows to assign per working set, defaults to 16384 but can be
214 | increased if dataset sizes are so large that we want to use fewer workers and don't mind the committing
215 | step taking longer per-worker.
216 | :return: a list of shuffling jobs to do.
217 | """
218 | # TODO(jchia): Dynamically decide on what backends to use for S3 and MetadataDatabase instead of hardcoding here
219 | backend = DatasetWriterBackend(S3PathFactory(), S3DataStorage(), DynamodbMetadataDatabase())
220 | job_factory = ShuffleJobFactory(backend, worker_max_working_set_size=worker_max_working_set_size)
221 | return list(job_factory.build_shuffle_jobs(DatasetID.from_str(dataset_id)))
222 |
223 |
224 | @flytekit.task(requests=flytekit.Resources(mem="8Gi", cpu="2"), retries=4, cache=True, cache_version="v1")
225 | def run_shuffling_job(job: ShuffleJob) -> ShuffleWorkerResults:
226 | """Run one shuffling job
227 | :param job: the ShuffleJob for this worker to run
228 | :return: pyarrow table containing only metadata and pointers to the ColumnBytesFiles in S3 for
229 | bytes columns in the dataset.
230 | """
231 | worker = ShuffleWorker(storage=S3DataStorage(), s3_path_factory=S3PathFactory())
232 | return ShuffleWorkerResults(
233 | pa_table=worker.process_job(job),
234 | partition=job.dataset_partition.partition,
235 | )
236 |
237 |
238 | @flytekit.task(requests=flytekit.Resources(mem="8Gi", cpu="2"), retries=2)
239 | def finalize_shuffling_jobs(dataset_id: str, shuffle_results: List[ShuffleWorkerResults]) -> Dict[str, int]:
240 | """Aggregate the indexes from the various shuffling jobs and publish the final parquet files for the dataset.
241 | :param dataset_id: string representation of the DatasetID we need to process (dataset name + version).
242 | :param shuffled_jobs_files: list of files containing the pandas Dataframe generated by the shuffling jobs.
243 | :return: A dictionary mapping partition_name -> size_of_partition.
244 | """
245 | results_by_partition = collections.defaultdict(list)
246 | for result in shuffle_results:
247 | results_by_partition[result.partition].append(result.pa_table)
248 |
249 | tables_by_partition: Dict[str, pa.Table] = {}
250 | for partition in results_by_partition:
251 | tables_by_partition[partition] = pa.concat_tables(results_by_partition[partition])
252 |
253 | dataset_id_obj = DatasetID.from_str(dataset_id)
254 | save_index(
255 | dataset_id_obj.name,
256 | dataset_id_obj.version,
257 | tables_by_partition,
258 | s3_storage=S3DataStorage(),
259 | s3_path_factory=S3PathFactory(),
260 | )
261 | return {partition_name: len(tables_by_partition[partition_name]) for partition_name in tables_by_partition}
262 |
263 |
264 | @flytekit.workflow # type: ignore
265 | def WickerDataShufflingWorkflow(
266 | dataset_id: str,
267 | schema_json_str: str,
268 | worker_max_working_set_size: int = 16384,
269 | ) -> Dict[str, int]:
270 | """Pipeline finalizing a wicker dataset.
271 | :param dataset_id: string representation of the DatasetID we need to process (dataset name + version).
272 | :param schema_json_str: string representation of the schema, serialized as JSON
273 | :return: A dictionary mapping partition_name -> size_of_partition.
274 | """
275 | schema_json_str_committed = initialize_dataset(
276 | dataset_id=dataset_id,
277 | schema_json_str=schema_json_str,
278 | )
279 | jobs = create_shuffling_jobs(
280 | schema_json_str=schema_json_str_committed,
281 | dataset_id=dataset_id,
282 | worker_max_working_set_size=worker_max_working_set_size,
283 | )
284 | shuffle_results = flytekit.map_task(run_shuffling_job, metadata=flytekit.TaskMetadata(retries=1))(job=jobs)
285 | result = cast(Dict[str, int], finalize_shuffling_jobs(dataset_id=dataset_id, shuffle_results=shuffle_results))
286 | return result
287 |
--------------------------------------------------------------------------------
/wicker/plugins/spark.py:
--------------------------------------------------------------------------------
1 | """Spark plugin for writing a dataset with Spark only (no external metadata database required)
2 |
3 | This plugin does an expensive global sorting step using Spark, which could be prohibitive
4 | for large datasets.
5 | """
6 | from __future__ import annotations
7 |
8 | from typing import Any, Dict, Iterable, List, Optional, Tuple
9 |
10 | import pyarrow as pa
11 | import pyarrow.compute as pc
12 |
13 | try:
14 | import pyspark
15 | except ImportError:
16 | raise RuntimeError(
17 | "pyspark is not detected in current environment, install Wicker with extra arguments:"
18 | " `pip install wicker[spark]`"
19 | )
20 |
21 | from operator import add
22 |
23 | from wicker import schema as schema_module
24 | from wicker.core.definitions import DatasetID
25 | from wicker.core.errors import WickerDatastoreException
26 | from wicker.core.persistance import AbstractDataPersistor
27 | from wicker.core.storage import S3DataStorage, S3PathFactory
28 | from wicker.schema import serialization
29 |
30 | DEFAULT_SPARK_PARTITION_SIZE = 256
31 | MAX_COL_FILE_NUMROW = 50 # TODO(isaak-willett): Magic number, we should derive this based on row size
32 |
33 | PrimaryKeyTuple = Tuple[Any, ...]
34 | UnparsedExample = Dict[str, Any]
35 | ParsedExample = Dict[str, Any]
36 | PointerParsedExample = Dict[str, Any]
37 |
38 |
39 | def persist_wicker_dataset(
40 | dataset_name: str,
41 | dataset_version: str,
42 | dataset_schema: schema_module.DatasetSchema,
43 | rdd: pyspark.rdd.RDD[Tuple[str, UnparsedExample]],
44 | s3_storage: S3DataStorage = S3DataStorage(),
45 | s3_path_factory: S3PathFactory = S3PathFactory(),
46 | local_reduction: bool = False,
47 | sort: bool = True,
48 | partition_size=DEFAULT_SPARK_PARTITION_SIZE,
49 | ) -> Optional[Dict[str, int]]:
50 | """
51 | Persist wicker dataset public facing api function, for api consistency.
52 | :param dataset_name: name of dataset persisted
53 | :type dataset_name: str
54 | :param dataset_version: version of dataset persisted
55 | :type dataset_version: str
56 | :param dataset_schema: schema of dataset to be persisted
57 | :type dataset_schema: DatasetSchema
58 | :param rdd: rdd of data to persist
59 | :type rdd: RDD
60 | :param s3_storage: s3 storage abstraction
61 | :type s3_storage: S3DataStorage
62 | :param s3_path_factory: s3 path abstraction
63 | :type s3_path_factory: S3PathFactory
64 | :param local_reduction: if true, reduce tables on main instance (no spark). This is useful if the spark reduction
65 | is not feasible for a large dataste.
66 | :type local_reduction: bool
67 | :param sort: if true, sort the resulting table by primary keys
68 | :type sort: bool
69 | :param partition_size: partition size during the sort
70 | :type partition_size: int
71 | """
72 | return SparkPersistor(s3_storage, s3_path_factory).persist_wicker_dataset(
73 | dataset_name,
74 | dataset_version,
75 | dataset_schema,
76 | rdd,
77 | local_reduction=local_reduction,
78 | sort=sort,
79 | partition_size=partition_size,
80 | )
81 |
82 |
83 | class SparkPersistor(AbstractDataPersistor):
84 | def __init__(
85 | self,
86 | s3_storage: S3DataStorage = S3DataStorage(),
87 | s3_path_factory: S3PathFactory = S3PathFactory(),
88 | ) -> None:
89 | """
90 | Init a SparkPersistor
91 |
92 | :param s3_storage: The storage abstraction for S3
93 | :type s3_storage: S3DataStore
94 | :param s3_path_factory: The path factory for generating s3 paths
95 | based on dataset name and version
96 | :type s3_path_factory: S3PathFactory
97 | """
98 | super().__init__(s3_storage, s3_path_factory)
99 |
100 | def persist_wicker_dataset(
101 | self,
102 | dataset_name: str,
103 | dataset_version: str,
104 | schema: schema_module.DatasetSchema,
105 | rdd: pyspark.rdd.RDD[Tuple[str, UnparsedExample]],
106 | local_reduction: bool = False,
107 | sort: bool = True,
108 | partition_size=DEFAULT_SPARK_PARTITION_SIZE,
109 | ) -> Optional[Dict[str, int]]:
110 | """
111 | Persist the current rdd dataset defined by name, version, schema, and data.
112 | """
113 | # check if variables have been set ie: not None
114 | if (
115 | not isinstance(dataset_name, str)
116 | or not isinstance(dataset_version, str)
117 | or not isinstance(schema, schema_module.DatasetSchema)
118 | or not isinstance(rdd, pyspark.rdd.RDD)
119 | ):
120 | raise ValueError("Current dataset variables not all set, set all to proper not None values")
121 |
122 | # define locally for passing to spark rdd ops, breaks if relying on self
123 | # since it passes to spark engine and we lose self context
124 | s3_storage = self.s3_storage
125 | s3_path_factory = self.s3_path_factory
126 | parse_row = self.parse_row
127 | get_row_keys = self.get_row_keys
128 | persist_wicker_partition = self.persist_wicker_partition
129 | save_partition_tbl = self.save_partition_tbl
130 |
131 | # put the schema up on to s3
132 | schema_path = s3_path_factory.get_dataset_schema_path(DatasetID(name=dataset_name, version=dataset_version))
133 | s3_storage.put_object_s3(serialization.dumps(schema).encode("utf-8"), schema_path)
134 |
135 | # parse the rows and ensure validation passes, ie: rows actual data matches expected types
136 | rdd0 = rdd.mapValues(lambda row: parse_row(row, schema)) # type: ignore
137 |
138 | # Make sure to cache the RDD to ease future computations, since it seems that sortBy and zipWithIndex
139 | # trigger actions and we want to avoid recomputing the source RDD at all costs
140 | rdd0 = rdd0.cache()
141 | dataset_size = rdd0.count()
142 |
143 | rdd1 = rdd0.keyBy(lambda row: get_row_keys(row, schema))
144 |
145 | # Sort RDD by keys
146 | rdd2: pyspark.rdd.RDD[Tuple[Tuple[Any, ...], Tuple[str, ParsedExample]]] = rdd1.sortByKey(
147 | # TODO(jchia): Magic number, we should derive this based on row size
148 | numPartitions=max(1, dataset_size // partition_size),
149 | ascending=True,
150 | )
151 |
152 | def set_partition(iterator: Iterable[PrimaryKeyTuple]) -> Iterable[int]:
153 | key_set = set(iterator)
154 | yield len(key_set)
155 |
156 | # the number of unique keys in rdd partitions
157 | # this is softer check than collecting all the keys in all partitions to check uniqueness
158 | rdd_key_count: int = rdd2.map(lambda x: x[0]).mapPartitions(set_partition).reduce(add)
159 | num_unique_keys = rdd_key_count
160 | if dataset_size != num_unique_keys:
161 | raise WickerDatastoreException(
162 | f"""Error: dataset examples do not have unique primary key tuples.
163 | Dataset has has {dataset_size} examples but {num_unique_keys} unique primary keys"""
164 | )
165 |
166 | # persist the spark partition to S3Storage
167 | rdd3 = rdd2.values()
168 |
169 | rdd4 = rdd3.mapPartitions(
170 | lambda spark_iterator: persist_wicker_partition(
171 | dataset_name,
172 | spark_iterator,
173 | schema,
174 | s3_storage,
175 | s3_path_factory,
176 | target_max_column_file_numrows=MAX_COL_FILE_NUMROW,
177 | )
178 | )
179 |
180 | if not local_reduction:
181 | # combine the rdd by the keys in the pyarrow table
182 | rdd5 = rdd4.combineByKey(
183 | createCombiner=lambda data: pa.Table.from_pydict(
184 | {col: [data[col]] for col in schema.get_all_column_names()}
185 | ),
186 | mergeValue=lambda tbl, data: pa.Table.from_batches(
187 | [
188 | *tbl.to_batches(), # type: ignore
189 | *pa.Table.from_pydict({col: [data[col]] for col in schema.get_all_column_names()}).to_batches(),
190 | ]
191 | ),
192 | mergeCombiners=lambda tbl1, tbl2: pa.Table.from_batches(
193 | [
194 | *tbl1.to_batches(), # type: ignore
195 | *tbl2.to_batches(), # type: ignore
196 | ]
197 | ),
198 | )
199 |
200 | # create the partition tables
201 | if sort:
202 | rdd6 = rdd5.mapValues(
203 | lambda pa_tbl: pc.take(
204 | pa_tbl,
205 | pc.sort_indices(
206 | pa_tbl,
207 | sort_keys=[(pk, "ascending") for pk in schema.primary_keys],
208 | ),
209 | )
210 | )
211 | else:
212 | rdd6 = rdd5
213 |
214 | # save the parition table to s3
215 | rdd7 = rdd6.map(
216 | lambda partition_table: save_partition_tbl(
217 | partition_table, dataset_name, dataset_version, s3_storage, s3_path_factory
218 | )
219 | )
220 | rdd_list = rdd7.collect()
221 | written = {partition: size for partition, size in rdd_list}
222 | else:
223 | # In normal operation, rdd5 may have thousands of partitions at the start of operation,
224 | # however because there are typically only at most three dataset splits (train,test,val),
225 | # the output of rdd5 has only three partitions with any data. Empriically this has lead to
226 | # issues related to driver memory. As a work around, we can instead perform this reduction
227 | # manually outside of the JVM without much of a performance penalty since spark's parallelism
228 | # was not being taken advantage of anyway.
229 |
230 | # We are going to iterate over each partition, grab its pyarrow table, and merge them.
231 | # to avoid localIterator from running rdd4 sequentially, we first cache it and trigger an action.
232 | rdd4 = rdd4.cache()
233 | _ = rdd4.count()
234 |
235 | # the rest of this is adapted from the map task persist
236 | merged_dicts: Dict[str, Dict[str, List[Any]]] = {}
237 | for partition_key, row in rdd4.toLocalIterator():
238 | current_dict: Dict[str, List[Any]] = merged_dicts.get(partition_key, {})
239 | for col in row.keys():
240 | if col in current_dict:
241 | current_dict[col].append(row[col])
242 | else:
243 | current_dict[col] = [row[col]]
244 | merged_dicts[partition_key] = current_dict
245 |
246 | arrow_dict = {}
247 | for partition_key, data_dict in merged_dicts.items():
248 | data_table = pa.Table.from_pydict(data_dict)
249 | if sort:
250 | arrow_dict[partition_key] = pc.take(
251 | pa.Table.from_pydict(data_dict),
252 | pc.sort_indices(data_table, sort_keys=[(pk, "ascending") for pk in schema.primary_keys]),
253 | )
254 | else:
255 | arrow_dict[partition_key] = data_table
256 |
257 | written = {}
258 | for partition_key, pa_table in arrow_dict.items():
259 | self.save_partition_tbl(
260 | (partition_key, pa_table), dataset_name, dataset_version, self.s3_storage, self.s3_path_factory
261 | )
262 | written[partition_key] = pa_table.num_rows
263 |
264 | return written
265 |
266 | @staticmethod
267 | def get_row_keys(
268 | partition_data_tup: Tuple[str, ParsedExample], schema: schema_module.DatasetSchema
269 | ) -> PrimaryKeyTuple:
270 | """
271 | Get the keys of a row based on the parition tuple and the data schema.
272 |
273 | :param partition_data_tup: Tuple of partition id and ParsedExample row
274 | :type partition_data_tup: Tuple[str, ParsedExample]
275 | :return: Tuple of primary key values from parsed row and schema
276 | :rtype: PrimaryKeyTuple
277 | """
278 | partition, data = partition_data_tup
279 | return (partition,) + tuple(data[pk] for pk in schema.primary_keys)
280 |
--------------------------------------------------------------------------------
/wicker/plugins/wandb.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Dict, Literal
3 |
4 | import wandb
5 |
6 | from wicker.core.config import get_config
7 | from wicker.core.definitions import DatasetID
8 | from wicker.core.storage import S3PathFactory
9 |
10 |
11 | def version_dataset(
12 | dataset_name: str,
13 | dataset_version: str,
14 | entity: str,
15 | metadata: Dict[str, Any],
16 | dataset_backend: Literal["s3"] = "s3",
17 | ) -> None:
18 | """
19 | Version the dataset on Weights and Biases using the config parameters defined in wickerconfig.json.
20 |
21 | Args:
22 | dataset_name: The name of the dataset to be versioned
23 | dataset_version: The version of the dataset to be versioned
24 | entity: Who the run will belong to
25 | metadata: The metadata to be logged as an artifact, enforces dataclass for metadata documentation
26 | dataset_backend: The backend where the dataset is stored, currently only supports s3
27 | """
28 | # needs to first acquire and set wandb creds
29 | # WANDB_API_KEY, WANDB_BASE_URL
30 | # _set_wandb_credentials()
31 |
32 | # needs to init the wandb run, this is going to be a 'data' run
33 | dataset_run = wandb.init(project="dataset_curation", name=f"{dataset_name}_{dataset_version}", entity=entity)
34 |
35 | # grab the uri of the dataset to be versioned
36 | dataset_uri = _identify_s3_url_for_dataset_version(dataset_name, dataset_version, dataset_backend)
37 |
38 | # establish the artifact and save the dir/s3_url to the artifact
39 | data_artifact = wandb.Artifact(f"{dataset_name}_{dataset_version}", type="dataset")
40 | data_artifact.add_reference(dataset_uri, name="dataset")
41 |
42 | # save metadata dict to the artifact
43 | data_artifact.metadata["version"] = dataset_version
44 | data_artifact.metadata["s3_uri"] = dataset_uri
45 | for key, value in metadata.items():
46 | data_artifact.metadata[key] = value
47 |
48 | # save the artifact to the run
49 | dataset_run.log_artifact(data_artifact) # type: ignore
50 | dataset_run.finish() # type: ignore
51 |
52 |
53 | def _set_wandb_credentials() -> None:
54 | """
55 | Acquire the weights and biases credentials and load them into the environment.
56 |
57 | This load the variables into the environment as ENV Variables for WandB to use,
58 | this function overrides the previously set wandb env variables with the ones specified in
59 | the wicker config if they exist.
60 | """
61 | # load the config
62 | config = get_config()
63 |
64 | # if the keys are present in the config add them to the env
65 | wandb_config = config.wandb_config
66 | for field in wandb_config.__dataclass_fields__: # type: ignore
67 | attr = wandb_config.__getattribute__(field)
68 | if attr is not None:
69 | os.environ[str(field).upper()] = attr
70 | else:
71 | if str(field).upper() not in os.environ:
72 | raise EnvironmentError(
73 | f"Cannot use W&B without setting {str(field.upper())}. "
74 | f"Specify in either ENV or through wicker config file."
75 | )
76 |
77 |
78 | def _identify_s3_url_for_dataset_version(
79 | dataset_name: str,
80 | dataset_version: str,
81 | dataset_backend: Literal["s3"] = "s3",
82 | ) -> str:
83 | """
84 | Identify and return the s3 url for the dataset and version specified in the backend.
85 |
86 | Args:
87 | dataset_name: name of the dataset to retrieve url
88 | dataset_version: version of the dataset to retrieve url
89 | dataset_backend: backend of the dataset to retrieve url
90 |
91 | Returns:
92 | The url pointing to the dataset on storage
93 | """
94 | schema_path = ""
95 | if dataset_backend == "s3":
96 | # needs to do the parsing work to fetch the correct s3 uri
97 | schema_path = S3PathFactory().get_dataset_assets_path(DatasetID(name=dataset_name, version=dataset_version))
98 | return schema_path
99 |
--------------------------------------------------------------------------------
/wicker/schema/__init__.py:
--------------------------------------------------------------------------------
1 | from .schema import * # noqa: F401, F403
2 |
--------------------------------------------------------------------------------
/wicker/schema/codecs.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | from typing import Any, Dict, Type
5 |
6 |
7 | class Codec(abc.ABC):
8 | """Base class for all object decoder/encoders.
9 |
10 | Defining Codec classes allows user to use arbitrary types in their Wicker Examples. Subclasses of Codec provide
11 | functionality to serialize/deserialize data to/from underlying storage, and can be used to abstract functionality
12 | such as compression, validation of object fields and provide adaptor functionality to other libraries (e.g. NumPy,
13 | PIL, torch).
14 |
15 | Wicker provides some Codec implementations by default for NumPy, but users can choose to define their own Codecs
16 | and use them during data-writing and data-reading.
17 | """
18 |
19 | # Maps from codec_name to codec class.
20 | codec_registry: Dict[str, Type[Codec]] = {}
21 |
22 | def __init_subclass__(cls: Type[Codec], **kwargs: Any) -> None:
23 | """Automatically register subclasses of Codec."""
24 | if cls._codec_name() in Codec.codec_registry:
25 | raise KeyError(f"Codec '{cls._codec_name()}' was already defined.")
26 | if cls._codec_name():
27 | Codec.codec_registry[cls._codec_name()] = cls
28 |
29 | @staticmethod
30 | @abc.abstractmethod
31 | def _codec_name() -> str:
32 | """Needs to be overriden to return a globally unique name for the codec."""
33 | pass
34 |
35 | def get_codec_name(self) -> str:
36 | """Accessor for _codec_name. In general, derived classes should not touch this, and should just implement
37 | _codec_name. This accessor is used internally to support generic schema serialization/deserialization.
38 | """
39 | return self._codec_name()
40 |
41 | def save_codec_to_dict(self) -> Dict[str, Any]:
42 | """If you want to save some parameters of this codec with the dataset
43 | schema, return the fields here. The returned dictionary should be JSON compatible.
44 | Note that this is a dataset-level value, not a per example value."""
45 | return {}
46 |
47 | @staticmethod
48 | @abc.abstractmethod
49 | def load_codec_from_dict(data: Dict[str, Any]) -> Codec:
50 | """Create a new instance of this codec with the given parameters."""
51 | pass
52 |
53 | @abc.abstractmethod
54 | def validate_and_encode_object(self, obj: Any) -> bytes:
55 | """Encode the given object into bytes. The function is also responsible for validating the data.
56 | :param obj: Object to encode
57 | :return: The encoded bytes for the given object."""
58 | pass
59 |
60 | @abc.abstractmethod
61 | def decode_object(self, data: bytes) -> Any:
62 | """Decode an object from the given bytes. This is the opposite of validate_and_encode_object.
63 | We expect obj == decode_object(validate_and_encode_object(obj))
64 | :param data: bytes to decode.
65 | :return: Decoded object."""
66 | pass
67 |
68 | def object_type(self) -> Type[Any]:
69 | """Return the expected type of the objects handled by this codec.
70 | This method can be overriden to match more specific classes."""
71 | return object
72 |
73 | def __eq__(self, other: Any) -> bool:
74 | return (
75 | super().__eq__(other)
76 | and self.get_codec_name() == other.get_codec_name()
77 | and self.save_codec_to_dict() == other.save_codec_to_dict()
78 | )
79 |
--------------------------------------------------------------------------------
/wicker/schema/dataloading.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple, TypeVar
2 |
3 | from wicker.schema import schema, validation
4 |
5 | T = TypeVar("T")
6 |
7 |
8 | def load_example(example: validation.AvroRecord, schema: schema.DatasetSchema) -> Dict[str, Any]:
9 | """Loads an example according to the provided schema, converting data based on their column
10 | information into their corresponding in-memory representations (e.g. numpy arrays, torch tensors)
11 |
12 | The returned dictionary will be in the same shape as defined in the schema, with keys
13 | corresponding to the schema names, and values validated/transformed against the schema's fields.
14 |
15 | :param example: example to load
16 | :type example: validation.AvroRecord
17 | :param schema: schema to load against
18 | :type schema: schema.DatasetSchema
19 | :return: loaded example
20 | :rtype: Dict[str, Any]
21 | """
22 | load_example_visitor = LoadExampleVisitor(example, schema)
23 | return load_example_visitor.load_example()
24 |
25 |
26 | class LoadExampleVisitor(schema.DatasetSchemaVisitor[Any]):
27 | """schema.DatasetSchemaVisitor class that will validate and load an Example
28 | in accordance to a provided schema.DatasetSchema
29 | """
30 |
31 | def __init__(self, example: validation.AvroRecord, schema: schema.DatasetSchema):
32 | # Pointers to original data (data should be kept immutable)
33 | self._schema = schema
34 | self._example = example
35 |
36 | # Pointers to help keep visitor state during tree traversal
37 | self._current_data: Any = self._example
38 | self._current_path: Tuple[str, ...] = tuple()
39 |
40 | def load_example(self) -> Dict[str, Any]:
41 | """Loads an example from its Avro format into its in-memory representations"""
42 | # Since the original input example is non-None, the loaded example will be non-None also
43 | example: Dict[str, Any] = self._schema.schema_record._accept_visitor(self)
44 | return example
45 |
46 | def process_record_field(self, field: schema.RecordField) -> Optional[validation.AvroRecord]:
47 | """Visit an schema.RecordField schema field"""
48 | current_data = validation.validate_dict(self._current_data, field.required, self._current_path)
49 | if current_data is None:
50 | return current_data
51 |
52 | # Process nested fields by setting up the visitor's state and visiting each node
53 | processing_path = self._current_path
54 | processing_example = current_data
55 | loaded = {}
56 |
57 | # When reading records, the client might restrict the columns to load to a subset of the
58 | # full columns, so check if the key is actually present in the example being processed
59 | for nested_field in field.fields:
60 | if nested_field.name in processing_example:
61 | self._current_path = processing_path + (nested_field.name,)
62 | self._current_data = processing_example[nested_field.name]
63 | loaded[nested_field.name] = nested_field._accept_visitor(self)
64 | return loaded
65 |
66 | def process_int_field(self, field: schema.IntField) -> Optional[int]:
67 | return validation.validate_field_type(self._current_data, int, field.required, self._current_path)
68 |
69 | def process_long_field(self, field: schema.LongField) -> Optional[int]:
70 | return validation.validate_field_type(self._current_data, int, field.required, self._current_path)
71 |
72 | def process_string_field(self, field: schema.StringField) -> Optional[str]:
73 | return validation.validate_field_type(self._current_data, str, field.required, self._current_path)
74 |
75 | def process_bool_field(self, field: schema.BoolField) -> Optional[bool]:
76 | return validation.validate_field_type(self._current_data, bool, field.required, self._current_path)
77 |
78 | def process_float_field(self, field: schema.FloatField) -> Optional[float]:
79 | return validation.validate_field_type(self._current_data, float, field.required, self._current_path)
80 |
81 | def process_double_field(self, field: schema.DoubleField) -> Optional[float]:
82 | return validation.validate_field_type(self._current_data, float, field.required, self._current_path)
83 |
84 | def process_object_field(self, field: schema.ObjectField) -> Optional[Any]:
85 | data = validation.validate_field_type(self._current_data, bytes, field.required, self._current_path)
86 | if data is None:
87 | return data
88 | return field.codec.decode_object(data)
89 |
90 | def process_array_field(self, field: schema.ArrayField) -> Optional[List[Any]]:
91 | current_data = validation.validate_field_type(self._current_data, list, field.required, self._current_path)
92 | if current_data is None:
93 | return current_data
94 |
95 | # Process array elements by setting up the visitor's state and visiting each element
96 | processing_path = self._current_path
97 | loaded = []
98 |
99 | # Arrays may contain None values if the element field declares that it is not required
100 | for element_index, element in enumerate(current_data):
101 | self._current_path = processing_path + (f"elem[{element_index}]",)
102 | self._current_data = element
103 | loaded.append(field.element_field._accept_visitor(self))
104 | return loaded
105 |
--------------------------------------------------------------------------------
/wicker/schema/dataparsing.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple
2 |
3 | from wicker.core.definitions import ExampleMetadata
4 | from wicker.core.errors import WickerSchemaException
5 | from wicker.schema import schema, validation
6 |
7 |
8 | def parse_example(example: Dict[str, Any], schema: schema.DatasetSchema) -> validation.AvroRecord:
9 | """Parses an example according to the provided schema, converting known types such as
10 | numpy arrays, torch tensors etc into bytes for storage on disk
11 |
12 | The returned dictionary will be in the same shape as defined in the schema, with keys
13 | corresponding to the schema names, and values validated against the schema's fields.
14 |
15 | :param example: example to parse
16 | :type example: Dict[str, Any]
17 | :param schema: schema to parse against
18 | :type schema: schema.DatasetSchema
19 | :return: parsed example
20 | :rtype: Dict[str, Any]
21 | """
22 | parse_example_visitor = ParseExampleVisitor(example, schema)
23 | return parse_example_visitor.parse_example()
24 |
25 |
26 | def parse_example_metadata(example: Dict[str, Any], schema: schema.DatasetSchema) -> ExampleMetadata:
27 | """Parses ExampleMetadata from an example according to the provided schema
28 |
29 | The returned dictionary will be in the same shape as defined in the schema, with keys
30 | corresponding to the schema names, and values validated against the schema's fields.
31 |
32 | Certain field types will be ignored when parsing examples as they are not considered part of the example metadata:
33 |
34 | 1. BytesField
35 | ... More to be implemented (e.g. AvroNumpyField etc)
36 |
37 | :param example: example to parse
38 | :type example: Dict[str, Any]
39 | :param schema: schema to parse against
40 | :type schema: schema.DatasetSchema
41 | """
42 | parse_example_metadata_visitor = ParseExampleMetadataVisitor(example, schema)
43 | return parse_example_metadata_visitor.parse_example()
44 |
45 |
46 | # A special exception to indicate that a field should be skipped
47 | class _SkipFieldException(Exception):
48 | pass
49 |
50 |
51 | class ParseExampleVisitor(schema.DatasetSchemaVisitor[Any]):
52 | """schema.DatasetSchemaVisitor class that will validate and transform an Example
53 | in accordance to a provided schema.DatasetSchema
54 |
55 | The original example is not modified in-place, and we incur the cost of copying
56 | primitives (e.g. bytes, strings)
57 | """
58 |
59 | def __init__(self, example: Dict[str, Any], schema: schema.DatasetSchema):
60 | # Pointers to original data (data should be kept immutable)
61 | self._schema = schema
62 | self._example = example
63 |
64 | # Pointers to help keep visitor state during tree traversal
65 | self._current_data: Any = self._example
66 | self._current_path: Tuple[str, ...] = tuple()
67 |
68 | def parse_example(self) -> Dict[str, Any]:
69 | """Parses an example into a form that is suitable for storage in an Avro format"""
70 | # Since the original input example is non-None, the parsed example will be non-None also
71 | example: Dict[str, Any] = self._schema.schema_record._accept_visitor(self)
72 | return example
73 |
74 | def process_record_field(self, field: schema.RecordField) -> Optional[validation.AvroRecord]:
75 | """Visit an schema.RecordField schema field"""
76 | val = validation.validate_field_type(self._current_data, dict, field.required, self._current_path)
77 | if val is None:
78 | return val
79 |
80 | # Add keys to the example for any non-required fields that were left unset in the raw data
81 | for optional_field in [f for f in field.fields if not f.required]:
82 | if optional_field.name not in val:
83 | val[optional_field.name] = None
84 |
85 | # Validate that data matches schema exactly
86 | schema_key_names = {nested_field.name for nested_field in field.fields}
87 | if val.keys() != schema_key_names:
88 | raise WickerSchemaException(
89 | f"Error at path {'.'.join(self._current_path)}: "
90 | f"Example missing keys: {list(schema_key_names - val.keys())} "
91 | f"and has extra keys: {list(val.keys() - schema_key_names)}"
92 | )
93 |
94 | # Process nested fields by setting up the visitor's state and visiting each node
95 | res = {}
96 | processing_path = self._current_path
97 | processing_example = self._current_data
98 |
99 | for nested_field in field.fields:
100 | self._current_path = processing_path + (nested_field.name,)
101 | self._current_data = processing_example[nested_field.name]
102 | try:
103 | res[nested_field.name] = nested_field._accept_visitor(self)
104 | except _SkipFieldException:
105 | pass
106 | return res
107 |
108 | def process_int_field(self, field: schema.IntField) -> Optional[int]:
109 | return validation.validate_field_type(self._current_data, int, field.required, self._current_path)
110 |
111 | def process_long_field(self, field: schema.LongField) -> Optional[int]:
112 | return validation.validate_field_type(self._current_data, int, field.required, self._current_path)
113 |
114 | def process_string_field(self, field: schema.StringField) -> Optional[str]:
115 | return validation.validate_field_type(self._current_data, str, field.required, self._current_path)
116 |
117 | def process_bool_field(self, field: schema.BoolField) -> Optional[bool]:
118 | return validation.validate_field_type(self._current_data, bool, field.required, self._current_path)
119 |
120 | def process_float_field(self, field: schema.FloatField) -> Optional[float]:
121 | return validation.validate_field_type(self._current_data, float, field.required, self._current_path)
122 |
123 | def process_double_field(self, field: schema.DoubleField) -> Optional[float]:
124 | return validation.validate_field_type(self._current_data, float, field.required, self._current_path)
125 |
126 | def process_object_field(self, field: schema.ObjectField) -> Optional[bytes]:
127 | data = validation.validate_field_type(
128 | self._current_data, field.codec.object_type(), field.required, self._current_path
129 | )
130 | if data is None:
131 | return None
132 | return field.codec.validate_and_encode_object(data)
133 |
134 | def process_array_field(self, field: schema.ArrayField) -> Optional[List[Any]]:
135 | val = validation.validate_field_type(self._current_data, list, field.required, self._current_path)
136 | if val is None:
137 | return val
138 |
139 | # Process array elements by setting up the visitor's state and visiting each element
140 | res = []
141 | processing_path = self._current_path
142 | processing_example = self._current_data
143 |
144 | # Arrays may contain None values if the element field declares that it is not required
145 | for element_index, element in enumerate(processing_example):
146 | self._current_path = processing_path + (f"elem[{element_index}]",)
147 | self._current_data = element
148 | # Allow _SkipFieldExceptions to propagate up and skip this array field
149 | res.append(field.element_field._accept_visitor(self))
150 | return res
151 |
152 |
153 | class ParseExampleMetadataVisitor(ParseExampleVisitor):
154 | """Specialization of ParseExampleVisitor which skips over certain fields that are now parsed as metadata"""
155 |
156 | def process_object_field(self, field: schema.ObjectField) -> Optional[bytes]:
157 | # Raise a special error to indicate that this field should be skipped
158 | raise _SkipFieldException()
159 |
--------------------------------------------------------------------------------
/wicker/schema/serialization.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from typing import Any, Dict, Type, Union
4 |
5 | from wicker.core.errors import WickerSchemaException
6 | from wicker.schema import codecs, schema
7 | from wicker.schema.schema import PRIMARY_KEYS_TAG
8 |
9 | JSON_SCHEMA_VERSION = 2
10 | DATASET_ID_REGEX = re.compile(r"^(?P[0-9a-zA-Z_]+)/v(?P[0-9]+\.[0-9]+\.[0-9]+)$")
11 |
12 |
13 | def dumps(schema: schema.DatasetSchema, pretty: bool = True) -> str:
14 | """Dumps a schema as JSON
15 |
16 | :param schema: schema to dump
17 | :type schema: schema.DatasetSchema
18 | :param pretty: whether to dump the schema as a prettified JSON, defaults to False
19 | :type pretty: bool, optional
20 | :return: JSON string
21 | :rtype: str
22 | """
23 | visitor = AvroDatasetSchemaSerializer()
24 | jdata = schema.schema_record._accept_visitor(visitor)
25 | jdata.update(jdata["type"])
26 | jdata["_json_version"] = JSON_SCHEMA_VERSION
27 | if pretty:
28 | return json.dumps(jdata, sort_keys=True, indent=4)
29 | return json.dumps(jdata)
30 |
31 |
32 | def loads(schema_str: str, treat_objects_as_bytes: bool = False) -> schema.DatasetSchema:
33 | """Loads a DatasetSchema from a JSON str
34 |
35 | :param schema_str: JSON string
36 | :param treat_objects_as_bytes: If set, don't load codecs for object types, instead just treat them as bytes.
37 | This is useful for code that needs to work with datasets in a generic way, but does not need to actually decode
38 | the data.
39 | :return: Deserialized DatasetSchema object
40 | """
41 | # Parse string as JSON
42 | try:
43 | schema_dict = json.loads(schema_str)
44 | except json.decoder.JSONDecodeError:
45 | raise WickerSchemaException(f"Unable load string as JSON: {schema_str}")
46 |
47 | # Construct the DatasetSchema
48 | try:
49 | fields = [_loads(d, treat_objects_as_bytes=treat_objects_as_bytes) for d in schema_dict["fields"]]
50 | return schema.DatasetSchema(
51 | fields=fields,
52 | primary_keys=json.loads(schema_dict.get(PRIMARY_KEYS_TAG, "[]")),
53 | allow_empty_primary_keys=True, # For backward compatibility. Clean me AVSW-78939.
54 | )
55 | except KeyError as err:
56 | raise WickerSchemaException(f"Malformed serialization of DatasetSchema: {err}")
57 |
58 |
59 | def _loads(schema_dict: Dict[str, Any], treat_objects_as_bytes: bool) -> schema.SchemaField:
60 | """Recursively parses a schema dictionary into the appropriate SchemaField"""
61 | type_ = schema_dict["type"]
62 | required = True
63 |
64 | # Nullable types are represented in Avro schemas as Union (list) types with "null" as the
65 | # first element (by convention)
66 | if isinstance(type_, list):
67 | assert type_[0] == "null"
68 | required = False
69 | type_ = type_[1]
70 |
71 | if isinstance(type_, dict) and type_["type"] == "record":
72 | return schema.RecordField(
73 | name=schema_dict["name"],
74 | fields=[_loads(f, treat_objects_as_bytes=treat_objects_as_bytes) for f in type_["fields"]],
75 | description=schema_dict["_description"],
76 | required=required,
77 | )
78 | if isinstance(type_, dict) and type_["type"] == "array":
79 | # ArrayField columns are limited to contain only dicts or simple types (and not nested arrays)
80 | element_dict = type_.copy()
81 | element_dict["type"] = type_["items"]
82 | element_field = _loads(element_dict, treat_objects_as_bytes=treat_objects_as_bytes)
83 | return schema.ArrayField(
84 | element_field=element_field,
85 | required=required,
86 | )
87 | return _loads_base_types(type_, required, schema_dict, treat_objects_as_bytes=treat_objects_as_bytes)
88 |
89 |
90 | def _loads_base_types(
91 | type_: str, required: bool, schema_dict: Dict[str, Any], treat_objects_as_bytes: bool = False
92 | ) -> schema.SchemaField:
93 | if type_ == "int":
94 | return schema.IntField(
95 | name=schema_dict["name"],
96 | description=schema_dict["_description"],
97 | required=required,
98 | )
99 | elif type_ == "long":
100 | return schema.LongField(
101 | name=schema_dict["name"],
102 | description=schema_dict["_description"],
103 | required=required,
104 | )
105 | elif type_ == "string":
106 | return schema.StringField(
107 | name=schema_dict["name"],
108 | description=schema_dict["_description"],
109 | required=required,
110 | )
111 | elif type_ == "boolean":
112 | return schema.BoolField(
113 | name=schema_dict["name"],
114 | description=schema_dict["_description"],
115 | required=required,
116 | )
117 | elif type_ == "float":
118 | return schema.FloatField(
119 | name=schema_dict["name"],
120 | description=schema_dict["_description"],
121 | required=required,
122 | )
123 | elif type_ == "double":
124 | return schema.DoubleField(
125 | name=schema_dict["name"],
126 | description=schema_dict["_description"],
127 | required=required,
128 | )
129 | elif type_ == "bytes":
130 | l5ml_metatype = schema_dict["_l5ml_metatype"]
131 | if l5ml_metatype == "object":
132 | codec_name = schema_dict["_codec_name"]
133 | if treat_objects_as_bytes:
134 | codec: codecs.Codec = _PassThroughObjectCodec(codec_name, json.loads(schema_dict["_codec_params"]))
135 | else:
136 | try:
137 | codec_cls: Type[codecs.Codec] = codecs.Codec.codec_registry[codec_name]
138 | except KeyError:
139 | raise WickerSchemaException(
140 | f"Could not find a registered codec with name {codec_name} "
141 | f"for field {schema_dict['name']}. Please define a subclass of ObjectField.Codec and define "
142 | "the codec_name static method."
143 | )
144 | codec = codec_cls.load_codec_from_dict(json.loads(schema_dict["_codec_params"]))
145 | return schema.ObjectField(
146 | name=schema_dict["name"],
147 | codec=codec,
148 | description=schema_dict["_description"],
149 | required=required,
150 | is_heavy_pointer=schema_dict.get("_is_heavy_pointer", False),
151 | )
152 | elif l5ml_metatype == "numpy":
153 | return schema.NumpyField(
154 | name=schema_dict["name"],
155 | description=schema_dict["_description"],
156 | is_heavy_pointer=schema_dict.get("_is_heavy_pointer", False),
157 | required=required,
158 | shape=schema_dict["_shape"],
159 | dtype=schema_dict["_dtype"],
160 | )
161 | elif l5ml_metatype == "bytes":
162 | return schema.BytesField(
163 | name=schema_dict["name"],
164 | description=schema_dict["_description"],
165 | required=required,
166 | is_heavy_pointer=schema_dict.get("_is_heavy_pointer", False),
167 | )
168 | raise WickerSchemaException(f"Unhandled _l5ml_metatype for avro bytes type: {l5ml_metatype}")
169 | elif type_ == "record":
170 | schema_fields = []
171 | for schema_field in schema_dict["fields"]:
172 | loaded_field = _loads_base_types(schema_field["type"], True, schema_field)
173 | schema_fields.append(loaded_field)
174 | return schema.RecordField(
175 | fields=schema_fields,
176 | name=schema_dict["name"],
177 | )
178 | raise WickerSchemaException(f"Unhandled type: {type_}")
179 |
180 |
181 | class _PassThroughObjectCodec(codecs.Codec):
182 | """The _PassThroughObjectCodec class is a placeholder for any object codec, it is used when we need to parse
183 | any possible schema properly, but we don't need to read the data. This codec does not decode/encode the data,
184 | instead is just acts as an identity function. However, it keeps track of all the attributes of the original codec
185 | as they were stored in the loaded schema, so that if we save that schema again, we do not lose any information.
186 | """
187 |
188 | def __init__(self, codec_name: str, codec_attributes: Dict[str, Any]):
189 | self._codec_name_value = codec_name
190 | self._codec_attributes = codec_attributes
191 |
192 | @staticmethod
193 | def _codec_name() -> str:
194 | # The static method does not make any sense here, because this class is a placeholder for any codec class.
195 | # Instead we will overload the get_codec_name function.
196 | return ""
197 |
198 | def get_codec_name(self) -> str:
199 | return self._codec_name_value
200 |
201 | def save_codec_to_dict(self) -> Dict[str, Any]:
202 | return self._codec_attributes
203 |
204 | @staticmethod
205 | def load_codec_from_dict(data: Dict[str, Any]) -> codecs.Codec:
206 | pass
207 |
208 | def validate_and_encode_object(self, obj: bytes) -> bytes:
209 | return obj
210 |
211 | def decode_object(self, data: bytes) -> bytes:
212 | return data
213 |
214 | def object_type(self) -> Type[Any]:
215 | return bytes
216 |
217 |
218 | class AvroDatasetSchemaSerializer(schema.DatasetSchemaVisitor[Dict[str, Any]]):
219 | """A visitor class that serializes an AvroDatasetSchema as Avro-compatible JSON"""
220 |
221 | def process_schema_field(self, field: schema.SchemaField, avro_type: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
222 | """Common processing across all field types"""
223 | # The way to declare nullable fields in Avro schemas is to declare Avro Union (list) types.
224 | # The default type (usually null) should be listed first:
225 | # https://avro.apache.org/docs/current/spec.html#Unions
226 | field_type = avro_type if field.required else ["null", avro_type]
227 | return {
228 | "name": field.name,
229 | "type": field_type,
230 | "_description": field.description,
231 | **field.custom_field_tags,
232 | }
233 |
234 | def process_record_field(self, field: schema.RecordField) -> Dict[str, Any]:
235 | record_type = {
236 | "type": "record",
237 | "fields": [nested_field._accept_visitor(self) for nested_field in field.fields],
238 | "name": field.name,
239 | }
240 | return self.process_schema_field(field, record_type)
241 |
242 | def process_int_field(self, field: schema.IntField) -> Dict[str, Any]:
243 | return {
244 | **self.process_schema_field(field, "int"),
245 | }
246 |
247 | def process_long_field(self, field: schema.LongField) -> Dict[str, Any]:
248 | return {
249 | **self.process_schema_field(field, "long"),
250 | }
251 |
252 | def process_string_field(self, field: schema.StringField) -> Dict[str, Any]:
253 | return {
254 | **self.process_schema_field(field, "string"),
255 | }
256 |
257 | def process_bool_field(self, field: schema.BoolField) -> Dict[str, Any]:
258 | return {
259 | **self.process_schema_field(field, "boolean"),
260 | }
261 |
262 | def process_float_field(self, field: schema.FloatField) -> Dict[str, Any]:
263 | return {
264 | **self.process_schema_field(field, "float"),
265 | }
266 |
267 | def process_double_field(self, field: schema.DoubleField) -> Dict[str, Any]:
268 | return {
269 | **self.process_schema_field(field, "double"),
270 | }
271 |
272 | def process_object_field(self, field: schema.ObjectField) -> Dict[str, Any]:
273 | return {
274 | **self.process_schema_field(field, "bytes"),
275 | "_l5ml_metatype": "object",
276 | "_is_heavy_pointer": field.is_heavy_pointer,
277 | }
278 |
279 | def process_array_field(self, field: schema.ArrayField) -> Dict[str, Any]:
280 | array_type = field.element_field._accept_visitor(self)
281 | array_type["items"] = array_type["type"]
282 | array_type["type"] = "array"
283 |
284 | field_type = array_type if field.required else ["null", array_type]
285 | return {
286 | "name": field.name,
287 | "type": field_type,
288 | }
289 |
--------------------------------------------------------------------------------
/wicker/schema/validation.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Type, TypeVar
2 |
3 | from wicker.core.errors import WickerSchemaException
4 |
5 | T = TypeVar("T")
6 |
7 |
8 | # Unfortunately, mypy has weak support for recursive typing so we resort to Any here, but Any
9 | # should really be an alias for: Union[AvroRecord, str, bool, int, float, bytes]
10 | AvroRecord = Dict[str, Any]
11 |
12 |
13 | def validate_field_type(val: Any, type_: Type[T], required: bool, current_path: Tuple[str, ...]) -> Optional[T]:
14 | """Validates the type of a field
15 |
16 | :param val: value to validate
17 | :type val: Any
18 | :param type_: type to validate against
19 | :type type_: Type[T]
20 | :param required: whether or not the value is required (to be non-None)
21 | :type required: bool
22 | :param current_path: current parsing path
23 | :type current_path: Tuple[str, ...]
24 | :raises WickerSchemaException: when parsing error occurs
25 | :return: val, but validated to be of type T
26 | :rtype: T
27 | """
28 | if val is None:
29 | if not required:
30 | return val
31 | raise WickerSchemaException(
32 | f"Error at path {'.'.join(current_path)}: Example provided a None value for required field"
33 | )
34 | elif not isinstance(val, type_):
35 | raise WickerSchemaException(
36 | f"Error at path {'.'.join(current_path)}: Example provided a {type(val)} value, expected {type_}"
37 | )
38 | return val
39 |
40 |
41 | def validate_dict(val: Any, required: bool, current_path: Tuple[str, ...]) -> Optional[Dict[str, Any]]:
42 | """Validates a dictionary
43 |
44 | :param val: incoming value to validate as a dictionary
45 | :type val: Any
46 | :param required: whether or not the value is required (to be non-None)
47 | :type required: bool
48 | :param current_path: current parsing path
49 | :type current_path: Tuple[str, ...]
50 | :return: parsed dictionary
51 | :rtype: dict
52 | """
53 | if val is None:
54 | if not required:
55 | return val
56 | raise WickerSchemaException(
57 | f"Error at path {'.'.join(current_path)}: Example provided a None value for required field"
58 | )
59 | # PyArrow returns record fields as lists of (k, v) tuples
60 | elif isinstance(val, list):
61 | try:
62 | parsed_val = dict(val)
63 | except ValueError:
64 | raise WickerSchemaException(f"Error at path {'.'.join(current_path)}: Unable to convert list to dict")
65 | elif isinstance(val, dict):
66 | parsed_val = val
67 | else:
68 | raise WickerSchemaException(f"Error at path {'.'.join(current_path)}: Unable to convert object to dict")
69 |
70 | return validate_field_type(parsed_val, dict, required, current_path)
71 |
--------------------------------------------------------------------------------
/wicker/testing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/testing/__init__.py
--------------------------------------------------------------------------------
/wicker/testing/codecs.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any, Dict, List, Type
3 |
4 | from wicker.schema import codecs
5 |
6 |
7 | class Vector:
8 | def __init__(self, data: List[int]):
9 | self.data = data
10 |
11 | def __eq__(self, other: Any) -> bool:
12 | return super().__eq__(other) and self.data == other.data
13 |
14 |
15 | class VectorCodec(codecs.Codec):
16 | def __init__(self, compression_method: int) -> None:
17 | self.compression_method = compression_method
18 |
19 | @staticmethod
20 | def _codec_name() -> str:
21 | return "VectorCodec"
22 |
23 | def save_codec_to_dict(self) -> Dict[str, Any]:
24 | return {"compression_method": self.compression_method}
25 |
26 | @staticmethod
27 | def load_codec_from_dict(data: Dict[str, Any]) -> codecs.Codec:
28 | return VectorCodec(compression_method=data["compression_method"])
29 |
30 | def validate_and_encode_object(self, obj: Vector) -> bytes:
31 | # Inefficient but simple encoding method for testing.
32 | return json.dumps(obj.data).encode("utf-8")
33 |
34 | def decode_object(self, data: bytes) -> Vector:
35 | return Vector(json.loads(data.decode()))
36 |
37 | def object_type(self) -> Type[Any]:
38 | return Vector
39 |
--------------------------------------------------------------------------------
/wicker/testing/storage.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from pathlib import Path
4 | from typing import Any, Dict
5 |
6 | import pyarrow.fs as pafs
7 |
8 | from wicker.core.storage import S3DataStorage
9 |
10 |
11 | class FakeS3DataStorage(S3DataStorage):
12 | def __init__(self, tmpdir: str = "/tmp") -> None:
13 | self._tmpdir = tmpdir
14 |
15 | def __getstate__(self) -> Dict[Any, Any]:
16 | return {"tmpdir": self._tmpdir}
17 |
18 | def __setstate__(self, state: Dict[Any, Any]) -> None:
19 | self._tmpdir = state["tmpdir"]
20 | return
21 |
22 | def _get_local_path(self, path: str) -> str:
23 | return os.path.join(self._tmpdir, path.replace("s3://", ""))
24 |
25 | def check_exists_s3(self, input_path: str) -> bool:
26 | return os.path.exists(self._get_local_path(input_path))
27 |
28 | def fetch_obj_s3(self, input_path: str) -> bytes:
29 | if not self.check_exists_s3(input_path):
30 | raise KeyError(f"File {input_path} not found in the fake s3 storage.")
31 | with open(self._get_local_path(input_path), "rb") as f:
32 | return f.read()
33 |
34 | def fetch_file(self, input_path: str, local_prefix: str, timeout_seconds: int = 120) -> str:
35 | if not self.check_exists_s3(input_path):
36 | raise KeyError(f"File {input_path} not found in the fake s3 storage.")
37 | bucket, key = self.bucket_key_from_s3_path(input_path)
38 | dest_path = os.path.join(local_prefix, key)
39 | os.makedirs(os.path.dirname(dest_path), exist_ok=True)
40 | if not os.path.isfile(dest_path):
41 | shutil.copy2(self._get_local_path(input_path), dest_path)
42 | return dest_path
43 |
44 | def put_object_s3(self, object_bytes: bytes, s3_path: str) -> None:
45 | full_tmp_path = self._get_local_path(s3_path)
46 | os.makedirs(os.path.dirname(full_tmp_path), exist_ok=True)
47 | with open(full_tmp_path, "wb") as f:
48 | f.write(object_bytes)
49 |
50 | def put_file_s3(self, local_path: str, s3_path: str) -> None:
51 | full_tmp_path = self._get_local_path(s3_path)
52 | os.makedirs(os.path.dirname(full_tmp_path), exist_ok=True)
53 | shutil.copy2(local_path, full_tmp_path)
54 |
55 |
56 | class LocalDataStorage(S3DataStorage):
57 | def __init__(self, root_path: str):
58 | super().__init__()
59 | self._root_path = Path(root_path)
60 | self._fs = pafs.LocalFileSystem()
61 |
62 | @property
63 | def filesystem(self) -> pafs.FileSystem:
64 | return self._fs
65 |
66 | def _create_path(self, path: str) -> None:
67 | """Ensures the given path exists."""
68 | self._fs.create_dir(path, recursive=True)
69 |
70 | # Override.
71 | def check_exists_s3(self, input_path: str) -> bool:
72 | file_info = self._fs.get_file_info(input_path)
73 | return file_info.type != pafs.FileType.NotFound
74 |
75 | # Override.
76 | def fetch_file(self, input_path: str, local_prefix: str, timeout_seconds: int = 120) -> str:
77 | # This raises if the input path is not relative to the root.
78 | relative_input_path = Path(input_path).relative_to(self._root_path)
79 |
80 | target_path = os.path.join(local_prefix, str(relative_input_path))
81 | self._create_path(os.path.dirname(target_path))
82 | self._fs.copy_file(input_path, target_path)
83 | return target_path
84 |
85 | # Override.
86 | def fetch_partial_file_s3(
87 | self, input_path: str, local_prefix: str, offset: int, size: int, timeout_seconds: int = 120
88 | ) -> str:
89 | raise NotImplementedError("fetch_partial_file_s3")
90 |
91 | # Override.
92 | def put_object_s3(self, object_bytes: bytes, s3_path: str) -> None:
93 | self._create_path(os.path.dirname(s3_path))
94 | with self._fs.open_output_stream(s3_path) as ostream:
95 | ostream.write(object_bytes)
96 |
97 | # Override.
98 | def put_file_s3(self, local_path: str, s3_path: str) -> None:
99 | self._create_path(os.path.dirname(s3_path))
100 | self._fs.copy_file(local_path, s3_path)
101 |
--------------------------------------------------------------------------------