├── .cruft.json ├── .devcontainer └── devcontainer.json ├── .flake8 ├── .github ├── dependabot.yml └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── Dockerfile ├── LICENSE ├── README.md ├── docker-compose.yml ├── pyproject.toml ├── src └── graphchain │ ├── __init__.py │ ├── core.py │ ├── py.typed │ └── utils.py └── tests ├── __init__.py ├── test_dask_dataframe.py ├── test_graphchain.py └── test_high_level_graph.py /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "git@github.com:radix-ai/poetry-cookiecutter", 3 | "commit": "966391b0342fc923a177a30c39285adf4416fa17", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "package_name": "Graphchain", 8 | "package_description": "An efficient cache for the execution of dask graphs.", 9 | "package_url": "https://github.com/radix-ai/graphchain", 10 | "author_name": "Laurent Sorber", 11 | "author_email": "laurent@radix.ai", 12 | "python_version": "3.8", 13 | "with_fastapi_api": "0", 14 | "with_jupyter_lab": "0", 15 | "with_pydantic_typing": "0", 16 | "with_sentry_logging": "0", 17 | "with_streamlit_app": "0", 18 | "with_typer_cli": "0", 19 | "continuous_integration": "GitHub", 20 | "docstring_style": "NumPy", 21 | "private_package_repository_name": "", 22 | "private_package_repository_url": "", 23 | "__package_name_kebab_case": "graphchain", 24 | "__package_name_snake_case": "graphchain", 25 | "_template": "git@github.com:radix-ai/poetry-cookiecutter" 26 | } 27 | }, 28 | "directory": null 29 | } 30 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "graphchain", 3 | "dockerComposeFile": "../docker-compose.yml", 4 | "service": "dev", 5 | "workspaceFolder": "/app/", 6 | "overrideCommand": true, 7 | "postStartCommand": "cp --update /opt/build/poetry/poetry.lock /app/ && mkdir -p /app/.git/hooks/ && cp --update /opt/build/git/* /app/.git/hooks/", 8 | "customizations": { 9 | "vscode": { 10 | "extensions": [ 11 | "bungcip.better-toml", 12 | "eamodio.gitlens", 13 | "ms-azuretools.vscode-docker", 14 | "ms-python.python", 15 | "ms-python.vscode-pylance", 16 | "ryanluker.vscode-coverage-gutters", 17 | "visualstudioexptteam.vscodeintellicode" 18 | ], 19 | "settings": { 20 | "coverage-gutters.coverageFileNames": [ 21 | "reports/coverage.xml" 22 | ], 23 | "editor.codeActionsOnSave": { 24 | "source.organizeImports": true 25 | }, 26 | "editor.formatOnSave": true, 27 | "editor.rulers": [ 28 | 100 29 | ], 30 | "files.autoSave": "onFocusChange", 31 | "python.defaultInterpreterPath": "/opt/app-env/bin/python", 32 | "python.formatting.provider": "black", 33 | "python.linting.banditArgs": [ 34 | "--configfile", 35 | "pyproject.toml" 36 | ], 37 | "python.linting.banditEnabled": true, 38 | "python.linting.flake8Enabled": true, 39 | "python.linting.mypyEnabled": true, 40 | "python.linting.pydocstyleEnabled": true, 41 | "python.terminal.activateEnvironment": false, 42 | "python.testing.pytestEnabled": true, 43 | "terminal.integrated.defaultProfile.linux": "zsh", 44 | "terminal.integrated.profiles.linux": { 45 | "zsh": { 46 | "path": "/usr/bin/zsh" 47 | } 48 | } 49 | } 50 | } 51 | } 52 | } -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # http://flake8.pycqa.org/en/latest/user/configuration.html#project-configuration 3 | # https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#line-length 4 | # TODO: https://github.com/PyCQA/flake8/issues/234 5 | color = always 6 | doctests = True 7 | ignore = DAR103,E203,E501,W503 8 | max_line_length = 100 9 | max_complexity = 10 10 | 11 | # https://github.com/terrencepreilly/darglint#flake8 12 | # TODO: https://github.com/terrencepreilly/darglint/issues/130 13 | docstring_style = numpy 14 | strictness = long -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: github-actions 5 | directory: / 6 | schedule: 7 | interval: monthly 8 | commit-message: 9 | prefix: "ci" 10 | prefix-development: "ci" 11 | include: "scope" 12 | - package-ecosystem: pip 13 | directory: / 14 | schedule: 15 | interval: monthly 16 | commit-message: 17 | prefix: "build" 18 | prefix-development: "build" 19 | include: "scope" 20 | versioning-strategy: lockfile-only 21 | allow: 22 | - dependency-type: "all" 23 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.8" 20 | 21 | - name: Install Poetry 22 | run: pip install --no-input poetry 23 | 24 | - name: Publish package 25 | run: | 26 | poetry config pypi-token.pypi "${{ secrets.POETRY_PYPI_TOKEN_PYPI }}" 27 | poetry publish --build 28 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | pull_request: 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v3 17 | 18 | - name: Set up Node.js 19 | uses: actions/setup-node@v3 20 | with: 21 | node-version: 16 22 | 23 | - name: Install @devcontainers/cli 24 | run: npm install --location=global @devcontainers/cli 25 | 26 | - name: Start Dev Container 27 | env: 28 | DOCKER_BUILDKIT: 1 29 | run: | 30 | git config --global init.defaultBranch main 31 | devcontainer up --workspace-folder . 32 | 33 | - name: Lint package 34 | run: devcontainer exec --workspace-folder . poe lint 35 | 36 | - name: Test package 37 | run: devcontainer exec --workspace-folder . poe test 38 | 39 | - name: Upload coverage 40 | uses: codecov/codecov-action@v3 41 | with: 42 | files: reports/coverage.xml 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Coverage.py 2 | htmlcov/ 3 | reports/ 4 | 5 | # cruft 6 | *.rej 7 | 8 | # Data 9 | *.csv* 10 | *.dat* 11 | *.pickle* 12 | *.xls* 13 | *.zip* 14 | 15 | # direnv 16 | .envrc 17 | 18 | # dotenv 19 | .env 20 | 21 | # graphchain 22 | __graphchain_cache__/ 23 | 24 | # Hypothesis 25 | .hypothesis/ 26 | 27 | # Jupyter 28 | *.ipynb 29 | .ipynb_checkpoints/ 30 | notebooks/ 31 | 32 | # macOS 33 | .DS_Store 34 | 35 | # mypy 36 | .dmypy.json 37 | .mypy_cache/ 38 | 39 | # Node.js 40 | node_modules/ 41 | 42 | # Poetry 43 | .venv/ 44 | dist/ 45 | poetry.lock 46 | 47 | # PyCharm 48 | .idea/ 49 | 50 | # pyenv 51 | .python-version 52 | 53 | # pytest 54 | .pytest_cache/ 55 | 56 | # Python 57 | __pycache__/ 58 | *.py[cdo] 59 | 60 | # Terraform 61 | .terraform/ 62 | 63 | # VS Code 64 | .vscode/ 65 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com 2 | default_install_hook_types: [commit-msg, pre-commit] 3 | default_stages: [commit, manual] 4 | fail_fast: true 5 | repos: 6 | - repo: https://github.com/pre-commit/pygrep-hooks 7 | rev: v1.9.0 8 | hooks: 9 | - id: python-check-blanket-noqa 10 | - id: python-check-blanket-type-ignore 11 | - id: python-check-mock-methods 12 | - id: python-no-eval 13 | - id: python-no-log-warn 14 | - id: python-use-type-annotations 15 | - id: python-check-blanket-noqa 16 | - id: rst-backticks 17 | - id: rst-directive-colons 18 | - id: rst-inline-touching-normal 19 | - id: text-unicode-replacement-char 20 | - repo: https://github.com/pre-commit/pre-commit-hooks 21 | rev: v4.3.0 22 | hooks: 23 | - id: check-added-large-files 24 | - id: check-ast 25 | - id: check-builtin-literals 26 | - id: check-case-conflict 27 | - id: check-docstring-first 28 | - id: check-json 29 | - id: check-merge-conflict 30 | - id: check-shebang-scripts-are-executable 31 | - id: check-symlinks 32 | - id: check-toml 33 | - id: check-vcs-permalinks 34 | - id: check-xml 35 | - id: check-yaml 36 | - id: debug-statements 37 | - id: detect-private-key 38 | - id: fix-byte-order-marker 39 | - id: mixed-line-ending 40 | - id: trailing-whitespace 41 | types: [python] 42 | - id: end-of-file-fixer 43 | types: [python] 44 | - repo: local 45 | hooks: 46 | - id: commitizen 47 | name: commitizen 48 | entry: cz check --commit-msg-file 49 | require_serial: true 50 | language: system 51 | stages: [commit-msg] 52 | - id: pyupgrade 53 | name: pyupgrade 54 | entry: pyupgrade --py38-plus 55 | require_serial: true 56 | language: system 57 | types: [python] 58 | - id: absolufy-imports 59 | name: absolufy-imports 60 | entry: absolufy-imports 61 | require_serial: true 62 | language: system 63 | types: [python] 64 | - id: yesqa 65 | name: yesqa 66 | entry: yesqa 67 | require_serial: true 68 | language: system 69 | types: [python] 70 | - id: isort 71 | name: isort 72 | entry: isort 73 | require_serial: true 74 | language: system 75 | types: [python] 76 | - id: black 77 | name: black 78 | entry: black 79 | require_serial: true 80 | language: system 81 | types: [python] 82 | - id: shellcheck 83 | name: shellcheck 84 | entry: shellcheck --check-sourced 85 | language: system 86 | types: [shell] 87 | - id: bandit 88 | name: bandit 89 | entry: bandit --configfile pyproject.toml 90 | language: system 91 | types: [python] 92 | - id: pydocstyle 93 | name: pydocstyle 94 | entry: pydocstyle 95 | language: system 96 | types: [python] 97 | - id: flake8 98 | name: flake8 99 | entry: flake8 100 | language: system 101 | types: [python] 102 | - id: poetry-check 103 | name: poetry check 104 | entry: poetry check 105 | language: system 106 | files: pyproject.toml 107 | pass_filenames: false 108 | - id: mypy 109 | name: mypy 110 | entry: mypy 111 | language: system 112 | types: [python] 113 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v1.4.0 (2022-08-29) 2 | 3 | ### Feat 4 | 5 | - support multiprocessing schedulers (#88) 6 | 7 | ### Fix 8 | 9 | - add high level graph support (#92) 10 | 11 | ## v1.3.0 (2022-06-16) 12 | 13 | ### Feat 14 | 15 | - add (de)serialization customization (#76) 16 | 17 | ## v1.2.0 (2022-01-21) 18 | 19 | ### Feat 20 | 21 | - add support for dask.distributed (thanks to @cbyrohl) 22 | - rescaffold project 23 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | FROM python:3.8-slim AS base 3 | 4 | # Configure Python to print tracebacks on crash [1], and to not buffer stdout and stderr [2]. 5 | # [1] https://docs.python.org/3/using/cmdline.html#envvar-PYTHONFAULTHANDLER 6 | # [2] https://docs.python.org/3/using/cmdline.html#envvar-PYTHONUNBUFFERED 7 | ENV PYTHONFAULTHANDLER 1 8 | ENV PYTHONUNBUFFERED 1 9 | 10 | # Install Poetry. 11 | ENV POETRY_VERSION 1.1.13 12 | RUN --mount=type=cache,target=/root/.cache/ \ 13 | pip install poetry==$POETRY_VERSION 14 | 15 | # Create and activate a virtual environment. 16 | RUN python -m venv /opt/app-env 17 | ENV PATH /opt/app-env/bin:$PATH 18 | ENV VIRTUAL_ENV /opt/app-env 19 | 20 | # Install compilers that may be required for certain packages or platforms. 21 | RUN rm /etc/apt/apt.conf.d/docker-clean 22 | RUN --mount=type=cache,target=/var/cache/apt/ \ 23 | --mount=type=cache,target=/var/lib/apt/ \ 24 | apt-get update && \ 25 | apt-get install --no-install-recommends --yes build-essential 26 | 27 | # Set the working directory. 28 | WORKDIR /app/ 29 | 30 | # Install the run time Python environment. 31 | COPY poetry.lock* pyproject.toml /app/ 32 | RUN --mount=type=cache,target=/root/.cache/ \ 33 | mkdir -p src/graphchain/ && touch src/graphchain/__init__.py && touch README.md && \ 34 | poetry install --no-dev --no-interaction 35 | 36 | # Create a non-root user. 37 | ARG UID=1000 38 | ARG GID=$UID 39 | RUN groupadd --gid $GID app && \ 40 | useradd --create-home --gid $GID --uid $UID app 41 | 42 | FROM base as ci 43 | 44 | # Install git so we can run pre-commit. 45 | RUN --mount=type=cache,target=/var/cache/apt/ \ 46 | --mount=type=cache,target=/var/lib/apt/ \ 47 | apt-get update && \ 48 | apt-get install --no-install-recommends --yes git 49 | 50 | # Install the development Python environment. 51 | RUN --mount=type=cache,target=/root/.cache/ \ 52 | poetry install --no-interaction 53 | 54 | # Give the non-root user ownership and switch to the non-root user. 55 | RUN chown --recursive app /app/ /opt/ 56 | USER app 57 | 58 | FROM base as dev 59 | 60 | # Install development tools: compilers, curl, git, gpg, ssh, starship, sudo, vim, and zsh. 61 | RUN --mount=type=cache,target=/var/cache/apt/ \ 62 | --mount=type=cache,target=/var/lib/apt/ \ 63 | apt-get update && \ 64 | apt-get install --no-install-recommends --yes build-essential curl git gnupg ssh sudo vim zsh zsh-antigen && \ 65 | sh -c "$(curl -fsSL https://starship.rs/install.sh)" -- "--yes" && \ 66 | usermod --shell /usr/bin/zsh app 67 | 68 | # Install the development Python environment. 69 | RUN --mount=type=cache,target=/root/.cache/ \ 70 | poetry install --no-interaction 71 | 72 | # Persist output generated during docker build so that we can restore it in the dev container. 73 | COPY .pre-commit-config.yaml /app/ 74 | RUN mkdir -p /opt/build/poetry/ && cp poetry.lock /opt/build/poetry/ && \ 75 | git init && pre-commit install --install-hooks && \ 76 | mkdir -p /opt/build/git/ && cp .git/hooks/commit-msg .git/hooks/pre-commit /opt/build/git/ 77 | 78 | # Give the non-root user ownership and switch to the non-root user. 79 | RUN chown --recursive app /app/ /opt/ && \ 80 | echo 'app ALL=(root) NOPASSWD:ALL' > /etc/sudoers.d/app && \ 81 | chmod 0440 /etc/sudoers.d/app 82 | USER app 83 | 84 | # Configure the non-root user's shell. 85 | RUN echo 'source /usr/share/zsh-antigen/antigen.zsh' >> ~/.zshrc && \ 86 | echo 'antigen bundle zsh-users/zsh-syntax-highlighting' >> ~/.zshrc && \ 87 | echo 'antigen bundle zsh-users/zsh-autosuggestions' >> ~/.zshrc && \ 88 | echo 'antigen apply' >> ~/.zshrc && \ 89 | echo 'eval "$(starship init zsh)"' >> ~/.zshrc && \ 90 | echo 'HISTFILE=~/.zsh_history' >> ~/.zshrc && \ 91 | zsh -c 'source ~/.zshrc' 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 radix.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License](https://img.shields.io/github/license/mashape/apistatus.svg)](https://choosealicense.com/licenses/mit/) [![PyPI](https://img.shields.io/pypi/v/graphchain.svg)](https://pypi.python.org/pypi/graphchain/) 2 | 3 | # Graphchain 4 | 5 | ## What is graphchain? 6 | 7 | Graphchain is like [joblib.Memory](https://joblib.readthedocs.io/en/latest/memory.html) for dask graphs. [Dask graph computations](https://docs.dask.org/en/latest/spec.html) are cached to a local or remote location of your choice, specified by a [PyFilesystem FS URL](https://docs.pyfilesystem.org/en/latest/openers.html). 8 | 9 | When you change your dask graph (by changing a computation's implementation or its inputs), graphchain will take care to only recompute the minimum number of computations necessary to fetch the result. This allows you to iterate quickly over your graph without spending time on recomputing previously computed keys. 10 | 11 |

