├── .github └── workflows │ ├── ci.yml │ └── docs.yml ├── .gitignore ├── LICENSE ├── README.md ├── contributing ├── DEVELOPMENT.md └── RELEASE.md ├── docs ├── api-reference.md ├── building-data-sources.md ├── data-sources-guide.md └── simple-stream-reader-architecture.md ├── examples └── salesforce_example.py ├── poetry.lock ├── pyproject.toml ├── pyspark_datasources ├── __init__.py ├── arrow.py ├── fake.py ├── github.py ├── googlesheets.py ├── huggingface.py ├── jsonplaceholder.py ├── kaggle.py ├── lance.py ├── opensky.py ├── robinhood.py ├── salesforce.py ├── simplejson.py ├── stock.py └── weather.py └── tests ├── __init__.py ├── test_data_sources.py ├── test_google_sheets.py └── test_robinhood.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main, master ] 6 | pull_request: 7 | branches: [ main, master ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ['3.9', '3.10', '3.11', '3.12'] 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Install Poetry 27 | uses: snok/install-poetry@v1.3.4 28 | with: 29 | version: latest 30 | virtualenvs-create: true 31 | virtualenvs-in-project: true 32 | 33 | - name: Load cached venv 34 | id: cached-poetry-dependencies 35 | uses: actions/cache@v4 36 | with: 37 | path: .venv 38 | key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} 39 | 40 | - name: Install dependencies 41 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 42 | run: poetry install --no-interaction --no-root --extras "all" 43 | 44 | - name: Install project 45 | run: poetry install --no-interaction --extras "all" 46 | 47 | - name: Run tests 48 | run: poetry run pytest tests/ -v 49 | 50 | - name: Run tests with coverage 51 | run: | 52 | poetry run pytest tests/ --cov=pyspark_datasources --cov-report=xml --cov-report=term-missing 53 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy MkDocs to GitHub Pages 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - 'docs/**' 10 | - 'mkdocs.yml' 11 | - 'pyproject.toml' 12 | - '.github/workflows/docs.yml' 13 | workflow_dispatch: # Allow manual triggering 14 | 15 | permissions: 16 | contents: read 17 | pages: write 18 | id-token: write 19 | 20 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 21 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 22 | concurrency: 23 | group: "pages" 24 | cancel-in-progress: false 25 | 26 | jobs: 27 | build: 28 | runs-on: ubuntu-latest 29 | steps: 30 | - name: Checkout 31 | uses: actions/checkout@v4 32 | 33 | - name: Set up Python 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: '3.11' 37 | 38 | - name: Install Poetry 39 | run: | 40 | curl -sSL https://install.python-poetry.org | python - -y 41 | echo "$HOME/.local/bin" >> $GITHUB_PATH 42 | 43 | - name: Configure Poetry 44 | run: | 45 | poetry config virtualenvs.create true 46 | poetry config virtualenvs.in-project true 47 | 48 | - name: Load cached venv 49 | id: cached-poetry-dependencies 50 | uses: actions/cache@v4 51 | with: 52 | path: .venv 53 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 54 | 55 | - name: Install dependencies 56 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 57 | run: poetry install --no-interaction --no-root 58 | 59 | - name: Install project 60 | run: poetry install --no-interaction 61 | 62 | - name: Build MkDocs 63 | run: poetry run mkdocs build 64 | 65 | - name: Setup Pages 66 | uses: actions/configure-pages@v5 67 | 68 | - name: Upload artifact 69 | uses: actions/upload-pages-artifact@v3 70 | with: 71 | path: ./site 72 | 73 | deploy: 74 | environment: 75 | name: github-pages 76 | url: ${{ steps.deployment.outputs.page_url }} 77 | runs-on: ubuntu-latest 78 | needs: build 79 | steps: 80 | - name: Deploy to GitHub Pages 81 | id: deployment 82 | uses: actions/deploy-pages@v4 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Claude Code 163 | .claude/ 164 | claude_cache/ 165 | 166 | # Gemini 167 | .gemini/ 168 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PySpark Data Sources 2 | 3 | [![pypi](https://img.shields.io/pypi/v/pyspark-data-sources.svg?color=blue)](https://pypi.org/project/pyspark-data-sources/) 4 | [![code style: ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 5 | 6 | Custom Apache Spark data sources using the [Python Data Source API](https://spark.apache.org/docs/4.0.0/api/python/tutorial/sql/python_data_source.html) (Spark 4.0+). Learn by example and build your own data sources. 7 | 8 | ## Quick Start 9 | 10 | ### Installation 11 | 12 | ```bash 13 | pip install pyspark-data-sources 14 | 15 | # Install with specific extras 16 | pip install pyspark-data-sources[faker] # For FakeDataSource 17 | 18 | pip install pyspark-data-sources[all] # All optional dependencies 19 | ``` 20 | 21 | ### Requirements 22 | - Apache Spark 4.0+ or [Databricks Runtime 15.4 LTS](https://docs.databricks.com/aws/en/release-notes/runtime/15.4lts)+ 23 | - Python 3.9-3.12 24 | 25 | ### Basic Usage 26 | 27 | ```python 28 | from pyspark.sql import SparkSession 29 | from pyspark_datasources import FakeDataSource 30 | 31 | # Create Spark session 32 | spark = SparkSession.builder.appName("datasource-demo").getOrCreate() 33 | 34 | # Register the data source 35 | spark.dataSource.register(FakeDataSource) 36 | 37 | # Read batch data 38 | df = spark.read.format("fake").option("numRows", 5).load() 39 | df.show() 40 | # +--------------+----------+-------+------------+ 41 | # | name| date|zipcode| state| 42 | # +--------------+----------+-------+------------+ 43 | # | Pam Mitchell|1988-10-20| 23788| Tennessee| 44 | # |Melissa Turner|1996-06-14| 30851| Nevada| 45 | # | Brian Ramsey|2021-08-21| 55277| Washington| 46 | # | Caitlin Reed|1983-06-22| 89813|Pennsylvania| 47 | # | Douglas James|2007-01-18| 46226| Alabama| 48 | # +--------------+----------+-------+------------+ 49 | 50 | # Stream data 51 | stream = spark.readStream.format("fake").load() 52 | query = stream.writeStream.format("console").start() 53 | ``` 54 | 55 | ## Available Data Sources 56 | 57 | | Data Source | Type | Description | Install | 58 | |-------------|------|-------------|---------| 59 | | `fake` | Batch/Stream | Generate synthetic test data using Faker | `pip install pyspark-data-sources[faker]` | 60 | | `github` | Batch | Read GitHub pull requests | Built-in | 61 | | `googlesheets` | Batch | Read public Google Sheets | Built-in | 62 | | `huggingface` | Batch | Load Hugging Face datasets | `[huggingface]` | 63 | | `stock` | Batch | Fetch stock market data (Alpha Vantage) | Built-in | 64 | | `opensky` | Batch/Stream | Live flight tracking data | Built-in | 65 | | `kaggle` | Batch | Load Kaggle datasets | `[kaggle]` | 66 | | `arrow` | Batch | Read Apache Arrow files | `[arrow]` | 67 | | `lance` | Batch Write | Write Lance vector format | `[lance]` | 68 | 69 | 📚 **[See detailed examples for all data sources →](docs/data-sources-guide.md)** 70 | 71 | ## Example: Generate Fake Data 72 | 73 | ```python 74 | from pyspark_datasources import FakeDataSource 75 | 76 | spark.dataSource.register(FakeDataSource) 77 | 78 | # Generate synthetic data with custom schema 79 | df = spark.read.format("fake") \ 80 | .schema("name string, email string, company string") \ 81 | .option("numRows", 5) \ 82 | .load() 83 | 84 | df.show(truncate=False) 85 | # +------------------+-------------------------+-----------------+ 86 | # |name |email |company | 87 | # +------------------+-------------------------+-----------------+ 88 | # |Christine Sampson |johnsonjeremy@example.com|Hernandez-Nguyen | 89 | # |Yolanda Brown |williamlowe@example.net |Miller-Hernandez | 90 | # +------------------+-------------------------+-----------------+ 91 | ``` 92 | 93 | ## Building Your Own Data Source 94 | 95 | Here's a minimal example to get started: 96 | 97 | ```python 98 | from pyspark.sql.datasource import DataSource, DataSourceReader 99 | from pyspark.sql.types import StructType, StructField, StringType, IntegerType 100 | 101 | class MyCustomDataSource(DataSource): 102 | def name(self): 103 | return "mycustom" 104 | 105 | def schema(self): 106 | return StructType([ 107 | StructField("id", IntegerType()), 108 | StructField("name", StringType()) 109 | ]) 110 | 111 | def reader(self, schema): 112 | return MyCustomReader(self.options, schema) 113 | 114 | class MyCustomReader(DataSourceReader): 115 | def __init__(self, options, schema): 116 | self.options = options 117 | self.schema = schema 118 | 119 | def read(self, partition): 120 | # Your data reading logic here 121 | for i in range(10): 122 | yield (i, f"name_{i}") 123 | 124 | # Register and use 125 | spark.dataSource.register(MyCustomDataSource) 126 | df = spark.read.format("mycustom").load() 127 | ``` 128 | 129 | 📖 **[Complete guide with advanced patterns →](docs/building-data-sources.md)** 130 | 131 | ## Documentation 132 | 133 | - 📚 **[Data Sources Guide](docs/data-sources-guide.md)** - Detailed examples for each data source 134 | - 🔧 **[Building Data Sources](docs/building-data-sources.md)** - Complete tutorial with advanced patterns 135 | - 📖 **[API Reference](docs/api-reference.md)** - Full API specification and method signatures 136 | - 💻 **[Development Guide](contributing/DEVELOPMENT.md)** - Contributing and development setup 137 | 138 | ## Requirements 139 | 140 | - Apache Spark 4.0+ or Databricks Runtime 15.4 LTS+ 141 | - Python 3.9-3.12 142 | 143 | ## Contributing 144 | 145 | We welcome contributions! See our [Development Guide](contributing/DEVELOPMENT.md) for details. 146 | 147 | ## Resources 148 | 149 | - [Python Data Source API Documentation](https://spark.apache.org/docs/4.0.0/api/python/tutorial/sql/python_data_source.html) 150 | - [API Source Code](https://github.com/apache/spark/blob/master/python/pyspark/sql/datasource.py) 151 | -------------------------------------------------------------------------------- /contributing/DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | # Development Guide 2 | 3 | ## Environment Setup 4 | 5 | ### Prerequisites 6 | - Python 3.9-3.12 7 | - Poetry for dependency management 8 | - Apache Spark 4.0+ (or Databricks Runtime 15.4 LTS+) 9 | 10 | ### Installation 11 | 12 | ```bash 13 | # Clone the repository 14 | git clone https://github.com/allisonwang-db/pyspark-data-sources.git 15 | cd pyspark-data-sources 16 | 17 | # Install dependencies 18 | poetry install 19 | 20 | # Install with all optional dependencies 21 | poetry install --extras all 22 | 23 | # Activate virtual environment 24 | poetry shell 25 | ``` 26 | 27 | ### macOS Setup 28 | On macOS, you may encounter fork safety issues with PyArrow. Set this environment variable: 29 | 30 | ```bash 31 | export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES 32 | 33 | # Add to your shell profile for persistence 34 | echo 'export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES' >> ~/.zshrc 35 | ``` 36 | 37 | ## Testing 38 | 39 | ### Running Tests 40 | 41 | ```bash 42 | # Run all tests 43 | pytest 44 | 45 | # Run specific test file 46 | pytest tests/test_data_sources.py 47 | 48 | # Run specific test with verbose output 49 | pytest tests/test_data_sources.py::test_fake_datasource -v 50 | 51 | # Run with coverage 52 | pytest --cov=pyspark_datasources --cov-report=html 53 | ``` 54 | 55 | ### Writing Tests 56 | 57 | Tests follow this pattern: 58 | 59 | ```python 60 | import pytest 61 | from pyspark.sql import SparkSession 62 | 63 | @pytest.fixture 64 | def spark(): 65 | return SparkSession.builder \ 66 | .appName("test") \ 67 | .config("spark.sql.shuffle.partitions", "1") \ 68 | .getOrCreate() 69 | 70 | def test_my_datasource(spark): 71 | from pyspark_datasources import MyDataSource 72 | spark.dataSource.register(MyDataSource) 73 | 74 | df = spark.read.format("myformat").load() 75 | assert df.count() > 0 76 | assert len(df.columns) == expected_columns 77 | ``` 78 | 79 | ## Code Quality 80 | 81 | ### Formatting with Ruff 82 | 83 | This project uses [Ruff](https://github.com/astral-sh/ruff) for code formatting and linting. 84 | 85 | ```bash 86 | # Format code 87 | poetry run ruff format . 88 | 89 | # Run linter 90 | poetry run ruff check . 91 | 92 | # Run linter with auto-fix 93 | poetry run ruff check . --fix 94 | 95 | # Check specific file 96 | poetry run ruff check pyspark_datasources/fake.py 97 | ``` 98 | 99 | ### Pre-commit Hooks (Optional) 100 | 101 | ```bash 102 | # Install pre-commit hooks 103 | poetry add --group dev pre-commit 104 | pre-commit install 105 | 106 | # Run manually 107 | pre-commit run --all-files 108 | ``` 109 | 110 | ## Documentation 111 | 112 | ### Building Documentation 113 | 114 | The project previously used MkDocs for documentation. Documentation now lives primarily in: 115 | - README.md - Main documentation 116 | - Docstrings in source code 117 | - Contributing guides in /contributing 118 | 119 | ### Writing Docstrings 120 | 121 | Follow this pattern for data source docstrings: 122 | 123 | ```python 124 | class MyDataSource(DataSource): 125 | """ 126 | Brief description of the data source. 127 | 128 | Longer description explaining what it does and any important details. 129 | 130 | Name: `myformat` 131 | 132 | Options 133 | ------- 134 | option1 : str, optional 135 | Description of option1 (default: "value") 136 | option2 : int, required 137 | Description of option2 138 | 139 | Examples 140 | -------- 141 | Register and use the data source: 142 | 143 | >>> from pyspark_datasources import MyDataSource 144 | >>> spark.dataSource.register(MyDataSource) 145 | >>> df = spark.read.format("myformat").option("option2", 100).load() 146 | >>> df.show() 147 | +---+-----+ 148 | | id|value| 149 | +---+-----+ 150 | | 1| foo| 151 | | 2| bar| 152 | +---+-----+ 153 | 154 | >>> df.printSchema() 155 | root 156 | |-- id: integer (nullable = true) 157 | |-- value: string (nullable = true) 158 | """ 159 | ``` 160 | 161 | ## Adding New Data Sources 162 | 163 | ### Step 1: Create the Data Source File 164 | 165 | Create a new file in `pyspark_datasources/`: 166 | 167 | ```python 168 | # pyspark_datasources/mynewsource.py 169 | from pyspark.sql.datasource import DataSource, DataSourceReader 170 | from pyspark.sql.types import StructType, StructField, StringType 171 | 172 | class MyNewDataSource(DataSource): 173 | def name(self): 174 | return "mynewformat" 175 | 176 | def schema(self): 177 | return StructType([ 178 | StructField("field1", StringType(), True), 179 | StructField("field2", StringType(), True) 180 | ]) 181 | 182 | def reader(self, schema): 183 | return MyNewReader(self.options, schema) 184 | 185 | class MyNewReader(DataSourceReader): 186 | def __init__(self, options, schema): 187 | self.options = options 188 | self.schema = schema 189 | 190 | def read(self, partition): 191 | # Implement reading logic 192 | for i in range(10): 193 | yield ("value1", "value2") 194 | ``` 195 | 196 | ### Step 2: Add to __init__.py 197 | 198 | ```python 199 | # pyspark_datasources/__init__.py 200 | from pyspark_datasources.mynewsource import MyNewDataSource 201 | 202 | __all__ = [ 203 | # ... existing exports ... 204 | "MyNewDataSource", 205 | ] 206 | ``` 207 | 208 | ### Step 3: Add Tests 209 | 210 | ```python 211 | # tests/test_data_sources.py 212 | def test_mynew_datasource(spark): 213 | from pyspark_datasources import MyNewDataSource 214 | spark.dataSource.register(MyNewDataSource) 215 | 216 | df = spark.read.format("mynewformat").load() 217 | assert df.count() == 10 218 | assert df.columns == ["field1", "field2"] 219 | ``` 220 | 221 | ### Step 4: Update Documentation 222 | 223 | Add your data source to the table in README.md with examples. 224 | 225 | ## Package Management 226 | 227 | ### Adding Dependencies 228 | 229 | ```bash 230 | # Add required dependency 231 | poetry add requests 232 | 233 | # Add optional dependency 234 | poetry add --optional faker 235 | 236 | # Add dev dependency 237 | poetry add --group dev pytest-cov 238 | 239 | # Update dependencies 240 | poetry update 241 | ``` 242 | 243 | ### Managing Extras 244 | 245 | Edit `pyproject.toml` to add optional dependency groups: 246 | 247 | ```toml 248 | [tool.poetry.extras] 249 | mynewsource = ["special-library"] 250 | all = ["faker", "datasets", "special-library", ...] 251 | ``` 252 | 253 | ## Debugging 254 | 255 | ### Enable Spark Logging 256 | 257 | ```python 258 | import logging 259 | logging.basicConfig(level=logging.INFO) 260 | 261 | # Or in Spark config 262 | spark = SparkSession.builder \ 263 | .config("spark.sql.execution.arrow.pyspark.enabled", "true") \ 264 | .config("spark.log.level", "INFO") \ 265 | .getOrCreate() 266 | ``` 267 | 268 | ### Debug Data Source 269 | 270 | ```python 271 | class DebugReader(DataSourceReader): 272 | def read(self, partition): 273 | print(f"Reading partition: {partition.value if hasattr(partition, 'value') else partition}") 274 | print(f"Options: {self.options}") 275 | 276 | # Your reading logic 277 | for row in data: 278 | print(f"Yielding: {row}") 279 | yield row 280 | ``` 281 | 282 | ### Common Issues 283 | 284 | 1. **Serialization errors**: Ensure all class attributes are pickle-able 285 | 2. **Schema mismatch**: Verify returned data matches declared schema 286 | 3. **Missing dependencies**: Use try/except to provide helpful error messages 287 | 4. **API rate limits**: Implement backoff and retry logic 288 | 289 | ## Performance Optimization 290 | 291 | ### Use Partitioning 292 | 293 | ```python 294 | class OptimizedReader(DataSourceReader): 295 | def partitions(self): 296 | # Split work into multiple partitions 297 | num_partitions = int(self.options.get("numPartitions", "4")) 298 | return [InputPartition(i) for i in range(num_partitions)] 299 | 300 | def read(self, partition): 301 | # Each partition processes its subset 302 | partition_id = partition.value 303 | # Process only this partition's data 304 | ``` 305 | 306 | ### Use Arrow Format 307 | 308 | ```python 309 | import pyarrow as pa 310 | 311 | class ArrowOptimizedReader(DataSourceReader): 312 | def read(self, partition): 313 | # Return pyarrow.RecordBatch for better performance 314 | arrays = [ 315 | pa.array(["value1", "value2"]), 316 | pa.array([1, 2]) 317 | ] 318 | batch = pa.RecordBatch.from_arrays(arrays, names=["col1", "col2"]) 319 | yield batch 320 | ``` 321 | 322 | ## Continuous Integration 323 | 324 | The project uses GitHub Actions for CI/CD. Workflows are defined in `.github/workflows/`. 325 | 326 | ### Running CI Locally 327 | 328 | ```bash 329 | # Install act (GitHub Actions locally) 330 | brew install act # macOS 331 | 332 | # Run workflows 333 | act -j test 334 | ``` 335 | 336 | ## Troubleshooting 337 | 338 | ### PyArrow Issues on macOS 339 | 340 | ```bash 341 | # Set environment variable 342 | export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES 343 | 344 | # Or in Python 345 | import os 346 | os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES" 347 | ``` 348 | 349 | ### Poetry Issues 350 | 351 | ```bash 352 | # Clear cache 353 | poetry cache clear pypi --all 354 | 355 | # Update lock file 356 | poetry lock --no-update 357 | 358 | # Reinstall 359 | poetry install --remove-untracked 360 | ``` 361 | 362 | ### Spark Session Issues 363 | 364 | ```python 365 | # Stop existing session 366 | from pyspark.sql import SparkSession 367 | spark = SparkSession.getActiveSession() 368 | if spark: 369 | spark.stop() 370 | 371 | # Create new session 372 | spark = SparkSession.builder.appName("debug").getOrCreate() 373 | ``` -------------------------------------------------------------------------------- /contributing/RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release Workflow 2 | 3 | This document outlines the steps to create a new release for the `pyspark-data-sources` project. 4 | 5 | ## Prerequisites 6 | 7 | - Ensure you have Poetry installed 8 | - Ensure you have GitHub CLI installed (optional, for enhanced releases) 9 | - Ensure you have push access to the repository 10 | - Ensure all tests pass and the code is ready for release 11 | 12 | ## Release Steps 13 | 14 | ### 1. Update Version (Using Poetry Commands) 15 | 16 | Use Poetry's built-in version bumping commands: 17 | 18 | ```bash 19 | # Bump patch version (0.1.6 → 0.1.7) - for bug fixes 20 | poetry version patch 21 | 22 | # Bump minor version (0.1.6 → 0.2.0) - for new features 23 | poetry version minor 24 | 25 | # Bump major version (0.1.6 → 1.0.0) - for breaking changes 26 | poetry version major 27 | 28 | # Or set a specific version 29 | poetry version 1.2.3 30 | ``` 31 | 32 | This automatically updates the version in `pyproject.toml`. 33 | 34 | ### 2. Build and Publish 35 | 36 | ```bash 37 | # Build the package 38 | poetry build 39 | 40 | # Publish to PyPI (requires PyPI credentials) 41 | poetry publish 42 | ``` 43 | 44 | ### 3. Commit Version Changes 45 | 46 | ```bash 47 | # Add the version change 48 | git add pyproject.toml 49 | 50 | # Commit with the current version (automatically retrieved) 51 | git commit -m "Bump version to $(poetry version -s)" 52 | 53 | # Push to main branch 54 | git push 55 | ``` 56 | 57 | ### 4. Create GitHub Release 58 | 59 | #### Option A: Simple Git Tag 60 | ```bash 61 | # Create an annotated tag with current version 62 | git tag -a "v$(poetry version -s)" -m "Release version $(poetry version -s)" 63 | 64 | # Push the tag to GitHub 65 | git push origin "v$(poetry version -s)" 66 | ``` 67 | 68 | #### Option B: Rich GitHub Release (Recommended) 69 | ```bash 70 | # Create a GitHub release with current version 71 | gh release create "v$(poetry version -s)" \ 72 | --title "Release v$(poetry version -s)" \ 73 | --notes "Release notes for version $(poetry version -s)" \ 74 | --latest 75 | ``` 76 | 77 | ## Version Numbering 78 | 79 | Follow [Semantic Versioning](https://semver.org/): 80 | 81 | - **Patch** (`poetry version patch`): Bug fixes, no breaking changes 82 | - **Minor** (`poetry version minor`): New features, backward compatible 83 | - **Major** (`poetry version major`): Breaking changes 84 | 85 | ## Manual Version Update (Alternative) 86 | 87 | If you prefer to manually edit `pyproject.toml`: 88 | 89 | ```toml 90 | [tool.poetry] 91 | version = "x.y.z" # Update this line manually 92 | ``` 93 | 94 | Then follow steps 2-4 above. 95 | 96 | ## Release Checklist 97 | 98 | - [ ] All tests pass 99 | - [ ] Documentation is up to date 100 | - [ ] CHANGELOG.md is updated (if applicable) 101 | - [ ] Version is bumped using `poetry version [patch|minor|major]` 102 | - [ ] Package builds successfully (`poetry build`) 103 | - [ ] Package publishes successfully (`poetry publish`) 104 | - [ ] Version changes are committed and pushed 105 | - [ ] GitHub release/tag is created 106 | - [ ] Release notes are written 107 | 108 | ## Troubleshooting 109 | 110 | ### Publishing Issues 111 | - Ensure you're authenticated with PyPI: `poetry config pypi-token.pypi your-token` 112 | - Check if the version already exists on PyPI 113 | 114 | ### Git Tag Issues 115 | - If tag already exists: `git tag -d "v$(poetry version -s)"` (delete local) and `git push origin :refs/tags/"v$(poetry version -s)"` (delete remote) 116 | - Ensure you have push permissions to the repository 117 | 118 | ### GitHub CLI Issues 119 | - Authenticate: `gh auth login` 120 | - Check repository access: `gh repo view` 121 | 122 | ### PyArrow Compatibility Issues 123 | 124 | If you see `objc_initializeAfterForkError` crashes on macOS, set this environment variable: 125 | 126 | ```bash 127 | # For single commands 128 | OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES python your_script.py 129 | 130 | # For Poetry environment 131 | OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES poetry run python your_script.py 132 | 133 | # To set permanently in your shell (add to ~/.zshrc or ~/.bash_profile): 134 | export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES 135 | ``` 136 | 137 | ## Useful Poetry Version Commands 138 | 139 | ```bash 140 | # Check current version 141 | poetry version 142 | 143 | # Show version number only (useful for scripts) 144 | poetry version -s 145 | 146 | # Preview what the next version would be (without changing it) 147 | poetry version --dry-run patch 148 | poetry version --dry-run minor 149 | poetry version --dry-run major 150 | ``` 151 | 152 | ## Documentation Releases 153 | 154 | The project uses MkDocs with GitHub Pages for documentation. The documentation is automatically built and deployed via GitHub Actions. 155 | 156 | ### Automatic Documentation Updates 157 | 158 | The docs workflow (`.github/workflows/docs.yml`) automatically triggers when you push changes to: 159 | - `docs/**` - Any documentation files 160 | - `mkdocs.yml` - MkDocs configuration 161 | - `pyproject.toml` - Project configuration (version updates) 162 | - `.github/workflows/docs.yml` - The workflow itself 163 | 164 | ### Manual Documentation Deployment 165 | 166 | You can manually trigger documentation deployment: 167 | 168 | ```bash 169 | # Using GitHub CLI 170 | gh workflow run docs.yml 171 | 172 | # Or trigger via GitHub web interface: 173 | # Go to Actions tab → Deploy MkDocs to GitHub Pages → Run workflow 174 | ``` 175 | 176 | ### Documentation URLs 177 | 178 | - **Live Docs**: https://allisonwang-db.github.io/pyspark-data-sources 179 | - **Source**: `docs/` directory 180 | - **Configuration**: `mkdocs.yml` 181 | - **Workflow**: `.github/workflows/docs.yml` 182 | 183 | ### Adding New Documentation 184 | 185 | 1. Create new `.md` files in `docs/` or `docs/datasources/` 186 | 2. Update `mkdocs.yml` navigation if needed 187 | 3. Push to main/master branch 188 | 4. Documentation will auto-deploy via GitHub Actions 189 | -------------------------------------------------------------------------------- /docs/api-reference.md: -------------------------------------------------------------------------------- 1 | # Python Data Source API Reference 2 | 3 | Complete reference for the Python Data Source API introduced in Apache Spark 4.0. 4 | 5 | ## Core Abstract Base Classes 6 | 7 | ### DataSource 8 | 9 | The primary abstract base class for custom data sources supporting read/write operations. 10 | 11 | ```python 12 | class DataSource: 13 | def __init__(self, options: Dict[str, str]) 14 | def name() -> str 15 | def schema() -> StructType 16 | def reader(schema: StructType) -> DataSourceReader 17 | def writer(schema: StructType, overwrite: bool) -> DataSourceWriter 18 | def streamReader(schema: StructType) -> DataSourceStreamReader 19 | def streamWriter(schema: StructType, overwrite: bool) -> DataSourceStreamWriter 20 | def simpleStreamReader(schema: StructType) -> SimpleDataSourceStreamReader 21 | ``` 22 | 23 | #### Methods 24 | 25 | | Method | Required | Description | Default | 26 | |--------|----------|-------------|---------| 27 | | `__init__(options)` | No | Initialize with user options | Base class provides default | 28 | | `name()` | No | Return format name for registration | Class name | 29 | | `schema()` | Yes | Define the data source schema | No default | 30 | | `reader(schema)` | If batch read | Create batch reader | Not implemented | 31 | | `writer(schema, overwrite)` | If batch write | Create batch writer | Not implemented | 32 | | `streamReader(schema)` | If streaming* | Create streaming reader | Not implemented | 33 | | `streamWriter(schema, overwrite)` | If stream write | Create streaming writer | Not implemented | 34 | | `simpleStreamReader(schema)` | If streaming* | Create simple streaming reader | Not implemented | 35 | 36 | *For streaming read, implement either `streamReader` or `simpleStreamReader`, not both. 37 | 38 | ### DataSourceReader 39 | 40 | Abstract base class for reading data from sources in batch mode. 41 | 42 | ```python 43 | class DataSourceReader: 44 | def read(partition) -> Iterator 45 | def partitions() -> List[InputPartition] 46 | ``` 47 | 48 | #### Methods 49 | 50 | | Method | Required | Description | Default | 51 | |--------|----------|-------------|---------| 52 | | `read(partition)` | Yes | Read data from partition | No default | 53 | | `partitions()` | No | Return input partitions for parallel reading | Single partition | 54 | 55 | #### Return Types for read() 56 | 57 | The `read()` method can return: 58 | - **Tuples**: Matching the schema field order 59 | - **Row objects**: `from pyspark.sql import Row` 60 | - **pyarrow.RecordBatch**: For better performance 61 | 62 | ### DataSourceWriter 63 | 64 | Abstract base class for writing data to external sources in batch mode. 65 | 66 | ```python 67 | class DataSourceWriter: 68 | def write(iterator) -> WriterCommitMessage 69 | def commit(messages: List[WriterCommitMessage]) -> WriteResult 70 | def abort(messages: List[WriterCommitMessage]) -> None 71 | ``` 72 | 73 | #### Methods 74 | 75 | | Method | Required | Description | 76 | |--------|----------|-------------| 77 | | `write(iterator)` | Yes | Write data from iterator, return commit message | 78 | | `commit(messages)` | Yes | Commit successful writes | 79 | | `abort(messages)` | No | Handle write failures and cleanup | 80 | 81 | ### DataSourceStreamReader 82 | 83 | Abstract base class for streaming data sources with full offset management and partition planning. 84 | 85 | ```python 86 | class DataSourceStreamReader: 87 | def initialOffset() -> dict 88 | def latestOffset() -> dict 89 | def partitions(start: dict, end: dict) -> List[InputPartition] 90 | def read(partition) -> Iterator 91 | def commit(end: dict) -> None 92 | def stop() -> None 93 | ``` 94 | 95 | #### Methods 96 | 97 | | Method | Required | Description | 98 | |--------|----------|-------------| 99 | | `initialOffset()` | Yes | Return starting offset | 100 | | `latestOffset()` | Yes | Return latest available offset | 101 | | `partitions(start, end)` | Yes | Get partitions for offset range | 102 | | `read(partition)` | Yes | Read data from partition | 103 | | `commit(end)` | No | Mark offsets as processed | 104 | | `stop()` | No | Clean up resources | 105 | 106 | ### SimpleDataSourceStreamReader 107 | 108 | Simplified streaming reader interface without partition planning. Choose this over `DataSourceStreamReader` when your data source doesn't naturally partition. 109 | 110 | ```python 111 | class SimpleDataSourceStreamReader: 112 | def initialOffset() -> dict 113 | def read(start: dict) -> Tuple[Iterator, dict] 114 | def readBetweenOffsets(start: dict, end: dict) -> Iterator 115 | def commit(end: dict) -> None 116 | ``` 117 | 118 | #### Methods 119 | 120 | | Method | Required | Description | 121 | |--------|----------|-------------| 122 | | `initialOffset()` | Yes | Return starting offset | 123 | | `read(start)` | Yes | Read from start offset, return (iterator, next_offset) | 124 | | `readBetweenOffsets(start, end)` | Recommended | Deterministic replay between offsets | 125 | | `commit(end)` | No | Mark offsets as processed | 126 | 127 | ### DataSourceStreamWriter 128 | 129 | Abstract base class for writing data to external sinks in streaming queries. 130 | 131 | ```python 132 | class DataSourceStreamWriter: 133 | def write(iterator) -> WriterCommitMessage 134 | def commit(messages: List[WriterCommitMessage], batchId: int) -> None 135 | def abort(messages: List[WriterCommitMessage], batchId: int) -> None 136 | ``` 137 | 138 | #### Methods 139 | 140 | | Method | Required | Description | 141 | |--------|----------|-------------| 142 | | `write(iterator)` | Yes | Write data for a partition | 143 | | `commit(messages, batchId)` | Yes | Commit successful microbatch writes | 144 | | `abort(messages, batchId)` | No | Handle write failures for a microbatch | 145 | 146 | ## Helper Classes 147 | 148 | ### InputPartition 149 | 150 | Represents a partition of input data. 151 | 152 | ```python 153 | class InputPartition: 154 | def __init__(self, value: Any) 155 | ``` 156 | 157 | The `value` can be any serializable Python object that identifies the partition (int, dict, tuple, etc.). 158 | 159 | ### WriterCommitMessage 160 | 161 | Message returned from successful write operations. 162 | 163 | ```python 164 | class WriterCommitMessage: 165 | def __init__(self, value: Any) 166 | ``` 167 | 168 | The `value` typically contains metadata about the write (e.g., file path, row count). 169 | 170 | ### WriteResult 171 | 172 | Final result after committing all writes. 173 | 174 | ```python 175 | class WriteResult: 176 | def __init__(self, **kwargs) 177 | ``` 178 | 179 | Can contain any metadata about the completed write operation. 180 | 181 | ## Implementation Requirements 182 | 183 | ### 1. Serialization 184 | 185 | All classes must be pickle-serializable. This means: 186 | 187 | ```python 188 | # ❌ BAD - Connection objects can't be pickled 189 | class BadReader(DataSourceReader): 190 | def __init__(self, options): 191 | import psycopg2 192 | self.connection = psycopg2.connect(options['url']) 193 | 194 | # ✅ GOOD - Store configuration, create connections in read() 195 | class GoodReader(DataSourceReader): 196 | def __init__(self, options): 197 | self.connection_url = options['url'] 198 | 199 | def read(self, partition): 200 | import psycopg2 201 | conn = psycopg2.connect(self.connection_url) 202 | try: 203 | # Use connection 204 | yield from fetch_data(conn) 205 | finally: 206 | conn.close() 207 | ``` 208 | 209 | ### 2. Schema Definition 210 | 211 | Use PySpark's type system: 212 | 213 | ```python 214 | from pyspark.sql.types import * 215 | 216 | def schema(self): 217 | return StructType([ 218 | StructField("id", IntegerType(), nullable=False), 219 | StructField("name", StringType(), nullable=True), 220 | StructField("tags", ArrayType(StringType()), nullable=True), 221 | StructField("metadata", MapType(StringType(), StringType()), nullable=True), 222 | StructField("created_at", TimestampType(), nullable=True), 223 | ]) 224 | ``` 225 | 226 | ### 3. Data Types Support 227 | 228 | Supported Spark SQL data types: 229 | 230 | | Python Type | Spark Type | Notes | 231 | |-------------|------------|-------| 232 | | `int` | `IntegerType`, `LongType` | Use LongType for large integers | 233 | | `float` | `FloatType`, `DoubleType` | DoubleType is default for decimals | 234 | | `str` | `StringType` | Unicode strings | 235 | | `bool` | `BooleanType` | True/False | 236 | | `datetime.date` | `DateType` | Date without time | 237 | | `datetime.datetime` | `TimestampType` | Date with time | 238 | | `bytes` | `BinaryType` | Raw bytes | 239 | | `list` | `ArrayType(elementType)` | Homogeneous arrays | 240 | | `dict` | `MapType(keyType, valueType)` | Key-value pairs | 241 | | `tuple/Row` | `StructType([fields])` | Nested structures | 242 | 243 | ### 4. Error Handling 244 | 245 | Implement proper exception handling: 246 | 247 | ```python 248 | def read(self, partition): 249 | try: 250 | data = fetch_external_data() 251 | yield from data 252 | except NetworkError as e: 253 | if self.options.get("failOnError", "true") == "true": 254 | raise DataSourceError(f"Failed to fetch data: {e}") 255 | else: 256 | # Log and continue 257 | logger.warning(f"Skipping partition due to error: {e}") 258 | return # Empty iterator 259 | ``` 260 | 261 | ### 5. Resource Management 262 | 263 | Clean up resources properly: 264 | 265 | ```python 266 | class ManagedReader(DataSourceStreamReader): 267 | def stop(self): 268 | """Called when streaming query stops.""" 269 | if hasattr(self, 'connection_pool'): 270 | self.connection_pool.close() 271 | if hasattr(self, 'temp_files'): 272 | for f in self.temp_files: 273 | os.remove(f) 274 | ``` 275 | 276 | ### 6. File Path Convention 277 | 278 | When dealing with file paths: 279 | 280 | ```python 281 | # ✅ CORRECT - Path in load() 282 | df = spark.read.format("myformat").load("/path/to/data") 283 | 284 | # ❌ INCORRECT - Path in option() 285 | df = spark.read.format("myformat").option("path", "/path/to/data").load() 286 | ``` 287 | 288 | ### 7. Lazy Initialization 289 | 290 | Defer expensive operations: 291 | 292 | ```python 293 | class LazyReader(DataSourceReader): 294 | def __init__(self, options): 295 | # ✅ Just store configuration 296 | self.api_url = options['url'] 297 | self.api_key = options['apiKey'] 298 | 299 | def read(self, partition): 300 | # ✅ Expensive operations here 301 | session = create_authenticated_session(self.api_key) 302 | data = session.get(self.api_url) 303 | yield from parse_data(data) 304 | ``` 305 | 306 | ## Performance Optimization Techniques 307 | 308 | ### 1. Arrow Integration 309 | 310 | Return `pyarrow.RecordBatch` for better serialization: 311 | 312 | ```python 313 | import pyarrow as pa 314 | 315 | def read(self, partition): 316 | # Read data efficiently 317 | data = fetch_data_batch() 318 | 319 | # Convert to Arrow 320 | batch = pa.RecordBatch.from_pandas(data) 321 | yield batch 322 | ``` 323 | 324 | Benefits: 325 | - Zero-copy data transfer when possible 326 | - Columnar format efficiency 327 | - Better memory usage 328 | 329 | ### 2. Partitioning Strategy 330 | 331 | Effective partitioning for parallelism: 332 | 333 | ```python 334 | def partitions(self): 335 | # Example: Partition by date ranges 336 | dates = get_date_range(self.start_date, self.end_date) 337 | return [InputPartition({"date": date}) for date in dates] 338 | 339 | def read(self, partition): 340 | date = partition.value["date"] 341 | # Read only data for this date 342 | yield from read_data_for_date(date) 343 | ``` 344 | 345 | Guidelines: 346 | - Aim for equal-sized partitions 347 | - Consider data locality 348 | - Balance between too many small partitions and too few large ones 349 | 350 | ### 3. Batch Processing 351 | 352 | Process data in batches: 353 | 354 | ```python 355 | def read(self, partition): 356 | batch_size = 1000 357 | offset = 0 358 | 359 | while True: 360 | batch = fetch_batch(offset, batch_size) 361 | if not batch: 362 | break 363 | 364 | # Process entire batch at once 365 | processed = process_batch(batch) 366 | yield from processed 367 | 368 | offset += batch_size 369 | ``` 370 | 371 | ### 4. Connection Pooling 372 | 373 | Reuse connections within partitions: 374 | 375 | ```python 376 | def read(self, partition): 377 | # Create connection once per partition 378 | conn = create_connection() 379 | try: 380 | # Reuse for all records in partition 381 | for record_id in partition.value["ids"]: 382 | data = fetch_with_connection(conn, record_id) 383 | yield data 384 | finally: 385 | conn.close() 386 | ``` 387 | 388 | ### 5. Caching Strategies 389 | 390 | Cache frequently accessed data: 391 | 392 | ```python 393 | class CachedReader(DataSourceReader): 394 | _cache = {} # Class-level cache 395 | 396 | def read(self, partition): 397 | cache_key = partition.value["key"] 398 | 399 | if cache_key not in self._cache: 400 | self._cache[cache_key] = fetch_expensive_data(cache_key) 401 | 402 | data = self._cache[cache_key] 403 | yield from process_data(data) 404 | ``` 405 | 406 | ## Usage Examples 407 | 408 | ### Registration and Basic Usage 409 | 410 | ```python 411 | from pyspark.sql import SparkSession 412 | 413 | spark = SparkSession.builder.appName("datasource-example").getOrCreate() 414 | 415 | # Register the data source 416 | spark.dataSource.register(MyDataSource) 417 | 418 | # Use with format() 419 | df = spark.read.format("myformat").option("key", "value").load() 420 | 421 | # Use with streaming 422 | stream = spark.readStream.format("myformat").load() 423 | ``` 424 | 425 | ### Schema Handling 426 | 427 | ```python 428 | # Let data source define schema 429 | df = spark.read.format("myformat").load() 430 | 431 | # Override with custom schema 432 | custom_schema = StructType([ 433 | StructField("id", LongType()), 434 | StructField("value", DoubleType()) 435 | ]) 436 | df = spark.read.format("myformat").schema(custom_schema).load() 437 | 438 | # Schema subset - read only specific columns 439 | df = spark.read.format("myformat").load().select("id", "name") 440 | ``` 441 | 442 | ### Options Pattern 443 | 444 | ```python 445 | df = spark.read.format("myformat") \ 446 | .option("url", "https://api.example.com") \ 447 | .option("apiKey", "secret-key") \ 448 | .option("timeout", "30") \ 449 | .option("retries", "3") \ 450 | .option("batchSize", "1000") \ 451 | .load() 452 | ``` 453 | 454 | ## Common Pitfalls and Solutions 455 | 456 | ### Pitfall 1: Non-Serializable Classes 457 | 458 | **Problem**: Storing database connections, HTTP clients, or other non-serializable objects as instance variables. 459 | 460 | **Solution**: Create these objects in `read()` method. 461 | 462 | ### Pitfall 2: Ignoring Schema Parameter 463 | 464 | **Problem**: Always using `self.schema()` instead of the schema parameter passed to `reader()`. 465 | 466 | **Solution**: Respect user-specified schema for column pruning. 467 | 468 | ### Pitfall 3: Blocking in Constructor 469 | 470 | **Problem**: Making API calls or heavy I/O in `__init__`. 471 | 472 | **Solution**: Defer to `read()` method for lazy evaluation. 473 | 474 | ### Pitfall 4: No Partitioning 475 | 476 | **Problem**: Not implementing `partitions()`, limiting parallelism. 477 | 478 | **Solution**: Implement logical partitioning based on your data source. 479 | 480 | ### Pitfall 5: Poor Error Messages 481 | 482 | **Problem**: Generic exceptions without context. 483 | 484 | **Solution**: Provide detailed, actionable error messages. 485 | 486 | ### Pitfall 6: Resource Leaks 487 | 488 | **Problem**: Not cleaning up connections, files, or memory. 489 | 490 | **Solution**: Use try/finally blocks and implement `stop()` for streaming. 491 | 492 | ### Pitfall 7: Inefficient Data Transfer 493 | 494 | **Problem**: Yielding individual rows for large datasets. 495 | 496 | **Solution**: Use Arrow RecordBatch or batch processing. 497 | 498 | ## Testing Guidelines 499 | 500 | ### Unit Testing 501 | 502 | ```python 503 | import pytest 504 | from pyspark.sql import SparkSession 505 | 506 | @pytest.fixture(scope="session") 507 | def spark(): 508 | return SparkSession.builder \ 509 | .appName("test") \ 510 | .master("local[2]") \ 511 | .config("spark.sql.shuffle.partitions", "2") \ 512 | .getOrCreate() 513 | 514 | def test_reader(spark): 515 | spark.dataSource.register(MyDataSource) 516 | df = spark.read.format("myformat").load() 517 | assert df.count() > 0 518 | assert df.schema == expected_schema 519 | ``` 520 | 521 | ### Integration Testing 522 | 523 | ```python 524 | def test_with_real_data(spark, real_api_key): 525 | spark.dataSource.register(MyDataSource) 526 | 527 | df = spark.read.format("myformat") \ 528 | .option("apiKey", real_api_key) \ 529 | .load() 530 | 531 | # Verify real data 532 | assert df.count() > 0 533 | df.write.mode("overwrite").parquet("/tmp/test_output") 534 | ``` 535 | 536 | ### Performance Testing 537 | 538 | ```python 539 | def test_performance(spark, benchmark): 540 | spark.dataSource.register(MyDataSource) 541 | 542 | with benchmark("read_performance"): 543 | df = spark.read.format("myformat") \ 544 | .option("numRows", "1000000") \ 545 | .load() 546 | count = df.count() 547 | 548 | assert count == 1000000 549 | benchmark.assert_timing(max_seconds=10) 550 | ``` 551 | 552 | ## Additional Resources 553 | 554 | - [Apache Spark Python Data Source API](https://spark.apache.org/docs/4.0.0/api/python/tutorial/sql/python_data_source.html) 555 | - [Source Code](https://github.com/apache/spark/blob/master/python/pyspark/sql/datasource.py) 556 | - [Example Implementations](https://github.com/allisonwang-db/pyspark-data-sources) 557 | - [Simple Stream Reader Architecture](../contributing/simple-stream-reader-architecture.md) -------------------------------------------------------------------------------- /docs/data-sources-guide.md: -------------------------------------------------------------------------------- 1 | # Data Sources Guide 2 | 3 | This guide provides detailed examples and usage patterns for all available data sources in the PySpark Data Sources library. 4 | 5 | ## Table of Contents 6 | 1. [FakeDataSource - Generate Synthetic Data](#1-fakedatasource---generate-synthetic-data) 7 | 2. [GitHubDataSource - Read GitHub Pull Requests](#2-githubdatasource---read-github-pull-requests) 8 | 3. [GoogleSheetsDataSource - Read Public Google Sheets](#3-googlesheetsdatasource---read-public-google-sheets) 9 | 4. [HuggingFaceDataSource - Load Datasets from Hugging Face Hub](#4-huggingfacedatasource---load-datasets-from-hugging-face-hub) 10 | 5. [StockDataSource - Fetch Stock Market Data](#5-stockdatasource---fetch-stock-market-data) 11 | 6. [OpenSkyDataSource - Stream Live Flight Data](#6-openskydatasource---stream-live-flight-data) 12 | 7. [KaggleDataSource - Load Kaggle Datasets](#7-kaggledatasource---load-kaggle-datasets) 13 | 8. [ArrowDataSource - Read Apache Arrow Files](#8-arrowdatasource---read-apache-arrow-files) 14 | 9. [LanceDataSource - Vector Database Format](#9-lancedatasource---vector-database-format) 15 | 16 | ## 1. FakeDataSource - Generate Synthetic Data 17 | 18 | Generate fake data using the Faker library for testing and development. 19 | 20 | ### Installation 21 | ```bash 22 | pip install pyspark-data-sources[faker] 23 | ``` 24 | 25 | ### Basic Usage 26 | ```python 27 | from pyspark_datasources import FakeDataSource 28 | 29 | spark.dataSource.register(FakeDataSource) 30 | 31 | # Basic usage with default schema 32 | df = spark.read.format("fake").load() 33 | df.show() 34 | # +-----------+----------+-------+-------+ 35 | # | name| date|zipcode| state| 36 | # +-----------+----------+-------+-------+ 37 | # |Carlos Cobb|2018-07-15| 73003|Indiana| 38 | # | Eric Scott|1991-08-22| 10085| Idaho| 39 | # | Amy Martin|1988-10-28| 68076| Oregon| 40 | # +-----------+----------+-------+-------+ 41 | ``` 42 | 43 | ### Custom Schema 44 | Field names must match Faker provider methods. See [Faker documentation](https://faker.readthedocs.io/en/master/providers.html) for available providers. 45 | 46 | ```python 47 | # Custom schema - field names must match Faker provider methods 48 | df = spark.read.format("fake") \ 49 | .schema("name string, email string, phone_number string, company string") \ 50 | .option("numRows", 5) \ 51 | .load() 52 | df.show(truncate=False) 53 | # +------------------+-------------------------+----------------+-----------------+ 54 | # |name |email |phone_number |company | 55 | # +------------------+-------------------------+----------------+-----------------+ 56 | # |Christine Sampson |johnsonjeremy@example.com|+1-673-684-4608 |Hernandez-Nguyen | 57 | # |Yolanda Brown |williamlowe@example.net |(262)562-4152 |Miller-Hernandez | 58 | # |Joshua Hernandez |mary38@example.com |7785366623 |Davis Group | 59 | # |Joseph Gallagher |katiepatterson@example.net|+1-648-619-0997 |Brown-Fleming | 60 | # |Tina Morrison |johnbell@example.com |(684)329-3298 |Sherman PLC | 61 | # +------------------+-------------------------+----------------+-----------------+ 62 | ``` 63 | 64 | ### Streaming Usage 65 | ```python 66 | # Streaming usage - generates continuous data 67 | stream = spark.readStream.format("fake") \ 68 | .option("rowsPerMicrobatch", 10) \ 69 | .load() 70 | 71 | query = stream.writeStream \ 72 | .format("console") \ 73 | .outputMode("append") \ 74 | .trigger(processingTime="5 seconds") \ 75 | .start() 76 | ``` 77 | 78 | ## 2. GitHubDataSource - Read GitHub Pull Requests 79 | 80 | Fetch pull request data from GitHub repositories using the public GitHub API. 81 | 82 | ### Basic Usage 83 | ```python 84 | from pyspark_datasources import GithubDataSource 85 | 86 | spark.dataSource.register(GithubDataSource) 87 | 88 | # Read pull requests from a public repository 89 | df = spark.read.format("github").load("apache/spark") 90 | df.select("number", "title", "state", "user", "created_at").show(truncate=False) 91 | # +------+--------------------------------------------+------+---------------+--------------------+ 92 | # |number|title |state |user |created_at | 93 | # +------+--------------------------------------------+------+---------------+--------------------+ 94 | # |48730 |[SPARK-49998] Fix JSON parsing regression |open |john_doe |2024-11-20T10:15:30Z| 95 | # |48729 |[SPARK-49997] Improve error messages |merged|contributor123 |2024-11-19T15:22:45Z| 96 | # +------+--------------------------------------------+------+---------------+--------------------+ 97 | 98 | # Schema of the GitHub data source 99 | df.printSchema() 100 | # root 101 | # |-- number: integer (nullable = true) 102 | # |-- title: string (nullable = true) 103 | # |-- state: string (nullable = true) 104 | # |-- user: string (nullable = true) 105 | # |-- created_at: string (nullable = true) 106 | # |-- updated_at: string (nullable = true) 107 | ``` 108 | 109 | ### Notes 110 | - Uses the public GitHub API (rate limited to 60 requests/hour for unauthenticated requests) 111 | - Returns the most recent pull requests 112 | - Repository format: "owner/repository" 113 | 114 | ## 3. GoogleSheetsDataSource - Read Public Google Sheets 115 | 116 | Read data from publicly accessible Google Sheets. 117 | 118 | ### Basic Usage 119 | ```python 120 | from pyspark_datasources import GoogleSheetsDataSource 121 | 122 | spark.dataSource.register(GoogleSheetsDataSource) 123 | 124 | # Read from a public Google Sheet 125 | sheet_url = "https://docs.google.com/spreadsheets/d/1H7bKPGpAXbPRhTYFxqg1h6FmKl5ZhCTM_5OlAqCHfVs" 126 | df = spark.read.format("googlesheets").load(sheet_url) 127 | df.show() 128 | # +------+-------+--------+ 129 | # |Name |Age |City | 130 | # +------+-------+--------+ 131 | # |Alice |25 |NYC | 132 | # |Bob |30 |LA | 133 | # |Carol |28 |Chicago | 134 | # +------+-------+--------+ 135 | ``` 136 | 137 | ### Specify Sheet Name or Index 138 | ```python 139 | # Read a specific sheet by name 140 | df = spark.read.format("googlesheets") \ 141 | .option("sheetName", "Sheet2") \ 142 | .load(sheet_url) 143 | 144 | # Or by index (0-based) 145 | df = spark.read.format("googlesheets") \ 146 | .option("sheetIndex", "1") \ 147 | .load(sheet_url) 148 | ``` 149 | 150 | ### Requirements 151 | - Sheet must be publicly accessible (anyone with link can view) 152 | - First row is used as column headers 153 | - Data is read as strings by default 154 | 155 | ## 4. HuggingFaceDataSource - Load Datasets from Hugging Face Hub 156 | 157 | Access thousands of datasets from the Hugging Face Hub. 158 | 159 | ### Installation 160 | ```bash 161 | pip install pyspark-data-sources[huggingface] 162 | ``` 163 | 164 | ### Basic Usage 165 | ```python 166 | from pyspark_datasources import HuggingFaceDataSource 167 | 168 | spark.dataSource.register(HuggingFaceDataSource) 169 | 170 | # Load a dataset from Hugging Face 171 | df = spark.read.format("huggingface").load("imdb") 172 | df.select("text", "label").show(2, truncate=False) 173 | # +--------------------------------------------------+-----+ 174 | # |text |label| 175 | # +--------------------------------------------------+-----+ 176 | # |I rented this movie last night and I must say... |0 | 177 | # |This film is absolutely fantastic! The acting... |1 | 178 | # +--------------------------------------------------+-----+ 179 | ``` 180 | 181 | ### Advanced Options 182 | ```python 183 | # Load specific split 184 | df = spark.read.format("huggingface") \ 185 | .option("split", "train") \ 186 | .load("squad") 187 | 188 | # Load dataset with configuration 189 | df = spark.read.format("huggingface") \ 190 | .option("config", "plain_text") \ 191 | .load("wikipedia") 192 | 193 | # Load specific subset 194 | df = spark.read.format("huggingface") \ 195 | .option("split", "validation") \ 196 | .option("config", "en") \ 197 | .load("wikipedia") 198 | ``` 199 | 200 | ### Performance Tips 201 | - The data source automatically partitions large datasets for parallel processing 202 | - First load may be slow as it downloads and caches the dataset 203 | - Subsequent reads use the cached data 204 | 205 | ## 5. StockDataSource - Fetch Stock Market Data 206 | 207 | Get stock market data from Alpha Vantage API. 208 | 209 | ### Setup 210 | Obtain a free API key from [Alpha Vantage](https://www.alphavantage.co/support/#api-key). 211 | 212 | ### Basic Usage 213 | ```python 214 | from pyspark_datasources import StockDataSource 215 | 216 | spark.dataSource.register(StockDataSource) 217 | 218 | # Fetch stock data (requires Alpha Vantage API key) 219 | df = spark.read.format("stock") \ 220 | .option("symbols", "AAPL,GOOGL,MSFT") \ 221 | .option("api_key", "YOUR_API_KEY") \ 222 | .load() 223 | 224 | df.show() 225 | # +------+----------+------+------+------+------+--------+ 226 | # |symbol|timestamp |open |high |low |close |volume | 227 | # +------+----------+------+------+------+------+--------+ 228 | # |AAPL |2024-11-20|175.50|178.25|175.00|177.80|52341200| 229 | # |GOOGL |2024-11-20|142.30|143.90|141.50|143.25|18234500| 230 | # |MSFT |2024-11-20|425.10|428.75|424.50|427.90|12456300| 231 | # +------+----------+------+------+------+------+--------+ 232 | 233 | # Schema 234 | df.printSchema() 235 | # root 236 | # |-- symbol: string (nullable = true) 237 | # |-- timestamp: string (nullable = true) 238 | # |-- open: double (nullable = true) 239 | # |-- high: double (nullable = true) 240 | # |-- low: double (nullable = true) 241 | # |-- close: double (nullable = true) 242 | # |-- volume: long (nullable = true) 243 | ``` 244 | 245 | ### Options 246 | - `symbols`: Comma-separated list of stock symbols 247 | - `api_key`: Your Alpha Vantage API key 248 | - `function`: Time series function (default: "TIME_SERIES_DAILY") 249 | 250 | ## 6. OpenSkyDataSource - Stream Live Flight Data 251 | 252 | Stream real-time flight tracking data from OpenSky Network. 253 | 254 | ### Streaming Usage 255 | ```python 256 | from pyspark_datasources import OpenSkyDataSource 257 | 258 | spark.dataSource.register(OpenSkyDataSource) 259 | 260 | # Stream flight data for a specific region 261 | stream = spark.readStream.format("opensky") \ 262 | .option("region", "EUROPE") \ 263 | .load() 264 | 265 | # Process the stream 266 | query = stream.select("icao24", "callsign", "origin_country", "longitude", "latitude", "altitude") \ 267 | .writeStream \ 268 | .format("console") \ 269 | .outputMode("append") \ 270 | .trigger(processingTime="10 seconds") \ 271 | .start() 272 | 273 | # Sample output: 274 | # +------+--------+--------------+---------+--------+--------+ 275 | # |icao24|callsign|origin_country|longitude|latitude|altitude| 276 | # +------+--------+--------------+---------+--------+--------+ 277 | # |4b1806|SWR123 |Switzerland |8.5123 |47.3765 |10058.4 | 278 | # |3c6750|DLH456 |Germany |13.4050 |52.5200 |11887.2 | 279 | # +------+--------+--------------+---------+--------+--------+ 280 | ``` 281 | 282 | ### Batch Usage 283 | ```python 284 | # Read current flight data as batch 285 | df = spark.read.format("opensky") \ 286 | .option("region", "USA") \ 287 | .load() 288 | 289 | df.count() # Number of flights currently in USA airspace 290 | ``` 291 | 292 | ### Region Options 293 | - `EUROPE`: European airspace 294 | - `USA`: United States airspace 295 | - `ASIA`: Asian airspace 296 | - `WORLD`: Global (all flights) 297 | 298 | ## 7. KaggleDataSource - Load Kaggle Datasets 299 | 300 | Download and read datasets from Kaggle. 301 | 302 | ### Setup 303 | 1. Install Kaggle API: `pip install pyspark-data-sources[kaggle]` 304 | 2. Get API credentials from https://www.kaggle.com/account 305 | 3. Place credentials in `~/.kaggle/kaggle.json` 306 | 307 | ### Basic Usage 308 | ```python 309 | from pyspark_datasources import KaggleDataSource 310 | 311 | spark.dataSource.register(KaggleDataSource) 312 | 313 | # Load a Kaggle dataset 314 | df = spark.read.format("kaggle").load("titanic") 315 | df.select("PassengerId", "Name", "Age", "Survived").show(5) 316 | # +-----------+--------------------+----+--------+ 317 | # |PassengerId|Name |Age |Survived| 318 | # +-----------+--------------------+----+--------+ 319 | # |1 |Braund, Mr. Owen |22.0|0 | 320 | # |2 |Cumings, Mrs. John |38.0|1 | 321 | # |3 |Heikkinen, Miss Laina|26.0|1 | 322 | # |4 |Futrelle, Mrs Jacques|35.0|1 | 323 | # |5 |Allen, Mr. William |35.0|0 | 324 | # +-----------+--------------------+----+--------+ 325 | ``` 326 | 327 | ### Multi-file Datasets 328 | ```python 329 | # Load specific file from multi-file dataset 330 | df = spark.read.format("kaggle") \ 331 | .option("file", "train.csv") \ 332 | .load("competitionname/datasetname") 333 | 334 | # Load all CSV files 335 | df = spark.read.format("kaggle") \ 336 | .option("pattern", "*.csv") \ 337 | .load("multi-file-dataset") 338 | ``` 339 | 340 | ## 8. ArrowDataSource - Read Apache Arrow Files 341 | 342 | Efficiently read Arrow format files with zero-copy operations. 343 | 344 | ### Installation 345 | ```bash 346 | pip install pyspark-data-sources[arrow] 347 | ``` 348 | 349 | ### Usage 350 | ```python 351 | from pyspark_datasources import ArrowDataSource 352 | 353 | spark.dataSource.register(ArrowDataSource) 354 | 355 | # Read Arrow files 356 | df = spark.read.format("arrow").load("/path/to/data.arrow") 357 | 358 | # Read multiple Arrow files 359 | df = spark.read.format("arrow").load("/path/to/arrow/files/*.arrow") 360 | 361 | # The Arrow reader is optimized for performance 362 | df.count() # Fast counting due to Arrow metadata 363 | 364 | # Schema is preserved from Arrow files 365 | df.printSchema() 366 | ``` 367 | 368 | ### Performance Benefits 369 | - Zero-copy reads when possible 370 | - Preserves Arrow metadata 371 | - Efficient columnar data access 372 | - Automatic schema inference 373 | 374 | ## 9. LanceDataSource - Vector Database Format 375 | 376 | Read and write Lance format data, optimized for vector/ML workloads. 377 | 378 | ### Installation 379 | ```bash 380 | pip install pyspark-data-sources[lance] 381 | ``` 382 | 383 | ### Write Data 384 | ```python 385 | from pyspark_datasources import LanceDataSource 386 | 387 | spark.dataSource.register(LanceDataSource) 388 | 389 | # Prepare your DataFrame 390 | df = spark.range(1000).selectExpr( 391 | "id", 392 | "array(rand(), rand(), rand()) as vector", 393 | "rand() as score" 394 | ) 395 | 396 | # Write DataFrame to Lance format 397 | df.write.format("lance") \ 398 | .mode("overwrite") \ 399 | .save("/path/to/lance/dataset") 400 | ``` 401 | 402 | ### Read Data 403 | ```python 404 | # Read Lance dataset 405 | lance_df = spark.read.format("lance") \ 406 | .load("/path/to/lance/dataset") 407 | 408 | lance_df.show(5) 409 | 410 | # Lance preserves vector columns efficiently 411 | lance_df.printSchema() 412 | # root 413 | # |-- id: long (nullable = false) 414 | # |-- vector: array (nullable = true) 415 | # | |-- element: double (containsNull = true) 416 | # |-- score: double (nullable = true) 417 | ``` 418 | 419 | ### Features 420 | - Optimized for vector/embedding data 421 | - Efficient storage of array columns 422 | - Fast random access 423 | - Version control built-in 424 | 425 | ## Common Patterns 426 | 427 | ### Error Handling 428 | Most data sources will raise informative errors when required options are missing: 429 | 430 | ```python 431 | # This will raise an error about missing API key 432 | df = spark.read.format("stock") \ 433 | .option("symbols", "AAPL") \ 434 | .load() 435 | # ValueError: api_key option is required for StockDataSource 436 | ``` 437 | 438 | ### Schema Inference vs Specification 439 | Some data sources support both automatic schema inference and explicit schema specification: 440 | 441 | ```python 442 | # Automatic schema inference 443 | df = spark.read.format("fake").load() 444 | 445 | # Explicit schema specification 446 | from pyspark.sql.types import StructType, StructField, StringType 447 | 448 | schema = StructType([ 449 | StructField("first_name", StringType()), 450 | StructField("last_name", StringType()), 451 | StructField("email", StringType()) 452 | ]) 453 | 454 | df = spark.read.format("fake") \ 455 | .schema(schema) \ 456 | .load() 457 | ``` 458 | 459 | ### Partitioning for Performance 460 | Many data sources automatically partition data for parallel processing: 461 | 462 | ```python 463 | # HuggingFace automatically partitions large datasets 464 | df = spark.read.format("huggingface").load("wikipedia") 465 | df.rdd.getNumPartitions() # Returns number of partitions 466 | 467 | # You can control partitioning in some data sources 468 | df = spark.read.format("fake") \ 469 | .option("numPartitions", "8") \ 470 | .load() 471 | ``` 472 | 473 | ## Troubleshooting 474 | 475 | ### Common Issues 476 | 477 | 1. **Missing Dependencies** 478 | ```bash 479 | # Install specific extras 480 | pip install pyspark-data-sources[faker,huggingface,kaggle] 481 | ``` 482 | 483 | 2. **API Rate Limits** 484 | - GitHub: 60 requests/hour (unauthenticated) 485 | - Alpha Vantage: 5 requests/minute (free tier) 486 | - Use caching or implement retry logic 487 | 488 | 3. **Network Issues** 489 | - Most data sources require internet access 490 | - Check firewall/proxy settings 491 | - Some sources support offline mode after initial cache 492 | 493 | 4. **Memory Issues with Large Datasets** 494 | - Use partitioning for large datasets 495 | - Consider sampling: `df.sample(0.1)` 496 | - Increase Spark executor memory 497 | 498 | For more help, see the [Development Guide](../contributing/DEVELOPMENT.md) or open an issue on GitHub. -------------------------------------------------------------------------------- /docs/simple-stream-reader-architecture.md: -------------------------------------------------------------------------------- 1 | # SimpleDataSourceStreamReader Architecture 2 | 3 | ## Overview 4 | 5 | `SimpleDataSourceStreamReader` is a lightweight streaming data source reader in PySpark designed for scenarios with small data volumes and low throughput requirements. Unlike the standard `DataSourceStreamReader`, it executes entirely on the driver node, trading scalability for simplicity. 6 | 7 | ## Key Architecture Components 8 | 9 | ### Python-Side Components 10 | 11 | #### SimpleDataSourceStreamReader (datasource.py) 12 | The user-facing API with three core methods: 13 | - `initialOffset()`: Returns the starting position for a new streaming query 14 | - `read(start)`: Reads all available data from a given offset and returns both the data and the next offset 15 | - `readBetweenOffsets(start, end)`: Re-reads data deterministically for failure recovery 16 | 17 | #### _SimpleStreamReaderWrapper (datasource_internal.py) 18 | A private wrapper that implements the prefetch-and-cache pattern: 19 | - Maintains `current_offset` to track reading progress 20 | - Caches prefetched data in memory on the driver 21 | - Converts simple reader interface to standard streaming reader interface 22 | 23 | ### Scala-Side Components 24 | 25 | #### PythonMicroBatchStream (PythonMicroBatchStream.scala) 26 | Manages the micro-batch execution: 27 | - Creates and manages `PythonStreamingSourceRunner` for Python communication 28 | - Stores prefetched data in BlockManager with `PythonStreamBlockId` 29 | - Handles offset management and partition planning 30 | 31 | #### PythonStreamingSourceRunner (PythonStreamingSourceRunner.scala) 32 | The bridge between JVM and Python: 33 | - Spawns a Python worker process running `python_streaming_source_runner.py` 34 | - Serializes/deserializes data using Arrow format 35 | - Manages RPC-style communication for method invocations 36 | 37 | ## Data Flow and Lifecycle 38 | 39 | ### Query Initialization 40 | 1. Spark creates `PythonMicroBatchStream` when a streaming query starts 41 | 2. `PythonStreamingSourceRunner` spawns a Python worker process 42 | 3. Python worker instantiates the `SimpleDataSourceStreamReader` 43 | 4. Initial offset is obtained via `initialOffset()` call 44 | 45 | ### Micro-batch Execution (per trigger) 46 | 47 | #### 1. Offset Discovery (Driver) 48 | - Spark calls `latestOffset()` on `PythonMicroBatchStream` 49 | - Runner invokes Python's `latestOffset()` via RPC 50 | - Wrapper calls `simple_reader.read(current_offset)` to prefetch data 51 | - Data and new offset are returned and cached 52 | 53 | #### 2. Data Caching (Driver) 54 | - Prefetched records are converted to Arrow batches 55 | - Data is stored in BlockManager with a unique `PythonStreamBlockId` 56 | - Cache entry maintains mapping of (start_offset, end_offset) → data 57 | 58 | #### 3. Partition Planning (Driver) 59 | - `planInputPartitions(start, end)` creates a single `PythonStreamingInputPartition` 60 | - Partition references the cached block ID 61 | - No actual data distribution happens (single partition on driver) 62 | 63 | #### 4. Data Reading (Executor) 64 | - Executor retrieves cached data from BlockManager using block ID 65 | - Data is already in Arrow format for efficient processing 66 | - Records are converted to internal rows for downstream processing 67 | 68 | ## Integration with Spark Structured Streaming APIs 69 | 70 | ### User API Integration 71 | 72 | ```python 73 | # User defines a SimpleDataSourceStreamReader 74 | class MyStreamReader(SimpleDataSourceStreamReader): 75 | def initialOffset(self): 76 | return {"position": 0} 77 | 78 | def read(self, start): 79 | # Read data from source 80 | data = fetch_data_since(start["position"]) 81 | new_offset = {"position": start["position"] + len(data)} 82 | return (iter(data), new_offset) 83 | 84 | def readBetweenOffsets(self, start, end): 85 | # Re-read for failure recovery 86 | return fetch_data_between(start["position"], end["position"]) 87 | 88 | # Register and use with Spark 89 | class MyDataSource(DataSource): 90 | def simpleStreamReader(self, schema): 91 | return MyStreamReader() 92 | 93 | spark.dataSource.register(MyDataSource) 94 | df = spark.readStream.format("my_source").load() 95 | query = df.writeStream.format("console").start() 96 | ``` 97 | 98 | ### Streaming Engine Integration 99 | 100 | 1. **Trigger Processing**: Works with all trigger modes (ProcessingTime, Once, AvailableNow) 101 | 2. **Offset Management**: Offsets are checkpointed to WAL for exactly-once semantics 102 | 3. **Failure Recovery**: Uses `readBetweenOffsets()` to replay uncommitted batches 103 | 4. **Commit Protocol**: After successful batch, `commit(offset)` is called for cleanup 104 | 105 | ## Execution Flow Diagram 106 | 107 | ``` 108 | Driver Node Python Worker Executors 109 | ----------- ------------- --------- 110 | PythonMicroBatchStream 111 | | 112 | ├─> latestOffset() ──────────> PythonStreamingSourceRunner 113 | | | 114 | | ├─> RPC: LATEST_OFFSET ──> SimpleStreamReaderWrapper 115 | | | | 116 | | | ├─> read(current_offset) 117 | | | | └─> (data, new_offset) 118 | | | | 119 | | |<── Arrow batches ────────────┘ 120 | | | 121 | ├─> Cache in BlockManager <───────────┘ 122 | | (PythonStreamBlockId) 123 | | 124 | ├─> planInputPartitions() 125 | | └─> Single partition with BlockId 126 | | 127 | └─> createReaderFactory() ─────────────────────────────────────> Read from BlockManager 128 | | 129 | └─> Process records 130 | ``` 131 | 132 | ## Key Design Decisions and Trade-offs 133 | 134 | ### Advantages 135 | - **Simplicity**: No need to implement partitioning logic 136 | - **Consistency**: All data reading happens in one place (driver) 137 | - **Efficiency for small data**: Avoids overhead of distributed execution 138 | - **Easy offset management**: Single reader maintains consistent view of progress 139 | - **Quick development**: Minimal boilerplate for simple streaming sources 140 | 141 | ### Limitations 142 | - **Not scalable**: All data flows through driver (bottleneck) 143 | - **Memory constraints**: Driver must cache entire micro-batch 144 | - **Single point of failure**: Driver failure affects data reading 145 | - **Network overhead**: Data must be transferred from driver to executors 146 | - **Throughput ceiling**: Limited by driver's processing capacity 147 | 148 | ### Important Note from Source Code 149 | From datasource.py: 150 | > "Because SimpleDataSourceStreamReader read records in Spark driver node to determine end offset of each batch without partitioning, it is only supposed to be used in lightweight use cases where input rate and batch size is small." 151 | 152 | ## Use Cases 153 | 154 | ### Ideal for: 155 | - Configuration change streams 156 | - Small lookup table updates 157 | - Low-volume event streams (< 1000 records/sec) 158 | - Prototyping and testing streaming applications 159 | - REST API polling with low frequency 160 | - File system monitoring for small files 161 | - Message queue consumers with low throughput 162 | 163 | ### Not suitable for: 164 | - High-throughput data sources (use `DataSourceStreamReader` instead) 165 | - Large batch sizes that exceed driver memory 166 | - Sources requiring parallel reads for performance 167 | - Production workloads with high availability requirements 168 | - Kafka topics with high message rates 169 | - Large file streaming 170 | 171 | ## Implementation Example: File Monitor 172 | 173 | ```python 174 | import os 175 | import json 176 | from typing import Iterator, Tuple, Dict 177 | from pyspark.sql.datasource import SimpleDataSourceStreamReader 178 | 179 | class FileMonitorStreamReader(SimpleDataSourceStreamReader): 180 | def __init__(self, path: str): 181 | self.path = path 182 | 183 | def initialOffset(self) -> Dict: 184 | # Start with empty file list 185 | return {"processed_files": []} 186 | 187 | def read(self, start: Dict) -> Tuple[Iterator[Tuple], Dict]: 188 | processed = set(start.get("processed_files", [])) 189 | current_files = set(os.listdir(self.path)) 190 | new_files = current_files - processed 191 | 192 | # Read content from new files 193 | data = [] 194 | for filename in new_files: 195 | filepath = os.path.join(self.path, filename) 196 | if os.path.isfile(filepath): 197 | with open(filepath, 'r') as f: 198 | content = f.read() 199 | data.append((filename, content)) 200 | 201 | # Update offset 202 | new_offset = {"processed_files": list(current_files)} 203 | 204 | return (iter(data), new_offset) 205 | 206 | def readBetweenOffsets(self, start: Dict, end: Dict) -> Iterator[Tuple]: 207 | # For recovery: re-read files that were added between start and end 208 | start_files = set(start.get("processed_files", [])) 209 | end_files = set(end.get("processed_files", [])) 210 | files_to_read = end_files - start_files 211 | 212 | data = [] 213 | for filename in files_to_read: 214 | filepath = os.path.join(self.path, filename) 215 | if os.path.exists(filepath): 216 | with open(filepath, 'r') as f: 217 | content = f.read() 218 | data.append((filename, content)) 219 | 220 | return iter(data) 221 | ``` 222 | 223 | ## Performance Considerations 224 | 225 | ### Memory Management 226 | - Driver memory should be configured to handle maximum expected batch size 227 | - Use `spark.driver.memory` and `spark.driver.maxResultSize` appropriately 228 | - Monitor driver memory usage during streaming 229 | 230 | ### Optimization Tips 231 | 1. **Batch Size Control**: Implement rate limiting in `read()` method 232 | 2. **Data Compression**: Use efficient serialization formats (Parquet, Arrow) 233 | 3. **Offset Design**: Keep offset structure simple and small 234 | 4. **Caching Strategy**: Clear old cache entries in `commit()` method 235 | 5. **Error Handling**: Implement robust error handling in read methods 236 | 237 | ### Monitoring 238 | Key metrics to monitor: 239 | - Driver memory usage 240 | - Batch processing time 241 | - Records per batch 242 | - Offset checkpoint frequency 243 | - Cache hit/miss ratio 244 | 245 | ## Comparison with DataSourceStreamReader 246 | 247 | | Feature | SimpleDataSourceStreamReader | DataSourceStreamReader | 248 | |---------|------------------------------|------------------------| 249 | | Execution Location | Driver only | Distributed (executors) | 250 | | Partitioning | Single partition | Multiple partitions | 251 | | Scalability | Limited | High | 252 | | Implementation Complexity | Low | High | 253 | | Memory Requirements | Driver-heavy | Distributed | 254 | | Suitable Data Volume | < 10MB per batch | Unlimited | 255 | | Parallelism | None | Configurable | 256 | | Use Case | Prototyping, small streams | Production, large streams | 257 | 258 | ## Conclusion 259 | 260 | `SimpleDataSourceStreamReader` provides an elegant solution for integrating small-scale streaming data sources with Spark's Structured Streaming engine. By executing entirely on the driver, it simplifies development while maintaining full compatibility with Spark's streaming semantics. However, users must carefully consider the scalability limitations and ensure their use case fits within the architectural constraints of driver-side execution. 261 | 262 | For production systems with high throughput requirements, the standard `DataSourceStreamReader` with proper partitioning should be used instead. -------------------------------------------------------------------------------- /examples/salesforce_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Salesforce Datasource Example 5 | 6 | This example demonstrates how to use the SalesforceDataSource as a streaming datasource 7 | to write data from various sources to Salesforce objects. 8 | 9 | Requirements: 10 | - PySpark 4.0+ 11 | - simple-salesforce library 12 | - Valid Salesforce credentials 13 | 14 | Setup: 15 | pip install pyspark simple-salesforce 16 | 17 | Environment Variables: 18 | export SALESFORCE_USERNAME="your-username@company.com" 19 | export SALESFORCE_PASSWORD="your-password" 20 | export SALESFORCE_SECURITY_TOKEN="your-security-token" 21 | """ 22 | 23 | import os 24 | import sys 25 | import tempfile 26 | import csv 27 | from pyspark.sql import SparkSession 28 | from pyspark.sql.functions import col, lit, current_timestamp 29 | from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType 30 | 31 | 32 | def check_credentials(): 33 | """Check if Salesforce credentials are available""" 34 | username = os.getenv("SALESFORCE_USERNAME") 35 | password = os.getenv("SALESFORCE_PASSWORD") 36 | security_token = os.getenv("SALESFORCE_SECURITY_TOKEN") 37 | 38 | if not all([username, password, security_token]): 39 | print("❌ Missing Salesforce credentials!") 40 | print("Please set the following environment variables:") 41 | print(" export SALESFORCE_USERNAME='your-username@company.com'") 42 | print(" export SALESFORCE_PASSWORD='your-password'") 43 | print(" export SALESFORCE_SECURITY_TOKEN='your-security-token'") 44 | return False, None, None, None 45 | 46 | print(f"✅ Using Salesforce credentials for: {username}") 47 | return True, username, password, security_token 48 | 49 | 50 | def example_1_rate_source_to_accounts(): 51 | """Example 1: Stream from rate source to Salesforce Accounts""" 52 | print("\n" + "=" * 60) 53 | print("EXAMPLE 1: Rate Source → Salesforce Accounts") 54 | print("=" * 60) 55 | 56 | has_creds, username, password, security_token = check_credentials() 57 | if not has_creds: 58 | return 59 | 60 | spark = ( 61 | SparkSession.builder.appName("SalesforceExample1") 62 | .config("spark.sql.shuffle.partitions", "2") 63 | .getOrCreate() 64 | ) 65 | 66 | try: 67 | # Register Salesforce Datasource 68 | from pyspark_datasources.salesforce import SalesforceDataSource 69 | 70 | spark.dataSource.register(SalesforceDataSource) 71 | print("✅ Salesforce datasource registered") 72 | 73 | # Create streaming data from rate source 74 | streaming_df = spark.readStream.format("rate").option("rowsPerSecond", 2).load() 75 | 76 | # Transform to Account format 77 | account_data = streaming_df.select( 78 | col("timestamp").cast("string").alias("Name"), 79 | lit("Technology").alias("Industry"), 80 | (col("value") * 10000).cast("double").alias("AnnualRevenue"), 81 | ) 82 | 83 | print("📊 Starting streaming write to Salesforce Accounts...") 84 | 85 | # Write to Salesforce 86 | query = ( 87 | account_data.writeStream.format("pyspark.datasource.salesforce") 88 | .option("username", username) 89 | .option("password", password) 90 | .option("security_token", security_token) 91 | .option("salesforce_object", "Account") 92 | .option("batch_size", "10") 93 | .option("checkpointLocation", "/tmp/salesforce_example1_checkpoint") 94 | .trigger(once=True) 95 | .start() 96 | ) 97 | 98 | # Wait for completion 99 | query.awaitTermination(timeout=60) 100 | 101 | # Show results 102 | progress = query.lastProgress 103 | if progress: 104 | sources = progress.get("sources", []) 105 | if sources: 106 | records = sources[0].get("numInputRows", 0) 107 | print(f"✅ Successfully wrote {records} Account records to Salesforce") 108 | else: 109 | print("✅ Streaming completed successfully") 110 | 111 | except Exception as e: 112 | print(f"❌ Error: {e}") 113 | finally: 114 | spark.stop() 115 | 116 | 117 | def example_2_csv_to_contacts(): 118 | """Example 2: Stream from CSV files to Salesforce Contacts""" 119 | print("\n" + "=" * 60) 120 | print("EXAMPLE 2: CSV Files → Salesforce Contacts") 121 | print("=" * 60) 122 | 123 | has_creds, username, password, security_token = check_credentials() 124 | if not has_creds: 125 | return 126 | 127 | # Create temporary CSV directory 128 | csv_dir = tempfile.mkdtemp(prefix="salesforce_contacts_") 129 | print(f"📁 CSV directory: {csv_dir}") 130 | 131 | spark = ( 132 | SparkSession.builder.appName("SalesforceExample2") 133 | .config("spark.sql.shuffle.partitions", "2") 134 | .getOrCreate() 135 | ) 136 | 137 | try: 138 | # Register Salesforce datasource 139 | from pyspark_datasources.salesforce import SalesforceDataSource 140 | 141 | spark.dataSource.register(SalesforceDataSource) 142 | 143 | # Create sample CSV data 144 | contact_data = [ 145 | ["John", "Doe", "john.doe@example.com", "555-1234"], 146 | ["Jane", "Smith", "jane.smith@example.com", "555-5678"], 147 | ["Bob", "Johnson", "bob.johnson@example.com", "555-9012"], 148 | ] 149 | 150 | headers = ["FirstName", "LastName", "Email", "Phone"] 151 | csv_file = os.path.join(csv_dir, "contacts.csv") 152 | 153 | # Write CSV file 154 | with open(csv_file, "w", newline="") as f: 155 | writer = csv.writer(f) 156 | writer.writerow(headers) 157 | writer.writerows(contact_data) 158 | 159 | print(f"📝 Created CSV file with {len(contact_data)} contacts") 160 | 161 | # Define schema for CSV 162 | schema = StructType( 163 | [ 164 | StructField("FirstName", StringType(), True), 165 | StructField("LastName", StringType(), True), 166 | StructField("Email", StringType(), True), 167 | StructField("Phone", StringType(), True), 168 | ] 169 | ) 170 | 171 | # Stream from CSV 172 | streaming_df = ( 173 | spark.readStream.format("csv").option("header", "true").schema(schema).load(csv_dir) 174 | ) 175 | 176 | print("📊 Starting streaming write to Salesforce Contacts...") 177 | 178 | # Write to Salesforce with custom schema 179 | query = ( 180 | streaming_df.writeStream.format("pyspark.datasource.salesforce") 181 | .option("username", username) 182 | .option("password", password) 183 | .option("security_token", security_token) 184 | .option("salesforce_object", "Contact") 185 | .option( 186 | "schema", "FirstName STRING, LastName STRING NOT NULL, Email STRING, Phone STRING" 187 | ) 188 | .option("batch_size", "5") 189 | .option("checkpointLocation", "/tmp/salesforce_example2_checkpoint") 190 | .trigger(once=True) 191 | .start() 192 | ) 193 | 194 | # Wait for completion 195 | query.awaitTermination(timeout=60) 196 | 197 | # Show results 198 | progress = query.lastProgress 199 | if progress: 200 | sources = progress.get("sources", []) 201 | if sources: 202 | records = sources[0].get("numInputRows", 0) 203 | print(f"✅ Successfully wrote {records} Contact records to Salesforce") 204 | else: 205 | print("✅ Streaming completed successfully") 206 | 207 | except Exception as e: 208 | print(f"❌ Error: {e}") 209 | finally: 210 | spark.stop() 211 | # Cleanup 212 | import shutil 213 | 214 | if os.path.exists(csv_dir): 215 | shutil.rmtree(csv_dir) 216 | 217 | 218 | def example_3_checkpoint_demonstration(): 219 | """Example 3: Demonstrate checkpoint functionality with incremental data""" 220 | print("\n" + "=" * 60) 221 | print("EXAMPLE 3: Checkpoint Functionality Demonstration") 222 | print("=" * 60) 223 | 224 | has_creds, username, password, security_token = check_credentials() 225 | if not has_creds: 226 | return 227 | 228 | # Create temporary directories 229 | csv_dir = tempfile.mkdtemp(prefix="salesforce_checkpoint_") 230 | checkpoint_dir = tempfile.mkdtemp(prefix="checkpoint_") 231 | 232 | print(f"📁 CSV directory: {csv_dir}") 233 | print(f"📁 Checkpoint directory: {checkpoint_dir}") 234 | 235 | try: 236 | # Phase 1: Create initial data and first stream 237 | print("\n📊 Phase 1: Creating initial batch and first stream...") 238 | 239 | initial_data = [ 240 | [1, "InitialCorp_A", "Tech", 100000.0], 241 | [2, "InitialCorp_B", "Finance", 200000.0], 242 | [3, "InitialCorp_C", "Healthcare", 300000.0], 243 | ] 244 | 245 | headers = ["id", "name", "industry", "revenue"] 246 | csv_file1 = os.path.join(csv_dir, "batch_001.csv") 247 | 248 | with open(csv_file1, "w", newline="") as f: 249 | writer = csv.writer(f) 250 | writer.writerow(headers) 251 | writer.writerows(initial_data) 252 | 253 | # First stream 254 | spark = ( 255 | SparkSession.builder.appName("SalesforceCheckpointDemo") 256 | .config("spark.sql.shuffle.partitions", "2") 257 | .getOrCreate() 258 | ) 259 | 260 | from pyspark_datasources.salesforce import SalesforceDataSource 261 | 262 | spark.dataSource.register(SalesforceDataSource) 263 | 264 | schema = StructType( 265 | [ 266 | StructField("id", IntegerType(), True), 267 | StructField("name", StringType(), True), 268 | StructField("industry", StringType(), True), 269 | StructField("revenue", DoubleType(), True), 270 | ] 271 | ) 272 | 273 | streaming_df1 = ( 274 | spark.readStream.format("csv").option("header", "true").schema(schema).load(csv_dir) 275 | ) 276 | 277 | account_df1 = streaming_df1.select( 278 | col("name").alias("Name"), 279 | col("industry").alias("Industry"), 280 | col("revenue").alias("AnnualRevenue"), 281 | ) 282 | 283 | query1 = ( 284 | account_df1.writeStream.format("pyspark.datasource.salesforce") 285 | .option("username", username) 286 | .option("password", password) 287 | .option("security_token", security_token) 288 | .option("salesforce_object", "Account") 289 | .option("batch_size", "10") 290 | .option("checkpointLocation", checkpoint_dir) 291 | .trigger(once=True) 292 | .start() 293 | ) 294 | 295 | query1.awaitTermination(timeout=60) 296 | 297 | progress1 = query1.lastProgress 298 | records1 = 0 299 | if progress1 and progress1.get("sources"): 300 | records1 = progress1["sources"][0].get("numInputRows", 0) 301 | 302 | print(f" ✅ First stream processed {records1} records") 303 | 304 | # Phase 2: Add new data and second stream 305 | print("\n📊 Phase 2: Adding new batch and second stream...") 306 | 307 | import time 308 | 309 | time.sleep(2) # Brief pause 310 | 311 | new_data = [[4, "NewCorp_D", "Energy", 400000.0], [5, "NewCorp_E", "Retail", 500000.0]] 312 | 313 | csv_file2 = os.path.join(csv_dir, "batch_002.csv") 314 | with open(csv_file2, "w", newline="") as f: 315 | writer = csv.writer(f) 316 | writer.writerow(headers) 317 | writer.writerows(new_data) 318 | 319 | # Second stream with same checkpoint 320 | streaming_df2 = ( 321 | spark.readStream.format("csv").option("header", "true").schema(schema).load(csv_dir) 322 | ) 323 | 324 | account_df2 = streaming_df2.select( 325 | col("name").alias("Name"), 326 | col("industry").alias("Industry"), 327 | col("revenue").alias("AnnualRevenue"), 328 | ) 329 | 330 | query2 = ( 331 | account_df2.writeStream.format("pyspark.datasource.salesforce") 332 | .option("username", username) 333 | .option("password", password) 334 | .option("security_token", security_token) 335 | .option("salesforce_object", "Account") 336 | .option("batch_size", "10") 337 | .option("checkpointLocation", checkpoint_dir) 338 | .trigger(once=True) 339 | .start() 340 | ) 341 | 342 | query2.awaitTermination(timeout=60) 343 | 344 | progress2 = query2.lastProgress 345 | records2 = 0 346 | if progress2 and progress2.get("sources"): 347 | records2 = progress2["sources"][0].get("numInputRows", 0) 348 | 349 | print(f" ✅ Second stream processed {records2} records") 350 | 351 | # Analyze checkpoint functionality 352 | print(f"\n📈 Checkpoint Analysis:") 353 | print(f" - First stream: {records1} records (initial batch)") 354 | print(f" - Second stream: {records2} records (new batch)") 355 | print(f" - Total: {records1 + records2} records") 356 | 357 | if records1 == 3 and records2 == 2: 358 | print(" ✅ PERFECT: Exactly-once processing achieved!") 359 | print(" ✅ Checkpoint functionality working correctly") 360 | else: 361 | print(" ⚠️ Results may vary due to timing or file detection") 362 | 363 | except Exception as e: 364 | print(f"❌ Error: {e}") 365 | finally: 366 | spark.stop() 367 | # Cleanup 368 | import shutil 369 | 370 | if os.path.exists(csv_dir): 371 | shutil.rmtree(csv_dir) 372 | if os.path.exists(checkpoint_dir): 373 | shutil.rmtree(checkpoint_dir) 374 | 375 | 376 | def example_4_custom_object(): 377 | """Example 4: Write to custom Salesforce object""" 378 | print("\n" + "=" * 60) 379 | print("EXAMPLE 4: Custom Salesforce Object") 380 | print("=" * 60) 381 | 382 | has_creds, username, password, security_token = check_credentials() 383 | if not has_creds: 384 | return 385 | 386 | print("📝 This example shows how to write to custom Salesforce objects") 387 | print(" Note: Make sure your custom object exists in Salesforce") 388 | 389 | spark = ( 390 | SparkSession.builder.appName("SalesforceCustomObjectExample") 391 | .config("spark.sql.shuffle.partitions", "2") 392 | .getOrCreate() 393 | ) 394 | 395 | try: 396 | from pyspark_datasources.salesforce import SalesforceDataSource 397 | 398 | spark.dataSource.register(SalesforceDataSource) 399 | 400 | # Create sample data for custom object 401 | streaming_df = spark.readStream.format("rate").option("rowsPerSecond", 1).load() 402 | 403 | # Transform for custom object (example: Product__c) 404 | custom_data = streaming_df.select( 405 | col("value").cast("string").alias("Product_Code__c"), 406 | lit("Sample Product").alias("Name"), 407 | (col("value") * 29.99).cast("double").alias("Price__c"), 408 | current_timestamp().alias("Created_Date__c"), 409 | ) 410 | 411 | print("📊 Example configuration for custom object...") 412 | print(" Custom Object: Product__c") 413 | print(" Fields: Product_Code__c, Name, Price__c, Created_Date__c") 414 | print("\n Note: Uncomment and modify the following code for your custom object:") 415 | 416 | # Example code (commented out since custom object may not exist) 417 | print(""" 418 | query = custom_data.writeStream \\ 419 | .format("pyspark.datasource.salesforce") \\ 420 | .option("username", username) \\ 421 | .option("password", password) \\ 422 | .option("security_token", security_token) \\ 423 | .option("salesforce_object", "Product__c") \\ 424 | .option("schema", "Product_Code__c STRING, Name STRING, Price__c DOUBLE, Created_Date__c TIMESTAMP") \\ 425 | .option("batch_size", "20") \\ 426 | .option("checkpointLocation", "/tmp/custom_object_checkpoint") \\ 427 | .trigger(processingTime="10 seconds") \\ 428 | .start() 429 | """) 430 | 431 | print("✅ Custom object example configuration shown") 432 | 433 | except Exception as e: 434 | print(f"❌ Error: {e}") 435 | finally: 436 | spark.stop() 437 | 438 | 439 | def main(): 440 | """Run all examples""" 441 | print("🚀 Salesforce Datasource Examples") 442 | print("This demonstrates various ways to use the Salesforce streaming datasource") 443 | 444 | try: 445 | # Run examples 446 | example_1_rate_source_to_accounts() 447 | example_2_csv_to_contacts() 448 | example_3_checkpoint_demonstration() 449 | example_4_custom_object() 450 | 451 | print("\n" + "=" * 60) 452 | print("✅ All examples completed!") 453 | print("=" * 60) 454 | print("\n💡 Key takeaways:") 455 | print(" - Salesforce datasource supports various input sources (rate, CSV, etc.)") 456 | print(" - Checkpoint functionality enables exactly-once processing") 457 | print(" - Custom schemas allow flexibility for different Salesforce objects") 458 | print(" - Batch processing optimizes Salesforce API usage") 459 | print(" - Error handling provides fallback to individual record creation") 460 | 461 | except KeyboardInterrupt: 462 | print("\n⏹️ Examples interrupted by user") 463 | except Exception as e: 464 | print(f"\n❌ Unexpected error: {e}") 465 | 466 | 467 | if __name__ == "__main__": 468 | main() 469 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pyspark-data-sources" 3 | version = "0.1.10" 4 | description = "Custom Spark data sources for reading and writing data in Apache Spark, using the Python Data Source API" 5 | authors = ["allisonwang-db "] 6 | license = "Apache License 2.0" 7 | readme = "README.md" 8 | packages = [ 9 | { include = "pyspark_datasources" }, 10 | ] 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.9,<3.13" 14 | pyarrow = ">=11.0.0" 15 | requests = "^2.31.0" 16 | faker = "^23.1.0" 17 | mkdocstrings = {extras = ["python"], version = "^0.28.0"} 18 | datasets = {version = "^2.17.0", optional = true} 19 | databricks-sdk = {version = "^0.28.0", optional = true} 20 | kagglehub = {extras = ["pandas-datasets"], version = "^0.3.10", optional = true} 21 | simple-salesforce = {version = "^1.12.0", optional = true} 22 | pynacl = {version = "^1.5.0", optional = true} 23 | 24 | [tool.poetry.extras] 25 | faker = ["faker"] 26 | datasets = ["datasets"] 27 | databricks = ["databricks-sdk"] 28 | kaggle = ["kagglehub"] 29 | lance = ["pylance"] 30 | robinhood = ["pynacl"] 31 | salesforce = ["simple-salesforce"] 32 | all = ["faker", "datasets", "databricks-sdk", "kagglehub", "pynacl", "simple-salesforce"] 33 | 34 | [tool.poetry.group.dev.dependencies] 35 | pytest = "^8.0.0" 36 | pytest-cov = "^4.0.0" 37 | grpcio = "^1.60.1" 38 | grpcio-status = "^1.60.1" 39 | pandas = "^2.2.0" 40 | mkdocs-material = "^9.5.40" 41 | pyspark = "4.0.0" 42 | ruff = "^0.6.0" 43 | 44 | [tool.ruff] 45 | line-length = 100 46 | target-version = "py39" 47 | 48 | [tool.ruff.format] 49 | quote-style = "double" 50 | indent-style = "space" 51 | 52 | [tool.ruff.lint] 53 | select = ["E", "F", "I", "UP"] 54 | ignore = [] 55 | 56 | [build-system] 57 | requires = ["poetry-core"] 58 | build-backend = "poetry.core.masonry.api" 59 | -------------------------------------------------------------------------------- /pyspark_datasources/__init__.py: -------------------------------------------------------------------------------- 1 | from .arrow import ArrowDataSource 2 | from .fake import FakeDataSource 3 | from .github import GithubDataSource 4 | from .googlesheets import GoogleSheetsDataSource 5 | from .huggingface import HuggingFaceDatasets 6 | from .kaggle import KaggleDataSource 7 | from .opensky import OpenSkyDataSource 8 | from .robinhood import RobinhoodDataSource 9 | from .salesforce import SalesforceDataSource 10 | from .simplejson import SimpleJsonDataSource 11 | from .stock import StockDataSource 12 | from .jsonplaceholder import JSONPlaceholderDataSource 13 | -------------------------------------------------------------------------------- /pyspark_datasources/arrow.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterator, Union 2 | import os 3 | import glob 4 | 5 | import pyarrow as pa 6 | from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition 7 | from pyspark.sql.types import StructType 8 | 9 | 10 | class ArrowDataSource(DataSource): 11 | """ 12 | A data source for reading Apache Arrow files (.arrow) using PyArrow. 13 | 14 | This data source supports reading Arrow IPC files from local filesystem or 15 | cloud storage, leveraging PyArrow's efficient columnar format and returning 16 | PyArrow RecordBatch objects for optimal performance with PySpark's Arrow integration. 17 | 18 | Name: `arrow` 19 | 20 | Path Support 21 | ----------- 22 | Supports various path patterns in the load() method: 23 | - Single file: "/path/to/file.arrow" 24 | - Glob patterns: "/path/to/*.arrow" or "/path/to/data*.arrow" 25 | - Directory: "/path/to/directory" (reads all .arrow files) 26 | 27 | Partitioning Strategy 28 | -------------------- 29 | The data source creates one partition per file for parallel processing: 30 | - Single file: 1 partition 31 | - Multiple files: N partitions (one per file) 32 | - Directory: N partitions (one per .arrow file found) 33 | 34 | This enables Spark to process multiple files in parallel across different 35 | executor cores, improving performance for large datasets. 36 | 37 | Performance Notes 38 | ---------------- 39 | - Returns PyArrow RecordBatch objects for zero-copy data transfer 40 | - Leverages PySpark 4.0's enhanced Arrow integration 41 | - For DataFrames created in Spark, consider using the new df.to_arrow() method 42 | in PySpark 4.0+ for efficient Arrow conversion 43 | 44 | Examples 45 | -------- 46 | Register the data source: 47 | 48 | >>> from pyspark_datasources import ArrowDataSource 49 | >>> spark.dataSource.register(ArrowDataSource) 50 | 51 | Read a single Arrow file: 52 | 53 | >>> df = spark.read.format("arrow").load("/path/to/employees.arrow") 54 | >>> df.show() 55 | +---+-----------+---+-------+----------+------+ 56 | | id| name|age| salary|department|active| 57 | +---+-----------+---+-------+----------+------+ 58 | | 1|Alice Smith| 28|65000.0| Tech| true| 59 | +---+-----------+---+-------+----------+------+ 60 | 61 | Read multiple files with glob pattern (creates multiple partitions): 62 | 63 | >>> df = spark.read.format("arrow").load("/data/sales/sales_*.arrow") 64 | >>> df.show() 65 | >>> print(f"Number of partitions: {df.rdd.getNumPartitions()}") 66 | 67 | Read all Arrow files in a directory: 68 | 69 | >>> df = spark.read.format("arrow").load("/data/warehouse/") 70 | >>> df.show() 71 | 72 | Working with the result DataFrame and PySpark 4.0 Arrow integration: 73 | 74 | >>> df = spark.read.format("arrow").load("/path/to/data.arrow") 75 | >>> 76 | >>> # Process with Spark 77 | >>> result = df.filter(df.age > 25).groupBy("department").count() 78 | >>> result.show() 79 | >>> 80 | >>> # Convert back to Arrow using PySpark 4.0+ feature 81 | >>> arrow_table = result.to_arrow() # New in PySpark 4.0+ 82 | >>> print(f"Arrow table: {arrow_table}") 83 | 84 | Schema inference example: 85 | 86 | >>> # Schema is automatically inferred from the first file 87 | >>> df = spark.read.format("arrow").load("/path/to/*.arrow") 88 | >>> df.printSchema() 89 | root 90 | |-- product_id: long (nullable = true) 91 | |-- product_name: string (nullable = true) 92 | |-- price: double (nullable = true) 93 | """ 94 | 95 | @classmethod 96 | def name(cls): 97 | return "arrow" 98 | 99 | def schema(self) -> StructType: 100 | path = self.options.get("path") 101 | if not path: 102 | raise ValueError("Path option is required for Arrow data source") 103 | 104 | # Get the first file to determine schema 105 | files = self._get_files(path) 106 | if not files: 107 | raise ValueError(f"No files found at path: {path}") 108 | 109 | # Read schema from first file (Arrow IPC format) 110 | with pa.ipc.open_file(files[0]) as reader: 111 | table = reader.read_all() 112 | 113 | # Convert PyArrow schema to Spark schema using PySpark utility 114 | from pyspark.sql.pandas.types import from_arrow_schema 115 | 116 | return from_arrow_schema(table.schema) 117 | 118 | def reader(self, schema: StructType) -> "ArrowDataSourceReader": 119 | return ArrowDataSourceReader(schema, self.options) 120 | 121 | def _get_files(self, path: str) -> List[str]: 122 | """Get list of files matching the path pattern.""" 123 | if os.path.isfile(path): 124 | return [path] 125 | elif os.path.isdir(path): 126 | # Find all arrow files in directory 127 | arrow_files = glob.glob(os.path.join(path, "*.arrow")) 128 | return sorted(arrow_files) 129 | else: 130 | # Treat as glob pattern 131 | return sorted(glob.glob(path)) 132 | 133 | 134 | class ArrowDataSourceReader(DataSourceReader): 135 | """Reader for Arrow data source.""" 136 | 137 | def __init__(self, schema: StructType, options: dict) -> None: 138 | self.schema = schema 139 | self.options = options 140 | self.path = options.get("path") 141 | if not self.path: 142 | raise ValueError("Path option is required") 143 | 144 | def partitions(self) -> List[InputPartition]: 145 | """Create partitions, one per file for parallel reading.""" 146 | data_source = ArrowDataSource(self.options) 147 | files = data_source._get_files(self.path) 148 | return [InputPartition(file_path) for file_path in files] 149 | 150 | def read(self, partition: InputPartition) -> Iterator[pa.RecordBatch]: 151 | """Read data from a single file partition, returning PyArrow RecordBatch.""" 152 | file_path = partition.value 153 | 154 | try: 155 | # Read Arrow IPC file 156 | with pa.ipc.open_file(file_path) as reader: 157 | for i in range(reader.num_record_batches): 158 | batch = reader.get_batch(i) 159 | yield batch 160 | except Exception as e: 161 | raise RuntimeError(f"Failed to read Arrow file {file_path}: {str(e)}") 162 | -------------------------------------------------------------------------------- /pyspark_datasources/fake.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pyspark.sql.datasource import ( 4 | DataSource, 5 | DataSourceReader, 6 | DataSourceStreamReader, 7 | InputPartition, 8 | ) 9 | from pyspark.sql.types import StringType, StructType 10 | 11 | 12 | def _validate_faker_schema(schema): 13 | # Verify the library is installed correctly. 14 | try: 15 | from faker import Faker 16 | except ImportError: 17 | raise Exception("You need to install `faker` to use the fake datasource.") 18 | 19 | fake = Faker() 20 | for field in schema.fields: 21 | try: 22 | getattr(fake, field.name)() 23 | except AttributeError: 24 | raise Exception( 25 | f"Unable to find a method called `{field.name}` in faker. " 26 | f"Please check Faker's documentation to see supported methods." 27 | ) 28 | if field.dataType != StringType(): 29 | raise Exception( 30 | f"Field `{field.name}` is not a StringType. " 31 | f"Only StringType is supported in the fake datasource." 32 | ) 33 | 34 | 35 | class FakeDataSource(DataSource): 36 | """ 37 | A fake data source for PySpark to generate synthetic data using the `faker` library. 38 | 39 | This data source allows specifying a schema with field names that correspond to `faker` 40 | providers to generate random data for testing and development purposes. 41 | 42 | The default schema is `name string, date string, zipcode string, state string`, and the 43 | default number of rows is `3`. Both can be customized by users. 44 | 45 | Name: `fake` 46 | 47 | Notes 48 | ----- 49 | - The fake data source relies on the `faker` library. Make sure it is installed and accessible. 50 | - Only string type fields are supported, and each field name must correspond to a method name in 51 | the `faker` library. 52 | - When using the stream reader, `numRows` is the number of rows per microbatch. 53 | 54 | Examples 55 | -------- 56 | Register the data source. 57 | 58 | >>> from pyspark_datasources import FakeDataSource 59 | >>> spark.dataSource.register(FakeDataSource) 60 | 61 | Use the fake datasource with the default schema and default number of rows: 62 | 63 | >>> spark.read.format("fake").load().show() 64 | +-----------+----------+-------+-------+ 65 | | name| date|zipcode| state| 66 | +-----------+----------+-------+-------+ 67 | |Carlos Cobb|2018-07-15| 73003|Indiana| 68 | | Eric Scott|1991-08-22| 10085| Idaho| 69 | | Amy Martin|1988-10-28| 68076| Oregon| 70 | +-----------+----------+-------+-------+ 71 | 72 | Use the fake datasource with a custom schema: 73 | 74 | >>> spark.read.format("fake").schema("name string, company string").load().show() 75 | +---------------------+--------------+ 76 | |name |company | 77 | +---------------------+--------------+ 78 | |Tanner Brennan |Adams Group | 79 | |Leslie Maxwell |Santiago Group| 80 | |Mrs. Jacqueline Brown|Maynard Inc | 81 | +---------------------+--------------+ 82 | 83 | Use the fake datasource with a different number of rows: 84 | 85 | >>> spark.read.format("fake").option("numRows", 5).load().show() 86 | +--------------+----------+-------+------------+ 87 | | name| date|zipcode| state| 88 | +--------------+----------+-------+------------+ 89 | | Pam Mitchell|1988-10-20| 23788| Tennessee| 90 | |Melissa Turner|1996-06-14| 30851| Nevada| 91 | | Brian Ramsey|2021-08-21| 55277| Washington| 92 | | Caitlin Reed|1983-06-22| 89813|Pennsylvania| 93 | | Douglas James|2007-01-18| 46226| Alabama| 94 | +--------------+----------+-------+------------+ 95 | 96 | Streaming fake data: 97 | 98 | >>> stream = spark.readStream.format("fake").load().writeStream.format("console").start() 99 | Batch: 0 100 | +--------------+----------+-------+------------+ 101 | | name| date|zipcode| state| 102 | +--------------+----------+-------+------------+ 103 | | Tommy Diaz|1976-11-17| 27627|South Dakota| 104 | |Jonathan Perez|1986-02-23| 81307|Rhode Island| 105 | | Julia Farmer|1990-10-10| 40482| Virginia| 106 | +--------------+----------+-------+------------+ 107 | Batch: 1 108 | ... 109 | >>> stream.stop() 110 | """ 111 | 112 | @classmethod 113 | def name(cls): 114 | return "fake" 115 | 116 | def schema(self): 117 | return "name string, date string, zipcode string, state string" 118 | 119 | def reader(self, schema: StructType) -> "FakeDataSourceReader": 120 | _validate_faker_schema(schema) 121 | return FakeDataSourceReader(schema, self.options) 122 | 123 | def streamReader(self, schema) -> "FakeDataSourceStreamReader": 124 | _validate_faker_schema(schema) 125 | return FakeDataSourceStreamReader(schema, self.options) 126 | 127 | 128 | class FakeDataSourceReader(DataSourceReader): 129 | def __init__(self, schema, options) -> None: 130 | self.schema: StructType = schema 131 | self.options = options 132 | 133 | def read(self, partition): 134 | from faker import Faker 135 | 136 | fake = Faker() 137 | # Note: every value in this `self.options` dictionary is a string. 138 | num_rows = int(self.options.get("numRows", 3)) 139 | for _ in range(num_rows): 140 | row = [] 141 | for field in self.schema.fields: 142 | value = getattr(fake, field.name)() 143 | row.append(value) 144 | yield tuple(row) 145 | 146 | 147 | class FakeDataSourceStreamReader(DataSourceStreamReader): 148 | def __init__(self, schema, options) -> None: 149 | self.schema: StructType = schema 150 | self.rows_per_microbatch = int(options.get("numRows", 3)) 151 | self.options = options 152 | self.offset = 0 153 | 154 | def initialOffset(self) -> dict: 155 | return {"offset": 0} 156 | 157 | def latestOffset(self) -> dict: 158 | self.offset += self.rows_per_microbatch 159 | return {"offset": self.offset} 160 | 161 | def partitions(self, start, end) -> List[InputPartition]: 162 | return [InputPartition(end["offset"] - start["offset"])] 163 | 164 | def read(self, partition): 165 | from faker import Faker 166 | 167 | fake = Faker() 168 | for _ in range(partition.value): 169 | row = [] 170 | for field in self.schema.fields: 171 | value = getattr(fake, field.name)() 172 | row.append(value) 173 | yield tuple(row) 174 | -------------------------------------------------------------------------------- /pyspark_datasources/github.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | from pyspark.sql import Row 4 | from pyspark.sql.datasource import DataSource, DataSourceReader 5 | 6 | 7 | class GithubDataSource(DataSource): 8 | """ 9 | A DataSource for reading pull requests data from Github. 10 | 11 | Name: `github` 12 | 13 | Schema: `id int, title string, author string, created_at string, updated_at string` 14 | 15 | Examples 16 | -------- 17 | Register the data source. 18 | 19 | >>> from pyspark_datasources import GithubDataSource 20 | >>> spark.dataSource.register(GithubDataSource) 21 | 22 | Load pull requests data from a public Github repository. 23 | 24 | >>> spark.read.format("github").load("apache/spark").show() 25 | +---+--------------------+--------+--------------------+--------------------+ 26 | | id| title| author| created_at| updated_at| 27 | +---+--------------------+--------+--------------------+--------------------+ 28 | | 1|Initial commit | matei |2014-02-03T18:47:...|2014-02-03T18:47:...| 29 | |...| ...| ...| ...| ...| 30 | +---+--------------------+--------+--------------------+--------------------+ 31 | 32 | Load pull requests data from a private Github repository. 33 | 34 | >>> spark.read.format("github").option("token", "your-token").load("owner/repo").show() 35 | """ 36 | 37 | @classmethod 38 | def name(self): 39 | return "github" 40 | 41 | def schema(self): 42 | return "id int, title string, author string, created_at string, updated_at string" 43 | 44 | def reader(self, schema): 45 | return GithubPullRequestReader(self.options) 46 | 47 | 48 | class GithubPullRequestReader(DataSourceReader): 49 | def __init__(self, options): 50 | self.token = options.get("token") 51 | self.repo = options.get("path") 52 | if self.repo is None: 53 | raise Exception(f"Must specify a repo in `.load()` method.") 54 | 55 | def read(self, partition): 56 | header = { 57 | "Accept": "application/vnd.github+json", 58 | } 59 | if self.token is not None: 60 | header["Authorization"] = f"Bearer {self.token}" 61 | url = f"https://api.github.com/repos/{self.repo}/pulls" 62 | response = requests.get(url) 63 | response.raise_for_status() 64 | prs = response.json() 65 | for pr in prs: 66 | yield Row( 67 | id=pr.get("number"), 68 | title=pr.get("title"), 69 | author=pr.get("user", {}).get("login"), 70 | created_at=pr.get("created_at"), 71 | updated_at=pr.get("updated_at"), 72 | ) 73 | -------------------------------------------------------------------------------- /pyspark_datasources/googlesheets.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional 3 | 4 | from pyspark.sql.datasource import DataSource, DataSourceReader 5 | from pyspark.sql.types import StringType, StructField, StructType 6 | 7 | 8 | @dataclass 9 | class Sheet: 10 | """ 11 | A dataclass to identify a Google Sheets document. 12 | 13 | Attributes 14 | ---------- 15 | spreadsheet_id : str 16 | The ID of the Google Sheets document. 17 | sheet_id : str, optional 18 | The ID of the worksheet within the document. 19 | """ 20 | 21 | spreadsheet_id: str 22 | sheet_id: Optional[str] = None # if None, the first sheet is used 23 | 24 | @classmethod 25 | def from_url(cls, url: str) -> "Sheet": 26 | """ 27 | Converts a Google Sheets URL to a Sheet object. 28 | """ 29 | from urllib.parse import parse_qs, urlparse 30 | 31 | parsed = urlparse(url) 32 | if parsed.netloc != "docs.google.com" or not parsed.path.startswith("/spreadsheets/d/"): 33 | raise ValueError("URL is not a Google Sheets URL") 34 | qs = parse_qs(parsed.query) 35 | spreadsheet_id = parsed.path.split("/")[3] 36 | if "gid" in qs: 37 | sheet_id = qs["gid"][0] 38 | else: 39 | sheet_id = None 40 | return cls(spreadsheet_id, sheet_id) 41 | 42 | def get_query_url(self, query: Optional[str] = None): 43 | """ 44 | Gets the query url that returns the results of the query as a CSV file. 45 | 46 | If no query is provided, returns the entire sheet. 47 | If sheet ID is None, uses the first sheet. 48 | 49 | See https://developers.google.com/chart/interactive/docs/querylanguage 50 | """ 51 | from urllib.parse import urlencode 52 | 53 | path = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/gviz/tq" 54 | url_query = {"tqx": "out:csv"} 55 | if self.sheet_id: 56 | url_query["gid"] = self.sheet_id 57 | if query: 58 | url_query["tq"] = query 59 | return f"{path}?{urlencode(url_query)}" 60 | 61 | 62 | @dataclass 63 | class Parameters: 64 | sheet: Sheet 65 | has_header: bool 66 | 67 | 68 | class GoogleSheetsDataSource(DataSource): 69 | """ 70 | A DataSource for reading table from public Google Sheets. 71 | 72 | Name: `googlesheets` 73 | 74 | Schema: By default, all columns are treated as strings and the header row defines the column names. 75 | 76 | Options 77 | -------- 78 | - `url`: The URL of the Google Sheets document. 79 | - `path`: The ID of the Google Sheets document. 80 | - `sheet_id`: The ID of the worksheet within the document. 81 | - `has_header`: Whether the sheet has a header row. Default is `true`. 82 | 83 | Either `url` or `path` must be specified, but not both. 84 | 85 | Examples 86 | -------- 87 | Register the data source. 88 | 89 | >>> from pyspark_datasources import GoogleSheetsDataSource 90 | >>> spark.dataSource.register(GoogleSheetsDataSource) 91 | 92 | Load data from a public Google Sheets document using `path` and optional `sheet_id`. 93 | 94 | >>> spreadsheet_id = "10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0" 95 | >>> spark.read.format("googlesheets").options(sheet_id="0").load(spreadsheet_id).show() 96 | +-------+---------+---------+-------+ 97 | |country| latitude|longitude| name| 98 | +-------+---------+---------+-------+ 99 | | AD|42.546245| 1.601554|Andorra| 100 | | ...| ...| ...| ...| 101 | +-------+---------+---------+-------+ 102 | 103 | Load data from a public Google Sheets document using `url`. 104 | 105 | >>> url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=0#gid=0" 106 | >>> spark.read.format("googlesheets").options(url=url).load().show() 107 | +-------+---------+--------+-------+ 108 | |country| latitude|ongitude| name| 109 | +-------+---------+--------+-------+ 110 | | AD|42.546245|1.601554|Andorra| 111 | | ...| ...| ...| ...| 112 | +-------+---------+--------+-------+ 113 | 114 | Specify custom schema. 115 | 116 | >>> schema = "id string, lat double, long double, name string" 117 | >>> spark.read.format("googlesheets").schema(schema).options(url=url).load().show() 118 | +---+---------+--------+-------+ 119 | | id| lat| long| name| 120 | +---+---------+--------+-------+ 121 | | AD|42.546245|1.601554|Andorra| 122 | |...| ...| ...| ...| 123 | +---+---------+--------+-------+ 124 | 125 | Treat first row as data instead of header. 126 | 127 | >>> schema = "c1 string, c2 string, c3 string, c4 string" 128 | >>> spark.read.format("googlesheets").schema(schema).options(url=url, has_header="false").load().show() 129 | +-------+---------+---------+-------+ 130 | | c1| c2| c3| c4| 131 | +-------+---------+---------+-------+ 132 | |country| latitude|longitude| name| 133 | | AD|42.546245| 1.601554|Andorra| 134 | | ...| ...| ...| ...| 135 | +-------+---------+---------+-------+ 136 | """ 137 | 138 | @classmethod 139 | def name(self): 140 | return "googlesheets" 141 | 142 | def __init__(self, options: Dict[str, str]): 143 | if "url" in options: 144 | sheet = Sheet.from_url(options.pop("url")) 145 | elif "path" in options: 146 | sheet = Sheet(options.pop("path"), options.pop("sheet_id", None)) 147 | else: 148 | raise ValueError("You must specify either `url` or `path` (spreadsheet ID).") 149 | has_header = options.pop("has_header", "true").lower() == "true" 150 | self.parameters = Parameters(sheet, has_header) 151 | 152 | def schema(self) -> StructType: 153 | if not self.parameters.has_header: 154 | raise ValueError("Custom schema is required when `has_header` is false") 155 | 156 | import pandas as pd 157 | 158 | # Read schema from the first row of the sheet 159 | df = pd.read_csv(self.parameters.sheet.get_query_url("select * limit 1")) 160 | return StructType([StructField(col, StringType()) for col in df.columns]) 161 | 162 | def reader(self, schema: StructType) -> DataSourceReader: 163 | return GoogleSheetsReader(self.parameters, schema) 164 | 165 | 166 | class GoogleSheetsReader(DataSourceReader): 167 | def __init__(self, parameters: Parameters, schema: StructType): 168 | self.parameters = parameters 169 | self.schema = schema 170 | 171 | def read(self, partition): 172 | from urllib.request import urlopen 173 | 174 | from pyarrow import csv 175 | from pyspark.sql.pandas.types import to_arrow_schema 176 | 177 | # Specify column types based on the schema 178 | convert_options = csv.ConvertOptions( 179 | column_types=to_arrow_schema(self.schema), 180 | ) 181 | read_options = csv.ReadOptions( 182 | column_names=self.schema.fieldNames(), # Rename columns 183 | skip_rows=( 184 | 1 if self.parameters.has_header else 0 # Skip header row if present 185 | ), 186 | ) 187 | with urlopen(self.parameters.sheet.get_query_url()) as file: 188 | yield from csv.read_csv( 189 | file, read_options=read_options, convert_options=convert_options 190 | ).to_batches() 191 | -------------------------------------------------------------------------------- /pyspark_datasources/huggingface.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql.datasource import DataSource, DataSourceReader 2 | from pyspark.sql.types import StructType, StructField, StringType 3 | 4 | 5 | class HuggingFaceDatasets(DataSource): 6 | """ 7 | An example data source for reading HuggingFace Datasets in Spark. 8 | 9 | This data source allows reading public datasets from the HuggingFace Hub directly into Spark 10 | DataFrames. The schema is automatically inferred from the dataset features. The split can be 11 | specified using the `split` option. The default split is `train`. 12 | 13 | Name: `huggingface` 14 | 15 | Notes: 16 | ----- 17 | - Please use the official HuggingFace Datasets API: https://github.com/huggingface/pyspark_huggingface. 18 | - The HuggingFace `datasets` library is required to use this data source. Make sure it is installed. 19 | - If the schema is automatically inferred, it will use string type for all fields. 20 | - Currently it can only be used with public datasets. Private or gated ones are not supported. 21 | 22 | Examples 23 | -------- 24 | Register the data source. 25 | 26 | >>> from pyspark_datasources import HuggingFaceDatasets 27 | >>> spark.dataSource.register(HuggingFaceDatasets) 28 | 29 | Load a public dataset from the HuggingFace Hub. 30 | 31 | >>> spark.read.format("huggingface").load("imdb").show() 32 | +--------------------+-----+ 33 | | text|label| 34 | +--------------------+-----+ 35 | |I rented I AM CUR...| 0| 36 | |"I Am Curious: Ye...| 0| 37 | |... | ...| 38 | +--------------------+-----+ 39 | 40 | Load a specific split from a public dataset from the HuggingFace Hub. 41 | 42 | >>> spark.read.format("huggingface").option("split", "test").load("imdb").show() 43 | +--------------------+-----+ 44 | | text|label| 45 | +--------------------+-----+ 46 | |I love sci-fi and...| 0| 47 | |Worth the enterta...| 0| 48 | |... | ...| 49 | +--------------------+-----+ 50 | """ 51 | 52 | def __init__(self, options): 53 | super().__init__(options) 54 | if "path" not in options or not options["path"]: 55 | raise Exception("You must specify a dataset name in`.load()`.") 56 | 57 | @classmethod 58 | def name(cls): 59 | return "huggingface" 60 | 61 | def schema(self): 62 | # The imports must be inside the method to be serializable. 63 | from datasets import load_dataset_builder 64 | 65 | dataset_name = self.options["path"] 66 | ds_builder = load_dataset_builder(dataset_name) 67 | features = ds_builder.info.features 68 | if features is None: 69 | raise Exception( 70 | "Unable to automatically determine the schema using the dataset features. " 71 | "Please specify the schema manually using `.schema()`." 72 | ) 73 | schema = StructType() 74 | for key, value in features.items(): 75 | # For simplicity, use string for all values. 76 | schema.add(StructField(key, StringType(), True)) 77 | return schema 78 | 79 | def reader(self, schema: StructType) -> "DataSourceReader": 80 | return HuggingFaceDatasetsReader(schema, self.options) 81 | 82 | 83 | class HuggingFaceDatasetsReader(DataSourceReader): 84 | def __init__(self, schema: StructType, options: dict): 85 | self.schema = schema 86 | self.dataset_name = options["path"] 87 | # TODO: validate the split value. 88 | self.split = options.get("split", "train") # Default using train split. 89 | 90 | def read(self, partition): 91 | from datasets import load_dataset 92 | 93 | columns = [field.name for field in self.schema.fields] 94 | iter_dataset = load_dataset(self.dataset_name, split=self.split, streaming=True) 95 | for example in iter_dataset: 96 | yield tuple([example.get(column) for column in columns]) 97 | -------------------------------------------------------------------------------- /pyspark_datasources/jsonplaceholder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Iterator 2 | import requests 3 | from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition 4 | from pyspark.sql.types import StructType 5 | from pyspark.sql import Row 6 | 7 | 8 | class JSONPlaceholderDataSource(DataSource): 9 | """ 10 | A PySpark data source for JSONPlaceholder API. 11 | 12 | JSONPlaceholder is a free fake REST API for testing and prototyping. 13 | This data source provides access to posts, users, todos, comments, albums, and photos. 14 | 15 | Supported endpoints: 16 | - posts: Blog posts with userId, id, title, body 17 | - users: User profiles with complete information 18 | - todos: Todo items with userId, id, title, completed 19 | - comments: Comments with postId, id, name, email, body 20 | - albums: Albums with userId, id, title 21 | - photos: Photos with albumId, id, title, url, thumbnailUrl 22 | 23 | Name: `jsonplaceholder` 24 | 25 | Examples 26 | -------- 27 | Register the data source: 28 | 29 | >>> spark.dataSource.register(JSONPlaceholderDataSource) 30 | 31 | Read posts (default): 32 | 33 | >>> spark.read.format("jsonplaceholder").load().show() 34 | 35 | Read users: 36 | 37 | >>> spark.read.format("jsonplaceholder").option("endpoint", "users").load().show() 38 | 39 | Read with limit: 40 | 41 | >>> spark.read.format("jsonplaceholder").option("endpoint", "todos").option("limit", "5").load().show() 42 | 43 | Read specific item: 44 | 45 | >>> spark.read.format("jsonplaceholder").option("endpoint", "posts").option("id", "1").load().show() 46 | 47 | Referential Integrity 48 | ------------------- 49 | The data source supports joining related datasets: 50 | 51 | 1. Posts and Users relationship: 52 | posts.userId = users.id 53 | >>> posts_df = spark.read.format("jsonplaceholder").option("endpoint", "posts").load() 54 | >>> users_df = spark.read.format("jsonplaceholder").option("endpoint", "users").load() 55 | >>> posts_with_authors = posts_df.join(users_df, posts_df.userId == users_df.id) 56 | 57 | 2. Posts and Comments relationship: 58 | comments.postId = posts.id 59 | >>> comments_df = spark.read.format("jsonplaceholder").option("endpoint", "comments").load() 60 | >>> posts_with_comments = posts_df.join(comments_df, posts_df.id == comments_df.postId) 61 | 62 | 3. Users, Albums and Photos relationship: 63 | albums.userId = users.id 64 | photos.albumId = albums.id 65 | >>> albums_df = spark.read.format("jsonplaceholder").option("endpoint", "albums").load() 66 | >>> photos_df = spark.read.format("jsonplaceholder").option("endpoint", "photos").load() 67 | >>> user_albums = users_df.join(albums_df, users_df.id == albums_df.userId) 68 | >>> user_photos = user_albums.join(photos_df, albums_df.id == photos_df.albumId) 69 | """ 70 | 71 | @classmethod 72 | def name(cls) -> str: 73 | return "jsonplaceholder" 74 | 75 | def __init__(self, options=None): 76 | self.options = options or {} 77 | 78 | def schema(self) -> str: 79 | """ Returns the schema for the selected endpoint.""" 80 | schemas = { 81 | "posts": "userId INT, id INT, title STRING, body STRING", 82 | "users": ("id INT, name STRING, username STRING, email STRING, phone STRING, " 83 | "website STRING, address_street STRING, address_suite STRING, " 84 | "address_city STRING, address_zipcode STRING, address_geo_lat STRING, " 85 | "address_geo_lng STRING, company_name STRING, company_catchPhrase STRING, " 86 | "company_bs STRING"), 87 | "todos": "userId INT, id INT, title STRING, completed BOOLEAN", 88 | "comments": "postId INT, id INT, name STRING, email STRING, body STRING", 89 | "albums": "userId INT, id INT, title STRING", 90 | "photos": "albumId INT, id INT, title STRING, url STRING, thumbnailUrl STRING" 91 | } 92 | 93 | endpoint = self.options.get("endpoint", "posts") 94 | return schemas.get(endpoint, schemas["posts"]) 95 | 96 | def reader(self, schema: StructType) -> DataSourceReader: 97 | return JSONPlaceholderReader(self.options) 98 | 99 | 100 | class JSONPlaceholderReader(DataSourceReader): 101 | """Reader implementation for JSONPlaceholder API""" 102 | 103 | def __init__(self, options: Dict[str, str]): 104 | self.options = options 105 | self.base_url = "https://jsonplaceholder.typicode.com" 106 | 107 | self.endpoint = self.options.get("endpoint", "posts") 108 | self.limit = self.options.get("limit") 109 | self.id = self.options.get("id") 110 | 111 | def partitions(self) -> List[InputPartition]: 112 | return [InputPartition(0)] 113 | 114 | def read(self, partition: InputPartition) -> Iterator[Row]: 115 | url = f"{self.base_url}/{self.endpoint}" 116 | 117 | if self.id: 118 | url += f"/{self.id}" 119 | 120 | params = {} 121 | if self.limit and not self.id: 122 | params["_limit"] = self.limit 123 | 124 | try: 125 | response = requests.get(url, params=params, timeout=30) 126 | response.raise_for_status() 127 | 128 | data = response.json() 129 | 130 | if isinstance(data, dict): 131 | data = [data] 132 | elif not isinstance(data, list): 133 | data = [] 134 | 135 | return iter([self._process_item(item) for item in data]) 136 | 137 | except requests.RequestException as e: 138 | print(f"Failed to fetch data from {url}: {e}") 139 | return iter([]) 140 | except ValueError as e: 141 | print(f"Failed to parse JSON from {url}: {e}") 142 | return iter([]) 143 | except Exception as e: 144 | print(f"Unexpected error while reading data: {e}") 145 | return iter([]) 146 | 147 | def _process_item(self, item: Dict[str, Any]) -> Row: 148 | """Process individual items based on endpoint type""" 149 | 150 | def _process_posts(item): 151 | return Row( 152 | userId=item.get("userId"), 153 | id=item.get("id"), 154 | title=item.get("title", ""), 155 | body=item.get("body", "") 156 | ) 157 | 158 | def _process_users(item): 159 | address = item.get("address", {}) 160 | geo = address.get("geo", {}) 161 | company = item.get("company", {}) 162 | 163 | return Row( 164 | id=item.get("id"), 165 | name=item.get("name", ""), 166 | username=item.get("username", ""), 167 | email=item.get("email", ""), 168 | phone=item.get("phone", ""), 169 | website=item.get("website", ""), 170 | address_street=address.get("street", ""), 171 | address_suite=address.get("suite", ""), 172 | address_city=address.get("city", ""), 173 | address_zipcode=address.get("zipcode", ""), 174 | address_geo_lat=geo.get("lat", ""), 175 | address_geo_lng=geo.get("lng", ""), 176 | company_name=company.get("name", ""), 177 | company_catchPhrase=company.get("catchPhrase", ""), 178 | company_bs=company.get("bs", "") 179 | ) 180 | 181 | def _process_todos(item): 182 | return Row( 183 | userId=item.get("userId"), 184 | id=item.get("id"), 185 | title=item.get("title", ""), 186 | completed=item.get("completed", False) 187 | ) 188 | 189 | def _process_comments(item): 190 | return Row( 191 | postId=item.get("postId"), 192 | id=item.get("id"), 193 | name=item.get("name", ""), 194 | email=item.get("email", ""), 195 | body=item.get("body", "") 196 | ) 197 | 198 | def _process_albums(item): 199 | return Row( 200 | userId=item.get("userId"), 201 | id=item.get("id"), 202 | title=item.get("title", "") 203 | ) 204 | 205 | def _process_photos(item): 206 | return Row( 207 | albumId=item.get("albumId"), 208 | id=item.get("id"), 209 | title=item.get("title", ""), 210 | url=item.get("url", ""), 211 | thumbnailUrl=item.get("thumbnailUrl", "") 212 | ) 213 | 214 | processors = { 215 | "posts": _process_posts, 216 | "users": _process_users, 217 | "todos": _process_todos, 218 | "comments": _process_comments, 219 | "albums": _process_albums, 220 | "photos": _process_photos 221 | } 222 | 223 | processor = processors.get(self.endpoint, _process_posts) 224 | return processor(item) -------------------------------------------------------------------------------- /pyspark_datasources/kaggle.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from functools import cached_property 3 | from typing import TYPE_CHECKING, Iterator 4 | 5 | from pyspark.sql.datasource import DataSource, DataSourceReader 6 | from pyspark.sql.pandas.types import from_arrow_schema 7 | from pyspark.sql.types import StructType 8 | 9 | if TYPE_CHECKING: 10 | import pyarrow as pa 11 | 12 | 13 | class KaggleDataSource(DataSource): 14 | """ 15 | A DataSource for reading Kaggle datasets in Spark. 16 | 17 | This data source allows reading datasets from Kaggle directly into Spark DataFrames. 18 | 19 | Name: `kaggle` 20 | 21 | Options 22 | ------- 23 | - `handle`: The dataset handle on Kaggle, in the form of `{owner_slug}/{dataset_slug}` 24 | or `{owner_slug}/{dataset_slug}/versions/{version_number}` 25 | - `path`: The path to a file within the dataset. 26 | - `username`: The Kaggle username for authentication. 27 | - `key`: The Kaggle API key for authentication. 28 | 29 | Notes: 30 | ----- 31 | - The `kagglehub` library is required to use this data source. Make sure it is installed. 32 | - To read private datasets or datasets that require user authentication, `username` and `key` must be provided. 33 | - Currently all data is read from a single partition. 34 | 35 | Examples 36 | -------- 37 | Register the data source. 38 | 39 | >>> from pyspark_datasources import KaggleDataSource 40 | >>> spark.dataSource.register(KaggleDataSource) 41 | 42 | Load a public dataset from Kaggle. 43 | 44 | >>> spark.read.format("kaggle").options(handle="yasserh/titanic-dataset").load("Titanic-Dataset.csv").select("Name").show() 45 | +--------------------+ 46 | | Name| 47 | +--------------------+ 48 | |Braund, Mr. Owen ...| 49 | |Cumings, Mrs. Joh...| 50 | |... | 51 | +--------------------+ 52 | 53 | Load a private dataset with authentication. 54 | 55 | >>> spark.read.format("kaggle").options( 56 | ... username="myaccount", 57 | ... key="", 58 | ... handle="myaccount/my-private-dataset", 59 | ... ).load("file.csv").show() 60 | """ 61 | 62 | @classmethod 63 | def name(cls) -> str: 64 | return "kaggle" 65 | 66 | @cached_property 67 | def _data(self) -> "pa.Table": 68 | import ast 69 | import os 70 | 71 | import pyarrow as pa 72 | 73 | handle = self.options.pop("handle") 74 | path = self.options.pop("path") 75 | username = self.options.pop("username", None) 76 | key = self.options.pop("key", None) 77 | if username or key: 78 | if not (username and key): 79 | raise ValueError("Both username and key must be provided to authenticate.") 80 | os.environ["KAGGLE_USERNAME"] = username 81 | os.environ["KAGGLE_KEY"] = key 82 | 83 | kwargs = {k: ast.literal_eval(v) for k, v in self.options.items()} 84 | 85 | # Cache in a temporary directory to avoid writing to ~ which may be read-only 86 | with tempfile.TemporaryDirectory() as tmpdir: 87 | os.environ["KAGGLEHUB_CACHE"] = tmpdir 88 | import kagglehub 89 | 90 | df = kagglehub.dataset_load( 91 | kagglehub.KaggleDatasetAdapter.PANDAS, 92 | handle, 93 | path, 94 | **kwargs, 95 | ) 96 | return pa.Table.from_pandas(df) 97 | 98 | def schema(self) -> StructType: 99 | return from_arrow_schema(self._data.schema) 100 | 101 | def reader(self, schema: StructType) -> "KaggleDataReader": 102 | return KaggleDataReader(self) 103 | 104 | 105 | class KaggleDataReader(DataSourceReader): 106 | def __init__(self, source: KaggleDataSource): 107 | self.source = source 108 | 109 | def read(self, partition) -> Iterator["pa.RecordBatch"]: 110 | yield from self.source._data.to_batches() 111 | -------------------------------------------------------------------------------- /pyspark_datasources/lance.py: -------------------------------------------------------------------------------- 1 | import lance 2 | import pyarrow as pa 3 | 4 | from dataclasses import dataclass 5 | from typing import Iterator, List 6 | from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage 7 | from pyspark.sql.pandas.types import to_arrow_schema 8 | 9 | 10 | class LanceSink(DataSource): 11 | """ 12 | Write a Spark DataFrame into Lance format: https://lancedb.github.io/lance/index.html 13 | 14 | Note this requires Spark master branch nightly build to support `DataSourceArrowWriter`. 15 | 16 | Examples 17 | -------- 18 | Register the data source: 19 | 20 | >>> from pyspark_datasources import LanceSink 21 | >>> spark.dataSource.register(LanceSink) 22 | 23 | Create a Spark dataframe with 2 partitions: 24 | 25 | >>> df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], schema="id int, value string") 26 | 27 | Save the dataframe in lance format: 28 | 29 | >>> df.write.format("lance").mode("append").save("/tmp/test_lance") 30 | /tmp/test_lance 31 | _transactions _versions data 32 | 33 | Then you can use lance API to read the dataset: 34 | 35 | >>> import lance 36 | >>> ds = lance.LanceDataset("/tmp/test_lance") 37 | >>> ds.to_table().to_pandas() 38 | id 39 | 0 0 40 | 1 1 41 | 2 2 42 | 43 | Notes 44 | ----- 45 | - Currently this only works with Spark local mode. Cluster mode is not supported. 46 | """ 47 | 48 | @classmethod 49 | def name(cls) -> str: 50 | return "lance" 51 | 52 | def writer(self, schema, overwrite: bool): 53 | if overwrite: 54 | raise Exception("Overwrite mode is not supported") 55 | if "path" not in self.options: 56 | raise Exception("Dataset URI must be specified when calling save()") 57 | return LanceWriter(schema, overwrite, self.options) 58 | 59 | 60 | @dataclass 61 | class LanceCommitMessage(WriterCommitMessage): 62 | fragments: List[lance.FragmentMetadata] 63 | 64 | 65 | class LanceWriter(DataSourceArrowWriter): 66 | def __init__(self, schema, overwrite, options): 67 | self.options = options 68 | self.schema = schema # Spark Schema (pyspark.sql.types.StructType) 69 | self.arrow_schema = to_arrow_schema(schema) # Arrow schema (pa.StructType) 70 | self.uri = options["path"] 71 | assert not overwrite 72 | self.read_version = self._get_read_version() 73 | 74 | def _get_read_version(self): 75 | try: 76 | ds = lance.LanceDataset(self.uri) 77 | return ds.version 78 | except Exception: 79 | return None 80 | 81 | def write(self, iterator: Iterator[pa.RecordBatch]): 82 | reader = pa.RecordBatchReader.from_batches(self.arrow_schema, iterator) 83 | fragments = lance.fragment.write_fragments(reader, self.uri, schema=self.arrow_schema) 84 | return LanceCommitMessage(fragments=fragments) 85 | 86 | def commit(self, messages): 87 | fragments = [fragment for msg in messages for fragment in msg.fragments] 88 | if self.read_version: 89 | # This means the dataset already exists. 90 | op = lance.LanceOperation.Append(fragments) 91 | else: 92 | # Create a new dataset. 93 | schema = to_arrow_schema(self.schema) 94 | op = lance.LanceOperation.Overwrite(schema, fragments) 95 | lance.LanceDataset.commit(self.uri, op, read_version=self.read_version) 96 | -------------------------------------------------------------------------------- /pyspark_datasources/robinhood.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List, Optional, Generator, Union 3 | import requests 4 | import json 5 | import base64 6 | import datetime 7 | 8 | from pyspark.sql import Row 9 | from pyspark.sql.types import StructType 10 | from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition 11 | 12 | 13 | @dataclass 14 | class CryptoPair(InputPartition): 15 | """Represents a single crypto trading pair partition for parallel processing.""" 16 | 17 | symbol: str 18 | 19 | 20 | class RobinhoodDataReader(DataSourceReader): 21 | """Reader implementation for Robinhood Crypto API data source.""" 22 | 23 | def __init__(self, schema: StructType, options: Dict[str, str]) -> None: 24 | self.schema = schema 25 | self.options = options 26 | 27 | # Required API authentication 28 | self.api_key = options.get("api_key") 29 | self.private_key_base64 = options.get("private_key") 30 | 31 | if not self.api_key or not self.private_key_base64: 32 | raise ValueError( 33 | "Robinhood Crypto API requires both 'api_key' and 'private_key' options. " 34 | "The private_key should be base64-encoded. " 35 | "Get your API credentials from https://docs.robinhood.com/crypto/trading/" 36 | ) 37 | 38 | # Initialize NaCl signing key 39 | try: 40 | from nacl.signing import SigningKey 41 | 42 | private_key_seed = base64.b64decode(self.private_key_base64) 43 | self.signing_key = SigningKey(private_key_seed) 44 | except ImportError: 45 | raise ImportError( 46 | "PyNaCl library is required for Robinhood Crypto API authentication. " 47 | "Install it with: pip install pynacl" 48 | ) 49 | except Exception as e: 50 | raise ValueError(f"Invalid private key format: {str(e)}") 51 | 52 | # Crypto API base URL (configurable for testing) 53 | self.base_url = options.get("base_url", "https://trading.robinhood.com") 54 | 55 | def _get_current_timestamp(self) -> int: 56 | """Get current UTC timestamp.""" 57 | return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp()) 58 | 59 | def _generate_signature(self, timestamp: int, method: str, path: str, body: str = "") -> str: 60 | """Generate NaCl signature for API authentication following Robinhood's specification.""" 61 | # Official Robinhood signature format: f"{api_key}{current_timestamp}{path}{method}{body}" 62 | # For GET requests with no body, omit the body parameter 63 | if method.upper() == "GET" and not body: 64 | message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}" 65 | else: 66 | message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}{body}" 67 | 68 | signed = self.signing_key.sign(message_to_sign.encode("utf-8")) 69 | signature = base64.b64encode(signed.signature).decode("utf-8") 70 | return signature 71 | 72 | def _make_authenticated_request( 73 | self, 74 | method: str, 75 | path: str, 76 | params: Optional[Dict[str, str]] = None, 77 | json_data: Optional[Dict] = None, 78 | ) -> Optional[Dict]: 79 | """Make an authenticated request to the Robinhood Crypto API.""" 80 | timestamp = self._get_current_timestamp() 81 | url = self.base_url + path 82 | 83 | # Prepare request body for signature (only for non-GET requests) 84 | body = "" 85 | if method.upper() != "GET" and json_data: 86 | body = json.dumps(json_data, separators=(",", ":")) # Compact JSON format 87 | 88 | # Generate signature 89 | signature = self._generate_signature(timestamp, method, path, body) 90 | 91 | # Set authentication headers 92 | headers = { 93 | "x-api-key": self.api_key, 94 | "x-signature": signature, 95 | "x-timestamp": str(timestamp), 96 | } 97 | 98 | try: 99 | # Make request 100 | if method.upper() == "GET": 101 | response = requests.get(url, headers=headers, params=params, timeout=10) 102 | elif method.upper() == "POST": 103 | headers["Content-Type"] = "application/json" 104 | response = requests.post(url, headers=headers, json=json_data, timeout=10) 105 | else: 106 | response = requests.request( 107 | method, url, headers=headers, params=params, json=json_data, timeout=10 108 | ) 109 | 110 | response.raise_for_status() 111 | return response.json() 112 | except requests.RequestException as e: 113 | print(f"Error making API request to {path}: {e}") 114 | return None 115 | 116 | @staticmethod 117 | def _get_query_params(key: str, *args: str) -> str: 118 | """Build query parameters for API requests.""" 119 | if not args: 120 | return "" 121 | params = [f"{key}={arg}" for arg in args if arg] 122 | return "?" + "&".join(params) 123 | 124 | def partitions(self) -> List[CryptoPair]: 125 | """Create partitions for parallel processing of crypto pairs.""" 126 | # Use specified symbols from path 127 | symbols_str = self.options.get("path", "") 128 | if not symbols_str: 129 | raise ValueError("Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')") 130 | 131 | # Split symbols by comma and create partitions 132 | symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")] 133 | # Ensure proper format (e.g., BTC-USD) 134 | formatted_symbols = [] 135 | for symbol in symbols: 136 | if symbol and "-" not in symbol: 137 | symbol = f"{symbol}-USD" # Default to USD pair 138 | if symbol: 139 | formatted_symbols.append(symbol) 140 | 141 | return [CryptoPair(symbol=symbol) for symbol in formatted_symbols] 142 | 143 | def read(self, partition: CryptoPair) -> Generator[Row, None, None]: 144 | """Read crypto data for a single trading pair partition.""" 145 | symbol = partition.symbol 146 | 147 | try: 148 | yield from self._read_crypto_pair_data(symbol) 149 | except Exception as e: 150 | # Log error but don't fail the entire job 151 | print(f"Warning: Failed to fetch data for {symbol}: {str(e)}") 152 | 153 | def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]: 154 | """Fetch cryptocurrency market data for a given trading pair.""" 155 | try: 156 | # Get best bid/ask data for the trading pair using query parameters 157 | path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}" 158 | market_data = self._make_authenticated_request("GET", path) 159 | 160 | if market_data and "results" in market_data: 161 | for quote in market_data["results"]: 162 | # Parse numeric values safely 163 | def safe_float( 164 | value: Union[str, int, float, None], default: float = 0.0 165 | ) -> float: 166 | if value is None or value == "": 167 | return default 168 | try: 169 | return float(value) 170 | except (ValueError, TypeError): 171 | return default 172 | 173 | # Extract market data fields from best bid/ask response 174 | # Use the correct field names from the API response 175 | price = safe_float(quote.get("price")) 176 | bid_price = safe_float(quote.get("bid_inclusive_of_sell_spread")) 177 | ask_price = safe_float(quote.get("ask_inclusive_of_buy_spread")) 178 | 179 | yield Row( 180 | symbol=symbol, 181 | price=price, 182 | bid_price=bid_price, 183 | ask_price=ask_price, 184 | updated_at=quote.get("timestamp", ""), 185 | ) 186 | else: 187 | print(f"Warning: No market data found for {symbol}") 188 | 189 | except requests.exceptions.RequestException as e: 190 | print(f"Network error fetching data for {symbol}: {str(e)}") 191 | except (ValueError, KeyError) as e: 192 | print(f"Data parsing error for {symbol}: {str(e)}") 193 | except Exception as e: 194 | print(f"Unexpected error fetching data for {symbol}: {str(e)}") 195 | 196 | 197 | class RobinhoodDataSource(DataSource): 198 | """ 199 | A data source for reading cryptocurrency data from Robinhood Crypto API. 200 | 201 | This data source allows you to fetch real-time cryptocurrency market data, 202 | trading pairs, and price information using Robinhood's official Crypto API. 203 | It implements proper API key authentication and signature-based security. 204 | 205 | Name: `robinhood` 206 | 207 | Schema: `symbol string, price double, bid_price double, ask_price double, updated_at string` 208 | 209 | Examples 210 | -------- 211 | Register the data source: 212 | 213 | >>> from pyspark_datasources import RobinhoodDataSource 214 | >>> spark.dataSource.register(RobinhoodDataSource) 215 | 216 | Load cryptocurrency market data with API authentication: 217 | 218 | >>> df = spark.read.format("robinhood") \\ 219 | ... .option("api_key", "your-api-key") \\ 220 | ... .option("private_key", "your-base64-private-key") \\ 221 | ... .load("BTC-USD,ETH-USD,DOGE-USD") 222 | >>> df.show() 223 | +--------+--------+---------+---------+--------------------+ 224 | | symbol| price|bid_price|ask_price| updated_at| 225 | +--------+--------+---------+---------+--------------------+ 226 | |BTC-USD |45000.50|45000.25 |45000.75 |2024-01-15T16:00:...| 227 | |ETH-USD | 2650.75| 2650.50 | 2651.00 |2024-01-15T16:00:...| 228 | |DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...| 229 | +--------+--------+---------+---------+--------------------+ 230 | 231 | 232 | 233 | Options 234 | ------- 235 | - api_key: string (required) — Robinhood Crypto API key. 236 | - private_key: string (required) — Base64-encoded Ed25519 private key seed. 237 | - base_url: string (optional, default "https://trading.robinhood.com") — Override for sandbox/testing. 238 | 239 | Errors 240 | ------ 241 | - Raises ValueError when required options are missing or private_key is invalid. 242 | - Network/API errors are logged and skipped per symbol; no rows are emitted for failed symbols. 243 | 244 | Partitioning 245 | ------------ 246 | - One partition per requested trading pair (e.g., "BTC-USD,ETH-USD"). Symbols are uppercased and auto-appended with "-USD" if missing pair format. 247 | 248 | Arrow 249 | ----- 250 | - Rows are yielded directly; Arrow-based batches can be added in future for improved performance. 251 | 252 | Notes 253 | ----- 254 | - Requires 'pynacl' for Ed25519 signing: pip install pynacl 255 | - Refer to official Robinhood documentation for authentication details. 256 | """ 257 | 258 | @classmethod 259 | def name(cls) -> str: 260 | return "robinhood" 261 | 262 | def schema(self) -> str: 263 | return "symbol string, price double, bid_price double, ask_price double, updated_at string" 264 | 265 | def reader(self, schema: StructType) -> RobinhoodDataReader: 266 | return RobinhoodDataReader(schema, self.options) 267 | -------------------------------------------------------------------------------- /pyspark_datasources/salesforce.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Dict, List, Any 4 | 5 | from pyspark.sql.types import StructType 6 | from pyspark.sql.datasource import DataSource, DataSourceStreamWriter, WriterCommitMessage 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @dataclass 12 | class SalesforceCommitMessage(WriterCommitMessage): 13 | """Commit message for Salesforce write operations.""" 14 | 15 | records_written: int 16 | batch_id: int 17 | 18 | 19 | class SalesforceDataSource(DataSource): 20 | """ 21 | A Salesforce streaming datasource for PySpark to write data to Salesforce objects. 22 | 23 | This datasource enables writing streaming data from Spark to Salesforce using the 24 | Salesforce REST API. It supports common Salesforce objects like Account, Contact, 25 | Opportunity, and custom objects. 26 | 27 | Note: This is a write-only datasource, not a full bidirectional data source. 28 | 29 | Name: `salesforce` 30 | 31 | Notes 32 | ----- 33 | - Requires the `simple-salesforce` library for Salesforce API integration 34 | - **Write-only datasource**: Only supports streaming write operations (no read operations) 35 | - Uses Salesforce username/password/security token authentication 36 | - Supports batch writing with Salesforce Composite Tree API for efficient processing 37 | - Implements exactly-once semantics through Spark's checkpoint mechanism 38 | - If a streaming write job fails and is resumed from the checkpoint, 39 | it will not overwrite records already written in Salesforce; 40 | it resumes from the last committed offset. 41 | However, if records were written to Salesforce but not yet committed at the time of failure, 42 | duplicate records may occur after recovery. 43 | 44 | Parameters 45 | ---------- 46 | username : str 47 | Salesforce username (email address) 48 | password : str 49 | Salesforce password 50 | security_token : str 51 | Salesforce security token (obtained from Salesforce setup) 52 | salesforce_object : str, optional 53 | Target Salesforce object name (default: "Account") 54 | batch_size : str, optional 55 | Number of records to process per batch (default: "200") 56 | instance_url : str, optional 57 | Custom Salesforce instance URL (auto-detected if not provided) 58 | schema : str, optional 59 | Custom schema definition for the Salesforce object. If not provided, 60 | uses the default Account schema. Should be in Spark SQL DDL format. 61 | Example: "Name STRING NOT NULL, Industry STRING, AnnualRevenue DOUBLE" 62 | 63 | Examples 64 | -------- 65 | Register the Salesforce Datasource: 66 | 67 | >>> from pyspark_datasources import SalesforceDataSource 68 | >>> spark.dataSource.register(SalesforceDataSource) 69 | 70 | Write streaming data to Salesforce Accounts: 71 | 72 | >>> from pyspark.sql import SparkSession 73 | >>> from pyspark.sql.functions import col, lit 74 | >>> 75 | >>> spark = SparkSession.builder.appName("SalesforceExample").getOrCreate() 76 | >>> spark.dataSource.register(SalesforceDataSource) 77 | >>> 78 | >>> # Create sample streaming data 79 | >>> streaming_df = spark.readStream.format("rate").load() 80 | >>> account_data = streaming_df.select( 81 | ... col("value").cast("string").alias("Name"), 82 | ... lit("Technology").alias("Industry"), 83 | ... (col("value") * 100000).cast("double").alias("AnnualRevenue") 84 | ... ) 85 | >>> 86 | >>> # Write to Salesforce using the datasource 87 | >>> query = account_data.writeStream \\ 88 | ... .format("pyspark.datasource.salesforce") \\ 89 | ... .option("username", "your-username@company.com") \\ 90 | ... .option("password", "your-password") \\ 91 | ... .option("security_token", "your-security-token") \\ 92 | ... .option("salesforce_object", "Account") \\ 93 | ... .option("batch_size", "100") \\ 94 | ... .option("checkpointLocation", "/path/to/checkpoint") \\ 95 | ... .start() 96 | 97 | Write to Salesforce Contacts: 98 | 99 | >>> contact_data = streaming_df.select( 100 | ... col("value").cast("string").alias("FirstName"), 101 | ... lit("Doe").alias("LastName"), 102 | ... lit("contact@example.com").alias("Email") 103 | ... ) 104 | >>> 105 | >>> query = contact_data.writeStream \\ 106 | ... .format("pyspark.datasource.salesforce") \\ 107 | ... .option("username", "your-username@company.com") \\ 108 | ... .option("password", "your-password") \\ 109 | ... .option("security_token", "your-security-token") \\ 110 | ... .option("salesforce_object", "Contact") \\ 111 | ... .option("checkpointLocation", "/path/to/checkpoint") \\ 112 | ... .start() 113 | 114 | Write to custom Salesforce objects: 115 | 116 | >>> custom_data = streaming_df.select( 117 | ... col("value").cast("string").alias("Custom_Field__c"), 118 | ... lit("Custom Value").alias("Another_Field__c") 119 | ... ) 120 | >>> 121 | >>> query = custom_data.writeStream \\ 122 | ... .format("pyspark.datasource.salesforce") \\ 123 | ... .option("username", "your-username@company.com") \\ 124 | ... .option("password", "your-password") \\ 125 | ... .option("security_token", "your-security-token") \\ 126 | ... .option("salesforce_object", "Custom_Object__c") \\ 127 | ... .option("checkpointLocation", "/path/to/checkpoint") \\ 128 | ... .start() 129 | 130 | Using custom schema for specific Salesforce objects: 131 | 132 | >>> # Define schema for Contact object as a DDL string 133 | >>> contact_schema = "FirstName STRING NOT NULL, LastName STRING NOT NULL, Email STRING, Phone STRING" 134 | >>> 135 | >>> query = contact_data.writeStream \\ 136 | ... .format("pyspark.datasource.salesforce") \\ 137 | ... .option("username", "your-username@company.com") \\ 138 | ... .option("password", "your-password") \\ 139 | ... .option("security_token", "your-security-token") \\ 140 | ... .option("salesforce_object", "Contact") \\ 141 | ... .option("schema", "FirstName STRING NOT NULL, LastName STRING NOT NULL, Email STRING, Phone STRING") \\ 142 | ... .option("batch_size", "50") \\ 143 | ... .option("checkpointLocation", "/path/to/checkpoint") \\ 144 | ... .start() 145 | 146 | Using schema with Opportunity object: 147 | 148 | >>> opportunity_data = streaming_df.select( 149 | ... col("name").alias("Name"), 150 | ... col("amount").alias("Amount"), 151 | ... col("stage").alias("StageName"), 152 | ... col("close_date").alias("CloseDate") 153 | ... ) 154 | >>> 155 | >>> query = opportunity_data.writeStream \\ 156 | ... .format("pyspark.datasource.salesforce") \\ 157 | ... .option("username", "your-username@company.com") \\ 158 | ... .option("password", "your-password") \\ 159 | ... .option("security_token", "your-security-token") \\ 160 | ... .option("salesforce_object", "Opportunity") \\ 161 | ... .option("schema", "Name STRING NOT NULL, Amount DOUBLE, StageName STRING NOT NULL, CloseDate DATE") \\ 162 | ... .option("checkpointLocation", "/path/to/checkpoint") \\ 163 | ... .start() 164 | 165 | Key Features: 166 | 167 | - **Write-only datasource**: Designed specifically for writing data to Salesforce 168 | - **Batch processing**: Uses Salesforce Composite Tree API for efficient bulk writes 169 | - **Exactly-once semantics**: Integrates with Spark's checkpoint mechanism 170 | - **Error handling**: Graceful fallback to individual record creation if batch fails 171 | - **Flexible schema**: Supports any Salesforce object with custom schema definition 172 | """ 173 | 174 | @classmethod 175 | def name(cls) -> str: 176 | """Return the short name for this Salesforce datasource.""" 177 | return "pyspark.datasource.salesforce" 178 | 179 | def schema(self) -> str: 180 | """ 181 | Return the schema for Salesforce objects. 182 | 183 | If the user provides a 'schema' option, use it. 184 | Otherwise, return the default Account schema. 185 | """ 186 | user_schema = self.options.get("schema") 187 | if user_schema: 188 | return user_schema 189 | return """ 190 | Name STRING NOT NULL, 191 | Industry STRING, 192 | Phone STRING, 193 | Website STRING, 194 | AnnualRevenue DOUBLE, 195 | NumberOfEmployees INT, 196 | BillingStreet STRING, 197 | BillingCity STRING, 198 | BillingState STRING, 199 | BillingPostalCode STRING, 200 | BillingCountry STRING 201 | """ 202 | 203 | def streamWriter(self, schema: StructType, overwrite: bool) -> "SalesforceStreamWriter": 204 | """Create a stream writer for Salesforce datasource integration.""" 205 | return SalesforceStreamWriter(schema, self.options) 206 | 207 | 208 | class SalesforceStreamWriter(DataSourceStreamWriter): 209 | """Stream writer implementation for Salesforce datasource integration.""" 210 | 211 | def __init__(self, schema: StructType, options: Dict[str, str]): 212 | self.schema = schema 213 | self.options = options 214 | 215 | # Extract Salesforce configuration 216 | self.username = options.get("username") 217 | self.password = options.get("password") 218 | self.security_token = options.get("security_token") 219 | self.instance_url = options.get("instance_url") 220 | self.salesforce_object = options.get("salesforce_object", "Account") 221 | self.batch_size = int(options.get("batch_size", "200")) 222 | 223 | # Validate required options 224 | if not all([self.username, self.password, self.security_token]): 225 | raise ValueError( 226 | "Salesforce username, password, and security_token are required. " 227 | "Set them using .option() method in your streaming query." 228 | ) 229 | 230 | logger.info(f"Initializing Salesforce writer for object '{self.salesforce_object}'") 231 | 232 | def write(self, iterator) -> SalesforceCommitMessage: 233 | """Write data to Salesforce.""" 234 | # Import here to avoid serialization issues 235 | try: 236 | from simple_salesforce import Salesforce 237 | except ImportError: 238 | raise ImportError( 239 | "simple-salesforce library is required for Salesforce integration. " 240 | "Install it with: pip install simple-salesforce" 241 | ) 242 | 243 | from pyspark import TaskContext 244 | 245 | # Get task context for batch identification 246 | context = TaskContext.get() 247 | batch_id = context.taskAttemptId() 248 | 249 | # Connect to Salesforce 250 | try: 251 | sf_kwargs = { 252 | "username": self.username, 253 | "password": self.password, 254 | "security_token": self.security_token, 255 | } 256 | if self.instance_url: 257 | sf_kwargs["instance_url"] = self.instance_url 258 | 259 | sf = Salesforce(**sf_kwargs) 260 | logger.info(f"✓ Connected to Salesforce (batch {batch_id})") 261 | except Exception as e: 262 | logger.error(f"Failed to connect to Salesforce: {str(e)}") 263 | raise ConnectionError(f"Salesforce connection failed: {str(e)}") 264 | 265 | # Convert rows to Salesforce records and write in batches to avoid memory issues 266 | records_buffer = [] 267 | total_records_written = 0 268 | 269 | def flush_buffer(): 270 | nonlocal total_records_written 271 | if records_buffer: 272 | try: 273 | written = self._write_to_salesforce(sf, records_buffer, batch_id) 274 | logger.info( 275 | f"✅ Batch {batch_id}: Successfully wrote {written} records (buffer flush)" 276 | ) 277 | total_records_written += written 278 | except Exception as e: 279 | logger.error( 280 | f"❌ Batch {batch_id}: Failed to write records during buffer flush: {str(e)}" 281 | ) 282 | raise 283 | records_buffer.clear() 284 | 285 | for row in iterator: 286 | try: 287 | record = self._convert_row_to_salesforce_record(row) 288 | if record: # Only add non-empty records 289 | records_buffer.append(record) 290 | if len(records_buffer) >= self.batch_size: 291 | flush_buffer() 292 | except Exception as e: 293 | logger.warning(f"Failed to convert row to Salesforce record: {str(e)}") 294 | 295 | # Flush any remaining records in the buffer 296 | if records_buffer: 297 | flush_buffer() 298 | 299 | if total_records_written == 0: 300 | logger.info(f"No valid records to write in batch {batch_id}") 301 | else: 302 | logger.info( 303 | f"✅ Batch {batch_id}: Successfully wrote {total_records_written} records (total)" 304 | ) 305 | 306 | return SalesforceCommitMessage(records_written=total_records_written, batch_id=batch_id) 307 | 308 | def _convert_row_to_salesforce_record(self, row) -> Dict[str, Any]: 309 | """Convert a Spark Row to a Salesforce record format.""" 310 | record = {} 311 | 312 | for field in self.schema.fields: 313 | field_name = field.name 314 | try: 315 | # Use getattr for safe field access 316 | value = getattr(row, field_name, None) 317 | 318 | if value is not None: 319 | # Convert value based on field type 320 | if hasattr(value, "isoformat"): # datetime objects 321 | record[field_name] = value.isoformat() 322 | elif isinstance(value, (int, float)): 323 | record[field_name] = value 324 | else: 325 | record[field_name] = str(value) 326 | 327 | except Exception as e: 328 | logger.warning(f"Failed to convert field '{field_name}': {str(e)}") 329 | 330 | return record 331 | 332 | def _write_to_salesforce(self, sf, records: List[Dict[str, Any]], batch_id: int) -> int: 333 | """Write records to Salesforce using REST API.""" 334 | success_count = 0 335 | 336 | # Process records in batches using sObject Collections API 337 | for i in range(0, len(records), self.batch_size): 338 | batch_records = records[i : i + self.batch_size] 339 | 340 | try: 341 | # Use Composite Tree API for batch creation (up to 200 records) 342 | # Prepare records for batch API 343 | collection_records = [] 344 | for idx, record in enumerate(batch_records): 345 | # Add required attributes for Composite Tree API 346 | record_with_attributes = { 347 | "attributes": { 348 | "type": self.salesforce_object, 349 | "referenceId": f"ref{i + idx}", 350 | }, 351 | **record, 352 | } 353 | collection_records.append(record_with_attributes) 354 | 355 | # Make batch API call using Composite Tree API 356 | # This API is specifically designed for batch inserts 357 | payload = {"records": collection_records} 358 | 359 | response = sf.restful( 360 | f"composite/tree/{self.salesforce_object}", method="POST", json=payload 361 | ) 362 | 363 | # Count successful records 364 | # Composite Tree API returns a different response format 365 | if isinstance(response, dict): 366 | # Check if the batch was successful 367 | if response.get("hasErrors", True) is False: 368 | # All records in the batch were created successfully 369 | success_count += len(batch_records) 370 | else: 371 | # Some records failed, check individual results 372 | results = response.get("results", []) 373 | for result in results: 374 | if "id" in result: 375 | success_count += 1 376 | else: 377 | errors = result.get("errors", []) 378 | for error in errors: 379 | logger.warning( 380 | f"Failed to create record {result.get('referenceId', 'unknown')}: {error.get('message', 'Unknown error')}" 381 | ) 382 | else: 383 | logger.error(f"Unexpected response format: {response}") 384 | 385 | except Exception as e: 386 | logger.error( 387 | f"Error in batch creation for batch {i//self.batch_size + 1}: {str(e)}" 388 | ) 389 | # Fallback to individual record creation for this batch 390 | try: 391 | sf_object = getattr(sf, self.salesforce_object) 392 | for j, record in enumerate(batch_records): 393 | try: 394 | # Create the record in Salesforce 395 | result = sf_object.create(record) 396 | 397 | if result.get("success"): 398 | success_count += 1 399 | else: 400 | logger.warning( 401 | f"Failed to create record {i+j}: {result.get('errors', 'Unknown error')}" 402 | ) 403 | 404 | except Exception as e: 405 | logger.error(f"Error creating record {i+j}: {str(e)}") 406 | except AttributeError: 407 | raise ValueError(f"Salesforce object '{self.salesforce_object}' not found") 408 | 409 | # Log progress for large batches 410 | if len(records) > 50 and (i + self.batch_size) % 100 == 0: 411 | logger.info( 412 | f"Batch {batch_id}: Processed {i + self.batch_size}/{len(records)} records" 413 | ) 414 | 415 | return success_count 416 | -------------------------------------------------------------------------------- /pyspark_datasources/simplejson.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import time 4 | 5 | from dataclasses import dataclass 6 | from typing import Dict, List 7 | 8 | from pyspark.sql.types import StructType 9 | from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage 10 | 11 | 12 | class SimpleJsonDataSource(DataSource): 13 | """ 14 | A simple json writer for writing data to Databricks DBFS. 15 | 16 | Examples 17 | -------- 18 | 19 | >>> import pyspark.sql.functions as sf 20 | >>> df = spark.range(0, 10, 1, 2).withColumn("value", sf.expr("concat('value_', id)")) 21 | 22 | Register the data source. 23 | 24 | >>> from pyspark_datasources import SimpleJsonDataSource 25 | >>> spark.dataSource.register(SimpleJsonDataSource) 26 | 27 | Append the DataFrame to a DBFS path as json files. 28 | 29 | >>> ( 30 | ... df.write.format("simplejson") 31 | ... .mode("append") 32 | ... .option("databricks_url", "https://your-databricks-instance.cloud.databricks.com") 33 | ... .option("databricks_token", "your-token") 34 | ... .save("/path/to/output") 35 | ... ) 36 | 37 | Overwrite the DataFrame to a DBFS path as json files. 38 | 39 | >>> ( 40 | ... df.write.format("simplejson") 41 | ... .mode("overwrite") 42 | ... .option("databricks_url", "https://your-databricks-instance.cloud.databricks.com") 43 | ... .option("databricks_token", "your-token") 44 | ... .save("/path/to/output") 45 | ... ) 46 | """ 47 | 48 | @classmethod 49 | def name(self) -> str: 50 | return "simplejson" 51 | 52 | def writer(self, schema: StructType, overwrite: bool): 53 | return SimpleJsonWriter(schema, self.options, overwrite) 54 | 55 | 56 | @dataclass 57 | class CommitMessage(WriterCommitMessage): 58 | output_path: str 59 | 60 | 61 | class SimpleJsonWriter(DataSourceWriter): 62 | def __init__(self, schema: StructType, options: Dict, overwrite: bool): 63 | self.overwrite = overwrite 64 | self.databricks_url = options.get("databricks_url") 65 | self.databricks_token = options.get("databricks_token") 66 | if not self.databricks_url or not self.databricks_token: 67 | raise Exception("Databricks URL and token must be specified") 68 | self.path = options.get("path") 69 | if not self.path: 70 | raise Exception("You must specify an output path") 71 | 72 | def write(self, iterator): 73 | # Important: Always import non-serializable libraries inside the `write` method. 74 | from pyspark import TaskContext 75 | from databricks.sdk import WorkspaceClient 76 | 77 | # Consume all input rows and dump them as json. 78 | rows = [row.asDict() for row in iterator] 79 | json_data = json.dumps(rows) 80 | f = io.BytesIO(json_data.encode("utf-8")) 81 | 82 | context = TaskContext.get() 83 | id = context.taskAttemptId() 84 | file_path = f"{self.path}/{id}_{time.time_ns()}.json" 85 | 86 | # Upload to DFBS. 87 | w = WorkspaceClient(host=self.databricks_url, token=self.databricks_token) 88 | w.dbfs.upload(file_path, f) 89 | 90 | return CommitMessage(output_path=file_path) 91 | 92 | def commit(self, messages: List[CommitMessage]): 93 | from databricks.sdk import WorkspaceClient 94 | 95 | w = WorkspaceClient(host=self.databricks_url, token=self.databricks_token) 96 | paths = [message.output_path for message in messages] 97 | 98 | if self.overwrite: 99 | # Remove all files in the current directory except for the newly written files. 100 | for file in w.dbfs.list(self.path): 101 | if file.path not in paths: 102 | print(f"[Overwrite] Removing file {file.path}") 103 | w.dbfs.delete(file.path) 104 | 105 | # Write a success file 106 | file_path = f"{self.path}/_SUCCESS" 107 | f = io.BytesIO(b"success") 108 | w.dbfs.upload(file_path, f, overwrite=True) 109 | 110 | def abort(self, messages: List[CommitMessage]): 111 | from databricks.sdk import WorkspaceClient 112 | 113 | w = WorkspaceClient(host=self.databricks_url, token=self.databricks_token) 114 | # Clean up the newly written files 115 | for message in messages: 116 | if message is not None: 117 | print(f"[Abort] Removing up partially written files: {message.output_path}") 118 | w.dbfs.delete(message.output_path) 119 | -------------------------------------------------------------------------------- /pyspark_datasources/stock.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict 3 | 4 | from pyspark.sql.types import StructType 5 | from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition 6 | 7 | 8 | class StockDataSource(DataSource): 9 | """ 10 | A data source for reading stock data using the Alpha Vantage API. 11 | 12 | Examples 13 | -------- 14 | 15 | Load the daily stock data for SPY: 16 | 17 | >>> df = spark.read.format("stock").option("api_key", "your-key").load("SPY") 18 | >>> df.show(n=5) 19 | +----------+------+------+------+------+--------+------+ 20 | | date| open| high| low| close| volume|symbol| 21 | +----------+------+------+------+------+--------+------+ 22 | |2024-06-04|526.46|529.15|524.96|528.39|33898396| SPY| 23 | |2024-06-03|529.02|529.31| 522.6| 527.8|46835702| SPY| 24 | |2024-05-31|523.59| 527.5|518.36|527.37|90785755| SPY| 25 | |2024-05-30|524.52| 525.2|521.33|522.61|46468510| SPY| 26 | |2024-05-29|525.68|527.31|525.37| 526.1|45190323| SPY| 27 | +----------+------+------+------+------+--------+------+ 28 | """ 29 | 30 | @classmethod 31 | def name(self) -> str: 32 | return "stock" 33 | 34 | def schema(self) -> str: 35 | return ( 36 | "date string, open double, high double, " 37 | "low double, close double, volume long, symbol string" 38 | ) 39 | 40 | def reader(self, schema): 41 | return StockDataReader(schema, self.options) 42 | 43 | 44 | @dataclass 45 | class Symbol(InputPartition): 46 | name: str 47 | 48 | 49 | class StockDataReader(DataSourceReader): 50 | def __init__(self, schema: StructType, options: Dict): 51 | self.schema = schema 52 | self.options = options 53 | self.api_key = options.get("api_key") 54 | if not self.api_key: 55 | raise Exception(f"API Key is required to load the data.") 56 | # The name of the time-series. See https://www.alphavantage.co/documentation/ for more info. 57 | self.function = options.get("function", "TIME_SERIES_DAILY") 58 | if self.function not in ("TIME_SERIES_DAILY", "TIME_SERIES_WEEKLY", "TIME_SERIES_MONTHLY"): 59 | raise Exception(f"Function `{self.function}` is not supported.") 60 | 61 | def partitions(self): 62 | names = self.options["path"] 63 | # Split the names by comma and create a partition for each symbol. 64 | return [Symbol(name.strip()) for name in names.split(",")] 65 | 66 | def read(self, partition: Symbol): 67 | import requests 68 | 69 | symbol = partition.name 70 | resp = requests.get( 71 | f"https://www.alphavantage.co/query?" 72 | f"function={self.function}&symbol={symbol}&apikey={self.api_key}" 73 | ) 74 | resp.raise_for_status() 75 | data = resp.json() 76 | key_name = next(key for key in data.keys() if key != "Meta Data") 77 | for date, info in data[key_name].items(): 78 | yield ( 79 | date, 80 | float(info["1. open"]), 81 | float(info["2. high"]), 82 | float(info["3. low"]), 83 | float(info["4. close"]), 84 | float(info["5. volume"]), 85 | symbol, 86 | ) 87 | -------------------------------------------------------------------------------- /pyspark_datasources/weather.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import requests 4 | 5 | from pyspark.sql.datasource import DataSource, SimpleDataSourceStreamReader 6 | from pyspark.sql.types import StructType, StructField, DoubleType, StringType 7 | 8 | 9 | class WeatherDataSource(DataSource): 10 | """ 11 | A custom PySpark data source for fetching weather data from tomorrow.io for given 12 | locations (latitude, longitude). 13 | 14 | Options 15 | ------- 16 | 17 | - locations: specify a list of (latitude, longitude) tuples. 18 | - apikey: specify the API key for the weather service (tomorrow.io). 19 | - frequency: specify the frequency of the data ("minutely", "hourly", "daily"). 20 | Default is "minutely". 21 | 22 | Examples 23 | -------- 24 | 25 | Register the data source. 26 | 27 | >>> from pyspark_datasources import WeatherDataSource 28 | >>> spark.dataSource.register(WeatherDataSource) 29 | 30 | 31 | Define the options for the custom data source 32 | 33 | >>> options = { 34 | ... "locations": "[(37.7749, -122.4194), (40.7128, -74.0060)]", # San Francisco and New York 35 | ... "apikey": "your_api_key_here", 36 | ... } 37 | 38 | Create a DataFrame using the custom weather data source 39 | 40 | >>> weather_df = spark.readStream.format("weather").options(**options).load() 41 | 42 | Stream weather data and print the results to the console in real-time. 43 | 44 | >>> query = weather_df.writeStream.format("console").trigger(availableNow=True).start() 45 | """ 46 | 47 | @classmethod 48 | def name(cls): 49 | """Returns the name of the data source.""" 50 | return "weather" 51 | 52 | def __init__(self, options): 53 | """Initialize with options provided.""" 54 | self.options = options 55 | self.frequency = options.get("frequency", "minutely") 56 | if self.frequency not in ["minutely", "hourly", "daily"]: 57 | raise ValueError(f"Unsupported frequency: {self.frequency}") 58 | 59 | def schema(self): 60 | """Defines the output schema of the data source.""" 61 | return StructType( 62 | [ 63 | StructField("latitude", DoubleType(), True), 64 | StructField("longitude", DoubleType(), True), 65 | StructField("weather", StringType(), True), 66 | StructField("timestamp", StringType(), True), 67 | ] 68 | ) 69 | 70 | def simpleStreamReader(self, schema: StructType): 71 | """Returns an instance of the reader for this data source.""" 72 | return WeatherSimpleStreamReader(schema, self.options) 73 | 74 | 75 | class WeatherSimpleStreamReader(SimpleDataSourceStreamReader): 76 | def initialOffset(self): 77 | """ 78 | Returns the initial offset for reading, which serves as the starting point for 79 | the streaming data source. 80 | 81 | The initial offset is returned as a dictionary where each key is a unique identifier 82 | for a specific (latitude, longitude) pair, and each value is a timestamp string 83 | (in ISO 8601 format) representing the point in time from which data should start being 84 | read. 85 | 86 | Example: 87 | For locations [(37.7749, -122.4194), (40.7128, -74.0060)], the offset might look like: 88 | { 89 | "offset_37.7749_-122.4194": "2024-09-01T00:00:00Z", 90 | "offset_40.7128_-74.0060": "2024-09-01T00:00:00Z" 91 | } 92 | """ 93 | return {f"offset_{lat}_{long}": "2024-09-01T00:00:00Z" for (lat, long) in self.locations} 94 | 95 | @staticmethod 96 | def _parse_locations(locations_str: str): 97 | """Converts string representation of list of tuples to actual list of tuples.""" 98 | return [tuple(map(float, x)) for x in ast.literal_eval(locations_str)] 99 | 100 | def __init__(self, schema: StructType, options: dict): 101 | """Initialize with schema and options.""" 102 | super().__init__() 103 | self.schema = schema 104 | self.locations = self._parse_locations(options.get("locations", "[]")) 105 | self.api_key = options.get("apikey", "") 106 | self.current = 0 107 | self.frequency = options.get("frequency", "minutely") 108 | self.session = requests.Session() # Use a session for connection pooling 109 | 110 | def read(self, start: dict): 111 | """Reads data starting from the given offset.""" 112 | data = [] 113 | new_offset = {} 114 | for lat, long in self.locations: 115 | start_ts = start[f"offset_{lat}_{long}"] 116 | weather = self._fetch_weather(lat, long, self.api_key, self.session)[self.frequency] 117 | for entry in weather: 118 | # Start time is exclusive and end time is inclusive. 119 | if entry["time"] > start_ts: 120 | data.append((lat, long, json.dumps(entry["values"]), entry["time"])) 121 | new_offset.update({f"offset_{lat}_{long}": weather[-1]["time"]}) 122 | return (data, new_offset) 123 | 124 | @staticmethod 125 | def _fetch_weather(lat: float, long: float, api_key: str, session): 126 | """Fetches weather data for the given latitude and longitude using a REST API.""" 127 | url = f"https://api.tomorrow.io/v4/weather/forecast?location={lat},{long}&apikey={api_key}" 128 | response = session.get(url) 129 | response.raise_for_status() 130 | return response.json()["timelines"] 131 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allisonwang-db/pyspark-data-sources/4e877ceb342113c9d9cdfbb90371d2c1441a8b3b/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_data_sources.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tempfile 3 | import os 4 | import pyarrow as pa 5 | 6 | from pyspark.sql import SparkSession 7 | from pyspark_datasources import * 8 | 9 | 10 | @pytest.fixture 11 | def spark(): 12 | spark = SparkSession.builder.getOrCreate() 13 | yield spark 14 | 15 | 16 | def test_github_datasource(spark): 17 | spark.dataSource.register(GithubDataSource) 18 | df = spark.read.format("github").load("apache/spark") 19 | prs = df.collect() 20 | assert len(prs) > 0 21 | 22 | 23 | def test_fake_datasource_stream(spark): 24 | spark.dataSource.register(FakeDataSource) 25 | ( 26 | spark.readStream.format("fake") 27 | .load() 28 | .writeStream.format("memory") 29 | .queryName("result") 30 | .trigger(once=True) 31 | .start() 32 | .awaitTermination() 33 | ) 34 | spark.sql("SELECT * FROM result").show() 35 | assert spark.sql("SELECT * FROM result").count() == 3 36 | 37 | 38 | def test_fake_datasource(spark): 39 | spark.dataSource.register(FakeDataSource) 40 | df = spark.read.format("fake").load() 41 | df.show() 42 | assert df.count() == 3 43 | assert len(df.columns) == 4 44 | 45 | 46 | def test_kaggle_datasource(spark): 47 | spark.dataSource.register(KaggleDataSource) 48 | df = ( 49 | spark.read.format("kaggle") 50 | .options(handle="yasserh/titanic-dataset") 51 | .load("Titanic-Dataset.csv") 52 | ) 53 | df.show() 54 | assert df.count() == 891 55 | assert len(df.columns) == 12 56 | 57 | 58 | def test_opensky_datasource_stream(spark): 59 | spark.dataSource.register(OpenSkyDataSource) 60 | ( 61 | spark.readStream.format("opensky") 62 | .option("region", "EUROPE") 63 | .load() 64 | .writeStream.format("memory") 65 | .queryName("opensky_result") 66 | .trigger(once=True) 67 | .start() 68 | .awaitTermination() 69 | ) 70 | result = spark.sql("SELECT * FROM opensky_result") 71 | result.show() 72 | assert len(result.columns) == 18 # Check schema has expected number of fields 73 | assert result.count() > 0 # Verify we got some data 74 | 75 | def test_salesforce_datasource_registration(spark): 76 | """Test that Salesforce DataSource can be registered and validates required options.""" 77 | spark.dataSource.register(SalesforceDataSource) 78 | 79 | # Test that the datasource is registered with correct name 80 | assert SalesforceDataSource.name() == "pyspark.datasource.salesforce" 81 | 82 | # Test that the data source is streaming-only (no batch writer) 83 | from pyspark.sql.functions import lit 84 | 85 | try: 86 | # Try to use batch write - should fail since we only support streaming 87 | df = spark.range(1).select( 88 | lit("Test Company").alias("Name"), 89 | lit("Technology").alias("Industry"), 90 | lit(50000.0).alias("AnnualRevenue"), 91 | ) 92 | 93 | df.write.format("pyspark.datasource.salesforce").mode("append").save() 94 | assert False, "Should have raised error - Salesforce DataSource only supports streaming" 95 | except Exception as e: 96 | # This is expected - Salesforce DataSource only supports streaming writes 97 | error_msg = str(e).lower() 98 | # The error can be about unsupported mode or missing writer 99 | assert "unsupported" in error_msg or "writer" in error_msg or "not implemented" in error_msg 100 | 101 | 102 | def test_arrow_datasource_single_file(spark): 103 | """Test reading a single Arrow file.""" 104 | spark.dataSource.register(ArrowDataSource) 105 | 106 | # Create test data 107 | test_data = pa.table( 108 | {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "age": [25, 30, 35]} 109 | ) 110 | 111 | # Write to temporary Arrow file 112 | with tempfile.NamedTemporaryFile(suffix=".arrow", delete=False) as tmp_file: 113 | tmp_path = tmp_file.name 114 | 115 | try: 116 | with pa.ipc.new_file(tmp_path, test_data.schema) as writer: 117 | writer.write_table(test_data) 118 | 119 | # Read using Arrow data source 120 | df = spark.read.format("arrow").load(tmp_path) 121 | 122 | # Verify results 123 | assert df.count() == 3 124 | assert len(df.columns) == 3 125 | assert set(df.columns) == {"id", "name", "age"} 126 | 127 | # Verify data content 128 | rows = df.collect() 129 | assert len(rows) == 3 130 | assert rows[0]["name"] == "Alice" 131 | 132 | finally: 133 | # Clean up 134 | if os.path.exists(tmp_path): 135 | os.unlink(tmp_path) 136 | 137 | 138 | def test_arrow_datasource_multiple_files(spark): 139 | """Test reading multiple Arrow files from a directory.""" 140 | spark.dataSource.register(ArrowDataSource) 141 | 142 | # Create test data for multiple files 143 | test_data1 = pa.table( 144 | {"id": [1, 2], "name": ["Alice", "Bob"], "department": ["Engineering", "Sales"]} 145 | ) 146 | 147 | test_data2 = pa.table( 148 | {"id": [3, 4], "name": ["Charlie", "Diana"], "department": ["Marketing", "HR"]} 149 | ) 150 | 151 | # Create temporary directory 152 | with tempfile.TemporaryDirectory() as tmp_dir: 153 | # Write multiple Arrow files 154 | file1_path = os.path.join(tmp_dir, "data1.arrow") 155 | file2_path = os.path.join(tmp_dir, "data2.arrow") 156 | 157 | with pa.ipc.new_file(file1_path, test_data1.schema) as writer: 158 | writer.write_table(test_data1) 159 | 160 | with pa.ipc.new_file(file2_path, test_data2.schema) as writer: 161 | writer.write_table(test_data2) 162 | 163 | # Read using Arrow data source from directory 164 | df = spark.read.format("arrow").load(tmp_dir) 165 | 166 | # Verify results 167 | assert df.count() == 4 # 2 rows from each file 168 | assert len(df.columns) == 3 169 | assert set(df.columns) == {"id", "name", "department"} 170 | 171 | # Verify partitioning (should have 2 partitions, one per file) 172 | assert df.rdd.getNumPartitions() == 2 173 | 174 | # Verify all data is present 175 | rows = df.collect() 176 | names = {row["name"] for row in rows} 177 | assert names == {"Alice", "Bob", "Charlie", "Diana"} 178 | 179 | def test_jsonplaceholder_posts(spark): 180 | spark.dataSource.register(JSONPlaceholderDataSource) 181 | posts_df = spark.read.format("jsonplaceholder").option("endpoint", "posts").load() 182 | assert posts_df.count() > 0 # Ensure we have some posts 183 | 184 | 185 | def test_jsonplaceholder_referential_integrity(spark): 186 | spark.dataSource.register(JSONPlaceholderDataSource) 187 | users_df = spark.read.format("jsonplaceholder").option("endpoint", "users").load() 188 | assert users_df.count() > 0 # Ensure we have some users 189 | posts_df = spark.read.format("jsonplaceholder").option("endpoint", "posts").load() 190 | posts_with_authors = posts_df.join(users_df, posts_df.userId == users_df.id) 191 | assert posts_with_authors.count() > 0 # Ensure join is valid and we have posts with authors 192 | -------------------------------------------------------------------------------- /tests/test_google_sheets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pyspark.errors.exceptions.captured import AnalysisException, PythonException 3 | from pyspark.sql import SparkSession 4 | 5 | from pyspark_datasources import GoogleSheetsDataSource 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def spark(): 10 | spark = SparkSession.builder.getOrCreate() 11 | spark.dataSource.register(GoogleSheetsDataSource) 12 | yield spark 13 | 14 | 15 | def test_url(spark): 16 | url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797" 17 | df = spark.read.format("googlesheets").options(url=url).load() 18 | df.show() 19 | assert df.count() == 2 20 | assert len(df.columns) == 2 21 | assert df.schema.simpleString() == "struct" 22 | 23 | 24 | def test_spreadsheet_id(spark): 25 | df = spark.read.format("googlesheets").load("10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0") 26 | df.show() 27 | assert df.count() == 2 28 | assert len(df.columns) == 2 29 | 30 | 31 | def test_missing_options(spark): 32 | with pytest.raises(AnalysisException) as excinfo: 33 | spark.read.format("googlesheets").load() 34 | assert "ValueError" in str(excinfo.value) 35 | 36 | 37 | def test_mutual_exclusive_options(spark): 38 | with pytest.raises(AnalysisException) as excinfo: 39 | spark.read.format("googlesheets").options( 40 | url="a", 41 | spreadsheet_id="b", 42 | ).load() 43 | assert "ValueError" in str(excinfo.value) 44 | 45 | 46 | def test_custom_schema(spark): 47 | url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797" 48 | df = spark.read.format("googlesheets").options(url=url).schema("a double, b string").load() 49 | df.show() 50 | assert df.count() == 2 51 | assert len(df.columns) == 2 52 | assert df.schema.simpleString() == "struct" 53 | 54 | 55 | def test_custom_schema_mismatch_count(spark): 56 | url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797" 57 | df = spark.read.format("googlesheets").options(url=url).schema("a double").load() 58 | with pytest.raises(PythonException) as excinfo: 59 | df.show() 60 | assert "CSV parse error" in str(excinfo.value) 61 | 62 | 63 | def test_unnamed_column(spark): 64 | url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1579451727#gid=1579451727" 65 | df = spark.read.format("googlesheets").options(url=url).load() 66 | df.show() 67 | assert df.count() == 1 68 | assert df.columns == ["Unnamed: 0", "1", "Unnamed: 2"] 69 | 70 | 71 | def test_duplicate_column(spark): 72 | url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1875209731#gid=1875209731" 73 | df = spark.read.format("googlesheets").options(url=url).load() 74 | df.show() 75 | assert df.count() == 1 76 | assert df.columns == ["a", "a.1"] 77 | 78 | 79 | def test_no_header_row(spark): 80 | url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1579451727#gid=1579451727" 81 | df = ( 82 | spark.read.format("googlesheets") 83 | .schema("a int, b int, c int") 84 | .options(url=url, has_header="false") 85 | .load() 86 | ) 87 | df.show() 88 | assert df.count() == 2 89 | assert len(df.columns) == 3 90 | 91 | 92 | def test_empty(spark): 93 | url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=2123944555#gid=2123944555" 94 | with pytest.raises(AnalysisException) as excinfo: 95 | spark.read.format("googlesheets").options(url=url).load() 96 | assert "EmptyDataError" in str(excinfo.value) 97 | -------------------------------------------------------------------------------- /tests/test_robinhood.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from unittest.mock import Mock, patch 4 | from pyspark.sql import SparkSession, Row 5 | from pyspark.errors.exceptions.captured import AnalysisException 6 | 7 | from pyspark_datasources import RobinhoodDataSource 8 | 9 | 10 | @pytest.fixture 11 | def spark(): 12 | """Create SparkSession for testing.""" 13 | spark = SparkSession.builder.getOrCreate() 14 | spark.dataSource.register(RobinhoodDataSource) 15 | yield spark 16 | 17 | 18 | def test_robinhood_datasource_registration(spark): 19 | """Test that RobinhoodDataSource can be registered.""" 20 | # Test registration 21 | assert RobinhoodDataSource.name() == "robinhood" 22 | 23 | # Test schema 24 | expected_schema = ( 25 | "symbol string, price double, bid_price double, ask_price double, updated_at string" 26 | ) 27 | datasource = RobinhoodDataSource({}) 28 | assert datasource.schema() == expected_schema 29 | 30 | 31 | def test_robinhood_missing_credentials(spark): 32 | """Test that missing API credentials raises an error.""" 33 | with pytest.raises(AnalysisException) as excinfo: 34 | df = spark.read.format("robinhood").load("BTC-USD") 35 | df.collect() # Trigger execution 36 | 37 | assert "ValueError" in str(excinfo.value) and ( 38 | "api_key" in str(excinfo.value) or "private_key" in str(excinfo.value) 39 | ) 40 | 41 | 42 | def test_robinhood_missing_symbols(spark): 43 | """Test that missing symbols raises an error.""" 44 | with pytest.raises(AnalysisException) as excinfo: 45 | df = ( 46 | spark.read.format("robinhood") 47 | .option("api_key", "test-key") 48 | .option("private_key", "FAPmPMsQqDFOFiRvpUMJ6BC5eFOh/tPx7qcTYGKc8nE=") 49 | .load("") 50 | ) 51 | df.collect() # Trigger execution 52 | 53 | assert "ValueError" in str(excinfo.value) and "crypto pairs" in str(excinfo.value) 54 | 55 | 56 | def test_robinhood_invalid_private_key_format(spark): 57 | """Test that invalid private key format raises proper error.""" 58 | with pytest.raises(AnalysisException) as excinfo: 59 | df = ( 60 | spark.read.format("robinhood") 61 | .option("api_key", "test-key") 62 | .option("private_key", "invalid-key-format") 63 | .load("BTC-USD") 64 | ) 65 | df.collect() # Trigger execution 66 | 67 | assert "Invalid private key format" in str(excinfo.value) 68 | 69 | 70 | def test_robinhood_btc_data(spark): 71 | """Test BTC-USD data retrieval with registered API key - REQUIRES API CREDENTIALS.""" 72 | # Get credentials from environment variables 73 | api_key = os.environ.get("ROBINHOOD_API_KEY") 74 | private_key = os.environ.get("ROBINHOOD_PRIVATE_KEY") 75 | 76 | if not api_key or not private_key: 77 | pytest.skip( 78 | "ROBINHOOD_API_KEY and ROBINHOOD_PRIVATE_KEY environment variables required for real API tests" 79 | ) 80 | 81 | # Test loading BTC-USD data 82 | df = ( 83 | spark.read.format("robinhood") 84 | .option("api_key", api_key) 85 | .option("private_key", private_key) 86 | .load("BTC-USD") 87 | ) 88 | 89 | rows = df.collect() 90 | print(f"Retrieved {len(rows)} rows") 91 | 92 | # CRITICAL: Test MUST fail if no data is returned 93 | assert len(rows) > 0, "TEST FAILED: No data returned! Expected at least 1 BTC-USD record." 94 | 95 | for i, row in enumerate(rows): 96 | print(f"Row {i + 1}: {row}") 97 | 98 | # Validate data structure 99 | assert row.symbol == "BTC-USD", f"Expected BTC-USD, got {row.symbol}" 100 | assert isinstance(row.price, (int, float)), ( 101 | f"Price should be numeric, got {type(row.price)}" 102 | ) 103 | assert row.price > 0, f"Price should be > 0, got {row.price}" 104 | assert isinstance(row.bid_price, (int, float)), ( 105 | f"Bid price should be numeric, got {type(row.bid_price)}" 106 | ) 107 | assert isinstance(row.ask_price, (int, float)), ( 108 | f"Ask price should be numeric, got {type(row.ask_price)}" 109 | ) 110 | assert isinstance(row.updated_at, str), ( 111 | f"Updated timestamp should be string, got {type(row.updated_at)}" 112 | ) 113 | 114 | 115 | def test_robinhood_multiple_crypto_pairs(spark): 116 | """Test multi-crypto data retrieval with registered API key - REQUIRES API CREDENTIALS.""" 117 | # Get credentials from environment variables 118 | api_key = os.environ.get("ROBINHOOD_API_KEY") 119 | private_key = os.environ.get("ROBINHOOD_PRIVATE_KEY") 120 | 121 | if not api_key or not private_key: 122 | pytest.skip( 123 | "ROBINHOOD_API_KEY and ROBINHOOD_PRIVATE_KEY environment variables required for real API tests" 124 | ) 125 | 126 | # Test loading multiple crypto pairs 127 | df = ( 128 | spark.read.format("robinhood") 129 | .option("api_key", api_key) 130 | .option("private_key", private_key) 131 | .load("BTC-USD,ETH-USD,DOGE-USD") 132 | ) 133 | 134 | rows = df.collect() 135 | print(f"Retrieved {len(rows)} rows") 136 | 137 | # CRITICAL: Test MUST fail if no data is returned 138 | assert len(rows) > 0, "TEST FAILED: No data returned! Expected at least 1 crypto record." 139 | 140 | # CRITICAL: Should get data for all 3 requested pairs 141 | assert len(rows) >= 3, f"TEST FAILED: Expected 3 crypto pairs, got {len(rows)} records." 142 | 143 | symbols_found = set() 144 | 145 | for i, row in enumerate(rows): 146 | symbols_found.add(row.symbol) 147 | print(f"Row {i + 1}: {row}") 148 | 149 | # Validate each record 150 | assert isinstance(row.symbol, str), f"Symbol should be string, got {type(row.symbol)}" 151 | assert isinstance(row.price, (int, float)), ( 152 | f"Price should be numeric, got {type(row.price)}" 153 | ) 154 | assert row.price > 0, f"Price should be > 0, got {row.price}" 155 | assert isinstance(row.bid_price, (int, float)), ( 156 | f"Bid price should be numeric, got {type(row.bid_price)}" 157 | ) 158 | assert isinstance(row.ask_price, (int, float)), ( 159 | f"Ask price should be numeric, got {type(row.ask_price)}" 160 | ) 161 | assert isinstance(row.updated_at, str), ( 162 | f"Updated timestamp should be string, got {type(row.updated_at)}" 163 | ) 164 | 165 | # Test passes only if we have real data for the requested pairs 166 | assert len(symbols_found) >= 3, ( 167 | f"Expected at least 3 different symbols, got {len(symbols_found)}: {symbols_found}" 168 | ) 169 | --------------------------------------------------------------------------------