├── .github └── workflows │ ├── dependabot-merge.yml │ ├── pages.yml │ ├── publish_pypi.yml │ └── test.yml ├── .gitignore ├── CHANGELOG.md ├── CODEOWNERS ├── LICENSE.txt ├── README.md ├── ddataflow ├── __init__.py ├── data_source.py ├── data_sources.py ├── ddataflow.py ├── downloader.py ├── exceptions.py ├── sampling │ ├── __init__.py │ ├── default.py │ └── sampler.py ├── setup │ ├── __init__.py │ ├── ddataflow_config.py │ └── setup_project.py └── utils.py ├── ddataflow_local.py ├── docs ├── FAQ.md ├── api_reference │ ├── DDataflow.md │ ├── DataSource.md │ ├── DataSourceDownloader.md │ └── DataSources.md ├── ddataflow.png ├── index.md ├── local_development.md ├── sampling.md └── troubleshooting.md ├── examples ├── __init__.py ├── ddataflow_config.py └── pipeline.py ├── html ├── ddataflow.html ├── ddataflow │ ├── data_source.html │ ├── data_sources.html │ ├── ddataflow.html │ ├── downloader.html │ ├── exceptions.html │ ├── samples.html │ ├── samples │ │ └── ddataflow_config.html │ ├── sampling.html │ ├── sampling │ │ ├── default.html │ │ └── sampler.html │ ├── setup_project.html │ └── utils.html ├── index.html └── search.js ├── mkdocs.yml ├── pyproject.toml └── tests ├── __init__.py ├── test_configuration.py ├── test_sampling.py ├── test_sql.py └── tutorial ├── __init__.py ├── test_local_dev.py └── test_sources_and_writers.py /.github/workflows/dependabot-merge.yml: -------------------------------------------------------------------------------- 1 | name: Dependabot Approve 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: 0 9 * * Mon 7 | - cron: 0 10 * * Mon 8 | 9 | 10 | permissions: 11 | pull-requests: write 12 | 13 | jobs: 14 | auto-approve: 15 | name: Approve PR 16 | uses: getyourguide/actions/.github/workflows/dependabot-approve.yml@main 17 | with: 18 | targets: patch:all,minor:all,major:development 19 | pr-url: ${{ github.event.pull_request.html_url }} 20 | auto-merge: 21 | name: Auto merge 22 | uses: getyourguide/actions/.github/workflows/dependabot-merge.yml@main 23 | with: 24 | slack_channel: '#mlplatform-alerts' 25 | slack_ping_support: true 26 | hide_report_sections: ${{ (github.event.schedule == '0 9 * * Mon-Fri' && 'blocked,pending-approval') || ''}} 27 | secrets: inherit 28 | -------------------------------------------------------------------------------- /.github/workflows/pages.yml: -------------------------------------------------------------------------------- 1 | # Simple workflow for deploying static content to GitHub Pages 2 | name: Deploy static content to Pages 3 | 4 | on: 5 | push: 6 | branches: 7 | - master 8 | - main 9 | permissions: 10 | contents: write 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 16 | - name: Configure Git Credentials 17 | run: | 18 | git config user.name github-actions[bot] 19 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 20 | - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 21 | with: 22 | python-version: 3.x 23 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 24 | - uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 25 | with: 26 | key: mkdocs-material-${{ env.cache_id }} 27 | path: .cache 28 | restore-keys: | 29 | mkdocs-material- 30 | - run: pip install mkdocs mkdocstrings[python] mkdocs-material 31 | - run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish python poetry package 2 | on: 3 | # Triggers the workflow on push or pull request events but only for the "main" branch 4 | push: 5 | branches: [ "main" ] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0 12 | - name: Build and publish to pypi 13 | shell: bash 14 | env: 15 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 16 | run: | 17 | pip install poetry 18 | poetry config virtualenvs.create false 19 | poetry build 20 | echo "Publishing now..." 21 | poetry publish -u "__token__" --password "$PYPI_TOKEN" || true -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: Tests 4 | 5 | # Controls when the workflow will run 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the "main" branch 8 | push: 9 | branches: [ "main", "master" ] 10 | pull_request: 11 | branches: [ "main", "master" ] 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | jobs: 17 | test: 18 | # The type of runner that the job will run on 19 | runs-on: ubuntu-latest 20 | 21 | # Steps represent a sequence of tasks that will be executed as part of the job 22 | steps: 23 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 24 | - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0 25 | 26 | # Runs a single command using the runners shell 27 | - name: Install dependencies 28 | run: | 29 | sudo apt update -y 30 | sudo apt install gnome-keyring 31 | pip install --upgrade pip 32 | pip install poetry 33 | poetry install 34 | 35 | # Runs a set of commands using the runners shell 36 | - name: Run a multi-line script 37 | run: | 38 | poetry run pytest 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | *__pycache__* 3 | .idea/* 4 | *.swp 5 | dist/ 6 | site/ 7 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 1.1.15 2 | 3 | - Disable logs when DDataflow is not enable. Support setting logger level. 4 | 5 | # 1.1.14 6 | 7 | - Fix sample and download function 8 | 9 | # 1.1.13 10 | 11 | - Support s3 path and default database 12 | 13 | # 1.1.9 14 | 15 | - Upgrade dependencies due to security 16 | 17 | # 1.1.8 18 | 19 | - Release new version with security patch on urllib 20 | 21 | # 1.1.7 22 | 23 | - Fix bug in source_name. 24 | 25 | # 1.1.6 26 | 27 | - Security dependency fix 28 | 29 | # 1.1.5 - 2022-10-12 30 | 31 | - Allow customization of default sampler limit 32 | 33 | # 1.1.4 - 2022-10-11 34 | 35 | - improved initial project config 36 | 37 | # 1.1.3 - 2022-10-11 38 | 39 | - is_enabled api added 40 | - add documentation website 41 | 42 | # 1.1.2 - 2022-10-05 43 | 44 | - is_local api added 45 | 46 | # 1.1.0 - 2022-09-30 47 | 48 | - Support default sampling enablement in the sources configuration 49 | - Use default_sample: True 50 | - Support omitting the source name in the sources configuration 51 | - When omitted assumes table and uses the source key as the table name 52 | 53 | # 1.0.2 - 2022-09-29 54 | 55 | - Loosen up dependencies 56 | 57 | # 1.0.1 - 2022-09-06 58 | 59 | - Upgrade dependencies for security reasons 60 | 61 | # 1.0.0 - 2022-08-26 62 | 63 | - Add missing fire dependency and graduate project to 1.0 64 | 65 | # 0.2.0 66 | 67 | - Fix bug while downloading data sources 68 | 69 | # 0.1.12 70 | 71 | - Better error messages and documentation 72 | 73 | # 0.1.10 74 | 75 | - Support ddataflow.path to only replace the path in an environment 76 | 77 | # 0.1.9 78 | 79 | - Refactorings 80 | - Fix setup command GB to int 81 | 82 | # 0.1.8 83 | 84 | - Fix missing creating tmp view 85 | - Improve messages indentation 86 | 87 | # 0.1.7 88 | 89 | - Fix source_name for offline mode 90 | - Fix tests 91 | 92 | # 0.1.6 93 | 94 | - Make examples simpler 95 | - Use ddataflow instead of ddataflow_object in the setup examples 96 | 97 | # 0.1.5 98 | 99 | - Improve documentation and add an example 100 | - Improve the setup_project options and documentation 101 | 102 | # 0.1.4 103 | 104 | - create setup_project command 105 | - use limit only rather than sampling for default sampled entities 106 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @getyourguide/mlp 2 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDataFlow 2 | 3 | DDataFlow is an end2end tests and local development solution for machine learning and data pipelines using pyspark. 4 | Check out this blogpost if you want to [understand deeper its design motivation](https://www.getyourguide.careers/posts/ddataflow-a-tool-for-data-end-to-end-tests-for-machine-learning-pipelines). 5 | 6 | ![ddataflow overview](docs/ddataflow.png) 7 | 8 | You can find our documentation under this [link](https://code.getyourguide.com/DDataFlow/). 9 | 10 | ## Features 11 | 12 | - Read a subset of our data so to speed up the running of the pipelines during tests 13 | - Write to a test location our artifacts so you don't pollute production 14 | - Download data for enabling local machine development 15 | 16 | Enables to run on the pipelines in the CI 17 | 18 | ## 1. Install DDataflow 19 | 20 | ```sh 21 | pip install ddataflow 22 | ``` 23 | 24 | `ddataflow --help` will give you an overview of the available commands. 25 | 26 | 27 | # Getting Started (<5min Tutorial) 28 | 29 | 30 | ## 1. Setup some synthetic data 31 | 32 | See the [examples folder](examples/pipeline.py). 33 | 34 | ## 2. Create a ddataflow_config.py file 35 | 36 | The command `ddtaflow setup_project` creates a file like this for you. 37 | 38 | ```py 39 | from ddataflow import DDataflow 40 | 41 | config = { 42 | # add here your tables or paths with customized sampling logic 43 | "data_sources": { 44 | "demo_tours": { 45 | "source": lambda spark: spark.table('demo_tours'), 46 | "filter": lambda df: df.limit(500) 47 | } 48 | "demo_locations": { 49 | "source": lambda spark: spark.table('demo_locations'), 50 | "default_sampling": True, 51 | } 52 | }, 53 | "project_folder_name": "ddataflow_demo", 54 | } 55 | 56 | # initialize the application and validate the configuration 57 | ddataflow = DDataflow(**config) 58 | ``` 59 | 60 | ## 3. Use ddataflow in a pipeline 61 | 62 | ```py 63 | from ddataflow_config import ddataflow 64 | 65 | # replace spark.table for ddataflow source will return a spark dataframe 66 | print(ddataflow.source('demo_locations').count()) 67 | # for sql queries replace only the name of the table for the sample data source name provided by ddataflow 68 | print(spark.sql(f""" SELECT COUNT(1) from {ddataflow.name('demo_tours')}""").collect()[0]['count(1)']) 69 | ``` 70 | 71 | Now run it twice and observe the difference in the amount of records: 72 | `python pipeline.py` 73 | 74 | `ENABLE_DDATAFLOW=True python pipeline.py` 75 | 76 | You will see that the dataframes are sampled when ddataflow is enabled and full when the tool is disabled. 77 | 78 | You completed the short demo! 79 | 80 | ## How to develop 81 | 82 | The recommended approach to use ddataflow is to use the offline mode, which allows you to test your pipelines without the need for an active cluster. This is especially important for development and debugging purposes, as it allows you to quickly test and identify any issues with your pipelines. 83 | 84 | Alternatively, you can use Databricks Connect to test your pipelines on an active cluster. However, our experience with this approach has not been great, memory issues are common and there is the risk of overriding production data, so we recommend using the offline mode instead. 85 | 86 | If you have any questions or need any help, please don't hesitate to reach out. We are here to help you get the most out of ddataflow. 87 | 88 | 89 | ## Support 90 | 91 | In case of questions feel free to reach out or create an issue. 92 | 93 | Check out our [FAQ in case of problems](https://github.com/getyourguide/DDataFlow/blob/main/docs/FAQ.md) 94 | 95 | ## Contributing 96 | 97 | We welcome contributions to DDataFlow! If you would like to contribute, please follow these guidelines: 98 | 99 | 1. Fork the repository and create a new branch for your contribution. 100 | 2. Make your changes and ensure that the code passes all tests. 101 | 3. Submit a pull request with a clear description of your changes and the problem it solves. 102 | 103 | Please note that all contributions are subject to review and approval by the project maintainers. We appreciate your help in making DDataFlow even better! 104 | 105 | If you have any questions or need any help, please don't hesitate to reach out. We are here to assist you throughout the contribution process. 106 | 107 | ## License 108 | DDataFlow is licensed under the [MIT License](https://github.com/getyourguide/DDataFlow/blob/main/LICENSE). 109 | -------------------------------------------------------------------------------- /ddataflow/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # import DDataflow here to make importing easier 3 | from .ddataflow import DDataflow 4 | -------------------------------------------------------------------------------- /ddataflow/data_source.py: -------------------------------------------------------------------------------- 1 | import logging as logger 2 | import os 3 | 4 | from pyspark.sql import DataFrame 5 | 6 | from ddataflow.exceptions import BiggerThanMaxSize 7 | from ddataflow.sampling.default import filter_function 8 | from ddataflow.utils import get_or_create_spark 9 | 10 | 11 | class DataSource: 12 | """ 13 | Utility functions at data source level 14 | """ 15 | 16 | def __init__( 17 | self, 18 | *, 19 | name: str, 20 | config: dict, 21 | local_data_folder: str, 22 | snapshot_path: str, 23 | size_limit, 24 | ): 25 | self._name = name 26 | self._local_data_folder = local_data_folder 27 | self._snapshot_path = snapshot_path 28 | self._size_limit = size_limit 29 | self._config = config 30 | self._filter = None 31 | self._source = None 32 | 33 | if "source" in self._config: 34 | self._source = config["source"] 35 | else: 36 | if self._config.get("file-type") == "parquet": 37 | self._source = lambda spark: spark.read.parquet(self._name) 38 | else: 39 | self._source = lambda spark: spark.table(self._name) 40 | 41 | if "filter" in self._config: 42 | self._filter = self._config["filter"] 43 | else: 44 | if self._config.get("default_sampling"): 45 | self._filter = lambda df: filter_function(df) 46 | 47 | def query(self): 48 | """ 49 | query with filter unless none is present 50 | """ 51 | df = self.query_without_filter() 52 | 53 | if self._filter is not None: 54 | print(f"Filter set for {self._name}, applying it") 55 | df = self._filter(df) 56 | else: 57 | print(f"No filter set for {self._name}") 58 | 59 | return df 60 | 61 | def has_filter(self) -> bool: 62 | return self._filter is not None 63 | 64 | def query_without_filter(self): 65 | """ 66 | Go to the raw data source without any filtering 67 | """ 68 | spark = get_or_create_spark() 69 | logger.debug(f"Querying without filter source: '{self._name}'") 70 | return self._source(spark) 71 | 72 | def query_locally(self): 73 | logger.info(f"Querying locally {self._name}") 74 | 75 | path = self.get_local_path() 76 | if not os.path.exists(path): 77 | raise Exception( 78 | f"""Data source '{self.get_name()}' does not have data in {path}. 79 | Consider downloading using the following command: 80 | ddataflow current_project download_data_sources""" 81 | ) 82 | spark = get_or_create_spark() 83 | df = spark.read.parquet(path) 84 | 85 | return df 86 | 87 | def get_dbfs_sample_path(self) -> str: 88 | return os.path.join(self._snapshot_path, self._get_name_as_path()) 89 | 90 | def get_local_path(self) -> str: 91 | return os.path.join(self._local_data_folder, self._get_name_as_path()) 92 | 93 | def _get_name_as_path(self): 94 | """ 95 | converts the name when it has "/mnt/envents" in the name to a single file in a (flat structure) _mnt_events 96 | """ 97 | return self.get_name().replace("/", "_") 98 | 99 | def get_name(self) -> str: 100 | return self._name 101 | 102 | def get_parquet_filename(self) -> str: 103 | return self._name + ".parquet" 104 | 105 | def estimate_size_and_fail_if_too_big(self): 106 | """ 107 | Estimate the size of the data source use the _name used in the _config 108 | It will throw an exception if the estimated size is bigger than the maximum allowed in the configuration 109 | """ 110 | 111 | print("Estimating size of data source: ", self.get_name()) 112 | df = self.query() 113 | size_estimation = self._estimate_size(df) 114 | 115 | print("Estimated size of the Dataset in GB: ", size_estimation) 116 | 117 | if size_estimation > self._size_limit: 118 | raise BiggerThanMaxSize(self._name, size_estimation, self._size_limit) 119 | 120 | return df 121 | 122 | def _estimate_size(self, df: DataFrame) -> float: 123 | """ 124 | Estimates the size of a dataframe in Gigabytes 125 | 126 | Formula: 127 | number of gigabytes = (N*V*W) / 1024^3 128 | """ 129 | 130 | print(f"Amount of rows in dataframe to estimate size: {df.count()}") 131 | average_variable_size_bytes = 50 132 | return (df.count() * len(df.columns) * average_variable_size_bytes) / (1024**3) 133 | -------------------------------------------------------------------------------- /ddataflow/data_sources.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from ddataflow.data_source import DataSource 4 | 5 | 6 | class DataSources: 7 | """ 8 | Validates and Abstract the access to data sources 9 | """ 10 | 11 | def __init__( 12 | self, *, config, local_folder: str, snapshot_path: str, size_limit: int 13 | ): 14 | self.config = config 15 | self.data_source: Dict[str, Any] = {} 16 | self.download_folder = local_folder 17 | for data_source_name, data_source_config in self.config.items(): 18 | self.data_source[data_source_name] = DataSource( 19 | name=data_source_name, 20 | config=data_source_config, 21 | local_data_folder=local_folder, 22 | snapshot_path=snapshot_path, 23 | size_limit=size_limit, 24 | ) 25 | 26 | def all_data_sources_names(self) -> List[str]: 27 | return list(self.data_source.keys()) 28 | 29 | def get_data_source(self, name) -> DataSource: 30 | if name not in self.data_source: 31 | raise Exception(f"Data source does not exist {name}") 32 | return self.data_source[name] 33 | 34 | def get_filter(self, data_source_name: str): 35 | return self.config[data_source_name]["query"] 36 | 37 | def get_parquet_name(self, data_source_name: str): 38 | return self.config[data_source_name]["parquet_name"] 39 | -------------------------------------------------------------------------------- /ddataflow/ddataflow.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List, Optional, Union 4 | 5 | from ddataflow.data_source import DataSource 6 | from ddataflow.data_sources import DataSources 7 | from ddataflow.downloader import DataSourceDownloader 8 | from ddataflow.exceptions import WriterNotFoundException 9 | from ddataflow.sampling.default import ( 10 | build_default_sampling_for_sources, 11 | DefaultSamplerOptions, 12 | ) 13 | from ddataflow.sampling.sampler import Sampler 14 | from ddataflow.utils import get_or_create_spark, using_databricks_connect 15 | from pyspark.sql import DataFrame 16 | 17 | logger = logging.getLogger(__name__) 18 | handler = logging.StreamHandler() 19 | logger.addHandler(handler) 20 | 21 | 22 | class DDataflow: 23 | """ 24 | DDataflow is an end2end tests solution. 25 | See our docs manual for more details. 26 | Additionally, use help(ddataflow) to see the available methods. 27 | """ 28 | 29 | _DEFAULT_SNAPSHOT_BASE_PATH = "dbfs:/ddataflow" 30 | _LOCAL_BASE_SNAPSHOT_PATH = os.environ["HOME"] + "/.ddataflow" 31 | _ENABLE_DDATAFLOW_ENVVARIABLE = "ENABLE_DDATAFLOW" 32 | _ENABLE_OFFLINE_MODE_ENVVARIABLE = "ENABLE_OFFLINE_MODE" 33 | _DDATAFLOW_CONFIG_FILE = "ddataflow_config.py" 34 | 35 | _local_path: str 36 | 37 | def __init__( 38 | self, 39 | project_folder_name: str, 40 | data_sources: Optional[dict] = None, 41 | data_writers: Optional[dict] = None, 42 | data_source_size_limit_gb: int = 1, 43 | enable_ddataflow=False, 44 | sources_with_default_sampling: Optional[List[str]] = None, 45 | snapshot_path: Optional[str] = None, 46 | default_sampler: Optional[dict] = None, 47 | default_database: Optional[str] = None, 48 | ): 49 | """ 50 | Initialize the dataflow object. 51 | The input of this object is the config dictionary outlined in our integrator manual. 52 | 53 | Important params: 54 | project_folder_name: 55 | the name of the project that will be stored in the disk 56 | snapshot_path: 57 | path to the snapshot folder 58 | data_source_size_limit_gb: 59 | limit the size of the data sources 60 | default_sampler: 61 | options to pass to the default sampler 62 | sources_with_default_sampling: 63 | if you have tables you want to have by default and dont want to sample them first 64 | default_database: 65 | name of the default database. If ddataflow is enabled, a test db will be created and used. 66 | sources_with_default_sampling : 67 | Deprecated: use sources with default_sampling=True instead 68 | if you have tables you want to have by default and dont want to sample them first 69 | """ 70 | self._size_limit = data_source_size_limit_gb 71 | 72 | self.project_folder_name = project_folder_name 73 | 74 | base_path = snapshot_path if snapshot_path else self._DEFAULT_SNAPSHOT_BASE_PATH 75 | 76 | self._snapshot_path = base_path + "/" + project_folder_name 77 | self._local_path = self._LOCAL_BASE_SNAPSHOT_PATH + "/" + project_folder_name 78 | 79 | if default_sampler: 80 | # set this before creating data sources 81 | DefaultSamplerOptions.set(default_sampler) 82 | 83 | if not data_sources: 84 | data_sources = {} 85 | 86 | all_data_sources = { 87 | **build_default_sampling_for_sources(sources_with_default_sampling), 88 | **data_sources, 89 | } 90 | 91 | self._data_sources = DataSources( 92 | config=all_data_sources, 93 | local_folder=self._local_path, 94 | snapshot_path=self._snapshot_path, 95 | size_limit=self._size_limit, 96 | ) 97 | 98 | self._data_writers: dict = data_writers if data_writers else {} 99 | 100 | self._offline_enabled = os.getenv(self._ENABLE_OFFLINE_MODE_ENVVARIABLE, False) 101 | 102 | self._ddataflow_enabled: Union[str, bool] = os.getenv( 103 | self._ENABLE_DDATAFLOW_ENVVARIABLE, enable_ddataflow 104 | ) 105 | 106 | # if offline is enabled we should use local data 107 | if self._offline_enabled: 108 | self.enable_offline() 109 | 110 | self.save_sampled_data_sources = Sampler( 111 | self._snapshot_path, self._data_sources 112 | ).save_sampled_data_sources 113 | 114 | if default_database: 115 | self.set_up_database(default_database) 116 | 117 | # Print detailed logs when ddataflow is enabled 118 | if self._ddataflow_enabled: 119 | self.set_logger_level(logging.DEBUG) 120 | else: 121 | logger.info( 122 | "DDataflow is now DISABLED." 123 | "PRODUCTION data will be used and it will write to production tables." 124 | ) 125 | 126 | @staticmethod 127 | def setup_project(): 128 | """ 129 | Sets up a new ddataflow project with empty data sources in the current directory 130 | """ 131 | from ddataflow.setup.setup_project import setup_project 132 | 133 | setup_project() 134 | 135 | @staticmethod 136 | def current_project() -> "DDataflow": 137 | """ 138 | Returns a ddataflow configured with the current directory configuration file 139 | Requirements for this to work: 140 | 141 | 1. MLTools must be called from withing the project root directory 142 | 2. There must be a file called ddataflow_config.py there 143 | 3. the module must have defined DDataflow object with the name of ddataflow 144 | 145 | @todo investigate if we can use import_class_from_string 146 | """ 147 | import sys 148 | 149 | CONFIGURATION_FILE_NAME = "ddataflow_config.py" 150 | 151 | current_folder = os.getcwd() 152 | logger.debug("Loading config from folder", current_folder) 153 | config_location = os.path.join(current_folder, CONFIGURATION_FILE_NAME) 154 | 155 | if not os.path.exists(config_location): 156 | raise Exception( 157 | f""" 158 | This command needs to be executed within a project containing a {CONFIGURATION_FILE_NAME} file. 159 | You can start a new one for the current folder by running the following command: 160 | $ ddataflow setup_project""" 161 | ) 162 | 163 | sys.path.append(current_folder) 164 | 165 | import ddataflow_config 166 | 167 | if hasattr(ddataflow_config, "ddataflow_client"): 168 | return ddataflow_config.ddataflow_client 169 | 170 | if not hasattr(ddataflow_config, "ddataflow"): 171 | raise Exception("ddataflow object is not defined in your _config file") 172 | 173 | return ddataflow_config.ddataflow 174 | 175 | def source(self, name: str, debugger=False) -> DataFrame: 176 | """ 177 | Gives access to the data source configured in the dataflow 178 | 179 | You can also use this function in the terminal with --debugger=True to inspect the dataframe. 180 | """ 181 | self.print_status() 182 | 183 | logger.debug("Loading data source") 184 | data_source: DataSource = self._data_sources.get_data_source(name) 185 | logger.debug("Data source loaded") 186 | df = self._get_data_from_data_source(data_source) 187 | 188 | if debugger: 189 | logger.debug(f"Debugger enabled: {debugger}") 190 | logger.debug("In debug mode now, use query to inspect it") 191 | breakpoint() 192 | 193 | return df 194 | 195 | def _get_spark(self): 196 | return get_or_create_spark() 197 | 198 | def enable(self): 199 | """ 200 | When enabled ddataflow will read from the filtered data sources 201 | instead of production tables. And write to testing tables instead of production ones. 202 | """ 203 | 204 | self._ddataflow_enabled = True 205 | 206 | def is_enabled(self) -> bool: 207 | return self._ddataflow_enabled 208 | 209 | def enable_offline(self): 210 | """Programatically enable offline mode""" 211 | self._offline_enabled = True 212 | self.enable() 213 | 214 | def is_local(self) -> bool: 215 | return self._offline_enabled 216 | 217 | def disable_offline(self): 218 | """Programatically enable offline mode""" 219 | self._offline_enabled = False 220 | 221 | def source_name(self, name, disable_view_creation=False) -> str: 222 | """ 223 | Given the name of a production table, returns the name of the corresponding ddataflow table when ddataflow is enabled 224 | If ddataflow is disabled get the production one. 225 | """ 226 | logger.debug("Source name used: ", name) 227 | source_name = name 228 | 229 | # the gist of ddtafalow 230 | if self._ddataflow_enabled: 231 | source_name = self._get_new_table_name(name) 232 | if disable_view_creation: 233 | return source_name 234 | 235 | logger.debug(f"Creating a temp view with the name: {source_name}") 236 | data_source: DataSource = self._data_sources.get_data_source(name) 237 | 238 | if self._offline_enabled: 239 | df = data_source.query_locally() 240 | else: 241 | df = data_source.query() 242 | 243 | df.createOrReplaceTempView(source_name) 244 | 245 | return source_name 246 | 247 | return source_name 248 | 249 | def path(self, path): 250 | """ 251 | returns a deterministic path replacing the real production path with one based on the current environment needs. 252 | Currently support path starts with 'dbfs:/' and 's3://'. 253 | """ 254 | if not self._ddataflow_enabled: 255 | return path 256 | 257 | base_path = self._get_current_environment_data_folder() 258 | 259 | for path_prefix in ["dbfs:/", "s3://"]: 260 | path = path.replace(path_prefix, "") 261 | 262 | return base_path + "/" + path 263 | 264 | def set_up_database(self, db_name: str): 265 | """ 266 | Perform USE $DATABASE query to set up a default database. 267 | If ddataflow is enabled, use a test db to prevent writing data into production. 268 | """ 269 | # rename database if ddataflow is enabled 270 | if self._ddataflow_enabled: 271 | db_name = f"ddataflow_{db_name}" 272 | # get spark 273 | spark = self._get_spark() 274 | # create db if not exist 275 | sql = "CREATE DATABASE IF NOT EXISTS {0}".format(db_name) 276 | spark.sql(sql) 277 | # set default db 278 | spark.sql("USE {}".format(db_name)) 279 | logger.warning(f"The default database is now set to {db_name}") 280 | 281 | def _get_new_table_name(self, name) -> str: 282 | overriden_name = name.replace("dwh.", "") 283 | return self.project_folder_name + "_" + overriden_name 284 | 285 | def name(self, *args, **kwargs): 286 | """ 287 | A shorthand for source_name 288 | """ 289 | return self.source_name(*args, **kwargs) 290 | 291 | def disable(self): 292 | """Disable ddtaflow overriding tables, uses production state in other words""" 293 | self._ddataflow_enabled = False 294 | 295 | def _get_data_from_data_source(self, data_source: DataSource) -> DataFrame: 296 | if not self._ddataflow_enabled: 297 | logger.debug("DDataflow not enabled") 298 | # goes directly to production without prefilters 299 | return data_source.query_without_filter() 300 | 301 | if self._offline_enabled: 302 | # uses snapshot data 303 | if using_databricks_connect(): 304 | logger.debug( 305 | "Looks like you are using databricks-connect in offline mode. You probably want to run it " 306 | "without databricks connect in offline mode" 307 | ) 308 | 309 | return data_source.query_locally() 310 | 311 | logger.debug("DDataflow enabled and filtering") 312 | return data_source.query() 313 | 314 | def download_data_sources(self, overwrite: bool = True, debug=False): 315 | """ 316 | Download the data sources locally for development offline 317 | Note: you need databricks-cli for this command to work 318 | 319 | Options: 320 | overwrite: will first clean the existing files 321 | """ 322 | DataSourceDownloader().download_all(self._data_sources, overwrite, debug) 323 | 324 | def sample_and_download( 325 | self, ask_confirmation: bool = True, overwrite: bool = True 326 | ): 327 | """ 328 | Create a sample folder in dbfs and then downloads it in the local machine 329 | """ 330 | self.save_sampled_data_sources(dry_run=False, ask_confirmation=ask_confirmation) 331 | self.download_data_sources(overwrite) 332 | 333 | def write(self, df, name: str): 334 | """ 335 | Write a dataframe either to a local folder or the production one 336 | """ 337 | if name not in self._data_writers: 338 | raise WriterNotFoundException(name) 339 | 340 | if self._ddataflow_enabled: 341 | writing_path = self._snapshot_path 342 | 343 | if self._offline_enabled: 344 | writing_path = self._local_path 345 | else: 346 | if not writing_path.startswith(DDataflow._DEFAULT_SNAPSHOT_BASE_PATH): 347 | raise Exception( 348 | f"Only writing to {DDataflow._DEFAULT_SNAPSHOT_BASE_PATH} is enabled" 349 | ) 350 | 351 | writing_path = os.path.join(writing_path, name) 352 | logger.info("Writing data to parquet file: " + writing_path) 353 | return df.write.parquet(writing_path, mode="overwrite") 354 | 355 | # if none of the above writes to production 356 | return self._data_writers[name]["writer"](df, name, self._get_spark()) # type: ignore 357 | 358 | def read(self, name: str): 359 | """ 360 | Read the data writers parquet file which are stored in the ddataflow folder 361 | """ 362 | path = self._snapshot_path 363 | if self._offline_enabled: 364 | path = self._local_path 365 | 366 | parquet_path = os.path.join(path, name) 367 | return self._get_spark().read.parquet(parquet_path) 368 | 369 | def _print_snapshot_size(self): 370 | """ 371 | Prints the final size of the dataset in the folder 372 | Note: Only works in notebooks. 373 | """ 374 | import subprocess 375 | 376 | location = "/dbfs/ddataflow/" 377 | output = subprocess.getoutput(f"du -h -d2 {location}") 378 | print(output) 379 | 380 | def _print_download_folder_size(self): 381 | """ 382 | Prints the final size of the dataset in the folder 383 | """ 384 | import subprocess 385 | 386 | output = subprocess.getoutput(f"du -h -d2 {self._local_path}") 387 | print(output) 388 | 389 | def get_mlflow_path(self, original_path: str): 390 | """ 391 | overrides the mlflow path if 392 | """ 393 | overriden_path = self._get_overriden_arctifacts_current_path() 394 | if overriden_path: 395 | model_name = original_path.split("/")[-1] 396 | return overriden_path + "/" + model_name 397 | 398 | return original_path 399 | 400 | def _get_overriden_arctifacts_current_path(self): 401 | if self._offline_enabled: 402 | return self._local_path 403 | 404 | if self._ddataflow_enabled: 405 | return self._snapshot_path 406 | 407 | return None 408 | 409 | def is_enabled(self): 410 | """ 411 | To be enabled ddataflow has to be either in offline mode or with enable=True 412 | """ 413 | return self._offline_enabled or self._ddataflow_enabled 414 | 415 | def print_status(self): 416 | """ 417 | Print the status of the ddataflow 418 | """ 419 | if self._offline_enabled: 420 | logger.debug("DDataflow is now ENABLED in OFFLINE mode") 421 | logger.debug( 422 | "To disable it remove from your code or unset the enviroment variable 'unset ENABLE_DDATAFLOW ; unset ENABLE_OFFLINE_MODE'" 423 | ) 424 | elif self._ddataflow_enabled: 425 | logger.debug( 426 | """ 427 | DDataflow is now ENABLED in ONLINE mode. Filtered data will be used and it will write to temporary tables. 428 | """ 429 | ) 430 | else: 431 | logger.debug( 432 | f""" 433 | DDataflow is now DISABLED. So PRODUCTION data will be used and it will write to production tables. 434 | Use enable() function or export {self._ENABLE_DDATAFLOW_ENVVARIABLE}=True to enable it. 435 | If you are working offline use export ENABLE_OFFLINE_MODE=True instead. 436 | """ 437 | ) 438 | 439 | def _get_current_environment_data_folder(self) -> Optional[str]: 440 | if not self._ddataflow_enabled: 441 | raise Exception("DDataflow is disabled so no data folder is available") 442 | 443 | if self._offline_enabled: 444 | return self._local_path 445 | 446 | return self._snapshot_path 447 | 448 | def set_logger_level(self, level): 449 | """ 450 | Set logger level. 451 | Levels can be found here: https://docs.python.org/3/library/logging.html#logging-levels 452 | """ 453 | logger.info(f"Set logger level to: {level}") 454 | logger.setLevel(level) 455 | 456 | 457 | def main(): 458 | import fire 459 | 460 | fire.Fire(DDataflow) 461 | -------------------------------------------------------------------------------- /ddataflow/downloader.py: -------------------------------------------------------------------------------- 1 | import logging as logger 2 | import os 3 | 4 | from ddataflow.data_source import DataSource 5 | from ddataflow.data_sources import DataSources 6 | 7 | 8 | class DataSourceDownloader: 9 | 10 | _download_folder: str 11 | 12 | def download_all( 13 | self, data_sources: DataSources, overwrite: bool = True, debug=False 14 | ): 15 | """ 16 | Download the data sources locally for development offline 17 | Note: you need databricks-cli for this command to work 18 | 19 | Options: 20 | overwrite: will first clean the existing files 21 | """ 22 | self._download_folder = data_sources.download_folder 23 | if overwrite: 24 | if ".ddataflow" not in self._download_folder: 25 | raise Exception("Can only clean folders within .ddataflow") 26 | 27 | cmd_delete = f"rm -rf {self._download_folder}" 28 | print("Deleting content from", cmd_delete) 29 | os.system(cmd_delete) 30 | 31 | print("Starting to download the data-sources into your snapshot folder") 32 | 33 | for data_source_name in data_sources.all_data_sources_names(): 34 | print(f"Starting download process for datasource: {data_source_name}") 35 | data_source = data_sources.get_data_source(data_source_name) 36 | self._download_data_source(data_source, debug) 37 | 38 | print("Download of all data-sources finished successfully!") 39 | 40 | def _download_data_source(self, data_source: DataSource, debug=False): 41 | """ 42 | Download the latest data snapshot to the local machine for developing locally 43 | """ 44 | os.makedirs(self._download_folder, exist_ok=True) 45 | 46 | debug_str = "" 47 | if debug: 48 | debug_str = "--debug" 49 | 50 | cmd = f'databricks fs cp {debug_str} -r "{data_source.get_dbfs_sample_path()}" "{data_source.get_local_path()}"' 51 | 52 | logger.info(cmd) 53 | result = os.system(cmd) 54 | 55 | if result != 0: 56 | raise Exception( 57 | f""" 58 | Databricks cli failed! See error message above. 59 | Also consider rerunning the download command in your terminal to see the results. 60 | {cmd} 61 | """ 62 | ) 63 | -------------------------------------------------------------------------------- /ddataflow/exceptions.py: -------------------------------------------------------------------------------- 1 | class BiggerThanMaxSize(Exception): 2 | def create(self, datasource_name, size_estimation, size_limit): 3 | return BiggerThanMaxSize( 4 | f"DF {datasource_name} is estimated to be {round(size_estimation)}GB large. It is larger than the " 5 | f"current" 6 | f" limit {size_limit}GB. Adjust sampling filter or set higher limit" 7 | ) 8 | 9 | 10 | class WritingToLocationDenied(Exception): 11 | def create(self, base_folder): 12 | return WritingToLocationDenied(f"Only writing to {base_folder} is allowed") 13 | 14 | 15 | class WriterNotFoundException(Exception): 16 | def create(self, name: str): 17 | return WriterNotFoundException( 18 | f"The data writer was not found in your config: {name}" 19 | ) 20 | -------------------------------------------------------------------------------- /ddataflow/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/getyourguide/DDataFlow/7d7ed8e3fa9e4a51ef85afa50c22de3e60779056/ddataflow/sampling/__init__.py -------------------------------------------------------------------------------- /ddataflow/sampling/default.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import List, Optional 3 | 4 | import pyspark.sql.functions as F 5 | from pyspark.sql import DataFrame 6 | from dataclasses import dataclass 7 | 8 | 9 | @dataclass 10 | class DefaultSamplerOptions: 11 | """ 12 | Options to customize the default sampler 13 | """ 14 | 15 | limit: Optional[int] = 100000 16 | 17 | _instance = None 18 | 19 | @staticmethod 20 | def set(data: dict): 21 | DefaultSamplerOptions._instance = DefaultSamplerOptions(**data) 22 | 23 | return DefaultSamplerOptions._instance 24 | 25 | @staticmethod 26 | def get_instance(): 27 | if not DefaultSamplerOptions._instance: 28 | DefaultSamplerOptions._instance = DefaultSamplerOptions() 29 | 30 | return DefaultSamplerOptions._instance 31 | 32 | 33 | def sample_by_yesterday(df: DataFrame) -> DataFrame: 34 | """ 35 | Sample by yesterday 36 | """ 37 | 38 | yesterday = (datetime.now() - timedelta(days=1)).date() 39 | 40 | if "date" in df.columns: 41 | print("Found a date column, sampling by yesterday") 42 | return df.filter(F.col("date") == yesterday) 43 | 44 | return df 45 | 46 | 47 | def filter_function(df: DataFrame) -> DataFrame: 48 | """ 49 | Default filter function 50 | :param df: 51 | :return: 52 | """ 53 | df = sample_by_yesterday(df) 54 | df = df.limit(DefaultSamplerOptions.get_instance().limit) 55 | 56 | return df 57 | 58 | 59 | def build_default_sampling_for_sources(sources: Optional[List[dict]] = None) -> dict: 60 | """ 61 | Setup standard filters for the entries that we do not specify them 62 | """ 63 | result = {} 64 | if not sources: 65 | return result 66 | 67 | for source in sources: 68 | print("Build default sampling for source: " + source) 69 | result[source] = { 70 | "source": lambda spark: spark.table(source), 71 | "filter": lambda df: filter_function(df), 72 | } 73 | 74 | return result 75 | -------------------------------------------------------------------------------- /ddataflow/sampling/sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | from ddataflow.data_source import DataSource 5 | from ddataflow.data_sources import DataSources 6 | from ddataflow.exceptions import WritingToLocationDenied 7 | 8 | 9 | class Sampler: 10 | """ 11 | Samples and copy datasources 12 | """ 13 | 14 | def __init__(self, snapshot_path: str, data_sources: DataSources): 15 | self._BASE_SNAPSHOT_PATH = snapshot_path 16 | self._data_sources: DataSources = data_sources 17 | self._dry_run = True 18 | 19 | def save_sampled_data_sources( 20 | self, 21 | *, 22 | dry_run=True, 23 | ask_confirmation=False, 24 | sample_only: Optional[List[str]] = None, 25 | ): 26 | """ 27 | Make a snapshot of the sampled data for later downloading. 28 | 29 | The writing part is overriding! SO watch out where you write 30 | By default only do a dry-run. Do a dry_run=False to actually write to the sampled folder. 31 | """ 32 | self._dry_run = dry_run 33 | if self._dry_run: 34 | print( 35 | "Dry run enabled, no data will be written. Use dry_run=False to actually perform the operation." 36 | ) 37 | 38 | if sample_only is not None: 39 | print(f"Sampling only the following data sources: {sample_only}") 40 | 41 | print( 42 | "Sampling process starting for all data-sources." 43 | f" Saving the results to {self._BASE_SNAPSHOT_PATH} so you can download them." 44 | ) 45 | 46 | for data_source_name in self._data_sources.all_data_sources_names(): 47 | 48 | if sample_only is not None and data_source_name not in sample_only: 49 | print("data_source_name not in selected list", data_source_name) 50 | continue 51 | 52 | print(f"Starting sampling process for datasource: {data_source_name}") 53 | data_source: DataSource = self._data_sources.get_data_source( 54 | data_source_name 55 | ) 56 | 57 | print( 58 | f""" 59 | Writing copy to folder: {data_source.get_dbfs_sample_path()}. 60 | If you are writing to the wrong folder it could lead to data loss. 61 | """ 62 | ) 63 | 64 | df = data_source.estimate_size_and_fail_if_too_big() 65 | if ask_confirmation: 66 | proceed = input( 67 | f"Do you want to proceed with the sampling creation? (y/n):" 68 | ) 69 | if proceed != "y": 70 | print("Skipping the creation of the data source") 71 | continue 72 | self._write_sample_data(df, data_source) 73 | 74 | print("Success! Copying of sample data sources finished") 75 | 76 | def _write_sample_data(self, df, data_source: DataSource) -> None: 77 | """ 78 | Write the sampled data source into a temporary location in dbfs so it can be later downloaded 79 | If the location already exists it will be overwritten 80 | If you write to a location different than the standard ddataflow location in it will fail. 81 | """ 82 | 83 | # a security check so we dont write by mistake where we should not 84 | sample_destination = data_source.get_dbfs_sample_path() 85 | if not sample_destination.startswith(self._BASE_SNAPSHOT_PATH): 86 | raise WritingToLocationDenied(self._BASE_SNAPSHOT_PATH) 87 | 88 | if self._dry_run: 89 | print("Not writing, dry run enabled") 90 | return 91 | # add by default repartition so we only download a single file 92 | df = df.repartition(1) 93 | 94 | df.write.mode("overwrite").parquet(data_source.get_dbfs_sample_path()) 95 | -------------------------------------------------------------------------------- /ddataflow/setup/__init__.py: -------------------------------------------------------------------------------- 1 | def setup_project(config_file): 2 | content = """ 3 | from ddataflow import DDataflow 4 | 5 | config = { 6 | # change the name of the project for a meaningful name (github project name usually a good idea) 7 | "project_folder_name": "my_project", 8 | # add here your tables or paths with customized sampling logic 9 | "data_sources": { 10 | # 'demo_tours': { 11 | # # use default_sampling=True to rely on automatic sampling otherwise it will use the hole data 12 | # 'default_sampling': True, 13 | # }, 14 | # 'demo_tours2': { 15 | # 'source': lambda spark: spark.table('demo_tours'), 16 | # # to customize the sampling logic 17 | # 'filter': lambda df: df.limit(500) 18 | # } 19 | }, 20 | # to customize the max size of your examples uncomment the line below 21 | # "data_source_size_limit_gb": 3 22 | # to customize the location of your test datasets in your data wharehouse 23 | # "snapshot_path": "dbfs:/another_databricks_path", 24 | } 25 | 26 | # initialize the application and validate the configuration 27 | ddataflow = DDataflow(**config) 28 | """ 29 | 30 | with open(config_file, "w") as f: 31 | f.write(content) 32 | print(f"File {config_file} created in the current directory.") 33 | -------------------------------------------------------------------------------- /ddataflow/setup/ddataflow_config.py: -------------------------------------------------------------------------------- 1 | from ddataflow import DDataflow 2 | 3 | config = { 4 | # change the name of the project for a meaningful name (github project name usually a good idea) 5 | "project_folder_name": "your_project_name", 6 | # add here your tables or paths with customized sampling logic 7 | "data_sources": { 8 | # "demo_tours": { 9 | # # use default_sampling=True to rely on automatic sampling otherwise it will use the whole data 10 | # "default_sampling": True, 11 | # }, 12 | # "demo_tours2": { 13 | # "source": lambda spark: spark.table("demo_tours"), 14 | # # to customize the sampling logic 15 | # "filter": lambda df: df.limit(500), 16 | # }, 17 | # "/mnt/cleaned/EventName": { 18 | # "file-type": "parquet", 19 | # "default_sampling": True, 20 | # }, 21 | }, 22 | # "default_sampler": { 23 | # # defines the amount of rows retrieved with the default sampler, used as .limit(limit) in the dataframe 24 | # # default = 10000 25 | # "limit": 100000, 26 | # }, 27 | # to customize the max size of your examples uncomment the line below 28 | # "data_source_size_limit_gb": 3 29 | # to customize the location of your test datasets in your data wharehouse 30 | # "snapshot_path": "dbfs:/another_databricks_path", 31 | } 32 | 33 | # initialize the application and validate the configuration 34 | ddataflow = DDataflow(**config) 35 | -------------------------------------------------------------------------------- /ddataflow/setup/setup_project.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from ddataflow import DDataflow 5 | 6 | 7 | def setup_project(): 8 | config_file = DDataflow._DDATAFLOW_CONFIG_FILE 9 | path = os.path.dirname(os.path.realpath(__file__)) + "/ddataflow_config.py" 10 | 11 | shutil.copyfile(path, config_file) 12 | print(f"File {config_file} created in the current directory.") 13 | -------------------------------------------------------------------------------- /ddataflow/utils.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql import SparkSession 2 | 3 | 4 | def get_or_create_spark(): 5 | return SparkSingleton.get_instance() 6 | 7 | 8 | class SparkSingleton: 9 | spark = None 10 | 11 | @staticmethod 12 | def get_instance(): 13 | if SparkSingleton.spark is None: 14 | SparkSingleton.spark = SparkSession.builder.getOrCreate() 15 | 16 | return SparkSingleton.spark 17 | 18 | 19 | def summarize_spark_dataframe(spark_dataframe): 20 | spark_dataframe.show(3) 21 | result = { 22 | "num_of_rows": spark_dataframe.count(), 23 | "num_of_columns": len(spark_dataframe.columns), 24 | "estimated_size": estimate_spark_dataframe_size( 25 | spark_dataframe=spark_dataframe 26 | ), 27 | } 28 | return result 29 | 30 | 31 | def is_local(spark): 32 | setting = spark.conf.get("spark.master") 33 | return "local" in setting 34 | 35 | 36 | def using_databricks_connect() -> bool: 37 | """ 38 | Return true if databricks connect is being used 39 | """ 40 | import os 41 | 42 | result = os.system("pip show databricks-connect") 43 | 44 | return result == 0 45 | -------------------------------------------------------------------------------- /ddataflow_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import fire 3 | 4 | from ddataflow.ddataflow import main 5 | 6 | if __name__ == "__main__": 7 | main() 8 | -------------------------------------------------------------------------------- /docs/FAQ.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | 4 | ## I am trying to download data but the system is complaining my databricks cli is not configure 5 | 6 | After installing ddataflow run the configure producedure in your installed machine 7 | 8 | ```sh 9 | databricks configure --token 10 | ``` 11 | 12 | Follow the wizard until the end. 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /docs/api_reference/DDataflow.md: -------------------------------------------------------------------------------- 1 | ::: ddataflow.ddataflow.DDataflow -------------------------------------------------------------------------------- /docs/api_reference/DataSource.md: -------------------------------------------------------------------------------- 1 | ::: ddataflow.data_source -------------------------------------------------------------------------------- /docs/api_reference/DataSourceDownloader.md: -------------------------------------------------------------------------------- 1 | ::: ddataflow.downloader -------------------------------------------------------------------------------- /docs/api_reference/DataSources.md: -------------------------------------------------------------------------------- 1 | ::: ddataflow.data_sources -------------------------------------------------------------------------------- /docs/ddataflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/getyourguide/DDataFlow/7d7ed8e3fa9e4a51ef85afa50c22de3e60779056/docs/ddataflow.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Home 2 | 3 | DDataFlow is an end2end tests and local development solution for machine learning and data pipelines using pyspark. 4 | 5 | It allows you to: 6 | - Read a subset of our data so to speed up the running of the pipelines during tests 7 | - Write to a test location our artifacts so you don't pollute production 8 | - Download data for enabling local machine development 9 | 10 | Below is the DDataFlow integration manual. 11 | If you want to know how to use DDataFlow in the local machine, jump to [this section](local_development.md). 12 | 13 | ## Install Ddataflow 14 | 15 | ```sh 16 | pip install ddataflow 17 | ``` 18 | 19 | ## Mapping your data sources 20 | 21 | DDataflow is declarative and is completely configurable a single configuration in DDataflow startup. To create a configuration for you project simply run: 22 | 23 | ```shell 24 | ddataflow setup_project 25 | ``` 26 | 27 | You can use this config also in in a notebook, or using databricks-connect or in the repository with db-rocket. Example config below: 28 | 29 | ```python 30 | #later save this script as ddataflow_config.py to follow our convention 31 | from ddataflow import DDataflow 32 | import pyspark.sql.functions as F 33 | 34 | start_time = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d") 35 | end_time = datetime.now().strftime("%Y-%m-%d") 36 | 37 | config = { 38 | "data_sources": { 39 | # data sources define how to access data 40 | "events": { 41 | "source": lambda spark: spark.table("events"), 42 | # here we define the spark query to reduce the size of the data 43 | # the filtering strategy will most likely dependend on the domain. 44 | "filter": lambda df: 45 | df.filter(F.col("date") >= start_time) 46 | .filter(F.col("date") <= end_time) 47 | .filter(F.col("event_name").isin(["BookAction", "ActivityCardImpression"])), 48 | }, 49 | "ActivityCardImpression": { 50 | # source can also be partquet files 51 | "source": lambda spark: spark.read.parquet( 52 | f"dbfs:/events/eventname/date={start_time}/" 53 | ) 54 | }, 55 | }, 56 | "project_folder_name": "myproject", 57 | } 58 | 59 | # initialize the application and validate the configuration 60 | ddataflow_client = DDataflow(**config) 61 | ``` 62 | 63 | ## Replace the sources 64 | 65 | Replace in your code the calls to the original data sources for the ones provided by ddataflow. 66 | 67 | ```py 68 | spark.table('events') #... 69 | spark.read.parquet("dbfs:/mnt/analytics/cleaned/v1/ActivityCardImpression") # ... 70 | ``` 71 | 72 | Replace with the following: 73 | 74 | ```py 75 | from ddataflow_config import ddataflow_client 76 | 77 | ddataflow_client.source('events') 78 | ddataflow_client.source("ActivityCardImpression") 79 | ``` 80 | 81 | Its not a problem if you dont map all data sources if you dont map one it will keep going to production tables and 82 | might be slower. From this point you can use dddataflow to run your pipelines on the sample data instead of the full data. 83 | 84 | **Note: BY DEFAULT ddataflow is DISABLED, so the calls will attempt to go to production, which if done wrong can 85 | lead to writing trash data**. 86 | 87 | To enable DDataFlow you can either export an environment variable without changing the code. 88 | 89 | ```shell 90 | # in shell or in the CICD pipeline 91 | export ENABLE_DDATAFLOW=true 92 | # run your pipeline as normal 93 | python conduction_time_predictor/train.py 94 | ``` 95 | 96 | Or you can enable it programmatically in python 97 | 98 | ```shell 99 | ddataflow_client.enable() 100 | ``` 101 | 102 | At any point in time you can check if the tool is enabled or disabled by running: 103 | 104 | ```py 105 | ddataflow_client.print_status() 106 | ``` 107 | 108 | ## Writing data 109 | 110 | To write data we adivse you use the same code as production just write to a different destination. 111 | DDataflow provides the path function that will return a staging path when ddataflow is enabled. 112 | 113 | ```py 114 | final_path = ddataflow.path('/mnt/my/production/path') 115 | # final_path=/mnt/my/production/path when ddataflow is DISABLED 116 | # final path=$DDATAFLOW_FOLDER/project_name/mnt/my/production/path when ddataflow is ENABLED 117 | ``` 118 | 119 | And you are good to go! -------------------------------------------------------------------------------- /docs/local_development.md: -------------------------------------------------------------------------------- 1 | # Local Development 2 | 3 | DDataflow also enables one to develop with local data. We see this though as a more advanced use case, which might be 4 | the first choice for everybody. First, make a copy of the files you need to download in dbfs. 5 | 6 | ```py 7 | ddataflow.save_sampled_data_sources(ask_confirmation=False) 8 | ``` 9 | 10 | Then in your machine: 11 | 12 | ```sh 13 | ddataflow current_project download_data_sources 14 | ``` 15 | 16 | Now you can use the pipeline locally by exporting the following env variables: 17 | 18 | ```shell 19 | export ENABLE_OFFLINE_MODE=true 20 | # run your pipeline as normal 21 | python yourproject/train.py 22 | ``` 23 | 24 | The downloaded data sources will be stored at `$HOME/.ddataflow`. 25 | 26 | ## Local setup for spark 27 | 28 | if you run spark locally you might need to tweak some parameters compared to your cluster. Below is a good example you can use. 29 | 30 | ```py 31 | def configure_spark(): 32 | 33 | if ddataflow.is_local(): 34 | import pyspark 35 | 36 | spark_conf = pyspark.SparkConf() 37 | spark_conf.set("spark.sql.warehouse.dir", "/tmp") 38 | spark_conf.set("spark.sql.catalogImplementation", "hive") 39 | spark_conf.set("spark.driver.memory", "15g") 40 | spark_conf.setMaster("local[*]") 41 | sc = pyspark.SparkContext(conf=spark_conf) 42 | session = pyspark.sql.SparkSession(sc) 43 | 44 | return session 45 | 46 | return SparkSession.builder.getOrCreate() 47 | ``` 48 | 49 | If you run into Snappy compression problem: Please reinstall pyspark! 50 | -------------------------------------------------------------------------------- /docs/sampling.md: -------------------------------------------------------------------------------- 1 | ## Sampling on the notebook 2 | 3 | Add the following to your setup.py 4 | 5 | ```py 6 | # given that ddataflow config usually sits on the root of the project 7 | # we add it to the package data manually if we want to access the config 8 | # installed as a library 9 | py_modules=[ 10 | "ddataflow_config", 11 | ], 12 | ``` 13 | 14 | ## With DBrocket 15 | 16 | Cell 1 17 | 18 | ```sh 19 | %pip install --upgrade pip 20 | %pip install ddataflow 21 | %pip install /dbfs/temp/user/search_ranking_pipeline-1.0.1-py3-none-any.whl --force-reinstall` 22 | ``` 23 | 24 | Cell 2 25 | 26 | ```py 27 | from ddataflow_config import ddataflow 28 | ddataflow.save_sampled_data_sources() 29 | ``` 30 | 31 | Then use dry_run=False when you are ready to copy. 32 | -------------------------------------------------------------------------------- /docs/troubleshooting.md: -------------------------------------------------------------------------------- 1 | 2 | One drawback of having ddataflow in the root folder is that it can conflict with other ddtaflow installations. 3 | Prefer installing ddataflow in submodules of your main project (`myproject/main_module/ddataflow_config.py`) instead of globally (`myproject/ddataflow_config.py`). 4 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/getyourguide/DDataFlow/7d7ed8e3fa9e4a51ef85afa50c22de3e60779056/examples/__init__.py -------------------------------------------------------------------------------- /examples/ddataflow_config.py: -------------------------------------------------------------------------------- 1 | from ddataflow import DDataflow 2 | 3 | config = { 4 | # add here your tables or paths with customized sampling logic 5 | "data_sources": { 6 | "demo_tours": { 7 | "source": lambda spark: spark.table("demo_tours"), 8 | "filter": lambda df: df.limit(500), 9 | }, 10 | "demo_locations": { 11 | "source": lambda spark: spark.table("demo_locations"), 12 | "default_sampling": True, 13 | }, 14 | }, 15 | "project_folder_name": "ddataflow_demo", 16 | } 17 | 18 | # initialize the application and validate the configuration 19 | ddataflow = DDataflow(**config) 20 | -------------------------------------------------------------------------------- /examples/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from random import randrange 3 | 4 | from ddataflow_config import ddataflow 5 | from pyspark.sql.session import SparkSession 6 | 7 | spark = SparkSession.builder.getOrCreate() 8 | 9 | 10 | def create_data(): 11 | locations = [ 12 | {"location_id": i, "location_name": f"Location {i}"} for i in range(2000) 13 | ] 14 | 15 | locations_df = spark.createDataFrame(locations) 16 | locations_df.write.parquet("/tmp/demo_locations.parquet") 17 | 18 | tours = [ 19 | {"tour_id": i, "tour_name": f"Tour {i}", "location_id": randrange(2000)} 20 | for i in range(50000) 21 | ] 22 | tours_df = spark.createDataFrame(tours) 23 | tours_df.write.parquet("/tmp/demo_tours.parquet") 24 | 25 | 26 | def pipeline(): 27 | spark = SparkSession.builder.getOrCreate() 28 | 29 | # if we are dealing with offline data we dont need to register anything as ddataflow will take care of it for us 30 | if "ENABLE_OFFLINE_MODE" not in os.environ: 31 | spark.read.parquet("/tmp/demo_locations.parquet").registerTempTable( 32 | "demo_locations" 33 | ) 34 | spark.read.parquet("/tmp/demo_tours.parquet").registerTempTable("demo_tours") 35 | 36 | # pyspark code using a different source name 37 | total_locations = spark.table(ddataflow.name("demo_locations")).count() 38 | # sql code also works 39 | total_tours = spark.sql( 40 | f""" SELECT COUNT(1) from {ddataflow.name('demo_tours')}""" 41 | ).collect()[0]["count(1)"] 42 | return { 43 | "total_locations": total_locations, 44 | "total_tours": total_tours, 45 | } 46 | 47 | 48 | def run_scenarios(): 49 | create_data() 50 | ddataflow.disable() 51 | result = pipeline() 52 | assert result["total_tours"] == 50000 53 | 54 | ddataflow.enable() 55 | result = pipeline() 56 | print(result) 57 | assert result["total_tours"] == 500 58 | 59 | 60 | if __name__ == "__main__": 61 | import fire 62 | 63 | fire.Fire() 64 | -------------------------------------------------------------------------------- /html/ddataflow.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | ddataflow API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 47 |
48 |
49 |

