├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ ├── documentation.md │ └── feature_request.md └── workflows │ ├── add-labels.yaml │ └── release-uberjar.yaml ├── .gitignore ├── .scalafmt.conf ├── LICENSE ├── README.md ├── build.sbt ├── project ├── build.properties └── plugins.sbt ├── src ├── it │ ├── resources │ │ ├── sample1.jsonl │ │ └── sample2.jsonl │ └── scala │ │ └── io │ │ └── pinecone │ │ └── spark │ │ └── pinecone │ │ ├── BatchUpsertExample.scala │ │ └── StreamUpsertExample.scala ├── main │ └── scala │ │ └── io │ │ └── pinecone │ │ └── spark │ │ └── pinecone │ │ ├── Pinecone.scala │ │ ├── PineconeBatchWriter.scala │ │ ├── PineconeDataWriter.scala │ │ ├── PineconeDataWriterFactory.scala │ │ ├── PineconeIndex.scala │ │ ├── PineconeOptions.scala │ │ ├── PineconeStreamingWriter.scala │ │ ├── PineconeWrite.scala │ │ ├── PineconeWriteBuilder.scala │ │ └── package.scala └── test │ ├── resources │ ├── invalidUpsertInput1.jsonl │ ├── invalidUpsertInput2.jsonl │ ├── invalidUpsertInput3.jsonl │ ├── invalidUpsertInput4.jsonl │ ├── invalidUpsertInput5.jsonl │ ├── invalidUpsertInput6.jsonl │ ├── invalidUpsertInput7.jsonl │ ├── invalidUpsertInput8.jsonl │ └── invalidUpsertInput9.jsonl │ └── scala │ └── io │ └── pinecone │ └── spark │ └── pinecone │ ├── ParseCommonSchemaTest.scala │ └── ParseMetadataSpec.scala └── version.sbt /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug] " 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is this a new bug?** 11 | In other words: Is this an error, flaw, failure or fault? Please search Github issues and check our [Community Forum](https://community.pinecone.io/) to see if someone has already reported the bug you encountered. 12 | 13 | If this is a request for help or troubleshooting code in your own Pinecone project, please join the [Pinecone Community Forum](https://community.pinecone.io/). 14 | 15 | - [ ] I believe this is a new bug 16 | - [ ] I have searched the existing Github issues and Community Forum, and I could not find an existing post for this bug 17 | 18 | **Describe the bug** 19 | Describe the functionality that was working before but is broken now. 20 | 21 | **Error information** 22 | If you have one, please include the full stack trace here. If not, please share as much as you can about the error. 23 | 24 | **Steps to reproduce the issue locally** 25 | Include steps to reproduce the issue here. If you have sample code or a script that can be used to replicate this issue, please include that as well (including any dependent files to run the code). 26 | 27 | **Environment** 28 | * Scala Version: 29 | * Spark version: 30 | * Spark Pinecone Connector version: 31 | 32 | **Additional context** 33 | Add any other context about the problem here. 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Pinecone Community Forum 4 | url: https://community.pinecone.io/ 5 | about: For support, please see the community forum. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation 3 | about: Report an issue in our docs 4 | title: "[Docs] " 5 | labels: 'documentation' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Description** 11 | Describe the issue that you've encountered with our documentation. 12 | 13 | **Suggested solution** 14 | Describe how this issue could be fixed or improved. 15 | **Link to page** 16 | Add a link to the exact documentation page where the issue occurred. 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature Request]" 5 | labels: 'enhancement' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **What motivated you to submit this feature request?** 11 | A clear and concise description of why you are requesting this feature - e.g. "Being able to do x would allow me to..." 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/add-labels.yaml: -------------------------------------------------------------------------------- 1 | name: Label issues 2 | on: 3 | issues: 4 | types: 5 | - reopened 6 | - opened 7 | jobs: 8 | label_issues: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | issues: write 12 | steps: 13 | - run: gh issue edit "$NUMBER" --add-label "$LABELS" 14 | env: 15 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 16 | GH_REPO: ${{ github.repository }} 17 | NUMBER: ${{ github.event.issue.number }} 18 | LABELS: status:needs-triage 19 | -------------------------------------------------------------------------------- /.github/workflows/release-uberjar.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Uber JAR 2 | 3 | on: [workflow_dispatch] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout code 10 | uses: actions/checkout@v2 11 | 12 | - name: Set up JDK 13 | uses: actions/setup-java@v1 14 | with: 15 | distribution: temurin 16 | java-version: 8 17 | 18 | - name: Build and Test 19 | run: sbt -v +test 20 | 21 | - name: Build Uber JAR 22 | run: sbt assembly 23 | 24 | - name: Get Version 25 | id: version 26 | run: echo ::set-output name=snapshot::$(sbt -no-colors 'print version' | tail -n 1) 27 | 28 | - name: Show Version 29 | run: echo $(sbt -no-colors 'print version' | tail -n 1) 30 | 31 | - name: Output Version 32 | run: echo "JAR_VERSION=${{ steps.version.outputs.snapshot }}" >> $GITHUB_ENV 33 | 34 | - name: Publish Uber JAR 35 | uses: actions/upload-artifact@v2 36 | with: 37 | name: spark-pinecone-uber-jar-${{ env.JAR_VERSION }}.jar 38 | path: /home/runner/work/spark-pinecone/spark-pinecone/target/scala-2.12/spark-pinecone-${{ env.JAR_VERSION }}.jar 39 | 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .bsp/ 3 | 4 | *.class 5 | *.log 6 | 7 | dist/* 8 | target/ 9 | lib_managed/ 10 | src_managed/ 11 | project/boot/ 12 | project/plugins/project/ 13 | .history 14 | .cache 15 | .lib/ 16 | 17 | .DS_Store 18 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version = 3.4.3 2 | 3 | align.preset = more 4 | maxColumn = 100 5 | runner.dialect = scala212 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spark-pinecone 2 | The official [pinecone.io](https://pinecone.io) spark connector. 3 | 4 | ## Features 5 | - Please note that the connector's write operation is not atomic - some vectors might be written while others aren't if the operation is stopped or if it fails. 6 | In practice this shouldn't cause a serious issue. Pinecone is an idempotent key-value store. Re-running the job will result in the desired state without a need to clear the index or calculate some delta from the source data. 7 | - The client currently only supports batch writing of data into pinecone from a specific schema (see the example below). 8 | If you need to use the connector with a streaming pipeline, it is recommended to use a function like `foreachBatch`. 9 | 10 | ## Support 11 | This client currently supports Spark 3.5.0, Scala 2.12.X or 2.13.X and Java 8+. 12 | - For Scala 2.12, use `spark-pinecone_2.12.jar`: https://central.sonatype.com/artifact/io.pinecone/spark-pinecone_2.12. 13 | - For Scala 2.13, use `spark-pinecone_2.13.jar`: https://central.sonatype.com/artifact/io.pinecone/spark-pinecone_2.13. 14 | 15 | Make sure to add the correct JAR file to your project's dependencies according to your Scala version. 16 | 17 | ### Databricks and friends 18 | Due to various libraries provided by Databricks, please use the assembly jar from s3 to avoid dependency conflict. 19 | S3 path for assembly jar: 20 | 1. v1.2.0 (latest): s3://pinecone-jars/1.2.0/spark-pinecone-uberjar.jar 21 | 2. v1.1.0: s3://pinecone-jars/1.1.0/spark-pinecone-uberjar.jar 22 | 3. v1.0.0: s3://pinecone-jars/1.0.0/spark-pinecone-uberjar.jar 23 | 4. v0.2.2: s3://pinecone-jars/0.2.2/spark-pinecone-uberjar.jar 24 | 5. v0.2.1: s3://pinecone-jars/0.2.1/spark-pinecone-uberjar.jar 25 | 6. v0.1.4: s3://pinecone-jars/spark-pinecone-uberjar.jar 26 | 27 | ## Example 28 | To connect to Pinecone with Spark you'll have to retrieve the api key from [your Pinecone console](https://app.pinecone.io). 29 | Navigate to your project and click the "API Keys" button on the sidebar. The sample.jsonl file used in the examples below 30 | can be found [here](https://github.com/pinecone-io/spark-pinecone/blob/main/src/it/resources/sample.jsonl). 31 | 32 | ### Batch upsert 33 | Below are examples in Python and Scala for batch upserting vectors in Pinecone DB. 34 | 35 | #### Python 36 | ```python 37 | from pyspark import SparkConf 38 | from pyspark.sql import SparkSession 39 | from pyspark.sql.types import StructType, StructField, ArrayType, FloatType, StringType, LongType 40 | 41 | # Your API key and index name 42 | api_key = "PINECONE_API_KEY" 43 | index_name = "PINECONE_INDEX_NAME" 44 | source_tag = "PINECONE_SOURCE_TAG" 45 | 46 | COMMON_SCHEMA = StructType([ 47 | StructField("id", StringType(), False), 48 | StructField("namespace", StringType(), True), 49 | StructField("values", ArrayType(FloatType(), False), False), 50 | StructField("metadata", StringType(), True), 51 | StructField("sparse_values", StructType([ 52 | StructField("indices", ArrayType(LongType(), False), False), 53 | StructField("values", ArrayType(FloatType(), False), False) 54 | ]), True) 55 | ]) 56 | 57 | # Initialize Spark 58 | spark = SparkSession.builder.getOrCreate() 59 | 60 | # Read the file and apply the schema 61 | df = spark.read \ 62 | .option("multiLine", value = True) \ 63 | .option("mode", "PERMISSIVE") \ 64 | .schema(COMMON_SCHEMA) \ 65 | .json("src/test/resources/sample.jsonl") 66 | 67 | # Show if the read was successful 68 | df.show() 69 | 70 | # Write the dataFrame to Pinecone in batches 71 | df.write \ 72 | .option("pinecone.apiKey", api_key) \ 73 | .option("pinecone.indexName", index_name) \ 74 | .option("pinecone.sourceTag", source_tag) \ 75 | .format("io.pinecone.spark.pinecone.Pinecone") \ 76 | .mode("append") \ 77 | .save() 78 | ``` 79 | 80 | ### Scala 81 | ```scala 82 | import io.pinecone.spark.pinecone.{COMMON_SCHEMA, PineconeOptions} 83 | import org.apache.spark.SparkConf 84 | import org.apache.spark.sql.{SaveMode, SparkSession} 85 | 86 | object MainApp extends App { 87 | // Your API key and index name 88 | val apiKey = "PINECONE_API_KEY" 89 | val indexName = "PINECONE_INDEX_NAME" 90 | val sourceTag = "PINECONE_SOURCE_TAG" 91 | 92 | // Configure Spark to run locally with all available cores 93 | val conf = new SparkConf() 94 | .setMaster("local[*]") 95 | 96 | // Create a Spark session with the defined configuration 97 | val spark = SparkSession.builder().config(conf).getOrCreate() 98 | 99 | // Read the JSON file into a DataFrame, applying the COMMON_SCHEMA 100 | val df = spark.read 101 | .option("multiLine", value = true) 102 | .option("mode", "PERMISSIVE") 103 | .schema(COMMON_SCHEMA) 104 | .json("src/test/resources/sample.jsonl") // path to sample.jsonl 105 | 106 | // Define Pinecone options as a Map 107 | val pineconeOptions = Map( 108 | PineconeOptions.PINECONE_API_KEY_CONF -> apiKey, 109 | PineconeOptions.PINECONE_INDEX_NAME_CONF -> indexName, 110 | PineconeOptions.PINECONE_SOURCE_TAG_CONF -> sourceTag 111 | ) 112 | 113 | // Show if the read was successful 114 | df.show(df.count().toInt) 115 | 116 | // Write the DataFrame to Pinecone using the defined options in batches 117 | df.write 118 | .options(pineconeOptions) 119 | .format("io.pinecone.spark.pinecone.Pinecone") 120 | .mode(SaveMode.Append) 121 | .save() 122 | } 123 | ``` 124 | 125 | 126 | ### Stream upsert 127 | Below are examples in Python and Scala for streaming upserts of vectors in Pinecone DB. 128 | 129 | #### Python 130 | ```python 131 | from pyspark.sql import SparkSession 132 | from pyspark.sql.types import StructType, StructField, ArrayType, FloatType, StringType, LongType 133 | import os 134 | 135 | # Your API key and index name 136 | api_key = "PINECONE_API_KEY" 137 | index_name = "PINECONE_INDEX_NAME" 138 | source_tag = "PINECONE_SOURCE_TAG" 139 | 140 | COMMON_SCHEMA = StructType([ 141 | StructField("id", StringType(), False), 142 | StructField("namespace", StringType(), True), 143 | StructField("values", ArrayType(FloatType(), False), False), 144 | StructField("metadata", StringType(), True), 145 | StructField("sparse_values", StructType([ 146 | StructField("indices", ArrayType(LongType(), False), False), 147 | StructField("values", ArrayType(FloatType(), False), False) 148 | ]), True) 149 | ]) 150 | 151 | # Initialize Spark session 152 | spark = SparkSession.builder \ 153 | .appName("StreamUpsertExample") \ 154 | .config("spark.sql.shuffle.partitions", 3) \ 155 | .master("local") \ 156 | .getOrCreate() 157 | 158 | # Read the stream of JSON files, applying the schema from the input directory 159 | lines = spark.readStream \ 160 | .option("multiLine", True) \ 161 | .option("mode", "PERMISSIVE") \ 162 | .schema(COMMON_SCHEMA) \ 163 | .json("path/to/input/directory/") 164 | 165 | # Write the stream to Pinecone using the defined options 166 | upsert = lines.writeStream \ 167 | .format("io.pinecone.spark.pinecone.Pinecone") \ 168 | .option("pinecone.apiKey", api_key) \ 169 | .option("pinecone.indexName", index_name) \ 170 | .option("pinecone.sourceTag", source_tag) \ 171 | .option("checkpointLocation", "path/to/checkpoint/dir") \ 172 | .outputMode("append") \ 173 | .start() 174 | 175 | upsert.awaitTermination() 176 | ``` 177 | 178 | ### Scala 179 | ```scala 180 | import io.pinecone.spark.pinecone.{COMMON_SCHEMA, PineconeOptions} 181 | import org.apache.spark.SparkConf 182 | import org.apache.spark.sql.{SaveMode, SparkSession} 183 | 184 | object MainApp extends App { 185 | // Your API key and index name 186 | val apiKey = "PINECONE_API_KEY" 187 | val indexName = "PINECONE_INDEX_NAME" 188 | 189 | // Create a Spark session 190 | val spark = SparkSession.builder() 191 | .appName("StreamUpsertExample") 192 | .config("spark.sql.shuffle.partitions", 3) 193 | .master("local") 194 | .getOrCreate() 195 | 196 | // Read the JSON files into a DataFrame, applying the COMMON_SCHEMA from input directory 197 | val lines = spark.readStream 198 | .option("multiLine", value = true) 199 | .option("mode", "PERMISSIVE") 200 | .schema(COMMON_SCHEMA) 201 | .json("path/to/input/directory/") 202 | 203 | // Define Pinecone options as a Map 204 | val pineconeOptions = Map( 205 | PineconeOptions.PINECONE_API_KEY_CONF -> System.getenv("PINECONE_API_KEY"), 206 | PineconeOptions.PINECONE_INDEX_NAME_CONF -> System.getenv("PINECONE_INDEX"), 207 | PineconeOptions.PINECONE_SOURCE_TAG_CONF -> System.getenv("PINECONE_SOURCE_TAG") 208 | ) 209 | 210 | // Write the stream to Pinecone using the defined options 211 | val upsert = lines 212 | .writeStream 213 | .format("io.pinecone.spark.pinecone.Pinecone") 214 | .options(pineconeOptions) 215 | .option("checkpointLocation", "path/to/checkpoint/dir") 216 | .outputMode("append") 217 | .start() 218 | 219 | upsert.awaitTermination() 220 | } 221 | ``` 222 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import ReleaseTransformations._ 2 | 3 | lazy val sparkVersion = "3.5.0" 4 | 5 | ThisBuild / sonatypeCredentialHost := "s01.oss.sonatype.org" 6 | ThisBuild / sonatypeRepository := "https://s01.oss.sonatype.org/service/local" 7 | 8 | lazy val root = (project in file(".")) 9 | .configs(IntegrationTest) 10 | .settings( 11 | name := "spark-pinecone", 12 | organizationName := "Pinecone Systems", 13 | organizationHomepage := Some(url("http://pinecone.io/")), 14 | organization := "io.pinecone", 15 | licenses := Seq(("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0"))), 16 | description := "A spark connector for the Pinecone Vector Database", 17 | developers := List( 18 | Developer( 19 | "adamgs", 20 | "Adam Gutglick", 21 | "adam@pinecone.io", 22 | url("https://github.com/pinecone-io") 23 | ), 24 | Developer( 25 | "rajat08", 26 | "Rajat Tripathi", 27 | "rajat@pinecone.io", 28 | url("https://github.com/pinecone-io") 29 | ), 30 | Developer( 31 | "rohanshah18", 32 | "Rohan Shah", 33 | "rohan.s@pinecone.io", 34 | url("https://github.com/pinecone-io") 35 | ) 36 | ), 37 | versionScheme := Some("semver-spec"), 38 | scalaVersion := "2.12.15", 39 | scmInfo := Some( 40 | ScmInfo( 41 | url("https://github.com/pinecone-io/spark-pinecone"), 42 | "scm:git:git@github.com:pinecone-io/spark-pinecone.git" 43 | ) 44 | ), 45 | homepage := Some(url("https://github.com/pinecone-io/spark-pinecone")), 46 | Defaults.itSettings, 47 | crossScalaVersions := Seq("2.12.15", "2.13.8"), 48 | javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), 49 | libraryDependencies ++= Seq( 50 | "io.pinecone" % "pinecone-client" % "1.2.2", 51 | "org.scalatest" %% "scalatest" % "3.2.11" % "it,test", 52 | "org.apache.spark" %% "spark-core" % sparkVersion % "provided,test", 53 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided,test", 54 | "org.apache.spark" %% "spark-catalyst" % sparkVersion % "provided,test" 55 | ), 56 | Test / fork := true, 57 | assembly / assemblyShadeRules := Seq( 58 | ShadeRule 59 | .rename("com.google.protobuf.**" -> "shaded.protobuf.@1") 60 | .inAll, 61 | ShadeRule 62 | .rename("com.google.common.**" -> "shaded.guava.@1") 63 | .inAll 64 | ), 65 | assembly / assemblyMergeStrategy := { 66 | case PathList("META-INF", xs@_*) => MergeStrategy.concat 67 | case x => MergeStrategy.first 68 | }, 69 | // Build assembly jar, this builds an uberJar with all dependencies 70 | assembly / assemblyJarName := s"${name.value}-${version.value}.jar", 71 | assembly / artifact := { 72 | val art = (assembly / artifact).value 73 | art.withClassifier(Some("assembly")) 74 | }, 75 | addArtifact(assembly / artifact, assembly), 76 | publishLocal / skip := true, 77 | ThisBuild / publishMavenStyle := true, 78 | // Expects credentials stored in ~/.sbt/sonatype_credentials. This is a standard practice 79 | credentials += Credentials(Path.userHome / ".sbt" / "sonatype_credentials"), 80 | releaseCrossBuild := true, // true if you cross-build the project for multiple Scala versions 81 | // ToDo: remove this once the databricks issue is resolved 82 | assembly / publishTo := sonatypePublishToBundle.value, 83 | publishTo := sonatypePublishToBundle.value, 84 | releaseProcess := Seq[ReleaseStep]( 85 | checkSnapshotDependencies, 86 | inquireVersions, 87 | runClean, 88 | setReleaseVersion, 89 | commitReleaseVersion, 90 | tagRelease, 91 | releaseStepCommandAndRemaining("+publishSigned"), 92 | releaseStepCommand("sonatypeBundleRelease"), 93 | pushChanges 94 | ) 95 | ) 96 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.6.2 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6") 2 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.12") 3 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1") 4 | addSbtPlugin("com.github.sbt" % "sbt-release" % "1.1.0") 5 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "1.2.0") -------------------------------------------------------------------------------- /src/it/resources/sample1.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1", 4 | "namespace": "example-namespace1", 5 | "values": [ 6 | 1, 7 | 2, 8 | 3 9 | ], 10 | "metadata": { 11 | "hello": [ 12 | "world", 13 | "you" 14 | ], 15 | "some_string": "or not", 16 | "actual_number": 5.2, 17 | "round": 3 18 | }, 19 | "sparse_values": { 20 | "indices": [ 21 | 0, 22 | 2, 23 | 4294967295 24 | ], 25 | "values": [ 26 | 4.5, 27 | 5.5, 28 | 5 29 | ] 30 | } 31 | }, 32 | { 33 | "id": "v2", 34 | "values": [ 35 | 3, 36 | 2, 37 | 1 38 | ] 39 | }, 40 | { 41 | "namespace": "example-namespace1", 42 | "values": [ 43 | 1, 44 | 4, 45 | 9 46 | ], 47 | "id": "v3", 48 | "metadata": "" 49 | }, 50 | { 51 | "id": "v4", 52 | "namespace": "example-namespace1", 53 | "values": [ 54 | 1, 55 | 1, 56 | 2 57 | ], 58 | "metadata": { 59 | "key": "value" 60 | } 61 | }, 62 | { 63 | "id": "v5", 64 | "namespace": "example-namespace1", 65 | "values": [ 66 | 3, 67 | 5, 68 | 8 69 | ], 70 | "metadata": { 71 | "key": "value" 72 | }, 73 | "sparse_values": { 74 | "indices": [ 75 | 1 76 | ], 77 | "values": [ 78 | 4 79 | ] 80 | } 81 | }, 82 | { 83 | "id": "v6", 84 | "namespace": "example-namespace1", 85 | "values": [ 86 | 13, 87 | 21, 88 | 34 89 | ], 90 | "metadata": "" 91 | }, 92 | { 93 | "id": "v7", 94 | "namespace": "example-namespace1", 95 | "values": [ 96 | 5, 97 | 6, 98 | 7 99 | ], 100 | "metadata": { 101 | "hello": [ 102 | "world", 103 | "you" 104 | ], 105 | "some_string": "or not", 106 | "actual_number": 5.2, 107 | "round": 3 108 | } 109 | } 110 | ] -------------------------------------------------------------------------------- /src/it/resources/sample2.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v8", 4 | "namespace": "example-namespace2", 5 | "values": [ 6 | 1, 7 | 2, 8 | 3 9 | ], 10 | "metadata": { 11 | "hello": [ 12 | "world", 13 | "you" 14 | ], 15 | "some_string": "or not", 16 | "actual_number": 5.2, 17 | "round": 3 18 | }, 19 | "sparse_values": { 20 | "indices": [ 21 | 0, 22 | 2, 23 | 4294967295 24 | ], 25 | "values": [ 26 | 4.5, 27 | 5.5, 28 | 5 29 | ] 30 | } 31 | }, 32 | { 33 | "id": "v9", 34 | "values": [ 35 | 3, 36 | 2, 37 | 1 38 | ] 39 | }, 40 | { 41 | "namespace": "example-namespace2", 42 | "values": [ 43 | 1, 44 | 4, 45 | 9 46 | ], 47 | "id": "v10", 48 | "metadata": "" 49 | }, 50 | { 51 | "id": "v11", 52 | "namespace": "example-namespace2", 53 | "values": [ 54 | 1, 55 | 1, 56 | 2 57 | ], 58 | "metadata": { 59 | "key": "value" 60 | } 61 | }, 62 | { 63 | "id": "v12", 64 | "namespace": "example-namespace2", 65 | "values": [ 66 | 3, 67 | 5, 68 | 8 69 | ], 70 | "metadata": { 71 | "key": "value" 72 | }, 73 | "sparse_values": { 74 | "indices": [ 75 | 1 76 | ], 77 | "values": [ 78 | 4 79 | ] 80 | } 81 | }, 82 | { 83 | "id": "v13", 84 | "namespace": "example-namespace2", 85 | "values": [ 86 | 13, 87 | 21, 88 | 34 89 | ], 90 | "metadata": "" 91 | }, 92 | { 93 | "id": "v14", 94 | "namespace": "example-namespace2", 95 | "values": [ 96 | 5, 97 | 6, 98 | 7 99 | ], 100 | "metadata": { 101 | "hello": [ 102 | "world", 103 | "you" 104 | ], 105 | "some_string": "or not", 106 | "actual_number": 5.2, 107 | "round": 3 108 | } 109 | } 110 | ] -------------------------------------------------------------------------------- /src/it/scala/io/pinecone/spark/pinecone/BatchUpsertExample.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.sql.{SaveMode, SparkSession} 5 | import org.scalatest.flatspec.AnyFlatSpec 6 | import org.scalatest.matchers.should 7 | 8 | class BatchUpsertExample extends AnyFlatSpec with should.Matchers { 9 | "Run" should "just work" in { 10 | val conf = new SparkConf() 11 | .setMaster("local[*]") 12 | val spark = SparkSession.builder().config(conf).getOrCreate() 13 | 14 | val df = spark.read 15 | .option("multiLine", value = true) 16 | .option("mode", "PERMISSIVE") 17 | .schema(COMMON_SCHEMA) 18 | .json("src/it/resources/sample1.jsonl") 19 | .repartition(2) 20 | 21 | df.count() should be(7) 22 | 23 | val pineconeOptions = Map( 24 | PineconeOptions.PINECONE_API_KEY_CONF -> System.getenv("PINECONE_API_KEY"), 25 | PineconeOptions.PINECONE_INDEX_NAME_CONF -> System.getenv("PINECONE_INDEX"), 26 | PineconeOptions.PINECONE_SOURCE_TAG_CONF -> System.getenv("PINECONE_SOURCE_TAG") 27 | ) 28 | 29 | df.write 30 | .format("io.pinecone.spark.pinecone.Pinecone") 31 | .options(pineconeOptions) 32 | .mode(SaveMode.Append) 33 | .save() 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/it/scala/io/pinecone/spark/pinecone/StreamUpsertExample.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.SparkSession 4 | 5 | object StreamUpsertExample { 6 | 7 | def main(args: Array[String]): Unit = { 8 | 9 | val spark = SparkSession.builder().appName("StreamUpsertExample") 10 | .config("spark.sql.shuffle.partitions", 3) 11 | .master("local") 12 | .getOrCreate() 13 | 14 | val lines = spark.readStream 15 | .option("multiLine", value = true) 16 | .option("mode", "PERMISSIVE") 17 | .schema(COMMON_SCHEMA) 18 | .json("src/it/resources/") 19 | 20 | val pineconeOptions = Map( 21 | PineconeOptions.PINECONE_API_KEY_CONF -> System.getenv("PINECONE_API_KEY"), 22 | PineconeOptions.PINECONE_INDEX_NAME_CONF -> System.getenv("PINECONE_INDEX"), 23 | PineconeOptions.PINECONE_SOURCE_TAG_CONF -> System.getenv("PINECONE_SOURCE_TAG") 24 | ) 25 | 26 | val upsert = lines 27 | .writeStream 28 | .format("io.pinecone.spark.pinecone.Pinecone") 29 | .options(pineconeOptions) 30 | .option("checkpointLocation", "path/to/checkpoint/dir") 31 | .outputMode("append") 32 | .start() 33 | 34 | upsert.awaitTermination() 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/Pinecone.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.connector.catalog.{Table, TableProvider} 4 | import org.apache.spark.sql.connector.expressions.Transform 5 | import org.apache.spark.sql.types.StructType 6 | import org.apache.spark.sql.util.CaseInsensitiveStringMap 7 | 8 | import java.util 9 | 10 | case class Pinecone() extends TableProvider { 11 | override def inferSchema(options: CaseInsensitiveStringMap): StructType = 12 | COMMON_SCHEMA 13 | 14 | override def getTable( 15 | schema: StructType, 16 | partitioning: Array[Transform], 17 | properties: util.Map[String, String] 18 | ): Table = { 19 | val pineconeOptions = new PineconeOptions(new CaseInsensitiveStringMap(properties)) 20 | PineconeIndex(pineconeOptions) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/PineconeBatchWriter.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.connector.write.{ 4 | BatchWrite, 5 | DataWriterFactory, 6 | PhysicalWriteInfo, 7 | WriterCommitMessage 8 | } 9 | import org.slf4j.LoggerFactory 10 | 11 | case class PineconeBatchWriter(pineconeOptions: PineconeOptions) extends BatchWrite { 12 | private val log = LoggerFactory.getLogger(getClass) 13 | 14 | override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { 15 | PineconeDataWriterFactory(pineconeOptions) 16 | } 17 | 18 | override def commit(messages: Array[WriterCommitMessage]): Unit = { 19 | val totalVectorsWritten = messages.map(_.asInstanceOf[PineconeCommitMessage].vectorCount).sum 20 | 21 | log.info( 22 | s"""A total of $totalVectorsWritten vectors written to index "${pineconeOptions.indexName}"""" 23 | ) 24 | } 25 | 26 | override def abort(messages: Array[WriterCommitMessage]): Unit = {} 27 | 28 | // Pinecone is inherently a key-value store, so running an upsert operation with 29 | // vectors with the same ID due to speculative execution or other mechanisms will result 30 | // in the same end result. It is the pipeline's developer responsibility to ensure that the initial 31 | // data doesn't have any vectors with shared IDs, a case which might result in a different outcome 32 | // each run. 33 | override def useCommitCoordinator(): Boolean = false 34 | 35 | override def toString: String = s"""PineconeBatchWriter(index="${pineconeOptions.indexName}")""" 36 | 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/PineconeDataWriter.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import io.pinecone.proto.{SparseValues, UpsertRequest, Vector => PineconeVector} 4 | import io.pinecone.clients.{Pinecone => PineconeClient} 5 | import io.pinecone.configs.{PineconeConfig, PineconeConnection} 6 | import org.apache.spark.sql.catalyst.InternalRow 7 | import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} 8 | import org.slf4j.LoggerFactory 9 | 10 | import scala.collection.JavaConverters._ 11 | import scala.collection.mutable 12 | 13 | case class PineconeDataWriter( 14 | partitionId: Int, 15 | taskId: Long, 16 | options: PineconeOptions 17 | ) extends DataWriter[InternalRow] 18 | with Serializable { 19 | private val log = LoggerFactory.getLogger(getClass) 20 | private val config: PineconeConfig = new PineconeConfig(options.apiKey, options.sourceTag) 21 | private val pinecone: PineconeClient = new PineconeClient.Builder(options.apiKey).build() 22 | config.setHost(pinecone.describeIndex(options.indexName).getHost) 23 | private val conn: PineconeConnection = new PineconeConnection(config) 24 | private var upsertBuilderMap = mutable.Map[String, UpsertRequest.Builder]() 25 | private var currentVectorsInBatch = 0 26 | private var totalVectorSize = 0 27 | 28 | private val maxBatchSize = options.maxBatchSize 29 | 30 | // Reporting vars 31 | var totalVectorsWritten = 0 32 | 33 | override def write(record: InternalRow): Unit = { 34 | try { 35 | val id = record.getUTF8String(0).toString 36 | val namespace = if (!record.isNullAt(1)) record.getUTF8String(1).toString else "" 37 | val values = record.getArray(2).toFloatArray().map(float2Float).toIterable 38 | 39 | if (id.length > MAX_ID_LENGTH) { 40 | throw VectorIdTooLongException(id) 41 | } 42 | 43 | val vectorBuilder = PineconeVector 44 | .newBuilder() 45 | .setId(id) 46 | 47 | if (values.nonEmpty) { 48 | vectorBuilder.addAllValues(values.asJava) 49 | } 50 | 51 | if (!record.isNullAt(3)) { 52 | val metadata = record.getUTF8String(3).toString 53 | val metadataStruct = parseAndValidateMetadata(id, metadata) 54 | vectorBuilder.setMetadata(metadataStruct) 55 | } 56 | 57 | if (!record.isNullAt(4)) { 58 | val sparseVectorStruct = record.getStruct(4, 2) 59 | if (!sparseVectorStruct.isNullAt(0) && !sparseVectorStruct.isNullAt(1)) { 60 | val sparseIndices = sparseVectorStruct.getArray(0).toLongArray() 61 | 62 | sparseIndices.find { index => 63 | if (index < 0 || index > 0xFFFFFFFFL) { 64 | throw new IllegalArgumentException(s"Sparse index $index is out of range for unsigned 32-bit integers.") 65 | } 66 | false 67 | } 68 | 69 | val sparseId = sparseIndices.map(_.toInt).map(int2Integer).toIterable 70 | val sparseValues = sparseVectorStruct.getArray(1).toFloatArray().map(float2Float).toIterable 71 | 72 | val sparseDataBuilder = SparseValues.newBuilder() 73 | .addAllIndices(sparseId.asJava) 74 | .addAllValues(sparseValues.asJava) 75 | 76 | vectorBuilder.setSparseValues(sparseDataBuilder.build()) 77 | } 78 | } 79 | 80 | val vector = vectorBuilder 81 | .build() 82 | 83 | if ((currentVectorsInBatch == maxBatchSize) || 84 | (totalVectorSize + vector.getSerializedSize >= MAX_REQUEST_SIZE) // If the vector will push the request over the size limit 85 | ) { 86 | flushBatchToIndex() 87 | } 88 | 89 | val builder = upsertBuilderMap 90 | .getOrElseUpdate( 91 | namespace, { 92 | UpsertRequest.newBuilder().setNamespace(namespace) 93 | } 94 | ) 95 | 96 | builder.addVectors(vector) 97 | upsertBuilderMap.update(namespace, builder) 98 | currentVectorsInBatch += 1 99 | totalVectorSize += vector.getSerializedSize 100 | } catch { 101 | case e: NullPointerException => 102 | log.error(s"Null values in rows: ${e.getMessage}") 103 | throw NullValueException("") 104 | } 105 | } 106 | 107 | override def commit(): WriterCommitMessage = { 108 | flushBatchToIndex() 109 | 110 | log.debug(s"taskId=$taskId partitionId=$partitionId totalVectorsUpserted=$totalVectorsWritten") 111 | 112 | PineconeCommitMessage(totalVectorsWritten) 113 | } 114 | 115 | override def abort(): Unit = { 116 | log.error( 117 | s"PineconeDataWriter(taskId=$taskId, partitionId=$partitionId) encountered an unhandled error and is shutting down" 118 | ) 119 | cleanup() 120 | } 121 | 122 | override def close(): Unit = { 123 | cleanup() 124 | } 125 | 126 | /** Frees up all resources before the Writer is shutdown 127 | */ 128 | private def cleanup(): Unit = { 129 | conn.close() 130 | } 131 | 132 | /** Sends all data pinecone and resets the Writer's state. 133 | */ 134 | private def flushBatchToIndex(): Unit = { 135 | log.debug(s"Sending ${upsertBuilderMap.size} requests to Pinecone index") 136 | for (builder <- upsertBuilderMap.values) { 137 | val request = builder.build() 138 | val response = conn.getBlockingStub.upsert(request) 139 | log.debug(s"Upserted ${response.getUpsertedCount} vectors to ${options.indexName}") 140 | totalVectorsWritten += response.getUpsertedCount 141 | } 142 | 143 | log.debug(s"Upsert operation was successful") 144 | 145 | upsertBuilderMap = mutable.Map() 146 | currentVectorsInBatch = 0 147 | totalVectorSize = 0 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/PineconeDataWriterFactory.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.catalyst.InternalRow 4 | import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} 5 | import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory 6 | 7 | case class PineconeDataWriterFactory(pineconeOptions: PineconeOptions) 8 | extends DataWriterFactory 9 | with StreamingDataWriterFactory 10 | with Serializable { 11 | override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { 12 | PineconeDataWriter(partitionId, taskId, pineconeOptions) 13 | } 14 | 15 | override def createWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { 16 | PineconeDataWriter(partitionId, taskId, pineconeOptions) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/PineconeIndex.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableCapability} 4 | import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} 5 | import org.apache.spark.sql.types.StructType 6 | 7 | import scala.collection.JavaConverters._ 8 | import java.util 9 | import scala.collection.immutable.HashSet 10 | 11 | case class PineconeIndex(pineconeOptions: PineconeOptions) extends SupportsWrite { 12 | override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = 13 | PineconeWriteBuilder(pineconeOptions) 14 | 15 | override def schema(): StructType = COMMON_SCHEMA 16 | 17 | override def capabilities(): util.Set[TableCapability] = Set( 18 | TableCapability.BATCH_WRITE, 19 | TableCapability.STREAMING_WRITE 20 | ).asJava 21 | 22 | override def name(): String = pineconeOptions.indexName 23 | } 24 | -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/PineconeOptions.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import PineconeOptions._ 4 | import org.apache.spark.sql.util.CaseInsensitiveStringMap 5 | 6 | class PineconeOptions(config: CaseInsensitiveStringMap) extends Serializable { 7 | private val DEFAULT_BATCH_SIZE = 100 8 | 9 | val maxBatchSize: Int = 10 | config 11 | .getInt(PINECONE_BATCH_SIZE_CONF, DEFAULT_BATCH_SIZE) 12 | 13 | val apiKey: String = getKey(PINECONE_API_KEY_CONF, config) 14 | val indexName: String = getKey(PINECONE_INDEX_NAME_CONF, config) 15 | val sourceTag: String = getSourceTag(config) 16 | 17 | private def getKey(key: String, config: CaseInsensitiveStringMap): String = { 18 | Option(config.get(key)).getOrElse( 19 | throw new RuntimeException(s"Missing required parameter $key") 20 | ) 21 | } 22 | 23 | private def getSourceTag(config: CaseInsensitiveStringMap): String = { 24 | val value = Option(config.get(PINECONE_SOURCE_TAG_CONF)).getOrElse("") 25 | s"spark_$value" 26 | } 27 | } 28 | 29 | object PineconeOptions { 30 | val PINECONE_BATCH_SIZE_CONF: String = "pinecone.batchSize" 31 | val PINECONE_API_KEY_CONF: String = "pinecone.apiKey" 32 | val PINECONE_INDEX_NAME_CONF: String = "pinecone.indexName" 33 | val PINECONE_SOURCE_TAG_CONF: String = "pinecone.sourceTag" 34 | } 35 | -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/PineconeStreamingWriter.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.connector.write.{PhysicalWriteInfo, WriterCommitMessage} 4 | import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} 5 | import org.slf4j.LoggerFactory 6 | 7 | case class PineconeStreamingWriter(pineconeOptions: PineconeOptions) extends StreamingWrite { 8 | private val log = LoggerFactory.getLogger(getClass) 9 | 10 | override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = { 11 | PineconeDataWriterFactory(pineconeOptions) 12 | } 13 | 14 | override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { 15 | val totalVectorsWritten = messages.map(_.asInstanceOf[PineconeCommitMessage].vectorCount).sum 16 | 17 | log.info( 18 | s"""Epoch $epochId: A total of $totalVectorsWritten vectors written to index "${pineconeOptions.indexName}"""" 19 | ) 20 | } 21 | 22 | override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { 23 | log.error(s"Epoch $epochId: Write operation aborted") 24 | } 25 | 26 | override def toString: String = s"PineconeStreamingWriter(index=${pineconeOptions.indexName})" 27 | } -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/PineconeWrite.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.connector.write.{Write, BatchWrite} 4 | import org.apache.spark.sql.connector.write.streaming.StreamingWrite 5 | 6 | case class PineconeWrite(pineconeOptions: PineconeOptions) extends Write with Serializable { 7 | override def toBatch: BatchWrite = PineconeBatchWriter(pineconeOptions) 8 | override def toStreaming: StreamingWrite = PineconeStreamingWriter(pineconeOptions) 9 | } 10 | -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/PineconeWriteBuilder.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder} 4 | 5 | case class PineconeWriteBuilder(pineconeOptions: PineconeOptions) 6 | extends WriteBuilder 7 | with Serializable { 8 | override def build: Write = PineconeWrite(pineconeOptions) 9 | } 10 | -------------------------------------------------------------------------------- /src/main/scala/io/pinecone/spark/pinecone/package.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper 4 | import com.google.protobuf.{ListValue, Struct, Value} 5 | import org.apache.spark.sql.connector.write.WriterCommitMessage 6 | import org.apache.spark.sql.types.{ArrayType, FloatType, LongType, StringType, StructField, StructType} 7 | 8 | import scala.collection.JavaConverters._ 9 | 10 | package object pinecone { 11 | val COMMON_SCHEMA: StructType = 12 | new StructType() 13 | .add("id", StringType, nullable = false) 14 | .add("namespace", StringType, nullable = true) 15 | .add("values", ArrayType(FloatType, containsNull = false), nullable = false) 16 | .add("metadata", StringType, nullable = true) 17 | .add("sparse_values", StructType( 18 | Array( 19 | StructField("indices", ArrayType(LongType, containsNull = false), nullable = false), 20 | StructField("values", ArrayType(FloatType, containsNull = false), nullable = false) 21 | ) 22 | ), nullable = true) 23 | 24 | private[pinecone] val MAX_ID_LENGTH = 512 25 | private[pinecone] val MAX_METADATA_SIZE = 40 * math.pow(10, 3) // 40KB 26 | private[pinecone] val MAX_REQUEST_SIZE = 2 * math.pow(10, 6) // 2MB 27 | 28 | /** Parses the metadata of a vector from a JSON string representation to a ProtoBuf struct 29 | * 30 | * @param vectorId 31 | * the ID of the vector, used for error reporting purposes 32 | * @param metadataStr 33 | * \- the JSON string representing the vector's metadata 34 | * @return 35 | */ 36 | private[pinecone] def parseAndValidateMetadata(vectorId: String, metadataStr: String): Struct = { 37 | val structBuilder = Struct.newBuilder() 38 | val mapper = new ObjectMapper() 39 | 40 | val jsonTree = mapper.readTree(metadataStr) 41 | 42 | for (jsonField <- jsonTree.fields().asScala) { 43 | val key = jsonField.getKey 44 | val value = jsonField.getValue 45 | 46 | if (value.isTextual) { 47 | structBuilder.putFields(key, Value.newBuilder().setStringValue(value.asText()).build()) 48 | } else if (value.isNumber) { 49 | structBuilder.putFields(key, Value.newBuilder().setNumberValue(value.floatValue()).build()) 50 | } else if (value.isBoolean) { 51 | structBuilder.putFields(key, Value.newBuilder().setBoolValue(value.booleanValue()).build()) 52 | } else if (value.isArray && value.elements().asScala.toArray.forall(_.isTextual)) { 53 | val arrayElements = value.elements().asScala.toArray 54 | val listValueBuilder = ListValue.newBuilder() 55 | listValueBuilder.addAllValues( 56 | arrayElements 57 | .map(element => Value.newBuilder().setStringValue(element.textValue()).build()) 58 | .toIterable 59 | .asJava 60 | ) 61 | 62 | structBuilder.putFields( 63 | key, 64 | Value.newBuilder().setListValue(listValueBuilder.build()).build() 65 | ) 66 | } else { 67 | throw InvalidVectorMetadataException(vectorId, key) 68 | } 69 | } 70 | 71 | val finalStruct = structBuilder.build() 72 | 73 | if (finalStruct.getSerializedSize >= MAX_METADATA_SIZE) { 74 | throw PineconeMetadataTooLarge(vectorId) 75 | } 76 | 77 | finalStruct 78 | } 79 | 80 | case class PineconeCommitMessage(vectorCount: Int) extends WriterCommitMessage 81 | 82 | trait PineconeException extends Exception { 83 | def getMessage: String 84 | 85 | def vectorId: String 86 | } 87 | 88 | case class InvalidVectorMetadataException(vectorId: String, jsonKey: String) 89 | extends PineconeException { 90 | override def getMessage: String = 91 | s"Vector with ID '$vectorId' has invalid metadata field '$jsonKey'. Please refer to the Pinecone.io docs for a longer explanation." 92 | } 93 | 94 | case class PineconeMetadataTooLarge(vectorId: String) extends PineconeException { 95 | override def getMessage: String = 96 | s"Metadata for vector with ID '$vectorId' exceeded the maximum metadata size of $MAX_METADATA_SIZE bytes" 97 | } 98 | 99 | case class VectorIdTooLongException(vectorId: String) extends PineconeException { 100 | override def getMessage: String = 101 | s"Vector with ID starting with ${vectorId.substring(0, 8)}. Must be 512 characters or less. actual: ${vectorId.length}" 102 | } 103 | 104 | case class NullValueException(vectorId: String) extends PineconeException { 105 | override def getMessage: String = "Null id or value column found in row. Please ensure id and values are not null." 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput1.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "values": [ 4 | 1, 5 | 2, 6 | 3 7 | ] 8 | } 9 | ] -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput2.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1" 4 | } 5 | ] -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput3.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1", 4 | "values": [ 5 | 3, 6 | 2, 7 | 1 8 | ], 9 | "sparse_values": { 10 | } 11 | } 12 | ] -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput4.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1", 4 | "values": [ 5 | 3, 6 | 2, 7 | 1 8 | ], 9 | "sparse_values": { 10 | "values": [ 11 | 100, 12 | 101 13 | ] 14 | } 15 | } 16 | ] -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput5.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1", 4 | "values": [ 5 | 3, 6 | 2, 7 | 1 8 | ], 9 | "sparse_values": { 10 | "indices": [ 11 | 1, 12 | 2 13 | ] 14 | } 15 | } 16 | ] -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput6.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1", 4 | "values": [ 5 | 3, 6 | 2, 7 | 1 8 | ], 9 | "sparse_values": { 10 | "indices": [ 11 | null 12 | ], 13 | "values": [ 14 | 1 15 | ] 16 | } 17 | } 18 | ] -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput7.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1", 4 | "values": [ 5 | 3, 6 | 2, 7 | 1 8 | ], 9 | "sparse_values": { 10 | "indices": [ 11 | 1 12 | ], 13 | "values": [ 14 | null 15 | ] 16 | } 17 | } 18 | ] -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput8.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1", 4 | "values": [ 5 | 3, 6 | 2, 7 | 1 8 | ], 9 | "sparse_values": { 10 | "indices": [ 11 | -1 12 | ], 13 | "values": [ 14 | 100 15 | ] 16 | } 17 | } 18 | ] -------------------------------------------------------------------------------- /src/test/resources/invalidUpsertInput9.jsonl: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "v1", 4 | "values": [ 5 | 3, 6 | 2, 7 | 1 8 | ], 9 | "sparse_values": { 10 | "indices": [ 11 | 4294967296 12 | ], 13 | "values": [ 14 | 100 15 | ] 16 | } 17 | } 18 | ] -------------------------------------------------------------------------------- /src/test/scala/io/pinecone/spark/pinecone/ParseCommonSchemaTest.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import org.apache.spark.sql.{SaveMode, SparkSession} 4 | import org.scalatest.flatspec.AnyFlatSpec 5 | import org.scalatest.matchers.should 6 | 7 | class ParseCommonSchemaTest extends AnyFlatSpec with should.Matchers { 8 | private val spark: SparkSession = SparkSession.builder() 9 | .appName("SchemaValidationTest") 10 | .master("local[2]") 11 | .getOrCreate() 12 | 13 | private val inputFilePath = System.getProperty("user.dir") + "/src/test/resources" 14 | 15 | private val apiKey = "some_api_key" 16 | private val indexName = "step-test" 17 | 18 | private val pineconeOptions: Map[String, String] = Map( 19 | PineconeOptions.PINECONE_API_KEY_CONF -> apiKey, 20 | PineconeOptions.PINECONE_INDEX_NAME_CONF -> indexName 21 | ) 22 | 23 | def afterAll(): Unit = { 24 | if (spark != null) { 25 | spark.stop() 26 | } 27 | } 28 | 29 | def testInvalidJSON(file: String, testName: String): Unit = { 30 | it should testName in { 31 | val sparkException = intercept[org.apache.spark.SparkException] { 32 | val df = spark.read 33 | .option("multiLine", value = true) 34 | .option("mode", "PERMISSIVE") 35 | .schema(COMMON_SCHEMA) 36 | .json(file) 37 | .repartition(2) 38 | 39 | df.write 40 | .options(pineconeOptions) 41 | .format("io.pinecone.spark.pinecone.Pinecone") 42 | .mode(SaveMode.Append) 43 | .save() 44 | } 45 | sparkException 46 | .getCause 47 | .toString should include("java.lang.NullPointerException: Null value appeared in non-nullable field:") 48 | } 49 | } 50 | 51 | def testInvalidSparseIndices(file: String, testName: String): Unit = { 52 | it should testName in { 53 | val sparkException = intercept[org.apache.spark.SparkException] { 54 | val df = spark.read 55 | .option("multiLine", value = true) 56 | .option("mode", "PERMISSIVE") 57 | .schema(COMMON_SCHEMA) 58 | .json(file) 59 | .repartition(2) 60 | 61 | df.write 62 | .options(pineconeOptions) 63 | .format("io.pinecone.spark.pinecone.Pinecone") 64 | .mode(SaveMode.Append) 65 | .save() 66 | } 67 | sparkException.getCause.toString should include("java.lang.IllegalArgumentException:") 68 | sparkException.getCause.toString should include("is out of range for unsigned 32-bit integers") 69 | } 70 | } 71 | 72 | testInvalidJSON(s"$inputFilePath/invalidUpsertInput1.jsonl", 73 | "throw exception for missing id") 74 | testInvalidJSON(s"$inputFilePath/invalidUpsertInput2.jsonl", 75 | "throw exception for missing values") 76 | testInvalidJSON(s"$inputFilePath/invalidUpsertInput3.jsonl", 77 | "throw exception for missing sparse vector indices and values if sparse_values is defined") 78 | testInvalidJSON(s"$inputFilePath/invalidUpsertInput4.jsonl", 79 | "throw exception for missing sparse vector indices if sparse_values and its values are defined") 80 | testInvalidJSON(s"$inputFilePath/invalidUpsertInput5.jsonl", 81 | "throw exception for missing sparse vector values if sparse_values and its indices are defined") 82 | testInvalidJSON(s"$inputFilePath/invalidUpsertInput6.jsonl", 83 | "throw exception for null in sparse vector indices") 84 | testInvalidJSON(s"$inputFilePath/invalidUpsertInput7.jsonl", 85 | "throw exception for null in sparse vector values") 86 | testInvalidSparseIndices(s"$inputFilePath/invalidUpsertInput8.jsonl", 87 | "throw exception for invalid sparse vector indices") 88 | testInvalidSparseIndices(s"$inputFilePath/invalidUpsertInput9.jsonl", 89 | "throw exception for invalid sparse vector indices2") 90 | } 91 | -------------------------------------------------------------------------------- /src/test/scala/io/pinecone/spark/pinecone/ParseMetadataSpec.scala: -------------------------------------------------------------------------------- 1 | package io.pinecone.spark.pinecone 2 | 3 | import com.fasterxml.jackson.core.JsonParseException 4 | import org.scalatest.flatspec.AnyFlatSpec 5 | import org.scalatest.matchers.should 6 | 7 | class ParseMetadataSpec extends AnyFlatSpec with should.Matchers { 8 | val mockVectorId = "id" 9 | 10 | it should "throw an InvalidVectorMetadataException if the vector has invalid metadata (list of ints)" in { 11 | val metadataWithLongId = 12 | s""" 13 | |{ "field": [1, 2] } 14 | |""".stripMargin 15 | 16 | a[InvalidVectorMetadataException] should be thrownBy { 17 | parseAndValidateMetadata(mockVectorId, metadataWithLongId) 18 | } 19 | } 20 | 21 | it should "throw an InvalidVectorMetadataException if the vector has invalid metadata (list with mixed types)" in { 22 | val metadataWithLongId = 23 | s""" 24 | |{ "field": [1, "2"] } 25 | |""".stripMargin 26 | 27 | a[InvalidVectorMetadataException] should be thrownBy { 28 | parseAndValidateMetadata(mockVectorId, metadataWithLongId) 29 | } 30 | } 31 | 32 | it should "throw an error when the JSON string is invalid JSON (no closing braces)" in { 33 | val partialJson = 34 | """ 35 | |{ "hello": "world" 36 | |""".stripMargin 37 | a[JsonParseException] should be thrownBy { 38 | parseAndValidateMetadata(mockVectorId, partialJson) 39 | } 40 | } 41 | 42 | it should "throw an error when the JSON string is invalid JSON (Use of single quotes)" in { 43 | val partialJson = 44 | """ 45 | |{'hello': "world"} 46 | |""".stripMargin 47 | a[JsonParseException] should be thrownBy { 48 | parseAndValidateMetadata(mockVectorId, partialJson) 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /version.sbt: -------------------------------------------------------------------------------- 1 | ThisBuild / version := "1.2.0" 2 | --------------------------------------------------------------------------------