├── requirements.txt ├── CHANGELOG.md ├── MANIFEST.in ├── compile_snowflake.sh ├── dask_snowflake ├── __init__.py ├── core.py └── tests │ └── test_core.py ├── .github └── workflows │ ├── pre-commit.yml │ ├── tests.yml │ └── wheels.yml ├── ci ├── environment-3.10.yaml ├── environment-3.11.yaml ├── environment-3.12.yaml ├── environment-3.13.yaml └── environment-3.9.yaml ├── setup.cfg ├── .pre-commit-config.yaml ├── setup.py ├── Dockerfile ├── LICENSE ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | dask>=2024.3.0 2 | distributed 3 | snowflake-connector-python[pandas]>=2.6.0 4 | snowflake-sqlalchemy 5 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | Change log 2 | ========== 3 | 4 | 0.3.4 (2025-08-04) 5 | ------------------ 6 | 7 | - python 3.13 support 8 | - pandas >= 2.2 support 9 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include dask_snowflake *.py 2 | 3 | include setup.py 4 | include README.md 5 | include LICENSE 6 | include MANIFEST.in 7 | include requirements.txt -------------------------------------------------------------------------------- /compile_snowflake.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export SF_ARROW_LIBDIR=${CONDA_PREFIX}/lib 4 | export SF_NO_COPY_ARROW_LIB=1 5 | printenv 6 | 7 | python setup.py build_ext 8 | python setup.py bdist_wheel 9 | 10 | 11 | -------------------------------------------------------------------------------- /dask_snowflake/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import PackageNotFoundError, version 2 | 3 | from .core import read_snowflake, to_snowflake 4 | 5 | try: 6 | __version__ = version(__name__) 7 | except PackageNotFoundError: 8 | __version__ = "unknown" 9 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: [push] 4 | 5 | jobs: 6 | checks: 7 | name: "pre-commit hooks" 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - uses: actions/setup-python@v2 12 | - uses: pre-commit/action@v3.0.1 13 | -------------------------------------------------------------------------------- /ci/environment-3.10.yaml: -------------------------------------------------------------------------------- 1 | name: test-environment 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Required 6 | - python=3.10 7 | - dask 8 | - distributed 9 | - pandas 10 | - pyarrow 11 | - snowflake-connector-python >=2.6.0 12 | - snowflake-sqlalchemy 13 | # Testing 14 | - pytest 15 | -------------------------------------------------------------------------------- /ci/environment-3.11.yaml: -------------------------------------------------------------------------------- 1 | name: test-environment 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Required 6 | - python=3.11 7 | - dask 8 | - distributed 9 | - pandas 10 | - pyarrow 11 | - snowflake-connector-python >=2.6.0 12 | - snowflake-sqlalchemy 13 | # Testing 14 | - pytest 15 | -------------------------------------------------------------------------------- /ci/environment-3.12.yaml: -------------------------------------------------------------------------------- 1 | name: test-environment 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Required 6 | - python=3.12 7 | - dask 8 | - distributed 9 | - pandas 10 | - pyarrow 11 | - snowflake-connector-python >=2.6.0 12 | - snowflake-sqlalchemy 13 | # Testing 14 | - pytest 15 | -------------------------------------------------------------------------------- /ci/environment-3.13.yaml: -------------------------------------------------------------------------------- 1 | name: test-environment 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Required 6 | - python=3.13 7 | - dask 8 | - distributed 9 | - pandas 10 | - pyarrow 11 | - snowflake-connector-python >=2.6.0 12 | - snowflake-sqlalchemy 13 | # Testing 14 | - pytest 15 | -------------------------------------------------------------------------------- /ci/environment-3.9.yaml: -------------------------------------------------------------------------------- 1 | name: test-environment 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Required 6 | - python=3.9 7 | - dask 8 | - distributed 9 | - pandas 10 | - pyarrow 11 | - snowflake-connector-python >=2.6.0 12 | - snowflake-sqlalchemy 13 | # Testing 14 | - pytest 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = __init__.py 3 | max-line-length = 120 4 | 5 | [isort] 6 | sections = FUTURE,STDLIB,THIRDPARTY,DASK,FIRSTPARTY,LOCALFOLDER 7 | profile = black 8 | skip_gitignore = true 9 | force_to_top = true 10 | default_section = THIRDPARTY 11 | known_first_party = dask_snowflake 12 | known_dask = dask,distributed 13 | 14 | [tool:pytest] 15 | addopts = -v -rsxfE --durations=10 --color=yes -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.3.0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | exclude: versioneer.py 8 | - repo: https://github.com/keewis/blackdoc 9 | rev: v0.3.8 10 | hooks: 11 | - id: blackdoc 12 | - repo: https://github.com/pycqa/flake8 13 | rev: 6.0.0 14 | hooks: 15 | - id: flake8 16 | language_version: python3 17 | - repo: https://github.com/pycqa/isort 18 | rev: 5.12.0 19 | hooks: 20 | - id: isort 21 | language_version: python3 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | name="dask-snowflake", 7 | use_scm_version=True, 8 | setup_requires=["setuptools_scm"], 9 | description="Dask + Snowflake intergration", 10 | license="BSD", 11 | maintainer="James Bourbeau", 12 | maintainer_email="james@coiled.io", 13 | packages=["dask_snowflake"], 14 | long_description=open("README.md").read(), 15 | long_description_content_type="text/markdown", 16 | python_requires=">=3.9", 17 | install_requires=open("requirements.txt").read().strip().split("\n"), 18 | include_package_data=True, 19 | zip_safe=False, 20 | ) 21 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mambaorg/micromamba:0.11.3 2 | 3 | RUN mkdir /workspace 4 | 5 | 6 | RUN micromamba install --yes \ 7 | -c conda-forge \ 8 | nomkl \ 9 | git \ 10 | python=3.8 \ 11 | cython \ 12 | coiled \ 13 | c-compiler \ 14 | cxx-compiler \ 15 | dask \ 16 | distributed \ 17 | xgboost \ 18 | dask-ml \ 19 | xarray \ 20 | pyarrow \ 21 | tini \ 22 | && \ 23 | micromamba clean --all --yes 24 | 25 | 26 | WORKDIR /workspace 27 | RUN git clone https://github.com/snowflakedb/snowflake-connector-python.git 28 | WORKDIR /workspace/snowflake-connector-python 29 | RUN git checkout parallel-fetch-prpr 30 | 31 | SHELL ["/bin/bash", "-c"] 32 | COPY compile_snowflake.sh compile_snowflake.sh 33 | RUN ./compile_snowflake.sh \ 34 | && rm -rf build \ 35 | && pip install dist/*.whl \ 36 | && rm -rf dist \ 37 | rm compile_snowflake.sh 38 | 39 | RUN mkdir dask-snowflake 40 | 41 | WORKDIR /workspace 42 | COPY * dask_snowflake/ 43 | RUN pip install ./dask_snowflake 44 | 45 | 46 | ENTRYPOINT ["tini", "-g", "--"] 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Coiled 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | workflow_dispatch: 6 | jobs: 7 | test: 8 | runs-on: ${{ matrix.os }} 9 | defaults: 10 | run: 11 | shell: bash -l {0} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: ["windows-latest", "ubuntu-latest", "macos-latest"] 16 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 17 | exclude: 18 | # Python 3.11 build on macOS times out for some reason 19 | # xref https://github.com/dask-contrib/dask-snowflake/pull/56 20 | - os: macos-latest 21 | python-version: "3.11" 22 | steps: 23 | - name: Checkout source 24 | uses: actions/checkout@v4 25 | 26 | - name: Setup Conda Environment 27 | uses: conda-incubator/setup-miniconda@v3 28 | with: 29 | miniforge-version: latest 30 | use-mamba: true 31 | channel-priority: strict 32 | python-version: ${{ matrix.python-version }} 33 | environment-file: ci/environment-${{ matrix.python-version }}.yaml 34 | activate-environment: test-environment 35 | auto-activate-base: false 36 | 37 | - name: Install dask_snowflake 38 | run: python -m pip install -e . 39 | 40 | - name: Run tests 41 | env: 42 | SNOWFLAKE_USER: ${{ secrets.SNOWFLAKE_USER }} 43 | SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }} 44 | SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }} 45 | SNOWFLAKE_WAREHOUSE: ${{ secrets.SNOWFLAKE_WAREHOUSE }} 46 | SNOWFLAKE_ROLE: ${{ secrets.SNOWFLAKE_ROLE }} 47 | run: pytest dask_snowflake -------------------------------------------------------------------------------- /.github/workflows/wheels.yml: -------------------------------------------------------------------------------- 1 | name: Build and maybe upload to PyPI 2 | 3 | on: 4 | push: 5 | pull_request: 6 | release: 7 | types: 8 | - released 9 | - prereleased 10 | 11 | jobs: 12 | artifacts: 13 | name: Build wheels on ${{ matrix.os }} 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [ubuntu-latest] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | with: 23 | fetch-depth: 0 24 | - uses: actions/setup-python@v4 25 | with: 26 | python-version: '3.10' 27 | - name: Build wheels 28 | run: pip wheel . -w dist 29 | - name: Build Source Dist 30 | run: python setup.py sdist 31 | - uses: actions/upload-artifact@v4 32 | with: 33 | name: wheel 34 | path: ./dist/dask_snowflake* 35 | - uses: actions/upload-artifact@v4 36 | with: 37 | name: sdist 38 | path: ./dist/dask-snowflake* 39 | 40 | list_artifacts: 41 | name: List build artifacts 42 | needs: [artifacts] 43 | runs-on: ubuntu-latest 44 | steps: 45 | - uses: actions/download-artifact@v4 46 | with: 47 | name: sdist 48 | path: dist 49 | - uses: actions/download-artifact@v4 50 | with: 51 | name: wheel 52 | path: dist 53 | - name: test 54 | run: | 55 | ls 56 | ls dist 57 | 58 | upload_pypi: 59 | needs: [artifacts] 60 | if: "startsWith(github.ref, 'refs/tags/')" 61 | runs-on: ubuntu-latest 62 | environment: 63 | name: releases 64 | url: https://pypi.org/p/dask-snowflake 65 | permissions: 66 | id-token: write 67 | steps: 68 | - uses: actions/download-artifact@v4 69 | with: 70 | name: sdist 71 | path: dist 72 | - uses: actions/download-artifact@v4 73 | with: 74 | name: wheel 75 | path: dist 76 | - uses: pypa/gh-action-pypi-publish@release/v1 77 | with: 78 | packages-dir: dist 79 | skip-existing: true 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dask-worker-space/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # VS Code 123 | .vscode 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dask-Snowflake 2 | 3 | [![Tests](https://github.com/dask-contrib/dask-snowflake/actions/workflows/tests.yml/badge.svg)](https://github.com/dask-contrib/dask-snowflake/actions/workflows/tests.yml) 4 | [![Linting](https://github.com/dask-contrib/dask-snowflake/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/dask-contrib/dask-snowflake/actions/workflows/pre-commit.yml) 5 | 6 | ## Installation 7 | 8 | `dask-snowflake` can be installed with `pip`: 9 | 10 | ```shell 11 | pip install dask-snowflake 12 | ``` 13 | 14 | or with `conda`: 15 | 16 | ```shell 17 | conda install -c conda-forge dask-snowflake 18 | ``` 19 | 20 | ## Usage 21 | 22 | `dask-snowflake` provides `read_snowflake` and `to_snowflake` methods 23 | for parallel IO from Snowflake with Dask. 24 | 25 | ```python 26 | >>> from dask_snowflake import read_snowflake 27 | >>> example_query = ''' 28 | ... SELECT * 29 | ... FROM SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.CUSTOMER; 30 | ... ''' 31 | >>> ddf = read_snowflake( 32 | ... query=example_query, 33 | ... connection_kwargs={ 34 | ... "user": "...", 35 | ... "password": "...", 36 | ... "account": "...", 37 | ... }, 38 | ... ) 39 | ``` 40 | 41 | ```python 42 | >>> from dask_snowflake import to_snowflake 43 | >>> to_snowflake( 44 | ... ddf, 45 | ... name="my_table", 46 | ... connection_kwargs={ 47 | ... "user": "...", 48 | ... "password": "...", 49 | ... "account": "...", 50 | ... }, 51 | ... ) 52 | ``` 53 | 54 | See their docstrings for further API information. 55 | 56 | ## Tests 57 | 58 | Running tests requires a Snowflake account and access to a database. 59 | The test suite will automatically look for specific `SNOWFLAKE_*` 60 | environment variables (listed below) that must be set. 61 | 62 | It's recommended (though not required) to store these environment variables 63 | in a local `.env` file in the root of the `dask-snowflake` repository. 64 | This file will be automatically ignored by `git`, reducing the risk of accidentally 65 | commiting it. 66 | 67 | Here's what an example `.env` file looks like: 68 | 69 | ```env 70 | SNOWFLAKE_USER="" 71 | SNOWFLAKE_PASSWORD="" 72 | SNOWFLAKE_ACCOUNT="..aws" 73 | SNOWFLAKE_WAREHOUSE="" 74 | SNOWFLAKE_ROLE="" 75 | SNOWFLAKE_DATABASE="" 76 | SNOWFLAKE_SCHEMA="" 77 | ``` 78 | 79 | You may then `source .env` or install [`pytest-dotenv`](https://github.com/quiqua/pytest-dotenv) 80 | to automatically set these environment variables. 81 | 82 | > **_Note:_** 83 | > If you run the tests and get an `MemoryError` mentioning 84 | > "write+execute memory for ffi.callback()", you probably have stale 85 | > build of `cffi` from conda-forge. Remove it and install the version 86 | > using `pip`: 87 | > 88 | > ```shell 89 | > conda remove cffi --force 90 | > pip install cffi 91 | > ``` 92 | 93 | ## License 94 | 95 | [BSD-3](LICENSE) 96 | -------------------------------------------------------------------------------- /dask_snowflake/core.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | from typing import Optional, Sequence 5 | 6 | import pandas as pd 7 | import pyarrow as pa 8 | import snowflake.connector 9 | from snowflake.connector.pandas_tools import pd_writer, write_pandas 10 | from snowflake.connector.result_batch import ArrowResultBatch 11 | from snowflake.sqlalchemy import URL 12 | from sqlalchemy import create_engine 13 | 14 | import dask 15 | import dask.dataframe as dd 16 | from dask.delayed import delayed 17 | from dask.utils import parse_bytes 18 | 19 | 20 | @delayed 21 | def write_snowflake( 22 | df: pd.DataFrame, 23 | name: str, 24 | connection_kwargs: dict, 25 | write_pandas_kwargs: Optional[dict] = None, 26 | ): 27 | connection_kwargs = { 28 | **{"application": dask.config.get("snowflake.partner", "dask")}, 29 | **connection_kwargs, 30 | } 31 | with snowflake.connector.connect(**connection_kwargs) as conn: 32 | write_pandas( 33 | conn=conn, 34 | df=df, 35 | schema=connection_kwargs.get("schema", None), 36 | # NOTE: since ensure_db_exists uses uppercase for the table name 37 | table_name=name.upper(), 38 | quote_identifiers=False, 39 | **(write_pandas_kwargs or {}), 40 | ) 41 | 42 | 43 | @delayed 44 | def ensure_db_exists( 45 | df: pd.DataFrame, 46 | name: str, 47 | connection_kwargs, 48 | ): 49 | connection_kwargs = { 50 | **{"application": dask.config.get("snowflake.partner", "dask")}, 51 | **connection_kwargs, 52 | } 53 | # NOTE: we have a separate `ensure_db_exists` function in order to use 54 | # pandas' `to_sql` which will create a table if the requested one doesn't 55 | # already exist. However, we don't always want to use Snowflake's `pd_writer` 56 | # approach because it doesn't allow us disable parallel file uploading. 57 | # For these cases we use a separate `write_snowflake` function. 58 | engine = create_engine(URL(**connection_kwargs)) 59 | # # NOTE: pd_writer will automatically uppercase the table name 60 | df.to_sql( 61 | name=name, 62 | schema=connection_kwargs.get("schema", None), 63 | con=engine, 64 | index=False, 65 | if_exists="append", 66 | method=pd_writer, 67 | ) 68 | 69 | 70 | def to_snowflake( 71 | df: dd.DataFrame, 72 | name: str, 73 | connection_kwargs: dict, 74 | write_pandas_kwargs: Optional[dict] = None, 75 | compute: bool = True, 76 | ): 77 | """Write a Dask DataFrame to a Snowflake table. 78 | 79 | Parameters 80 | ---------- 81 | df: 82 | Dask DataFrame to save. 83 | name: 84 | Name of the table to save to. 85 | connection_kwargs: 86 | Connection arguments used when connecting to Snowflake with 87 | ``snowflake.connector.connect``. 88 | compute: 89 | Whether or not to compute immediately. If ``True``, write DataFrame 90 | partitions to Snowflake immediately. If ``False``, return a list of 91 | delayed objects that can be computed later. Defaults to ``True``. 92 | write_pandas_kwargs: 93 | Additional keyword arguments that will be passed to ``snowflake.connector.pandas_tools.write_pandas``. 94 | Examples 95 | -------- 96 | 97 | >>> from dask_snowflake import to_snowflake 98 | >>> df = ... # Create a Dask DataFrame 99 | >>> to_snowflake( 100 | ... df, 101 | ... name="my_table", 102 | ... connection_kwargs={ 103 | ... "user": "...", 104 | ... "password": "...", 105 | ... "account": "...", 106 | ... }, 107 | ... ) 108 | 109 | """ 110 | # Write the DataFrame meta to ensure table exists before 111 | # trying to write all partitions in parallel. Otherwise 112 | # we run into race conditions around creating a new table. 113 | # Also, some clusters will overwrite the `snowflake.partner` configuration value. 114 | # We run `ensure_db_exists` on the cluster to ensure we capture the 115 | # right partner application ID. 116 | ensure_db_exists(df._meta, name, connection_kwargs).compute() 117 | parts = [ 118 | write_snowflake(partition, name, connection_kwargs, write_pandas_kwargs) 119 | for partition in df.to_delayed() 120 | ] 121 | if compute: 122 | dask.compute(parts) 123 | else: 124 | return parts 125 | 126 | 127 | def _fetch_batches(chunks: list[ArrowResultBatch], arrow_options: dict): 128 | return pa.concat_tables([chunk.to_arrow() for chunk in chunks]).to_pandas( 129 | **arrow_options 130 | ) 131 | 132 | 133 | @delayed 134 | def _fetch_query_batches(query, connection_kwargs, execute_params): 135 | connection_kwargs = { 136 | **{"application": dask.config.get("snowflake.partner", "dask")}, 137 | **connection_kwargs, 138 | } 139 | with snowflake.connector.connect(**connection_kwargs) as conn: 140 | with conn.cursor() as cur: 141 | cur.check_can_use_pandas() 142 | cur.check_can_use_arrow_resultset() 143 | cur.execute(query, execute_params) 144 | batches = cur.get_result_batches() 145 | 146 | return [b for b in batches if b.rowcount > 0] 147 | 148 | 149 | def _partition_batches( 150 | batches: list[ArrowResultBatch], 151 | meta: pd.DataFrame, 152 | npartitions: None | int = None, 153 | partition_size: None | str | int = None, 154 | ) -> list[list[ArrowResultBatch]]: 155 | """ 156 | Given a list of batches and a sample, partition the batches into dask dataframe 157 | partitions. 158 | 159 | Batch sizing is seemingly not under our control, and is typically much smaller 160 | than the optimal partition size: 161 | https://docs.snowflake.com/en/user-guide/python-connector-distributed-fetch.html 162 | So instead batch the batches into partitions of approximately the right size. 163 | """ 164 | if (npartitions is None) is (partition_size is None): 165 | raise ValueError( 166 | "Must provide exactly one of `npartitions` or `partition_size`" 167 | ) 168 | 169 | if npartitions is not None: 170 | assert npartitions >= 1 171 | target = sum([b.rowcount for b in batches]) // npartitions 172 | elif partition_size is not None: 173 | partition_bytes = ( 174 | parse_bytes(partition_size) 175 | if isinstance(partition_size, str) 176 | else partition_size 177 | ) 178 | approx_row_size = meta.memory_usage().sum() / len(meta) 179 | target = max(partition_bytes / approx_row_size, 1) 180 | else: 181 | assert False # unreachable 182 | 183 | batches_partitioned: list[list[ArrowResultBatch]] = [] 184 | curr: list[ArrowResultBatch] = [] 185 | partition_len = 0 186 | for batch in batches: 187 | if len(curr) > 0 and batch.rowcount + partition_len > target: 188 | batches_partitioned.append(curr) 189 | curr = [batch] 190 | partition_len = batch.rowcount 191 | else: 192 | curr.append(batch) 193 | partition_len += batch.rowcount 194 | if curr: 195 | batches_partitioned.append(curr) 196 | 197 | return batches_partitioned 198 | 199 | 200 | def read_snowflake( 201 | query: str, 202 | *, 203 | connection_kwargs: dict, 204 | arrow_options: dict | None = None, 205 | execute_params: Sequence | dict | None = None, 206 | partition_size: str | int | None = None, 207 | npartitions: int | None = None, 208 | ) -> dd.DataFrame: 209 | """Load a Dask DataFrame based of the result of a Snowflake query. 210 | 211 | Parameters 212 | ---------- 213 | query: 214 | The Snowflake query to execute. 215 | connection_kwargs: 216 | Connection arguments used when connecting to Snowflake with 217 | ``snowflake.connector.connect``. 218 | arrow_options: 219 | Optional arguments forwarded to ``arrow.Table.to_pandas`` when 220 | converting data to a pandas DataFrame. 221 | execute_params: 222 | Optional query parameters to pass to Snowflake's ``Cursor.execute(...)`` 223 | method. 224 | partition_size: int or str 225 | Approximate size of each partition in the target Dask DataFrame. Either 226 | an integer number of bytes, or a string description like "100 MiB". 227 | Reasonable values are often around few hundred MiB per partition. You 228 | must provide either this or ``npartitions``, with ``npartitions`` taking 229 | precedence. Partitioning is approximate, and your actual partition sizes may 230 | vary. 231 | npartitions: int 232 | An integer number of partitions for the target Dask DataFrame. You 233 | must provide either this or ``partition_size``, with ``npartitions`` taking 234 | precedence. Partitioning is approximate, and your actual number of partitions 235 | may vary. 236 | 237 | Examples 238 | -------- 239 | 240 | >>> from dask_snowflake import read_snowflake 241 | >>> example_query = ''' 242 | ... SELECT * 243 | ... FROM SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.CUSTOMER; 244 | ... ''' 245 | >>> ddf = read_snowflake( 246 | ... query=example_query, 247 | ... connection_kwargs={ 248 | ... "user": "...", 249 | ... "password": "...", 250 | ... "account": "...", 251 | ... }, 252 | ... ) 253 | 254 | """ 255 | if arrow_options is None: 256 | arrow_options = {} 257 | 258 | # Provide a reasonable default, as the raw batches tend to be too small. 259 | if partition_size is None and npartitions is None: 260 | partition_size = "100MiB" 261 | 262 | # Disable `log_imported_packages_in_telemetry` as a temporary workaround for 263 | # https://github.com/snowflakedb/snowflake-connector-python/issues/1648. 264 | # Also xref https://github.com/dask-contrib/dask-snowflake/issues/51. 265 | if connection_kwargs.get("log_imported_packages_in_telemetry"): 266 | raise ValueError( 267 | "Using `log_imported_packages_in_telemetry=True` when creating a " 268 | "Snowflake connection is not currently supported." 269 | ) 270 | else: 271 | connection_kwargs["log_imported_packages_in_telemetry"] = False 272 | 273 | # Some clusters will overwrite the `snowflake.partner` configuration value. 274 | # We fetch snowflake batches on the cluster to ensure we capture the 275 | # right partner application ID. 276 | batches = _fetch_query_batches(query, connection_kwargs, execute_params).compute() 277 | if not batches: 278 | return dd.from_pandas(pd.DataFrame(), npartitions=1) 279 | 280 | batch_types = set(type(b) for b in batches) 281 | if len(batch_types) > 1 or next(iter(batch_types)) is not ArrowResultBatch: 282 | # See https://github.com/dask-contrib/dask-snowflake/issues/21 283 | raise RuntimeError( 284 | f"Currently only `ArrowResultBatch` are supported, but received batch types {batch_types}" 285 | ) 286 | 287 | # Read the first non-empty batch to determine meta, which is useful for a 288 | # better size estimate when partitioning. We could also allow empty meta 289 | # here, which should involve less data transfer to the client, at the 290 | # cost of worse size estimates. Batches seem less than 1MiB in practice, 291 | # so this is likely okay right now, but could be revisited. 292 | meta = batches[0].to_pandas(**arrow_options) 293 | 294 | batches_partitioned = _partition_batches( 295 | batches, meta, npartitions=npartitions, partition_size=partition_size 296 | ) 297 | 298 | return dd.from_map( 299 | partial(_fetch_batches, arrow_options=arrow_options), 300 | batches_partitioned, 301 | meta=meta, 302 | ) 303 | -------------------------------------------------------------------------------- /dask_snowflake/tests/test_core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | 4 | import pandas as pd 5 | import pytest 6 | import snowflake.connector 7 | from snowflake.sqlalchemy import URL 8 | from sqlalchemy import create_engine, text 9 | 10 | import dask 11 | import dask.dataframe as dd 12 | import dask.datasets 13 | from dask.utils import is_dataframe_like, parse_bytes 14 | from distributed import Client, Lock, worker_client 15 | 16 | from dask_snowflake import read_snowflake, to_snowflake 17 | 18 | 19 | @pytest.fixture 20 | def client(): 21 | with Client(n_workers=2, threads_per_worker=10) as client: 22 | yield client 23 | 24 | 25 | @pytest.fixture 26 | def table(connection_kwargs): 27 | name = f"test_table_{uuid.uuid4().hex}".upper() 28 | 29 | yield name 30 | 31 | engine = create_engine(URL(**connection_kwargs)) 32 | with engine.connect() as conn: 33 | conn.execute(text(f"DROP TABLE IF EXISTS {name}")) 34 | 35 | 36 | @pytest.fixture(scope="module") 37 | def connection_kwargs(): 38 | return dict( 39 | user=os.environ["SNOWFLAKE_USER"], 40 | password=os.environ["SNOWFLAKE_PASSWORD"], 41 | account=os.environ["SNOWFLAKE_ACCOUNT"], 42 | database=os.environ.get("SNOWFLAKE_DATABASE", "testdb"), 43 | schema=os.environ.get("SNOWFLAKE_SCHEMA", "public"), 44 | warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], 45 | role=os.environ["SNOWFLAKE_ROLE"], 46 | ) 47 | 48 | 49 | # TODO: Find out if snowflake supports lower-case column names 50 | df = pd.DataFrame({"A": range(10), "B": range(10, 20)}) 51 | ddf = dd.from_pandas(df, npartitions=2) 52 | 53 | 54 | def test_write_read_roundtrip(table, connection_kwargs, client): 55 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 56 | 57 | query = f"SELECT * FROM {table}" 58 | df_out = read_snowflake(query, connection_kwargs=connection_kwargs, npartitions=2) 59 | # FIXME: Why does read_snowflake return lower-case columns names? 60 | df_out.columns = df_out.columns.str.upper() 61 | # FIXME: We need to sort the DataFrame because paritions are written 62 | # in a non-sequential order. 63 | dd.utils.assert_eq( 64 | df, df_out.sort_values(by="A").reset_index(drop=True), check_dtype=False 65 | ) 66 | 67 | 68 | def test_read_empty_result(table, connection_kwargs, client): 69 | # A query that yields in an empty results set should return an empty DataFrame 70 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 71 | 72 | result = read_snowflake( 73 | f"SELECT * FROM {table} where A > %(target)s", 74 | execute_params={"target": df.A.max().item()}, 75 | connection_kwargs=connection_kwargs, 76 | npartitions=2, 77 | ) 78 | assert is_dataframe_like(result) 79 | assert len(result.index) == 0 80 | assert len(result.columns) == 0 81 | 82 | 83 | def test_to_snowflake_compute_false(table, connection_kwargs, client): 84 | result = to_snowflake( 85 | ddf, name=table, connection_kwargs=connection_kwargs, compute=False 86 | ) 87 | assert isinstance(result, list) 88 | assert len(result) == ddf.npartitions 89 | 90 | dask.compute(result) 91 | 92 | ddf2 = read_snowflake( 93 | f"SELECT * FROM {table}", 94 | connection_kwargs=connection_kwargs, 95 | npartitions=2, 96 | ) 97 | # FIXME: Why does read_snowflake return lower-case columns names? 98 | ddf2.columns = ddf2.columns.str.upper() 99 | # FIXME: We need to sort the DataFrame because paritions are written 100 | # in a non-sequential order. 101 | dd.utils.assert_eq( 102 | df, ddf2.sort_values(by="A").reset_index(drop=True), check_dtype=False 103 | ) 104 | 105 | 106 | def test_arrow_options(table, connection_kwargs, client): 107 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 108 | 109 | query = f"SELECT * FROM {table}" 110 | df_out = read_snowflake( 111 | query, 112 | connection_kwargs=connection_kwargs, 113 | arrow_options={"types_mapper": lambda x: pd.Float32Dtype()}, 114 | npartitions=2, 115 | ) 116 | # FIXME: Why does read_snowflake return lower-case columns names? 117 | df_out.columns = df_out.columns.str.upper() 118 | # FIXME: We need to sort the DataFrame because paritions are written 119 | # in a non-sequential order. 120 | expected = df.astype(pd.Float32Dtype()) 121 | dd.utils.assert_eq( 122 | expected, df_out.sort_values(by="A").reset_index(drop=True), check_dtype=False 123 | ) 124 | 125 | 126 | def test_write_pandas_kwargs(table, connection_kwargs, client): 127 | to_snowflake( 128 | ddf.repartition(npartitions=1), name=table, connection_kwargs=connection_kwargs 129 | ) 130 | # Overwrite existing table 131 | to_snowflake( 132 | ddf.repartition(npartitions=1), 133 | name=table, 134 | connection_kwargs=connection_kwargs, 135 | write_pandas_kwargs={"overwrite": True}, 136 | ) 137 | 138 | query = f"SELECT * FROM {table}" 139 | df_out = read_snowflake(query, connection_kwargs=connection_kwargs, npartitions=2) 140 | # FIXME: Why does read_snowflake return lower-case columns names? 141 | df_out.columns = df_out.columns.str.upper() 142 | # FIXME: We need to sort the DataFrame because paritions are written 143 | # in a non-sequential order. 144 | dd.utils.assert_eq( 145 | df, df_out.sort_values(by="A").reset_index(drop=True), check_dtype=False 146 | ) 147 | 148 | 149 | def test_application_id_default(table, connection_kwargs, monkeypatch): 150 | # Patch Snowflake's normal connection mechanism with checks that 151 | # the expected application ID is set 152 | count = 0 153 | 154 | def mock_connect(**kwargs): 155 | nonlocal count 156 | count += 1 157 | assert kwargs["application"] == "dask" 158 | return snowflake.connector.Connect(**kwargs) 159 | 160 | monkeypatch.setattr(snowflake.connector, "connect", mock_connect) 161 | 162 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 163 | # One extra connection is made to ensure the DB table exists 164 | count_after_write = ddf.npartitions + 1 165 | assert count == count_after_write 166 | 167 | ddf_out = read_snowflake( 168 | f"SELECT * FROM {table}", connection_kwargs=connection_kwargs, npartitions=2 169 | ) 170 | assert count == count_after_write + ddf_out.npartitions 171 | 172 | 173 | def test_application_id_config(table, connection_kwargs, monkeypatch): 174 | with dask.config.set({"snowflake.partner": "foo"}): 175 | # Patch Snowflake's normal connection mechanism with checks that 176 | # the expected application ID is set 177 | count = 0 178 | 179 | def mock_connect(**kwargs): 180 | nonlocal count 181 | count += 1 182 | assert kwargs["application"] == "foo" 183 | return snowflake.connector.Connect(**kwargs) 184 | 185 | monkeypatch.setattr(snowflake.connector, "connect", mock_connect) 186 | 187 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 188 | # One extra connection is made to ensure the DB table exists 189 | count_after_write = ddf.npartitions + 1 190 | assert count == count_after_write 191 | 192 | ddf_out = read_snowflake( 193 | f"SELECT * FROM {table}", connection_kwargs=connection_kwargs, npartitions=2 194 | ) 195 | assert count == count_after_write + ddf_out.npartitions 196 | 197 | 198 | def test_application_id_config_on_cluster(table, connection_kwargs, client): 199 | # Ensure client and workers have different `snowflake.partner` values set. 200 | # Later we'll check that the config value on the workers is the one that's actually used. 201 | with dask.config.set({"snowflake.partner": "foo"}): 202 | client.run(lambda: dask.config.set({"snowflake.partner": "bar"})) 203 | assert dask.config.get("snowflake.partner") == "foo" 204 | assert all( 205 | client.run(lambda: dask.config.get("snowflake.partner") == "bar").values() 206 | ) 207 | 208 | # Patch Snowflake's normal connection mechanism with checks that 209 | # the expected application ID is set 210 | def patch_snowflake_connect(): 211 | def mock_connect(**kwargs): 212 | with worker_client() as client: 213 | # A lock is needed to safely increment the connect counter below 214 | with Lock("snowflake-connect"): 215 | assert kwargs["application"] == "bar" 216 | count = client.get_metadata("connect-count", 0) 217 | client.set_metadata("connect-count", count + 1) 218 | return snowflake.connector.Connect(**kwargs) 219 | 220 | snowflake.connector.connect = mock_connect 221 | 222 | client.run(patch_snowflake_connect) 223 | 224 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 225 | # One extra connection is made to ensure the DB table exists 226 | count_after_write = ddf.npartitions + 1 227 | 228 | ddf_out = read_snowflake( 229 | f"SELECT * FROM {table}", connection_kwargs=connection_kwargs, npartitions=2 230 | ) 231 | assert ( 232 | client.get_metadata("connect-count") 233 | == count_after_write + ddf_out.npartitions 234 | ) 235 | 236 | 237 | def test_application_id_explicit(table, connection_kwargs, monkeypatch): 238 | # Include explicit application ID in input `connection_kwargs` 239 | connection_kwargs["application"] = "foo" 240 | 241 | # Patch Snowflake's normal connection mechanism with checks that 242 | # the expected application ID is set 243 | count = 0 244 | 245 | def mock_connect(**kwargs): 246 | nonlocal count 247 | count += 1 248 | assert kwargs["application"] == "foo" 249 | return snowflake.connector.Connect(**kwargs) 250 | 251 | monkeypatch.setattr(snowflake.connector, "connect", mock_connect) 252 | 253 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 254 | # One extra connection is made to ensure the DB table exists 255 | count_after_write = ddf.npartitions + 1 256 | assert count == count_after_write 257 | 258 | ddf_out = read_snowflake( 259 | f"SELECT * FROM {table}", connection_kwargs=connection_kwargs, npartitions=2 260 | ) 261 | assert count == count_after_write + ddf_out.npartitions 262 | 263 | 264 | def test_execute_params(table, connection_kwargs, client): 265 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 266 | 267 | df_out = read_snowflake( 268 | f"SELECT * FROM {table} where A = %(target)s", 269 | execute_params={"target": 3}, 270 | connection_kwargs=connection_kwargs, 271 | npartitions=2, 272 | ) 273 | # FIXME: Why does read_snowflake return lower-case columns names? 274 | df_out.columns = df_out.columns.str.upper() 275 | # FIXME: We need to sort the DataFrame because paritions are written 276 | # in a non-sequential order. 277 | dd.utils.assert_eq( 278 | df[df["A"] == 3], 279 | df_out, 280 | check_dtype=False, 281 | check_index=False, 282 | ) 283 | 284 | 285 | def test_result_batching(table, connection_kwargs, client): 286 | ddf = ( 287 | dask.datasets.timeseries(freq="10s", seed=1) 288 | .reset_index(drop=True) 289 | .rename(columns=lambda c: c.upper()) 290 | ) 291 | 292 | to_snowflake(ddf, name=table, connection_kwargs=connection_kwargs) 293 | 294 | # Test partition_size logic 295 | ddf_out = read_snowflake( 296 | f"SELECT * FROM {table}", 297 | connection_kwargs=connection_kwargs, 298 | partition_size="2 MiB", 299 | ) 300 | 301 | partition_sizes = ddf_out.memory_usage_per_partition().compute() 302 | assert (partition_sizes < 2 * parse_bytes("2 MiB")).all() 303 | 304 | # Test partition_size logic 305 | ddf_out = read_snowflake( 306 | f"SELECT * FROM {table}", 307 | connection_kwargs=connection_kwargs, 308 | npartitions=4, 309 | ) 310 | assert abs(ddf_out.npartitions - 4) <= 2 311 | 312 | # Can't specify both 313 | with pytest.raises(ValueError, match="exactly one"): 314 | ddf_out = read_snowflake( 315 | f"SELECT * FROM {table}", 316 | connection_kwargs=connection_kwargs, 317 | npartitions=4, 318 | partition_size="2 MiB", 319 | ) 320 | 321 | dd.utils.assert_eq(ddf, ddf_out, check_dtype=False, check_index=False) 322 | --------------------------------------------------------------------------------