50 | ddataflow

51 | 52 | 53 | 54 | 55 | 56 | 57 |
1# flake8: noqa
 58 | 2# import  DDataflow here to make importing easier
 59 | 3from .ddataflow import DDataflow
 60 | 
61 | 62 | 63 |
64 |
65 | 247 | -------------------------------------------------------------------------------- /html/ddataflow/samples.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | ddataflow.samples API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 44 |
45 |
46 |

47 | ddataflow.samples

48 | 49 | 50 | 51 | 52 | 53 |
54 |
55 | 237 | -------------------------------------------------------------------------------- /html/ddataflow/sampling.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | ddataflow.sampling API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 45 |
46 |
47 |

48 | ddataflow.sampling

49 | 50 | 51 | 52 | 53 | 54 |
55 |
56 | 238 | -------------------------------------------------------------------------------- /html/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: DDataflow 2 | site_url: https://code.getyourguide.com/DDataFlow/ 3 | repo_url: https://github.com/getyourguide/DDataFlow/ 4 | edit_uri: edit/main/docs/ 5 | 6 | theme: 7 | name: material 8 | icon: 9 | edit: material/pencil 10 | repo: fontawesome/brands/github 11 | features: 12 | - content.action.edit 13 | 14 | 15 | 16 | markdown_extensions: 17 | - pymdownx.superfences 18 | 19 | nav: 20 | - 'index.md' 21 | - 'local_development.md' 22 | - 'sampling.md' 23 | - API Reference: 24 | - 'api_reference/DDataflow.md' 25 | - 'api_reference/DataSource.md' 26 | - 'api_reference/DataSources.md' 27 | - 'api_reference/DataSourceDownloader.md' 28 | - 'troubleshooting.md' 29 | - 'FAQ.md' 30 | 31 | plugins: 32 | - search 33 | - mkdocstrings: 34 | handlers: 35 | # See: https://mkdocstrings.github.io/python/usage/ 36 | python: 37 | options: 38 | docstring_style: sphinx 39 | allow_inspection: true 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "DDataFlow" 3 | version = "1.1.16" 4 | description = "A tool for end2end data tests" 5 | authors = ["Data products GYG "] 6 | readme = "README.md" 7 | repository = "https://github.com/getyourguide/DDataFlow" 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.8,<4" 11 | databricks-cli = {version = ">=0.16"} 12 | pyspark = "3.4.1" 13 | fire = ">=0.4" 14 | # this is not a direct dependency of the template, but a dependency of databricks-cli 15 | # that comes from mlflow, if we dont pin it here we get a lower version that has a security problem 16 | oauthlib = ">=3.2.1" 17 | urllib3 = ">=1.24.2" 18 | requests = "^2.23.3" 19 | 20 | [tool.poetry.scripts] 21 | ddataflow = 'ddataflow.ddataflow:main' 22 | 23 | [tool.poetry.dev-dependencies] 24 | pytest = "^6.2" 25 | 26 | [build-system] 27 | requires = ["poetry-core>=1.0.0"] 28 | build-backend = "poetry.core.masonry.api" 29 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if os.environ.get("ENABLE_OFFLINE_MODE"): 4 | raise Exception("Offline mode enabled might confuse unit tests, unset it first.") 5 | 6 | if os.environ.get("ENABLE_DDATAFLOW"): 7 | raise Exception( 8 | "ENABLE_DDATAFLOW env variable might confuse unit tests, unset it first." 9 | ) 10 | -------------------------------------------------------------------------------- /tests/test_configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from ddataflow import DDataflow 6 | 7 | from datetime import datetime, timedelta 8 | 9 | import pyspark.sql.functions as F 10 | 11 | from ddataflow.ddataflow import DDataflow 12 | 13 | start_time = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d") 14 | end_time = datetime.now().strftime("%Y-%m-%d") 15 | 16 | config = { 17 | "data_sources": { 18 | # data sources define how to access data 19 | "events": { 20 | "source": lambda spark: spark.table("events"), 21 | "filter": lambda df: df.filter(F.col("date") >= start_time) 22 | .filter(F.col("date") <= end_time) 23 | .filter(F.col("event_name").isin(["BookAction", "ActivityCardImpression"])) 24 | .filter(F.hour("kafka_timestamp").between(8, 9)) 25 | .limit(1000), 26 | }, 27 | "ActivityCardImpression": { 28 | "source": lambda spark: spark.read.parquet( 29 | "dbfs:/events/ActivityCardImpression" 30 | ), 31 | "filter": lambda df: df.filter(F.col("date") >= start_time) 32 | .filter(F.col("date") <= end_time) 33 | .filter(F.hour("kafka_timestamp").between(8, 9)) 34 | .limit(1000), 35 | }, 36 | "dim_tour": { 37 | "source": lambda spark: spark.table("tour"), 38 | }, 39 | }, 40 | "project_folder_name": "tests", 41 | "data_source_size_limit_gb": 3, 42 | } 43 | 44 | 45 | def test_initialize_successfully(): 46 | """ 47 | Tests that a correct config will not fail to be instantiated 48 | """ 49 | 50 | DDataflow(**config) 51 | 52 | 53 | def test_wrong_config_fails(): 54 | with pytest.raises(BaseException, match="wrong param"): 55 | 56 | DDataflow(**{**config, **{"a wrong param": "a wrong value"}}) 57 | 58 | 59 | def test_current_project_path(): 60 | """ 61 | Test that varying our environment we get different paths 62 | """ 63 | config = { 64 | "project_folder_name": "my_tests", 65 | } 66 | ddataflow = DDataflow(**config) 67 | # by default do not override 68 | assert ddataflow._get_overriden_arctifacts_current_path() is None 69 | ddataflow.enable() 70 | assert ( 71 | "dbfs:/ddataflow/my_tests" == ddataflow._get_overriden_arctifacts_current_path() 72 | ) 73 | ddataflow.enable_offline() 74 | assert ( 75 | os.getenv("HOME") + "/.ddataflow/my_tests" 76 | == ddataflow._get_overriden_arctifacts_current_path() 77 | ) 78 | 79 | 80 | def test_temp_table_name(): 81 | 82 | config = { 83 | "sources_with_default_sampling": ["location"], 84 | "project_folder_name": "unit_tests", 85 | } 86 | 87 | ddataflow = DDataflow(**config) 88 | ddataflow.disable() 89 | # by default do not override 90 | assert ddataflow.name("location", disable_view_creation=True) == "location" 91 | ddataflow.enable() 92 | assert ( 93 | ddataflow.name("location", disable_view_creation=True) == "unit_tests_location" 94 | ) 95 | -------------------------------------------------------------------------------- /tests/test_sampling.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from pyspark.sql.session import SparkSession 4 | 5 | from ddataflow import DDataflow 6 | 7 | 8 | def test_sampling_end2end(): 9 | """ 10 | Tests that a correct _config will not fail to be instantiated 11 | """ 12 | spark = SparkSession.builder.getOrCreate() 13 | entries = [ 14 | ["id1", "Sagrada Familila"], 15 | ["id2", "Eiffel Tower"], 16 | ["id3", "abc"], 17 | ] 18 | df = spark.createDataFrame(entries, ["id", "_name"]) 19 | df.createOrReplaceTempView("location") 20 | 21 | assert spark.table("location").count() == 3 22 | 23 | config = { 24 | "data_sources": { 25 | "location_filtered": { 26 | "source": lambda spark: spark.table("location"), 27 | "filter": lambda df: df.filter(df._name == "abc"), 28 | }, 29 | "location": {}, 30 | }, 31 | "project_folder_name": "unit_tests", 32 | "snapshot_path": "/tmp/ddataflow_test", 33 | } 34 | 35 | ddataflow = DDataflow(**config) 36 | 37 | ddataflow.disable() 38 | ddataflow.disable_offline() 39 | assert ddataflow.source("location").count() == 3 40 | assert ddataflow.source("location_filtered").count() == 3 41 | 42 | ddataflow.enable() 43 | assert ddataflow.source("location").count() == 3 44 | assert ddataflow.source("location_filtered").count() == 1 45 | 46 | ddataflow.save_sampled_data_sources(dry_run=False) 47 | 48 | # after sampling the following destinations have the copy 49 | assert os.path.exists("/tmp/ddataflow_test/unit_tests/location") 50 | assert os.path.exists("/tmp/ddataflow_test/unit_tests/location_filtered") 51 | 52 | 53 | def test_sampling_paths(): 54 | config = { 55 | "data_sources": { 56 | "location_filtered": { 57 | "source": lambda spark: spark.table("location"), 58 | }, 59 | "/mnt/foo/bar": {}, 60 | }, 61 | "project_folder_name": "unit_tests", 62 | "snapshot_path": "/tmp/ddataflow_test", 63 | } 64 | 65 | ddataflow = DDataflow(**config) 66 | assert ( 67 | ddataflow._data_sources.get_data_source( 68 | "location_filtered" 69 | ).get_dbfs_sample_path() 70 | == "/tmp/ddataflow_test/unit_tests/location_filtered" 71 | ) 72 | assert ( 73 | ddataflow._data_sources.get_data_source("/mnt/foo/bar").get_dbfs_sample_path() 74 | == "/tmp/ddataflow_test/unit_tests/_mnt_foo_bar" 75 | ) 76 | -------------------------------------------------------------------------------- /tests/test_sql.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql.session import SparkSession 2 | 3 | from ddataflow import DDataflow 4 | 5 | 6 | def test_sql(): 7 | spark = SparkSession.builder.getOrCreate() 8 | 9 | config = { 10 | "project_folder_name": "unit_tests", 11 | "data_sources": {"location": {"default_sampling": True}}, 12 | "default_sampler": {"limit": 2}, 13 | } 14 | ddataflow = DDataflow(**config) 15 | 16 | query = f""" select count(1) as total 17 | from {ddataflow.source_name('location')} 18 | """ 19 | 20 | ddataflow.disable() 21 | result = spark.sql(query) 22 | # default amount of tours 23 | assert result.collect()[0].total == 3 24 | 25 | ddataflow.enable() 26 | query = f""" select count(1) as total 27 | from {ddataflow.source_name('location')} 28 | """ 29 | result = spark.sql(query) 30 | assert result.collect()[0].total == 2 31 | 32 | 33 | if __name__ == "__main__": 34 | import fire 35 | 36 | fire.Fire() 37 | -------------------------------------------------------------------------------- /tests/tutorial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/getyourguide/DDataFlow/7d7ed8e3fa9e4a51ef85afa50c22de3e60779056/tests/tutorial/__init__.py -------------------------------------------------------------------------------- /tests/tutorial/test_local_dev.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from ddataflow import DDataflow 6 | from ddataflow.utils import using_databricks_connect 7 | 8 | config = { 9 | "sources_with_default_sampling": ["location"], 10 | "data_sources": { 11 | # data sources define how to access data 12 | "tour": { 13 | "source": lambda spark: spark.table("tour"), 14 | "filter": lambda df: df.sample(0.1).limit(500), 15 | }, 16 | }, 17 | "project_folder_name": "unit_tests", 18 | "data_source_size_limit_gb": 0.5, 19 | } 20 | 21 | ddataflow = DDataflow(**config) # type: ignore 22 | ddataflow.enable_offline() 23 | 24 | 25 | @pytest.mark.skipif( 26 | not using_databricks_connect(), reason="needs databricks connect to work" 27 | ) 28 | def test_sample_and_download(): 29 | ddataflow.sample_and_download(ask_confirmation=False) 30 | assert os.path.exists(ddataflow._local_path + "/tour") 31 | assert os.path.exists(ddataflow._local_path + "/location") 32 | 33 | # Done! From now on you can use your sources entirely offline 34 | 35 | 36 | @pytest.mark.skipif( 37 | not using_databricks_connect(), reason="needs databricks connect to work" 38 | ) 39 | def test_read_offline_files(): 40 | ddataflow = DDataflow(**config) # type: ignore 41 | ddataflow.enable_offline() 42 | 43 | df = ddataflow.source("tour") 44 | assert df.count() == 500 45 | df.show(3) 46 | 47 | df = ddataflow.source("location") 48 | assert df.count() == 1000 49 | df.show(3) 50 | 51 | 52 | if __name__ == "__main__": 53 | import fire 54 | 55 | fire.Fire() 56 | -------------------------------------------------------------------------------- /tests/tutorial/test_sources_and_writers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pyspark.sql.functions as F 4 | from pyspark.sql.session import SparkSession 5 | 6 | from ddataflow import DDataflow 7 | 8 | data = [ 9 | { 10 | "tour_id": 132, 11 | "tour_title": "Helicopter Tour", 12 | "status": "active", 13 | "location_id": 65, 14 | }, 15 | { 16 | "tour_id": 558, 17 | "tour_title": "Walking Tour", 18 | "location_id": 59, 19 | }, 20 | { 21 | "tour_id": 3009, 22 | "tour_title": "Transfer Hotel", 23 | "location_id": 65, 24 | }, 25 | ] 26 | 27 | 28 | def test_source_filters_when_toggled(): 29 | """ 30 | Creates a datasource 31 | """ 32 | 33 | spark = SparkSession.builder.getOrCreate() 34 | 35 | df = spark.createDataFrame(data) 36 | df.createOrReplaceTempView("tour") 37 | 38 | config = { 39 | "data_sources": { 40 | # data sources define how to access data 41 | "tour": { 42 | "source": lambda spark: spark.table("tour"), 43 | "filter": lambda df: df.filter("location_id == 65"), 44 | }, 45 | }, 46 | "project_folder_name": "unit_tests", 47 | } 48 | 49 | ddataflow = DDataflow(**config) 50 | assert df.count() == 3 51 | 52 | ddataflow.enable() 53 | assert ddataflow.source("tour").count() == 2 54 | 55 | ddataflow.disable() 56 | assert ddataflow.source("tour").count() == 3 57 | 58 | 59 | def test_writer(): 60 | """Test that writing works""" 61 | spark = SparkSession.builder.getOrCreate() 62 | config = { 63 | "data_writers": { 64 | "test_new_dim_tour": { 65 | "writer": lambda df, name, spark: df.registerTempTable(name) 66 | }, 67 | }, 68 | "project_folder_name": "unit_tests", 69 | } 70 | ddataflow = DDataflow(**config) 71 | # to not write in dbfs lets do this operation offline 72 | ddataflow.enable_offline() 73 | 74 | df = spark.createDataFrame(data) 75 | new_df = df.withColumn("new_column", F.lit("a new value")) 76 | 77 | ddataflow.write(new_df, "test_new_dim_tour") 78 | loaded_from_disk = ddataflow.read("test_new_dim_tour") 79 | 80 | assert new_df.columns == loaded_from_disk.columns 81 | 82 | 83 | def test_mlflow(): 84 | """ 85 | Test that varying our environment we get different paths 86 | """ 87 | original_model_path = "/Shared/RnR/MLflow/MyModel" 88 | ddataflow = DDataflow( 89 | **{ 90 | "project_folder_name": "my_tests", 91 | } 92 | ) 93 | 94 | # when ddataflow is disabled do not change the path 95 | original_model_path == ddataflow.get_mlflow_path(original_model_path) 96 | 97 | # when it is enabled write do a special place in dbfs 98 | ddataflow.enable() 99 | assert "dbfs:/ddataflow/my_tests/MyModel" == ddataflow.get_mlflow_path( 100 | original_model_path 101 | ) 102 | 103 | # when it is in offline mode write to the local filesystems 104 | ddataflow.enable_offline() 105 | assert os.getenv( 106 | "HOME" 107 | ) + "/.ddataflow/my_tests/MyModel" == ddataflow.get_mlflow_path(original_model_path) 108 | 109 | 110 | if __name__ == "__main__": 111 | import fire 112 | 113 | fire.Fire() 114 | --------------------------------------------------------------------------------