├── .cruft.json
├── .devcontainer
└── devcontainer.json
├── .flake8
├── .github
├── dependabot.yml
└── workflows
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── Dockerfile
├── LICENSE
├── README.md
├── docker-compose.yml
├── pyproject.toml
├── src
└── graphchain
│ ├── __init__.py
│ ├── core.py
│ ├── py.typed
│ └── utils.py
└── tests
├── __init__.py
├── test_dask_dataframe.py
├── test_graphchain.py
└── test_high_level_graph.py
/.cruft.json:
--------------------------------------------------------------------------------
1 | {
2 | "template": "git@github.com:radix-ai/poetry-cookiecutter",
3 | "commit": "966391b0342fc923a177a30c39285adf4416fa17",
4 | "checkout": null,
5 | "context": {
6 | "cookiecutter": {
7 | "package_name": "Graphchain",
8 | "package_description": "An efficient cache for the execution of dask graphs.",
9 | "package_url": "https://github.com/radix-ai/graphchain",
10 | "author_name": "Laurent Sorber",
11 | "author_email": "laurent@radix.ai",
12 | "python_version": "3.8",
13 | "with_fastapi_api": "0",
14 | "with_jupyter_lab": "0",
15 | "with_pydantic_typing": "0",
16 | "with_sentry_logging": "0",
17 | "with_streamlit_app": "0",
18 | "with_typer_cli": "0",
19 | "continuous_integration": "GitHub",
20 | "docstring_style": "NumPy",
21 | "private_package_repository_name": "",
22 | "private_package_repository_url": "",
23 | "__package_name_kebab_case": "graphchain",
24 | "__package_name_snake_case": "graphchain",
25 | "_template": "git@github.com:radix-ai/poetry-cookiecutter"
26 | }
27 | },
28 | "directory": null
29 | }
30 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "graphchain",
3 | "dockerComposeFile": "../docker-compose.yml",
4 | "service": "dev",
5 | "workspaceFolder": "/app/",
6 | "overrideCommand": true,
7 | "postStartCommand": "cp --update /opt/build/poetry/poetry.lock /app/ && mkdir -p /app/.git/hooks/ && cp --update /opt/build/git/* /app/.git/hooks/",
8 | "customizations": {
9 | "vscode": {
10 | "extensions": [
11 | "bungcip.better-toml",
12 | "eamodio.gitlens",
13 | "ms-azuretools.vscode-docker",
14 | "ms-python.python",
15 | "ms-python.vscode-pylance",
16 | "ryanluker.vscode-coverage-gutters",
17 | "visualstudioexptteam.vscodeintellicode"
18 | ],
19 | "settings": {
20 | "coverage-gutters.coverageFileNames": [
21 | "reports/coverage.xml"
22 | ],
23 | "editor.codeActionsOnSave": {
24 | "source.organizeImports": true
25 | },
26 | "editor.formatOnSave": true,
27 | "editor.rulers": [
28 | 100
29 | ],
30 | "files.autoSave": "onFocusChange",
31 | "python.defaultInterpreterPath": "/opt/app-env/bin/python",
32 | "python.formatting.provider": "black",
33 | "python.linting.banditArgs": [
34 | "--configfile",
35 | "pyproject.toml"
36 | ],
37 | "python.linting.banditEnabled": true,
38 | "python.linting.flake8Enabled": true,
39 | "python.linting.mypyEnabled": true,
40 | "python.linting.pydocstyleEnabled": true,
41 | "python.terminal.activateEnvironment": false,
42 | "python.testing.pytestEnabled": true,
43 | "terminal.integrated.defaultProfile.linux": "zsh",
44 | "terminal.integrated.profiles.linux": {
45 | "zsh": {
46 | "path": "/usr/bin/zsh"
47 | }
48 | }
49 | }
50 | }
51 | }
52 | }
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | # http://flake8.pycqa.org/en/latest/user/configuration.html#project-configuration
3 | # https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#line-length
4 | # TODO: https://github.com/PyCQA/flake8/issues/234
5 | color = always
6 | doctests = True
7 | ignore = DAR103,E203,E501,W503
8 | max_line_length = 100
9 | max_complexity = 10
10 |
11 | # https://github.com/terrencepreilly/darglint#flake8
12 | # TODO: https://github.com/terrencepreilly/darglint/issues/130
13 | docstring_style = numpy
14 | strictness = long
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | updates:
4 | - package-ecosystem: github-actions
5 | directory: /
6 | schedule:
7 | interval: monthly
8 | commit-message:
9 | prefix: "ci"
10 | prefix-development: "ci"
11 | include: "scope"
12 | - package-ecosystem: pip
13 | directory: /
14 | schedule:
15 | interval: monthly
16 | commit-message:
17 | prefix: "build"
18 | prefix-development: "build"
19 | include: "scope"
20 | versioning-strategy: lockfile-only
21 | allow:
22 | - dependency-type: "all"
23 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | release:
5 | types:
6 | - created
7 |
8 | jobs:
9 | publish:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Checkout
14 | uses: actions/checkout@v3
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: "3.8"
20 |
21 | - name: Install Poetry
22 | run: pip install --no-input poetry
23 |
24 | - name: Publish package
25 | run: |
26 | poetry config pypi-token.pypi "${{ secrets.POETRY_PYPI_TOKEN_PYPI }}"
27 | poetry publish --build
28 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | pull_request:
9 |
10 | jobs:
11 | test:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - name: Checkout
16 | uses: actions/checkout@v3
17 |
18 | - name: Set up Node.js
19 | uses: actions/setup-node@v3
20 | with:
21 | node-version: 16
22 |
23 | - name: Install @devcontainers/cli
24 | run: npm install --location=global @devcontainers/cli
25 |
26 | - name: Start Dev Container
27 | env:
28 | DOCKER_BUILDKIT: 1
29 | run: |
30 | git config --global init.defaultBranch main
31 | devcontainer up --workspace-folder .
32 |
33 | - name: Lint package
34 | run: devcontainer exec --workspace-folder . poe lint
35 |
36 | - name: Test package
37 | run: devcontainer exec --workspace-folder . poe test
38 |
39 | - name: Upload coverage
40 | uses: codecov/codecov-action@v3
41 | with:
42 | files: reports/coverage.xml
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Coverage.py
2 | htmlcov/
3 | reports/
4 |
5 | # cruft
6 | *.rej
7 |
8 | # Data
9 | *.csv*
10 | *.dat*
11 | *.pickle*
12 | *.xls*
13 | *.zip*
14 |
15 | # direnv
16 | .envrc
17 |
18 | # dotenv
19 | .env
20 |
21 | # graphchain
22 | __graphchain_cache__/
23 |
24 | # Hypothesis
25 | .hypothesis/
26 |
27 | # Jupyter
28 | *.ipynb
29 | .ipynb_checkpoints/
30 | notebooks/
31 |
32 | # macOS
33 | .DS_Store
34 |
35 | # mypy
36 | .dmypy.json
37 | .mypy_cache/
38 |
39 | # Node.js
40 | node_modules/
41 |
42 | # Poetry
43 | .venv/
44 | dist/
45 | poetry.lock
46 |
47 | # PyCharm
48 | .idea/
49 |
50 | # pyenv
51 | .python-version
52 |
53 | # pytest
54 | .pytest_cache/
55 |
56 | # Python
57 | __pycache__/
58 | *.py[cdo]
59 |
60 | # Terraform
61 | .terraform/
62 |
63 | # VS Code
64 | .vscode/
65 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # https://pre-commit.com
2 | default_install_hook_types: [commit-msg, pre-commit]
3 | default_stages: [commit, manual]
4 | fail_fast: true
5 | repos:
6 | - repo: https://github.com/pre-commit/pygrep-hooks
7 | rev: v1.9.0
8 | hooks:
9 | - id: python-check-blanket-noqa
10 | - id: python-check-blanket-type-ignore
11 | - id: python-check-mock-methods
12 | - id: python-no-eval
13 | - id: python-no-log-warn
14 | - id: python-use-type-annotations
15 | - id: python-check-blanket-noqa
16 | - id: rst-backticks
17 | - id: rst-directive-colons
18 | - id: rst-inline-touching-normal
19 | - id: text-unicode-replacement-char
20 | - repo: https://github.com/pre-commit/pre-commit-hooks
21 | rev: v4.3.0
22 | hooks:
23 | - id: check-added-large-files
24 | - id: check-ast
25 | - id: check-builtin-literals
26 | - id: check-case-conflict
27 | - id: check-docstring-first
28 | - id: check-json
29 | - id: check-merge-conflict
30 | - id: check-shebang-scripts-are-executable
31 | - id: check-symlinks
32 | - id: check-toml
33 | - id: check-vcs-permalinks
34 | - id: check-xml
35 | - id: check-yaml
36 | - id: debug-statements
37 | - id: detect-private-key
38 | - id: fix-byte-order-marker
39 | - id: mixed-line-ending
40 | - id: trailing-whitespace
41 | types: [python]
42 | - id: end-of-file-fixer
43 | types: [python]
44 | - repo: local
45 | hooks:
46 | - id: commitizen
47 | name: commitizen
48 | entry: cz check --commit-msg-file
49 | require_serial: true
50 | language: system
51 | stages: [commit-msg]
52 | - id: pyupgrade
53 | name: pyupgrade
54 | entry: pyupgrade --py38-plus
55 | require_serial: true
56 | language: system
57 | types: [python]
58 | - id: absolufy-imports
59 | name: absolufy-imports
60 | entry: absolufy-imports
61 | require_serial: true
62 | language: system
63 | types: [python]
64 | - id: yesqa
65 | name: yesqa
66 | entry: yesqa
67 | require_serial: true
68 | language: system
69 | types: [python]
70 | - id: isort
71 | name: isort
72 | entry: isort
73 | require_serial: true
74 | language: system
75 | types: [python]
76 | - id: black
77 | name: black
78 | entry: black
79 | require_serial: true
80 | language: system
81 | types: [python]
82 | - id: shellcheck
83 | name: shellcheck
84 | entry: shellcheck --check-sourced
85 | language: system
86 | types: [shell]
87 | - id: bandit
88 | name: bandit
89 | entry: bandit --configfile pyproject.toml
90 | language: system
91 | types: [python]
92 | - id: pydocstyle
93 | name: pydocstyle
94 | entry: pydocstyle
95 | language: system
96 | types: [python]
97 | - id: flake8
98 | name: flake8
99 | entry: flake8
100 | language: system
101 | types: [python]
102 | - id: poetry-check
103 | name: poetry check
104 | entry: poetry check
105 | language: system
106 | files: pyproject.toml
107 | pass_filenames: false
108 | - id: mypy
109 | name: mypy
110 | entry: mypy
111 | language: system
112 | types: [python]
113 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## v1.4.0 (2022-08-29)
2 |
3 | ### Feat
4 |
5 | - support multiprocessing schedulers (#88)
6 |
7 | ### Fix
8 |
9 | - add high level graph support (#92)
10 |
11 | ## v1.3.0 (2022-06-16)
12 |
13 | ### Feat
14 |
15 | - add (de)serialization customization (#76)
16 |
17 | ## v1.2.0 (2022-01-21)
18 |
19 | ### Feat
20 |
21 | - add support for dask.distributed (thanks to @cbyrohl)
22 | - rescaffold project
23 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:1
2 | FROM python:3.8-slim AS base
3 |
4 | # Configure Python to print tracebacks on crash [1], and to not buffer stdout and stderr [2].
5 | # [1] https://docs.python.org/3/using/cmdline.html#envvar-PYTHONFAULTHANDLER
6 | # [2] https://docs.python.org/3/using/cmdline.html#envvar-PYTHONUNBUFFERED
7 | ENV PYTHONFAULTHANDLER 1
8 | ENV PYTHONUNBUFFERED 1
9 |
10 | # Install Poetry.
11 | ENV POETRY_VERSION 1.1.13
12 | RUN --mount=type=cache,target=/root/.cache/ \
13 | pip install poetry==$POETRY_VERSION
14 |
15 | # Create and activate a virtual environment.
16 | RUN python -m venv /opt/app-env
17 | ENV PATH /opt/app-env/bin:$PATH
18 | ENV VIRTUAL_ENV /opt/app-env
19 |
20 | # Install compilers that may be required for certain packages or platforms.
21 | RUN rm /etc/apt/apt.conf.d/docker-clean
22 | RUN --mount=type=cache,target=/var/cache/apt/ \
23 | --mount=type=cache,target=/var/lib/apt/ \
24 | apt-get update && \
25 | apt-get install --no-install-recommends --yes build-essential
26 |
27 | # Set the working directory.
28 | WORKDIR /app/
29 |
30 | # Install the run time Python environment.
31 | COPY poetry.lock* pyproject.toml /app/
32 | RUN --mount=type=cache,target=/root/.cache/ \
33 | mkdir -p src/graphchain/ && touch src/graphchain/__init__.py && touch README.md && \
34 | poetry install --no-dev --no-interaction
35 |
36 | # Create a non-root user.
37 | ARG UID=1000
38 | ARG GID=$UID
39 | RUN groupadd --gid $GID app && \
40 | useradd --create-home --gid $GID --uid $UID app
41 |
42 | FROM base as ci
43 |
44 | # Install git so we can run pre-commit.
45 | RUN --mount=type=cache,target=/var/cache/apt/ \
46 | --mount=type=cache,target=/var/lib/apt/ \
47 | apt-get update && \
48 | apt-get install --no-install-recommends --yes git
49 |
50 | # Install the development Python environment.
51 | RUN --mount=type=cache,target=/root/.cache/ \
52 | poetry install --no-interaction
53 |
54 | # Give the non-root user ownership and switch to the non-root user.
55 | RUN chown --recursive app /app/ /opt/
56 | USER app
57 |
58 | FROM base as dev
59 |
60 | # Install development tools: compilers, curl, git, gpg, ssh, starship, sudo, vim, and zsh.
61 | RUN --mount=type=cache,target=/var/cache/apt/ \
62 | --mount=type=cache,target=/var/lib/apt/ \
63 | apt-get update && \
64 | apt-get install --no-install-recommends --yes build-essential curl git gnupg ssh sudo vim zsh zsh-antigen && \
65 | sh -c "$(curl -fsSL https://starship.rs/install.sh)" -- "--yes" && \
66 | usermod --shell /usr/bin/zsh app
67 |
68 | # Install the development Python environment.
69 | RUN --mount=type=cache,target=/root/.cache/ \
70 | poetry install --no-interaction
71 |
72 | # Persist output generated during docker build so that we can restore it in the dev container.
73 | COPY .pre-commit-config.yaml /app/
74 | RUN mkdir -p /opt/build/poetry/ && cp poetry.lock /opt/build/poetry/ && \
75 | git init && pre-commit install --install-hooks && \
76 | mkdir -p /opt/build/git/ && cp .git/hooks/commit-msg .git/hooks/pre-commit /opt/build/git/
77 |
78 | # Give the non-root user ownership and switch to the non-root user.
79 | RUN chown --recursive app /app/ /opt/ && \
80 | echo 'app ALL=(root) NOPASSWD:ALL' > /etc/sudoers.d/app && \
81 | chmod 0440 /etc/sudoers.d/app
82 | USER app
83 |
84 | # Configure the non-root user's shell.
85 | RUN echo 'source /usr/share/zsh-antigen/antigen.zsh' >> ~/.zshrc && \
86 | echo 'antigen bundle zsh-users/zsh-syntax-highlighting' >> ~/.zshrc && \
87 | echo 'antigen bundle zsh-users/zsh-autosuggestions' >> ~/.zshrc && \
88 | echo 'antigen apply' >> ~/.zshrc && \
89 | echo 'eval "$(starship init zsh)"' >> ~/.zshrc && \
90 | echo 'HISTFILE=~/.zsh_history' >> ~/.zshrc && \
91 | zsh -c 'source ~/.zshrc'
92 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 radix.ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://choosealicense.com/licenses/mit/) [](https://pypi.python.org/pypi/graphchain/)
2 |
3 | # Graphchain
4 |
5 | ## What is graphchain?
6 |
7 | Graphchain is like [joblib.Memory](https://joblib.readthedocs.io/en/latest/memory.html) for dask graphs. [Dask graph computations](https://docs.dask.org/en/latest/spec.html) are cached to a local or remote location of your choice, specified by a [PyFilesystem FS URL](https://docs.pyfilesystem.org/en/latest/openers.html).
8 |
9 | When you change your dask graph (by changing a computation's implementation or its inputs), graphchain will take care to only recompute the minimum number of computations necessary to fetch the result. This allows you to iterate quickly over your graph without spending time on recomputing previously computed keys.
10 |
11 |
12 | 
13 | Source: xkcd.com/1205/
14 |
15 |
16 | The main difference between graphchain and joblib.Memory is that in graphchain a computation's materialised inputs are _not_ serialised and hashed (which can be very expensive when the inputs are large objects such as pandas DataFrames). Instead, a chain of hashes (hence the name graphchain) of the computation object and its dependencies (which are also computation objects) is used to identify the cache file.
17 |
18 | Additionally, the result of a computation is only cached if it is estimated that loading that computation from cache will save time compared to simply computing the computation. The decision on whether to cache depends on the characteristics of the cache location, which are different when caching to the local filesystem compared to caching to S3 for example.
19 |
20 | ## Usage by example
21 |
22 | ### Basic usage
23 |
24 | Install graphchain with pip to get started:
25 |
26 | ```sh
27 | pip install graphchain
28 | ```
29 |
30 | To demonstrate how graphchain can save you time, let's first create a simple dask graph that (1) creates a few pandas DataFrames, (2) runs a relatively heavy operation on these DataFrames, and (3) summarises the results.
31 |
32 | ```python
33 | import dask
34 | import graphchain
35 | import pandas as pd
36 |
37 | def create_dataframe(num_rows, num_cols):
38 | print("Creating DataFrame...")
39 | return pd.DataFrame(data=[range(num_cols)]*num_rows)
40 |
41 | def expensive_computation(df, num_quantiles):
42 | print("Running expensive computation on DataFrame...")
43 | return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)])
44 |
45 | def summarize_dataframes(*dfs):
46 | print("Summing DataFrames...")
47 | return sum(df.sum().sum() for df in dfs)
48 |
49 | dsk = {
50 | "df_a": (create_dataframe, 10_000, 1000),
51 | "df_b": (create_dataframe, 10_000, 1000),
52 | "df_c": (expensive_computation, "df_a", 2048),
53 | "df_d": (expensive_computation, "df_b", 2048),
54 | "result": (summarize_dataframes, "df_c", "df_d")
55 | }
56 | ```
57 |
58 | Using `dask.get` to fetch the `"result"` key takes about 6 seconds:
59 |
60 | ```python
61 | >>> %time dask.get(dsk, "result")
62 |
63 | Creating DataFrame...
64 | Running expensive computation on DataFrame...
65 | Creating DataFrame...
66 | Running expensive computation on DataFrame...
67 | Summing DataFrames...
68 |
69 | CPU times: user 7.39 s, sys: 686 ms, total: 8.08 s
70 | Wall time: 6.19 s
71 | ```
72 |
73 | On the other hand, using `graphchain.get` for the first time to fetch `'result'` takes only 4 seconds:
74 |
75 | ```python
76 | >>> %time graphchain.get(dsk, "result")
77 |
78 | Creating DataFrame...
79 | Running expensive computation on DataFrame...
80 | Summing DataFrames...
81 |
82 | CPU times: user 4.7 s, sys: 519 ms, total: 5.22 s
83 | Wall time: 4.04 s
84 | ```
85 |
86 | The reason `graphchain.get` is faster than `dask.get` is because it can load `df_b` and `df_d` from cache after `df_a` and `df_c` have been computed and cached. Note that graphchain will only cache the result of a computation if loading that computation from cache is estimated to be faster than simply running the computation.
87 |
88 | Running `graphchain.get` a second time to fetch `"result"` will be almost instant since this time the result itself is also available from cache:
89 |
90 | ```python
91 | >>> %time graphchain.get(dsk, "result")
92 |
93 | CPU times: user 4.79 ms, sys: 1.79 ms, total: 6.58 ms
94 | Wall time: 5.34 ms
95 | ```
96 |
97 | Now let's say we want to change how the result is summarised from a sum to an average:
98 |
99 | ```python
100 | def summarize_dataframes(*dfs):
101 | print("Averaging DataFrames...")
102 | return sum(df.mean().mean() for df in dfs) / len(dfs)
103 | ```
104 |
105 | If we then ask graphchain to fetch `"result"`, it will detect that only `summarize_dataframes` has changed and therefore only recompute this function with inputs loaded from cache:
106 |
107 | ```python
108 | >>> %time graphchain.get(dsk, "result")
109 |
110 | Averaging DataFrames...
111 |
112 | CPU times: user 123 ms, sys: 37.2 ms, total: 160 ms
113 | Wall time: 86.6 ms
114 | ```
115 |
116 | ### Storing the graphchain cache remotely
117 |
118 | Graphchain's cache is by default `./__graphchain_cache__`, but you can ask graphchain to use a cache at any [PyFilesystem FS URL](https://docs.pyfilesystem.org/en/latest/openers.html) such as `s3://mybucket/__graphchain_cache__`:
119 |
120 | ```python
121 | graphchain.get(dsk, "result", location="s3://mybucket/__graphchain_cache__")
122 | ```
123 |
124 | ### Excluding keys from being cached
125 |
126 | In some cases you may not want a key to be cached. To avoid writing certain keys to the graphchain cache, you can use the `skip_keys` argument:
127 |
128 | ```python
129 | graphchain.get(dsk, "result", skip_keys=["result"])
130 | ```
131 |
132 | ### Using graphchain with dask.delayed
133 |
134 | Alternatively, you can use graphchain together with dask.delayed for easier dask graph creation:
135 |
136 | ```python
137 | import dask
138 | import pandas as pd
139 |
140 | @dask.delayed
141 | def create_dataframe(num_rows, num_cols):
142 | print("Creating DataFrame...")
143 | return pd.DataFrame(data=[range(num_cols)]*num_rows)
144 |
145 | @dask.delayed
146 | def expensive_computation(df, num_quantiles):
147 | print("Running expensive computation on DataFrame...")
148 | return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)])
149 |
150 | @dask.delayed
151 | def summarize_dataframes(*dfs):
152 | print("Summing DataFrames...")
153 | return sum(df.sum().sum() for df in dfs)
154 |
155 | df_a = create_dataframe(num_rows=10_000, num_cols=1000)
156 | df_b = create_dataframe(num_rows=10_000, num_cols=1000)
157 | df_c = expensive_computation(df_a, num_quantiles=2048)
158 | df_d = expensive_computation(df_b, num_quantiles=2048)
159 | result = summarize_dataframes(df_c, df_d)
160 | ```
161 |
162 | After which you can compute `result` by setting the `delayed_optimize` method to `graphchain.optimize`:
163 |
164 | ```python
165 | import graphchain
166 | from functools import partial
167 |
168 | optimize_s3 = partial(graphchain.optimize, location="s3://mybucket/__graphchain_cache__/")
169 |
170 | with dask.config.set(scheduler="sync", delayed_optimize=optimize_s3):
171 | print(result.compute())
172 | ```
173 |
174 | ### Using a custom a serializer/deserializer
175 |
176 | By default graphchain will cache dask computations with [joblib.dump](https://joblib.readthedocs.io/en/latest/generated/joblib.dump.html) and LZ4 compression. However, you may also supply a custom `serialize` and `deserialize` function that writes and reads computations to and from a [PyFilesystem filesystem](https://docs.pyfilesystem.org/en/latest/introduction.html), respectively. For example, the following snippet shows how to serialize dask DataFrames with [dask.dataframe.to_parquet](https://docs.dask.org/en/stable/generated/dask.dataframe.to_parquet.html), while other objects are serialized with joblib:
177 |
178 | ```python
179 | import dask.dataframe
180 | import graphchain
181 | import fs.osfs
182 | import joblib
183 | import os
184 | from functools import partial
185 | from typing import Any
186 |
187 | def custom_serialize(obj: Any, fs: fs.osfs.OSFS, key: str) -> None:
188 | """Serialize dask DataFrames with to_parquet, and other objects with joblib.dump."""
189 | if isinstance(obj, dask.dataframe.DataFrame):
190 | obj.to_parquet(os.path.join(fs.root_path, "parquet", key))
191 | else:
192 | with fs.open(f"{key}.joblib", "wb") as fid:
193 | joblib.dump(obj, fid)
194 |
195 | def custom_deserialize(fs: fs.osfs.OSFS, key: str) -> Any:
196 | """Deserialize dask DataFrames with read_parquet, and other objects with joblib.load."""
197 | if fs.exists(f"{key}.joblib"):
198 | with fs.open(f"{key}.joblib", "rb") as fid:
199 | return joblib.load(fid)
200 | else:
201 | return dask.dataframe.read_parquet(os.path.join(fs.root_path, "parquet", key))
202 |
203 | optimize_parquet = partial(
204 | graphchain.optimize,
205 | location="./__graphchain_cache__/custom/",
206 | serialize=custom_serialize,
207 | deserialize=custom_deserialize
208 | )
209 |
210 | with dask.config.set(scheduler="sync", delayed_optimize=optimize_parquet):
211 | print(result.compute())
212 | ```
213 |
214 | ## Contributing
215 |
216 |
217 | Setup: once per device
218 |
219 | 1. [Generate an SSH key](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/generating-a-new-ssh-key-and-adding-it-to-the-ssh-agent#generating-a-new-ssh-key) and [add the SSH key to your GitHub account](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/adding-a-new-ssh-key-to-your-github-account).
220 | 1. Configure SSH to automatically load your SSH keys:
221 | ```sh
222 | cat << EOF >> ~/.ssh/config
223 | Host *
224 | AddKeysToAgent yes
225 | IgnoreUnknown UseKeychain
226 | UseKeychain yes
227 | EOF
228 | ```
229 | 1. [Install Docker Desktop](https://www.docker.com/get-started).
230 | - Enable _Use Docker Compose V2_ in Docker Desktop's preferences window.
231 | - _Linux only_:
232 | - [Configure Docker and Docker Compose to use the BuildKit build system](https://docs.docker.com/develop/develop-images/build_enhancements/#to-enable-buildkit-builds). On macOS and Windows, BuildKit is enabled by default in Docker Desktop.
233 | - Export your user's user id and group id so that [files created in the Dev Container are owned by your user](https://github.com/moby/moby/issues/3206):
234 | ```sh
235 | cat << EOF >> ~/.bashrc
236 | export UID=$(id --user)
237 | export GID=$(id --group)
238 | EOF
239 | ```
240 | 1. [Install VS Code](https://code.visualstudio.com/) and [VS Code's Remote-Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). Alternatively, install [PyCharm](https://www.jetbrains.com/pycharm/download/).
241 | - _Optional:_ Install a [Nerd Font](https://www.nerdfonts.com/font-downloads) such as [FiraCode Nerd Font](https://github.com/ryanoasis/nerd-fonts/tree/master/patched-fonts/FiraCode) with `brew tap homebrew/cask-fonts && brew install --cask font-fira-code-nerd-font` and [configure VS Code](https://github.com/tonsky/FiraCode/wiki/VS-Code-Instructions) or [configure PyCharm](https://github.com/tonsky/FiraCode/wiki/Intellij-products-instructions) to use `'FiraCode Nerd Font'`.
242 |
243 |
244 |
245 |
246 | Setup: once per project
247 |
248 | 1. Clone this repository.
249 | 2. Start a [Dev Container](https://code.visualstudio.com/docs/remote/containers) in your preferred development environment:
250 | - _VS Code_: open the cloned repository and run Ctrl/⌘ + ⇧ + P → _Remote-Containers: Reopen in Container_.
251 | - _PyCharm_: open the cloned repository and [configure Docker Compose as a remote interpreter](https://www.jetbrains.com/help/pycharm/using-docker-compose-as-a-remote-interpreter.html#docker-compose-remote).
252 | - _Terminal_: open the cloned repository and run `docker compose run --rm dev` to start an interactive Dev Container.
253 |
254 |
255 |
256 |
257 | Developing
258 |
259 | - This project follows the [Conventional Commits](https://www.conventionalcommits.org/) standard to automate [Semantic Versioning](https://semver.org/) and [Keep A Changelog](https://keepachangelog.com/) with [Commitizen](https://github.com/commitizen-tools/commitizen).
260 | - Run `poe` from within the development environment to print a list of [Poe the Poet](https://github.com/nat-n/poethepoet) tasks available to run on this project.
261 | - Run `poetry add {package}` from within the development environment to install a run time dependency and add it to `pyproject.toml` and `poetry.lock`.
262 | - Run `poetry remove {package}` from within the development environment to uninstall a run time dependency and remove it from `pyproject.toml` and `poetry.lock`.
263 | - Run `poetry update` from within the development environment to upgrade all dependencies to the latest versions allowed by `pyproject.toml`.
264 | - Run `cz bump` to bump the package's version, update the `CHANGELOG.md`, and create a git tag.
265 |
266 |
267 |
268 | ## Developed by Radix
269 |
270 | [Radix](https://radix.ai) is a Belgium-based Machine Learning company.
271 |
272 | Our vision is to make technology work for and with us. We believe that if technology is used in a creative way, jobs become more fulfilling, people become the best version of themselves, and companies grow.
273 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.9"
2 |
3 | services:
4 | dev:
5 | build:
6 | context: .
7 | target: dev
8 | args:
9 | UID: ${UID:-1000}
10 | GID: ${GID:-1000}
11 | stdin_open: true
12 | tty: true
13 | entrypoint: []
14 | command:
15 | [
16 | "sh",
17 | "-c",
18 | "cp --update /opt/build/poetry/poetry.lock /app/ && mkdir -p /app/.git/hooks/ && cp --update /opt/build/git/* /app/.git/hooks/ && zsh"
19 | ]
20 | environment:
21 | - POETRY_PYPI_TOKEN_PYPI
22 | - SSH_AUTH_SOCK=/run/host-services/ssh-auth.sock
23 | volumes:
24 | - .:/app/
25 | - ~/.gitconfig:/etc/gitconfig
26 | - ~/.ssh/known_hosts:/home/app/.ssh/known_hosts
27 | - ${SSH_AGENT_AUTH_SOCK:-/run/host-services/ssh-auth.sock}:/run/host-services/ssh-auth.sock
28 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # https://python-poetry.org/docs/pyproject/#poetry-and-pep-517
2 | [build-system]
3 | requires = ["poetry-core>=1.0.0"]
4 | build-backend = "poetry.core.masonry.api"
5 |
6 | # https://python-poetry.org/docs/pyproject/
7 | [tool.poetry]
8 | name = "graphchain"
9 | version = "1.4.0"
10 | description = "An efficient cache for the execution of dask graphs."
11 | authors = ["Laurent Sorber "]
12 | readme = "README.md"
13 | repository = "https://github.com/radix-ai/graphchain"
14 |
15 | # https://commitizen-tools.github.io/commitizen/config/
16 | [tool.commitizen]
17 | bump_message = "bump(release): v$current_version → v$new_version"
18 | tag_format = "v$version"
19 | update_changelog_on_bump = true
20 | version = "1.4.0"
21 | version_files = ["pyproject.toml:version"]
22 |
23 | # https://python-poetry.org/docs/dependency-specification/
24 | [tool.poetry.dependencies]
25 | cloudpickle = ">=1.0.0,<3.0.0"
26 | dask = ">=2020.12.0"
27 | fs-s3fs = "^1"
28 | joblib = "^1"
29 | lz4 = ">=3,<5"
30 | python = "^3.8,<3.11"
31 |
32 | # https://python-poetry.org/docs/master/managing-dependencies/
33 | # TODO: Split in `tool.poetry.group.dev` and `tool.poetry.group.test` when Poetry 1.2.0 is released.
34 | [tool.poetry.dev-dependencies]
35 | absolufy-imports = "^0.3.1"
36 | bandit = { extras = ["toml"], version = "^1.7.4" }
37 | black = "^22.6.0"
38 | commitizen = "^2.27.1"
39 | coverage = { extras = ["toml"], version = "^6.4.1" }
40 | cruft = "^2.11.0"
41 | darglint = "^1.8.1"
42 | fastparquet = "^0.8.1"
43 | flake8 = "^5.0.4"
44 | flake8-bugbear = "^22.6.22"
45 | flake8-comprehensions = "^3.10.0"
46 | flake8-mutable = "^1.2.0"
47 | flake8-print = "^5.0.0"
48 | flake8-pytest-style = "^1.6.0"
49 | flake8-rst-docstrings = "^0.2.6"
50 | flake8-tidy-imports = "^4.8.0"
51 | isort = "^5.10.1"
52 | mypy = "^0.971"
53 | pandas = "^1.3.5"
54 | pandas-stubs = "^1.4.3"
55 | pdoc = "^12.0.2"
56 | pep8-naming = "^0.13.0"
57 | poethepoet = "^0.16.0"
58 | pre-commit = "^2.19.0"
59 | pydocstyle = { extras = ["toml"], version = "^6.1.1" }
60 | pytest = "^7.1.2"
61 | pytest-clarity = "^1.0.1"
62 | pytest-mock = "^3.8.1"
63 | pytest-xdist = "^2.5.0"
64 | pyupgrade = "^2.34.0"
65 | safety = "^2.1.1"
66 | shellcheck-py = "^0.8.0"
67 | yesqa = "^1.4.0"
68 |
69 | # https://bandit.readthedocs.io/en/latest/config.html
70 | [tool.bandit]
71 | skips = ["B101", "B110", "B403"]
72 |
73 | # https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file
74 | [tool.black]
75 | line-length = 100
76 |
77 | # https://coverage.readthedocs.io/en/latest/config.html#report
78 | [tool.coverage.report]
79 | fail_under = 50
80 | precision = 1
81 | show_missing = true
82 | skip_covered = true
83 |
84 | # https://coverage.readthedocs.io/en/latest/config.html#run
85 | [tool.coverage.run]
86 | branch = true
87 | command_line = "--module pytest"
88 | data_file = "reports/.coverage"
89 | source = ["src"]
90 |
91 | # https://coverage.readthedocs.io/en/latest/config.html#xml
92 | [tool.coverage.xml]
93 | output = "reports/coverage.xml"
94 |
95 | # https://pycqa.github.io/isort/docs/configuration/options.html
96 | [tool.isort]
97 | color_output = true
98 | line_length = 100
99 | profile = "black"
100 | src_paths = ["src", "tests"]
101 |
102 | # https://mypy.readthedocs.io/en/latest/config_file.html
103 | [tool.mypy]
104 | junit_xml = "reports/mypy.xml"
105 | strict = true
106 | disallow_subclassing_any = false
107 | disallow_untyped_calls = false
108 | disallow_untyped_decorators = false
109 | ignore_missing_imports = true
110 | no_implicit_reexport = false
111 | pretty = true
112 | show_column_numbers = true
113 | show_error_codes = true
114 | show_error_context = true
115 | warn_unreachable = true
116 |
117 | # http://www.pydocstyle.org/en/latest/usage.html#configuration-files
118 | [tool.pydocstyle]
119 | convention = "numpy"
120 |
121 | # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref
122 | [tool.pytest.ini_options]
123 | addopts = "--color=yes --doctest-modules --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml"
124 | filterwarnings = ["error", "ignore::DeprecationWarning"]
125 | testpaths = ["src", "tests"]
126 | xfail_strict = true
127 |
128 | # https://github.com/nat-n/poethepoet
129 | [tool.poe.tasks]
130 |
131 | [tool.poe.tasks.docs]
132 | help = "Generate this package's docs"
133 | cmd = """
134 | pdoc
135 | --docformat $docformat
136 | --output-directory $outputdirectory
137 | graphchain
138 | """
139 |
140 | [[tool.poe.tasks.docs.args]]
141 | help = "The docstring style (default: numpy)"
142 | name = "docformat"
143 | options = ["--docformat"]
144 | default = "numpy"
145 |
146 | [[tool.poe.tasks.docs.args]]
147 | help = "The output directory (default: docs)"
148 | name = "outputdirectory"
149 | options = ["--output-directory"]
150 | default = "docs"
151 |
152 | [tool.poe.tasks.lint]
153 | help = "Lint this package"
154 |
155 | [[tool.poe.tasks.lint.sequence]]
156 | cmd = """
157 | pre-commit run
158 | --all-files
159 | --color always
160 | """
161 |
162 | [[tool.poe.tasks.lint.sequence]]
163 | shell = "safety check --continue-on-error --full-report"
164 |
165 | [tool.poe.tasks.test]
166 | help = "Test this package"
167 |
168 | [[tool.poe.tasks.test.sequence]]
169 | cmd = "coverage run"
170 |
171 | [[tool.poe.tasks.test.sequence]]
172 | cmd = "coverage report"
173 |
174 | [[tool.poe.tasks.test.sequence]]
175 | cmd = "coverage xml"
176 |
--------------------------------------------------------------------------------
/src/graphchain/__init__.py:
--------------------------------------------------------------------------------
1 | """Graphchain is a cache for dask graphs."""
2 |
3 | from graphchain.core import get, optimize
4 |
5 | __all__ = ["get", "optimize"]
6 |
--------------------------------------------------------------------------------
/src/graphchain/core.py:
--------------------------------------------------------------------------------
1 | """Graphchain core."""
2 |
3 | import datetime as dt
4 | import logging
5 | import time
6 | from copy import deepcopy
7 | from functools import lru_cache, partial
8 | from pickle import HIGHEST_PROTOCOL
9 | from typing import (
10 | Any,
11 | Callable,
12 | Container,
13 | Dict,
14 | Hashable,
15 | Iterable,
16 | Optional,
17 | Sequence,
18 | TypeVar,
19 | Union,
20 | cast,
21 | )
22 |
23 | import cloudpickle
24 | import dask
25 | import fs
26 | import fs.base
27 | import fs.memoryfs
28 | import fs.osfs
29 | import joblib
30 | from dask.highlevelgraph import HighLevelGraph, Layer
31 |
32 | from graphchain.utils import get_size, str_to_posix_fully_portable_filename
33 |
34 | T = TypeVar("T")
35 |
36 |
37 | # We have to define an lru_cache wrapper here because mypy doesn't support decorated properties [1].
38 | # [1] https://github.com/python/mypy/issues/5858
39 | def _cache(__user_function: Callable[..., T]) -> Callable[..., T]:
40 | return lru_cache(maxsize=None)(__user_function)
41 |
42 |
43 | def hlg_setitem(self: HighLevelGraph, key: Hashable, value: Any) -> None:
44 | """Set a HighLevelGraph computation."""
45 | for d in self.layers.values():
46 | if key in d:
47 | d[key] = value # type: ignore[index]
48 |
49 |
50 | # Monkey patch HighLevelGraph to add a missing `__setitem__` method.
51 | if not hasattr(HighLevelGraph, "__setitem__"):
52 | HighLevelGraph.__setitem__ = hlg_setitem # type: ignore[index]
53 |
54 |
55 | def layer_setitem(self: Layer, key: Hashable, value: Any) -> None:
56 | """Set a Layer computation."""
57 | self.mapping[key] = value # type: ignore[attr-defined]
58 |
59 |
60 | # Monkey patch Layer to add a missing `__setitem__` method.
61 | if not hasattr(Layer, "__setitem__"):
62 | Layer.__setitem__ = layer_setitem # type: ignore[index]
63 |
64 |
65 | logger = logging.getLogger(__name__)
66 |
67 |
68 | def joblib_dump(
69 | obj: Any, fs: fs.base.FS, key: str, ext: str = "joblib", **kwargs: Dict[str, Any]
70 | ) -> None:
71 | """Store an object on a filesystem."""
72 | filename = f"{key}.{ext}"
73 | with fs.open(filename, "wb") as fid:
74 | joblib.dump(obj, fid, **kwargs)
75 |
76 |
77 | def joblib_load(fs: fs.base.FS, key: str, ext: str = "joblib", **kwargs: Dict[str, Any]) -> Any:
78 | """Load an object from a filesystem."""
79 | filename = f"{key}.{ext}"
80 | with fs.open(filename, "rb") as fid:
81 | return joblib.load(fid, **kwargs)
82 |
83 |
84 | joblib_dump_lz4 = partial(joblib_dump, compress="lz4", ext="joblib.lz4", protocol=HIGHEST_PROTOCOL)
85 | joblib_load_lz4 = partial(joblib_load, ext="joblib.lz4")
86 |
87 |
88 | class CacheFS:
89 | """Lazily opened PyFileSystem."""
90 |
91 | def __init__(self, location: Union[str, fs.base.FS]):
92 | """Wrap a PyFilesystem FS URL to provide a single FS instance.
93 |
94 | Parameters
95 | ----------
96 | location
97 | A PyFilesystem FS URL to store the cached computations in. Can be a local directory such
98 | as ``"./__graphchain_cache__/"`` or a remote directory such as
99 | ``"s3://bucket/__graphchain_cache__/"``. You can also pass a PyFilesystem itself
100 | instead.
101 | """
102 | self.location = location
103 |
104 | @property # type: ignore[misc]
105 | @_cache
106 | def fs(self) -> fs.base.FS:
107 | """Open a PyFilesystem FS to the cache directory."""
108 | # create=True does not yet work for S3FS [1]. This should probably be left to the user as we
109 | # don't know in which region to create the bucket, among other configuration options.
110 | # [1] https://github.com/PyFilesystem/s3fs/issues/23
111 | if isinstance(self.location, fs.base.FS):
112 | return self.location
113 | return fs.open_fs(self.location, create=True)
114 |
115 |
116 | class CachedComputation:
117 | """A replacement for computations in dask graphs."""
118 |
119 | def __init__(
120 | self,
121 | dsk: Dict[Hashable, Any],
122 | key: Hashable,
123 | computation: Any,
124 | location: Union[str, fs.base.FS, CacheFS] = "./__graphchain_cache__/",
125 | serialize: Callable[[Any, fs.base.FS, str], None] = joblib_dump_lz4,
126 | deserialize: Callable[[fs.base.FS, str], Any] = joblib_load_lz4,
127 | write_to_cache: Union[bool, str] = "auto",
128 | ) -> None:
129 | """Cache a dask graph computation.
130 |
131 | A wrapper for the computation object to replace the original computation with in the dask
132 | graph.
133 |
134 | Parameters
135 | ----------
136 | dsk
137 | The dask graph this computation is a part of.
138 | key
139 | The key corresponding to this computation in the dask graph.
140 | computation
141 | The computation to cache.
142 | location
143 | A PyFilesystem FS URL to store the cached computations in. Can be a local directory such
144 | as ``"./__graphchain_cache__/"`` or a remote directory such as
145 | ``"s3://bucket/__graphchain_cache__/"``. You can also pass a CacheFS instance or a
146 | PyFilesystem itself instead.
147 | serialize
148 | A function of the form ``serialize(result: Any, fs: fs.base.FS, key: str)`` that caches
149 | a computation `result` to a filesystem `fs` under a given `key`.
150 | deserialize
151 | A function of the form ``deserialize(fs: fs.base.FS, key: str)`` that reads a cached
152 | computation `result` from a `key` on a given filesystem `fs`.
153 | write_to_cache
154 | Whether or not to cache this computation. If set to ``"auto"``, will only write to cache
155 | if it is expected this will speed up future gets of this computation, taking into
156 | account the characteristics of the `location` filesystem.
157 | """
158 | self.dsk = dsk
159 | self.key = key
160 | self.computation = computation
161 | self.location = location
162 | self.serialize = serialize
163 | self.deserialize = deserialize
164 | self.write_to_cache = write_to_cache
165 |
166 | @property # type: ignore[misc]
167 | @_cache
168 | def cache_fs(self) -> fs.base.FS:
169 | """Open a PyFilesystem FS to the cache directory."""
170 | # create=True does not yet work for S3FS [1]. This should probably be left to the user as we
171 | # don't know in which region to create the bucket, among other configuration options.
172 | # [1] https://github.com/PyFilesystem/s3fs/issues/23
173 | if isinstance(self.location, fs.base.FS):
174 | return self.location
175 | if isinstance(self.location, CacheFS):
176 | return self.location.fs
177 | return fs.open_fs(self.location, create=True)
178 |
179 | def __repr__(self) -> str:
180 | """Represent this CachedComputation object as a string."""
181 | return f""
182 |
183 | def _subs_dependencies_with_hash(self, computation: Any) -> Any:
184 | """Replace key references in a computation by their hashes."""
185 | dependencies = dask.core.get_dependencies(
186 | self.dsk, task=0 if computation is None else computation
187 | )
188 | for dep in dependencies:
189 | computation = dask.core.subs(
190 | computation,
191 | dep,
192 | self.dsk[dep].hash
193 | if isinstance(self.dsk[dep], CachedComputation)
194 | else self.dsk[dep][0].hash,
195 | )
196 | return computation
197 |
198 | def _subs_tasks_with_src(self, computation: Any) -> Any:
199 | """Replace task functions by their source code."""
200 | if type(computation) is list:
201 | # This computation is a list of computations.
202 | computation = [self._subs_tasks_with_src(x) for x in computation]
203 | elif dask.core.istask(computation):
204 | # This computation is a task.
205 | src = joblib.func_inspect.get_func_code(computation[0])[0]
206 | computation = (src,) + computation[1:]
207 | return computation
208 |
209 | def compute_hash(self) -> str:
210 | """Compute a hash of this computation object and its dependencies."""
211 | # Replace dependencies with their hashes and functions with source.
212 | computation = self._subs_dependencies_with_hash(self.computation)
213 | computation = self._subs_tasks_with_src(computation)
214 | # Return the hash of the resulting computation.
215 | comp_hash: str = joblib.hash(cloudpickle.dumps(computation))
216 | return comp_hash
217 |
218 | @property
219 | def hash(self) -> str:
220 | """Return the hash of this CachedComputation."""
221 | if not hasattr(self, "_hash"):
222 | self._hash = self.compute_hash()
223 | return self._hash
224 |
225 | def estimate_load_time(self, result: Any) -> float:
226 | """Estimate the time to load the given result from cache."""
227 | size: float = get_size(result) / dask.config.get("cache_estimated_compression_ratio", 2.0)
228 | # Use typical SSD latency and bandwith if cache_fs is a local filesystem, else use a typical
229 | # latency and bandwidth for network-based filesystems.
230 | read_latency = float(
231 | dask.config.get(
232 | "cache_latency",
233 | 1e-4 if isinstance(self.cache_fs, (fs.osfs.OSFS, fs.memoryfs.MemoryFS)) else 50e-3,
234 | )
235 | )
236 | read_throughput = float(
237 | dask.config.get(
238 | "cache_throughput",
239 | 500e6 if isinstance(self.cache_fs, (fs.osfs.OSFS, fs.memoryfs.MemoryFS)) else 50e6,
240 | )
241 | )
242 | return read_latency + size / read_throughput
243 |
244 | @_cache
245 | def read_time(self, timing_type: str) -> float:
246 | """Read the time to load, compute, or store from file."""
247 | time_filename = f"{self.hash}.time.{timing_type}"
248 | with self.cache_fs.open(time_filename, "r") as fid:
249 | return float(fid.read())
250 |
251 | def write_time(self, timing_type: str, seconds: float) -> None:
252 | """Write the time to load, compute, or store from file."""
253 | time_filename = f"{self.hash}.time.{timing_type}"
254 | with self.cache_fs.open(time_filename, "w") as fid:
255 | fid.write(str(seconds))
256 |
257 | def write_log(self, log_type: str) -> None:
258 | """Write the timestamp of a load, compute, or store operation."""
259 | key = str_to_posix_fully_portable_filename(str(self.key))
260 | now = str_to_posix_fully_portable_filename(str(dt.datetime.now()))
261 | log_filename = f".{now}.{log_type}.{key}.log"
262 | with self.cache_fs.open(log_filename, "w") as fid:
263 | fid.write(self.hash)
264 |
265 | def time_to_result(self, memoize: bool = True) -> float:
266 | """Estimate the time to load or compute this computation."""
267 | if hasattr(self, "_time_to_result"):
268 | return self._time_to_result # type: ignore[has-type,no-any-return]
269 | if memoize:
270 | try:
271 | try:
272 | load_time = self.read_time("load")
273 | except Exception:
274 | load_time = self.read_time("store") / 2
275 | self._time_to_result = load_time
276 | return load_time
277 | except Exception:
278 | pass
279 | compute_time = self.read_time("compute")
280 | dependency_time = 0
281 | dependencies = dask.core.get_dependencies(
282 | self.dsk, task=0 if self.computation is None else self.computation
283 | )
284 | for dep in dependencies:
285 | dependency_time += self.dsk[dep][0].time_to_result()
286 | total_time = compute_time + dependency_time
287 | if memoize:
288 | self._time_to_result = total_time
289 | return total_time
290 |
291 | def cache_file_exists(self) -> bool:
292 | """Check if this CachedComputation's cache file exists."""
293 | return self.cache_fs.exists(f"{self.hash}.time.store")
294 |
295 | def load(self) -> Any:
296 | """Load this result of this computation from cache."""
297 | try:
298 | # Load from cache.
299 | start_time = time.perf_counter()
300 | logger.info(f"LOAD {self} from {self.cache_fs}/{self.hash}")
301 | result = self.deserialize(self.cache_fs, self.hash)
302 | load_time = time.perf_counter() - start_time
303 | # Write load time and log operation.
304 | self.write_time("load", load_time)
305 | self.write_log("load")
306 | return result
307 | except Exception:
308 | logger.exception(
309 | f"Could not read {self.cache_fs}/{self.hash}. Marking cache as invalid, please try again!"
310 | )
311 | self.cache_fs.remove(f"{self.hash}.time.store")
312 | raise
313 |
314 | def compute(self, *args: Any, **kwargs: Any) -> Any:
315 | """Compute this computation."""
316 | # Compute the computation.
317 | logger.info(f"COMPUTE {self}")
318 | start_time = time.perf_counter()
319 | if dask.core.istask(self.computation):
320 | result = self.computation[0](*args, **kwargs)
321 | else:
322 | result = args[0]
323 | compute_time = time.perf_counter() - start_time
324 | # Write compute time and log operation
325 | self.write_time("compute", compute_time)
326 | self.write_log("compute")
327 | return result
328 |
329 | def store(self, result: Any) -> None:
330 | """Store the result of this computation in the cache."""
331 | if not self.cache_file_exists():
332 | logger.info(f"STORE {self} to {self.cache_fs}/{self.hash}")
333 | try:
334 | # Store to cache.
335 | start_time = time.perf_counter()
336 | self.serialize(result, self.cache_fs, self.hash)
337 | store_time = time.perf_counter() - start_time
338 | # Write store time and log operation
339 | self.write_time("store", store_time)
340 | self.write_log("store")
341 | except Exception:
342 | # Not crucial to stop if caching fails.
343 | logger.exception(f"Could not write {self.hash}.")
344 |
345 | def patch_computation_in_graph(self) -> None:
346 | """Patch the graph to use this CachedComputation."""
347 | if self.cache_file_exists():
348 | # If there are cache candidates to load this computation from, remove all dependencies
349 | # for this task from the graph as far as dask is concerned.
350 | self.dsk[self.key] = (self,)
351 | else:
352 | # If there are no cache candidates, wrap the execution of the computation with this
353 | # CachedComputation's __call__ method and keep references to its dependencies.
354 | self.dsk[self.key] = (
355 | (self,) + self.computation[1:]
356 | if dask.core.istask(self.computation)
357 | else (self, self.computation)
358 | )
359 |
360 | def __call__(self, *args: Any, **kwargs: Any) -> Any:
361 | """Load this computation from cache, or compute and then store it."""
362 | # Load.
363 | if self.cache_file_exists():
364 | return self.load()
365 | # Compute.
366 | result = self.compute(*args, **kwargs)
367 | # Store.
368 | write_to_cache = self.write_to_cache
369 | if write_to_cache == "auto":
370 | compute_time = self.time_to_result(memoize=False)
371 | estimated_load_time = self.estimate_load_time(result)
372 | write_to_cache = estimated_load_time < compute_time
373 | logger.debug(
374 | f'{"Going" if write_to_cache else "Not going"} to cache {self}'
375 | f" because estimated_load_time={estimated_load_time} "
376 | f'{"<" if write_to_cache else ">="} '
377 | f"compute_time={compute_time}"
378 | )
379 | if write_to_cache:
380 | self.store(result)
381 | return result
382 |
383 |
384 | def optimize(
385 | dsk: Union[Dict[Hashable, Any], HighLevelGraph],
386 | keys: Optional[Union[Hashable, Iterable[Hashable]]] = None,
387 | skip_keys: Optional[Container[Hashable]] = None,
388 | location: Union[str, fs.base.FS, CacheFS] = "./__graphchain_cache__/",
389 | serialize: Callable[[Any, fs.base.FS, str], None] = joblib_dump_lz4,
390 | deserialize: Callable[[fs.base.FS, str], Any] = joblib_load_lz4,
391 | ) -> Dict[Hashable, Any]:
392 | """Optimize a dask graph with cached computations.
393 |
394 | According to the dask graph specification [1]_, a dask graph is a dictionary that maps `keys` to
395 | `computations`. A computation can be:
396 |
397 | 1. Another key in the graph.
398 | 2. A literal.
399 | 3. A task, which is of the form `(Callable, *args)`.
400 | 4. A list of other computations.
401 |
402 | This optimizer replaces all computations in a graph with ``CachedComputation``'s, so that
403 | getting items from the graph will be backed by a cache of your choosing. With this cache, only
404 | the very minimum number of computations will actually be computed to return the values
405 | corresponding to the given keys.
406 |
407 | `CachedComputation` objects *do not* hash task inputs (which is the approach that
408 | `functools.lru_cache` and `joblib.Memory` take) to identify which cache file to load. Instead, a
409 | chain of hashes (hence the name `graphchain`) of the computation object and its dependencies
410 | (which are also computation objects) is used to identify the cache file.
411 |
412 | Since it is generally cheap to hash the graph's computation objects, `graphchain`'s cache is
413 | likely to be much faster than hashing task inputs, which can be slow for large objects such as
414 | `pandas.DataFrame`'s.
415 |
416 | Parameters
417 | ----------
418 | dsk
419 | The dask graph to optimize with caching computations.
420 | keys
421 | Not used. Is present for compatibility with dask optimizers [2]_.
422 | skip_keys
423 | A container of keys not to cache.
424 | location
425 | A PyFilesystem FS URL to store the cached computations in. Can be a local directory such as
426 | ``"./__graphchain_cache__/"`` or a remote directory such as
427 | ``"s3://bucket/__graphchain_cache__/"``. You can also pass a PyFilesystem itself instead.
428 | serialize
429 | A function of the form ``serialize(result: Any, fs: fs.base.FS, key: str)`` that caches a
430 | computation `result` to a filesystem `fs` under a given `key`.
431 | deserialize
432 | A function of the form ``deserialize(fs: fs.base.FS, key: str)`` that reads a cached
433 | computation `result` from a `key` on a given filesystem `fs`.
434 |
435 | Returns
436 | -------
437 | dict
438 | A copy of the dask graph where the computations have been replaced by `CachedComputation`'s.
439 |
440 | References
441 | ----------
442 | .. [1] https://docs.dask.org/en/latest/spec.html
443 | .. [2] https://docs.dask.org/en/latest/optimize.html
444 | """
445 | # Technically a HighLevelGraph isn't actually a dict, but it has largely the same API so we can
446 | # treat it as one. We can't use a type union or protocol either, because HighLevelGraph doesn't
447 | # actually have a __setitem__ implementation, we just monkey-patched that in.
448 | dsk = cast(Dict[Hashable, Any], deepcopy(dsk))
449 | # Verify that the graph is a DAG.
450 | assert dask.core.isdag(dsk, list(dsk.keys()))
451 | if isinstance(location, str):
452 | location = CacheFS(location)
453 | # Replace graph computations by CachedComputations.
454 | skip_keys = skip_keys or set()
455 | for key, computation in dsk.items():
456 | dsk[key] = CachedComputation(
457 | dsk,
458 | key,
459 | computation,
460 | location=location,
461 | serialize=serialize,
462 | deserialize=deserialize,
463 | write_to_cache=False if key in skip_keys else "auto",
464 | )
465 | # Remove task arguments if we can load from cache.
466 | for key in dsk:
467 | dsk[key].patch_computation_in_graph()
468 |
469 | return dsk
470 |
471 |
472 | def get(
473 | dsk: Dict[Hashable, Any],
474 | keys: Union[Hashable, Sequence[Hashable]],
475 | skip_keys: Optional[Container[Hashable]] = None,
476 | location: Union[str, fs.base.FS, CacheFS] = "./__graphchain_cache__/",
477 | serialize: Callable[[Any, fs.base.FS, str], None] = joblib_dump_lz4,
478 | deserialize: Callable[[fs.base.FS, str], Any] = joblib_load_lz4,
479 | scheduler: Optional[
480 | Callable[[Dict[Hashable, Any], Union[Hashable, Sequence[Hashable]]], Any]
481 | ] = None,
482 | ) -> Any:
483 | """Get one or more keys from a dask graph with caching.
484 |
485 | Optimizes a dask graph with `graphchain.optimize` and then computes the requested keys with the
486 | desired scheduler, which is by default `dask.get`.
487 |
488 | See `graphchain.optimize` for more information on how `graphchain`'s cache mechanism works.
489 |
490 | Parameters
491 | ----------
492 | dsk
493 | The dask graph to query.
494 | keys
495 | The keys to compute.
496 | skip_keys
497 | A container of keys not to cache.
498 | location
499 | A PyFilesystem FS URL to store the cached computations in. Can be a local directory such as
500 | ``"./__graphchain_cache__/"`` or a remote directory such as
501 | ``"s3://bucket/__graphchain_cache__/"``. You can also pass a PyFilesystem itself instead.
502 | serialize
503 | A function of the form ``serialize(result: Any, fs: fs.base.FS, key: str)`` that caches a
504 | computation `result` to a filesystem `fs` under a given `key`.
505 | deserialize
506 | A function of the form ``deserialize(fs: fs.base.FS, key: str)`` that reads a cached
507 | computation `result` from a `key` on a given filesystem `fs`.
508 | scheduler
509 | The dask scheduler to use to retrieve the keys from the graph.
510 |
511 | Returns
512 | -------
513 | Any
514 | The computed values corresponding to the given keys.
515 | """
516 | cached_dsk = optimize(
517 | dsk,
518 | keys,
519 | skip_keys=skip_keys,
520 | location=location,
521 | serialize=serialize,
522 | deserialize=deserialize,
523 | )
524 | schedule = dask.base.get_scheduler(scheduler=scheduler) or dask.get
525 | return schedule(cached_dsk, keys)
526 |
--------------------------------------------------------------------------------
/src/graphchain/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlinear-ai/graphchain/1146693720eb9a3077a342d6733106d3ec7de1ad/src/graphchain/py.typed
--------------------------------------------------------------------------------
/src/graphchain/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions used by graphchain."""
2 |
3 | import string
4 | import sys
5 | from typing import Any, Optional, Set
6 |
7 |
8 | def _fast_get_size(obj: Any) -> int:
9 | if hasattr(obj, "__len__") and len(obj) <= 0:
10 | return 0
11 | if hasattr(obj, "sample") and hasattr(obj, "memory_usage"): # DF, Series.
12 | n = min(len(obj), 1000)
13 | s = obj.sample(frac=n / len(obj)).memory_usage(index=True, deep=True)
14 | if hasattr(s, "sum"):
15 | s = s.sum()
16 | if hasattr(s, "compute"):
17 | s = s.compute()
18 | s = s / n * len(obj)
19 | return int(s)
20 | elif hasattr(obj, "nbytes"): # Numpy.
21 | return int(obj.nbytes)
22 | elif hasattr(obj, "data") and hasattr(obj.data, "nbytes"): # Sparse.
23 | return int(3 * obj.data.nbytes)
24 | raise TypeError("Could not determine size of the given object.")
25 |
26 |
27 | def _slow_get_size(obj: Any, seen: Optional[Set[Any]] = None) -> int:
28 | size = sys.getsizeof(obj)
29 | seen = seen or set()
30 | obj_id = id(obj)
31 | if obj_id in seen:
32 | return 0
33 | seen.add(obj_id)
34 | if isinstance(obj, dict):
35 | size += sum(get_size(v, seen) for v in obj.values())
36 | size += sum(get_size(k, seen) for k in obj.keys())
37 | elif hasattr(obj, "__dict__"):
38 | size += get_size(obj.__dict__, seen)
39 | elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)):
40 | size += sum(get_size(i, seen) for i in obj)
41 | return size
42 |
43 |
44 | def get_size(obj: Any, seen: Optional[Set[Any]] = None) -> int:
45 | """Recursively compute the size of an object.
46 |
47 | Parameters
48 | ----------
49 | obj
50 | The object to get the size of.
51 | seen
52 | A set of seen objects.
53 |
54 | Returns
55 | -------
56 | int
57 | The (approximate) size in bytes of the given object.
58 | """
59 | # Short-circuit some types.
60 | try:
61 | return _fast_get_size(obj)
62 | except TypeError:
63 | pass
64 | # General-purpose size computation.
65 | return _slow_get_size(obj, seen)
66 |
67 |
68 | def str_to_posix_fully_portable_filename(s: str) -> str:
69 | """Convert key to POSIX fully portable filename [1].
70 |
71 | Parameters
72 | ----------
73 | s
74 | The string to convert to a POSIX fully portable filename.
75 |
76 | Returns
77 | -------
78 | str
79 | A POSIX fully portable filename.
80 |
81 | References
82 | ----------
83 | .. [1] https://en.wikipedia.org/wiki/Filename
84 | """
85 | safechars = string.ascii_letters + string.digits + "._-"
86 | return "".join(c if c in safechars else "-" for c in s)
87 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Graphchain module tests."""
2 |
--------------------------------------------------------------------------------
/tests/test_dask_dataframe.py:
--------------------------------------------------------------------------------
1 | """Test module for dask DataFrames."""
2 |
3 | import os
4 | from functools import partial
5 | from typing import Any
6 |
7 | import dask
8 | import dask.dataframe
9 | import fs.base
10 | import fs.osfs
11 | import joblib
12 | import pytest
13 | from dask.highlevelgraph import HighLevelGraph
14 |
15 | from graphchain.core import optimize
16 |
17 |
18 | @pytest.fixture()
19 | def dask_dataframe_graph() -> HighLevelGraph:
20 | """Generate an example dask DataFrame graph."""
21 |
22 | @dask.delayed(pure=True)
23 | def create_dataframe() -> dask.dataframe.DataFrame:
24 | df: dask.dataframe.DataFrame = dask.datasets.timeseries(seed=42)
25 | return df
26 |
27 | @dask.delayed(pure=True)
28 | def summarise_dataframe(df: dask.dataframe.DataFrame) -> float:
29 | value: float = df["x"].sum().compute() + df["y"].sum().compute()
30 | return value
31 |
32 | df = create_dataframe()
33 | result: HighLevelGraph = summarise_dataframe(df)
34 | return result
35 |
36 |
37 | def test_dask_dataframe_graph(dask_dataframe_graph: HighLevelGraph) -> None:
38 | """Test that the graph can be traversed and its result is correct."""
39 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")})
40 | with dask.config.set(scheduler="sync", delayed_optimize=optimize):
41 | result = dask_dataframe_graph.compute() # type: ignore[attr-defined]
42 | assert result == 856.0466289487188
43 | result = dask_dataframe_graph.compute() # type: ignore[attr-defined]
44 | assert result == 856.0466289487188
45 |
46 |
47 | def test_custom_serde(dask_dataframe_graph: HighLevelGraph) -> None:
48 | """Test that we can use a custom serializer/deserializer."""
49 |
50 | def custom_serialize(obj: Any, fs: fs.osfs.OSFS, key: str) -> None:
51 | if isinstance(obj, dask.dataframe.DataFrame):
52 | obj.to_parquet(os.path.join(fs.root_path, key))
53 | else:
54 | with fs.open(f"{key}.joblib", "wb") as fid:
55 | joblib.dump(obj, fid)
56 |
57 | def custom_deserialize(fs: fs.osfs.OSFS, key: str) -> Any:
58 | if fs.exists(f"{key}.joblib"):
59 | with fs.open(f"{key}.joblib", "rb") as fid:
60 | return joblib.load(fid)
61 | else:
62 | return dask.dataframe.read_parquet(os.path.join(fs.root_path, key))
63 |
64 | custom_optimize = partial(
65 | optimize,
66 | location="__graphchain_cache__/parquet/",
67 | serialize=custom_serialize,
68 | deserialize=custom_deserialize,
69 | )
70 |
71 | # Ensure everything gets cached.
72 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")})
73 |
74 | with dask.config.set(scheduler="sync", delayed_optimize=custom_optimize):
75 | result = dask_dataframe_graph.compute() # type: ignore[attr-defined]
76 | assert result == 856.0466289487188
77 | result = dask_dataframe_graph.compute() # type: ignore[attr-defined]
78 | assert result == 856.0466289487188
79 |
--------------------------------------------------------------------------------
/tests/test_graphchain.py:
--------------------------------------------------------------------------------
1 | """Test module for the graphchain core."""
2 |
3 | import functools
4 | import os
5 | import shutil
6 | import tempfile
7 | from typing import Any, Callable, Dict, Hashable, Iterable, Tuple, Union
8 |
9 | import dask
10 | import fs
11 | import pytest
12 |
13 | from graphchain.core import CachedComputation, optimize
14 |
15 |
16 | @pytest.fixture()
17 | def dask_graph() -> Dict[Hashable, Any]:
18 | r"""Generate an example dask graph.
19 |
20 | Will be used as a basis for the functional testing of the graphchain
21 | module::
22 |
23 | O top(..)
24 | ____|____
25 | / \
26 | d1 O baz(..)
27 | _________|________
28 | / \
29 | O boo(...) O goo(...)
30 | _______|_______ ____|____
31 | / | \ / | \
32 | O O O O | O
33 | foo(.) bar(.) baz(.) foo(.) v6 bar(.)
34 | | | | | |
35 | | | | | |
36 | v1 v2 v3 v4 v5
37 | """
38 | # Functions
39 | def foo(argument: int) -> int:
40 | return argument
41 |
42 | def bar(argument: int) -> int:
43 | return argument + 2
44 |
45 | def baz(*args: int) -> int:
46 | return sum(args)
47 |
48 | def boo(*args: int) -> int:
49 | return len(args) + sum(args)
50 |
51 | def goo(*args: int) -> int:
52 | return sum(args) + 1
53 |
54 | def top(argument: int, argument2: int) -> int:
55 | return argument - argument2
56 |
57 | # Graph (for the function definitions above)
58 | dsk = {
59 | "v0": None,
60 | "v1": 1,
61 | "v2": 2,
62 | "v3": 3,
63 | "v4": 0,
64 | "v5": -1,
65 | "v6": -2,
66 | "d1": -3,
67 | "foo1": (foo, "v1"),
68 | "foo2": (foo, "v4"),
69 | "bar1": (bar, "v2"),
70 | "bar2": (bar, "v5"),
71 | "baz1": (baz, "v3"),
72 | "baz2": (baz, "boo1", "goo1"),
73 | "boo1": (boo, "foo1", "bar1", "baz1"),
74 | "goo1": (goo, "foo2", "bar2", "v6"),
75 | "top1": (top, "d1", "baz2"),
76 | }
77 | return dsk # type: ignore[return-value]
78 |
79 |
80 | @pytest.fixture(scope="module")
81 | def temp_dir() -> str: # type: ignore[misc]
82 | """Create a temporary directory to store the cache in."""
83 | with tempfile.TemporaryDirectory(prefix="__graphchain_cache__") as tmpdir:
84 | yield tmpdir
85 |
86 |
87 | @pytest.fixture(scope="module")
88 | def temp_dir_s3() -> str:
89 | """Create the directory used for the graphchain tests on S3."""
90 | location = "s3://graphchain-test-bucket/__pytest_graphchain_cache__"
91 | return location
92 |
93 |
94 | def test_dag(dask_graph: Dict[Hashable, Any]) -> None:
95 | """Test that the graph can be traversed and its result is correct."""
96 | dsk = dask_graph
97 | result = dask.get(dsk, ["top1"])
98 | assert result == (-14,)
99 |
100 |
101 | @pytest.fixture()
102 | def optimizer(temp_dir: str) -> Tuple[str, Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]]:
103 | """Prefill the graphchain optimizer's parameters."""
104 | return temp_dir, functools.partial(optimize, location=temp_dir)
105 |
106 |
107 | @pytest.fixture()
108 | def optimizer_exec_only_nodes(
109 | temp_dir: str,
110 | ) -> Tuple[str, Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]]:
111 | """Prefill the graphchain optimizer's parameters."""
112 | return temp_dir, functools.partial(optimize, location=temp_dir, skip_keys=["boo1"])
113 |
114 |
115 | @pytest.fixture()
116 | def optimizer_s3(
117 | temp_dir_s3: str,
118 | ) -> Tuple[str, Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]]:
119 | """Prefill the graphchain optimizer's parameters."""
120 | return temp_dir_s3, functools.partial(optimize, location=temp_dir_s3)
121 |
122 |
123 | def test_first_run(
124 | dask_graph: Dict[Hashable, Any],
125 | optimizer: Tuple[
126 | str,
127 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]],
128 | ],
129 | ) -> None:
130 | """First run.
131 |
132 | Tests a first run of the graphchain optimization function ``optimize``. It
133 | checks the final result, that that all function calls are wrapped - for
134 | execution and output storing, that the hashchain is created, that hashed
135 | outputs (the .pickle[.lz4] files) are generated and that the name of
136 | each file is a key in the hashchain.
137 | """
138 | dsk = dask_graph
139 | cache_dir, graphchain_optimize = optimizer
140 |
141 | # Run optimizer
142 | newdsk = graphchain_optimize(dsk, ["top1"])
143 |
144 | # Check the final result
145 | result = dask.get(newdsk, ["top1"])
146 | assert result == (-14,)
147 |
148 | # Check that all functions have been wrapped
149 | for key, _task in dsk.items():
150 | newtask = newdsk[key]
151 | assert isinstance(newtask[0], CachedComputation)
152 |
153 | # Check that the hash files are written and that each
154 | # filename can be found as a key in the hashchain
155 | # (the association of hash <-> DAG tasks is not tested)
156 | storage = fs.osfs.OSFS(cache_dir)
157 | filelist = storage.listdir("/")
158 | nfiles = len(filelist)
159 | assert nfiles >= len(dsk)
160 | storage.close()
161 |
162 |
163 | @pytest.mark.skip(reason="Need AWS credentials to test")
164 | def test_single_run_s3(
165 | dask_graph: Dict[Hashable, Any],
166 | optimizer_s3: Tuple[
167 | str,
168 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]],
169 | ],
170 | ) -> None:
171 | """Run on S3.
172 |
173 | Tests a single run of the graphchain optimization function ``optimize``
174 | using Amazon S3 as a persistency layer. It checks the final result, that
175 | all function calls are wrapped - for execution and output storing, that the
176 | hashchain is created, that hashed outputs (the .pickle[.lz4] files)
177 | are generated and that the name of each file is a key in the hashchain.
178 | """
179 | dsk = dask_graph
180 | cache_dir, graphchain_optimize = optimizer_s3
181 |
182 | # Run optimizer
183 | newdsk = graphchain_optimize(dsk, ["top1"])
184 |
185 | # Check the final result
186 | result = dask.get(newdsk, ["top1"])
187 | assert result == (-14,)
188 |
189 | data_ext = ".joblib.lz4"
190 |
191 | # Check that all functions have been wrapped
192 | for key, _task in dsk.items():
193 | newtask = newdsk[key]
194 | isinstance(newtask, CachedComputation)
195 |
196 | # Check that the hash files are written and that each
197 | # filename can be found as a key in the hashchain
198 | # (the association of hash <-> DAG tasks is not tested)
199 | storage = fs.open_fs(cache_dir)
200 | filelist = storage.listdir("/")
201 | nfiles = sum(int(x.endswith(data_ext)) for x in filelist)
202 | assert nfiles == len(dsk)
203 |
204 |
205 | def test_second_run(
206 | dask_graph: Dict[Hashable, Any],
207 | optimizer: Tuple[
208 | str,
209 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]],
210 | ],
211 | ) -> None:
212 | """Second run.
213 |
214 | Tests a second run of the graphchain optimization function `optimize`. It
215 | checks the final result, that that all function calls are wrapped - for
216 | loading and the the result key has no dependencies.
217 | """
218 | dsk = dask_graph
219 | _, graphchain_optimize = optimizer
220 |
221 | # Run optimizer
222 | newdsk = graphchain_optimize(dsk, ["top1"])
223 |
224 | # Check the final result
225 | result = dask.get(newdsk, ["top1"])
226 | assert result == (-14,)
227 |
228 | # Check that the functions are wrapped for loading
229 | for key in dsk.keys():
230 | newtask = newdsk[key]
231 | assert isinstance(newtask, tuple)
232 | assert isinstance(newtask[0], CachedComputation)
233 |
234 |
235 | def test_node_changes(
236 | dask_graph: Dict[Hashable, Any],
237 | optimizer: Tuple[
238 | str,
239 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]],
240 | ],
241 | ) -> None:
242 | """Test node changes.
243 |
244 | Tests the functionality of the graphchain in the event of changes in the
245 | structure of the graph, namely by altering the functions/constants
246 | associated to the tasks. After optimization, the afected nodes should be
247 | wrapped in a storeand execution wrapper and their dependency lists should
248 | not be empty.
249 | """
250 | dsk = dask_graph
251 | _, graphchain_optimize = optimizer
252 |
253 | # Replacement function 'goo'
254 | def goo(*args: int) -> int:
255 | # hash miss!
256 | return sum(args) + 1
257 |
258 | # Replacement function 'top'
259 | def top(argument: int, argument2: int) -> int:
260 | # hash miss!
261 | return argument - argument2
262 |
263 | moddata = {
264 | "goo1": (goo, {"goo1", "baz2", "top1"}, (-14,)),
265 | # "top1": (top, {"top1"}, (-14,)),
266 | "top1": (lambda *args: -14, {"top1"}, (-14,)),
267 | "v2": (1000, {"v2", "bar1", "boo1", "baz2", "top1"}, (-1012,)),
268 | }
269 |
270 | for (modkey, (taskobj, _affected_nodes, result)) in moddata.items():
271 | workdsk = dsk.copy()
272 | if callable(taskobj):
273 | workdsk[modkey] = (taskobj, *dsk[modkey][1:])
274 | else:
275 | workdsk[modkey] = taskobj
276 |
277 | newdsk = graphchain_optimize(workdsk, ["top1"])
278 | assert result == dask.get(newdsk, ["top1"])
279 |
280 |
281 | def test_exec_only_nodes(
282 | dask_graph: Dict[Hashable, Any],
283 | optimizer_exec_only_nodes: Tuple[
284 | str,
285 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]],
286 | ],
287 | ) -> None:
288 | """Test skipping some tasks.
289 |
290 | Tests that execution-only nodes execute in the event that dependencies of
291 | their parent nodes (i.e. in the dask graph) get modified.
292 | """
293 | dsk = dask_graph
294 | cache_dir, graphchain_optimize = optimizer_exec_only_nodes
295 |
296 | # Cleanup temporary directory
297 | filelist = os.listdir(cache_dir)
298 | for entry in filelist:
299 | entrypath = os.path.join(cache_dir, entry)
300 | if os.path.isdir(entrypath):
301 | shutil.rmtree(entrypath, ignore_errors=True)
302 | else:
303 | os.remove(entrypath)
304 | filelist = os.listdir(cache_dir)
305 | assert not filelist
306 |
307 | # Run optimizer first time
308 | newdsk = graphchain_optimize(dsk, ["top1"])
309 | result = dask.get(newdsk, ["top1"])
310 | assert result == (-14,)
311 |
312 | # Modify function
313 | def goo(*args: int) -> int:
314 | # hash miss this!
315 | return sum(args) + 1
316 |
317 | dsk["goo1"] = (goo, *dsk["goo1"][1:])
318 |
319 | # Run optimizer a second time
320 | newdsk = graphchain_optimize(dsk, ["top1"])
321 |
322 | # Check the final result:
323 | # The output of node 'boo1' is needed at node 'baz2'
324 | # because 'goo1' was modified. A matching result indicates
325 | # that the boo1 node was executed, its dependencies loaded
326 | # which is the desired behaviour in such cases.
327 | result = dask.get(newdsk, ["top1"])
328 | assert result == (-14,)
329 |
330 |
331 | def test_cache_deletion(
332 | dask_graph: Dict[Hashable, Any],
333 | optimizer: Tuple[
334 | str,
335 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]],
336 | ],
337 | ) -> None:
338 | """Test cache deletion.
339 |
340 | Tests the ability to obtain results in the event that cache files are
341 | deleted (in the even of a cache-miss, the exec-store wrapper should be
342 | re-run by the load-wrapper).
343 | """
344 | dsk = dask_graph
345 | cache_dir, graphchain_optimize = optimizer
346 | storage = fs.osfs.OSFS(cache_dir)
347 |
348 | # Cleanup first
349 | storage.removetree("/")
350 |
351 | # Run optimizer (first time)
352 | newdsk = graphchain_optimize(dsk, ["top1"])
353 | result = dask.get(newdsk, ["top1"])
354 |
355 | newdsk = graphchain_optimize(dsk, ["top1"])
356 | result = dask.get(newdsk, ["top1"])
357 |
358 | # Check the final result
359 | assert result == (-14,)
360 |
361 |
362 | def test_identical_nodes(
363 | optimizer: Tuple[
364 | str,
365 | Callable[[Dict[Hashable, Any], Union[Hashable, Iterable[Hashable]]], Dict[Hashable, Any]],
366 | ]
367 | ) -> None:
368 | """Small test for the presence of identical nodes."""
369 | cache_dir, graphchain_optimize = optimizer
370 |
371 | def foo(x: int) -> int:
372 | return x + 1
373 |
374 | def bar(*args: int) -> int:
375 | return sum(args)
376 |
377 | dsk = {"foo1": (foo, 1), "foo2": (foo, 1), "top1": (bar, "foo1", "foo2")}
378 |
379 | # First run
380 | newdsk = graphchain_optimize(dsk, ["top1"]) # type: ignore[arg-type]
381 | result = dask.get(newdsk, ["top1"])
382 | assert result == (4,)
383 |
384 | # Second run
385 | newdsk = graphchain_optimize(dsk, ["top1"]) # type: ignore[arg-type]
386 | result = dask.get(newdsk, ["top1"])
387 | assert result == (4,)
388 |
--------------------------------------------------------------------------------
/tests/test_high_level_graph.py:
--------------------------------------------------------------------------------
1 | """Test module for the dask HighLevelGraphs."""
2 |
3 | from functools import partial
4 | from typing import Any, Dict, cast
5 |
6 | import dask
7 | import dask.dataframe as dd
8 | import fs.base
9 | import pandas as pd
10 | import pytest
11 | from dask.highlevelgraph import HighLevelGraph
12 | from fs.memoryfs import MemoryFS
13 |
14 | from graphchain.core import optimize
15 |
16 |
17 | @pytest.fixture()
18 | def dask_high_level_graph() -> HighLevelGraph:
19 | """Generate an example dask HighLevelGraph."""
20 |
21 | @dask.delayed(pure=True)
22 | def create_dataframe(num_rows: int, num_cols: int) -> pd.DataFrame:
23 | return pd.DataFrame(data=[range(num_cols)] * num_rows)
24 |
25 | @dask.delayed(pure=True)
26 | def create_dataframe2(num_rows: int, num_cols: int) -> pd.DataFrame:
27 | return pd.DataFrame(data=[range(num_cols)] * num_rows)
28 |
29 | @dask.delayed(pure=True)
30 | def complicated_computation(df: pd.DataFrame, num_quantiles: int) -> pd.DataFrame:
31 | return df.quantile(q=[i / num_quantiles for i in range(num_quantiles)])
32 |
33 | @dask.delayed(pure=True)
34 | def summarise_dataframes(*dfs: pd.DataFrame) -> float:
35 | return sum(cast("pd.Series[float]", df.sum()).sum() for df in dfs)
36 |
37 | df_a = create_dataframe(1000, 1000)
38 | df_b = create_dataframe2(1000, 1000)
39 | df_c = complicated_computation(df_a, 2048)
40 | df_d = complicated_computation(df_b, 2048)
41 | result: HighLevelGraph = summarise_dataframes(df_c, df_d)
42 | return result
43 |
44 |
45 | def test_high_level_dag(dask_high_level_graph: HighLevelGraph) -> None:
46 | """Test that the graph can be traversed and its result is correct."""
47 | with dask.config.set(scheduler="sync"):
48 | result = dask_high_level_graph.compute() # type: ignore[attr-defined]
49 | assert result == 2045952000.0
50 |
51 |
52 | def test_high_level_graph(dask_high_level_graph: HighLevelGraph) -> None:
53 | """Test that the graph can be traversed and its result is correct."""
54 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")})
55 | with dask.config.set(scheduler="sync", delayed_optimize=optimize):
56 | result = dask_high_level_graph.compute() # type: ignore[attr-defined]
57 | assert result == 2045952000.0
58 | result = dask_high_level_graph.compute() # type: ignore[attr-defined]
59 | assert result == 2045952000.0
60 |
61 |
62 | def test_high_level_graph_optimize() -> None:
63 | """Test that we can handle the case where a `HighLevelGraph` is passed directly to `optimize()`."""
64 | with dask.config.set(
65 | dataframe_optimize=optimize,
66 | scheduler="sync",
67 | ):
68 | df = dd.from_pandas(
69 | pd.DataFrame({"a": range(1000), "b": range(1000, 2000)}), npartitions=2
70 | ).set_index("b")
71 | computed = df.compute()
72 | assert computed.at[1000, "a"] == 0
73 |
74 |
75 | def test_high_level_graph_parallel(dask_high_level_graph: HighLevelGraph) -> None:
76 | """Test that the graph can be traversed and its result is correct when using parallel scheduler."""
77 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")})
78 | with dask.config.set(scheduler="processes", delayed_optimize=optimize):
79 | result = dask_high_level_graph.compute() # type: ignore[attr-defined]
80 | assert result == 2045952000.0
81 | result = dask_high_level_graph.compute() # type: ignore[attr-defined]
82 | assert result == 2045952000.0
83 |
84 |
85 | def test_custom_serde(dask_high_level_graph: HighLevelGraph) -> None:
86 | """Test that we can use a custom serializer/deserializer."""
87 | custom_cache: Dict[str, Any] = {}
88 |
89 | def custom_serialize(obj: Any, fs: fs.base.FS, key: str) -> None:
90 | # Write the key itself to the filesystem.
91 | with fs.open(f"{key}.dat", "wb") as fid:
92 | fid.write(key)
93 | # Store the actual result in an in-memory cache.
94 | custom_cache[key] = result
95 |
96 | def custom_deserialize(fs: fs.base.FS, key: str) -> Any:
97 | # Verify that we have written the key to the filesystem.
98 | with fs.open(f"{key}.dat", "rb") as fid:
99 | assert key == fid.read()
100 | # Get the result corresponding to that key.
101 | return custom_cache[key]
102 |
103 | # Use a custom location so that we don't corrupt the default cache.
104 | custom_optimize = partial(
105 | optimize,
106 | location=MemoryFS(),
107 | serialize=custom_serialize,
108 | deserialize=custom_deserialize,
109 | )
110 |
111 | # Ensure everything gets cached.
112 | dask.config.set({"cache_latency": 0, "cache_throughput": float("inf")})
113 |
114 | with dask.config.set(scheduler="sync", delayed_optimize=custom_optimize):
115 | result = dask_high_level_graph.compute() # type: ignore[attr-defined]
116 | assert result == 2045952000.0
117 | result = dask_high_level_graph.compute() # type: ignore[attr-defined]
118 | assert result == 2045952000.0
119 |
--------------------------------------------------------------------------------