├── .github └── workflows │ ├── ci.yml │ ├── docs.yml │ └── publish.yml ├── .gitignore ├── CODEOWNERS ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── VERSION ├── dev-requirements.txt ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── getstarted.rst │ ├── index.rst │ └── schema.rst ├── mypy.ini ├── notebooks └── Wicker Hello World.ipynb ├── pyproject.toml ├── setup.cfg ├── setup.py ├── tests ├── .wickerconfig.test.json ├── __init__.py ├── core │ ├── __init__.py │ └── test_persistence.py ├── test_avro_schema.py ├── test_column_files.py ├── test_datasets.py ├── test_filelock.py ├── test_numpy_codec.py ├── test_shuffle.py ├── test_spark.py ├── test_storage.py └── test_wandb.py ├── tox.ini └── wicker ├── __init__.py ├── core ├── __init__.py ├── abstract.py ├── column_files.py ├── config.py ├── datasets.py ├── definitions.py ├── errors.py ├── filelock.py ├── persistance.py ├── shuffle.py ├── storage.py ├── utils.py └── writer.py ├── plugins ├── __init__.py ├── dynamodb.py ├── flyte.py ├── spark.py └── wandb.py ├── schema ├── __init__.py ├── codecs.py ├── dataloading.py ├── dataparsing.py ├── schema.py ├── serialization.py └── validation.py └── testing ├── __init__.py ├── codecs.py └── storage.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | # Controls when the workflow will run 4 | on: 5 | # Triggers the workflow on push or pull request events but only for the main branch 6 | push: 7 | branches: [ main ] 8 | pull_request: 9 | branches: [ main ] 10 | 11 | # Allows you to run this workflow manually from the Actions tab 12 | workflow_dispatch: 13 | 14 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 15 | jobs: 16 | # This workflow contains a single job called "run-ci" 17 | run-ci: 18 | # The type of runner that the job will run on 19 | runs-on: ubuntu-latest 20 | 21 | # Set environment variables 22 | # ToDo: remove this and make the config hermetic and tunable so each test can run without this file overhead 23 | env: 24 | WICKER_CONFIG_PATH: ${{ github.workspace }}/tests/.wickerconfig.test.json 25 | 26 | # Steps represent a sequence of tasks that will be executed as part of the job 27 | steps: 28 | 29 | - name: Check out repo 30 | uses: actions/checkout@v2 31 | 32 | - name: Set up Python 3.8 33 | uses: actions/setup-python@v1 34 | with: 35 | python-version: 3.8 36 | 37 | - name: Install dependencies 38 | run: | 39 | pip install --upgrade pip 40 | pip install -r dev-requirements.txt 41 | 42 | - name: Run tests 43 | run: make test 44 | 45 | - name: Run lints 46 | run: make lint 47 | 48 | - name: Run type checks 49 | run: make type-check 50 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | # Controls when the workflow will run 4 | on: 5 | # Triggers the workflow on push events but only for the main branch 6 | push: 7 | branches: [ main ] 8 | pull_request: 9 | branches: [ main ] 10 | 11 | # Allows you to run this workflow manually from the Actions tab 12 | workflow_dispatch: 13 | 14 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 15 | jobs: 16 | 17 | deploy-docs: 18 | # The type of runner that the job will run on 19 | runs-on: ubuntu-latest 20 | 21 | # Set environment variables 22 | env: 23 | WICKER_CONFIG_PATH: tests/.wickerconfig.test.json 24 | 25 | # Steps represent a sequence of tasks that will be executed as part of the job 26 | steps: 27 | 28 | - name: Check out repo 29 | uses: actions/checkout@v2 30 | 31 | - name: Set up Python 3.8 32 | uses: actions/setup-python@v1 33 | with: 34 | python-version: 3.8 35 | 36 | - name: Install dependencies 37 | run: | 38 | pip install --upgrade pip 39 | pip install -r dev-requirements.txt 40 | 41 | - name: Build docs 42 | run: | 43 | pushd docs 44 | make html 45 | popd 46 | touch docs/build/html/.nojekyll 47 | 48 | - name: Deploy docs 49 | uses: JamesIves/github-pages-deploy-action@4.1.8 50 | with: 51 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 52 | BASE_BRANCH: master # The branch the action should deploy from. 53 | BRANCH: gh-pages # The branch the action should deploy to. 54 | FOLDER: docs/build/html # The folder the action should deploy (only stuff inside will be copied). 55 | # Reactivate when ready 56 | if: github.event_name == 'push' 57 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: publish 2 | 3 | # Controls when the workflow will run 4 | on: 5 | # Triggers the workflow on push events but only for tags starting with v* 6 | push: 7 | tags: ["v*"] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 13 | jobs: 14 | 15 | publish-pypi: 16 | # The type of runner that the job will run on 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - name: Check out repo 21 | uses: actions/checkout@v2 22 | 23 | - name: Build package 24 | run: | 25 | python3 -m pip install --upgrade build 26 | python3 -m build 27 | 28 | - name: Publish package 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.PYPI_API_TOKEN }} 33 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | .idea/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # Code env specific 133 | .vscode 134 | *.egg 135 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @aalavian @anantsimran @chrisochoatri @convexquad @marccarre @pickles-bread-and-butter 2 | 3 | CODEOWNERS @aalavian @marccarre 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Woven Planet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include VERSION 2 | include README.md 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | docs: 2 | sphinx-apidoc -f -o docs/source/ . 3 | cd docs && make html 4 | @echo "Check out file://docs/html/index.html for the generated docs." 5 | 6 | lint-fix: 7 | @echo "Running automated lint fixes" 8 | python -m isort tests/* wicker/* 9 | python -m black tests/* wicker/* 10 | 11 | lint: 12 | @echo "Running wicker lints" 13 | # Ignoring slice errors E203: https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#slices 14 | python -m flake8 tests/* wicker/* --ignore E203,W503 15 | python -m isort tests/* wicker/* --check --diff 16 | python -m black tests/* wicker/* --check 17 | 18 | type-check: 19 | @echo "Running wicker type checking with mypy" 20 | python -m mypy tests 21 | python -m mypy wicker 22 | 23 | test: 24 | @echo "Running wicker tests" 25 | python -m unittest discover 26 | 27 | .PHONY: black docs lint type-check test 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wicker 2 | 3 | Wicker is an open source framework for Machine Learning dataset storage and serving developed at Woven Planet L5. 4 | 5 | # Usage 6 | 7 | Refer to the [Wicker documentation's Getting Started page](https://woven-planet.github.io/wicker/getstarted.html) for more information. 8 | 9 | # Development 10 | 11 | To develop on Wicker to contribute, set up your local environment as follows: 12 | 13 | 1. Create a new virtual environment 14 | 2. Do a `pip install -r dev-requirements.txt` to install the development dependencies 15 | 3. Run `make test` to run all unit tests 16 | 4. Run `make lint-fix` to fix all lints and `make lint` to check for any lints that must be fixed manually 17 | 5. Run `make type-check` to check for type errors 18 | 19 | To contribute a new plugin to have Wicker be compatible with other technologies (e.g. Kubernetes, Ray, AWS batch etc): 20 | 21 | 1. Add your plugin into the `wicker.plugins` module as an appropriately named module 22 | 2. If your new plugin requires new external dependencies: 23 | 1. Add a new extra-requires entry to `setup.cfg` 24 | 2. Update `dev-requirements.txt` with any necessary dependencies to run your module in unit tests 25 | 3. Write a unit test in `tests/` to test your module 26 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.0.18 2 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.12 2 | attrs==21.2.0 3 | avro==1.10.2 4 | Babel==2.9.1 5 | black==21.11b1 6 | boto3==1.20.22 7 | botocore==1.23.22 8 | certifi==2021.10.8 9 | charset-normalizer==2.0.9 10 | click==8.0.3 11 | decorator==5.1.0 12 | docutils==0.17.1 13 | flake8==4.0.1 14 | idna==3.3 15 | imagesize==1.3.0 16 | iniconfig==1.1.1 17 | isort==5.9.3 18 | Jinja2==3.0.3 19 | jmespath==0.10.0 20 | MarkupSafe==2.0.1 21 | mccabe==0.6.1 22 | mypy==0.960 23 | mypy-extensions==0.4.3 24 | numpy==1.21.2 25 | packaging==21.0 26 | pathspec==0.9.0 27 | platformdirs==2.4.0 28 | pluggy==1.0.0 29 | py==1.10.0 30 | py4j==0.10.9.2 31 | pyarrow==5.0.0 32 | pycodestyle==2.8.0 33 | pyflakes==2.4.0 34 | Pygments==2.10.0 35 | pyparsing==2.4.7 36 | pyspark==3.2.0 37 | pytest==6.2.5 38 | python-dateutil==2.8.2 39 | pytz==2021.3 40 | regex==2021.11.10 41 | requests==2.26.0 42 | retry==0.9.2 43 | retrying==1.3.3 44 | s3transfer==0.5.0 45 | six==1.16.0 46 | snowballstemmer==2.2.0 47 | Sphinx==4.3.2 48 | sphinxcontrib-applehelp==1.0.2 49 | sphinxcontrib-devhelp==1.0.2 50 | sphinxcontrib-htmlhelp==2.0.0 51 | sphinxcontrib-jsmath==1.0.1 52 | sphinxcontrib-qthelp==1.0.3 53 | sphinxcontrib-serializinghtml==1.1.5 54 | toml==0.10.2 55 | tomli==1.2.2 56 | types-retry==0.9.2 57 | typing-extensions==3.10.0.2 58 | urllib3==1.26.7 59 | wandb==0.12.21 60 | tqdm 61 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'Wicker' 21 | copyright = '2021, Woven Planet L5' 22 | author = 'Jay Chia, Alex Bain, Francois Lefevere, Ben Hansen, Jason Zhao' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | "sphinx.ext.autodoc", 32 | ] 33 | 34 | # Add any paths that contain templates here, relative to this directory. 35 | templates_path = ['_templates'] 36 | 37 | # List of patterns, relative to source directory, that match files and 38 | # directories to ignore when looking for source files. 39 | # This pattern also affects html_static_path and html_extra_path. 40 | exclude_patterns = [ 41 | "tests*", 42 | "setup.py", 43 | ] 44 | 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | 48 | # The theme to use for HTML and HTML Help pages. See the documentation for 49 | # a list of builtin themes. 50 | # 51 | html_theme = 'alabaster' 52 | 53 | # Add any paths that contain custom static files (such as style sheets) here, 54 | # relative to this directory. They are copied after the builtin static files, 55 | # so a file named "default.css" will overwrite the builtin "default.css". 56 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/source/getstarted.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | Wicker is an open source framework for Machine Learning dataset storage and serving developed at Woven Planet L5. 5 | 6 | Wicker leverages other open source technologies such as Apache Arrow and Apache Parquet to store and serve data. Operating 7 | Wicker mainly requires users to provide an object store (currently Wicker is only compatible with AWS S3, but integrations with 8 | other cloud object stores are a work-in-progress). 9 | 10 | Out of the box, Wicker provides integrations with several widely used technologies such as Spark, Flyte and DynamoDB to allow users 11 | to write Wicker datasets from these data infrastructures. However, Wicker was built with a high degree of extensibility in mind, and 12 | allows users to build and use their own implementations to easily integrate with their own infrastructure. 13 | 14 | Installation 15 | ------------ 16 | 17 | ``pip install wicker`` 18 | 19 | Additionally, in order to use some of the provided integrations with other open-source tooling such as Spark, Flyte and Kubernetes, 20 | users may optionally add these options as extra install arguments: 21 | 22 | ``pip install wicker[spark,flyte,kubernetes,...]`` 23 | 24 | Configuration 25 | ------------- 26 | 27 | By default, Wicker searches for a configurations file at ``~/wickerconfig.json``. Users may also change this path by setting the 28 | ``WICKER_CONFIG_PATH`` variable to point to their configuration JSON file. 29 | 30 | .. code-block:: json 31 | 32 | { 33 | "aws_s3_config": { 34 | "s3_datasets_path": "s3://my-bucket/somepath", // Path to the AWS bucket + prefix to use 35 | "region": "us-west-2", // Region of your bucket 36 | "store_concatenated_bytes_files_in_dataset": true // (Optional) Whether to store concatenated bytes files in the dataset 37 | } 38 | } 39 | 40 | Writing your first Dataset 41 | -------------------------- 42 | 43 | Wicker allows users to work both locally and in the cloud, by leveraging different compute and storage backends. 44 | 45 | Note that every dataset must have a defined schema. We define schemas using Wicker's schema library: 46 | 47 | .. code-block:: python3 48 | 49 | from wicker import schema 50 | 51 | MY_SCHEMA = schema.DatasetSchema( 52 | primary_keys=["foo"], 53 | fields=[ 54 | schema.StringField("foo", description="This is an optional human-readable description of the field"), 55 | schema.NumpyField("arr", shape=(4, 4), dtype="float64"), 56 | ] 57 | ) 58 | 59 | The above schema defines a dataset that consists of data that looks like: 60 | 61 | .. code-block:: python3 62 | 63 | { 64 | "foo": "some_string", 65 | "arr": np.array([ 66 | [1., 1., 1., 1.], 67 | [1., 1., 1., 1.], 68 | [1., 1., 1., 1.], 69 | [1., 1., 1., 1.], 70 | ]) 71 | } 72 | 73 | We have the guarantee that the dataset will be: 74 | 75 | 1. Sorted by each examples's `"foo"` field as this is the only primary_key of the dataset 76 | 2. Each example's `"arr"` field contains a 4-by-4 numpy array of float64 values 77 | 78 | After defining a schema, we can then start to write data conforming to this schema to a dataset 79 | 80 | Using Spark 81 | ^^^^^^^^^^^ 82 | 83 | Spark is a common data engine and Wicker provides integrations to write datasets from Spark. 84 | 85 | .. code-block:: python3 86 | 87 | from wicker.plugins.spark import SparkPersistor 88 | 89 | examples = [ 90 | ( 91 | "train", # Wicker dataset partition that this row belongs to 92 | { 93 | "foo": f"foo{i}", 94 | "arr": np.ones((4, 4)), 95 | } 96 | ) for i in range(1000) 97 | ] 98 | 99 | rdd = spark_context.parallelize(examples) 100 | persistor = SparkPersistor() 101 | persistor.persist_wicker_dataset( 102 | "my_dataset_name", 103 | "0.0.1", 104 | MY_SCHEMA, 105 | rdd, 106 | ) 107 | 108 | And that's it! Wicker will handle all the sorting and persisting of the data for you under the hood. 109 | 110 | Using Non-Data Engine Infrastructures 111 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 112 | 113 | Not all users have access to infrastructure like Spark, or want to fire up something quite as heavyweight for 114 | maybe a smaller dataset or use-case. For these users, Wicker exposes a ``DatasetWriter`` API for adding and committing 115 | examples from any environment. 116 | 117 | To make this work, Wicker needs an intermediate ``MetadataDatabase`` to store and index information about each row before 118 | it commits the dataset. We provide a default integration with DynamoDB, but users can implement their own integrations easily 119 | by implementing the abstract interface ``wicker.core.writer.AbstractDatasetWriterMetadataDatabase``, and use their own 120 | MetadataDatabases as intermediate storage for persisting their data. Integrations with other databases as to use as a Wicker-compatible 121 | MetadataDatabase is a work-in-progress. 122 | 123 | Below, we provide an example of how we can use `Flyte `_ to commit our datasets, using DynamoDB as our 124 | MetadataDatabase. More plugins are being written for other commonly used cloud infrastructure such as AWS Batch, Kubernetes etc. 125 | 126 | .. code-block:: python3 127 | 128 | from wicker.schema import serialization 129 | from wicker.core.definitions import DatasetDefinition, DatasetID 130 | from wicker.core.writer import DatasetWriter 131 | from wicker.plugins import dynamodb, flyte 132 | 133 | # First, add the following to our ~/.wickerconfig.json file to enable Wicker's DynamoDB integrations 134 | # 135 | # "dynamodb_config": { // only if users need to use DynamoDB for writing datasets 136 | # "table_name": "my-table", // name of the table to use in dynamodb 137 | # "region": "us-west-2" // region of your table 138 | # } 139 | 140 | metadata_database = dynamodb.DynamodbMetadataDatabase() 141 | dataset_definition = DatasetDefinition(DatasetID(name="my_dataset", version="0.0.1"), MY_SCHEMA) 142 | 143 | # (1): Add examples to your dataset 144 | # 145 | # Note that this can be called from anywhere asynchronously, e.g. in different Flyte workers, from 146 | # a Jupyter notebook, a local Python script etc - as long as the same metadata_database config is used 147 | with DatasetWriter(dataset_definition, metadata_database) as writer: 148 | writer.add_example( 149 | "train", # Name of your Wicker dataset partition (e.g. train, test, eval, unittest, ...) 150 | { 151 | "foo": "foo1", 152 | "arr": np.eye(4).astype("float64"), 153 | }, # Raw data for a single example that conforms to your schema 154 | ) 155 | 156 | # (2): When ready, commit the dataset. 157 | # 158 | # Trigger the Flyte workflow to commit the dataset, either from the Flyte UI, Flyte CLI or from a Python script 159 | flyte.WickerDataShufflingWorkflow( 160 | dataset_id=str(dataset_definition.dataset_id), 161 | schema_json_str=serialization.dumps(MY_SCHEMA), 162 | ) 163 | 164 | 1. Start adding examples to your dataset. Note: 165 | a) Here we use a ``DynamodbMetadataDatabase`` as the metadata storage for this dataset, but users can use other Metadata Database implementations here as well if they do not have an accessible DynamoDB instance. 166 | b) The ``.add_example(...)`` call writes a single example to the ``"train"`` partition, and can potentially throw a ``WickerSchemaException`` error if the data provided does not conform to the schema. 167 | 2. Commit your dataset. Note here that we use the committing functionality provided by ``wicker.plugins.flyte``, but more plugins for other data infrastructures are a work-in-progress (e.g. Kubernetes, AWS Batch) 168 | 169 | 170 | Reading from your Dataset 171 | ------------------------- 172 | 173 | .. code-block:: python3 174 | 175 | from wicker.core.datasets import S3Dataset 176 | 177 | ds = S3Dataset("my_new_dataset", "0.0.1", "train", columns_to_load=["foo", "arr"]) 178 | 179 | # Check the size of your "train" partition 180 | len(ds) 181 | 182 | # Retrieve a single item, initial access is slow (O(seconds)) 183 | x0 = ds[0] 184 | 185 | # Subsequent data accesses are fast (O(us)), data is cached in page buffers 186 | x0_ = ds[0] 187 | 188 | # Access to contiguous indices is also fast (O(ms)), data is cached on disk/in page buffers 189 | x1 = ds[1] 190 | 191 | Reading from your dataset is as simple as indexing on an ``S3Dataset`` handle. Note: 192 | 193 | 1. Wicker is built for high-throughput and initial access times are amortized by accessing contiguous chunks of indices. Sampling for distributed ML training should take this into account and provide each worker with a contiguous chunk of indices as its working set for good performance. 194 | 195 | 2. Wicker allows users to select columns that they are interested in using, using the ``columns_to_load`` keyword argument 196 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Wicker documentation master file, created by 2 | sphinx-quickstart on Sun Dec 26 08:41:22 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Wicker's documentation! 7 | ================================== 8 | 9 | Wicker is a Python framework for storing and serving Machine Learning datasets. It provides: 10 | 11 | **Abstraction of underlying storage** 12 | 13 | Read and write data as Python dictionaries - don't worry about where/how it is stored. 14 | 15 | 16 | **Schematized Data** 17 | 18 | Enforce a schema when reading and writing your data, providing readability and useability for users of your datasets. 19 | 20 | 21 | **Metadata Indexing** 22 | 23 | Easily store, query and visualize your metadata in your own stores such as BigQuery or Snowflake for downstream analyses. 24 | 25 | 26 | **Efficient Dataloading** 27 | 28 | Built and optimized for high-throughput distributed training of deep learning models. 29 | 30 | 31 | .. toctree:: 32 | getstarted 33 | schema 34 | 35 | Indices and tables 36 | ================== 37 | 38 | * :ref:`genindex` 39 | * :ref:`modindex` 40 | * :ref:`search` 41 | -------------------------------------------------------------------------------- /docs/source/schema.rst: -------------------------------------------------------------------------------- 1 | Wicker Schemas 2 | ============== 3 | 4 | Every Wicker dataset has an associated schema which is declared at schema write-time. 5 | 6 | Wicker schemas are Python objects which are serialized in storage as Avro-compatible JSON files. 7 | When declaring schemas, we use the ``wicker.schema.DatasetSchema`` object: 8 | 9 | .. code-block:: python3 10 | 11 | from wicker.schema import DatasetSchema 12 | 13 | my_schema = DatasetSchema( 14 | primary_keys=["foo", "bar"], 15 | fields=[...], 16 | ) 17 | 18 | Your schema must be defined with a set of primary_keys. Your primary keys must be the names of 19 | string, float, int or bool fields in your schema, and will be used to order your dataset. 20 | 21 | Schema Fields 22 | ------------- 23 | 24 | Here is a list of Schema fields that Wicker provides. Most notably, users can implement custom fields 25 | by implementing their own codecs and using the ``ObjectField``. 26 | 27 | .. automodule:: wicker.schema.schema 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | :exclude-members: DatasetSchema, *._accept_visitor 32 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.8 3 | warn_unused_configs = True 4 | 5 | [mypy-boto3.*] 6 | ignore_missing_imports = True 7 | 8 | [mypy-pynamodb.*] 9 | ignore_missing_imports = True 10 | 11 | [mypy-retrying.*] 12 | ignore_missing_imports = True 13 | 14 | [mypy-pyarrow.*] 15 | ignore_missing_imports = True 16 | 17 | [mypy-pyspark.*] 18 | ignore_missing_imports = True 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.black] 9 | line-length = 120 10 | 11 | [tool.isort] 12 | profile = "black" 13 | multi_line_output = 3 14 | skip = ["docs"] 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = wicker 3 | version = file: VERSION 4 | author = wicker-maintainers 5 | author_email = wicker-maintainers@woven-planet.global 6 | description = An open source framework for Machine Learning dataset storage and serving 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/woven-planet/Wicker 10 | project_urls = 11 | Bug Tracker = https://github.com/woven-planet/Wicker/issues 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: Apache Software License 15 | Operating System :: OS Independent 16 | 17 | [options] 18 | include_package_data = True 19 | package_dir = 20 | =. 21 | packages = find: 22 | python_requires = >=3.6 23 | install_requires = 24 | numpy 25 | pyarrow 26 | boto3 27 | 28 | [options.extras_require] 29 | flyte = flytekit 30 | dynamodb = pynamodb 31 | spark = pyspark 32 | wandb = wandb 33 | 34 | [options.packages.find] 35 | where = . 36 | exclude = 37 | tests 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /tests/.wickerconfig.test.json: -------------------------------------------------------------------------------- 1 | { 2 | "aws_s3_config": { 3 | "s3_datasets_path": "s3://fake_data/", 4 | "region": "us-west-2", 5 | "boto_config": { 6 | "max_pool_connections":10, 7 | "read_timeout_s": 140, 8 | "connect_timeout_s": 140 9 | } 10 | }, 11 | "filesystem_configs": [ 12 | { 13 | "config_name": "filesystem_1", 14 | "prefix_replace_path": "s3://fake_data_1/", 15 | "root_datasets_path": "/mnt/fake_data_1/" 16 | }, 17 | { 18 | "config_name": "filesystem_2", 19 | "prefix_replace_path": "s3://fake_data_2/", 20 | "root_datasets_path": "/mnt/fake_data_2/" 21 | } 22 | ], 23 | "dynamodb_config": { 24 | "table_name": "fake_db", 25 | "region": "us-west-2" 26 | }, 27 | "storage_download_config":{ 28 | "retries": 2, 29 | "timeout": 150, 30 | "retry_backoff":5, 31 | "retry_delay_s": 4 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/tests/__init__.py -------------------------------------------------------------------------------- /tests/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/tests/core/__init__.py -------------------------------------------------------------------------------- /tests/core/test_persistence.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | import uuid 5 | from typing import Any, Dict, List, Tuple 6 | 7 | import pyarrow.fs as pafs 8 | import pyarrow.parquet as papq 9 | import pytest 10 | 11 | from wicker import schema 12 | from wicker.core.datasets import S3Dataset 13 | from wicker.core.persistance import BasicPersistor 14 | from wicker.core.storage import S3PathFactory 15 | from wicker.schema.schema import DatasetSchema 16 | from wicker.testing.storage import LocalDataStorage 17 | 18 | DATASET_NAME = "dataset" 19 | DATASET_VERSION = "0.0.1" 20 | SCHEMA = schema.DatasetSchema( 21 | primary_keys=["bar", "foo"], 22 | fields=[ 23 | schema.IntField("foo"), 24 | schema.StringField("bar"), 25 | schema.BytesField("bytescol"), 26 | ], 27 | ) 28 | EXAMPLES = [ 29 | ( 30 | "train" if i % 2 == 0 else "test", 31 | { 32 | "foo": random.randint(0, 10000), 33 | "bar": str(uuid.uuid4()), 34 | "bytescol": b"0", 35 | }, 36 | ) 37 | for i in range(10000) 38 | ] 39 | # Examples with a duplicated key 40 | EXAMPLES_DUPES = copy.deepcopy(EXAMPLES) 41 | 42 | 43 | @pytest.fixture 44 | def mock_basic_persistor(request, tmpdir) -> Tuple[BasicPersistor, str]: 45 | storage = request.param.get("storage", LocalDataStorage(root_path=tmpdir)) 46 | path_factory = request.param.get("path_factory", S3PathFactory(s3_root_path=os.path.join(tmpdir, "datasets"))) 47 | return BasicPersistor(storage, path_factory), tmpdir 48 | 49 | 50 | def assert_written_correctness(persistor: BasicPersistor, data: List[Tuple[str, Dict[str, Any]]], tmpdir: str) -> None: 51 | """Asserts that all files are written as expected by the L5MLDatastore""" 52 | # Check that files are correctly written locally by Spark/Parquet with a _SUCCESS marker file 53 | prefix = persistor.s3_path_factory.root_path 54 | assert DATASET_NAME in os.listdir(os.path.join(tmpdir, prefix)) 55 | assert DATASET_VERSION in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME)) 56 | for partition in ["train", "test"]: 57 | print(os.listdir(os.path.join(tmpdir, prefix))) 58 | columns_path = os.path.join(tmpdir, prefix, "__COLUMN_CONCATENATED_FILES__") 59 | all_read_bytes = b"" 60 | for filename in os.listdir(columns_path): 61 | concatenated_bytes_filepath = os.path.join(columns_path, filename) 62 | with open(concatenated_bytes_filepath, "rb") as bytescol_file: 63 | all_read_bytes += bytescol_file.read() 64 | assert all_read_bytes == b"0" * 10000 65 | 66 | # Load parquet file and assert ordering of primary_key 67 | assert f"{partition}.parquet" in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION)) 68 | tbl = papq.read_table(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION, f"{partition}.parquet")) 69 | foobar = [(barval.as_py(), fooval.as_py()) for fooval, barval in zip(tbl["foo"], tbl["bar"])] 70 | assert foobar == sorted(foobar) 71 | 72 | # Also load the data from the dataset and check it. 73 | ds = S3Dataset( 74 | DATASET_NAME, 75 | DATASET_VERSION, 76 | partition, 77 | local_cache_path_prefix=tmpdir, 78 | storage=persistor.s3_storage, 79 | s3_path_factory=persistor.s3_path_factory, 80 | pa_filesystem=pafs.LocalFileSystem(), 81 | ) 82 | 83 | ds_values = [ds[idx] for idx in range(len(ds))] 84 | # Sort the expected values by the primary keys. 85 | expected_values = sorted([value for p, value in data if p == partition], key=lambda v: (v["bar"], v["foo"])) 86 | assert ds_values == expected_values 87 | 88 | 89 | @pytest.mark.parametrize( 90 | "mock_basic_persistor, dataset_name, dataset_version, dataset_schema, dataset", 91 | [({}, DATASET_NAME, DATASET_VERSION, SCHEMA, copy.deepcopy(EXAMPLES_DUPES))], 92 | indirect=["mock_basic_persistor"], 93 | ) 94 | def test_basic_persistor( 95 | mock_basic_persistor: Tuple[BasicPersistor, str], 96 | dataset_name: str, 97 | dataset_version: str, 98 | dataset_schema: DatasetSchema, 99 | dataset: List[Tuple[str, Dict[str, Any]]], 100 | ): 101 | """ 102 | Test if the basic persistor can persist data in the format we have established. 103 | 104 | Ensure we read the right file locations, the right amount of bytes, 105 | and the ordering is correct. 106 | """ 107 | # create the mock basic persistor 108 | mock_basic_persistor_obj, tempdir = mock_basic_persistor 109 | # persist the dataset 110 | mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) 111 | # assert the dataset is correctly written 112 | assert_written_correctness(persistor=mock_basic_persistor_obj, data=dataset, tmpdir=tempdir) 113 | -------------------------------------------------------------------------------- /tests/test_column_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | import uuid 5 | from unittest.mock import MagicMock 6 | 7 | from wicker.core.column_files import ( 8 | ColumnBytesFileCache, 9 | ColumnBytesFileLocationV1, 10 | ColumnBytesFileReader, 11 | ColumnBytesFileWriter, 12 | ) 13 | from wicker.core.definitions import DatasetID, DatasetPartition 14 | from wicker.core.storage import S3PathFactory, WickerPathFactory 15 | from wicker.testing.storage import FakeS3DataStorage 16 | 17 | FAKE_SHARD_ID = "fake_shard_id" 18 | FAKE_DATA_PARTITION = DatasetPartition(dataset_id=DatasetID(name="name", version="0.0.1"), partition="partition") 19 | FAKE_BYTES = b"foo" 20 | FAKE_BYTES2 = b"0123456789" 21 | FAKE_COL = "col0" 22 | FAKE_COL2 = "col1" 23 | 24 | 25 | class TestColumnBytesFileWriter(unittest.TestCase): 26 | def test_write_empty(self) -> None: 27 | mock_storage = MagicMock() 28 | path_factory = S3PathFactory() 29 | with ColumnBytesFileWriter( 30 | storage=mock_storage, 31 | s3_path_factory=path_factory, 32 | ): 33 | pass 34 | mock_storage.put_object_s3.assert_not_called() 35 | 36 | def test_write_one_column_one_row(self) -> None: 37 | path_factory = S3PathFactory() 38 | mock_storage = MagicMock() 39 | with ColumnBytesFileWriter( 40 | storage=mock_storage, 41 | s3_path_factory=path_factory, 42 | ) as ccb: 43 | info = ccb.add(FAKE_COL, FAKE_BYTES) 44 | self.assertEqual(info.byte_offset, 0) 45 | self.assertEqual(info.data_size, len(FAKE_BYTES)) 46 | mock_storage.put_file_s3.assert_called_once_with( 47 | unittest.mock.ANY, 48 | os.path.join( 49 | path_factory.get_column_concatenated_bytes_files_path(), 50 | str(info.file_id), 51 | ), 52 | ) 53 | 54 | def test_write_one_column_multi_row(self) -> None: 55 | path_factory = S3PathFactory() 56 | mock_storage = MagicMock() 57 | with ColumnBytesFileWriter( 58 | storage=mock_storage, 59 | s3_path_factory=path_factory, 60 | ) as ccb: 61 | info = ccb.add(FAKE_COL, FAKE_BYTES) 62 | self.assertEqual(info.byte_offset, 0) 63 | self.assertEqual(info.data_size, len(FAKE_BYTES)) 64 | 65 | next_info = ccb.add(FAKE_COL, FAKE_BYTES) 66 | self.assertEqual(next_info.byte_offset, len(FAKE_BYTES)) 67 | self.assertEqual(next_info.data_size, len(FAKE_BYTES)) 68 | mock_storage.put_file_s3.assert_called_once_with( 69 | unittest.mock.ANY, 70 | os.path.join( 71 | path_factory.get_column_concatenated_bytes_files_path(), 72 | str(info.file_id), 73 | ), 74 | ) 75 | 76 | def test_write_multi_column_multi_row(self) -> None: 77 | path_factory = S3PathFactory() 78 | mock_storage = MagicMock() 79 | with ColumnBytesFileWriter( 80 | storage=mock_storage, 81 | s3_path_factory=path_factory, 82 | ) as ccb: 83 | info1 = ccb.add(FAKE_COL, FAKE_BYTES) 84 | self.assertEqual(info1.byte_offset, 0) 85 | self.assertEqual(info1.data_size, len(FAKE_BYTES)) 86 | 87 | info1 = ccb.add(FAKE_COL, FAKE_BYTES) 88 | self.assertEqual(info1.byte_offset, len(FAKE_BYTES)) 89 | self.assertEqual(info1.data_size, len(FAKE_BYTES)) 90 | 91 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES) 92 | self.assertEqual(info2.byte_offset, 0) 93 | self.assertEqual(info2.data_size, len(FAKE_BYTES)) 94 | 95 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES) 96 | self.assertEqual(info2.byte_offset, len(FAKE_BYTES)) 97 | self.assertEqual(info2.data_size, len(FAKE_BYTES)) 98 | 99 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES) 100 | self.assertEqual(info2.byte_offset, len(FAKE_BYTES) * 2) 101 | self.assertEqual(info2.data_size, len(FAKE_BYTES)) 102 | 103 | mock_storage.put_file_s3.assert_any_call( 104 | unittest.mock.ANY, 105 | os.path.join( 106 | path_factory.get_column_concatenated_bytes_files_path(), 107 | str(info1.file_id), 108 | ), 109 | ) 110 | mock_storage.put_file_s3.assert_any_call( 111 | unittest.mock.ANY, 112 | os.path.join( 113 | path_factory.get_column_concatenated_bytes_files_path(), 114 | str(info2.file_id), 115 | ), 116 | ) 117 | 118 | def test_write_large_file(self) -> None: 119 | with tempfile.TemporaryDirectory() as tmpdir: 120 | path_factory = S3PathFactory() 121 | storage = FakeS3DataStorage(tmpdir=tmpdir) 122 | with ColumnBytesFileWriter(storage=storage, s3_path_factory=path_factory, target_file_size=10) as ccb: 123 | info1 = ccb.add(FAKE_COL, FAKE_BYTES) 124 | self.assertEqual(info1.byte_offset, 0) 125 | self.assertEqual(info1.data_size, len(FAKE_BYTES)) 126 | info1 = ccb.add(FAKE_COL, FAKE_BYTES) 127 | self.assertEqual(info1.byte_offset, len(FAKE_BYTES)) 128 | self.assertEqual(info1.data_size, len(FAKE_BYTES)) 129 | info1 = ccb.add(FAKE_COL, FAKE_BYTES) 130 | self.assertEqual(info1.byte_offset, len(FAKE_BYTES) * 2) 131 | self.assertEqual(info1.data_size, len(FAKE_BYTES)) 132 | info1 = ccb.add(FAKE_COL, FAKE_BYTES) 133 | self.assertEqual(info1.byte_offset, len(FAKE_BYTES) * 3) 134 | self.assertEqual(info1.data_size, len(FAKE_BYTES)) 135 | info2 = ccb.add(FAKE_COL, FAKE_BYTES) 136 | self.assertEqual(info2.byte_offset, 0) 137 | self.assertEqual(info2.data_size, len(FAKE_BYTES)) 138 | self.assertNotEqual(info1.file_id, info2.file_id) 139 | 140 | info1 = ccb.add(FAKE_COL2, FAKE_BYTES2) 141 | self.assertEqual(info1.byte_offset, 0) 142 | self.assertEqual(info1.data_size, len(FAKE_BYTES2)) 143 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES2) 144 | self.assertEqual(info2.byte_offset, 0) 145 | self.assertEqual(info2.data_size, len(FAKE_BYTES2)) 146 | self.assertNotEqual(info1.file_id, info2.file_id) 147 | info3 = ccb.add(FAKE_COL2, FAKE_BYTES2) 148 | self.assertEqual(info3.byte_offset, 0) 149 | self.assertEqual(info3.data_size, len(FAKE_BYTES2)) 150 | self.assertNotEqual(info2.file_id, info3.file_id) 151 | 152 | def test_write_manyrows_file(self) -> None: 153 | with tempfile.TemporaryDirectory() as tmpdir: 154 | path_factory = S3PathFactory() 155 | storage = FakeS3DataStorage(tmpdir=tmpdir) 156 | with ColumnBytesFileWriter( 157 | storage=storage, s3_path_factory=path_factory, target_file_rowgroup_size=1 158 | ) as ccb: 159 | info1 = ccb.add(FAKE_COL, FAKE_BYTES) 160 | self.assertEqual(info1.byte_offset, 0) 161 | self.assertEqual(info1.data_size, len(FAKE_BYTES)) 162 | info2 = ccb.add(FAKE_COL, FAKE_BYTES) 163 | self.assertEqual(info2.byte_offset, 0) 164 | self.assertEqual(info2.data_size, len(FAKE_BYTES)) 165 | self.assertNotEqual(info1.file_id, info2.file_id) 166 | 167 | info1 = ccb.add(FAKE_COL2, FAKE_BYTES2) 168 | self.assertEqual(info1.byte_offset, 0) 169 | self.assertEqual(info1.data_size, len(FAKE_BYTES2)) 170 | info2 = ccb.add(FAKE_COL2, FAKE_BYTES2) 171 | self.assertEqual(info2.byte_offset, 0) 172 | self.assertEqual(info2.data_size, len(FAKE_BYTES2)) 173 | self.assertNotEqual(info1.file_id, info2.file_id) 174 | info3 = ccb.add(FAKE_COL2, FAKE_BYTES2) 175 | self.assertEqual(info3.byte_offset, 0) 176 | self.assertEqual(info3.data_size, len(FAKE_BYTES2)) 177 | self.assertNotEqual(info2.file_id, info3.file_id) 178 | 179 | 180 | class TestColumnBytesFileCacheAndReader(unittest.TestCase): 181 | def test_write_one_column_one_row_and_read(self) -> None: 182 | path_factory = S3PathFactory() 183 | mock_storage = MagicMock() 184 | # First, write a row to column bytes mock storage. 185 | with ColumnBytesFileWriter( 186 | storage=mock_storage, 187 | s3_path_factory=path_factory, 188 | ) as ccb: 189 | info = ccb.add(FAKE_COL, FAKE_BYTES) 190 | self.assertEqual(info.byte_offset, 0) 191 | self.assertEqual(info.data_size, len(FAKE_BYTES)) 192 | s3_path = os.path.join( 193 | path_factory.get_column_concatenated_bytes_files_path(), 194 | str(info.file_id), 195 | ) 196 | mock_storage.put_file_s3.assert_called_once_with(unittest.mock.ANY, s3_path) 197 | 198 | # Now, verify that we can read it back from mock storage. 199 | local_path = os.path.join("/tmp", s3_path.split("s3://fake_data/")[1]) 200 | mock_storage.fetch_file = MagicMock(return_value=local_path) 201 | 202 | cbf_cache = ColumnBytesFileCache( 203 | storage=mock_storage, 204 | path_factory=path_factory, 205 | ) 206 | 207 | # Mock the helper function that just opens a file and returns its contents. 208 | read_column_bytes_file_mock = MagicMock(return_value=FAKE_BYTES) 209 | cbf_cache._read_column_bytes_file = read_column_bytes_file_mock # type: ignore 210 | 211 | # Read back the column file. 212 | bytes_read = cbf_cache.read(info) 213 | mock_storage.fetch_file.assert_called_once_with(s3_path, "/tmp", timeout_seconds=-1) 214 | self.assertEqual(len(bytes_read), len(FAKE_BYTES)) 215 | self.assertEqual(bytes_read, FAKE_BYTES) 216 | 217 | # Now let's verify that we can use the reader directly to read the local column bytes. 218 | wicker_path_factory = WickerPathFactory("/tmp") 219 | cbf_reader = ColumnBytesFileReader(wicker_path_factory) 220 | 221 | # Reset the reader function mock so we can check that it was called just once. 222 | read_column_bytes_file_mock = MagicMock(return_value=FAKE_BYTES) 223 | cbf_reader._read_column_bytes_file = read_column_bytes_file_mock # type: ignore 224 | 225 | bytes_read = cbf_reader.read(info) 226 | read_column_bytes_file_mock.assert_called_once_with( 227 | column_bytes_file_info=info, 228 | column_bytes_file_path=local_path, 229 | ) 230 | self.assertEqual(len(bytes_read), len(FAKE_BYTES)) 231 | self.assertEqual(bytes_read, FAKE_BYTES) 232 | 233 | 234 | class TestCCBInfo(unittest.TestCase): 235 | def test_to_string(self) -> None: 236 | ccb_info = ColumnBytesFileLocationV1(uuid.uuid4(), 100, 100) 237 | ccb_info_parsed = ColumnBytesFileLocationV1.from_bytes(ccb_info.to_bytes()) 238 | self.assertEqual(ccb_info, ccb_info_parsed) 239 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | from contextlib import contextmanager 5 | from typing import Any, Iterator, NamedTuple, Tuple 6 | from unittest.mock import patch 7 | 8 | import numpy as np 9 | import pyarrow as pa # type: ignore 10 | import pyarrow.fs as pafs # type: ignore 11 | import pyarrow.parquet as papq # type: ignore 12 | 13 | from wicker.core.column_files import ColumnBytesFileWriter 14 | from wicker.core.config import ( 15 | FILESYSTEM_CONFIG, 16 | StorageDownloadConfig, 17 | WickerAwsS3Config, 18 | WickerConfig, 19 | WickerFileSystemConfig, 20 | WickerWandBConfig, 21 | ) 22 | from wicker.core.datasets import FileSystemDataset, S3Dataset, build_dataset 23 | from wicker.core.definitions import DatasetID, DatasetPartition 24 | from wicker.core.storage import FileSystemDataStorage, S3PathFactory, WickerPathFactory 25 | from wicker.schema import schema, serialization 26 | from wicker.testing.storage import FakeS3DataStorage 27 | 28 | FAKE_NAME = "dataset_name" 29 | FAKE_VERSION = "0.0.1" 30 | FAKE_PARTITION = "train" 31 | FAKE_DATASET_ID = DatasetID(name=FAKE_NAME, version=FAKE_VERSION) 32 | FAKE_DATASET_PARTITION = DatasetPartition(FAKE_DATASET_ID, partition=FAKE_PARTITION) 33 | 34 | FAKE_SHAPE = (4, 4) 35 | FAKE_DTYPE = "float64" 36 | FAKE_NUMPY_CODEC = schema.WickerNumpyCodec(FAKE_SHAPE, FAKE_DTYPE) 37 | FAKE_SCHEMA = schema.DatasetSchema( 38 | primary_keys=["foo"], 39 | fields=[ 40 | schema.StringField("foo"), 41 | schema.NumpyField("np_arr", shape=FAKE_SHAPE, dtype=FAKE_DTYPE), 42 | ], 43 | ) 44 | FAKE_DATA = [{"foo": f"bar{i}", "np_arr": np.eye(4)} for i in range(1000)] 45 | 46 | 47 | def build_mock_wicker_config(tmpdir: str) -> WickerConfig: 48 | """Helper function to build WickerConfig objects to use as unit test mocks.""" 49 | return WickerConfig( 50 | raw={}, 51 | aws_s3_config=WickerAwsS3Config.from_json({}), 52 | filesystem_configs=[ 53 | WickerFileSystemConfig.from_json( 54 | { 55 | "config_name": "filesystem_1", 56 | "prefix_replace_path": "", 57 | "root_datasets_path": os.path.join(tmpdir, "fake_data"), 58 | } 59 | ), 60 | ], 61 | storage_download_config=StorageDownloadConfig.from_json({}), 62 | wandb_config=WickerWandBConfig.from_json({}), 63 | ) 64 | 65 | 66 | @contextmanager 67 | def cwd(path): 68 | """Changes the current working directory, and returns to the previous directory afterwards""" 69 | oldpwd = os.getcwd() 70 | os.chdir(path) 71 | try: 72 | yield 73 | finally: 74 | os.chdir(oldpwd) 75 | 76 | 77 | class TestFileSystemDataset(unittest.TestCase): 78 | @contextmanager 79 | def _setup_storage(self) -> Iterator[Tuple[FileSystemDataStorage, WickerPathFactory, str]]: 80 | with tempfile.TemporaryDirectory() as tmpdir, cwd(tmpdir): 81 | fake_local_fs_storage = FileSystemDataStorage() 82 | fake_local_path_factory = WickerPathFactory(root_path=os.path.join(tmpdir, "fake_data")) 83 | fake_s3_path_factory = S3PathFactory() 84 | fake_s3_storage = FakeS3DataStorage(tmpdir=tmpdir) 85 | with ColumnBytesFileWriter( 86 | storage=fake_s3_storage, 87 | s3_path_factory=fake_s3_path_factory, 88 | target_file_rowgroup_size=10, 89 | ) as writer: 90 | locs = [ 91 | writer.add("np_arr", FAKE_NUMPY_CODEC.validate_and_encode_object(data["np_arr"])) # type: ignore 92 | for data in FAKE_DATA 93 | ] 94 | 95 | arrow_metadata_table = pa.Table.from_pydict( 96 | {"foo": [data["foo"] for data in FAKE_DATA], "np_arr": [loc.to_bytes() for loc in locs]} 97 | ) 98 | metadata_table_path = os.path.join( 99 | tmpdir, fake_local_path_factory._get_dataset_partition_path(FAKE_DATASET_PARTITION) 100 | ) 101 | os.makedirs(os.path.dirname(metadata_table_path), exist_ok=True) 102 | papq.write_table(arrow_metadata_table, metadata_table_path) 103 | 104 | # The mock storage class here actually writes to local storage, so we can use it in the test. 105 | fake_s3_storage.put_object_s3( 106 | serialization.dumps(FAKE_SCHEMA).encode("utf-8"), 107 | fake_local_path_factory._get_dataset_schema_path(FAKE_DATASET_ID), 108 | ) 109 | yield fake_local_fs_storage, fake_local_path_factory, tmpdir 110 | 111 | def test_filesystem_dataset(self): 112 | with self._setup_storage() as (fake_local_storage, fake_local_path_factory, tmpdir): 113 | ds = FileSystemDataset( 114 | FAKE_NAME, 115 | FAKE_VERSION, 116 | FAKE_PARTITION, 117 | fake_local_path_factory, 118 | fake_local_storage, 119 | ) 120 | for i in range(len(FAKE_DATA)): 121 | retrieved = ds[i] 122 | reference = FAKE_DATA[i] 123 | self.assertEqual(retrieved["foo"], reference["foo"]) 124 | np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"]) 125 | 126 | # Also double-check that the builder function is working correctly. 127 | with patch("wicker.core.datasets.get_config") as mock_get_config: 128 | mock_get_config.return_value = build_mock_wicker_config(tmpdir) 129 | ds2 = build_dataset( 130 | FILESYSTEM_CONFIG, 131 | FAKE_NAME, 132 | FAKE_VERSION, 133 | FAKE_PARTITION, 134 | config_name="filesystem_1", 135 | ) 136 | for i in range(len(FAKE_DATA)): 137 | retrieved = ds2[i] 138 | reference = FAKE_DATA[i] 139 | self.assertEqual(retrieved["foo"], reference["foo"]) 140 | np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"]) 141 | 142 | 143 | class TestS3Dataset(unittest.TestCase): 144 | @contextmanager 145 | def _setup_storage(self) -> Iterator[Tuple[FakeS3DataStorage, S3PathFactory, str]]: 146 | """Context manager that sets up a local directory to mimic S3 storage for a committed dataset, 147 | and returns a tuple of (S3DataStorage, S3PathFactory, tmpdir_path) for the caller to use as 148 | fixtures in their tests. 149 | """ 150 | with tempfile.TemporaryDirectory() as tmpdir, cwd(tmpdir): 151 | fake_s3_storage = FakeS3DataStorage(tmpdir=tmpdir) 152 | fake_s3_path_factory = S3PathFactory() 153 | with ColumnBytesFileWriter( 154 | storage=fake_s3_storage, 155 | s3_path_factory=fake_s3_path_factory, 156 | target_file_rowgroup_size=10, 157 | ) as writer: 158 | locs = [ 159 | writer.add("np_arr", FAKE_NUMPY_CODEC.validate_and_encode_object(data["np_arr"])) # type: ignore 160 | for data in FAKE_DATA 161 | ] 162 | arrow_metadata_table = pa.Table.from_pydict( 163 | { 164 | "foo": [data["foo"] for data in FAKE_DATA], 165 | "np_arr": [loc.to_bytes() for loc in locs], 166 | } 167 | ) 168 | metadata_table_path = os.path.join( 169 | tmpdir, 170 | fake_s3_path_factory.get_dataset_partition_path(FAKE_DATASET_PARTITION, s3_prefix=False), 171 | "part-1.parquet", 172 | ) 173 | os.makedirs(os.path.dirname(metadata_table_path), exist_ok=True) 174 | papq.write_table( 175 | arrow_metadata_table, 176 | metadata_table_path, 177 | ) 178 | fake_s3_storage.put_object_s3( 179 | serialization.dumps(FAKE_SCHEMA).encode("utf-8"), 180 | fake_s3_path_factory.get_dataset_schema_path(FAKE_DATASET_ID), 181 | ) 182 | yield fake_s3_storage, fake_s3_path_factory, tmpdir 183 | 184 | def test_dataset(self) -> None: 185 | with self._setup_storage() as (fake_s3_storage, fake_s3_path_factory, tmpdir): 186 | ds = S3Dataset( 187 | FAKE_NAME, 188 | FAKE_VERSION, 189 | FAKE_PARTITION, 190 | local_cache_path_prefix=tmpdir, 191 | columns_to_load=None, 192 | storage=fake_s3_storage, 193 | s3_path_factory=fake_s3_path_factory, 194 | pa_filesystem=pafs.LocalFileSystem(), 195 | ) 196 | 197 | for i in range(len(FAKE_DATA)): 198 | retrieved = ds[i] 199 | reference = FAKE_DATA[i] 200 | self.assertEqual(retrieved["foo"], reference["foo"]) 201 | np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"]) 202 | 203 | def test_filters_dataset(self) -> None: 204 | filtered_value_list = [f"bar{i}" for i in range(100)] 205 | with self._setup_storage() as (fake_s3_storage, fake_s3_path_factory, tmpdir): 206 | ds = S3Dataset( 207 | FAKE_NAME, 208 | FAKE_VERSION, 209 | FAKE_PARTITION, 210 | local_cache_path_prefix=tmpdir, 211 | columns_to_load=None, 212 | storage=fake_s3_storage, 213 | s3_path_factory=fake_s3_path_factory, 214 | pa_filesystem=pafs.LocalFileSystem(), 215 | filters=[("foo", "in", filtered_value_list)], 216 | ) 217 | self.assertEqual(len(ds), len(filtered_value_list)) 218 | retrieved_values_list = [ds[i]["foo"] for i in range(len(ds))] 219 | retrieved_values_list.sort() 220 | filtered_value_list.sort() 221 | self.assertListEqual(retrieved_values_list, filtered_value_list) 222 | 223 | def test_dataset_size(self) -> None: 224 | with self._setup_storage() as (fake_s3_storage, fake_s3_path_factory, tmpdir): 225 | # overwrite the mocked resource function using the fake storage 226 | class FakeResponse(NamedTuple): 227 | content_length: int 228 | 229 | # we do this to mock out using boto3, we use boto3 on the dataset 230 | # size because we can get just file metadata but we only mock 231 | # out the file storage pull usually to mock out boto3 we sub in 232 | # a replacement function that uses the fake storage 233 | class MockedS3Resource: 234 | def __init__(self) -> None: 235 | pass 236 | 237 | def Object(self, bucket: str, key: str) -> FakeResponse: 238 | full_path = os.path.join("s3://", bucket, key) 239 | data = fake_s3_storage.fetch_obj_s3(full_path) 240 | return FakeResponse(content_length=len(data)) 241 | 242 | def mock_resource_returner(_: Any): 243 | return MockedS3Resource() 244 | 245 | with patch("wicker.core.datasets.boto3.resource", mock_resource_returner): 246 | ds = S3Dataset( 247 | FAKE_NAME, 248 | FAKE_VERSION, 249 | FAKE_PARTITION, 250 | local_cache_path_prefix=tmpdir, 251 | columns_to_load=None, 252 | storage=fake_s3_storage, 253 | s3_path_factory=fake_s3_path_factory, 254 | pa_filesystem=pafs.LocalFileSystem(), 255 | ) 256 | 257 | # sub this in to get the local size of the parquet dir 258 | def _get_parquet_dir_size_mocked(): 259 | def get_parquet_size(path="."): 260 | total = 0 261 | with os.scandir(path) as it: 262 | for entry in it: 263 | if entry.is_file() and ".parquet" in entry.name: 264 | total += entry.stat().st_size 265 | elif entry.is_dir(): 266 | total += get_parquet_size(entry.path) 267 | return total 268 | 269 | return get_parquet_size(fake_s3_storage._tmpdir) 270 | 271 | ds._get_parquet_dir_size = _get_parquet_dir_size_mocked # type: ignore 272 | 273 | dataset_size = ds.dataset_size 274 | 275 | # get the expected size, all of the col files plus pyarrow table 276 | def get_dir_size(path="."): 277 | total = 0 278 | with os.scandir(path) as it: 279 | for entry in it: 280 | if entry.is_file() and ".json" not in entry.name: 281 | total += entry.stat().st_size 282 | elif entry.is_dir(): 283 | total += get_dir_size(entry.path) 284 | return total 285 | 286 | expected_bytes = get_dir_size(fake_s3_storage._tmpdir) 287 | assert expected_bytes == dataset_size 288 | -------------------------------------------------------------------------------- /tests/test_filelock.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import re 4 | import sys 5 | import tempfile 6 | import time 7 | import unittest 8 | from typing import Optional 9 | 10 | from wicker.core.filelock import SimpleUnixFileLock 11 | 12 | 13 | def write_index_to_file(filepath: str, index: int) -> None: 14 | # If we don't protect this writing operation with a lock, we will get interleaved 15 | # writing in the file instead of lines that have "foo-[index]" 16 | with SimpleUnixFileLock(f"{filepath}.lock"): 17 | for c in f"foo-{index}\n": 18 | with open(filepath, "a") as f: 19 | f.write(c) 20 | 21 | 22 | class TestFileLock(unittest.TestCase): 23 | def test_simple_acquire_no_race_conditions(self) -> None: 24 | with tempfile.NamedTemporaryFile() as lockfile: 25 | placeholder: Optional[str] = None 26 | with SimpleUnixFileLock(lockfile.name): 27 | placeholder = "foo" 28 | self.assertEqual(placeholder, "foo") 29 | 30 | def test_synchronized_concurrent_writes(self) -> None: 31 | with tempfile.NamedTemporaryFile("w+") as write_file: 32 | with multiprocessing.Pool(4) as pool: 33 | pool.starmap(write_index_to_file, [(write_file.name, i) for i in range(1000)]) 34 | for line in write_file.readlines(): 35 | self.assertTrue(re.match(r"foo-[0-9]+", line)) 36 | 37 | def test_timeout(self) -> None: 38 | with tempfile.NamedTemporaryFile() as lockfile: 39 | 40 | def run_acquire_and_sleep() -> None: 41 | with SimpleUnixFileLock(lockfile.name): 42 | time.sleep(2) 43 | 44 | proc = multiprocessing.Process(target=run_acquire_and_sleep) 45 | proc.start() 46 | 47 | # Give the other process some time to start up 48 | time.sleep(1) 49 | 50 | with self.assertRaises(TimeoutError): 51 | with SimpleUnixFileLock(lockfile.name, timeout_seconds=1): 52 | assert False, "Lock acquisition should time out" 53 | 54 | def test_process_dies(self) -> None: 55 | with tempfile.NamedTemporaryFile() as lockfile: 56 | 57 | def run_acquire_and_die() -> None: 58 | """Simulate a process dying on some exception while the lock is acquired""" 59 | # Mute itself to prevent too much noise in stdout/stderr 60 | with open(os.devnull, "w") as devnull: 61 | sys.stdout = devnull 62 | sys.stderr = devnull 63 | lock = SimpleUnixFileLock(lockfile.name) 64 | lock.__enter__() 65 | raise ValueError("die") 66 | 67 | # When the child process dies, the fd (and the exclusive lock) should be closed automatically 68 | proc = multiprocessing.Process(target=run_acquire_and_die) 69 | proc.start() 70 | 71 | # Give the other process some time to start up 72 | time.sleep(1) 73 | 74 | placeholder: Optional[str] = None 75 | with SimpleUnixFileLock(lockfile.name, timeout_seconds=1): 76 | placeholder = "foo" 77 | self.assertEqual(placeholder, "foo") 78 | -------------------------------------------------------------------------------- /tests/test_numpy_codec.py: -------------------------------------------------------------------------------- 1 | import io 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | from wicker.core.errors import WickerSchemaException 7 | from wicker.schema.schema import WickerNumpyCodec 8 | 9 | EYE_ARR = np.eye(4) 10 | eye_arr_bio = io.BytesIO() 11 | np.save(eye_arr_bio, EYE_ARR) 12 | EYE_ARR_BYTES = eye_arr_bio.getvalue() 13 | 14 | 15 | class TestNumpyCodec(unittest.TestCase): 16 | def test_codec_none_shape(self) -> None: 17 | codec = WickerNumpyCodec(shape=None, dtype="float64") 18 | self.assertEqual(codec.get_codec_name(), "wicker_numpy") 19 | self.assertEqual(codec.object_type(), np.ndarray) 20 | self.assertEqual( 21 | codec.save_codec_to_dict(), 22 | { 23 | "dtype": "float64", 24 | "shape": None, 25 | }, 26 | ) 27 | self.assertEqual(WickerNumpyCodec.load_codec_from_dict(codec.save_codec_to_dict()), codec) 28 | self.assertEqual(codec.validate_and_encode_object(EYE_ARR), EYE_ARR_BYTES) 29 | np.testing.assert_equal(codec.decode_object(EYE_ARR_BYTES), EYE_ARR) 30 | 31 | def test_codec_unbounded_dim_shape(self) -> None: 32 | codec = WickerNumpyCodec(shape=(-1, -1), dtype="float64") 33 | self.assertEqual(codec.get_codec_name(), "wicker_numpy") 34 | self.assertEqual(codec.object_type(), np.ndarray) 35 | self.assertEqual( 36 | codec.save_codec_to_dict(), 37 | { 38 | "dtype": "float64", 39 | "shape": [-1, -1], 40 | }, 41 | ) 42 | self.assertEqual(WickerNumpyCodec.load_codec_from_dict(codec.save_codec_to_dict()), codec) 43 | self.assertEqual(codec.validate_and_encode_object(EYE_ARR), EYE_ARR_BYTES) 44 | np.testing.assert_equal(codec.decode_object(EYE_ARR_BYTES), EYE_ARR) 45 | 46 | # Should raise when provided with bad shapes with too few/many dimensions 47 | with self.assertRaises(WickerSchemaException): 48 | codec.validate_and_encode_object(np.ones((10,))) 49 | with self.assertRaises(WickerSchemaException): 50 | codec.validate_and_encode_object(np.ones((10, 10, 10))) 51 | 52 | def test_codec_fixed_shape(self) -> None: 53 | codec = WickerNumpyCodec(shape=(4, 4), dtype="float64") 54 | self.assertEqual(codec.get_codec_name(), "wicker_numpy") 55 | self.assertEqual(codec.object_type(), np.ndarray) 56 | self.assertEqual( 57 | codec.save_codec_to_dict(), 58 | { 59 | "dtype": "float64", 60 | "shape": [4, 4], 61 | }, 62 | ) 63 | self.assertEqual(WickerNumpyCodec.load_codec_from_dict(codec.save_codec_to_dict()), codec) 64 | self.assertEqual(codec.validate_and_encode_object(EYE_ARR), EYE_ARR_BYTES) 65 | np.testing.assert_equal(codec.decode_object(EYE_ARR_BYTES), EYE_ARR) 66 | 67 | # Should raise when provided with bad shapes with too few/many/wrong dimensions 68 | with self.assertRaises(WickerSchemaException): 69 | codec.validate_and_encode_object(np.ones((10,))) 70 | with self.assertRaises(WickerSchemaException): 71 | codec.validate_and_encode_object(np.ones((10, 10, 10))) 72 | with self.assertRaises(WickerSchemaException): 73 | codec.validate_and_encode_object(np.ones((5, 4))) 74 | 75 | def test_codec_bad_dtype(self) -> None: 76 | with self.assertRaises(WickerSchemaException): 77 | WickerNumpyCodec(shape=(4, 4), dtype="SOME_BAD_DTYPE") 78 | -------------------------------------------------------------------------------- /tests/test_shuffle.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tempfile 3 | import unittest 4 | import unittest.mock 5 | from typing import IO, Dict, List, Tuple 6 | 7 | from wicker.core.column_files import ColumnBytesFileLocationV1 8 | from wicker.core.config import get_config 9 | from wicker.core.definitions import DatasetDefinition, DatasetID, DatasetPartition 10 | from wicker.core.shuffle import ShuffleJob, ShuffleJobFactory, ShuffleWorker 11 | from wicker.core.writer import MetadataDatabaseScanRow 12 | from wicker.schema import dataparsing, schema, serialization 13 | from wicker.testing.codecs import Vector, VectorCodec 14 | from wicker.testing.storage import FakeS3DataStorage 15 | 16 | DATASET_NAME = "dataset1" 17 | DATASET_VERSION = "0.0.1" 18 | FAKE_DATASET_SCHEMA = schema.DatasetSchema( 19 | fields=[ 20 | schema.IntField("timestamp"), 21 | schema.StringField("car_id"), 22 | schema.ObjectField("vector", VectorCodec(0)), 23 | ], 24 | primary_keys=["car_id", "timestamp"], 25 | ) 26 | FAKE_EXAMPLE = { 27 | "timestamp": 1, 28 | "car_id": "somecar", 29 | "vector": Vector([0, 0, 0]), 30 | } 31 | FAKE_DATASET_ID = DatasetID(name=DATASET_NAME, version=DATASET_VERSION) 32 | FAKE_DATASET_DEFINITION = DatasetDefinition( 33 | dataset_id=FAKE_DATASET_ID, 34 | schema=FAKE_DATASET_SCHEMA, 35 | ) 36 | 37 | 38 | class TestShuffleJobFactory(unittest.TestCase): 39 | def test_shuffle_job_factory_no_entries(self) -> None: 40 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend""" 41 | mock_rows: List[MetadataDatabaseScanRow] = [] 42 | mock_dataset_writer_backend = unittest.mock.MagicMock() 43 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows) 44 | job_factory = ShuffleJobFactory(mock_dataset_writer_backend) 45 | self.assertEqual(list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)), []) 46 | 47 | def test_shuffle_job_factory_one_entry(self) -> None: 48 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend""" 49 | mock_rows: List[MetadataDatabaseScanRow] = [ 50 | MetadataDatabaseScanRow(partition="train", row_data_path="somepath", row_size=1337) 51 | ] 52 | mock_dataset_writer_backend = unittest.mock.MagicMock() 53 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows) 54 | job_factory = ShuffleJobFactory(mock_dataset_writer_backend) 55 | self.assertEqual( 56 | list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)), 57 | [ 58 | ShuffleJob( 59 | dataset_partition=DatasetPartition( 60 | dataset_id=FAKE_DATASET_ID, 61 | partition="train", 62 | ), 63 | files=[("somepath", 1337)], 64 | ) 65 | ], 66 | ) 67 | 68 | def test_shuffle_job_factory_one_partition(self) -> None: 69 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend""" 70 | partition = "train" 71 | mock_rows: List[MetadataDatabaseScanRow] = [ 72 | MetadataDatabaseScanRow(partition=partition, row_data_path=f"somepath{i}", row_size=i) for i in range(10) 73 | ] 74 | mock_dataset_writer_backend = unittest.mock.MagicMock() 75 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows) 76 | job_factory = ShuffleJobFactory(mock_dataset_writer_backend) 77 | self.assertEqual( 78 | list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)), 79 | [ 80 | ShuffleJob( 81 | dataset_partition=DatasetPartition( 82 | dataset_id=FAKE_DATASET_ID, 83 | partition=partition, 84 | ), 85 | files=[(row.row_data_path, row.row_size) for row in mock_rows], 86 | ) 87 | ], 88 | ) 89 | 90 | def test_shuffle_job_factory_one_partition_two_working_sets(self) -> None: 91 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend""" 92 | partition = "train" 93 | worker_max_working_set_size = 5 94 | mock_rows: List[MetadataDatabaseScanRow] = [ 95 | MetadataDatabaseScanRow(partition=partition, row_data_path=f"somepath{i}", row_size=i) for i in range(10) 96 | ] 97 | mock_dataset_writer_backend = unittest.mock.MagicMock() 98 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows) 99 | job_factory = ShuffleJobFactory( 100 | mock_dataset_writer_backend, worker_max_working_set_size=worker_max_working_set_size 101 | ) 102 | self.assertEqual( 103 | list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)), 104 | [ 105 | ShuffleJob( 106 | dataset_partition=DatasetPartition( 107 | dataset_id=FAKE_DATASET_ID, 108 | partition=partition, 109 | ), 110 | files=[(row.row_data_path, row.row_size) for row in mock_rows[:worker_max_working_set_size]], 111 | ), 112 | ShuffleJob( 113 | dataset_partition=DatasetPartition( 114 | dataset_id=FAKE_DATASET_ID, 115 | partition=partition, 116 | ), 117 | files=[(row.row_data_path, row.row_size) for row in mock_rows[worker_max_working_set_size:]], 118 | ), 119 | ], 120 | ) 121 | 122 | def test_shuffle_job_factory_multi_partition_multi_working_sets(self) -> None: 123 | """Tests the functionality of a shuffle_job_factory when provided with a mock backend""" 124 | partitions = ("train", "test") 125 | num_rows = 10 126 | num_batches = 2 127 | worker_max_working_set_size = num_rows // num_batches 128 | mock_rows: List[MetadataDatabaseScanRow] = [ 129 | MetadataDatabaseScanRow(partition=partition, row_data_path=f"somepath{i}", row_size=i) 130 | for partition in partitions 131 | for i in range(10) 132 | ] 133 | mock_dataset_writer_backend = unittest.mock.MagicMock() 134 | mock_dataset_writer_backend._metadata_db.scan_sorted.return_value = (row for row in mock_rows) 135 | job_factory = ShuffleJobFactory( 136 | mock_dataset_writer_backend, worker_max_working_set_size=worker_max_working_set_size 137 | ) 138 | self.assertEqual( 139 | list(job_factory.build_shuffle_jobs(FAKE_DATASET_ID)), 140 | [ 141 | ShuffleJob( 142 | dataset_partition=DatasetPartition( 143 | dataset_id=FAKE_DATASET_ID, 144 | partition=partition, 145 | ), 146 | files=[ 147 | (row.row_data_path, row.row_size) 148 | for row in mock_rows[ 149 | worker_max_working_set_size 150 | * (batch + (partition_index * num_batches)) : worker_max_working_set_size 151 | * (batch + (partition_index * num_batches) + 1) 152 | ] 153 | ], 154 | ) 155 | for partition_index, partition in enumerate(partitions) 156 | for batch in range(2) 157 | ], 158 | ) 159 | 160 | 161 | class TestShuffleWorker(unittest.TestCase): 162 | def setUp(self) -> None: 163 | self.boto3_patcher = unittest.mock.patch("wicker.core.shuffle.boto3") 164 | self.boto3_mock = self.boto3_patcher.start() 165 | self.uploaded_column_bytes_files: Dict[Tuple[str, str], bytes] = {} 166 | 167 | def tearDown(self) -> None: 168 | self.boto3_patcher.stop() 169 | self.uploaded_column_bytes_files.clear() 170 | 171 | @staticmethod 172 | def download_fileobj_mock(bucket: str, key: str, bio: IO) -> None: 173 | bio.write( 174 | pickle.dumps( 175 | dataparsing.parse_example( 176 | FAKE_EXAMPLE, 177 | FAKE_DATASET_SCHEMA, 178 | ) 179 | ) 180 | ) 181 | bio.seek(0) 182 | return None 183 | 184 | def test_process_job(self) -> None: 185 | with tempfile.TemporaryDirectory() as tmpdir: 186 | # The threaded workers each construct their own S3DataStorage from a boto3 client to download 187 | # file data in parallel. We mock those out here by mocking out the boto3 client itself. 188 | self.boto3_mock.session.Session().client().download_fileobj.side_effect = self.download_fileobj_mock 189 | 190 | fake_job = ShuffleJob( 191 | dataset_partition=DatasetPartition( 192 | dataset_id=FAKE_DATASET_ID, 193 | partition="test", 194 | ), 195 | files=[(f"s3://somebucket/path/{i}", i) for i in range(10)], 196 | ) 197 | 198 | fake_storage = FakeS3DataStorage(tmpdir=tmpdir) 199 | fake_storage.put_object_s3( 200 | serialization.dumps(FAKE_DATASET_SCHEMA).encode("utf-8"), 201 | f"{get_config().aws_s3_config.s3_datasets_path}/{FAKE_DATASET_ID.name}" 202 | f"/{FAKE_DATASET_ID.version}/avro_schema.json", 203 | ) 204 | worker = ShuffleWorker(storage=fake_storage) 205 | shuffle_results = worker.process_job(fake_job) 206 | 207 | self.assertEqual(shuffle_results["timestamp"].to_pylist(), [FAKE_EXAMPLE["timestamp"] for _ in range(10)]) 208 | self.assertEqual(shuffle_results["car_id"].to_pylist(), [FAKE_EXAMPLE["car_id"] for _ in range(10)]) 209 | for location_bytes in shuffle_results["vector"].to_pylist(): 210 | location = ColumnBytesFileLocationV1.from_bytes(location_bytes) 211 | path = worker.s3_path_factory.get_column_concatenated_bytes_s3path_from_uuid(location.file_id.bytes) 212 | self.assertTrue(fake_storage.check_exists_s3(path)) 213 | data = fake_storage.fetch_obj_s3(path)[location.byte_offset : location.byte_offset + location.data_size] 214 | self.assertEqual(VectorCodec(0).decode_object(data), FAKE_EXAMPLE["vector"]) 215 | -------------------------------------------------------------------------------- /tests/test_spark.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | import tempfile 5 | import unittest 6 | import uuid 7 | 8 | import pyarrow.parquet as papq 9 | from pyspark.sql import SparkSession 10 | 11 | from wicker import schema 12 | from wicker.core.config import get_config 13 | from wicker.core.errors import WickerDatastoreException 14 | from wicker.plugins.spark import persist_wicker_dataset 15 | from wicker.testing.storage import FakeS3DataStorage 16 | 17 | DATASET_NAME = "dataset" 18 | DATASET_VERSION = "0.0.1" 19 | TEST_ROWS_NUM = 10000 20 | SCHEMA = schema.DatasetSchema( 21 | primary_keys=["bar", "foo"], 22 | fields=[ 23 | schema.IntField("foo"), 24 | schema.StringField("bar"), 25 | schema.BytesField("bytescol"), 26 | ], 27 | ) 28 | EXAMPLES = [ 29 | ( 30 | "train" if i % 2 == 0 else "test", 31 | { 32 | "foo": random.randint(0, TEST_ROWS_NUM), 33 | "bar": str(uuid.uuid4()), 34 | "bytescol": b"0", 35 | }, 36 | ) 37 | for i in range(TEST_ROWS_NUM) 38 | ] 39 | # Examples with a duplicated key 40 | EXAMPLES_DUPES = copy.deepcopy(EXAMPLES) 41 | EXAMPLES_DUPES[5000] = EXAMPLES_DUPES[0] 42 | 43 | 44 | class LocalWritingTestCase(unittest.TestCase): 45 | def assert_written_correctness(self, tmpdir: str, row_num: int = TEST_ROWS_NUM) -> None: 46 | """Asserts that all files are written as expected by the L5MLDatastore""" 47 | # Check that files are correctly written locally by Spark/Parquet with a _SUCCESS marker file 48 | prefix = get_config().aws_s3_config.s3_datasets_path.replace("s3://", "") 49 | self.assertIn(DATASET_NAME, os.listdir(os.path.join(tmpdir, prefix))) 50 | self.assertIn(DATASET_VERSION, os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME))) 51 | for partition in ["train", "test"]: 52 | columns_path = os.path.join(tmpdir, prefix, "__COLUMN_CONCATENATED_FILES__") 53 | all_read_bytes = b"" 54 | for filename in os.listdir(columns_path): 55 | concatenated_bytes_filepath = os.path.join(columns_path, filename) 56 | with open(concatenated_bytes_filepath, "rb") as bytescol_file: 57 | all_read_bytes += bytescol_file.read() 58 | self.assertEqual(all_read_bytes, b"0" * row_num) 59 | 60 | # Load parquet file and assert ordering of primary_key 61 | self.assertIn( 62 | f"{partition}.parquet", os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION)) 63 | ) 64 | tbl = papq.read_table(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION, f"{partition}.parquet")) 65 | foobar = [(barval.as_py(), fooval.as_py()) for fooval, barval in zip(tbl["foo"], tbl["bar"])] 66 | self.assertEqual(foobar, sorted(foobar)) 67 | 68 | def test_simple_schema_local_writing(self) -> None: 69 | for local_reduction in (True, False): 70 | for sort in (True, False): 71 | with tempfile.TemporaryDirectory() as tmpdir: 72 | fake_storage = FakeS3DataStorage(tmpdir=tmpdir) 73 | spark_session = SparkSession.builder.appName("test").master("local[*]") 74 | spark = spark_session.getOrCreate() 75 | sc = spark.sparkContext 76 | rdd = sc.parallelize(copy.deepcopy(EXAMPLES), 100) 77 | persist_wicker_dataset( 78 | DATASET_NAME, 79 | DATASET_VERSION, 80 | SCHEMA, 81 | rdd, 82 | fake_storage, 83 | local_reduction=local_reduction, 84 | sort=sort, 85 | ) 86 | self.assert_written_correctness(tmpdir) 87 | 88 | def test_dupe_primary_keys_raises_exception(self) -> None: 89 | for local_reduction in (True, False): 90 | for sort in (True, False): 91 | with self.assertRaises(WickerDatastoreException) as e: 92 | with tempfile.TemporaryDirectory() as tmpdir: 93 | fake_storage = FakeS3DataStorage(tmpdir=tmpdir) 94 | spark_session = SparkSession.builder.appName("test").master("local[*]") 95 | spark = spark_session.getOrCreate() 96 | sc = spark.sparkContext 97 | rdd = sc.parallelize(copy.deepcopy(EXAMPLES_DUPES), 100) 98 | persist_wicker_dataset( 99 | DATASET_NAME, 100 | DATASET_VERSION, 101 | SCHEMA, 102 | rdd, 103 | fake_storage, 104 | local_reduction=local_reduction, 105 | sort=sort, 106 | ) 107 | 108 | self.assertIn( 109 | "Error: dataset examples do not have unique primary key tuples", 110 | str(e.exception), 111 | ) 112 | 113 | def test_simple_schema_local_writing_4_row_dataset(self) -> None: 114 | for local_reduction in (True, False): 115 | for sort in (True, False): 116 | with tempfile.TemporaryDirectory() as tmpdir: 117 | small_row_cnt = 4 118 | fake_storage = FakeS3DataStorage(tmpdir=tmpdir) 119 | spark_session = SparkSession.builder.appName("test").master("local[*]") 120 | spark = spark_session.getOrCreate() 121 | sc = spark.sparkContext 122 | rdd = sc.parallelize(copy.deepcopy(EXAMPLES)[:small_row_cnt], 1) 123 | persist_wicker_dataset( 124 | DATASET_NAME, 125 | DATASET_VERSION, 126 | SCHEMA, 127 | rdd, 128 | fake_storage, 129 | local_reduction=local_reduction, 130 | sort=sort, 131 | ) 132 | self.assert_written_correctness(tmpdir, small_row_cnt) 133 | -------------------------------------------------------------------------------- /tests/test_storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import string 4 | import tempfile 5 | from typing import Any, Dict 6 | from unittest import TestCase, mock 7 | 8 | from botocore.exceptions import ClientError # type: ignore 9 | from botocore.stub import Stubber # type: ignore 10 | 11 | from wicker.core.config import WickerConfig 12 | from wicker.core.storage import FileSystemDataStorage, S3DataStorage, S3PathFactory 13 | 14 | RANDOM_SEED_VALUE = 1 15 | RANDOM_STRING_CHAR_COUNT = 10 16 | 17 | 18 | class TestFileSystemDataStorage(TestCase): 19 | def test_fetch_file(self) -> None: 20 | """Unit test for fetching file from local/mounted file system to different location""" 21 | # put file in the directory that you're using for test 22 | with tempfile.TemporaryDirectory() as temp_dir: 23 | src_dir = os.path.join(temp_dir, "test", "location", "starting", "mount") 24 | os.makedirs(src_dir, exist_ok=True) 25 | src_path = os.path.join(src_dir, "test.txt") 26 | dst_dir = os.path.join(temp_dir, "desired", "location", "for", "test") 27 | os.makedirs(dst_dir, exist_ok=True) 28 | dst_path = os.path.join(dst_dir, "test.txt") 29 | 30 | random.seed(RANDOM_SEED_VALUE) 31 | expected_string = "".join( 32 | random.choices(string.ascii_uppercase + string.digits, k=RANDOM_STRING_CHAR_COUNT) 33 | ) 34 | with open(src_path, "w") as open_src: 35 | open_src.write(expected_string) 36 | 37 | # create local file store 38 | local_datastore = FileSystemDataStorage() 39 | # save file to destination 40 | local_datastore.fetch_file(src_path, dst_dir) 41 | 42 | # verify file exists 43 | assert os.path.exists(dst_path) 44 | 45 | # assert contents are the expected 46 | with open(dst_path, "r") as open_dst_file: 47 | test_string = open_dst_file.readline() 48 | assert test_string == expected_string 49 | 50 | 51 | class TestS3DataStorage(TestCase): 52 | def test_bucket_key_from_s3_path(self) -> None: 53 | """Unit test for the S3DataStorage bucket_key_from_s3_path function""" 54 | data_storage = S3DataStorage() 55 | 56 | s3_url = "s3://hello/world" 57 | bucket, key = data_storage.bucket_key_from_s3_path(s3_url) 58 | self.assertEqual(bucket, "hello") 59 | self.assertEqual(key, "world") 60 | 61 | s3_url = "s3://hello/" 62 | bucket, key = data_storage.bucket_key_from_s3_path(s3_url) 63 | self.assertEqual(bucket, "hello") 64 | self.assertEqual(key, "") 65 | 66 | s3_url = "s3://" 67 | bucket, key = data_storage.bucket_key_from_s3_path(s3_url) 68 | self.assertEqual(bucket, "") 69 | self.assertEqual(key, "") 70 | 71 | s3_url = "s3://hello/world/" 72 | bucket, key = data_storage.bucket_key_from_s3_path(s3_url) 73 | self.assertEqual(bucket, "hello") 74 | self.assertEqual(key, "world/") 75 | 76 | def test_check_exists_s3(self) -> None: 77 | """Unit test for the check_exists_s3 function.""" 78 | data_storage = S3DataStorage() 79 | input_path = "s3://foo/bar/baz/dummy" 80 | 81 | with Stubber(data_storage.client) as stubber: 82 | response = {} # type: ignore 83 | expected_params = {"Bucket": "foo", "Key": "bar/baz/dummy"} 84 | stubber.add_response("head_object", response, expected_params) 85 | self.assertTrue(data_storage.check_exists_s3(input_path)) 86 | 87 | def test_check_exists_s3_nonexisting(self) -> None: 88 | """Unit test for the check_exists_s3 function.""" 89 | data_storage = S3DataStorage() 90 | input_path = "s3://foo/bar/baz/dummy" 91 | 92 | with Stubber(data_storage.client) as stubber: 93 | stubber.add_client_error( 94 | expected_params={"Bucket": "foo", "Key": "bar/baz/dummy"}, 95 | method="head_object", 96 | service_error_code="404", 97 | service_message="The specified key does not exist.", 98 | ) 99 | 100 | # The check_exists_s3 function catches the exception when the key does not exist 101 | self.assertFalse(data_storage.check_exists_s3(input_path)) 102 | 103 | def test_put_object_s3(self) -> None: 104 | """Unit test for the put_object_s3 function.""" 105 | data_storage = S3DataStorage() 106 | object_bytes = b"this is my object" 107 | input_path = "s3://foo/bar/baz/dummy" 108 | 109 | with Stubber(data_storage.client) as stubber: 110 | response = {} # type: ignore 111 | expected_params = { 112 | "Body": object_bytes, 113 | "Bucket": "foo", 114 | "Key": "bar/baz/dummy", 115 | } 116 | stubber.add_response("put_object", response, expected_params) 117 | data_storage.put_object_s3(object_bytes, input_path) 118 | 119 | def test_put_file_s3(self) -> None: 120 | """Unit test for the put_file_s3 function""" 121 | data_storage = S3DataStorage() 122 | object_bytes = b"this is my object" 123 | input_path = "s3://foo/bar/baz/dummy" 124 | 125 | with tempfile.NamedTemporaryFile() as tmpfile: 126 | tmpfile.write(object_bytes) 127 | tmpfile.flush() 128 | 129 | with Stubber(data_storage.client) as stubber: 130 | response = {} # type: ignore 131 | stubber.add_response("put_object", response, None) 132 | data_storage.put_file_s3(tmpfile.name, input_path) 133 | 134 | @staticmethod 135 | def download_file_side_effect(*args, **kwargs) -> None: # type: ignore 136 | """Helper function to patch the S3 download_file function with a side-effect that creates an 137 | empty file at the correct path in order to mock the download""" 138 | input_path = str(kwargs["filename"]) 139 | with open(input_path, "w"): 140 | pass 141 | 142 | # Stubber does not have a stub function for S3 client download_file function, so patch it 143 | @mock.patch("boto3.s3.transfer.S3Transfer.download_file") 144 | def test_fetch_file(self, download_file: mock.Mock) -> None: 145 | """Unit test for the fetch_file function.""" 146 | data_storage = S3DataStorage() 147 | input_path = "s3://foo/bar/baz/dummy" 148 | with tempfile.TemporaryDirectory() as local_prefix: 149 | # Add a side-effect to create the file to download at the correct local path 150 | download_file.side_effect = self.download_file_side_effect 151 | 152 | local_path = data_storage.fetch_file(input_path, local_prefix) 153 | download_file.assert_called_once_with( 154 | bucket="foo", 155 | key="bar/baz/dummy", 156 | filename=f"{local_prefix}/bar/baz/dummy", 157 | extra_args=None, 158 | callback=None, 159 | ) 160 | self.assertTrue(os.path.isfile(local_path)) 161 | 162 | # Stubber does not have a stub function for S3 client download_file function, so patch it 163 | @mock.patch("boto3.s3.transfer.S3Transfer.download_file") 164 | def test_fetch_file_s3_on_nonexistent_file(self, download_file: mock.Mock) -> None: 165 | """Unit test for the fetch_file function for a non-existent file in S3.""" 166 | data_storage = S3DataStorage() 167 | input_path = "s3://foo/bar/barbazz/dummy" 168 | local_prefix = "/tmp" 169 | 170 | response = {"Error": {"Code": "404"}} 171 | side_effect = ClientError(response, "unexpected") 172 | download_file.side_effect = side_effect 173 | 174 | with self.assertRaises(ClientError): 175 | data_storage.fetch_file(input_path, local_prefix) 176 | 177 | 178 | class TestS3PathFactory(TestCase): 179 | @mock.patch("wicker.core.storage.get_config") 180 | def test_get_column_concatenated_bytes_files_path(self, mock_get_config: mock.Mock) -> None: 181 | """Unit test for the S3PathFactory get_column_concatenated_bytes_files_path 182 | function""" 183 | # If store_concatenated_bytes_files_in_dataset is False, return the default path 184 | dummy_config: Dict[str, Any] = { 185 | "aws_s3_config": { 186 | "s3_datasets_path": "s3://dummy_bucket/wicker/", 187 | "region": "us-east-1", 188 | "boto_config": {"max_pool_connections": 10, "read_timeout_s": 140, "connect_timeout_s": 140}, 189 | }, 190 | "dynamodb_config": {"table_name": "fake-table-name", "region": "us-west-2"}, 191 | "storage_download_config": { 192 | "retries": 2, 193 | "timeout": 150, 194 | "retry_backoff": 5, 195 | "retry_delay_s": 4, 196 | }, 197 | } 198 | mock_get_config.return_value = WickerConfig.from_json(dummy_config) 199 | 200 | path_factory = S3PathFactory() 201 | self.assertEqual( 202 | path_factory.get_column_concatenated_bytes_files_path(), 203 | "s3://dummy_bucket/wicker/__COLUMN_CONCATENATED_FILES__", 204 | ) 205 | 206 | # If store_concatenated_bytes_files_in_dataset is True, return the dataset-specific path 207 | dummy_config["aws_s3_config"]["store_concatenated_bytes_files_in_dataset"] = True 208 | mock_get_config.return_value = WickerConfig.from_json(dummy_config) 209 | dataset_name = "dummy_dataset" 210 | self.assertEqual( 211 | S3PathFactory().get_column_concatenated_bytes_files_path(dataset_name=dataset_name), 212 | f"s3://dummy_bucket/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__", 213 | ) 214 | 215 | # If the store_concatenated_bytes_files_in_dataset option is True but no 216 | # dataset_name, raise ValueError 217 | with self.assertRaises(ValueError): 218 | S3PathFactory().get_column_concatenated_bytes_files_path() 219 | 220 | # Test the remove s3 prefix option in the get_column_concatenated_bytes_files_path function 221 | self.assertEqual( 222 | S3PathFactory().get_column_concatenated_bytes_files_path(dataset_name=dataset_name, s3_prefix=False), 223 | f"dummy_bucket/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__", 224 | ) 225 | 226 | # Test when the s3 prefix remove bool is not passed the prefix isn't eliminated. 227 | self.assertEqual( 228 | S3PathFactory(prefix_replace_path="/test_mount_path").get_column_concatenated_bytes_files_path( 229 | dataset_name=dataset_name, s3_prefix=True 230 | ), 231 | f"s3://dummy_bucket/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__", 232 | ) 233 | 234 | self.assertEqual( 235 | S3PathFactory(prefix_replace_path="/test_mount_path/").get_column_concatenated_bytes_files_path( 236 | dataset_name=dataset_name, s3_prefix=False 237 | ), 238 | f"/test_mount_path/dummy_bucket/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__", 239 | ) 240 | 241 | self.assertEqual( 242 | S3PathFactory(prefix_replace_path="/test_mount_path/").get_column_concatenated_bytes_files_path( 243 | dataset_name=dataset_name, s3_prefix=False, cut_prefix_override="s3://dummy_bucket/" 244 | ), 245 | f"/test_mount_path/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__", 246 | ) 247 | 248 | self.assertEqual( 249 | S3PathFactory(prefix_replace_path="/test_mount_path/").get_column_concatenated_bytes_files_path( 250 | dataset_name=dataset_name, s3_prefix=True, cut_prefix_override="s3://dummy_bucket/" 251 | ), 252 | f"/test_mount_path/wicker/{dataset_name}/__COLUMN_CONCATENATED_FILES__", 253 | ) 254 | -------------------------------------------------------------------------------- /tests/test_wandb.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from unittest.mock import call, patch 4 | 5 | import pytest 6 | 7 | from wicker.plugins.wandb import ( 8 | _identify_s3_url_for_dataset_version, 9 | _set_wandb_credentials, 10 | version_dataset, 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def temp_config(request, tmpdir): 16 | # parse the config json into a proper config 17 | config_json = request.param.get("config_json") 18 | parsable_config = { 19 | "wandb_config": config_json.get("wandb_config", {}), 20 | "aws_s3_config": { 21 | "s3_datasets_path": config_json.get("aws_s3_config", {}).get("s3_datasets_path", None), 22 | "region": config_json.get("aws_s3_config", {}).get("region", None), 23 | }, 24 | } 25 | 26 | # set the config path env and then get the creds from the function 27 | temp_config = tmpdir / "temp_config.json" 28 | with open(temp_config, "w") as file_stream: 29 | json.dump(parsable_config, file_stream) 30 | os.environ["WICKER_CONFIG_PATH"] = str(temp_config) 31 | return temp_config, parsable_config 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "temp_config, dataset_name, dataset_version, dataset_metadata, entity", 36 | [ 37 | ( 38 | { 39 | "config_json": { 40 | "aws_s3_config": { 41 | "s3_datasets_path": "s3://test_path/to_nowhere/", 42 | "region": "us-west-2", 43 | }, 44 | "wandb_config": { 45 | "wandb_api_key": "test_key", 46 | "wandb_base_url": "test_url", 47 | }, 48 | } 49 | }, 50 | "test_data", 51 | "0.0.1", 52 | {"test_key": "test_value"}, 53 | "test_entity", 54 | ) 55 | ], 56 | indirect=["temp_config"], 57 | ) 58 | def test_version_dataset(temp_config, dataset_name, dataset_version, dataset_metadata, entity): 59 | """ 60 | GIVEN: A mocked dataset folder to track, a dataset version to track, metadata to track, and a backend to use 61 | WHEN: This dataset is registered or versioned on W&B 62 | THEN: The dataset shows up as a run on the portal 63 | """ 64 | # need to mock out all the wandb calls and test just the inputs to them 65 | _, config = temp_config 66 | with patch("wicker.plugins.wandb.wandb.init") as patched_wandb_init: 67 | with patch("wicker.plugins.wandb.wandb.Artifact") as patched_artifact: 68 | # version the dataset with the patched functions/classes 69 | version_dataset(dataset_name, dataset_version, entity, dataset_metadata) 70 | 71 | # establish the expected calls 72 | expected_artifact_calls = [ 73 | call(f"{dataset_name}_{dataset_version}", type="dataset"), 74 | call().add_reference( 75 | f"{config['aws_s3_config']['s3_datasets_path']}{dataset_name}/{dataset_version}/assets", 76 | name="dataset", 77 | ), 78 | call().metadata.__setitem__("version", dataset_version), 79 | call().metadata.__setitem__( 80 | "s3_uri", f"{config['aws_s3_config']['s3_datasets_path']}{dataset_name}/{dataset_version}/assets" 81 | ), 82 | ] 83 | for key, value in dataset_metadata.items(): 84 | expected_artifact_calls.append(call().metadata.__setitem__(key, value)) 85 | 86 | expected_run_calls = [ 87 | call(project="dataset_curation", name=f"{dataset_name}_{dataset_version}", entity=entity), 88 | call().log_artifact(patched_artifact()), 89 | ] 90 | 91 | # assert that these are properly called 92 | patched_artifact.assert_has_calls(expected_artifact_calls, any_order=True) 93 | 94 | patched_wandb_init.assert_has_calls(expected_run_calls, any_order=True) 95 | 96 | 97 | @pytest.mark.parametrize( 98 | "credentials_to_load, temp_config", 99 | [ 100 | ( 101 | {"WANDB_BASE_URL": "env_base", "WANDB_API_KEY": "env_key", "WANDB_USER_EMAIL": "env_email"}, 102 | { 103 | "config_json": { 104 | "wandb_config": { 105 | "wandb_base_url": "config_base", 106 | "wandb_api_key": "config_key", 107 | } 108 | } 109 | }, 110 | ) 111 | ], 112 | ids=["basic test to override all params"], 113 | indirect=["temp_config"], 114 | ) 115 | def test_set_wandb_credentials(credentials_to_load, temp_config, tmpdir): 116 | """ 117 | GIVEN: A set of credentials as existing envs and a config json specifying creds 118 | WHEN: The configs are requested for wandb 119 | THEN: The proper env variables should be set in the environment based on rules, default to 120 | wicker config and reject preset env variables 121 | """ 122 | temp_config_pth, config_json = temp_config 123 | 124 | # load the creds into the env as the base comparison 125 | for key, value in credentials_to_load.items(): 126 | os.environ[key] = value 127 | 128 | # compare the creds for expected results 129 | _set_wandb_credentials() 130 | assert os.environ["WANDB_BASE_URL"] == config_json["wandb_config"]["wandb_base_url"] 131 | assert os.environ["WANDB_API_KEY"] == config_json["wandb_config"]["wandb_api_key"] 132 | 133 | 134 | @pytest.mark.parametrize( 135 | "temp_config, dataset_name, dataset_version, dataset_backend, expected_url", 136 | [ 137 | ( 138 | {"config_json": {"aws_s3_config": {"s3_datasets_path": "s3://test_path_to_nowhere", "region": "test"}}}, 139 | "test_dataset", 140 | "0.0.0", 141 | "s3", 142 | "s3://test_path_to_nowhere/test_dataset/0.0.0/assets", 143 | ) 144 | ], 145 | ids=["Basic test with s3"], 146 | indirect=["temp_config"], 147 | ) 148 | def test_identify_s3_url_for_dataset_version(temp_config, dataset_name, dataset_version, dataset_backend, expected_url): 149 | """ 150 | GIVEN: A temporary config, a dataset name, version, backend, and expected url 151 | WHEN: The assets url is pulled from the path factory with these parameters 152 | THEN: The expected url should match what is returned from the function 153 | """ 154 | parsed_dataset_url = _identify_s3_url_for_dataset_version(dataset_name, dataset_version, dataset_backend) 155 | assert parsed_dataset_url == expected_url 156 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | exclude=docs 4 | -------------------------------------------------------------------------------- /wicker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/__init__.py -------------------------------------------------------------------------------- /wicker/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/core/__init__.py -------------------------------------------------------------------------------- /wicker/core/abstract.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/core/abstract.py -------------------------------------------------------------------------------- /wicker/core/config.py: -------------------------------------------------------------------------------- 1 | """This module defines how to configure Wicker from the user environment 2 | """ 3 | 4 | from __future__ import annotations 5 | 6 | import dataclasses 7 | import json 8 | import os 9 | from typing import Any, Dict, List 10 | 11 | AWS_S3_CONFIG = "aws_s3_config" 12 | FILESYSTEM_CONFIG = "filesystem_config" 13 | 14 | 15 | @dataclasses.dataclass(frozen=True) 16 | class WickerWandBConfig: 17 | wandb_base_url: str 18 | wandb_api_key: str 19 | 20 | @classmethod 21 | def from_json(cls, data: Dict[str, Any]) -> WickerWandBConfig: 22 | return cls( 23 | wandb_api_key=data.get("wandb_api_key", None), 24 | wandb_base_url=data.get("wandb_base_url", None), 25 | ) 26 | 27 | 28 | @dataclasses.dataclass(frozen=True) 29 | class BotoS3Config: 30 | max_pool_connections: int 31 | read_timeout_s: int 32 | connect_timeout_s: int 33 | 34 | @classmethod 35 | def from_json(cls, data: Dict[str, Any]) -> BotoS3Config: 36 | return cls( 37 | max_pool_connections=data.get("max_pool_connections", 0), 38 | read_timeout_s=data.get("read_timeout_s", 0), 39 | connect_timeout_s=data.get("connect_timeout_s", 0), 40 | ) 41 | 42 | 43 | @dataclasses.dataclass(frozen=True) 44 | class WickerAwsS3Config: 45 | s3_datasets_path: str 46 | region: str 47 | boto_config: BotoS3Config 48 | store_concatenated_bytes_files_in_dataset: bool = False 49 | 50 | @classmethod 51 | def from_json(cls, data: Dict[str, Any]) -> WickerAwsS3Config: 52 | return cls( 53 | s3_datasets_path=data.get("s3_datasets_path", ""), 54 | region=data.get("region", ""), 55 | boto_config=BotoS3Config.from_json(data.get("boto_config", {})), 56 | store_concatenated_bytes_files_in_dataset=data.get("store_concatenated_bytes_files_in_dataset", False), 57 | ) 58 | 59 | 60 | @dataclasses.dataclass(frozen=True) 61 | class WickerFileSystemConfig: 62 | config_name: str 63 | prefix_replace_path: str 64 | root_datasets_path: str 65 | 66 | @classmethod 67 | def from_json(cls, data: Dict[str, Any]) -> WickerFileSystemConfig: 68 | return cls( 69 | config_name=data.get("config_name", ""), 70 | prefix_replace_path=data.get("prefix_replace_path", ""), 71 | root_datasets_path=data.get("root_datasets_path", ""), 72 | ) 73 | 74 | @classmethod 75 | def from_json_list(cls, data: List[Dict[str, Any]]) -> List[WickerFileSystemConfig]: 76 | return [WickerFileSystemConfig.from_json(d) for d in data] 77 | 78 | 79 | @dataclasses.dataclass(frozen=True) 80 | class StorageDownloadConfig: 81 | retries: int 82 | timeout: int 83 | retry_backoff: int 84 | retry_delay_s: int 85 | 86 | @classmethod 87 | def from_json(cls, data: Dict[str, Any]) -> StorageDownloadConfig: 88 | return cls( 89 | retries=data.get("retries", 0), 90 | timeout=data.get("timeout", 0), 91 | retry_backoff=data.get("retry_backoff", 0), 92 | retry_delay_s=data.get("retry_delay_s", 0), 93 | ) 94 | 95 | 96 | @dataclasses.dataclass() 97 | class WickerConfig: 98 | raw: Dict[str, Any] 99 | aws_s3_config: WickerAwsS3Config 100 | filesystem_configs: List[WickerFileSystemConfig] 101 | storage_download_config: StorageDownloadConfig 102 | wandb_config: WickerWandBConfig 103 | 104 | @classmethod 105 | def from_json(cls, data: Dict[str, Any]) -> WickerConfig: 106 | return cls( 107 | raw=data, 108 | aws_s3_config=WickerAwsS3Config.from_json(data.get(AWS_S3_CONFIG, {})), 109 | filesystem_configs=WickerFileSystemConfig.from_json_list(data.get("filesystem_configs", [])), 110 | storage_download_config=StorageDownloadConfig.from_json(data.get("storage_download_config", {})), 111 | wandb_config=WickerWandBConfig.from_json(data.get("wandb_config", {})), 112 | ) 113 | 114 | 115 | def get_config() -> WickerConfig: 116 | """Retrieves the Wicker config for the current process""" 117 | 118 | wicker_config_path = os.getenv("WICKER_CONFIG_PATH", os.path.expanduser("~/wickerconfig.json")) 119 | with open(wicker_config_path, "r") as f: 120 | config = WickerConfig.from_json(json.load(f)) 121 | return config 122 | -------------------------------------------------------------------------------- /wicker/core/definitions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import enum 5 | import re 6 | from typing import Any, Dict, TypeVar 7 | 8 | # flake8 linting struggles with the module name `schema` and DatasetDefinition dataclass field name `schema` 9 | from wicker import schema # noqa: 401 10 | 11 | Example = TypeVar("Example") 12 | ExampleMetadata = Dict[str, Any] 13 | 14 | 15 | DATASET_ID_REGEX = re.compile(r"^(?P[0-9a-zA-Z_]+)/(?P[0-9]+\.[0-9]+\.[0-9]+)$") 16 | 17 | 18 | @dataclasses.dataclass(frozen=True) 19 | class DatasetID: 20 | """Representation of the unique identifier of a dataset 21 | 22 | `name` should be alphanumeric and contain no spaces, only underscores 23 | `version` should be a semantic version (e.g. 1.0.0) 24 | """ 25 | 26 | name: str 27 | version: str 28 | 29 | @classmethod 30 | def from_str(cls, s: str) -> DatasetID: 31 | """Parses a DatasetID from a string""" 32 | match = DATASET_ID_REGEX.match(s) 33 | if not match: 34 | raise ValueError(f"{s} is not a valid DatasetID") 35 | return cls(name=match["name"], version=match["version"]) 36 | 37 | def __str__(self) -> str: 38 | """Helper function to return the representation of the dataset version as a path-like string""" 39 | return f"{self.name}/{self.version}" 40 | 41 | @staticmethod 42 | def validate_dataset_id(name: str, version: str) -> None: 43 | """Validates the name and version of a dataset""" 44 | if not re.match(r"^[0-9a-zA-Z_]+$", name): 45 | raise ValueError( 46 | f"Provided dataset name {name} must be alphanumeric and contain no spaces, only underscores" 47 | ) 48 | if not re.match(r"^[0-9]+\.[0-9]+\.[0-9]+$", version): 49 | pass 50 | """raise ValueError( 51 | f"Provided dataset version {version} should be a semantic version without any prefixes/suffixes" 52 | )""" 53 | 54 | def __post_init__(self) -> None: 55 | DatasetID.validate_dataset_id(self.name, self.version) 56 | 57 | 58 | @dataclasses.dataclass(frozen=True) 59 | class DatasetDefinition: 60 | """Representation of the definition of a dataset (immutable once dataset is added) 61 | 62 | `name` should be alphanumeric and contain no spaces, only underscores 63 | `version` should be a semantic version (e.g. 1.0.0) 64 | """ 65 | 66 | dataset_id: DatasetID 67 | schema: schema.DatasetSchema 68 | 69 | @property 70 | def identifier(self) -> DatasetID: 71 | return DatasetID(name=self.dataset_id.name, version=self.dataset_id.version) 72 | 73 | def __post_init__(self) -> None: 74 | DatasetID.validate_dataset_id(self.dataset_id.name, self.dataset_id.version) 75 | 76 | 77 | @dataclasses.dataclass(frozen=True) 78 | class DatasetPartition: 79 | """Representation of the definition of a partition within dataset 80 | 81 | The partition here is meant to represent the common train/val/test splits of a dataset, but 82 | could also represent partitions for other use cases. 83 | 84 | `partition` should be alphanumeric and contain no spaces, only underscores 85 | """ 86 | 87 | dataset_id: DatasetID 88 | partition: str 89 | 90 | def __str__(self) -> str: 91 | """Helper function to return the representation of the partition as a path-like string""" 92 | return f"{self.dataset_id.name}/{self.dataset_id.version}/{self.partition}" 93 | 94 | 95 | class DatasetState(enum.Enum): 96 | """Representation of the state of a dataset""" 97 | 98 | STAGED = "PENDING" 99 | COMMITTED = "COMMITTED" 100 | -------------------------------------------------------------------------------- /wicker/core/errors.py: -------------------------------------------------------------------------------- 1 | class WickerDatastoreException(Exception): 2 | pass 3 | 4 | 5 | class WickerSchemaException(Exception): 6 | pass 7 | -------------------------------------------------------------------------------- /wicker/core/filelock.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import contextlib 4 | import fcntl 5 | import os 6 | import signal 7 | from typing import Any, Iterator, Optional 8 | 9 | 10 | @contextlib.contextmanager 11 | def _timeout(seconds: int, lockfile: str) -> Iterator[None]: 12 | """Context manager for triggering a timeout via SIGALRM if a blocking call 13 | does not return in `seconds` number of seconds. 14 | 15 | Usage: 16 | >>> with _timeout(5, "somefile"): 17 | >>> some_long_running_function() 18 | 19 | In the above example, if some_long_running_function takes more than 5 seconds, a 20 | TimeoutError will be raised from the `with _timeout(5, "somefile")` context. 21 | """ 22 | if seconds == -1.0: 23 | # No timeout specified, we execute without registering a SIGALRM 24 | yield 25 | else: 26 | 27 | def timeout_handler(signum: Any, frame: Any) -> None: 28 | """Handles receiving a SIGALRM by throwing a TimeoutError with a useful 29 | error message 30 | """ 31 | raise TimeoutError(f"{lockfile} could not be acquired after {seconds}s") 32 | 33 | # Register `timeout_handler` as the handler for any SIGALRMs that are raised 34 | # We store the original SIGALRM handler if available to restore it afterwards 35 | original_handler = signal.signal(signal.SIGALRM, timeout_handler) 36 | 37 | try: 38 | # Raise a SIGALRM in `seconds` number of seconds, and then yield to code 39 | # in the context manager's code-block. 40 | signal.alarm(seconds) 41 | yield 42 | # Make sure to always restore the original SIGALRM handler 43 | finally: 44 | signal.alarm(0) 45 | signal.signal(signal.SIGALRM, original_handler) 46 | 47 | 48 | class SimpleUnixFileLock: 49 | """Simple blocking lock class that uses files to achieve system-wide locking 50 | 51 | 1. If timeout is not specified, waits forever by default 52 | 2. Does not support recursive locking and WILL break if called recursively 53 | """ 54 | 55 | def __init__(self, lock_file: str, timeout_seconds: int = -1): 56 | # The path to the lock file. 57 | self._lock_file = lock_file 58 | 59 | # The file descriptor for the *_lock_file* as it is returned by the 60 | # os.open() function. 61 | # This file lock is only NOT None, if the object currently holds the 62 | # lock. 63 | self._lock_file_fd: Optional[int] = None 64 | 65 | # The default timeout value. 66 | self.timeout_seconds = timeout_seconds 67 | 68 | return None 69 | 70 | def __enter__(self) -> SimpleUnixFileLock: 71 | open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC 72 | fd = os.open(self._lock_file, open_mode) 73 | try: 74 | with _timeout(self.timeout_seconds, self._lock_file): 75 | fcntl.lockf(fd, fcntl.LOCK_EX) 76 | except (IOError, OSError, TimeoutError): 77 | os.close(fd) 78 | raise 79 | else: 80 | self._lock_file_fd = fd 81 | return self 82 | 83 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 84 | assert self._lock_file_fd is not None, "self._lock_file_fd not set by __enter__" 85 | fd = self._lock_file_fd 86 | self._lock_file_fd = None 87 | fcntl.flock(fd, fcntl.LOCK_UN) 88 | os.close(fd) 89 | return None 90 | -------------------------------------------------------------------------------- /wicker/core/persistance.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Dict, Iterable, List, Optional, Tuple 3 | 4 | import pyarrow as pa 5 | import pyarrow.compute as pc 6 | 7 | from wicker import schema as schema_module 8 | from wicker.core.column_files import ColumnBytesFileWriter 9 | from wicker.core.definitions import DatasetID 10 | from wicker.core.shuffle import save_index 11 | from wicker.core.storage import S3DataStorage, S3PathFactory 12 | from wicker.schema import dataparsing, serialization 13 | 14 | MAX_COL_FILE_NUMROW = 50 # TODO(isaak-willett): Magic number, we should derive this based on row size 15 | 16 | UnparsedExample = Dict[str, Any] 17 | ParsedExample = Dict[str, Any] 18 | PointerParsedExample = Dict[str, Any] 19 | 20 | 21 | class AbstractDataPersistor(abc.ABC): 22 | """ 23 | Abstract class for persisting data onto a user defined cloud or local instance. 24 | 25 | Only s3 is supported right now but plan to support other data stores 26 | (BigQuery, Azure, Postgres) 27 | """ 28 | 29 | def __init__( 30 | self, 31 | s3_storage: S3DataStorage = S3DataStorage(), 32 | s3_path_factory: S3PathFactory = S3PathFactory(), 33 | ) -> None: 34 | """ 35 | Init a Persister 36 | 37 | :param s3_storage: The storage abstraction for S3 38 | :type s3_storage: S3DataStore 39 | :param s3_path_factory: The path factory for generating s3 paths 40 | based on dataset name and version 41 | :type s3_path_factory: S3PathFactory 42 | """ 43 | super().__init__() 44 | self.s3_storage = s3_storage 45 | self.s3_path_factory = s3_path_factory 46 | 47 | @abc.abstractmethod 48 | def persist_wicker_dataset( 49 | self, 50 | dataset_name: str, 51 | dataset_version: str, 52 | dataset_schema: schema_module.DatasetSchema, 53 | dataset: Any, 54 | ) -> Optional[Dict[str, int]]: 55 | """ 56 | Persist a user specified dataset defined by name, version, schema, and data. 57 | 58 | :param dataset_name: Name of the dataset 59 | :type dataset_name: str 60 | :param dataset_version: Version of the dataset 61 | :type: dataset_version: str 62 | :param dataset_schema: Schema of the dataset 63 | :type dataset_schema: wicker.schema.schema.DatasetSchema 64 | :param dataset: Data of the dataset 65 | :type dataset: User defined 66 | """ 67 | raise NotImplementedError("Method, persist_wicker_dataset, needs to be implemented in inhertiance class.") 68 | 69 | @staticmethod 70 | def parse_row(data_row: UnparsedExample, schema: schema_module.DatasetSchema) -> ParsedExample: 71 | """ 72 | Parse a row to test for validation errors. 73 | 74 | :param data_row: Data row to be parsed 75 | :type data_row: UnparsedExample 76 | :return: parsed row containing the correct types associated with schema 77 | :rtype: ParsedExample 78 | """ 79 | return dataparsing.parse_example(data_row, schema) 80 | 81 | # Write data to Column Byte Files 82 | 83 | @staticmethod 84 | def persist_wicker_partition( 85 | dataset_name: str, 86 | spark_partition_iter: Iterable[Tuple[str, ParsedExample]], 87 | schema: schema_module.DatasetSchema, 88 | s3_storage: S3DataStorage, 89 | s3_path_factory: S3PathFactory, 90 | target_max_column_file_numrows: int = 50, 91 | ) -> Iterable[Tuple[str, PointerParsedExample]]: 92 | """Persists a Spark partition of examples with parsed bytes into S3Storage as ColumnBytesFiles, 93 | returning a new Spark partition of examples with heavy-pointers and metadata only. 94 | :param dataset_name: dataset name 95 | :param spark_partition_iter: Spark partition of `(partition_str, example)`, where `example` 96 | is a dictionary of parsed bytes that needs to be uploaded to S3 97 | :param target_max_column_file_numrows: Maximum number of rows in column files. Defaults to 50. 98 | :return: a Generator of `(partition_str, example)`, where `example` is a dictionary with heavy-pointers 99 | that point to ColumnBytesFiles in S3 in place of the parsed bytes 100 | """ 101 | column_bytes_file_writers: Dict[str, ColumnBytesFileWriter] = {} 102 | heavy_pointer_columns = schema.get_pointer_columns() 103 | metadata_columns = schema.get_non_pointer_columns() 104 | 105 | for partition, example in spark_partition_iter: 106 | # Create ColumnBytesFileWriter lazily as required, for each partition 107 | if partition not in column_bytes_file_writers: 108 | column_bytes_file_writers[partition] = ColumnBytesFileWriter( 109 | s3_storage, 110 | s3_path_factory, 111 | target_file_rowgroup_size=target_max_column_file_numrows, 112 | dataset_name=dataset_name, 113 | ) 114 | 115 | # Write to ColumnBytesFileWriter and return only metadata + heavy-pointers 116 | parquet_metadata: Dict[str, Any] = {col: example[col] for col in metadata_columns} 117 | for col in heavy_pointer_columns: 118 | loc = column_bytes_file_writers[partition].add(col, example[col]) 119 | parquet_metadata[col] = loc.to_bytes() 120 | yield partition, parquet_metadata 121 | 122 | # Flush all writers when finished 123 | for partition in column_bytes_file_writers: 124 | column_bytes_file_writers[partition].close() 125 | 126 | @staticmethod 127 | def save_partition_tbl( 128 | partition_table_tuple: Tuple[str, pa.Table], 129 | dataset_name: str, 130 | dataset_version: str, 131 | s3_storage: S3DataStorage, 132 | s3_path_factory: S3PathFactory, 133 | ) -> Tuple[str, int]: 134 | """ 135 | Save a partition table to s3 under the dataset name and version. 136 | 137 | :param partition_table_tuple: Tuple of partition id and pyarrow table to save 138 | :type partition_table_tuple: Tuple[str, pyarrow.Table] 139 | :return: A tuple containing the paritiion id and the num of saved rows 140 | :rtype: Tuple[str, int] 141 | """ 142 | partition, pa_tbl = partition_table_tuple 143 | save_index( 144 | dataset_name, 145 | dataset_version, 146 | {partition: pa_tbl}, 147 | s3_storage=s3_storage, 148 | s3_path_factory=s3_path_factory, 149 | ) 150 | return (partition, pa_tbl.num_rows) 151 | 152 | 153 | def persist_wicker_dataset( 154 | dataset_name: str, 155 | dataset_version: str, 156 | dataset_schema: schema_module.DatasetSchema, 157 | dataset: Any, 158 | s3_storage: S3DataStorage = S3DataStorage(), 159 | s3_path_factory: S3PathFactory = S3PathFactory(), 160 | ) -> Optional[Dict[str, int]]: 161 | """ 162 | Persist wicker dataset public facing api function, for api consistency. 163 | :param dataset_name: name of dataset persisted 164 | :type dataset_name: str 165 | :param dataset_version: version of dataset persisted 166 | :type dataset_version: str 167 | :param dataset_schema: schema of dataset to be persisted 168 | :type dataset_schema: DatasetSchema 169 | :param rdd: rdd of data to persist 170 | :type rdd: RDD 171 | :param s3_storage: s3 storage abstraction 172 | :type s3_storage: S3DataStorage 173 | :param s3_path_factory: s3 path abstraction 174 | :type s3_path_factory: S3PathFactory 175 | """ 176 | return BasicPersistor(s3_storage, s3_path_factory).persist_wicker_dataset( 177 | dataset_name, dataset_version, dataset_schema, dataset 178 | ) 179 | 180 | 181 | class BasicPersistor(AbstractDataPersistor): 182 | """ 183 | Basic persistor class that persists wicker data on s3 in a non sorted manner. 184 | 185 | We will move to supporting other features like shuffling, other data engines, etc... 186 | """ 187 | 188 | def __init__( 189 | self, s3_storage: S3DataStorage = S3DataStorage(), s3_path_factory: S3PathFactory = S3PathFactory() 190 | ) -> None: 191 | super().__init__(s3_storage, s3_path_factory) 192 | 193 | def persist_wicker_dataset( 194 | self, dataset_name: str, dataset_version: str, dataset_schema: schema_module.DatasetSchema, dataset: Any 195 | ) -> Optional[Dict[str, int]]: 196 | """ 197 | Persist a user defined dataset, pushing data to s3 in a basic manner 198 | 199 | :param dataset_name: Name of the dataset 200 | :type dataset_name: str 201 | :param dataset_version: Version of the dataset 202 | :type: dataset_version: str 203 | :param dataset_schema: Schema of the dataset 204 | :type dataset_schema: wicker.schema.schema.DatasetSchema 205 | :param dataset: Data of the dataset 206 | :type dataset: User defined 207 | """ 208 | # what needs to be done within this function 209 | # 1. Check if the variables are set 210 | # check if variables have been set ie: not None 211 | if ( 212 | not isinstance(dataset_name, str) 213 | or not isinstance(dataset_version, str) 214 | or not isinstance(dataset_schema, schema_module.DatasetSchema) 215 | ): 216 | raise ValueError("Current dataset variables not all set, set all to proper not None values") 217 | 218 | # 2. Put the schema up on s3 219 | schema_path = self.s3_path_factory.get_dataset_schema_path( 220 | DatasetID(name=dataset_name, version=dataset_version) 221 | ) 222 | self.s3_storage.put_object_s3(serialization.dumps(dataset_schema).encode("utf-8"), schema_path) 223 | 224 | # 3. Validate the rows and ensure data is well formed, sort while doing 225 | dataset_0 = [(row[0], self.parse_row(row[1], dataset_schema)) for row in dataset] 226 | 227 | # 4. Sort the dataset if not sorted 228 | sorted_dataset_0 = sorted(dataset_0, key=lambda tup: tup[0]) 229 | 230 | # 6. Persist the partitions to S3 231 | metadata_iterator = self.persist_wicker_partition( 232 | dataset_name, 233 | sorted_dataset_0, 234 | dataset_schema, 235 | self.s3_storage, 236 | self.s3_path_factory, 237 | MAX_COL_FILE_NUMROW, 238 | ) 239 | 240 | # 7. Create the parition table, need to combine keys in a way we can form table 241 | # split into k dicts where k is partition number and the data is a list of values 242 | # for each key for all the dicts in the partition 243 | merged_dicts: Dict[str, Dict[str, List[Any]]] = {} 244 | for partition_key, row in metadata_iterator: 245 | current_dict: Dict[str, List[Any]] = merged_dicts.get(partition_key, {}) 246 | for col in row.keys(): 247 | if col in current_dict: 248 | current_dict[col].append(row[col]) 249 | else: 250 | current_dict[col] = [row[col]] 251 | merged_dicts[partition_key] = current_dict 252 | # convert each of the dicts to a pyarrow table in the same way SparkPersistor 253 | # converts, needed to ensure parity between the two 254 | arrow_dict = {} 255 | for partition_key, data_dict in merged_dicts.items(): 256 | data_table = pa.Table.from_pydict(data_dict) 257 | arrow_dict[partition_key] = pc.take( 258 | pa.Table.from_pydict(data_dict), 259 | pc.sort_indices(data_table, sort_keys=[(pk, "ascending") for pk in dataset_schema.primary_keys]), 260 | ) 261 | 262 | # 8. Persist the partition table to s3 263 | written_dict = {} 264 | for partition_key, pa_table in arrow_dict.items(): 265 | self.save_partition_tbl( 266 | (partition_key, pa_table), dataset_name, dataset_version, self.s3_storage, self.s3_path_factory 267 | ) 268 | written_dict[partition_key] = pa_table.num_rows 269 | 270 | return written_dict 271 | -------------------------------------------------------------------------------- /wicker/core/shuffle.py: -------------------------------------------------------------------------------- 1 | """Classes handling the shuffling of data in S3 when committing a dataset. 2 | 3 | When committing a dataset, we sort the data by primary_key before materializing in S3 4 | as Parquet files. 5 | 6 | 1. The ShuffleJob is the unit of work, and it is just an ordered set of Examples that should 7 | be bundled together in one Parquet file 8 | 9 | 2. The ShuffleJobFactory produces ShuffleJobs, using a DatasetWriter object to retrieve the 10 | written examples for divvying up as ShuffleJobs. 11 | 12 | 3. ShuffleWorkers receive the ShuffleJobs and perform the act of retrieving the data and 13 | persisting the data into S3 as Parquet files (one for each ShuffleJob) 14 | """ 15 | from __future__ import annotations 16 | 17 | import collections 18 | import concurrent.futures 19 | import dataclasses 20 | import os 21 | import pickle 22 | import tempfile 23 | from typing import Any, Dict, Generator, List, Optional, Tuple 24 | 25 | import boto3 26 | import pyarrow as pa 27 | import pyarrow.fs as pafs 28 | import pyarrow.parquet as papq 29 | 30 | from wicker.core.column_files import ColumnBytesFileWriter 31 | from wicker.core.definitions import DatasetID, DatasetPartition 32 | from wicker.core.storage import S3DataStorage, S3PathFactory 33 | from wicker.core.writer import DatasetWriterBackend 34 | from wicker.schema import serialization 35 | 36 | # Maximum working set for each worker 37 | DEFAULT_WORKER_MAX_WORKING_SET_SIZE = 16384 38 | 39 | 40 | @dataclasses.dataclass 41 | class ShuffleJob: 42 | """Represents all the shuffling operations that will happen for a given partition (train/eval/test) on a given 43 | compute shard.""" 44 | 45 | dataset_partition: DatasetPartition 46 | files: List[Tuple[str, int]] 47 | 48 | 49 | class ShuffleJobFactory: 50 | def __init__( 51 | self, 52 | writer_backend: DatasetWriterBackend, 53 | worker_max_working_set_size: int = DEFAULT_WORKER_MAX_WORKING_SET_SIZE, 54 | ): 55 | self.writer_backend = writer_backend 56 | 57 | # Factory configurations 58 | self.worker_max_working_set_size = worker_max_working_set_size 59 | 60 | def build_shuffle_jobs(self, dataset_id: DatasetID) -> Generator[ShuffleJob, None, None]: 61 | # Initialize with first item 62 | example_keys = self.writer_backend._metadata_db.scan_sorted(dataset_id) 63 | try: 64 | initial_key = next(example_keys) 65 | except StopIteration: 66 | return 67 | job = ShuffleJob( 68 | dataset_partition=DatasetPartition(dataset_id=dataset_id, partition=initial_key.partition), 69 | files=[(initial_key.row_data_path, initial_key.row_size)], 70 | ) 71 | 72 | # Yield ShuffleJobs as we accumulate ExampleKeys, where each ShuffleJob is upper-bounded in size by 73 | # self.worker_max_working_set_size and has all ExampleKeys from the same partition 74 | for example_key in example_keys: 75 | if ( 76 | example_key.partition == job.dataset_partition.partition 77 | and len(job.files) < self.worker_max_working_set_size 78 | ): 79 | job.files.append((example_key.row_data_path, example_key.row_size)) 80 | continue 81 | 82 | # Yield job and construct new job to keep iterating 83 | yield job 84 | job = ShuffleJob( 85 | dataset_partition=DatasetPartition(dataset_id=dataset_id, partition=example_key.partition), 86 | files=[(example_key.row_data_path, example_key.row_size)], 87 | ) 88 | else: 89 | yield job 90 | 91 | 92 | _download_thread_session: Optional[boto3.session.Session] = None 93 | _download_thread_client: Optional[S3DataStorage] = None 94 | 95 | 96 | def _initialize_download_thread(): 97 | global _download_thread_session 98 | global _download_thread_client 99 | if _download_thread_client is None: 100 | _download_thread_session = boto3.session.Session() 101 | _download_thread_client = S3DataStorage(session=_download_thread_session) 102 | 103 | 104 | class ShuffleWorker: 105 | def __init__( 106 | self, 107 | target_rowgroup_bytes_size: int = int(256e6), 108 | max_worker_threads: int = 16, 109 | max_memory_usage_bytes: int = int(2e9), 110 | storage: S3DataStorage = S3DataStorage(), 111 | s3_path_factory: S3PathFactory = S3PathFactory(), 112 | ): 113 | self.target_rowgroup_bytes_size = target_rowgroup_bytes_size 114 | self.max_worker_threads = max_worker_threads 115 | self.max_memory_usage_bytes = max_memory_usage_bytes 116 | self.storage = storage 117 | self.s3_path_factory = s3_path_factory 118 | 119 | def _download_files(self, job: ShuffleJob) -> Generator[Dict[str, Any], None, None]: 120 | """Downloads the files in a ShuffleJob, yielding a generator of Dict[str, Any] 121 | 122 | Internally, this function maintains a buffer of lookahead downloads to execute downloads 123 | in parallel over a ThreadPoolExecutor, up to a maximum of `max_memory_usage_bytes` bytes. 124 | 125 | Args: 126 | job (ShuffleJob): job to download files for 127 | 128 | Yields: 129 | Generator[Dict[str, Any], None, None]: stream of Dict[str, Any] from downloading the files 130 | in order 131 | """ 132 | # TODO(jchia): Add retries here 133 | def _download_file(filepath: str) -> Dict[str, Any]: 134 | assert _download_thread_client is not None 135 | return pickle.loads(_download_thread_client.fetch_obj_s3(filepath)) 136 | 137 | with concurrent.futures.ThreadPoolExecutor( 138 | self.max_worker_threads, initializer=_initialize_download_thread 139 | ) as executor: 140 | buffer: Dict[int, Tuple[concurrent.futures.Future[Dict[str, Any]], int]] = {} 141 | buffer_size = 0 142 | buffer_index = 0 143 | for current_index in range(len(job.files)): 144 | while buffer_index < len(job.files) and buffer_size < self.max_memory_usage_bytes: 145 | filepath, file_size = job.files[buffer_index] 146 | buffer[buffer_index] = (executor.submit(_download_file, filepath), file_size) 147 | buffer_size += file_size 148 | buffer_index += 1 149 | current_future, current_file_size = buffer[current_index] 150 | yield current_future.result() 151 | buffer_size -= current_file_size 152 | del buffer[current_index] 153 | 154 | def _estimate_target_file_rowgroup_size( 155 | self, 156 | job: ShuffleJob, 157 | target_rowgroup_size_bytes: int = int(256e6), 158 | min_target_rowgroup_size: int = 16, 159 | ) -> int: 160 | """Estimates the number of rows to include in each rowgroup using a target size for the rowgroup 161 | 162 | :param job: job to estimate 163 | :param target_rowgroup_size_bytes: target size in bytes of a rowgroup, defaults to 256MB 164 | :param min_target_rowgroup_size: minimum number of rows in a rowgroup 165 | :return: target number of rows in a rowgroup 166 | """ 167 | average_filesize = sum([size for _, size in job.files]) / len(job.files) 168 | return max(min_target_rowgroup_size, int(target_rowgroup_size_bytes / average_filesize)) 169 | 170 | def process_job(self, job: ShuffleJob) -> pa.Table: 171 | # Load dataset schema 172 | dataset_schema = serialization.loads( 173 | self.storage.fetch_obj_s3( 174 | self.s3_path_factory.get_dataset_schema_path(job.dataset_partition.dataset_id) 175 | ).decode("utf-8") 176 | ) 177 | 178 | # Estimate how many rows to add to each ColumnBytesFile 179 | target_file_rowgroup_size = self._estimate_target_file_rowgroup_size(job) 180 | 181 | # Initialize data containers to dump into parquet 182 | heavy_pointer_columns = dataset_schema.get_pointer_columns() 183 | metadata_columns = dataset_schema.get_non_pointer_columns() 184 | parquet_metadata: Dict[str, List[Any]] = collections.defaultdict(list) 185 | 186 | # Parse each row, uploading heavy_pointer bytes to S3 and storing only pointers 187 | # in parquet_metadata 188 | with ColumnBytesFileWriter( 189 | self.storage, 190 | self.s3_path_factory, 191 | target_file_rowgroup_size=target_file_rowgroup_size, 192 | dataset_name=job.dataset_partition.dataset_id.name, 193 | ) as writer: 194 | for data in self._download_files(job): 195 | for col in metadata_columns: 196 | parquet_metadata[col].append(data[col]) 197 | for col in heavy_pointer_columns: 198 | loc = writer.add(col, data[col]) 199 | parquet_metadata[col].append(loc.to_bytes()) 200 | 201 | # Save parquet_metadata as a PyArrow Table 202 | assert len({len(parquet_metadata[col]) for col in parquet_metadata}) == 1, "All columns must have same length" 203 | return pa.Table.from_pydict(parquet_metadata) 204 | 205 | 206 | def save_index( 207 | dataset_name: str, 208 | dataset_version: str, 209 | final_indices: Dict[str, pa.Table], 210 | s3_path_factory: Optional[S3PathFactory] = None, 211 | s3_storage: Optional[S3DataStorage] = None, 212 | ) -> None: 213 | """Saves a generated final_index into persistent storage 214 | 215 | :param dataset_name: Name of the dataset 216 | :param dataset_version: Version of the dataset 217 | :param final_index: Dictionary of pandas dataframes which is the finalized index 218 | :param pyarrow_filesystem: PyArrow filesystem to use, defaults to None 219 | :param s3_path_factory: S3PathFactory to use 220 | :param s3_storage: S3DataStorage to use 221 | """ 222 | s3_storage = s3_storage if s3_storage is not None else S3DataStorage() 223 | s3_path_factory = s3_path_factory if s3_path_factory is not None else S3PathFactory() 224 | for partition_name in final_indices: 225 | dataset_partition = DatasetPartition( 226 | dataset_id=DatasetID( 227 | name=dataset_name, 228 | version=dataset_version, 229 | ), 230 | partition=partition_name, 231 | ) 232 | 233 | parquet_folder = s3_path_factory.get_dataset_partition_path(dataset_partition, s3_prefix=True) 234 | parquet_path = os.path.join(parquet_folder, "part-0.parquet") 235 | 236 | # Write the Parquet file as one file locally, then upload to S3 237 | with tempfile.NamedTemporaryFile() as tmpfile: 238 | pa_table = final_indices[partition_name] 239 | papq.write_table( 240 | pa_table, 241 | tmpfile.name, 242 | compression="zstd", 243 | row_group_size=None, 244 | filesystem=pafs.LocalFileSystem(), 245 | # We skip writing statistics since it bloats the file, and we don't actually run any queries 246 | # on the Parquet files that could make use of predicate push-down 247 | write_statistics=False, 248 | ) 249 | s3_storage.put_file_s3( 250 | tmpfile.name, 251 | parquet_path, 252 | ) 253 | 254 | return None 255 | -------------------------------------------------------------------------------- /wicker/core/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import signal 3 | 4 | 5 | @contextlib.contextmanager 6 | def time_limit(seconds, error_message: str): 7 | def signal_handler(signum, frame): 8 | raise TimeoutError(f"Timed out!. {error_message}") 9 | 10 | signal.signal(signal.SIGALRM, signal_handler) 11 | signal.alarm(seconds) 12 | try: 13 | yield 14 | finally: 15 | signal.alarm(0) 16 | -------------------------------------------------------------------------------- /wicker/core/writer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | import contextlib 5 | import dataclasses 6 | import hashlib 7 | import os 8 | import pickle 9 | import threading 10 | import time 11 | from concurrent.futures import Executor, Future, ThreadPoolExecutor 12 | from types import TracebackType 13 | from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union 14 | 15 | from wicker.core.definitions import DatasetDefinition, DatasetID 16 | from wicker.core.errors import WickerDatastoreException 17 | from wicker.core.storage import S3DataStorage, S3PathFactory 18 | from wicker.schema import dataparsing, serialization 19 | 20 | 21 | @dataclasses.dataclass 22 | class ExampleKey: 23 | """Unique identifier for one example.""" 24 | 25 | # Name of the partition (usually train/eval/test) 26 | partition: str 27 | # Values of the primary keys for the example (in order of key precedence) 28 | primary_key_values: List[Union[str, int]] 29 | 30 | def hash(self) -> str: 31 | return hashlib.sha256( 32 | "/".join([self.partition, *[str(obj) for obj in self.primary_key_values]]).encode("utf-8") 33 | ).hexdigest() 34 | 35 | 36 | @dataclasses.dataclass 37 | class MetadataDatabaseScanRow: 38 | """Container for data obtained by scanning the MetadataDatabase""" 39 | 40 | partition: str 41 | row_data_path: str 42 | row_size: int 43 | 44 | 45 | class AbstractDatasetWriterMetadataDatabase: 46 | """Database for storing metadata, used from inside the DatasetWriterBackend. 47 | 48 | NOTE: Implementors - this is the main implementation integration point for creating a new kind of DatasetWriter. 49 | """ 50 | 51 | @abc.abstractmethod 52 | def save_row_metadata(self, dataset_id: DatasetID, key: ExampleKey, location: str, row_size: int) -> None: 53 | """Saves a row in the metadata database, marking it as having been uploaded to S3 and 54 | ready for shuffling. 55 | 56 | :param dataset_id: The ID of the dataset to save to 57 | :param key: The key of the example 58 | :param location: The location of the example in S3 59 | :param row_size: The size of the file in S3 60 | """ 61 | pass 62 | 63 | @abc.abstractmethod 64 | def scan_sorted(self, dataset_id: DatasetID) -> Generator[MetadataDatabaseScanRow, None, None]: 65 | """Scans the MetadataDatabase for a **SORTED** stream of MetadataDatabaseScanRows for a given dataset. 66 | The stream is sorted by partition first, and then primary_key_values second. 67 | 68 | Should be fast O(minutes) to perform as this will be called from a single machine to assign chunks to jobs 69 | to run. 70 | 71 | :param dataset_id: The dataset to scan the metadata database for 72 | :return: a Generator of ExampleDBRows in **SORTED** partition + primary_key order 73 | """ 74 | pass 75 | 76 | 77 | class DatasetWriterBackend: 78 | """The backend for a DatasetWriter. 79 | 80 | Responsible for saving and retrieving data used during the dataset writing and committing workflow. 81 | """ 82 | 83 | def __init__( 84 | self, 85 | s3_path_factory: S3PathFactory, 86 | s3_storage: S3DataStorage, 87 | metadata_database: AbstractDatasetWriterMetadataDatabase, 88 | ): 89 | self._s3_path_factory = s3_path_factory 90 | self._s3_storage = s3_storage 91 | self._metadata_db = metadata_database 92 | 93 | def save_row(self, dataset_id: DatasetID, key: ExampleKey, raw_data: Dict[str, Any]) -> None: 94 | """Adds an example to the backend 95 | 96 | :param dataset_id: ID of the dataset to save the row to 97 | :param key: Key of the example to write 98 | :param raw_data: raw data for the example that conforms to the schema provided at initialization 99 | """ 100 | hashed_row_key = key.hash() 101 | pickled_row = pickle.dumps(raw_data) # TODO(jchia): Do we want a more sophisticated storage format here? 102 | row_s3_path = os.path.join( 103 | self._s3_path_factory.get_temporary_row_files_path(dataset_id), 104 | hashed_row_key, 105 | ) 106 | 107 | # Persist data in S3 and in MetadataDatabase 108 | self._s3_storage.put_object_s3(pickled_row, row_s3_path) 109 | self._metadata_db.save_row_metadata(dataset_id, key, row_s3_path, len(pickled_row)) 110 | 111 | def commit_schema( 112 | self, 113 | dataset_definition: DatasetDefinition, 114 | ) -> None: 115 | """Write the schema to the backend as part of the commit step.""" 116 | schema_path = self._s3_path_factory.get_dataset_schema_path(dataset_definition.identifier) 117 | serialized_schema = serialization.dumps(dataset_definition.schema) 118 | self._s3_storage.put_object_s3(serialized_schema.encode(), schema_path) 119 | 120 | 121 | DEFAULT_BUFFER_SIZE_LIMIT = 1000 122 | 123 | 124 | class DatasetWriter: 125 | """DatasetWriter providing async writing functionality. Implementors should override 126 | the ._save_row_impl method to define functionality for saving each individual row from inside the 127 | async thread executors. 128 | """ 129 | 130 | def __init__( 131 | self, 132 | dataset_definition: DatasetDefinition, 133 | metadata_database: AbstractDatasetWriterMetadataDatabase, 134 | s3_path_factory: Optional[S3PathFactory] = None, 135 | s3_storage: Optional[S3DataStorage] = None, 136 | buffer_size_limit: int = DEFAULT_BUFFER_SIZE_LIMIT, 137 | executor: Optional[Executor] = None, 138 | wait_flush_timeout_seconds: int = 300, 139 | ): 140 | """Create a new DatasetWriter 141 | 142 | :param dataset_definition: definition of the dataset 143 | :param s3_path_factory: factory for s3 paths 144 | :param s3_storage: S3-compatible storage for storing data 145 | :param buffer_size_limit: size limit to the number of examples buffered in-memory, defaults to 1000 146 | :param executor: concurrent.futures.Executor to use for async writing, defaults to None 147 | :param wait_flush_timeout_seconds: number of seconds to wait before timing out on flushing 148 | all examples, defaults to 10 149 | """ 150 | self.dataset_definition = dataset_definition 151 | self.backend = DatasetWriterBackend( 152 | s3_path_factory if s3_path_factory is not None else S3PathFactory(), 153 | s3_storage if s3_storage is not None else S3DataStorage(), 154 | metadata_database, 155 | ) 156 | 157 | self.buffer: List[Tuple[ExampleKey, Dict[str, Any]]] = [] 158 | self.buffer_size_limit = buffer_size_limit 159 | 160 | self.wait_flush_timeout_seconds = wait_flush_timeout_seconds 161 | self.executor = executor if executor is not None else ThreadPoolExecutor(max_workers=16) 162 | self.writes_in_flight: Dict[str, Dict[str, Any]] = {} 163 | self.flush_condition_variable = threading.Condition() 164 | 165 | def __enter__(self) -> DatasetWriter: 166 | return self 167 | 168 | def __exit__( 169 | self, 170 | exception_type: Optional[Type[BaseException]], 171 | exception_value: Optional[BaseException], 172 | traceback: Optional[TracebackType], 173 | ) -> None: 174 | self.flush(block=True) 175 | 176 | def __del__(self) -> None: 177 | self.flush(block=True) 178 | 179 | def add_example(self, partition_name: str, raw_data: Dict[str, Any]) -> None: 180 | """Adds an example to the writer 181 | 182 | :param partition_name: partition name where the example belongs 183 | :param raw_data: raw data for the example that conforms to the schema provided at initialization 184 | """ 185 | # Run sanity checks on the data, fill empty fields. 186 | ex = dataparsing.parse_example(raw_data, self.dataset_definition.schema) 187 | example_key = ExampleKey( 188 | partition=partition_name, primary_key_values=[ex[k] for k in self.dataset_definition.schema.primary_keys] 189 | ) 190 | self.buffer.append((example_key, ex)) 191 | 192 | # Flush buffer to persistent storage if at size limit 193 | if len(self.buffer) > self.buffer_size_limit: 194 | self.flush(block=False) 195 | 196 | def flush(self, block: bool = True) -> None: 197 | """Flushes the writer 198 | 199 | :param block: whether to block on flushing all currently buffered examples, defaults to True 200 | :raises TimeoutError: timing out on flushing all examples 201 | """ 202 | batch_data = self.buffer 203 | self.buffer = [] 204 | self._save_batch_data(batch_data) 205 | 206 | if block: 207 | with self._block_on_writes_in_flight(max_in_flight=0, timeout_seconds=self.wait_flush_timeout_seconds): 208 | pass 209 | 210 | @contextlib.contextmanager 211 | def _block_on_writes_in_flight(self, max_in_flight: int = 0, timeout_seconds: int = 60) -> Iterator[None]: 212 | """Blocks until number of writes in flight <= max_in_flight and yields control to the with block 213 | 214 | Usage: 215 | >>> with self._block_on_writes_in_flight(max_in_flight=10): 216 | >>> # Code here holds the exclusive lock on self.flush_condition_variable 217 | >>> ... 218 | >>> # Code here has released the lock, and has no guarantees about the number of writes in flight 219 | 220 | :param max_in_flight: maximum number of writes in flight, defaults to 0 221 | :param timeout_seconds: maximum number of seconds to wait, defaults to 10 222 | :raises TimeoutError: if waiting for more than timeout_seconds 223 | """ 224 | start_time = time.time() 225 | with self.flush_condition_variable: 226 | while len(self.writes_in_flight) > max_in_flight: 227 | if time.time() - start_time > timeout_seconds: 228 | raise TimeoutError( 229 | f"Timed out while flushing dataset writes with {len(self.writes_in_flight)} writes in flight" 230 | ) 231 | self.flush_condition_variable.wait() 232 | yield 233 | 234 | def _save_row(self, key: ExampleKey, data: Dict[str, Any]) -> str: 235 | self.backend.save_row(self.dataset_definition.identifier, key, data) 236 | return key.hash() 237 | 238 | def _save_batch_data(self, batch_data: List[Tuple[ExampleKey, Dict[str, Any]]]) -> None: 239 | """Save a batch of data to persistent storage 240 | 241 | :param row_keys: Unique identifiers for each row Uses only 0-9A-F, can be used as a file name. 242 | :param batch_data: Batch of data 243 | """ 244 | 245 | def done_callback(future: Future[str]) -> None: 246 | # TODO(jchia): We can add retries on failure by reappending to the buffer 247 | # this currently may raise a CancelledError or some Exception thrown by the .save_row function 248 | with self.flush_condition_variable: 249 | del self.writes_in_flight[future.result()] 250 | self.flush_condition_variable.notify() 251 | 252 | for key, data in batch_data: 253 | # Keep the number of writes in flight always smaller than 2 * self.buffer_size_limit 254 | with self._block_on_writes_in_flight( 255 | max_in_flight=2 * self.buffer_size_limit, 256 | timeout_seconds=self.wait_flush_timeout_seconds, 257 | ): 258 | key_hash = key.hash() 259 | if key_hash in self.writes_in_flight: 260 | raise WickerDatastoreException( 261 | f"Error: data example has non unique key {key}, primary keys must be unique" 262 | ) 263 | self.writes_in_flight[key_hash] = data 264 | future = self.executor.submit(self._save_row, key, data) 265 | future.add_done_callback(done_callback) 266 | -------------------------------------------------------------------------------- /wicker/plugins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/plugins/__init__.py -------------------------------------------------------------------------------- /wicker/plugins/dynamodb.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import heapq 3 | from typing import Generator, List, Tuple 4 | 5 | try: 6 | import pynamodb # type: ignore 7 | except ImportError: 8 | raise RuntimeError( 9 | "pynamodb is not detected in current environment, install Wicker with extra arguments:" 10 | " `pip install wicker[dynamodb]`" 11 | ) 12 | from pynamodb.attributes import NumberAttribute, UnicodeAttribute 13 | from pynamodb.models import Model 14 | from retry import retry 15 | 16 | from wicker.core.config import get_config 17 | from wicker.core.definitions import DatasetID 18 | from wicker.core.writer import ( 19 | AbstractDatasetWriterMetadataDatabase, 20 | ExampleKey, 21 | MetadataDatabaseScanRow, 22 | ) 23 | 24 | # DANGER: If these constants are ever changed, this is a backward-incompatible change. 25 | # Make sure that all the writers and readers of dynamodb are in sync when changing. 26 | NUM_DYNAMODB_SHARDS = 32 27 | HASH_PREFIX_LENGTH = 4 28 | DYNAMODB_QUERY_PAGINATION_LIMIT = 1000 29 | 30 | 31 | @dataclasses.dataclass(frozen=True) 32 | class DynamoDBConfig: 33 | table_name: str 34 | region: str 35 | 36 | 37 | def get_dynamodb_config() -> DynamoDBConfig: 38 | raw_config = get_config().raw 39 | if "dynamodb_config" not in raw_config: 40 | raise RuntimeError("Could not find 'dynamodb' parameters in config") 41 | if "table_name" not in raw_config["dynamodb_config"]: 42 | raise RuntimeError("Could not find 'table_name' parameter in config.dynamodb_config") 43 | if "region" not in raw_config["dynamodb_config"]: 44 | raise RuntimeError("Could not find 'region' parameter in config.dynamodb_config") 45 | return DynamoDBConfig( 46 | table_name=raw_config["dynamodb_config"]["table_name"], 47 | region=raw_config["dynamodb_config"]["region"], 48 | ) 49 | 50 | 51 | class DynamoDBExampleDBRow(Model): 52 | class Meta: 53 | table_name = get_dynamodb_config().table_name 54 | region = get_dynamodb_config().region 55 | 56 | dataset_id = UnicodeAttribute(hash_key=True) 57 | example_id = UnicodeAttribute(range_key=True) 58 | partition = UnicodeAttribute() 59 | row_data_path = UnicodeAttribute() 60 | row_size = NumberAttribute() 61 | 62 | 63 | def _key_to_row_id_and_shard_id(example_key: ExampleKey) -> Tuple[str, int]: 64 | """Deterministically convert an ExampleKey into a row_id and a shard_id which are used as 65 | DynamoDB RangeKeys and HashKeys respectively. 66 | 67 | HashKeys help to increase read/write throughput by allowing us to use different partitions. 68 | RangeKeys are how the rows are sorted within partitions by DynamoDB. 69 | 70 | We completely randomize the hash and range key to optimize for write throughput, but this means 71 | that sorting needs to be done entirely client-side in our application. 72 | """ 73 | partition_example_id = f"{example_key.partition}//{'/'.join([str(key) for key in example_key.primary_key_values])}" 74 | hash = example_key.hash() 75 | shard = int(hash, 16) % NUM_DYNAMODB_SHARDS 76 | return partition_example_id, shard 77 | 78 | 79 | def _dataset_shard_name(dataset_id: DatasetID, shard_id: int) -> str: 80 | """Get the name of the DynamoDB partition for a given dataset_definition and shard number""" 81 | return f"{dataset_id}_shard{shard_id:02d}" 82 | 83 | 84 | class DynamodbMetadataDatabase(AbstractDatasetWriterMetadataDatabase): 85 | def save_row_metadata(self, dataset_id: DatasetID, key: ExampleKey, location: str, row_size: int) -> None: 86 | """Saves a row in the metadata database, marking it as having been uploaded to S3 and 87 | ready for shuffling. 88 | 89 | :param dataset_id: The ID of the dataset to save to 90 | :param key: The key of the example 91 | :param location: The location of the example in S3 92 | :param row_size: The size of the file in S3 93 | """ 94 | partition_example_id, shard_id = _key_to_row_id_and_shard_id(key) 95 | entry = DynamoDBExampleDBRow( 96 | dataset_id=_dataset_shard_name(dataset_id, shard_id), 97 | example_id=partition_example_id, 98 | partition=key.partition, 99 | row_data_path=location, 100 | row_size=row_size, 101 | ) 102 | entry.save() 103 | 104 | def scan_sorted(self, dataset_id: DatasetID) -> Generator[MetadataDatabaseScanRow, None, None]: 105 | """Scans the MetadataDatabase for a **SORTED** list of ExampleKeys for a given dataset. Should be fast O(minutes) 106 | to perform as this will be called from a single machine to assign chunks to jobs to run. 107 | 108 | :param dataset: The dataset to scan the metadata database for 109 | :return: a Generator of MetadataDatabaseScanRow in **SORTED** primary_key order 110 | """ 111 | 112 | @retry(pynamodb.exceptions.QueryError, tries=10, backoff=2, delay=4, jitter=(0, 2)) 113 | def shard_iterator(shard_id: int) -> Generator[DynamoDBExampleDBRow, None, None]: 114 | """Yields DynamoDBExampleDBRows from a given shard to exhaustion, sorted by example_id in ascending order 115 | DynamoDBExampleDBRows are yielded in sorted order of the Dynamodb RangeKey, which is the example_id 116 | """ 117 | hash_key = _dataset_shard_name(dataset_id, shard_id) 118 | last_evaluated_key = None 119 | while True: 120 | query_results = DynamoDBExampleDBRow.query( 121 | hash_key, 122 | consistent_read=True, 123 | last_evaluated_key=last_evaluated_key, 124 | limit=DYNAMODB_QUERY_PAGINATION_LIMIT, 125 | ) 126 | yield from query_results 127 | if query_results.last_evaluated_key is None: 128 | break 129 | last_evaluated_key = query_results.last_evaluated_key 130 | 131 | # Individual shards have their rows already in sorted order 132 | # Elements are popped off each shard to exhaustion and put into a minheap 133 | # We yield from the heap to exhaustion to provide a stream of globally sorted example_ids 134 | heap: List[Tuple[str, int, DynamoDBExampleDBRow]] = [] 135 | shard_iterators = {shard_id: shard_iterator(shard_id) for shard_id in range(NUM_DYNAMODB_SHARDS)} 136 | for shard_id, iterator in shard_iterators.items(): 137 | try: 138 | row = next(iterator) 139 | heapq.heappush(heap, (row.example_id, shard_id, row)) 140 | except StopIteration: 141 | pass 142 | while heap: 143 | _, shard_id, row = heapq.heappop(heap) 144 | try: 145 | nextrow = next(shard_iterators[shard_id]) 146 | heapq.heappush(heap, (nextrow.example_id, shard_id, nextrow)) 147 | except StopIteration: 148 | pass 149 | yield MetadataDatabaseScanRow( 150 | partition=row.partition, 151 | row_data_path=row.row_data_path, 152 | row_size=row.row_size, 153 | ) 154 | -------------------------------------------------------------------------------- /wicker/plugins/flyte.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import dataclasses 3 | import json 4 | import tempfile 5 | from typing import Dict, List, Type, cast 6 | 7 | try: 8 | import flytekit # type: ignore 9 | except ImportError: 10 | raise RuntimeError( 11 | "flytekit is not detected in current environment, install Wicker with extra arguments:" 12 | " `pip install wicker[flyte]`" 13 | ) 14 | import pyarrow as pa # type: ignore 15 | from flytekit.extend import TypeEngine, TypeTransformer # type: ignore 16 | 17 | from wicker.core.definitions import DatasetID, DatasetPartition 18 | from wicker.core.shuffle import ShuffleJob, ShuffleJobFactory, ShuffleWorker, save_index 19 | from wicker.core.storage import S3DataStorage, S3PathFactory 20 | from wicker.core.writer import DatasetWriterBackend 21 | from wicker.plugins.dynamodb import DynamodbMetadataDatabase 22 | 23 | ### 24 | # Custom type definitions to clean up passing of data between Flyte tasks 25 | ### 26 | 27 | 28 | class ShuffleJobTransformer(TypeTransformer[ShuffleJob]): 29 | _TYPE_INFO = flytekit.BlobType( 30 | format="binary", 31 | dimensionality=flytekit.BlobType.BlobDimensionality.SINGLE, 32 | ) 33 | 34 | def __init__(self): 35 | super(ShuffleJobTransformer, self).__init__(name="shufflejob-transform", t=ShuffleJob) 36 | 37 | def get_literal_type(self, t: Type[ShuffleJob]) -> flytekit.LiteralType: 38 | """ 39 | This is useful to tell the Flytekit type system that ``ShuffleJob`` actually refers to what corresponding type 40 | In this example, we say its of format binary (do not try to introspect) and there are more than one files in it 41 | """ 42 | return flytekit.LiteralType(blob=self._TYPE_INFO) 43 | 44 | @staticmethod 45 | def _shuffle_jobs_to_bytes(job: ShuffleJob) -> bytes: 46 | return json.dumps( 47 | { 48 | "dataset_partition": { 49 | "dataset_id": { 50 | "name": job.dataset_partition.dataset_id.name, 51 | "version": job.dataset_partition.dataset_id.version, 52 | }, 53 | "partition": job.dataset_partition.partition, 54 | }, 55 | "files": job.files, 56 | } 57 | ).encode("utf-8") 58 | 59 | @staticmethod 60 | def _shuffle_jobs_from_bytes(b: bytes) -> ShuffleJob: 61 | data = json.loads(b.decode("utf-8")) 62 | return ShuffleJob( 63 | dataset_partition=DatasetPartition( 64 | dataset_id=DatasetID( 65 | name=data["dataset_partition"]["dataset_id"]["name"], 66 | version=data["dataset_partition"]["dataset_id"]["version"], 67 | ), 68 | partition=data["dataset_partition"]["partition"], 69 | ), 70 | files=[(path, size) for path, size in data["files"]], 71 | ) 72 | 73 | def to_literal( 74 | self, 75 | ctx: flytekit.FlyteContext, 76 | python_val: ShuffleJob, 77 | python_type: Type[ShuffleJob], 78 | expected: flytekit.LiteralType, 79 | ) -> flytekit.Literal: 80 | """ 81 | This method is used to convert from given python type object ``MyDataset`` to the Literal representation 82 | """ 83 | # Step 1: lets upload all the data into a remote place recommended by Flyte 84 | remote_path = ctx.file_access.get_random_remote_path() 85 | with tempfile.NamedTemporaryFile() as tmpfile: 86 | tmpfile.write(self._shuffle_jobs_to_bytes(python_val)) 87 | tmpfile.flush() 88 | tmpfile.seek(0) 89 | ctx.file_access.upload(tmpfile.name, remote_path) 90 | # Step 2: lets return a pointer to this remote_path in the form of a literal 91 | return flytekit.Literal( 92 | scalar=flytekit.Scalar( 93 | blob=flytekit.Blob(uri=remote_path, metadata=flytekit.BlobMetadata(type=self._TYPE_INFO)) 94 | ) 95 | ) 96 | 97 | def to_python_value( 98 | self, ctx: flytekit.FlyteContext, lv: flytekit.Literal, expected_python_type: Type[ShuffleJob] 99 | ) -> ShuffleJob: 100 | """ 101 | In this function we want to be able to re-hydrate the custom object from Flyte Literal value 102 | """ 103 | # Step 1: lets download remote data locally 104 | local_path = ctx.file_access.get_random_local_path() 105 | ctx.file_access.download(lv.scalar.blob.uri, local_path) 106 | # Step 2: create the ShuffleJob object 107 | with open(local_path, "rb") as f: 108 | return self._shuffle_jobs_from_bytes(f.read()) 109 | 110 | 111 | @dataclasses.dataclass 112 | class ShuffleWorkerResults: 113 | partition: str 114 | pa_table: pa.Table 115 | 116 | 117 | class ShuffleWorkerResultsTransformer(TypeTransformer[ShuffleWorkerResults]): 118 | _TYPE_INFO = flytekit.BlobType( 119 | format="binary", 120 | dimensionality=flytekit.BlobType.BlobDimensionality.SINGLE, 121 | ) 122 | 123 | def __init__(self): 124 | super(ShuffleWorkerResultsTransformer, self).__init__( 125 | name="shuffleworkerresults-transform", 126 | t=ShuffleWorkerResults, 127 | ) 128 | 129 | def get_literal_type(self, t: Type[ShuffleWorkerResults]) -> flytekit.LiteralType: 130 | """ 131 | This is useful to tell the Flytekit type system that ``ShuffleWorkerResults`` actually refers to 132 | what corresponding type In this example, we say its of format binary (do not try to introspect) and 133 | there are more than one files in it 134 | """ 135 | return flytekit.LiteralType(blob=self._TYPE_INFO) 136 | 137 | def to_literal( 138 | self, 139 | ctx: flytekit.FlyteContext, 140 | python_val: ShuffleWorkerResults, 141 | python_type: Type[ShuffleWorkerResults], 142 | expected: flytekit.LiteralType, 143 | ) -> flytekit.Literal: 144 | """ 145 | This method is used to convert from given python type object ``ShuffleWorkerResults`` 146 | to the Literal representation 147 | """ 148 | # Step 1: lets upload all the data into a remote place recommended by Flyte 149 | remote_path = ctx.file_access.get_random_remote_path() 150 | local_path = ctx.file_access.get_random_local_path() 151 | with pa.ipc.new_stream( 152 | local_path, python_val.pa_table.schema.with_metadata({"partition": python_val.partition}) 153 | ) as stream: 154 | stream.write(python_val.pa_table) 155 | ctx.file_access.upload(local_path, remote_path) 156 | # Step 2: lets return a pointer to this remote_path in the form of a literal 157 | return flytekit.Literal( 158 | scalar=flytekit.Scalar( 159 | blob=flytekit.Blob(uri=remote_path, metadata=flytekit.BlobMetadata(type=self._TYPE_INFO)) 160 | ) 161 | ) 162 | 163 | def to_python_value( 164 | self, ctx: flytekit.FlyteContext, lv: flytekit.Literal, expected_python_type: Type[ShuffleWorkerResults] 165 | ) -> ShuffleWorkerResults: 166 | """ 167 | In this function we want to be able to re-hydrate the custom object from Flyte Literal value 168 | """ 169 | # Step 1: lets download remote data locally 170 | local_path = ctx.file_access.get_random_local_path() 171 | ctx.file_access.download(lv.scalar.blob.uri, local_path) 172 | # Step 2: create the ShuffleWorkerResults object 173 | with pa.ipc.open_stream(local_path) as reader: 174 | return ShuffleWorkerResults( 175 | pa_table=pa.Table.from_batches([b for b in reader]), 176 | partition=reader.schema.metadata[b"partition"].decode("utf-8"), 177 | ) 178 | 179 | 180 | TypeEngine.register(ShuffleJobTransformer()) 181 | TypeEngine.register(ShuffleWorkerResultsTransformer()) 182 | 183 | 184 | ### 185 | # Task and Workflow definitions 186 | ### 187 | 188 | 189 | @flytekit.task(requests=flytekit.Resources(mem="2Gi", cpu="1"), retries=2) 190 | def initialize_dataset( 191 | schema_json_str: str, 192 | dataset_id: str, 193 | ) -> str: 194 | """Write the schema to the storage.""" 195 | s3_path_factory = S3PathFactory() 196 | s3_storage = S3DataStorage() 197 | schema_path = s3_path_factory.get_dataset_schema_path(DatasetID.from_str(dataset_id)) 198 | s3_storage.put_object_s3(schema_json_str.encode("utf-8"), schema_path) 199 | return schema_json_str 200 | 201 | 202 | @flytekit.task(requests=flytekit.Resources(mem="8Gi", cpu="2"), retries=2) 203 | def create_shuffling_jobs( 204 | schema_json_str: str, 205 | dataset_id: str, 206 | worker_max_working_set_size: int = 16384, 207 | ) -> List[ShuffleJob]: 208 | """Read the DynamoDB and return a list of shuffling jobs to distribute. 209 | 210 | The job descriptions are stored into files managed by Flyte. 211 | :param dataset_id: string representation of the DatasetID we need to process (dataset name + version). 212 | :param schema_json_str: string representation of the dataset schema 213 | :param max_rows_per_worker: Maximum number of rows to assign per working set, defaults to 16384 but can be 214 | increased if dataset sizes are so large that we want to use fewer workers and don't mind the committing 215 | step taking longer per-worker. 216 | :return: a list of shuffling jobs to do. 217 | """ 218 | # TODO(jchia): Dynamically decide on what backends to use for S3 and MetadataDatabase instead of hardcoding here 219 | backend = DatasetWriterBackend(S3PathFactory(), S3DataStorage(), DynamodbMetadataDatabase()) 220 | job_factory = ShuffleJobFactory(backend, worker_max_working_set_size=worker_max_working_set_size) 221 | return list(job_factory.build_shuffle_jobs(DatasetID.from_str(dataset_id))) 222 | 223 | 224 | @flytekit.task(requests=flytekit.Resources(mem="8Gi", cpu="2"), retries=4, cache=True, cache_version="v1") 225 | def run_shuffling_job(job: ShuffleJob) -> ShuffleWorkerResults: 226 | """Run one shuffling job 227 | :param job: the ShuffleJob for this worker to run 228 | :return: pyarrow table containing only metadata and pointers to the ColumnBytesFiles in S3 for 229 | bytes columns in the dataset. 230 | """ 231 | worker = ShuffleWorker(storage=S3DataStorage(), s3_path_factory=S3PathFactory()) 232 | return ShuffleWorkerResults( 233 | pa_table=worker.process_job(job), 234 | partition=job.dataset_partition.partition, 235 | ) 236 | 237 | 238 | @flytekit.task(requests=flytekit.Resources(mem="8Gi", cpu="2"), retries=2) 239 | def finalize_shuffling_jobs(dataset_id: str, shuffle_results: List[ShuffleWorkerResults]) -> Dict[str, int]: 240 | """Aggregate the indexes from the various shuffling jobs and publish the final parquet files for the dataset. 241 | :param dataset_id: string representation of the DatasetID we need to process (dataset name + version). 242 | :param shuffled_jobs_files: list of files containing the pandas Dataframe generated by the shuffling jobs. 243 | :return: A dictionary mapping partition_name -> size_of_partition. 244 | """ 245 | results_by_partition = collections.defaultdict(list) 246 | for result in shuffle_results: 247 | results_by_partition[result.partition].append(result.pa_table) 248 | 249 | tables_by_partition: Dict[str, pa.Table] = {} 250 | for partition in results_by_partition: 251 | tables_by_partition[partition] = pa.concat_tables(results_by_partition[partition]) 252 | 253 | dataset_id_obj = DatasetID.from_str(dataset_id) 254 | save_index( 255 | dataset_id_obj.name, 256 | dataset_id_obj.version, 257 | tables_by_partition, 258 | s3_storage=S3DataStorage(), 259 | s3_path_factory=S3PathFactory(), 260 | ) 261 | return {partition_name: len(tables_by_partition[partition_name]) for partition_name in tables_by_partition} 262 | 263 | 264 | @flytekit.workflow # type: ignore 265 | def WickerDataShufflingWorkflow( 266 | dataset_id: str, 267 | schema_json_str: str, 268 | worker_max_working_set_size: int = 16384, 269 | ) -> Dict[str, int]: 270 | """Pipeline finalizing a wicker dataset. 271 | :param dataset_id: string representation of the DatasetID we need to process (dataset name + version). 272 | :param schema_json_str: string representation of the schema, serialized as JSON 273 | :return: A dictionary mapping partition_name -> size_of_partition. 274 | """ 275 | schema_json_str_committed = initialize_dataset( 276 | dataset_id=dataset_id, 277 | schema_json_str=schema_json_str, 278 | ) 279 | jobs = create_shuffling_jobs( 280 | schema_json_str=schema_json_str_committed, 281 | dataset_id=dataset_id, 282 | worker_max_working_set_size=worker_max_working_set_size, 283 | ) 284 | shuffle_results = flytekit.map_task(run_shuffling_job, metadata=flytekit.TaskMetadata(retries=1))(job=jobs) 285 | result = cast(Dict[str, int], finalize_shuffling_jobs(dataset_id=dataset_id, shuffle_results=shuffle_results)) 286 | return result 287 | -------------------------------------------------------------------------------- /wicker/plugins/spark.py: -------------------------------------------------------------------------------- 1 | """Spark plugin for writing a dataset with Spark only (no external metadata database required) 2 | 3 | This plugin does an expensive global sorting step using Spark, which could be prohibitive 4 | for large datasets. 5 | """ 6 | from __future__ import annotations 7 | 8 | from typing import Any, Dict, Iterable, List, Optional, Tuple 9 | 10 | import pyarrow as pa 11 | import pyarrow.compute as pc 12 | 13 | try: 14 | import pyspark 15 | except ImportError: 16 | raise RuntimeError( 17 | "pyspark is not detected in current environment, install Wicker with extra arguments:" 18 | " `pip install wicker[spark]`" 19 | ) 20 | 21 | from operator import add 22 | 23 | from wicker import schema as schema_module 24 | from wicker.core.definitions import DatasetID 25 | from wicker.core.errors import WickerDatastoreException 26 | from wicker.core.persistance import AbstractDataPersistor 27 | from wicker.core.storage import S3DataStorage, S3PathFactory 28 | from wicker.schema import serialization 29 | 30 | DEFAULT_SPARK_PARTITION_SIZE = 256 31 | MAX_COL_FILE_NUMROW = 50 # TODO(isaak-willett): Magic number, we should derive this based on row size 32 | 33 | PrimaryKeyTuple = Tuple[Any, ...] 34 | UnparsedExample = Dict[str, Any] 35 | ParsedExample = Dict[str, Any] 36 | PointerParsedExample = Dict[str, Any] 37 | 38 | 39 | def persist_wicker_dataset( 40 | dataset_name: str, 41 | dataset_version: str, 42 | dataset_schema: schema_module.DatasetSchema, 43 | rdd: pyspark.rdd.RDD[Tuple[str, UnparsedExample]], 44 | s3_storage: S3DataStorage = S3DataStorage(), 45 | s3_path_factory: S3PathFactory = S3PathFactory(), 46 | local_reduction: bool = False, 47 | sort: bool = True, 48 | partition_size=DEFAULT_SPARK_PARTITION_SIZE, 49 | ) -> Optional[Dict[str, int]]: 50 | """ 51 | Persist wicker dataset public facing api function, for api consistency. 52 | :param dataset_name: name of dataset persisted 53 | :type dataset_name: str 54 | :param dataset_version: version of dataset persisted 55 | :type dataset_version: str 56 | :param dataset_schema: schema of dataset to be persisted 57 | :type dataset_schema: DatasetSchema 58 | :param rdd: rdd of data to persist 59 | :type rdd: RDD 60 | :param s3_storage: s3 storage abstraction 61 | :type s3_storage: S3DataStorage 62 | :param s3_path_factory: s3 path abstraction 63 | :type s3_path_factory: S3PathFactory 64 | :param local_reduction: if true, reduce tables on main instance (no spark). This is useful if the spark reduction 65 | is not feasible for a large dataste. 66 | :type local_reduction: bool 67 | :param sort: if true, sort the resulting table by primary keys 68 | :type sort: bool 69 | :param partition_size: partition size during the sort 70 | :type partition_size: int 71 | """ 72 | return SparkPersistor(s3_storage, s3_path_factory).persist_wicker_dataset( 73 | dataset_name, 74 | dataset_version, 75 | dataset_schema, 76 | rdd, 77 | local_reduction=local_reduction, 78 | sort=sort, 79 | partition_size=partition_size, 80 | ) 81 | 82 | 83 | class SparkPersistor(AbstractDataPersistor): 84 | def __init__( 85 | self, 86 | s3_storage: S3DataStorage = S3DataStorage(), 87 | s3_path_factory: S3PathFactory = S3PathFactory(), 88 | ) -> None: 89 | """ 90 | Init a SparkPersistor 91 | 92 | :param s3_storage: The storage abstraction for S3 93 | :type s3_storage: S3DataStore 94 | :param s3_path_factory: The path factory for generating s3 paths 95 | based on dataset name and version 96 | :type s3_path_factory: S3PathFactory 97 | """ 98 | super().__init__(s3_storage, s3_path_factory) 99 | 100 | def persist_wicker_dataset( 101 | self, 102 | dataset_name: str, 103 | dataset_version: str, 104 | schema: schema_module.DatasetSchema, 105 | rdd: pyspark.rdd.RDD[Tuple[str, UnparsedExample]], 106 | local_reduction: bool = False, 107 | sort: bool = True, 108 | partition_size=DEFAULT_SPARK_PARTITION_SIZE, 109 | ) -> Optional[Dict[str, int]]: 110 | """ 111 | Persist the current rdd dataset defined by name, version, schema, and data. 112 | """ 113 | # check if variables have been set ie: not None 114 | if ( 115 | not isinstance(dataset_name, str) 116 | or not isinstance(dataset_version, str) 117 | or not isinstance(schema, schema_module.DatasetSchema) 118 | or not isinstance(rdd, pyspark.rdd.RDD) 119 | ): 120 | raise ValueError("Current dataset variables not all set, set all to proper not None values") 121 | 122 | # define locally for passing to spark rdd ops, breaks if relying on self 123 | # since it passes to spark engine and we lose self context 124 | s3_storage = self.s3_storage 125 | s3_path_factory = self.s3_path_factory 126 | parse_row = self.parse_row 127 | get_row_keys = self.get_row_keys 128 | persist_wicker_partition = self.persist_wicker_partition 129 | save_partition_tbl = self.save_partition_tbl 130 | 131 | # put the schema up on to s3 132 | schema_path = s3_path_factory.get_dataset_schema_path(DatasetID(name=dataset_name, version=dataset_version)) 133 | s3_storage.put_object_s3(serialization.dumps(schema).encode("utf-8"), schema_path) 134 | 135 | # parse the rows and ensure validation passes, ie: rows actual data matches expected types 136 | rdd0 = rdd.mapValues(lambda row: parse_row(row, schema)) # type: ignore 137 | 138 | # Make sure to cache the RDD to ease future computations, since it seems that sortBy and zipWithIndex 139 | # trigger actions and we want to avoid recomputing the source RDD at all costs 140 | rdd0 = rdd0.cache() 141 | dataset_size = rdd0.count() 142 | 143 | rdd1 = rdd0.keyBy(lambda row: get_row_keys(row, schema)) 144 | 145 | # Sort RDD by keys 146 | rdd2: pyspark.rdd.RDD[Tuple[Tuple[Any, ...], Tuple[str, ParsedExample]]] = rdd1.sortByKey( 147 | # TODO(jchia): Magic number, we should derive this based on row size 148 | numPartitions=max(1, dataset_size // partition_size), 149 | ascending=True, 150 | ) 151 | 152 | def set_partition(iterator: Iterable[PrimaryKeyTuple]) -> Iterable[int]: 153 | key_set = set(iterator) 154 | yield len(key_set) 155 | 156 | # the number of unique keys in rdd partitions 157 | # this is softer check than collecting all the keys in all partitions to check uniqueness 158 | rdd_key_count: int = rdd2.map(lambda x: x[0]).mapPartitions(set_partition).reduce(add) 159 | num_unique_keys = rdd_key_count 160 | if dataset_size != num_unique_keys: 161 | raise WickerDatastoreException( 162 | f"""Error: dataset examples do not have unique primary key tuples. 163 | Dataset has has {dataset_size} examples but {num_unique_keys} unique primary keys""" 164 | ) 165 | 166 | # persist the spark partition to S3Storage 167 | rdd3 = rdd2.values() 168 | 169 | rdd4 = rdd3.mapPartitions( 170 | lambda spark_iterator: persist_wicker_partition( 171 | dataset_name, 172 | spark_iterator, 173 | schema, 174 | s3_storage, 175 | s3_path_factory, 176 | target_max_column_file_numrows=MAX_COL_FILE_NUMROW, 177 | ) 178 | ) 179 | 180 | if not local_reduction: 181 | # combine the rdd by the keys in the pyarrow table 182 | rdd5 = rdd4.combineByKey( 183 | createCombiner=lambda data: pa.Table.from_pydict( 184 | {col: [data[col]] for col in schema.get_all_column_names()} 185 | ), 186 | mergeValue=lambda tbl, data: pa.Table.from_batches( 187 | [ 188 | *tbl.to_batches(), # type: ignore 189 | *pa.Table.from_pydict({col: [data[col]] for col in schema.get_all_column_names()}).to_batches(), 190 | ] 191 | ), 192 | mergeCombiners=lambda tbl1, tbl2: pa.Table.from_batches( 193 | [ 194 | *tbl1.to_batches(), # type: ignore 195 | *tbl2.to_batches(), # type: ignore 196 | ] 197 | ), 198 | ) 199 | 200 | # create the partition tables 201 | if sort: 202 | rdd6 = rdd5.mapValues( 203 | lambda pa_tbl: pc.take( 204 | pa_tbl, 205 | pc.sort_indices( 206 | pa_tbl, 207 | sort_keys=[(pk, "ascending") for pk in schema.primary_keys], 208 | ), 209 | ) 210 | ) 211 | else: 212 | rdd6 = rdd5 213 | 214 | # save the parition table to s3 215 | rdd7 = rdd6.map( 216 | lambda partition_table: save_partition_tbl( 217 | partition_table, dataset_name, dataset_version, s3_storage, s3_path_factory 218 | ) 219 | ) 220 | rdd_list = rdd7.collect() 221 | written = {partition: size for partition, size in rdd_list} 222 | else: 223 | # In normal operation, rdd5 may have thousands of partitions at the start of operation, 224 | # however because there are typically only at most three dataset splits (train,test,val), 225 | # the output of rdd5 has only three partitions with any data. Empriically this has lead to 226 | # issues related to driver memory. As a work around, we can instead perform this reduction 227 | # manually outside of the JVM without much of a performance penalty since spark's parallelism 228 | # was not being taken advantage of anyway. 229 | 230 | # We are going to iterate over each partition, grab its pyarrow table, and merge them. 231 | # to avoid localIterator from running rdd4 sequentially, we first cache it and trigger an action. 232 | rdd4 = rdd4.cache() 233 | _ = rdd4.count() 234 | 235 | # the rest of this is adapted from the map task persist 236 | merged_dicts: Dict[str, Dict[str, List[Any]]] = {} 237 | for partition_key, row in rdd4.toLocalIterator(): 238 | current_dict: Dict[str, List[Any]] = merged_dicts.get(partition_key, {}) 239 | for col in row.keys(): 240 | if col in current_dict: 241 | current_dict[col].append(row[col]) 242 | else: 243 | current_dict[col] = [row[col]] 244 | merged_dicts[partition_key] = current_dict 245 | 246 | arrow_dict = {} 247 | for partition_key, data_dict in merged_dicts.items(): 248 | data_table = pa.Table.from_pydict(data_dict) 249 | if sort: 250 | arrow_dict[partition_key] = pc.take( 251 | pa.Table.from_pydict(data_dict), 252 | pc.sort_indices(data_table, sort_keys=[(pk, "ascending") for pk in schema.primary_keys]), 253 | ) 254 | else: 255 | arrow_dict[partition_key] = data_table 256 | 257 | written = {} 258 | for partition_key, pa_table in arrow_dict.items(): 259 | self.save_partition_tbl( 260 | (partition_key, pa_table), dataset_name, dataset_version, self.s3_storage, self.s3_path_factory 261 | ) 262 | written[partition_key] = pa_table.num_rows 263 | 264 | return written 265 | 266 | @staticmethod 267 | def get_row_keys( 268 | partition_data_tup: Tuple[str, ParsedExample], schema: schema_module.DatasetSchema 269 | ) -> PrimaryKeyTuple: 270 | """ 271 | Get the keys of a row based on the parition tuple and the data schema. 272 | 273 | :param partition_data_tup: Tuple of partition id and ParsedExample row 274 | :type partition_data_tup: Tuple[str, ParsedExample] 275 | :return: Tuple of primary key values from parsed row and schema 276 | :rtype: PrimaryKeyTuple 277 | """ 278 | partition, data = partition_data_tup 279 | return (partition,) + tuple(data[pk] for pk in schema.primary_keys) 280 | -------------------------------------------------------------------------------- /wicker/plugins/wandb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, Literal 3 | 4 | import wandb 5 | 6 | from wicker.core.config import get_config 7 | from wicker.core.definitions import DatasetID 8 | from wicker.core.storage import S3PathFactory 9 | 10 | 11 | def version_dataset( 12 | dataset_name: str, 13 | dataset_version: str, 14 | entity: str, 15 | metadata: Dict[str, Any], 16 | dataset_backend: Literal["s3"] = "s3", 17 | ) -> None: 18 | """ 19 | Version the dataset on Weights and Biases using the config parameters defined in wickerconfig.json. 20 | 21 | Args: 22 | dataset_name: The name of the dataset to be versioned 23 | dataset_version: The version of the dataset to be versioned 24 | entity: Who the run will belong to 25 | metadata: The metadata to be logged as an artifact, enforces dataclass for metadata documentation 26 | dataset_backend: The backend where the dataset is stored, currently only supports s3 27 | """ 28 | # needs to first acquire and set wandb creds 29 | # WANDB_API_KEY, WANDB_BASE_URL 30 | # _set_wandb_credentials() 31 | 32 | # needs to init the wandb run, this is going to be a 'data' run 33 | dataset_run = wandb.init(project="dataset_curation", name=f"{dataset_name}_{dataset_version}", entity=entity) 34 | 35 | # grab the uri of the dataset to be versioned 36 | dataset_uri = _identify_s3_url_for_dataset_version(dataset_name, dataset_version, dataset_backend) 37 | 38 | # establish the artifact and save the dir/s3_url to the artifact 39 | data_artifact = wandb.Artifact(f"{dataset_name}_{dataset_version}", type="dataset") 40 | data_artifact.add_reference(dataset_uri, name="dataset") 41 | 42 | # save metadata dict to the artifact 43 | data_artifact.metadata["version"] = dataset_version 44 | data_artifact.metadata["s3_uri"] = dataset_uri 45 | for key, value in metadata.items(): 46 | data_artifact.metadata[key] = value 47 | 48 | # save the artifact to the run 49 | dataset_run.log_artifact(data_artifact) # type: ignore 50 | dataset_run.finish() # type: ignore 51 | 52 | 53 | def _set_wandb_credentials() -> None: 54 | """ 55 | Acquire the weights and biases credentials and load them into the environment. 56 | 57 | This load the variables into the environment as ENV Variables for WandB to use, 58 | this function overrides the previously set wandb env variables with the ones specified in 59 | the wicker config if they exist. 60 | """ 61 | # load the config 62 | config = get_config() 63 | 64 | # if the keys are present in the config add them to the env 65 | wandb_config = config.wandb_config 66 | for field in wandb_config.__dataclass_fields__: # type: ignore 67 | attr = wandb_config.__getattribute__(field) 68 | if attr is not None: 69 | os.environ[str(field).upper()] = attr 70 | else: 71 | if str(field).upper() not in os.environ: 72 | raise EnvironmentError( 73 | f"Cannot use W&B without setting {str(field.upper())}. " 74 | f"Specify in either ENV or through wicker config file." 75 | ) 76 | 77 | 78 | def _identify_s3_url_for_dataset_version( 79 | dataset_name: str, 80 | dataset_version: str, 81 | dataset_backend: Literal["s3"] = "s3", 82 | ) -> str: 83 | """ 84 | Identify and return the s3 url for the dataset and version specified in the backend. 85 | 86 | Args: 87 | dataset_name: name of the dataset to retrieve url 88 | dataset_version: version of the dataset to retrieve url 89 | dataset_backend: backend of the dataset to retrieve url 90 | 91 | Returns: 92 | The url pointing to the dataset on storage 93 | """ 94 | schema_path = "" 95 | if dataset_backend == "s3": 96 | # needs to do the parsing work to fetch the correct s3 uri 97 | schema_path = S3PathFactory().get_dataset_assets_path(DatasetID(name=dataset_name, version=dataset_version)) 98 | return schema_path 99 | -------------------------------------------------------------------------------- /wicker/schema/__init__.py: -------------------------------------------------------------------------------- 1 | from .schema import * # noqa: F401, F403 2 | -------------------------------------------------------------------------------- /wicker/schema/codecs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | from typing import Any, Dict, Type 5 | 6 | 7 | class Codec(abc.ABC): 8 | """Base class for all object decoder/encoders. 9 | 10 | Defining Codec classes allows user to use arbitrary types in their Wicker Examples. Subclasses of Codec provide 11 | functionality to serialize/deserialize data to/from underlying storage, and can be used to abstract functionality 12 | such as compression, validation of object fields and provide adaptor functionality to other libraries (e.g. NumPy, 13 | PIL, torch). 14 | 15 | Wicker provides some Codec implementations by default for NumPy, but users can choose to define their own Codecs 16 | and use them during data-writing and data-reading. 17 | """ 18 | 19 | # Maps from codec_name to codec class. 20 | codec_registry: Dict[str, Type[Codec]] = {} 21 | 22 | def __init_subclass__(cls: Type[Codec], **kwargs: Any) -> None: 23 | """Automatically register subclasses of Codec.""" 24 | if cls._codec_name() in Codec.codec_registry: 25 | raise KeyError(f"Codec '{cls._codec_name()}' was already defined.") 26 | if cls._codec_name(): 27 | Codec.codec_registry[cls._codec_name()] = cls 28 | 29 | @staticmethod 30 | @abc.abstractmethod 31 | def _codec_name() -> str: 32 | """Needs to be overriden to return a globally unique name for the codec.""" 33 | pass 34 | 35 | def get_codec_name(self) -> str: 36 | """Accessor for _codec_name. In general, derived classes should not touch this, and should just implement 37 | _codec_name. This accessor is used internally to support generic schema serialization/deserialization. 38 | """ 39 | return self._codec_name() 40 | 41 | def save_codec_to_dict(self) -> Dict[str, Any]: 42 | """If you want to save some parameters of this codec with the dataset 43 | schema, return the fields here. The returned dictionary should be JSON compatible. 44 | Note that this is a dataset-level value, not a per example value.""" 45 | return {} 46 | 47 | @staticmethod 48 | @abc.abstractmethod 49 | def load_codec_from_dict(data: Dict[str, Any]) -> Codec: 50 | """Create a new instance of this codec with the given parameters.""" 51 | pass 52 | 53 | @abc.abstractmethod 54 | def validate_and_encode_object(self, obj: Any) -> bytes: 55 | """Encode the given object into bytes. The function is also responsible for validating the data. 56 | :param obj: Object to encode 57 | :return: The encoded bytes for the given object.""" 58 | pass 59 | 60 | @abc.abstractmethod 61 | def decode_object(self, data: bytes) -> Any: 62 | """Decode an object from the given bytes. This is the opposite of validate_and_encode_object. 63 | We expect obj == decode_object(validate_and_encode_object(obj)) 64 | :param data: bytes to decode. 65 | :return: Decoded object.""" 66 | pass 67 | 68 | def object_type(self) -> Type[Any]: 69 | """Return the expected type of the objects handled by this codec. 70 | This method can be overriden to match more specific classes.""" 71 | return object 72 | 73 | def __eq__(self, other: Any) -> bool: 74 | return ( 75 | super().__eq__(other) 76 | and self.get_codec_name() == other.get_codec_name() 77 | and self.save_codec_to_dict() == other.save_codec_to_dict() 78 | ) 79 | -------------------------------------------------------------------------------- /wicker/schema/dataloading.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, TypeVar 2 | 3 | from wicker.schema import schema, validation 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def load_example(example: validation.AvroRecord, schema: schema.DatasetSchema) -> Dict[str, Any]: 9 | """Loads an example according to the provided schema, converting data based on their column 10 | information into their corresponding in-memory representations (e.g. numpy arrays, torch tensors) 11 | 12 | The returned dictionary will be in the same shape as defined in the schema, with keys 13 | corresponding to the schema names, and values validated/transformed against the schema's fields. 14 | 15 | :param example: example to load 16 | :type example: validation.AvroRecord 17 | :param schema: schema to load against 18 | :type schema: schema.DatasetSchema 19 | :return: loaded example 20 | :rtype: Dict[str, Any] 21 | """ 22 | load_example_visitor = LoadExampleVisitor(example, schema) 23 | return load_example_visitor.load_example() 24 | 25 | 26 | class LoadExampleVisitor(schema.DatasetSchemaVisitor[Any]): 27 | """schema.DatasetSchemaVisitor class that will validate and load an Example 28 | in accordance to a provided schema.DatasetSchema 29 | """ 30 | 31 | def __init__(self, example: validation.AvroRecord, schema: schema.DatasetSchema): 32 | # Pointers to original data (data should be kept immutable) 33 | self._schema = schema 34 | self._example = example 35 | 36 | # Pointers to help keep visitor state during tree traversal 37 | self._current_data: Any = self._example 38 | self._current_path: Tuple[str, ...] = tuple() 39 | 40 | def load_example(self) -> Dict[str, Any]: 41 | """Loads an example from its Avro format into its in-memory representations""" 42 | # Since the original input example is non-None, the loaded example will be non-None also 43 | example: Dict[str, Any] = self._schema.schema_record._accept_visitor(self) 44 | return example 45 | 46 | def process_record_field(self, field: schema.RecordField) -> Optional[validation.AvroRecord]: 47 | """Visit an schema.RecordField schema field""" 48 | current_data = validation.validate_dict(self._current_data, field.required, self._current_path) 49 | if current_data is None: 50 | return current_data 51 | 52 | # Process nested fields by setting up the visitor's state and visiting each node 53 | processing_path = self._current_path 54 | processing_example = current_data 55 | loaded = {} 56 | 57 | # When reading records, the client might restrict the columns to load to a subset of the 58 | # full columns, so check if the key is actually present in the example being processed 59 | for nested_field in field.fields: 60 | if nested_field.name in processing_example: 61 | self._current_path = processing_path + (nested_field.name,) 62 | self._current_data = processing_example[nested_field.name] 63 | loaded[nested_field.name] = nested_field._accept_visitor(self) 64 | return loaded 65 | 66 | def process_int_field(self, field: schema.IntField) -> Optional[int]: 67 | return validation.validate_field_type(self._current_data, int, field.required, self._current_path) 68 | 69 | def process_long_field(self, field: schema.LongField) -> Optional[int]: 70 | return validation.validate_field_type(self._current_data, int, field.required, self._current_path) 71 | 72 | def process_string_field(self, field: schema.StringField) -> Optional[str]: 73 | return validation.validate_field_type(self._current_data, str, field.required, self._current_path) 74 | 75 | def process_bool_field(self, field: schema.BoolField) -> Optional[bool]: 76 | return validation.validate_field_type(self._current_data, bool, field.required, self._current_path) 77 | 78 | def process_float_field(self, field: schema.FloatField) -> Optional[float]: 79 | return validation.validate_field_type(self._current_data, float, field.required, self._current_path) 80 | 81 | def process_double_field(self, field: schema.DoubleField) -> Optional[float]: 82 | return validation.validate_field_type(self._current_data, float, field.required, self._current_path) 83 | 84 | def process_object_field(self, field: schema.ObjectField) -> Optional[Any]: 85 | data = validation.validate_field_type(self._current_data, bytes, field.required, self._current_path) 86 | if data is None: 87 | return data 88 | return field.codec.decode_object(data) 89 | 90 | def process_array_field(self, field: schema.ArrayField) -> Optional[List[Any]]: 91 | current_data = validation.validate_field_type(self._current_data, list, field.required, self._current_path) 92 | if current_data is None: 93 | return current_data 94 | 95 | # Process array elements by setting up the visitor's state and visiting each element 96 | processing_path = self._current_path 97 | loaded = [] 98 | 99 | # Arrays may contain None values if the element field declares that it is not required 100 | for element_index, element in enumerate(current_data): 101 | self._current_path = processing_path + (f"elem[{element_index}]",) 102 | self._current_data = element 103 | loaded.append(field.element_field._accept_visitor(self)) 104 | return loaded 105 | -------------------------------------------------------------------------------- /wicker/schema/dataparsing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | from wicker.core.definitions import ExampleMetadata 4 | from wicker.core.errors import WickerSchemaException 5 | from wicker.schema import schema, validation 6 | 7 | 8 | def parse_example(example: Dict[str, Any], schema: schema.DatasetSchema) -> validation.AvroRecord: 9 | """Parses an example according to the provided schema, converting known types such as 10 | numpy arrays, torch tensors etc into bytes for storage on disk 11 | 12 | The returned dictionary will be in the same shape as defined in the schema, with keys 13 | corresponding to the schema names, and values validated against the schema's fields. 14 | 15 | :param example: example to parse 16 | :type example: Dict[str, Any] 17 | :param schema: schema to parse against 18 | :type schema: schema.DatasetSchema 19 | :return: parsed example 20 | :rtype: Dict[str, Any] 21 | """ 22 | parse_example_visitor = ParseExampleVisitor(example, schema) 23 | return parse_example_visitor.parse_example() 24 | 25 | 26 | def parse_example_metadata(example: Dict[str, Any], schema: schema.DatasetSchema) -> ExampleMetadata: 27 | """Parses ExampleMetadata from an example according to the provided schema 28 | 29 | The returned dictionary will be in the same shape as defined in the schema, with keys 30 | corresponding to the schema names, and values validated against the schema's fields. 31 | 32 | Certain field types will be ignored when parsing examples as they are not considered part of the example metadata: 33 | 34 | 1. BytesField 35 | ... More to be implemented (e.g. AvroNumpyField etc) 36 | 37 | :param example: example to parse 38 | :type example: Dict[str, Any] 39 | :param schema: schema to parse against 40 | :type schema: schema.DatasetSchema 41 | """ 42 | parse_example_metadata_visitor = ParseExampleMetadataVisitor(example, schema) 43 | return parse_example_metadata_visitor.parse_example() 44 | 45 | 46 | # A special exception to indicate that a field should be skipped 47 | class _SkipFieldException(Exception): 48 | pass 49 | 50 | 51 | class ParseExampleVisitor(schema.DatasetSchemaVisitor[Any]): 52 | """schema.DatasetSchemaVisitor class that will validate and transform an Example 53 | in accordance to a provided schema.DatasetSchema 54 | 55 | The original example is not modified in-place, and we incur the cost of copying 56 | primitives (e.g. bytes, strings) 57 | """ 58 | 59 | def __init__(self, example: Dict[str, Any], schema: schema.DatasetSchema): 60 | # Pointers to original data (data should be kept immutable) 61 | self._schema = schema 62 | self._example = example 63 | 64 | # Pointers to help keep visitor state during tree traversal 65 | self._current_data: Any = self._example 66 | self._current_path: Tuple[str, ...] = tuple() 67 | 68 | def parse_example(self) -> Dict[str, Any]: 69 | """Parses an example into a form that is suitable for storage in an Avro format""" 70 | # Since the original input example is non-None, the parsed example will be non-None also 71 | example: Dict[str, Any] = self._schema.schema_record._accept_visitor(self) 72 | return example 73 | 74 | def process_record_field(self, field: schema.RecordField) -> Optional[validation.AvroRecord]: 75 | """Visit an schema.RecordField schema field""" 76 | val = validation.validate_field_type(self._current_data, dict, field.required, self._current_path) 77 | if val is None: 78 | return val 79 | 80 | # Add keys to the example for any non-required fields that were left unset in the raw data 81 | for optional_field in [f for f in field.fields if not f.required]: 82 | if optional_field.name not in val: 83 | val[optional_field.name] = None 84 | 85 | # Validate that data matches schema exactly 86 | schema_key_names = {nested_field.name for nested_field in field.fields} 87 | if val.keys() != schema_key_names: 88 | raise WickerSchemaException( 89 | f"Error at path {'.'.join(self._current_path)}: " 90 | f"Example missing keys: {list(schema_key_names - val.keys())} " 91 | f"and has extra keys: {list(val.keys() - schema_key_names)}" 92 | ) 93 | 94 | # Process nested fields by setting up the visitor's state and visiting each node 95 | res = {} 96 | processing_path = self._current_path 97 | processing_example = self._current_data 98 | 99 | for nested_field in field.fields: 100 | self._current_path = processing_path + (nested_field.name,) 101 | self._current_data = processing_example[nested_field.name] 102 | try: 103 | res[nested_field.name] = nested_field._accept_visitor(self) 104 | except _SkipFieldException: 105 | pass 106 | return res 107 | 108 | def process_int_field(self, field: schema.IntField) -> Optional[int]: 109 | return validation.validate_field_type(self._current_data, int, field.required, self._current_path) 110 | 111 | def process_long_field(self, field: schema.LongField) -> Optional[int]: 112 | return validation.validate_field_type(self._current_data, int, field.required, self._current_path) 113 | 114 | def process_string_field(self, field: schema.StringField) -> Optional[str]: 115 | return validation.validate_field_type(self._current_data, str, field.required, self._current_path) 116 | 117 | def process_bool_field(self, field: schema.BoolField) -> Optional[bool]: 118 | return validation.validate_field_type(self._current_data, bool, field.required, self._current_path) 119 | 120 | def process_float_field(self, field: schema.FloatField) -> Optional[float]: 121 | return validation.validate_field_type(self._current_data, float, field.required, self._current_path) 122 | 123 | def process_double_field(self, field: schema.DoubleField) -> Optional[float]: 124 | return validation.validate_field_type(self._current_data, float, field.required, self._current_path) 125 | 126 | def process_object_field(self, field: schema.ObjectField) -> Optional[bytes]: 127 | data = validation.validate_field_type( 128 | self._current_data, field.codec.object_type(), field.required, self._current_path 129 | ) 130 | if data is None: 131 | return None 132 | return field.codec.validate_and_encode_object(data) 133 | 134 | def process_array_field(self, field: schema.ArrayField) -> Optional[List[Any]]: 135 | val = validation.validate_field_type(self._current_data, list, field.required, self._current_path) 136 | if val is None: 137 | return val 138 | 139 | # Process array elements by setting up the visitor's state and visiting each element 140 | res = [] 141 | processing_path = self._current_path 142 | processing_example = self._current_data 143 | 144 | # Arrays may contain None values if the element field declares that it is not required 145 | for element_index, element in enumerate(processing_example): 146 | self._current_path = processing_path + (f"elem[{element_index}]",) 147 | self._current_data = element 148 | # Allow _SkipFieldExceptions to propagate up and skip this array field 149 | res.append(field.element_field._accept_visitor(self)) 150 | return res 151 | 152 | 153 | class ParseExampleMetadataVisitor(ParseExampleVisitor): 154 | """Specialization of ParseExampleVisitor which skips over certain fields that are now parsed as metadata""" 155 | 156 | def process_object_field(self, field: schema.ObjectField) -> Optional[bytes]: 157 | # Raise a special error to indicate that this field should be skipped 158 | raise _SkipFieldException() 159 | -------------------------------------------------------------------------------- /wicker/schema/serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Any, Dict, Type, Union 4 | 5 | from wicker.core.errors import WickerSchemaException 6 | from wicker.schema import codecs, schema 7 | from wicker.schema.schema import PRIMARY_KEYS_TAG 8 | 9 | JSON_SCHEMA_VERSION = 2 10 | DATASET_ID_REGEX = re.compile(r"^(?P[0-9a-zA-Z_]+)/v(?P[0-9]+\.[0-9]+\.[0-9]+)$") 11 | 12 | 13 | def dumps(schema: schema.DatasetSchema, pretty: bool = True) -> str: 14 | """Dumps a schema as JSON 15 | 16 | :param schema: schema to dump 17 | :type schema: schema.DatasetSchema 18 | :param pretty: whether to dump the schema as a prettified JSON, defaults to False 19 | :type pretty: bool, optional 20 | :return: JSON string 21 | :rtype: str 22 | """ 23 | visitor = AvroDatasetSchemaSerializer() 24 | jdata = schema.schema_record._accept_visitor(visitor) 25 | jdata.update(jdata["type"]) 26 | jdata["_json_version"] = JSON_SCHEMA_VERSION 27 | if pretty: 28 | return json.dumps(jdata, sort_keys=True, indent=4) 29 | return json.dumps(jdata) 30 | 31 | 32 | def loads(schema_str: str, treat_objects_as_bytes: bool = False) -> schema.DatasetSchema: 33 | """Loads a DatasetSchema from a JSON str 34 | 35 | :param schema_str: JSON string 36 | :param treat_objects_as_bytes: If set, don't load codecs for object types, instead just treat them as bytes. 37 | This is useful for code that needs to work with datasets in a generic way, but does not need to actually decode 38 | the data. 39 | :return: Deserialized DatasetSchema object 40 | """ 41 | # Parse string as JSON 42 | try: 43 | schema_dict = json.loads(schema_str) 44 | except json.decoder.JSONDecodeError: 45 | raise WickerSchemaException(f"Unable load string as JSON: {schema_str}") 46 | 47 | # Construct the DatasetSchema 48 | try: 49 | fields = [_loads(d, treat_objects_as_bytes=treat_objects_as_bytes) for d in schema_dict["fields"]] 50 | return schema.DatasetSchema( 51 | fields=fields, 52 | primary_keys=json.loads(schema_dict.get(PRIMARY_KEYS_TAG, "[]")), 53 | allow_empty_primary_keys=True, # For backward compatibility. Clean me AVSW-78939. 54 | ) 55 | except KeyError as err: 56 | raise WickerSchemaException(f"Malformed serialization of DatasetSchema: {err}") 57 | 58 | 59 | def _loads(schema_dict: Dict[str, Any], treat_objects_as_bytes: bool) -> schema.SchemaField: 60 | """Recursively parses a schema dictionary into the appropriate SchemaField""" 61 | type_ = schema_dict["type"] 62 | required = True 63 | 64 | # Nullable types are represented in Avro schemas as Union (list) types with "null" as the 65 | # first element (by convention) 66 | if isinstance(type_, list): 67 | assert type_[0] == "null" 68 | required = False 69 | type_ = type_[1] 70 | 71 | if isinstance(type_, dict) and type_["type"] == "record": 72 | return schema.RecordField( 73 | name=schema_dict["name"], 74 | fields=[_loads(f, treat_objects_as_bytes=treat_objects_as_bytes) for f in type_["fields"]], 75 | description=schema_dict["_description"], 76 | required=required, 77 | ) 78 | if isinstance(type_, dict) and type_["type"] == "array": 79 | # ArrayField columns are limited to contain only dicts or simple types (and not nested arrays) 80 | element_dict = type_.copy() 81 | element_dict["type"] = type_["items"] 82 | element_field = _loads(element_dict, treat_objects_as_bytes=treat_objects_as_bytes) 83 | return schema.ArrayField( 84 | element_field=element_field, 85 | required=required, 86 | ) 87 | return _loads_base_types(type_, required, schema_dict, treat_objects_as_bytes=treat_objects_as_bytes) 88 | 89 | 90 | def _loads_base_types( 91 | type_: str, required: bool, schema_dict: Dict[str, Any], treat_objects_as_bytes: bool = False 92 | ) -> schema.SchemaField: 93 | if type_ == "int": 94 | return schema.IntField( 95 | name=schema_dict["name"], 96 | description=schema_dict["_description"], 97 | required=required, 98 | ) 99 | elif type_ == "long": 100 | return schema.LongField( 101 | name=schema_dict["name"], 102 | description=schema_dict["_description"], 103 | required=required, 104 | ) 105 | elif type_ == "string": 106 | return schema.StringField( 107 | name=schema_dict["name"], 108 | description=schema_dict["_description"], 109 | required=required, 110 | ) 111 | elif type_ == "boolean": 112 | return schema.BoolField( 113 | name=schema_dict["name"], 114 | description=schema_dict["_description"], 115 | required=required, 116 | ) 117 | elif type_ == "float": 118 | return schema.FloatField( 119 | name=schema_dict["name"], 120 | description=schema_dict["_description"], 121 | required=required, 122 | ) 123 | elif type_ == "double": 124 | return schema.DoubleField( 125 | name=schema_dict["name"], 126 | description=schema_dict["_description"], 127 | required=required, 128 | ) 129 | elif type_ == "bytes": 130 | l5ml_metatype = schema_dict["_l5ml_metatype"] 131 | if l5ml_metatype == "object": 132 | codec_name = schema_dict["_codec_name"] 133 | if treat_objects_as_bytes: 134 | codec: codecs.Codec = _PassThroughObjectCodec(codec_name, json.loads(schema_dict["_codec_params"])) 135 | else: 136 | try: 137 | codec_cls: Type[codecs.Codec] = codecs.Codec.codec_registry[codec_name] 138 | except KeyError: 139 | raise WickerSchemaException( 140 | f"Could not find a registered codec with name {codec_name} " 141 | f"for field {schema_dict['name']}. Please define a subclass of ObjectField.Codec and define " 142 | "the codec_name static method." 143 | ) 144 | codec = codec_cls.load_codec_from_dict(json.loads(schema_dict["_codec_params"])) 145 | return schema.ObjectField( 146 | name=schema_dict["name"], 147 | codec=codec, 148 | description=schema_dict["_description"], 149 | required=required, 150 | is_heavy_pointer=schema_dict.get("_is_heavy_pointer", False), 151 | ) 152 | elif l5ml_metatype == "numpy": 153 | return schema.NumpyField( 154 | name=schema_dict["name"], 155 | description=schema_dict["_description"], 156 | is_heavy_pointer=schema_dict.get("_is_heavy_pointer", False), 157 | required=required, 158 | shape=schema_dict["_shape"], 159 | dtype=schema_dict["_dtype"], 160 | ) 161 | elif l5ml_metatype == "bytes": 162 | return schema.BytesField( 163 | name=schema_dict["name"], 164 | description=schema_dict["_description"], 165 | required=required, 166 | is_heavy_pointer=schema_dict.get("_is_heavy_pointer", False), 167 | ) 168 | raise WickerSchemaException(f"Unhandled _l5ml_metatype for avro bytes type: {l5ml_metatype}") 169 | elif type_ == "record": 170 | schema_fields = [] 171 | for schema_field in schema_dict["fields"]: 172 | loaded_field = _loads_base_types(schema_field["type"], True, schema_field) 173 | schema_fields.append(loaded_field) 174 | return schema.RecordField( 175 | fields=schema_fields, 176 | name=schema_dict["name"], 177 | ) 178 | raise WickerSchemaException(f"Unhandled type: {type_}") 179 | 180 | 181 | class _PassThroughObjectCodec(codecs.Codec): 182 | """The _PassThroughObjectCodec class is a placeholder for any object codec, it is used when we need to parse 183 | any possible schema properly, but we don't need to read the data. This codec does not decode/encode the data, 184 | instead is just acts as an identity function. However, it keeps track of all the attributes of the original codec 185 | as they were stored in the loaded schema, so that if we save that schema again, we do not lose any information. 186 | """ 187 | 188 | def __init__(self, codec_name: str, codec_attributes: Dict[str, Any]): 189 | self._codec_name_value = codec_name 190 | self._codec_attributes = codec_attributes 191 | 192 | @staticmethod 193 | def _codec_name() -> str: 194 | # The static method does not make any sense here, because this class is a placeholder for any codec class. 195 | # Instead we will overload the get_codec_name function. 196 | return "" 197 | 198 | def get_codec_name(self) -> str: 199 | return self._codec_name_value 200 | 201 | def save_codec_to_dict(self) -> Dict[str, Any]: 202 | return self._codec_attributes 203 | 204 | @staticmethod 205 | def load_codec_from_dict(data: Dict[str, Any]) -> codecs.Codec: 206 | pass 207 | 208 | def validate_and_encode_object(self, obj: bytes) -> bytes: 209 | return obj 210 | 211 | def decode_object(self, data: bytes) -> bytes: 212 | return data 213 | 214 | def object_type(self) -> Type[Any]: 215 | return bytes 216 | 217 | 218 | class AvroDatasetSchemaSerializer(schema.DatasetSchemaVisitor[Dict[str, Any]]): 219 | """A visitor class that serializes an AvroDatasetSchema as Avro-compatible JSON""" 220 | 221 | def process_schema_field(self, field: schema.SchemaField, avro_type: Union[str, Dict[str, Any]]) -> Dict[str, Any]: 222 | """Common processing across all field types""" 223 | # The way to declare nullable fields in Avro schemas is to declare Avro Union (list) types. 224 | # The default type (usually null) should be listed first: 225 | # https://avro.apache.org/docs/current/spec.html#Unions 226 | field_type = avro_type if field.required else ["null", avro_type] 227 | return { 228 | "name": field.name, 229 | "type": field_type, 230 | "_description": field.description, 231 | **field.custom_field_tags, 232 | } 233 | 234 | def process_record_field(self, field: schema.RecordField) -> Dict[str, Any]: 235 | record_type = { 236 | "type": "record", 237 | "fields": [nested_field._accept_visitor(self) for nested_field in field.fields], 238 | "name": field.name, 239 | } 240 | return self.process_schema_field(field, record_type) 241 | 242 | def process_int_field(self, field: schema.IntField) -> Dict[str, Any]: 243 | return { 244 | **self.process_schema_field(field, "int"), 245 | } 246 | 247 | def process_long_field(self, field: schema.LongField) -> Dict[str, Any]: 248 | return { 249 | **self.process_schema_field(field, "long"), 250 | } 251 | 252 | def process_string_field(self, field: schema.StringField) -> Dict[str, Any]: 253 | return { 254 | **self.process_schema_field(field, "string"), 255 | } 256 | 257 | def process_bool_field(self, field: schema.BoolField) -> Dict[str, Any]: 258 | return { 259 | **self.process_schema_field(field, "boolean"), 260 | } 261 | 262 | def process_float_field(self, field: schema.FloatField) -> Dict[str, Any]: 263 | return { 264 | **self.process_schema_field(field, "float"), 265 | } 266 | 267 | def process_double_field(self, field: schema.DoubleField) -> Dict[str, Any]: 268 | return { 269 | **self.process_schema_field(field, "double"), 270 | } 271 | 272 | def process_object_field(self, field: schema.ObjectField) -> Dict[str, Any]: 273 | return { 274 | **self.process_schema_field(field, "bytes"), 275 | "_l5ml_metatype": "object", 276 | "_is_heavy_pointer": field.is_heavy_pointer, 277 | } 278 | 279 | def process_array_field(self, field: schema.ArrayField) -> Dict[str, Any]: 280 | array_type = field.element_field._accept_visitor(self) 281 | array_type["items"] = array_type["type"] 282 | array_type["type"] = "array" 283 | 284 | field_type = array_type if field.required else ["null", array_type] 285 | return { 286 | "name": field.name, 287 | "type": field_type, 288 | } 289 | -------------------------------------------------------------------------------- /wicker/schema/validation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Type, TypeVar 2 | 3 | from wicker.core.errors import WickerSchemaException 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | # Unfortunately, mypy has weak support for recursive typing so we resort to Any here, but Any 9 | # should really be an alias for: Union[AvroRecord, str, bool, int, float, bytes] 10 | AvroRecord = Dict[str, Any] 11 | 12 | 13 | def validate_field_type(val: Any, type_: Type[T], required: bool, current_path: Tuple[str, ...]) -> Optional[T]: 14 | """Validates the type of a field 15 | 16 | :param val: value to validate 17 | :type val: Any 18 | :param type_: type to validate against 19 | :type type_: Type[T] 20 | :param required: whether or not the value is required (to be non-None) 21 | :type required: bool 22 | :param current_path: current parsing path 23 | :type current_path: Tuple[str, ...] 24 | :raises WickerSchemaException: when parsing error occurs 25 | :return: val, but validated to be of type T 26 | :rtype: T 27 | """ 28 | if val is None: 29 | if not required: 30 | return val 31 | raise WickerSchemaException( 32 | f"Error at path {'.'.join(current_path)}: Example provided a None value for required field" 33 | ) 34 | elif not isinstance(val, type_): 35 | raise WickerSchemaException( 36 | f"Error at path {'.'.join(current_path)}: Example provided a {type(val)} value, expected {type_}" 37 | ) 38 | return val 39 | 40 | 41 | def validate_dict(val: Any, required: bool, current_path: Tuple[str, ...]) -> Optional[Dict[str, Any]]: 42 | """Validates a dictionary 43 | 44 | :param val: incoming value to validate as a dictionary 45 | :type val: Any 46 | :param required: whether or not the value is required (to be non-None) 47 | :type required: bool 48 | :param current_path: current parsing path 49 | :type current_path: Tuple[str, ...] 50 | :return: parsed dictionary 51 | :rtype: dict 52 | """ 53 | if val is None: 54 | if not required: 55 | return val 56 | raise WickerSchemaException( 57 | f"Error at path {'.'.join(current_path)}: Example provided a None value for required field" 58 | ) 59 | # PyArrow returns record fields as lists of (k, v) tuples 60 | elif isinstance(val, list): 61 | try: 62 | parsed_val = dict(val) 63 | except ValueError: 64 | raise WickerSchemaException(f"Error at path {'.'.join(current_path)}: Unable to convert list to dict") 65 | elif isinstance(val, dict): 66 | parsed_val = val 67 | else: 68 | raise WickerSchemaException(f"Error at path {'.'.join(current_path)}: Unable to convert object to dict") 69 | 70 | return validate_field_type(parsed_val, dict, required, current_path) 71 | -------------------------------------------------------------------------------- /wicker/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/woven-planet/wicker/babcf9a50419e9ee0a58b115ba96300a008d6344/wicker/testing/__init__.py -------------------------------------------------------------------------------- /wicker/testing/codecs.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List, Type 3 | 4 | from wicker.schema import codecs 5 | 6 | 7 | class Vector: 8 | def __init__(self, data: List[int]): 9 | self.data = data 10 | 11 | def __eq__(self, other: Any) -> bool: 12 | return super().__eq__(other) and self.data == other.data 13 | 14 | 15 | class VectorCodec(codecs.Codec): 16 | def __init__(self, compression_method: int) -> None: 17 | self.compression_method = compression_method 18 | 19 | @staticmethod 20 | def _codec_name() -> str: 21 | return "VectorCodec" 22 | 23 | def save_codec_to_dict(self) -> Dict[str, Any]: 24 | return {"compression_method": self.compression_method} 25 | 26 | @staticmethod 27 | def load_codec_from_dict(data: Dict[str, Any]) -> codecs.Codec: 28 | return VectorCodec(compression_method=data["compression_method"]) 29 | 30 | def validate_and_encode_object(self, obj: Vector) -> bytes: 31 | # Inefficient but simple encoding method for testing. 32 | return json.dumps(obj.data).encode("utf-8") 33 | 34 | def decode_object(self, data: bytes) -> Vector: 35 | return Vector(json.loads(data.decode())) 36 | 37 | def object_type(self) -> Type[Any]: 38 | return Vector 39 | -------------------------------------------------------------------------------- /wicker/testing/storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | from typing import Any, Dict 5 | 6 | import pyarrow.fs as pafs 7 | 8 | from wicker.core.storage import S3DataStorage 9 | 10 | 11 | class FakeS3DataStorage(S3DataStorage): 12 | def __init__(self, tmpdir: str = "/tmp") -> None: 13 | self._tmpdir = tmpdir 14 | 15 | def __getstate__(self) -> Dict[Any, Any]: 16 | return {"tmpdir": self._tmpdir} 17 | 18 | def __setstate__(self, state: Dict[Any, Any]) -> None: 19 | self._tmpdir = state["tmpdir"] 20 | return 21 | 22 | def _get_local_path(self, path: str) -> str: 23 | return os.path.join(self._tmpdir, path.replace("s3://", "")) 24 | 25 | def check_exists_s3(self, input_path: str) -> bool: 26 | return os.path.exists(self._get_local_path(input_path)) 27 | 28 | def fetch_obj_s3(self, input_path: str) -> bytes: 29 | if not self.check_exists_s3(input_path): 30 | raise KeyError(f"File {input_path} not found in the fake s3 storage.") 31 | with open(self._get_local_path(input_path), "rb") as f: 32 | return f.read() 33 | 34 | def fetch_file(self, input_path: str, local_prefix: str, timeout_seconds: int = 120) -> str: 35 | if not self.check_exists_s3(input_path): 36 | raise KeyError(f"File {input_path} not found in the fake s3 storage.") 37 | bucket, key = self.bucket_key_from_s3_path(input_path) 38 | dest_path = os.path.join(local_prefix, key) 39 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 40 | if not os.path.isfile(dest_path): 41 | shutil.copy2(self._get_local_path(input_path), dest_path) 42 | return dest_path 43 | 44 | def put_object_s3(self, object_bytes: bytes, s3_path: str) -> None: 45 | full_tmp_path = self._get_local_path(s3_path) 46 | os.makedirs(os.path.dirname(full_tmp_path), exist_ok=True) 47 | with open(full_tmp_path, "wb") as f: 48 | f.write(object_bytes) 49 | 50 | def put_file_s3(self, local_path: str, s3_path: str) -> None: 51 | full_tmp_path = self._get_local_path(s3_path) 52 | os.makedirs(os.path.dirname(full_tmp_path), exist_ok=True) 53 | shutil.copy2(local_path, full_tmp_path) 54 | 55 | 56 | class LocalDataStorage(S3DataStorage): 57 | def __init__(self, root_path: str): 58 | super().__init__() 59 | self._root_path = Path(root_path) 60 | self._fs = pafs.LocalFileSystem() 61 | 62 | @property 63 | def filesystem(self) -> pafs.FileSystem: 64 | return self._fs 65 | 66 | def _create_path(self, path: str) -> None: 67 | """Ensures the given path exists.""" 68 | self._fs.create_dir(path, recursive=True) 69 | 70 | # Override. 71 | def check_exists_s3(self, input_path: str) -> bool: 72 | file_info = self._fs.get_file_info(input_path) 73 | return file_info.type != pafs.FileType.NotFound 74 | 75 | # Override. 76 | def fetch_file(self, input_path: str, local_prefix: str, timeout_seconds: int = 120) -> str: 77 | # This raises if the input path is not relative to the root. 78 | relative_input_path = Path(input_path).relative_to(self._root_path) 79 | 80 | target_path = os.path.join(local_prefix, str(relative_input_path)) 81 | self._create_path(os.path.dirname(target_path)) 82 | self._fs.copy_file(input_path, target_path) 83 | return target_path 84 | 85 | # Override. 86 | def fetch_partial_file_s3( 87 | self, input_path: str, local_prefix: str, offset: int, size: int, timeout_seconds: int = 120 88 | ) -> str: 89 | raise NotImplementedError("fetch_partial_file_s3") 90 | 91 | # Override. 92 | def put_object_s3(self, object_bytes: bytes, s3_path: str) -> None: 93 | self._create_path(os.path.dirname(s3_path)) 94 | with self._fs.open_output_stream(s3_path) as ostream: 95 | ostream.write(object_bytes) 96 | 97 | # Override. 98 | def put_file_s3(self, local_path: str, s3_path: str) -> None: 99 | self._create_path(os.path.dirname(s3_path)) 100 | self._fs.copy_file(local_path, s3_path) 101 | --------------------------------------------------------------------------------