12 |
13 | Source: xkcd.com/1205/ 14 |

15 | 16 | The main difference between graphchain and joblib.Memory is that in graphchain a computation's materialised inputs are _not_ serialised and hashed (which can be very expensive when the inputs are large objects such as pandas DataFrames). Instead, a chain of hashes (hence the name graphchain) of the computation object and its dependencies (which are also computation objects) is used to identify the cache file. 17 | 18 | Additionally, the result of a computation is only cached if it is estimated that loading that computation from cache will save time compared to simply computing the computation. The decision on whether to cache depends on the characteristics of the cache location, which are different when caching to the local filesystem compared to caching to S3 for example. 19 | 20 | ## Usage by example 21 | 22 | ### Basic usage 23 | 24 | Install graphchain with pip to get started: 25 | 26 | ```sh 27 | pip install graphchain 28 | ``` 29 | 30 | To demonstrate how graphchain can save you time, let's first create a simple dask graph that (1) creates a few pandas DataFrames, (2) runs a relatively heavy operation on these DataFrames, and (3) summarises the results. 31 | 32 | ```python 33 | import dask 34 | import graphchain 35 | import pandas as pd 36 | 37 | def create_dataframe(num_rows, num_cols): 38 | print("Creating DataFrame...") 39 | return pd.DataFrame(data=[range(num_cols)]*num_rows) 40 | 41 | def expensive_computation(df, num_quantiles): 42 | print("Running expensive computation on DataFrame...") 43 | return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)]) 44 | 45 | def summarize_dataframes(*dfs): 46 | print("Summing DataFrames...") 47 | return sum(df.sum().sum() for df in dfs) 48 | 49 | dsk = { 50 | "df_a": (create_dataframe, 10_000, 1000), 51 | "df_b": (create_dataframe, 10_000, 1000), 52 | "df_c": (expensive_computation, "df_a", 2048), 53 | "df_d": (expensive_computation, "df_b", 2048), 54 | "result": (summarize_dataframes, "df_c", "df_d") 55 | } 56 | ``` 57 | 58 | Using `dask.get` to fetch the `"result"` key takes about 6 seconds: 59 | 60 | ```python 61 | >>> %time dask.get(dsk, "result") 62 | 63 | Creating DataFrame... 64 | Running expensive computation on DataFrame... 65 | Creating DataFrame... 66 | Running expensive computation on DataFrame... 67 | Summing DataFrames... 68 | 69 | CPU times: user 7.39 s, sys: 686 ms, total: 8.08 s 70 | Wall time: 6.19 s 71 | ``` 72 | 73 | On the other hand, using `graphchain.get` for the first time to fetch `'result'` takes only 4 seconds: 74 | 75 | ```python 76 | >>> %time graphchain.get(dsk, "result") 77 | 78 | Creating DataFrame... 79 | Running expensive computation on DataFrame... 80 | Summing DataFrames... 81 | 82 | CPU times: user 4.7 s, sys: 519 ms, total: 5.22 s 83 | Wall time: 4.04 s 84 | ``` 85 | 86 | The reason `graphchain.get` is faster than `dask.get` is because it can load `df_b` and `df_d` from cache after `df_a` and `df_c` have been computed and cached. Note that graphchain will only cache the result of a computation if loading that computation from cache is estimated to be faster than simply running the computation. 87 | 88 | Running `graphchain.get` a second time to fetch `"result"` will be almost instant since this time the result itself is also available from cache: 89 | 90 | ```python 91 | >>> %time graphchain.get(dsk, "result") 92 | 93 | CPU times: user 4.79 ms, sys: 1.79 ms, total: 6.58 ms 94 | Wall time: 5.34 ms 95 | ``` 96 | 97 | Now let's say we want to change how the result is summarised from a sum to an average: 98 | 99 | ```python 100 | def summarize_dataframes(*dfs): 101 | print("Averaging DataFrames...") 102 | return sum(df.mean().mean() for df in dfs) / len(dfs) 103 | ``` 104 | 105 | If we then ask graphchain to fetch `"result"`, it will detect that only `summarize_dataframes` has changed and therefore only recompute this function with inputs loaded from cache: 106 | 107 | ```python 108 | >>> %time graphchain.get(dsk, "result") 109 | 110 | Averaging DataFrames... 111 | 112 | CPU times: user 123 ms, sys: 37.2 ms, total: 160 ms 113 | Wall time: 86.6 ms 114 | ``` 115 | 116 | ### Storing the graphchain cache remotely 117 | 118 | Graphchain's cache is by default `./__graphchain_cache__`, but you can ask graphchain to use a cache at any [PyFilesystem FS URL](https://docs.pyfilesystem.org/en/latest/openers.html) such as `s3://mybucket/__graphchain_cache__`: 119 | 120 | ```python 121 | graphchain.get(dsk, "result", location="s3://mybucket/__graphchain_cache__") 122 | ``` 123 | 124 | ### Excluding keys from being cached 125 | 126 | In some cases you may not want a key to be cached. To avoid writing certain keys to the graphchain cache, you can use the `skip_keys` argument: 127 | 128 | ```python 129 | graphchain.get(dsk, "result", skip_keys=["result"]) 130 | ``` 131 | 132 | ### Using graphchain with dask.delayed 133 | 134 | Alternatively, you can use graphchain together with dask.delayed for easier dask graph creation: 135 | 136 | ```python 137 | import dask 138 | import pandas as pd 139 | 140 | @dask.delayed 141 | def create_dataframe(num_rows, num_cols): 142 | print("Creating DataFrame...") 143 | return pd.DataFrame(data=[range(num_cols)]*num_rows) 144 | 145 | @dask.delayed 146 | def expensive_computation(df, num_quantiles): 147 | print("Running expensive computation on DataFrame...") 148 | return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)]) 149 | 150 | @dask.delayed 151 | def summarize_dataframes(*dfs): 152 | print("Summing DataFrames...") 153 | return sum(df.sum().sum() for df in dfs) 154 | 155 | df_a = create_dataframe(num_rows=10_000, num_cols=1000) 156 | df_b = create_dataframe(num_rows=10_000, num_cols=1000) 157 | df_c = expensive_computation(df_a, num_quantiles=2048) 158 | df_d = expensive_computation(df_b, num_quantiles=2048) 159 | result = summarize_dataframes(df_c, df_d) 160 | ``` 161 | 162 | After which you can compute `result` by setting the `delayed_optimize` method to `graphchain.optimize`: 163 | 164 | ```python 165 | import graphchain 166 | from functools import partial 167 | 168 | optimize_s3 = partial(graphchain.optimize, location="s3://mybucket/__graphchain_cache__/") 169 | 170 | with dask.config.set(scheduler="sync", delayed_optimize=optimize_s3): 171 | print(result.compute()) 172 | ``` 173 | 174 | ### Using a custom a serializer/deserializer 175 | 176 | By default graphchain will cache dask computations with [joblib.dump](https://joblib.readthedocs.io/en/latest/generated/joblib.dump.html) and LZ4 compression. However, you may also supply a custom `serialize` and `deserialize` function that writes and reads computations to and from a [PyFilesystem filesystem](https://docs.pyfilesystem.org/en/latest/introduction.html), respectively. For example, the following snippet shows how to serialize dask DataFrames with [dask.dataframe.to_parquet](https://docs.dask.org/en/stable/generated/dask.dataframe.to_parquet.html), while other objects are serialized with joblib: 177 | 178 | ```python 179 | import dask.dataframe 180 | import graphchain 181 | import fs.osfs 182 | import joblib 183 | import os 184 | from functools import partial 185 | from typing import Any 186 | 187 | def custom_serialize(obj: Any, fs: fs.osfs.OSFS, key: str) -> None: 188 | """Serialize dask DataFrames with to_parquet, and other objects with joblib.dump.""" 189 | if isinstance(obj, dask.dataframe.DataFrame): 190 | obj.to_parquet(os.path.join(fs.root_path, "parquet", key)) 191 | else: 192 | with fs.open(f"{key}.joblib", "wb") as fid: 193 | joblib.dump(obj, fid) 194 | 195 | def custom_deserialize(fs: fs.osfs.OSFS, key: str) -> Any: 196 | """Deserialize dask DataFrames with read_parquet, and other objects with joblib.load.""" 197 | if fs.exists(f"{key}.joblib"): 198 | with fs.open(f"{key}.joblib", "rb") as fid: 199 | return joblib.load(fid) 200 | else: 201 | return dask.dataframe.read_parquet(os.path.join(fs.root_path, "parquet", key)) 202 | 203 | optimize_parquet = partial( 204 | graphchain.optimize, 205 | location="./__graphchain_cache__/custom/", 206 | serialize=custom_serialize, 207 | deserialize=custom_deserialize 208 | ) 209 | 210 | with dask.config.set(scheduler="sync", delayed_optimize=optimize_parquet): 211 | print(result.compute()) 212 | ``` 213 | 214 | ## Contributing 215 | 216 |
217 | Setup: once per device 218 | 219 | 1. [Generate an SSH key](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/generating-a-new-ssh-key-and-adding-it-to-the-ssh-agent#generating-a-new-ssh-key) and [add the SSH key to your GitHub account](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/adding-a-new-ssh-key-to-your-github-account). 220 | 1. Configure SSH to automatically load your SSH keys: 221 | ```sh 222 | cat << EOF >> ~/.ssh/config 223 | Host * 224 | AddKeysToAgent yes 225 | IgnoreUnknown UseKeychain 226 | UseKeychain yes 227 | EOF 228 | ``` 229 | 1. [Install Docker Desktop](https://www.docker.com/get-started). 230 | - Enable _Use Docker Compose V2_ in Docker Desktop's preferences window. 231 | - _Linux only_: 232 | - [Configure Docker and Docker Compose to use the BuildKit build system](https://docs.docker.com/develop/develop-images/build_enhancements/#to-enable-buildkit-builds). On macOS and Windows, BuildKit is enabled by default in Docker Desktop. 233 | - Export your user's user id and group id so that [files created in the Dev Container are owned by your user](https://github.com/moby/moby/issues/3206): 234 | ```sh 235 | cat << EOF >> ~/.bashrc 236 | export UID=$(id --user) 237 | export GID=$(id --group) 238 | EOF 239 | ``` 240 | 1. [Install VS Code](https://code.visualstudio.com/) and [VS Code's Remote-Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). Alternatively, install [PyCharm](https://www.jetbrains.com/pycharm/download/). 241 | - _Optional:_ Install a [Nerd Font](https://www.nerdfonts.com/font-downloads) such as [FiraCode Nerd Font](https://github.com/ryanoasis/nerd-fonts/tree/master/patched-fonts/FiraCode) with `brew tap homebrew/cask-fonts && brew install --cask font-fira-code-nerd-font` and [configure VS Code](https://github.com/tonsky/FiraCode/wiki/VS-Code-Instructions) or [configure PyCharm](https://github.com/tonsky/FiraCode/wiki/Intellij-products-instructions) to use `'FiraCode Nerd Font'`. 242 | 243 |
244 | 245 |
246 | Setup: once per project 247 | 248 | 1. Clone this repository. 249 | 2. Start a [Dev Container](https://code.visualstudio.com/docs/remote/containers) in your preferred development environment: 250 | - _VS Code_: open the cloned repository and run Ctrl/⌘ + + P → _Remote-Containers: Reopen in Container_. 251 | - _PyCharm_: open the cloned repository and [configure Docker Compose as a remote interpreter](https://www.jetbrains.com/help/pycharm/using-docker-compose-as-a-remote-interpreter.html#docker-compose-remote). 252 | - _Terminal_: open the cloned repository and run `docker compose run --rm dev` to start an interactive Dev Container. 253 | 254 |
255 | 256 |
257 | Developing 258 | 259 | - This project follows the [Conventional Commits](https://www.conventionalcommits.org/) standard to automate [Semantic Versioning](https://semver.org/) and [Keep A Changelog](https://keepachangelog.com/) with [Commitizen](https://github.com/commitizen-tools/commitizen). 260 | - Run `poe` from within the development environment to print a list of [Poe the Poet](https://github.com/nat-n/poethepoet) tasks available to run on this project. 261 | - Run `poetry add {package}` from within the development environment to install a run time dependency and add it to `pyproject.toml` and `poetry.lock`. 262 | - Run `poetry remove {package}` from within the development environment to uninstall a run time dependency and remove it from `pyproject.toml` and `poetry.lock`. 263 | - Run `poetry update` from within the development environment to upgrade all dependencies to the latest versions allowed by `pyproject.toml`. 264 | - Run `cz bump` to bump the package's version, update the `CHANGELOG.md`, and create a git tag. 265 | 266 |
267 | 268 | ## Developed by Radix 269 | 270 | [Radix](https://radix.ai) is a Belgium-based Machine Learning company. 271 | 272 | Our vision is to make technology work for and with us. We believe that if technology is used in a creative way, jobs become more fulfilling, people become the best version of themselves, and companies grow. 273 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | dev: 5 | build: 6 | context: . 7 | target: dev 8 | args: 9 | UID: ${UID:-1000} 10 | GID: ${GID:-1000} 11 | stdin_open: true 12 | tty: true 13 | entrypoint: [] 14 | command: 15 | [ 16 | "sh", 17 | "-c", 18 | "cp --update /opt/build/poetry/poetry.lock /app/ && mkdir -p /app/.git/hooks/ && cp --update /opt/build/git/* /app/.git/hooks/ && zsh" 19 | ] 20 | environment: 21 | - POETRY_PYPI_TOKEN_PYPI 22 | - SSH_AUTH_SOCK=/run/host-services/ssh-auth.sock 23 | volumes: 24 | - .:/app/ 25 | - ~/.gitconfig:/etc/gitconfig 26 | - ~/.ssh/known_hosts:/home/app/.ssh/known_hosts 27 | - ${SSH_AGENT_AUTH_SOCK:-/run/host-services/ssh-auth.sock}:/run/host-services/ssh-auth.sock 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # https://python-poetry.org/docs/pyproject/#poetry-and-pep-517 2 | [build-system] 3 | requires = ["poetry-core>=1.0.0"] 4 | build-backend = "poetry.core.masonry.api" 5 | 6 | # https://python-poetry.org/docs/pyproject/ 7 | [tool.poetry] 8 | name = "graphchain" 9 | version = "1.4.0" 10 | description = "An efficient cache for the execution of dask graphs." 11 | authors = ["Laurent Sorber "] 12 | readme = "README.md" 13 | repository = "https://github.com/radix-ai/graphchain" 14 | 15 | # https://commitizen-tools.github.io/commitizen/config/ 16 | [tool.commitizen] 17 | bump_message = "bump(release): v$current_version → v$new_version" 18 | tag_format = "v$version" 19 | update_changelog_on_bump = true 20 | version = "1.4.0" 21 | version_files = ["pyproject.toml:version"] 22 | 23 | # https://python-poetry.org/docs/dependency-specification/ 24 | [tool.poetry.dependencies] 25 | cloudpickle = ">=1.0.0,<3.0.0" 26 | dask = ">=2020.12.0" 27 | fs-s3fs = "^1" 28 | joblib = "^1" 29 | lz4 = ">=3,<5" 30 | python = "^3.8,<3.11" 31 | 32 | # https://python-poetry.org/docs/master/managing-dependencies/ 33 | # TODO: Split in `tool.poetry.group.dev` and `tool.poetry.group.test` when Poetry 1.2.0 is released. 34 | [tool.poetry.dev-dependencies] 35 | absolufy-imports = "^0.3.1" 36 | bandit = { extras = ["toml"], version = "^1.7.4" } 37 | black = "^22.6.0" 38 | commitizen = "^2.27.1" 39 | coverage = { extras = ["toml"], version = "^6.4.1" } 40 | cruft = "^2.11.0" 41 | darglint = "^1.8.1" 42 | fastparquet = "^0.8.1" 43 | flake8 = "^5.0.4" 44 | flake8-bugbear = "^22.6.22" 45 | flake8-comprehensions = "^3.10.0" 46 | flake8-mutable = "^1.2.0" 47 | flake8-print = "^5.0.0" 48 | flake8-pytest-style = "^1.6.0" 49 | flake8-rst-docstrings = "^0.2.6" 50 | flake8-tidy-imports = "^4.8.0" 51 | isort = "^5.10.1" 52 | mypy = "^0.971" 53 | pandas = "^1.3.5" 54 | pandas-stubs = "^1.4.3" 55 | pdoc = "^12.0.2" 56 | pep8-naming = "^0.13.0" 57 | poethepoet = "^0.16.0" 58 | pre-commit = "^2.19.0" 59 | pydocstyle = { extras = ["toml"], version = "^6.1.1" } 60 | pytest = "^7.1.2" 61 | pytest-clarity = "^1.0.1" 62 | pytest-mock = "^3.8.1" 63 | pytest-xdist = "^2.5.0" 64 | pyupgrade = "^2.34.0" 65 | safety = "^2.1.1" 66 | shellcheck-py = "^0.8.0" 67 | yesqa = "^1.4.0" 68 | 69 | # https://bandit.readthedocs.io/en/latest/config.html 70 | [tool.bandit] 71 | skips = ["B101", "B110", "B403"] 72 | 73 | # https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file 74 | [tool.black] 75 | line-length = 100 76 | 77 | # https://coverage.readthedocs.io/en/latest/config.html#report 78 | [tool.coverage.report] 79 | fail_under = 50 80 | precision = 1 81 | show_missing = true 82 | skip_covered = true 83 | 84 | # https://coverage.readthedocs.io/en/latest/config.html#run 85 | [tool.coverage.run] 86 | branch = true 87 | command_line = "--module pytest" 88 | data_file = "reports/.coverage" 89 | source = ["src"] 90 | 91 | # https://coverage.readthedocs.io/en/latest/config.html#xml 92 | [tool.coverage.xml] 93 | output = "reports/coverage.xml" 94 | 95 | # https://pycqa.github.io/isort/docs/configuration/options.html 96 | [tool.isort] 97 | color_output = true 98 | line_length = 100 99 | profile = "black" 100 | src_paths = ["src", "tests"] 101 | 102 | # https://mypy.readthedocs.io/en/latest/config_file.html 103 | [tool.mypy] 104 | junit_xml = "reports/mypy.xml" 105 | strict = true 106 | disallow_subclassing_any = false 107 | disallow_untyped_calls = false 108 | disallow_untyped_decorators = false 109 | ignore_missing_imports = true 110 | no_implicit_reexport = false 111 | pretty = true 112 | show_column_numbers = true 113 | show_error_codes = true 114 | show_error_context = true 115 | warn_unreachable = true 116 | 117 | # http://www.pydocstyle.org/en/latest/usage.html#configuration-files 118 | [tool.pydocstyle] 119 | convention = "numpy" 120 | 121 | # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref 122 | [tool.pytest.ini_options] 123 | addopts = "--color=yes --doctest-modules --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml" 124 | filterwarnings = ["error", "ignore::DeprecationWarning"] 125 | testpaths = ["src", "tests"] 126 | xfail_strict = true 127 | 128 | # https://github.com/nat-n/poethepoet 129 | [tool.poe.tasks] 130 | 131 | [tool.poe.tasks.docs] 132 | help = "Generate this package's docs" 133 | cmd = """ 134 | pdoc 135 | --docformat $docformat 136 | --output-directory $outputdirectory 137 | graphchain 138 | """ 139 | 140 | [[tool.poe.tasks.docs.args]] 141 | help = "The docstring style (default: numpy)" 142 | name = "docformat" 143 | options = ["--docformat"] 144 | default = "numpy" 145 | 146 | [[tool.poe.tasks.docs.args]] 147 | help = "The output directory (default: docs)" 148 | name = "outputdirectory" 149 | options = ["--output-directory"] 150 | default = "docs" 151 | 152 | [tool.poe.tasks.lint] 153 | help = "Lint this package" 154 | 155 | [[tool.poe.tasks.lint.sequence]] 156 | cmd = """ 157 | pre-commit run 158 | --all-files 159 | --color always 160 | """ 161 | 162 | [[tool.poe.tasks.lint.sequence]] 163 | shell = "safety check --continue-on-error --full-report" 164 | 165 | [tool.poe.tasks.test] 166 | help = "Test this package" 167 | 168 | [[tool.poe.tasks.test.sequence]] 169 | cmd = "coverage run" 170 | 171 | [[tool.poe.tasks.test.sequence]] 172 | cmd = "coverage report" 173 | 174 | [[tool.poe.tasks.test.sequence]] 175 | cmd = "coverage xml" 176 | -------------------------------------------------------------------------------- /src/graphchain/__init__.py: -------------------------------------------------------------------------------- 1 | """Graphchain is a cache for dask graphs.""" 2 | 3 | from graphchain.core import get, optimize 4 | 5 | __all__ = ["get", "optimize"] 6 | -------------------------------------------------------------------------------- /src/graphchain/core.py: -------------------------------------------------------------------------------- 1 | """Graphchain core.""" 2 | 3 | import datetime as dt 4 | import logging 5 | import time 6 | from copy import deepcopy 7 | from functools import lru_cache, partial 8 | from pickle import HIGHEST_PROTOCOL 9 | from typing import ( 10 | Any, 11 | Callable, 12 | Container, 13 | Dict, 14 | Hashable, 15 | Iterable, 16 | Optional, 17 | Sequence, 18 | TypeVar, 19 | Union, 20 | cast, 21 | ) 22 | 23 | import cloudpickle 24 | import dask 25 | import fs 26 | import fs.base 27 | import fs.memoryfs 28 | import fs.osfs 29 | import joblib 30 | from dask.highlevelgraph import HighLevelGraph, Layer 31 | 32 | from graphchain.utils import get_size, str_to_posix_fully_portable_filename 33 | 34 | T = TypeVar("T") 35 | 36 | 37 | # We have to define an lru_cache wrapper here because mypy doesn't support decorated properties [1]. 38 | # [1] https://github.com/python/mypy/issues/5858 39 | def _cache(__user_function: Callable[..., T]) -> Callable[..., T]: 40 | return lru_cache(maxsize=None)(__user_function) 41 | 42 | 43 | def hlg_setitem(self: HighLevelGraph, key: Hashable, value: Any) -> None: 44 | """Set a HighLevelGraph computation.""" 45 | for d in self.layers.values(): 46 | if key in d: 47 | d[key] = value # type: ignore[index] 48 | 49 | 50 | # Monkey patch HighLevelGraph to add a missing `__setitem__` method. 51 | if not hasattr(HighLevelGraph, "__setitem__"): 52 | HighLevelGraph.__setitem__ = hlg_setitem # type: ignore[index] 53 | 54 | 55 | def layer_setitem(self: Layer, key: Hashable, value: Any) -> None: 56 | """Set a Layer computation.""" 57 | self.mapping[key] = value # type: ignore[attr-defined] 58 | 59 | 60 | # Monkey patch Layer to add a missing `__setitem__` method. 61 | if not hasattr(Layer, "__setitem__"): 62 | Layer.__setitem__ = layer_setitem # type: ignore[index] 63 | 64 | 65 | logger = logging.getLogger(__name__) 66 | 67 | 68 | def joblib_dump( 69 | obj: Any, fs: fs.base.FS, key: str, ext: str = "joblib", **kwargs: Dict[str, Any] 70 | ) -> None: 71 | """Store an object on a filesystem.""" 72 | filename = f"{key}.{ext}" 73 | with fs.open(filename, "wb") as fid: 74 | joblib.dump(obj, fid, **kwargs) 75 | 76 | 77 | def joblib_load(fs: fs.base.FS, key: str, ext: str = "joblib", **kwargs: Dict[str, Any]) -> Any: 78 | """Load an object from a filesystem.""" 79 | filename = f"{key}.{ext}" 80 | with fs.open(filename, "rb") as fid: 81 | return joblib.load(fid, **kwargs) 82 | 83 | 84 | joblib_dump_lz4 = partial(joblib_dump, compress="lz4", ext="joblib.lz4", protocol=HIGHEST_PROTOCOL) 85 | joblib_load_lz4 = partial(joblib_load, ext="joblib.lz4") 86 | 87 | 88 | class CacheFS: 89 | """Lazily opened PyFileSystem.""" 90 | 91 | def __init__(self, location: Union[str, fs.base.FS]): 92 | """Wrap a PyFilesystem FS URL to provide a single FS instance. 93 | 94 | Parameters 95 | ---------- 96 | location 97 | A PyFilesystem FS URL to store the cached computations in. Can be a local directory such 98 | as ``"./__graphchain_cache__/"`` or a remote directory such as 99 | ``"s3://bucket/__graphchain_cache__/"``. You can also pass a PyFilesystem itself 100 | instead. 101 | """ 102 | self.location = location 103 | 104 | @property # type: ignore[misc] 105 | @_cache 106 | def fs(self) -> fs.base.FS: 107 | """Open a PyFilesystem FS to the cache directory.""" 108 | # create=True does not yet work for S3FS [1]. This should probably be left to the user as we 109 | # don't know in which region to create the bucket, among other configuration options. 110 | # [1] https://github.com/PyFilesystem/s3fs/issues/23 111 | if isinstance(self.location, fs.base.FS): 112 | return self.location 113 | return fs.open_fs(self.location, create=True) 114 | 115 | 116 | class CachedComputation: 117 | """A replacement for computations in dask graphs.""" 118 | 119 | def __init__( 120 | self, 121 | dsk: Dict[Hashable, Any], 122 | key: Hashable, 123 | computation: Any, 124 | location: Union[str, fs.base.FS, CacheFS] = "./__graphchain_cache__/", 125 | serialize: Callable[[Any, fs.base.FS, str], None] = joblib_dump_lz4, 126 | deserialize: Callable[[fs.base.FS, str], Any] = joblib_load_lz4, 127 | write_to_cache: Union[bool, str] = "auto", 128 | ) -> None: 129 | """Cache a dask graph computation. 130 | 131 | A wrapper for the computation object to replace the original computation with in the dask 132 | graph. 133 | 134 | Parameters 135 | ---------- 136 | dsk 137 | The dask graph this computation is a part of. 138 | key 139 | The key corresponding to this computation in the dask graph. 140 | computation 141 | The computation to cache. 142 | location 143 | A PyFilesystem FS URL to store the cached computations in. Can be a local directory such 144 | as ``"./__graphchain_cache__/"`` or a remote directory such as 145 | ``"s3://bucket/__graphchain_cache__/"``. You can also pass a CacheFS instance or a 146 | PyFilesystem itself instead. 147 | serialize 148 | A function of the form ``serialize(result: Any, fs: fs.base.FS, key: str)`` that caches 149 | a computation `result` to a filesystem `fs` under a given `key`. 150 | deserialize 151 | A function of the form ``deserialize(fs: fs.base.FS, key: str)`` that reads a cached 152 | computation `result` from a `key` on a given filesystem `fs`. 153 | write_to_cache 154 | Whether or not to cache this computation. If set to ``"auto"``, will only write to cache 155 | if it is expected this will speed up future gets of this computation, taking into 156 | account the characteristics of the `location` filesystem. 157 | """ 158 | self.dsk = dsk 159 | self.key = key 160 | self.computation = computation 161 | self.location = location 162 | self.serialize = serialize 163 | self.deserialize = deserialize 164 | self.write_to_cache = write_to_cache 165 | 166 | @property # type: ignore[misc] 167 | @_cache 168 | def cache_fs(self) -> fs.base.FS: 169 | """Open a PyFilesystem FS to the cache directory.""" 170 | # create=True does not yet work for S3FS [1]. This should probably be left to the user as we 171 | # don't know in which region to create the bucket, among other configuration options. 172 | # [1] https://github.com/PyFilesystem/s3fs/issues/23 173 | if isinstance(self.location, fs.base.FS): 174 | return self.location 175 | if isinstance(self.location, CacheFS): 176 | return self.location.fs 177 | return fs.open_fs(self.location, create=True) 178 | 179 | def __repr__(self) -> str: 180 | """Represent this CachedComputation object as a string.""" 181 | return f"" 182 | 183 | def _subs_dependencies_with_hash(self, computation: Any) -> Any: 184 | """Replace key references in a computation by their hashes.""" 185 | dependencies = dask.core.get_dependencies( 186 | self.dsk, task=0 if computation is None else computation 187 | ) 188 | for dep in dependencies: 189 | computation = dask.core.subs( 190 | computation, 191 | dep, 192 | self.dsk[dep].hash 193 | if isinstance(self.dsk[dep], CachedComputation) 194 | else self.dsk[dep][0].hash, 195 | ) 196 | return computation 197 | 198 | def _subs_tasks_with_src(self, computation: Any) -> Any: 199 | """Replace task functions by their source code.""" 200 | if type(computation) is list: 201 | # This computation is a list of computations. 202 | computation = [self._subs_tasks_with_src(x) for x in computation] 203 | elif dask.core.istask(computation): 204 | # This computation is a task. 205 | src = joblib.func_inspect.get_func_code(computation[0])[0] 206 | computation = (src,) + computation[1:] 207 | return computation 208 | 209 | def compute_hash(self) -> str: 210 | """Compute a hash of this computation object and its dependencies.""" 211 | # Replace dependencies with their hashes and functions with source. 212 | computation = self._subs_dependencies_with_hash(self.computation) 213 | computation = self._subs_tasks_with_src(computation) 214 | # Return the hash of the resulting computation. 215 | comp_hash: str = joblib.hash(cloudpickle.dumps(computation)) 216 | return comp_hash 217 | 218 | @property 219 | def hash(self) -> str: 220 | """Return the hash of this CachedComputation.""" 221 | if not hasattr(self, "_hash"): 222 | self._hash = self.compute_hash() 223 | return self._hash 224 | 225 | def estimate_load_time(self, result: Any) -> float: 226 | """Estimate the time to load the given result from cache.""" 227 | size: float = get_size(result) / dask.config.get("cache_estimated_compression_ratio", 2.0) 228 | # Use typical SSD latency and bandwith if cache_fs is a local filesystem, else use a typical 229 | # latency and bandwidth for network-based filesystems. 230 | read_latency = float( 231 | dask.config.get( 232 | "cache_latency", 233 | 1e-4 if isinstance(self.cache_fs, (fs.osfs.OSFS, fs.memoryfs.MemoryFS)) else 50e-3, 234 | ) 235 | ) 236 | read_throughput = float( 237 | dask.config.get( 238 | "cache_throughput", 239 | 500e6 if isinstance(self.cache_fs, (fs.osfs.OSFS, fs.memoryfs.MemoryFS)) else 50e6, 240 | ) 241 | ) 242 | return read_latency + size / read_throughput 243 | 244 | @_cache 245 | def read_time(self, timing_type: str) -> float: 246 | """Read the time to load, compute, or store from file.""" 247 | time_filename = f"{self.hash}.time.{timing_type}" 248 | with self.cache_fs.open(time_filename, "r") as fid: 249 | return float(fid.read()) 250 | 251 | def write_time(self, timing_type: str, seconds: float) -> None: 252 | """Write the time to load, compute, or store from file.""" 253 | time_filename = f"{self.hash}.time.{timing_type}" 254 | with self.cache_fs.open(time_filename, "w") as fid: 255 | fid.write(str(seconds)) 256 | 257 | def write_log(self, log_type: str) -> None: 258 | """Write the timestamp of a load, compute, or store operation.""" 259 | key = str_to_posix_fully_portable_filename(str(self.key)) 260 | now = str_to_posix_fully_portable_filename(str(dt.datetime.now())) 261 | log_filename = f".{now}.{log_type}.{key}.log" 262 | with self.cache_fs.open(log_filename, "w") as fid: 263 | fid.write(self.hash) 264 | 265 | def time_to_result(self, memoize: bool = True) -> float: 266 | """Estimate the time to load or compute this computation.""" 267 | if hasattr(self, "_time_to_result"): 268 | return self._time_to_result # type: ignore[has-type,no-any-return] 269 | if memoize: 270 | try: 271 | try: 272 | load_time = self.read_time("load") 273 | except Exception: 274 | load_time = self.read_time("store") / 2 275 | self._time_to_result = load_time 276 | return load_time 277 | except Exception: 278 | pass 279 | compute_time = self.read_time("compute") 280 | dependency_time = 0 281 | dependencies = dask.core.get_dependencies( 282 | self.dsk, task=0 if self.computation is None else self.computation 283 | ) 284 | for dep in dependencies: 285 | dependency_time += self.dsk[dep][0].time_to_result() 286 | total_time = compute_time + dependency_time 287 | if memoize: 288 | self._time_to_result = total_time 289 | return total_time 290 | 291 | def cache_file_exists(self) -> bool: 292 | """Check if this CachedComputation's cache file exists.""" 293 | return self.cache_fs.exists(f"{self.hash}.time.store") 294 | 295 | def load(self) -> Any: 296 | """Load this result of this computation from cache.""" 297 | try: 298 | # Load from cache. 299 | start_time = time.perf_counter() 300 | logger.info(f"LOAD {self} from {self.cache_fs}/{self.hash}") 301 | result = self.deserialize(self.cache_fs, self.hash) 302 | load_time = time.perf_counter() - start_time 303 | # Write load time and log operation. 304 | self.write_time("load", load_time) 305 | self.write_log("load") 306 | return result 307 | except Exception: 308 | logger.exception( 309 | f"Could not read {self.cache_fs}/{self.hash}. Marking cache as invalid, please try again!" 310 | ) 311 | self.cache_fs.remove(f"{self.hash}.time.store") 312 | raise 313 | 314 | def compute(self, *args: Any, **kwargs: Any) -> Any: 315 | """Compute this computation.""" 316 | # Compute the computation. 317 | logger.info(f"COMPUTE {self}") 318 | start_time = time.perf_counter() 319 | if dask.core.istask(self.computation): 320 | result = self.computation[0](*args, **kwargs) 321 | else: 322 | result = args[0] 323 | compute_time = time.perf_counter() - start_time 324 | # Write compute time and log operation 325 | self.write_time("compute", compute_time) 326 | self.write_log("compute") 327 | return result 328 | 329 | def store(self, result: Any) -> None: 330 | """Store the result of this computation in the cache.""" 331 | if not self.cache_file_exists(): 332 | logger.info(f"STORE {self} to {self.cache_fs}/{self.hash}") 333 | try: 334 | # Store to cache. 335 | start_time = time.perf_counter() 336 | self.serialize(result, self.cache_fs, self.hash) 337 | store_time = time.perf_counter() - start_time 338 | # Write store time and log operation 339 | self.write_time("store", store_time) 340 | self.write_log("store") 341 | except Exception: 342 | # Not crucial to stop if caching fails. 343 | logger.exception(f"Could not write {self.hash}.") 344 | 345 | def patch_computation_in_graph(self) -> None: 346 | """Patch the graph to use this CachedComputation.""" 347 | if self.cache_file_exists(): 348 | # If there are cache candidates to load this computation from, remove all dependencies 349 | # for this task from the graph as far as dask is concerned. 350 | self.dsk[self.key] = (self,) 351 | else: 352 | # If there are no cache candidates, wrap the execution of the computation with this 353 | # CachedComputation's __call__ method and keep references to its dependencies. 354 | self.dsk[self.key] = ( 355 | (self,) + self.computation[1:] 356 | if dask.core.istask(self.computation) 357 | else (self, self.computation) 358 | ) 359 | 360 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 361 | """Load this computation from cache, or compute and then store it.""" 362 | # Load. 363 | if self.cache_file_exists(): 364 | return self.load() 365 | # Compute. 366 | result = self.compute(*args, **kwargs) 367 | # Store. 368 | write_to_cache = self.write_to_cache 369 | if write_to_cache == "auto": 370 | compute_time = self.time_to_result(memoize=False) 371 | estimated_load_time = self.estimate_load_time(result) 372 | write_to_cache = estimated_load_time < compute_time 373 | logger.debug( 374 | f'{"Going" if write_to_cache else "Not going"} to cache {self}' 375 | f" because estimated_load_time={estimated_load_time} " 376 | f'{"<" if write_to_cache else ">="} ' 377 | f"compute_time={compute_time}" 378 | ) 379 | if write_to_cache: 380 | self.store(result) 381 | return result 382 | 383 | 384 | def optimize( 385 | dsk: Union[Dict[Hashable, Any], HighLevelGraph], 386 | keys: Optional[Union[Hashable, Iterable[Hashable]]] = None, 387 | skip_keys: Optional[Container[Hashable]] = None, 388 | location: Union[str, fs.base.FS, CacheFS] = "./__graphchain_cache__/", 389 | serialize: Callable[[Any, fs.base.FS, str], None] = joblib_dump_lz4, 390 | deserialize: Callable[[fs.base.FS, str], Any] = joblib_load_lz4, 391 | ) -> Dict[Hashable, Any]: 392 | """Optimize a dask graph with cached computations. 393 | 394 | According to the dask graph specification [1]_, a dask graph is a dictionary that maps `keys` to 395 | `computations`. A computation can be: 396 | 397 | 1. Another key in the graph. 398 | 2. A literal. 399 | 3. A task, which is of the form `(Callable, *args)`. 400 | 4. A list of other computations. 401 | 402 | This optimizer replaces all computations in a graph with ``CachedComputation``'s, so that 403 | getting items from the graph will be backed by a cache of your choosing. With this cache, only 404 | the very minimum number of computations will actually be computed to return the values 405 | corresponding to the given keys. 406 | 407 | `CachedComputation` objects *do not* hash task inputs (which is the approach that 408 | `functools.lru_cache` and `joblib.Memory` take) to identify which cache file to load. Instead, a 409 | chain of hashes (hence the name `graphchain`) of the computation object and its dependencies 410 | (which are also computation objects) is used to identify the cache file. 411 | 412 | Since it is generally cheap to hash the graph's computation objects, `graphchain`'s cache is 413 | likely to be much faster than hashing task inputs, which can be slow for large objects such as 414 | `pandas.DataFrame`'s. 415 | 416 | Parameters 417 | ---------- 418 | dsk 419 | The dask graph to optimize with caching computations. 420 | keys 421 | Not used. Is present for compatibility with dask optimizers [2]_. 422 | skip_keys 423 | A container of keys not to cache. 424 | location 425 | A PyFilesystem FS URL to store the cached computations in. Can be a local directory such as 426 | ``"./__graphchain_cache__/"`` or a remote directory such as 427 | ``"s3://bucket/__graphchain_cache__/"``. You can also pass a PyFilesystem itself instead. 428 | serialize 429 | A function of the form ``serialize(result: Any, fs: fs.base.FS, key: str)`` that caches a 430 | computation `result` to a filesystem `fs` under a given `key`. 431 | deserialize 432 | A function of the form ``deserialize(fs: fs.base.FS, key: str)`` that reads a cached 433 | computation `result` from a `key` on a given filesystem `fs`. 434 | 435 | Returns 436 | ------- 437 | dict 438 | A copy of the dask graph where the computations have been replaced by `CachedComputation`'s. 439 | 440 | References 441 | ---------- 442 | .. [1] https://docs.dask.org/en/latest/spec.html 443 | .. [2] https://docs.dask.org/en/latest/optimize.html 444 | """ 445 | # Technically a HighLevelGraph isn't actually a dict, but it has largely the same API so we can 446 | # treat it as one. We can't use a type union or protocol either, because HighLevelGraph doesn't 447 | # actually have a __setitem__ implementation, we just monkey-patched that in. 448 | dsk = cast(Dict[Hashable, Any], deepcopy(dsk)) 449 | # Verify that the graph is a DAG. 450 | assert dask.core.isdag(dsk, list(dsk.keys())) 451 | if isinstance(location, str): 452 | location = CacheFS(location) 453 | # Replace graph computations by CachedComputations. 454 | skip_keys = skip_keys or set() 455 | for key, computation in dsk.items(): 456 | dsk[key] = CachedComputation( 457 | dsk, 458 | key, 459 | computation, 460 | location=location, 461 | serialize=serialize, 462 | deserialize=deserialize, 463 | write_to_cache=False if key in skip_keys else "auto", 464 | ) 465 | # Remove task arguments if we can load from cache. 466 | for key in dsk: 467 | dsk[key].patch_computation_in_graph() 468 | 469 | return dsk 470 | 471 | 472 | def get( 473 | dsk: Dict[Hashable, Any], 474 | keys: Union[Hashable, Sequence[Hashable]], 475 | skip_keys: Optional[Container[Hashable]] = None, 476 | location: Union[str, fs.base.FS, CacheFS] = "./__graphchain_cache__/", 477 | serialize: Callable[[Any, fs.base.FS, str], None] = joblib_dump_lz4, 478 | deserialize: Callable[[fs.base.FS, str], Any] = joblib_load_lz4, 479 | scheduler: Optional[ 480 | Callable[[Dict[Hashable, Any], Union[Hashable, Sequence[Hashable]]], Any] 481 | ] = None, 482 | ) -> Any: 483 | """Get one or more keys from a dask graph with caching. 484 | 485 | Optimizes a dask graph with `graphchain.optimize` and then computes the requested keys with the 486 | desired scheduler, which is by default `dask.get`. 487 | 488 | See `graphchain.optimize` for more information on how `graphchain`'s cache mechanism works. 489 | 490 | Parameters 491 | ---------- 492 | dsk 493 | The dask graph to query. 494 | keys 495 | The keys to compute. 496 | skip_keys 497 | A container of keys not to cache. 498 | location 499 | A PyFilesystem FS URL to store the cached computations in. Can be a local directory such as 500 | ``"./__graphchain_cache__/"`` or a remote directory such as 501 | ``"s3://bucket/__graphchain_cache__/"``. You can also pass a PyFilesystem itself instead. 502 | serialize 503 | A function of the form ``serialize(result: Any, fs: fs.base.FS, key: str)`` that caches a 504 | computation `result` to a filesystem `fs` under a given `key`. 505 | deserialize 506 | A function of the form ``deserialize(fs: fs.base.FS, key: str)`` that reads a cached 507 | computation `result` from a `key` on a given filesystem `fs`. 508 | scheduler 509 | The dask scheduler to use to retrieve the keys from the graph. 510 | 511 | Returns 512 | ------- 513 | Any 514 | The computed values corresponding to the given keys. 515 | """ 516 | cached_dsk = optimize( 517 | dsk, 518 | keys, 519 | skip_keys=skip_keys, 520 | location=location, 521 | serialize=serialize, 522 | deserialize=deserialize, 523 | ) 524 | schedule = dask.base.get_scheduler(scheduler=scheduler) or dask.get 525 | return schedule(cached_dsk, keys) 526 | -------------------------------------------------------------------------------- /src/graphchain/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlinear-ai/graphchain/1146693720eb9a3077a342d6733106d3ec7de1ad/src/graphchain/py.typed -------------------------------------------------------------------------------- /src/graphchain/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions used by graphchain.""" 2 | 3 | import string 4 | import sys 5 | from typing import Any, Optional, Set 6 | 7 | 8 | def _fast_get_size(obj: Any) -> int: 9 | if hasattr(obj, "__len__") and len(obj) <= 0: 10 | return 0 11 | if hasattr(obj, "sample") and hasattr(obj, "memory_usage"): # DF, Series. 12 | n = min(len(obj), 1000) 13 | s = obj.sample(frac=n / len(obj)).memory_usage(index=True, deep=True) 14 | if hasattr(s, "sum"): 15 | s = s.sum() 16 | if hasattr(s, "compute"): 17 | s = s.compute() 18 | s = s / n * len(obj) 19 | return int(s) 20 | elif hasattr(obj, "nbytes"): # Numpy. 21 | return int(obj.nbytes) 22 | elif hasattr(obj, "data") and hasattr(obj.data, "nbytes"): # Sparse. 23 | return int(3 * obj.data.nbytes) 24 | raise TypeError("Could not determine size of the given object.") 25 | 26 | 27 | def _slow_get_size(obj: Any, seen: Optional[Set[Any]] = None) -> int: 28 | size = sys.getsizeof(obj) 29 | seen = seen or set() 30 | obj_id = id(obj) 31 | if obj_id in seen: 32 | return 0 33 | seen.add(obj_id) 34 | if isinstance(obj, dict): 35 | size += sum(get_size(v, seen) for v in obj.values()) 36 | size += sum(get_size(k, seen) for k in obj.keys()) 37 | elif hasattr(obj, "__dict__"): 38 | size += get_size(obj.__dict__, seen) 39 | elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)): 40 | size += sum(get_size(i, seen) for i in obj) 41 | return size 42 | 43 | 44 | def get_size(obj: Any, seen: Optional[Set[Any]] = None) -> int: 45 | """Recursively compute the size of an object. 46 | 47 | Parameters 48 | ---------- 49 | obj 50 | The object to get the size of. 51 | seen 52 | A set of seen objects. 53 | 54 | Returns 55 | ------- 56 | int 57 | The (approximate) size in bytes of the given object. 58 | """ 59 | # Short-circuit some types. 60 | try: 61 | return _fast_get_size(obj) 62 | except TypeError: 63 | pass 64 | # General-purpose size computation. 65 | return _slow_get_size(obj, seen) 66 | 67 | 68 | def str_to_posix_fully_portable_filename(s: str) -> str: 69 | """Convert key to POSIX fully portable filename [1]. 70 | 71 | Parameters 72 | ---------- 73 | s 74 | The string to convert to a POSIX fully portable filename. 75 | 76 | Returns 77 | ------- 78 | str 79 | A POSIX fully portable filename. 80 | 81 | References 82 | ---------- 83 | .. [1] https://en.wikipedia.org/wiki/Filename 84 | """ 85 | safechars = string.ascii_letters + string.digits + "._-" 86 | return "".join(c if c in safechars else "-" for c in s) 87 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Graphchain module tests.""" 2 | -------------------------------------------------------------------------------- /tests/test_dask_dataframe.py: -------------------------------------------------------------------------------- 1 | """Test module for dask DataFrames.""" 2 | 3 | import os 4 | from functools import partial 5 | from typing import Any 6 | 7 | import dask 8 | import dask.dataframe 9 | import fs.base 10 | import fs.osfs 11 | import joblib 12 | import pytest 13 | from dask.highlevelgraph import HighLevelGraph 14 | 15 | from graphchain.core import optimize 16 | 17 | 18 | @pytest.fixture() 19 | def dask_dataframe_graph() -> HighLevelGraph: 20 | """Generate an example dask DataFrame graph.""" 21 | 22 | @dask.delayed(pure=True) 23 | def create_dataframe() -> dask.dataframe.DataFrame: 24 | df: dask.dataframe.DataFrame = dask.datasets.timeseries(seed=42) 25 | return df 26 | 27 | @dask.delayed(pure=True) 28 | def summarise_dataframe(df: dask.dataframe.DataFrame) -> float: 29 | value: float = df["x"].sum().compute() + df["y"].sum().compute() 30 | return value 31 | 32 | df = create_dataframe() 33 | result: HighLevelGraph = summarise_dataframe(df) 34 | return result 35 | 36 | 37 | def test_dask_dataframe_graph(dask_dataframe_graph: HighLevelGraph) -> None: 38 | """Test that the graph can be traversed and its result is correct.""" 39 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")}) 40 | with dask.config.set(scheduler="sync", delayed_optimize=optimize): 41 | result = dask_dataframe_graph.compute() # type: ignore[attr-defined] 42 | assert result == 856.0466289487188 43 | result = dask_dataframe_graph.compute() # type: ignore[attr-defined] 44 | assert result == 856.0466289487188 45 | 46 | 47 | def test_custom_serde(dask_dataframe_graph: HighLevelGraph) -> None: 48 | """Test that we can use a custom serializer/deserializer.""" 49 | 50 | def custom_serialize(obj: Any, fs: fs.osfs.OSFS, key: str) -> None: 51 | if isinstance(obj, dask.dataframe.DataFrame): 52 | obj.to_parquet(os.path.join(fs.root_path, key)) 53 | else: 54 | with fs.open(f"{key}.joblib", "wb") as fid: 55 | joblib.dump(obj, fid) 56 | 57 | def custom_deserialize(fs: fs.osfs.OSFS, key: str) -> Any: 58 | if fs.exists(f"{key}.joblib"): 59 | with fs.open(f"{key}.joblib", "rb") as fid: 60 | return joblib.load(fid) 61 | else: 62 | return dask.dataframe.read_parquet(os.path.join(fs.root_path, key)) 63 | 64 | custom_optimize = partial( 65 | optimize, 66 | location="__graphchain_cache__/parquet/", 67 | serialize=custom_serialize, 68 | deserialize=custom_deserialize, 69 | ) 70 | 71 | # Ensure everything gets cached. 72 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")}) 73 | 74 | with dask.config.set(scheduler="sync", delayed_optimize=custom_optimize): 75 | result = dask_dataframe_graph.compute() # type: ignore[attr-defined] 76 | assert result == 856.0466289487188 77 | result = dask_dataframe_graph.compute() # type: ignore[attr-defined] 78 | assert result == 856.0466289487188 79 | -------------------------------------------------------------------------------- /tests/test_graphchain.py: -------------------------------------------------------------------------------- 1 | """Test module for the graphchain core.""" 2 | 3 | import functools 4 | import os 5 | import shutil 6 | import tempfile 7 | from typing import Any, Callable, Dict, Hashable, Iterable, Tuple, Union 8 | 9 | import dask 10 | import fs 11 | import pytest 12 | 13 | from graphchain.core import CachedComputation, optimize 14 | 15 | 16 | @pytest.fixture() 17 | def dask_graph() -> Dict[Hashable, Any]: 18 | r"""Generate an example dask graph. 19 | 20 | Will be used as a basis for the functional testing of the graphchain 21 | module:: 22 | 23 | O top(..) 24 | ____|____ 25 | / \ 26 | d1 O baz(..) 27 | _________|________ 28 | / \ 29 | O boo(...) O goo(...) 30 | _______|_______ ____|____ 31 | / | \ / | \ 32 | O O O O | O 33 | foo(.) bar(.) baz(.) foo(.) v6 bar(.) 34 | | | | | | 35 | | | | | | 36 | v1 v2 v3 v4 v5 37 | """ 38 | # Functions 39 | def foo(argument: int) -> int: 40 | return argument 41 | 42 | def bar(argument: int) -> int: 43 | return argument + 2 44 | 45 | def baz(*args: int) -> int: 46 | return sum(args) 47 | 48 | def boo(*args: int) -> int: 49 | return len(args) + sum(args) 50 | 51 | def goo(*args: int) -> int: 52 | return sum(args) + 1 53 | 54 | def top(argument: int, argument2: int) -> int: 55 | return argument - argument2 56 | 57 | # Graph (for the function definitions above) 58 | dsk = { 59 | "v0": None, 60 | "v1": 1, 61 | "v2": 2, 62 | "v3": 3, 63 | "v4": 0, 64 | "v5": -1, 65 | "v6": -2, 66 | "d1": -3, 67 | "foo1": (foo, "v1"), 68 | "foo2": (foo, "v4"), 69 | "bar1": (bar, "v2"), 70 | "bar2": (bar, "v5"), 71 | "baz1": (baz, "v3"), 72 | "baz2": (baz, "boo1", "goo1"), 73 | "boo1": (boo, "foo1", "bar1", "baz1"), 74 | "goo1": (goo, "foo2", "bar2", "v6"), 75 | "top1": (top, "d1", "baz2"), 76 | } 77 | return dsk # type: ignore[return-value] 78 | 79 | 80 | @pytest.fixture(scope="module") 81 | def temp_dir() -> str: # type: ignore[misc] 82 | """Create a temporary directory to store the cache in.""" 83 | with tempfile.TemporaryDirectory(prefix="__graphchain_cache__") as tmpdir: 84 | yield tmpdir 85 | 86 | 87 | @pytest.fixture(scope="module") 88 | def temp_dir_s3() -> str: 89 | """Create the directory used for the graphchain tests on S3.""" 90 | location = "s3://graphchain-test-bucket/__pytest_graphchain_cache__" 91 | return location 92 | 93 | 94 | def test_dag(dask_graph: Dict[Hashable, Any]) -> None: 95 | """Test that the graph can be traversed and its result is correct.""" 96 | dsk = dask_graph 97 | result = dask.get(dsk, ["top1"]) 98 | assert result == (-14,) 99 | 100 | 101 | @pytest.fixture() 102 | def optimizer(temp_dir: str) -> Tuple[str, Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]]: 103 | """Prefill the graphchain optimizer's parameters.""" 104 | return temp_dir, functools.partial(optimize, location=temp_dir) 105 | 106 | 107 | @pytest.fixture() 108 | def optimizer_exec_only_nodes( 109 | temp_dir: str, 110 | ) -> Tuple[str, Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]]: 111 | """Prefill the graphchain optimizer's parameters.""" 112 | return temp_dir, functools.partial(optimize, location=temp_dir, skip_keys=["boo1"]) 113 | 114 | 115 | @pytest.fixture() 116 | def optimizer_s3( 117 | temp_dir_s3: str, 118 | ) -> Tuple[str, Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]]: 119 | """Prefill the graphchain optimizer's parameters.""" 120 | return temp_dir_s3, functools.partial(optimize, location=temp_dir_s3) 121 | 122 | 123 | def test_first_run( 124 | dask_graph: Dict[Hashable, Any], 125 | optimizer: Tuple[ 126 | str, 127 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]], 128 | ], 129 | ) -> None: 130 | """First run. 131 | 132 | Tests a first run of the graphchain optimization function ``optimize``. It 133 | checks the final result, that that all function calls are wrapped - for 134 | execution and output storing, that the hashchain is created, that hashed 135 | outputs (the .pickle[.lz4] files) are generated and that the name of 136 | each file is a key in the hashchain. 137 | """ 138 | dsk = dask_graph 139 | cache_dir, graphchain_optimize = optimizer 140 | 141 | # Run optimizer 142 | newdsk = graphchain_optimize(dsk, ["top1"]) 143 | 144 | # Check the final result 145 | result = dask.get(newdsk, ["top1"]) 146 | assert result == (-14,) 147 | 148 | # Check that all functions have been wrapped 149 | for key, _task in dsk.items(): 150 | newtask = newdsk[key] 151 | assert isinstance(newtask[0], CachedComputation) 152 | 153 | # Check that the hash files are written and that each 154 | # filename can be found as a key in the hashchain 155 | # (the association of hash <-> DAG tasks is not tested) 156 | storage = fs.osfs.OSFS(cache_dir) 157 | filelist = storage.listdir("/") 158 | nfiles = len(filelist) 159 | assert nfiles >= len(dsk) 160 | storage.close() 161 | 162 | 163 | @pytest.mark.skip(reason="Need AWS credentials to test") 164 | def test_single_run_s3( 165 | dask_graph: Dict[Hashable, Any], 166 | optimizer_s3: Tuple[ 167 | str, 168 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]], 169 | ], 170 | ) -> None: 171 | """Run on S3. 172 | 173 | Tests a single run of the graphchain optimization function ``optimize`` 174 | using Amazon S3 as a persistency layer. It checks the final result, that 175 | all function calls are wrapped - for execution and output storing, that the 176 | hashchain is created, that hashed outputs (the .pickle[.lz4] files) 177 | are generated and that the name of each file is a key in the hashchain. 178 | """ 179 | dsk = dask_graph 180 | cache_dir, graphchain_optimize = optimizer_s3 181 | 182 | # Run optimizer 183 | newdsk = graphchain_optimize(dsk, ["top1"]) 184 | 185 | # Check the final result 186 | result = dask.get(newdsk, ["top1"]) 187 | assert result == (-14,) 188 | 189 | data_ext = ".joblib.lz4" 190 | 191 | # Check that all functions have been wrapped 192 | for key, _task in dsk.items(): 193 | newtask = newdsk[key] 194 | isinstance(newtask, CachedComputation) 195 | 196 | # Check that the hash files are written and that each 197 | # filename can be found as a key in the hashchain 198 | # (the association of hash <-> DAG tasks is not tested) 199 | storage = fs.open_fs(cache_dir) 200 | filelist = storage.listdir("/") 201 | nfiles = sum(int(x.endswith(data_ext)) for x in filelist) 202 | assert nfiles == len(dsk) 203 | 204 | 205 | def test_second_run( 206 | dask_graph: Dict[Hashable, Any], 207 | optimizer: Tuple[ 208 | str, 209 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]], 210 | ], 211 | ) -> None: 212 | """Second run. 213 | 214 | Tests a second run of the graphchain optimization function `optimize`. It 215 | checks the final result, that that all function calls are wrapped - for 216 | loading and the the result key has no dependencies. 217 | """ 218 | dsk = dask_graph 219 | _, graphchain_optimize = optimizer 220 | 221 | # Run optimizer 222 | newdsk = graphchain_optimize(dsk, ["top1"]) 223 | 224 | # Check the final result 225 | result = dask.get(newdsk, ["top1"]) 226 | assert result == (-14,) 227 | 228 | # Check that the functions are wrapped for loading 229 | for key in dsk.keys(): 230 | newtask = newdsk[key] 231 | assert isinstance(newtask, tuple) 232 | assert isinstance(newtask[0], CachedComputation) 233 | 234 | 235 | def test_node_changes( 236 | dask_graph: Dict[Hashable, Any], 237 | optimizer: Tuple[ 238 | str, 239 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]], 240 | ], 241 | ) -> None: 242 | """Test node changes. 243 | 244 | Tests the functionality of the graphchain in the event of changes in the 245 | structure of the graph, namely by altering the functions/constants 246 | associated to the tasks. After optimization, the afected nodes should be 247 | wrapped in a storeand execution wrapper and their dependency lists should 248 | not be empty. 249 | """ 250 | dsk = dask_graph 251 | _, graphchain_optimize = optimizer 252 | 253 | # Replacement function 'goo' 254 | def goo(*args: int) -> int: 255 | # hash miss! 256 | return sum(args) + 1 257 | 258 | # Replacement function 'top' 259 | def top(argument: int, argument2: int) -> int: 260 | # hash miss! 261 | return argument - argument2 262 | 263 | moddata = { 264 | "goo1": (goo, {"goo1", "baz2", "top1"}, (-14,)), 265 | # "top1": (top, {"top1"}, (-14,)), 266 | "top1": (lambda *args: -14, {"top1"}, (-14,)), 267 | "v2": (1000, {"v2", "bar1", "boo1", "baz2", "top1"}, (-1012,)), 268 | } 269 | 270 | for (modkey, (taskobj, _affected_nodes, result)) in moddata.items(): 271 | workdsk = dsk.copy() 272 | if callable(taskobj): 273 | workdsk[modkey] = (taskobj, *dsk[modkey][1:]) 274 | else: 275 | workdsk[modkey] = taskobj 276 | 277 | newdsk = graphchain_optimize(workdsk, ["top1"]) 278 | assert result == dask.get(newdsk, ["top1"]) 279 | 280 | 281 | def test_exec_only_nodes( 282 | dask_graph: Dict[Hashable, Any], 283 | optimizer_exec_only_nodes: Tuple[ 284 | str, 285 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]], 286 | ], 287 | ) -> None: 288 | """Test skipping some tasks. 289 | 290 | Tests that execution-only nodes execute in the event that dependencies of 291 | their parent nodes (i.e. in the dask graph) get modified. 292 | """ 293 | dsk = dask_graph 294 | cache_dir, graphchain_optimize = optimizer_exec_only_nodes 295 | 296 | # Cleanup temporary directory 297 | filelist = os.listdir(cache_dir) 298 | for entry in filelist: 299 | entrypath = os.path.join(cache_dir, entry) 300 | if os.path.isdir(entrypath): 301 | shutil.rmtree(entrypath, ignore_errors=True) 302 | else: 303 | os.remove(entrypath) 304 | filelist = os.listdir(cache_dir) 305 | assert not filelist 306 | 307 | # Run optimizer first time 308 | newdsk = graphchain_optimize(dsk, ["top1"]) 309 | result = dask.get(newdsk, ["top1"]) 310 | assert result == (-14,) 311 | 312 | # Modify function 313 | def goo(*args: int) -> int: 314 | # hash miss this! 315 | return sum(args) + 1 316 | 317 | dsk["goo1"] = (goo, *dsk["goo1"][1:]) 318 | 319 | # Run optimizer a second time 320 | newdsk = graphchain_optimize(dsk, ["top1"]) 321 | 322 | # Check the final result: 323 | # The output of node 'boo1' is needed at node 'baz2' 324 | # because 'goo1' was modified. A matching result indicates 325 | # that the boo1 node was executed, its dependencies loaded 326 | # which is the desired behaviour in such cases. 327 | result = dask.get(newdsk, ["top1"]) 328 | assert result == (-14,) 329 | 330 | 331 | def test_cache_deletion( 332 | dask_graph: Dict[Hashable, Any], 333 | optimizer: Tuple[ 334 | str, 335 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]], 336 | ], 337 | ) -> None: 338 | """Test cache deletion. 339 | 340 | Tests the ability to obtain results in the event that cache files are 341 | deleted (in the even of a cache-miss, the exec-store wrapper should be 342 | re-run by the load-wrapper). 343 | """ 344 | dsk = dask_graph 345 | cache_dir, graphchain_optimize = optimizer 346 | storage = fs.osfs.OSFS(cache_dir) 347 | 348 | # Cleanup first 349 | storage.removetree("/") 350 | 351 | # Run optimizer (first time) 352 | newdsk = graphchain_optimize(dsk, ["top1"]) 353 | result = dask.get(newdsk, ["top1"]) 354 | 355 | newdsk = graphchain_optimize(dsk, ["top1"]) 356 | result = dask.get(newdsk, ["top1"]) 357 | 358 | # Check the final result 359 | assert result == (-14,) 360 | 361 | 362 | def test_identical_nodes( 363 | optimizer: Tuple[ 364 | str, 365 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]], 366 | ] 367 | ) -> None: 368 | """Small test for the presence of identical nodes.""" 369 | cache_dir, graphchain_optimize = optimizer 370 | 371 | def foo(x: int) -> int: 372 | return x + 1 373 | 374 | def bar(*args: int) -> int: 375 | return sum(args) 376 | 377 | dsk = {"foo1": (foo, 1), "foo2": (foo, 1), "top1": (bar, "foo1", "foo2")} 378 | 379 | # First run 380 | newdsk = graphchain_optimize(dsk, ["top1"]) # type: ignore[arg-type] 381 | result = dask.get(newdsk, ["top1"]) 382 | assert result == (4,) 383 | 384 | # Second run 385 | newdsk = graphchain_optimize(dsk, ["top1"]) # type: ignore[arg-type] 386 | result = dask.get(newdsk, ["top1"]) 387 | assert result == (4,) 388 | -------------------------------------------------------------------------------- /tests/test_high_level_graph.py: -------------------------------------------------------------------------------- 1 | """Test module for the dask HighLevelGraphs.""" 2 | 3 | from functools import partial 4 | from typing import Any, Dict, cast 5 | 6 | import dask 7 | import dask.dataframe as dd 8 | import fs.base 9 | import pandas as pd 10 | import pytest 11 | from dask.highlevelgraph import HighLevelGraph 12 | from fs.memoryfs import MemoryFS 13 | 14 | from graphchain.core import optimize 15 | 16 | 17 | @pytest.fixture() 18 | def dask_high_level_graph() -> HighLevelGraph: 19 | """Generate an example dask HighLevelGraph.""" 20 | 21 | @dask.delayed(pure=True) 22 | def create_dataframe(num_rows: int, num_cols: int) -> pd.DataFrame: 23 | return pd.DataFrame(data=[range(num_cols)] * num_rows) 24 | 25 | @dask.delayed(pure=True) 26 | def create_dataframe2(num_rows: int, num_cols: int) -> pd.DataFrame: 27 | return pd.DataFrame(data=[range(num_cols)] * num_rows) 28 | 29 | @dask.delayed(pure=True) 30 | def complicated_computation(df: pd.DataFrame, num_quantiles: int) -> pd.DataFrame: 31 | return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)]) 32 | 33 | @dask.delayed(pure=True) 34 | def summarise_dataframes(*dfs: pd.DataFrame) -> float: 35 | return sum(cast("pd.Series[float]", df.sum()).sum() for df in dfs) 36 | 37 | df_a = create_dataframe(1000, 1000) 38 | df_b = create_dataframe2(1000, 1000) 39 | df_c = complicated_computation(df_a, 2048) 40 | df_d = complicated_computation(df_b, 2048) 41 | result: HighLevelGraph = summarise_dataframes(df_c, df_d) 42 | return result 43 | 44 | 45 | def test_high_level_dag(dask_high_level_graph: HighLevelGraph) -> None: 46 | """Test that the graph can be traversed and its result is correct.""" 47 | with dask.config.set(scheduler="sync"): 48 | result = dask_high_level_graph.compute() # type: ignore[attr-defined] 49 | assert result == 2045952000.0 50 | 51 | 52 | def test_high_level_graph(dask_high_level_graph: HighLevelGraph) -> None: 53 | """Test that the graph can be traversed and its result is correct.""" 54 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")}) 55 | with dask.config.set(scheduler="sync", delayed_optimize=optimize): 56 | result = dask_high_level_graph.compute() # type: ignore[attr-defined] 57 | assert result == 2045952000.0 58 | result = dask_high_level_graph.compute() # type: ignore[attr-defined] 59 | assert result == 2045952000.0 60 | 61 | 62 | def test_high_level_graph_optimize() -> None: 63 | """Test that we can handle the case where a `HighLevelGraph` is passed directly to `optimize()`.""" 64 | with dask.config.set( 65 | dataframe_optimize=optimize, 66 | scheduler="sync", 67 | ): 68 | df = dd.from_pandas( 69 | pd.DataFrame({"a": range(1000), "b": range(1000, 2000)}), npartitions=2 70 | ).set_index("b") 71 | computed = df.compute() 72 | assert computed.at[1000, "a"] == 0 73 | 74 | 75 | def test_high_level_graph_parallel(dask_high_level_graph: HighLevelGraph) -> None: 76 | """Test that the graph can be traversed and its result is correct when using parallel scheduler.""" 77 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")}) 78 | with dask.config.set(scheduler="processes", delayed_optimize=optimize): 79 | result = dask_high_level_graph.compute() # type: ignore[attr-defined] 80 | assert result == 2045952000.0 81 | result = dask_high_level_graph.compute() # type: ignore[attr-defined] 82 | assert result == 2045952000.0 83 | 84 | 85 | def test_custom_serde(dask_high_level_graph: HighLevelGraph) -> None: 86 | """Test that we can use a custom serializer/deserializer.""" 87 | custom_cache: Dict[str, Any] = {} 88 | 89 | def custom_serialize(obj: Any, fs: fs.base.FS, key: str) -> None: 90 | # Write the key itself to the filesystem. 91 | with fs.open(f"{key}.dat", "wb") as fid: 92 | fid.write(key) 93 | # Store the actual result in an in-memory cache. 94 | custom_cache[key] = result 95 | 96 | def custom_deserialize(fs: fs.base.FS, key: str) -> Any: 97 | # Verify that we have written the key to the filesystem. 98 | with fs.open(f"{key}.dat", "rb") as fid: 99 | assert key == fid.read() 100 | # Get the result corresponding to that key. 101 | return custom_cache[key] 102 | 103 | # Use a custom location so that we don't corrupt the default cache. 104 | custom_optimize = partial( 105 | optimize, 106 | location=MemoryFS(), 107 | serialize=custom_serialize, 108 | deserialize=custom_deserialize, 109 | ) 110 | 111 | # Ensure everything gets cached. 112 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")}) 113 | 114 | with dask.config.set(scheduler="sync", delayed_optimize=custom_optimize): 115 | result = dask_high_level_graph.compute() # type: ignore[attr-defined] 116 | assert result == 2045952000.0 117 | result = dask_high_level_graph.compute() # type: ignore[attr-defined] 118 | assert result == 2045952000.0 119 | --------------------------------------------------------------------------------