├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── build.sbt ├── project ├── build.properties └── plugins.sbt ├── src ├── main │ └── scala │ │ └── com │ │ └── github │ │ └── jparkie │ │ └── spark │ │ └── elasticsearch │ │ ├── SparkEsBulkWriter.scala │ │ ├── SparkEsMapper.scala │ │ ├── SparkEsSerializer.scala │ │ ├── conf │ │ ├── SparkEsMapperConf.scala │ │ ├── SparkEsTransportClientConf.scala │ │ └── SparkEsWriteConf.scala │ │ ├── sql │ │ ├── SparkEsDataFrameFunctions.scala │ │ ├── SparkEsDataFrameMapper.scala │ │ ├── SparkEsDataFrameSerializer.scala │ │ └── package.scala │ │ ├── transport │ │ ├── SparkEsTransportClientManager.scala │ │ └── SparkEsTransportClientProxy.scala │ │ └── util │ │ ├── SparkEsConfParam.scala │ │ └── SparkEsException.scala └── test │ └── scala │ └── com │ └── github │ └── jparkie │ └── spark │ └── elasticsearch │ ├── ElasticSearchServer.scala │ ├── SparkEsBulkWriterSpec.scala │ ├── conf │ ├── SparkEsMapperConfSpec.scala │ ├── SparkEsTransportClientConfSpec.scala │ └── SparkEsWriteConfSpec.scala │ ├── sql │ ├── PackageSpec.scala │ ├── SparkEsDataFrameMapperSpec.scala │ └── SparkEsDataFrameSerializerSpec.scala │ └── transport │ ├── SparkEsTransportClientManagerSpec.scala │ └── SparkEsTransportClientProxySpec.scala └── version.sbt /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/intellij,scala 2 | 3 | ### Intellij ### 4 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 5 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 6 | 7 | # Override All: 8 | .idea/ 9 | 10 | # User-specific stuff: 11 | .idea/workspace.xml 12 | .idea/tasks.xml 13 | .idea/dictionaries 14 | .idea/vcs.xml 15 | .idea/jsLibraryMappings.xml 16 | 17 | # Sensitive or high-churn files: 18 | .idea/dataSources.ids 19 | .idea/dataSources.xml 20 | .idea/sqlDataSources.xml 21 | .idea/dynamic.xml 22 | .idea/uiDesigner.xml 23 | 24 | # Gradle: 25 | .idea/gradle.xml 26 | .idea/libraries 27 | 28 | # Mongo Explorer plugin: 29 | .idea/mongoSettings.xml 30 | 31 | ## File-based project format: 32 | *.iws 33 | 34 | ## Plugin-specific files: 35 | 36 | # IntelliJ 37 | /out/ 38 | 39 | # mpeltonen/sbt-idea plugin 40 | .idea_modules/ 41 | 42 | # JIRA plugin 43 | atlassian-ide-plugin.xml 44 | 45 | # Crashlytics plugin (for Android Studio and IntelliJ) 46 | com_crashlytics_export_strings.xml 47 | crashlytics.properties 48 | crashlytics-build.properties 49 | fabric.properties 50 | 51 | 52 | ### Scala ### 53 | *.class 54 | *.log 55 | 56 | # sbt specific 57 | .cache 58 | .history 59 | .lib/ 60 | dist/* 61 | target/ 62 | lib_managed/ 63 | src_managed/ 64 | project/boot/ 65 | project/plugins/project/ 66 | 67 | # Scala-IDE specific 68 | .scala_dependencies 69 | .worksheet -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | 3 | scala: 4 | - 2.10.6 5 | - 2.11.7 6 | 7 | jdk: 8 | - oraclejdk7 9 | - oraclejdk8 10 | 11 | sudo: false 12 | 13 | cache: 14 | directories: 15 | - $HOME/.ivy2 16 | 17 | script: 18 | - sbt clean coverage test 19 | 20 | after_success: 21 | - bash <(curl -s https://codecov.io/bash) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spark2Elasticsearch 2 | 3 | Spark Library for Bulk Loading into Elasticsearch 4 | 5 | [![Build Status](https://travis-ci.org/jparkie/Spark2Elasticsearch.svg?branch=master)](https://travis-ci.org/jparkie/Spark2Elasticsearch) 6 | 7 | ## Requirements 8 | 9 | Spark2Elasticsearch supports Spark 1.4 and above. 10 | 11 | | Spark2Elasticsearch Version | Elasticsearch Version | 12 | | --------------------------- | --------------------- | 13 | | `2.0.X` | `2.0.X` | 14 | | `2.1.X` | `2.1.X` | 15 | 16 | ## Downloads 17 | 18 | #### SBT 19 | ```scala 20 | libraryDependencies += "com.github.jparkie" %% "spark2elasticsearch" % "2.0.2" 21 | ``` 22 | 23 | Or: 24 | 25 | ```scala 26 | libraryDependencies += "com.github.jparkie" %% "spark2elasticsearch" % "2.1.2" 27 | ``` 28 | 29 | Add the following resolver if needed: 30 | 31 | ```scala 32 | resolvers += "Sonatype OSS Releases" at "https://oss.sonatype.org/content/repositories/releases" 33 | resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots" 34 | ``` 35 | 36 | #### Maven 37 | ```xml 38 | 39 | com.github.jparkie 40 | spark2elasticsearch_2.10 41 | x.y.z-SNAPSHOT 42 | 43 | ``` 44 | 45 | It is planned for Spark2Elasticsearch to be available on the following: 46 | - http://spark-packages.org/ 47 | 48 | ## Features 49 | - Utilizes Elasticsearch Java API with a `TransportClient` to bulk load data from a `DataFrame` into Elasticsearch. 50 | 51 | ## Usage 52 | 53 | ### Bulk Loading into Elasticsearch 54 | 55 | ```scala 56 | // Import the following to have access to the `bulkLoadToEs()` function. 57 | import com.github.jparkie.spark.elasticsearch.sql._ 58 | 59 | val sparkConf = new SparkConf() 60 | val sc = SparkContext.getOrCreate(sparkConf) 61 | val sqlContext = SQLContext.getOrCreate(sc) 62 | 63 | val df = sqlContext.read.parquet("") 64 | 65 | // Specify the `index` and the `type` to write. 66 | df.bulkLoadToEs( 67 | esIndex = "twitter", 68 | esType = "tweets" 69 | ) 70 | ``` 71 | 72 | Refer to for more: [SparkEsDataFrameFunctions.scala](https://github.com/jparkie/Spark2Elasticsearch/blob/master/src/main/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameFunctions.scala) 73 | 74 | ## Configurations 75 | 76 | When adding configurations to through spark-submit, prefix property names with `spark.`. 77 | 78 | ### SparkEsMapperConf 79 | 80 | Refer to for more: [SparkEsMapperConf.scala](https://github.com/jparkie/Spark2Elasticsearch/blob/master/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsMapperConf.scala) 81 | 82 | | Property Name | Default | Description | 83 | | ------------------------- |:-------:| ------------| 84 | | `es.mapping.id` | None | The document field/property name containing the document id. | 85 | | `es.mapping.parent` | None | The document field/property name containing the document parent. To specify a constant, use the format. | 86 | | `es.mapping.version` | None | The document field/property name containing the document version. To specify a constant, use the format. | 87 | | `es.mapping.version.type` | None | Indicates the type of versioning used. http://www.elastic.co/guide/en/elasticsearch/reference/2.0/docs-index_.html#_version_types If es.mapping.version is undefined (default), its value is unspecified. If es.mapping.version is specified, its value becomes external. | 88 | | `es.mapping.routing` | None | The document field/property name containing the document routing. To specify a constant, use the format. | 89 | | `es.mapping.ttl` | None | The document field/property name containing the document time-to-live. To specify a constant, use the format. | 90 | | `es.mapping.timestamp` | None | The document field/property name containing the document timestamp. To specify a constant, use the format. | 91 | 92 | ### SparkEsTransportClientConf 93 | 94 | Refer to for more: [SparkEsTransportClientConf.scala](https://github.com/jparkie/Spark2Elasticsearch/blob/master/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsTransportClientConf.scala) 95 | 96 | | Property Name | Default | Description | 97 | | -------------------------------------------- |:----------:| ------------| 98 | | `es.nodes` | *Required* | The minimum set of hosts to connect to when establishing a client. Comma separated, colon separated host and port. | 99 | | `es.port` | 9300 | The port to connect when establishing a client. | 100 | | `es.cluster.name` | None | The name of the Elasticsearch cluster to connect. | 101 | | `es.client.transport.sniff` | None | If set to true, will discover other IP addresses to connect. | 102 | | `es.client.transport.ignore_cluster_name` | None | Set to true to ignore cluster name validation of connected nodes. | 103 | | `es.client.transport.ping_timeout` | 5s | The time to wait for a ping response from a node. | 104 | | `es.client.transport.nodes_sampler_interval` | 5s | How often to sample / ping the nodes listed and connected. | 105 | 106 | ### SparkEsWriteConf 107 | 108 | Refer to for more: [SparkEsWriteConf.scala](https://github.com/jparkie/Spark2Elasticsearch/blob/master/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsWriteConf.scala) 109 | 110 | | Property Name | Default | Description | 111 | | ----------------------------- |:-------:| ------------| 112 | | `es.batch.size.entries` | 1000 | The number of IndexRequests to batch in one request. | 113 | | `es.batch.size.bytes` | 5 | The maximum size in MB of a batch. | 114 | | `es.batch.concurrent.request` | 1 | The number of concurrent requests in flight. | 115 | | `es.batch.flush.timeout` | 10 | The maximum time in seconds to wait while closing a BulkProcessor. | 116 | 117 | ## Documentation 118 | 119 | Scaladocs are currently unavailable. 120 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import com.typesafe.sbt.SbtScalariform 2 | import com.typesafe.sbt.SbtScalariform._ 3 | import scalariform.formatter.preferences._ 4 | 5 | /** 6 | * Organization: 7 | */ 8 | organization := "com.github.jparkie" 9 | organizationName := "jparkie" 10 | 11 | /** 12 | * Library Meta: 13 | */ 14 | name := "Spark2Elasticsearch" 15 | licenses := Seq(("Apache License, Version 2.0", url("http://www.apache.org/licenses/LICENSE-2.0"))) 16 | 17 | /** 18 | * Scala: 19 | */ 20 | scalaVersion := "2.10.6" 21 | crossScalaVersions := Seq("2.10.6", "2.11.7") 22 | 23 | /** 24 | * Library Dependencies: 25 | */ 26 | 27 | // Exclusion Rules: 28 | val guavaRule = ExclusionRule("com.google.guava", "guava") 29 | val sparkNetworkCommonRule = ExclusionRule("org.apache.spark", "spark-network-common") 30 | 31 | // Versions: 32 | val SparkVersion = "1.4.1" 33 | val SparkTestVersion = "1.4.1_0.3.0" 34 | val ScalaTestVersion = "2.2.4" 35 | val ElasticsearchVersion = "2.0.2" 36 | val Slf4jVersion = "1.7.10" 37 | 38 | // Dependencies: 39 | val sparkCore = "org.apache.spark" %% "spark-core" % SparkVersion % "provided" excludeAll(sparkNetworkCommonRule, guavaRule) 40 | val sparkSql = "org.apache.spark" %% "spark-sql" % SparkVersion % "provided" excludeAll(sparkNetworkCommonRule, guavaRule) 41 | val sparkTest = "com.holdenkarau" %% "spark-testing-base" % SparkTestVersion % "test" 42 | val scalaTest = "org.scalatest" %% "scalatest" % ScalaTestVersion % "test" 43 | val elasticsearch = "org.elasticsearch" % "elasticsearch" % ElasticsearchVersion 44 | val slf4j = "org.slf4j" % "slf4j-api" % Slf4jVersion 45 | 46 | libraryDependencies ++= Seq(sparkCore, sparkSql, sparkTest, scalaTest, elasticsearch, slf4j) 47 | 48 | /** 49 | * Tests: 50 | */ 51 | parallelExecution in Test := false 52 | 53 | /** 54 | * Scalariform: 55 | */ 56 | SbtScalariform.scalariformSettings 57 | ScalariformKeys.preferences := FormattingPreferences() 58 | .setPreference(RewriteArrowSymbols, false) 59 | .setPreference(AlignParameters, true) 60 | .setPreference(AlignSingleLineCaseStatements, true) 61 | .setPreference(SpacesAroundMultiImports, true) 62 | 63 | /** 64 | * Scoverage: 65 | */ 66 | coverageEnabled in Test := true 67 | 68 | /** 69 | * Publishing to Sonatype: 70 | */ 71 | publishMavenStyle := true 72 | 73 | publishArtifact in Test := false 74 | 75 | publishTo := { 76 | val nexus = "https://oss.sonatype.org/" 77 | if (isSnapshot.value) 78 | Some("snapshots" at nexus + "content/repositories/snapshots") 79 | else 80 | Some("releases" at nexus + "service/local/staging/deploy/maven2") 81 | } 82 | 83 | pomExtra := { 84 | https://github.com/jparkie/Spark2Elasticsearch 85 | 86 | git@github.com:jparkie/Spark2Elasticsearch.git 87 | scm:git:git@github.com:jparkie/Spark2Elasticsearch.git 88 | 89 | 90 | 91 | jparkie 92 | Jacob Park 93 | https://github.com/jparkie 94 | 95 | 96 | } 97 | 98 | /** 99 | * Release: 100 | */ 101 | import ReleaseTransformations._ 102 | 103 | releasePublishArtifactsAction := PgpKeys.publishSigned.value 104 | 105 | releaseProcess := Seq[ReleaseStep]( 106 | checkSnapshotDependencies, 107 | inquireVersions, 108 | runTest, 109 | setReleaseVersion, 110 | commitReleaseVersion, 111 | tagRelease, 112 | publishArtifacts, 113 | setNextVersion, 114 | commitNextVersion, 115 | pushChanges 116 | ) 117 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.8 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" 2 | resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" 3 | resolvers += "Sonatype OSS Releases" at "https://oss.sonatype.org/service/local/staging/deploy/maven2" 4 | 5 | addSbtPlugin("org.scalariform" % "sbt-scalariform" % "1.6.0") 6 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") 7 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.2") 8 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.3.5") -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/SparkEsBulkWriter.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch 2 | 3 | import java.util.concurrent.TimeUnit 4 | 5 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsWriteConf 6 | import com.github.jparkie.spark.elasticsearch.util.SparkEsException 7 | import org.apache.spark.{ Logging, TaskContext } 8 | import org.elasticsearch.action.bulk.{ BulkProcessor, BulkRequest, BulkResponse } 9 | import org.elasticsearch.action.index.IndexRequest 10 | import org.elasticsearch.action.update.UpdateRequest 11 | import org.elasticsearch.client.Client 12 | import org.elasticsearch.common.unit.{ ByteSizeUnit, ByteSizeValue } 13 | 14 | class SparkEsBulkWriter[T]( 15 | esIndex: String, 16 | esType: String, 17 | esClient: () => Client, 18 | sparkEsSerializer: SparkEsSerializer[T], 19 | sparkEsMapper: SparkEsMapper[T], 20 | sparkEsWriteConf: SparkEsWriteConf 21 | ) extends Serializable with Logging { 22 | /** 23 | * Logs the executionId, number of requests, size, and latency of flushes. 24 | */ 25 | class SparkEsBulkProcessorListener() extends BulkProcessor.Listener { 26 | override def beforeBulk(executionId: Long, request: BulkRequest): Unit = { 27 | logInfo(s"For executionId ($executionId), executing ${request.numberOfActions()} actions of estimate size ${request.estimatedSizeInBytes()} in bytes.") 28 | } 29 | 30 | override def afterBulk(executionId: Long, request: BulkRequest, response: BulkResponse): Unit = { 31 | logInfo(s"For executionId ($executionId), executed ${request.numberOfActions()} in ${response.getTookInMillis} milliseconds.") 32 | 33 | if (response.hasFailures) { 34 | throw new SparkEsException(response.buildFailureMessage()) 35 | } 36 | } 37 | 38 | override def afterBulk(executionId: Long, request: BulkRequest, failure: Throwable): Unit = { 39 | logError(s"For executionId ($executionId), BulkRequest failed.", failure) 40 | 41 | throw new SparkEsException(failure.getMessage, failure) 42 | } 43 | } 44 | 45 | private[elasticsearch] def logDuration(closure: () => Unit): Unit = { 46 | val localStartTime = System.nanoTime() 47 | 48 | closure() 49 | 50 | val localEndTime = System.nanoTime() 51 | 52 | val differenceTime = localEndTime - localStartTime 53 | logInfo(s"Elasticsearch Task completed in ${TimeUnit.MILLISECONDS.convert(differenceTime, TimeUnit.NANOSECONDS)} milliseconds.") 54 | } 55 | 56 | private[elasticsearch] def createBulkProcessor(): BulkProcessor = { 57 | val esBulkProcessorListener = new SparkEsBulkProcessorListener() 58 | val esBulkProcessor = BulkProcessor.builder(esClient(), esBulkProcessorListener) 59 | .setBulkActions(sparkEsWriteConf.bulkActions) 60 | .setBulkSize(new ByteSizeValue(sparkEsWriteConf.bulkSizeInMB, ByteSizeUnit.MB)) 61 | .setConcurrentRequests(sparkEsWriteConf.concurrentRequests) 62 | .build() 63 | 64 | esBulkProcessor 65 | } 66 | 67 | private[elasticsearch] def closeBulkProcessor(bulkProcessor: BulkProcessor): Unit = { 68 | val isClosed = bulkProcessor.awaitClose(sparkEsWriteConf.flushTimeoutInSeconds, TimeUnit.SECONDS) 69 | if (isClosed) { 70 | logInfo("Closed Elasticsearch Bulk Processor.") 71 | } else { 72 | logError("Elasticsearch Bulk Processor failed to close.") 73 | } 74 | } 75 | 76 | private[elasticsearch] def applyMappings(currentRow: T, indexRequest: IndexRequest): Unit = { 77 | sparkEsMapper.extractMappingId(currentRow).foreach(indexRequest.id) 78 | sparkEsMapper.extractMappingParent(currentRow).foreach(indexRequest.parent) 79 | sparkEsMapper.extractMappingVersion(currentRow).foreach(indexRequest.version) 80 | sparkEsMapper.extractMappingVersionType(currentRow).foreach(indexRequest.versionType) 81 | sparkEsMapper.extractMappingRouting(currentRow).foreach(indexRequest.routing) 82 | sparkEsMapper.extractMappingTTLInMillis(currentRow).foreach(indexRequest.ttl(_)) 83 | sparkEsMapper.extractMappingTimestamp(currentRow).foreach(indexRequest.timestamp) 84 | } 85 | 86 | /** 87 | * Upserts T to Elasticsearch by establishing a TransportClient and BulkProcessor. 88 | * 89 | * @param taskContext The TaskContext provided by the Spark DAGScheduler. 90 | * @param data The set of T to persist. 91 | */ 92 | def write(taskContext: TaskContext, data: Iterator[T]): Unit = logDuration { () => 93 | val esBulkProcessor = createBulkProcessor() 94 | 95 | for (currentRow <- data) { 96 | val currentIndexRequest = new IndexRequest(esIndex, esType) 97 | .source(sparkEsSerializer.write(currentRow)) 98 | 99 | applyMappings(currentRow, currentIndexRequest) 100 | 101 | val currentId = currentIndexRequest.id() 102 | val currentParent = currentIndexRequest.parent() 103 | val currentVersion = currentIndexRequest.version() 104 | val currentVersionType = currentIndexRequest.versionType() 105 | val currentRouting = currentIndexRequest.routing() 106 | 107 | val currentUpsertRequest = new UpdateRequest(esIndex, esType, currentId) 108 | .parent(currentParent) 109 | .version(currentVersion) 110 | .versionType(currentVersionType) 111 | .routing(currentRouting) 112 | .doc(currentIndexRequest) 113 | .docAsUpsert(true) 114 | 115 | esBulkProcessor.add(currentUpsertRequest) 116 | } 117 | 118 | closeBulkProcessor(esBulkProcessor) 119 | } 120 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/SparkEsMapper.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch 2 | 3 | import org.elasticsearch.index.VersionType 4 | 5 | /** 6 | * Extracts mappings from a T for an IndexRequest. 7 | * 8 | * @tparam T Object to extract mappings. 9 | */ 10 | trait SparkEsMapper[T] extends Serializable { 11 | /** 12 | * Extracts the document field/property name containing the document id. 13 | * 14 | * @param value Object to extract mappings. 15 | * @return The document field/property name containing the document id. 16 | */ 17 | def extractMappingId(value: T): Option[String] 18 | 19 | /** 20 | * Extracts the document field/property name containing the document parent. 21 | * 22 | * @param value Object to extract mappings. 23 | * @return The document field/property name containing the document parent. 24 | */ 25 | def extractMappingParent(value: T): Option[String] 26 | 27 | /** 28 | * Extracts the document field/property name containing the document version. 29 | * 30 | * @param value Object to extract mappings. 31 | * @return The document field/property name containing the document version. 32 | */ 33 | def extractMappingVersion(value: T): Option[Long] 34 | 35 | /** 36 | * Extracts the type of versioning used. 37 | * 38 | * @param value Object to extract mappings. 39 | * @return The type of versioning used.. 40 | */ 41 | def extractMappingVersionType(value: T): Option[VersionType] 42 | 43 | /** 44 | * Extracts the document field/property name containing the document routing. 45 | * 46 | * @param value Object to extract mappings. 47 | * @return The document field/property name containing the document routing. 48 | */ 49 | def extractMappingRouting(value: T): Option[String] 50 | 51 | /** 52 | * Extracts the document field/property name containing the document time-to-live. 53 | * 54 | * @param value Object to extract mappings. 55 | * @return The document field/property name containing the document time-to-live. 56 | */ 57 | def extractMappingTTLInMillis(value: T): Option[Long] 58 | 59 | /** 60 | * Extracts the document field/property name containing the document timestamp. 61 | * 62 | * @param value Object to extract mappings. 63 | * @return The document field/property name containing the document timestamp. 64 | */ 65 | def extractMappingTimestamp(value: T): Option[String] 66 | } 67 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/SparkEsSerializer.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch 2 | 3 | /** 4 | * Serializes a T into an Array[Byte] for an IndexRequest. 5 | * 6 | * @tparam T T Object to serialize. 7 | */ 8 | trait SparkEsSerializer[T] extends Serializable { 9 | /** 10 | * Serialize a T from a DataFrame into an Array[Byte]. 11 | * 12 | * @param value A T 13 | * @return The source T as Array[Byte]. 14 | */ 15 | def write(value: T): Array[Byte] 16 | } 17 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsMapperConf.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.conf 2 | 3 | import com.github.jparkie.spark.elasticsearch.util.SparkEsConfParam 4 | import org.apache.spark.SparkConf 5 | 6 | /** 7 | * https://www.elastic.co/guide/en/elasticsearch/hadoop/current/configuration.html#cfg-mapping 8 | * 9 | * @param esMappingId The document field/property name containing the document id. 10 | * @param esMappingParent The document field/property name containing the document parent. 11 | * To specify a constant, use the format. 12 | * @param esMappingVersion The document field/property name containing the document version. 13 | * To specify a constant, use the format. 14 | * @param esMappingVersionType Indicates the type of versioning used. 15 | * http://www.elastic.co/guide/en/elasticsearch/reference/2.0/docs-index_.html#_version_types 16 | * If es.mapping.version is undefined (default), its value is unspecified. 17 | * If es.mapping.version is specified, its value becomes external. 18 | * @param esMappingRouting The document field/property name containing the document routing. 19 | * To specify a constant, use the format. 20 | * @param esMappingTTLInMillis The document field/property name containing the document time-to-live. 21 | * To specify a constant, use the format. 22 | * @param esMappingTimestamp The document field/property name containing the document timestamp. 23 | * To specify a constant, use the format. 24 | */ 25 | 26 | case class SparkEsMapperConf( 27 | esMappingId: Option[String], 28 | esMappingParent: Option[String], 29 | esMappingVersion: Option[String], 30 | esMappingVersionType: Option[String], 31 | esMappingRouting: Option[String], 32 | esMappingTTLInMillis: Option[String], 33 | esMappingTimestamp: Option[String] 34 | ) extends Serializable 35 | 36 | object SparkEsMapperConf { 37 | val CONSTANT_FIELD_REGEX = """\<([^>]+)\>""".r 38 | 39 | val ES_MAPPING_ID = SparkEsConfParam[Option[String]]( 40 | name = "es.mapping.id", 41 | default = None 42 | ) 43 | val ES_MAPPING_PARENT = SparkEsConfParam[Option[String]]( 44 | name = "es.mapping.parent", 45 | default = None 46 | ) 47 | val ES_MAPPING_VERSION = SparkEsConfParam[Option[String]]( 48 | name = "es.mapping.version", 49 | default = None 50 | ) 51 | val ES_MAPPING_VERSION_TYPE = SparkEsConfParam[Option[String]]( 52 | name = "es.mapping.version.type", 53 | default = None 54 | ) 55 | val ES_MAPPING_ROUTING = SparkEsConfParam[Option[String]]( 56 | name = "es.mapping.routing", 57 | default = None 58 | ) 59 | val ES_MAPPING_TTL_IN_MILLIS = SparkEsConfParam[Option[String]]( 60 | name = "es.mapping.ttl", 61 | default = None 62 | ) 63 | val ES_MAPPING_TIMESTAMP = SparkEsConfParam[Option[String]]( 64 | name = "es.mapping.timestamp", 65 | default = None 66 | ) 67 | 68 | /** 69 | * Extracts SparkEsMapperConf from a SparkConf. 70 | * 71 | * @param sparkConf A SparkConf. 72 | * @return A SparkEsMapperConf from a SparkConf. 73 | */ 74 | def fromSparkConf(sparkConf: SparkConf): SparkEsMapperConf = { 75 | SparkEsMapperConf( 76 | esMappingId = ES_MAPPING_ID.fromConf(sparkConf)((sc, name) => sc.getOption(name)), 77 | esMappingParent = ES_MAPPING_PARENT.fromConf(sparkConf)((sc, name) => sc.getOption(name)), 78 | esMappingVersion = ES_MAPPING_VERSION.fromConf(sparkConf)((sc, name) => sc.getOption(name)), 79 | esMappingVersionType = ES_MAPPING_VERSION_TYPE.fromConf(sparkConf)((sc, name) => sc.getOption(name)), 80 | esMappingRouting = ES_MAPPING_ROUTING.fromConf(sparkConf)((sc, name) => sc.getOption(name)), 81 | esMappingTTLInMillis = ES_MAPPING_TTL_IN_MILLIS.fromConf(sparkConf)((sc, name) => sc.getOption(name)), 82 | esMappingTimestamp = ES_MAPPING_TIMESTAMP.fromConf(sparkConf)((sc, name) => sc.getOption(name)) 83 | ) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsTransportClientConf.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.conf 2 | 3 | import java.net.InetSocketAddress 4 | 5 | import com.github.jparkie.spark.elasticsearch.util.SparkEsConfParam 6 | import org.apache.spark.SparkConf 7 | 8 | import scala.collection.mutable 9 | 10 | /** 11 | * Configurations for EsNativeDataFrameWriter's TransportClient. 12 | * 13 | * @param transportAddresses The minimum set of hosts to connect to when establishing a client. 14 | * CONFIG_CLIENT_TRANSPORT_SNIFF is enabled by default. 15 | * @param transportPort The port to connect when establishing a client. 16 | * @param transportSettings Miscellaneous settings for the TransportClient. 17 | * Empty by default. 18 | */ 19 | case class SparkEsTransportClientConf( 20 | transportAddresses: Seq[String], 21 | transportPort: Int, 22 | transportSettings: Map[String, String] 23 | ) extends Serializable 24 | 25 | object SparkEsTransportClientConf { 26 | val CONFIG_CLUSTER_NAME = "cluster.name" 27 | val CONFIG_CLIENT_TRANSPORT_SNIFF = "client.transport.sniff" 28 | val CONFIG_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME = "client.transport.ignore_cluster_name" 29 | val CONFIG_CLIENT_TRANSPORT_PING_TIMEOUT = "client.transport.ping_timeout" 30 | val CONFIG_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL = "client.transport.nodes_sampler_interval" 31 | 32 | val ES_NODES = SparkEsConfParam[Seq[String]]( 33 | name = "es.nodes", 34 | default = Seq.empty[String] 35 | ) 36 | val ES_PORT = SparkEsConfParam[Int]( 37 | name = "es.port", 38 | default = 9300 39 | ) 40 | val ES_CLUSTER_NAME = SparkEsConfParam[String]( 41 | name = s"es.$CONFIG_CLUSTER_NAME", 42 | default = null 43 | ) 44 | val ES_CLIENT_TRANSPORT_SNIFF = SparkEsConfParam[String]( 45 | name = s"es.$CONFIG_CLIENT_TRANSPORT_SNIFF", 46 | default = null 47 | ) 48 | val ES_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME = SparkEsConfParam[String]( 49 | name = s"es.$CONFIG_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME", 50 | default = null 51 | ) 52 | val ES_CLIENT_TRANSPORT_PING_TIMEOUT = SparkEsConfParam[String]( 53 | name = s"es.$CONFIG_CLIENT_TRANSPORT_PING_TIMEOUT", 54 | default = null 55 | ) 56 | val ES_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL = SparkEsConfParam[String]( 57 | name = s"es.$CONFIG_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL", 58 | default = null 59 | ) 60 | 61 | def getTransportAddresses(transportAddresses: Seq[String], transportPort: Int): Seq[InetSocketAddress] = { 62 | transportAddresses match { 63 | case null | Nil => throw new IllegalArgumentException("A contact point list cannot be empty.") 64 | case hosts => hosts map { 65 | ipWithPort => 66 | ipWithPort.split(":") match { 67 | case Array(actualHost, actualPort) => 68 | new InetSocketAddress(actualHost, actualPort.toInt) 69 | case Array(actualHost) => 70 | new InetSocketAddress(actualHost, transportPort) 71 | case errorMessage => 72 | throw new IllegalArgumentException(s"A contact point should have the form [host:port] or [host] but was: $errorMessage.") 73 | } 74 | } 75 | } 76 | } 77 | 78 | /** 79 | * Extracts SparkEsTransportClientConf from a SparkConf. 80 | * 81 | * @param sparkConf A SparkConf. 82 | * @return A SparkEsTransportClientConf from a SparkConf. 83 | */ 84 | def fromSparkConf(sparkConf: SparkConf): SparkEsTransportClientConf = { 85 | val tempEsNodes = ES_NODES.fromConf(sparkConf)((sc, name) => sc.get(name).split(",")) 86 | val tempEsPort = ES_PORT.fromConf(sparkConf)((sc, name) => sc.getInt(name, ES_PORT.default)) 87 | val tempSettings = mutable.HashMap.empty[String, String] 88 | 89 | require( 90 | tempEsNodes.nonEmpty, 91 | s"""No nodes defined in property ${ES_NODES.name} is in SparkConf.""".stripMargin 92 | ) 93 | 94 | if (sparkConf.contains(ES_CLUSTER_NAME.name) || sparkConf.contains(s"spark.${ES_CLUSTER_NAME.name}")) 95 | tempSettings.put(CONFIG_CLUSTER_NAME, ES_CLUSTER_NAME.fromConf(sparkConf)((sc, name) => sc.get(name))) 96 | 97 | if (sparkConf.contains(ES_CLIENT_TRANSPORT_SNIFF.name) || sparkConf.contains(s"spark.${ES_CLIENT_TRANSPORT_SNIFF.name}")) 98 | tempSettings.put(CONFIG_CLIENT_TRANSPORT_SNIFF, ES_CLIENT_TRANSPORT_SNIFF.fromConf(sparkConf)((sc, name) => sc.get(name))) 99 | 100 | if (sparkConf.contains(ES_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME.name) || sparkConf.contains(s"spark.${ES_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME.name}")) 101 | tempSettings.put(CONFIG_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME, ES_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME.fromConf(sparkConf)((sc, name) => sc.get(name))) 102 | 103 | if (sparkConf.contains(ES_CLIENT_TRANSPORT_PING_TIMEOUT.name) || sparkConf.contains(s"spark.${ES_CLIENT_TRANSPORT_PING_TIMEOUT.name}")) 104 | tempSettings.put(CONFIG_CLIENT_TRANSPORT_PING_TIMEOUT, ES_CLIENT_TRANSPORT_PING_TIMEOUT.fromConf(sparkConf)((sc, name) => sc.get(name))) 105 | 106 | if (sparkConf.contains(ES_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL.name) || sparkConf.contains(s"spark.${ES_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL.name}")) 107 | tempSettings.put(CONFIG_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL, ES_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL.fromConf(sparkConf)((sc, name) => sc.get(name))) 108 | 109 | SparkEsTransportClientConf( 110 | transportAddresses = tempEsNodes, 111 | transportPort = tempEsPort, 112 | transportSettings = tempSettings.toMap 113 | ) 114 | } 115 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsWriteConf.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.conf 2 | 3 | import com.github.jparkie.spark.elasticsearch.util.SparkEsConfParam 4 | import org.apache.spark.SparkConf 5 | 6 | /** 7 | * Configurations for EsNativeDataFrameWriter's BulkProcessor. 8 | * 9 | * @param bulkActions The number of IndexRequests to batch in one request. 10 | * @param bulkSizeInMB The maximum size in MB of a batch. 11 | * @param concurrentRequests The number of concurrent requests in flight. 12 | * @param flushTimeoutInSeconds The maximum time in seconds to wait while closing a BulkProcessor. 13 | */ 14 | case class SparkEsWriteConf( 15 | bulkActions: Int, 16 | bulkSizeInMB: Int, 17 | concurrentRequests: Int, 18 | flushTimeoutInSeconds: Long 19 | ) extends Serializable 20 | 21 | object SparkEsWriteConf { 22 | val BULK_ACTIONS = SparkEsConfParam[Int]( 23 | name = "es.batch.size.entries", 24 | default = 1000 25 | ) 26 | val BULK_SIZE_IN_MB = SparkEsConfParam[Int]( 27 | name = "es.batch.size.bytes", 28 | default = 5 29 | ) 30 | val CONCURRENT_REQUESTS = SparkEsConfParam[Int]( 31 | name = "es.batch.concurrent.request", 32 | default = 1 33 | ) 34 | val FLUSH_TIMEOUT_IN_SECONDS = SparkEsConfParam[Long]( 35 | name = "es.batch.flush.timeout", 36 | default = 10 37 | ) 38 | 39 | /** 40 | * Extracts SparkEsTransportClientConf from a SparkConf. 41 | * 42 | * @param sparkConf A SparkConf. 43 | * @return A SparkEsTransportClientConf from a SparkConf. 44 | */ 45 | def fromSparkConf(sparkConf: SparkConf): SparkEsWriteConf = { 46 | SparkEsWriteConf( 47 | bulkActions = BULK_ACTIONS.fromConf(sparkConf)((sc, name) => sc.getInt(name, BULK_ACTIONS.default)), 48 | bulkSizeInMB = BULK_SIZE_IN_MB.fromConf(sparkConf)((sc, name) => sc.getInt(name, BULK_SIZE_IN_MB.default)), 49 | concurrentRequests = CONCURRENT_REQUESTS.fromConf(sparkConf)((sc, name) => sc.getInt(name, CONCURRENT_REQUESTS.default)), 50 | flushTimeoutInSeconds = FLUSH_TIMEOUT_IN_SECONDS.fromConf(sparkConf)((sc, name) => sc.getLong(name, FLUSH_TIMEOUT_IN_SECONDS.default)) 51 | ) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameFunctions.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.sql 2 | 3 | import com.github.jparkie.spark.elasticsearch.SparkEsBulkWriter 4 | import com.github.jparkie.spark.elasticsearch.conf.{ SparkEsMapperConf, SparkEsTransportClientConf, SparkEsWriteConf } 5 | import com.github.jparkie.spark.elasticsearch.transport.SparkEsTransportClientManager 6 | import org.apache.spark.sql.{ DataFrame, Row } 7 | 8 | /** 9 | * Extension of DataFrame with 'bulkLoadToEs()' function. 10 | * 11 | * @param dataFrame The DataFrame to lift into extension. 12 | */ 13 | class SparkEsDataFrameFunctions(dataFrame: DataFrame) extends Serializable { 14 | /** 15 | * SparkContext to schedule SparkEsWriter Tasks. 16 | */ 17 | private[sql] val sparkContext = dataFrame.sqlContext.sparkContext 18 | 19 | /** 20 | * Upserts DataFrame into Elasticsearch with the Java API utilizing a TransportClient. 21 | * 22 | * @param esIndex Index of DataFrame in Elasticsearch. 23 | * @param esType Type of DataFrame in Elasticsearch. 24 | * @param sparkEsTransportClientConf Configurations for the TransportClient. 25 | * @param sparkEsMapperConf Configurations for IndexRequest. 26 | * @param sparkEsWriteConf Configurations for the BulkProcessor. 27 | * Empty by default. 28 | */ 29 | def bulkLoadToEs( 30 | esIndex: String, 31 | esType: String, 32 | sparkEsTransportClientConf: SparkEsTransportClientConf = SparkEsTransportClientConf.fromSparkConf(sparkContext.getConf), 33 | sparkEsMapperConf: SparkEsMapperConf = SparkEsMapperConf.fromSparkConf(sparkContext.getConf), 34 | sparkEsWriteConf: SparkEsWriteConf = SparkEsWriteConf.fromSparkConf(sparkContext.getConf) 35 | )(implicit sparkEsTransportClientManager: SparkEsTransportClientManager = sparkEsTransportClientManager): Unit = { 36 | val sparkEsWriter = new SparkEsBulkWriter[Row]( 37 | esIndex = esIndex, 38 | esType = esType, 39 | esClient = () => sparkEsTransportClientManager.getTransportClient(sparkEsTransportClientConf), 40 | sparkEsSerializer = new SparkEsDataFrameSerializer(dataFrame.schema), 41 | sparkEsMapper = new SparkEsDataFrameMapper(sparkEsMapperConf), 42 | sparkEsWriteConf = sparkEsWriteConf 43 | ) 44 | 45 | sparkContext.runJob(dataFrame.rdd, sparkEsWriter.write _) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameMapper.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.sql 2 | 3 | import com.github.jparkie.spark.elasticsearch.SparkEsMapper 4 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsMapperConf 5 | import org.apache.spark.sql.Row 6 | import org.elasticsearch.index.VersionType 7 | 8 | /** 9 | * Extracts mappings from a Row for an IndexRequest. 10 | * 11 | * @param mapperConf Configurations for IndexRequest. 12 | */ 13 | class SparkEsDataFrameMapper(mapperConf: SparkEsMapperConf) extends SparkEsMapper[Row] { 14 | import SparkEsDataFrameMapper._ 15 | import SparkEsMapperConf._ 16 | 17 | /** 18 | * Extracts the document field/property name containing the document id. 19 | * 20 | * @param value Object to extract mappings. 21 | * @return The document field/property name containing the document id. 22 | */ 23 | override def extractMappingId(value: Row): Option[String] = { 24 | mapperConf.esMappingId 25 | .map(currentFieldName => value.getAsToString(currentFieldName)) 26 | } 27 | 28 | /** 29 | * Extracts the document field/property name containing the document parent. 30 | * 31 | * @param value Object to extract mappings. 32 | * @return The document field/property name containing the document parent. 33 | */ 34 | override def extractMappingParent(value: Row): Option[String] = { 35 | mapperConf.esMappingParent.map { 36 | case CONSTANT_FIELD_REGEX(constantValue) => 37 | constantValue 38 | case currentFieldName => 39 | value.getAsToString(currentFieldName) 40 | } 41 | } 42 | 43 | /** 44 | * Extracts the document field/property name containing the document version. 45 | * 46 | * @param value Object to extract mappings. 47 | * @return The document field/property name containing the document version. 48 | */ 49 | override def extractMappingVersion(value: Row): Option[Long] = { 50 | mapperConf.esMappingVersion.map { 51 | case CONSTANT_FIELD_REGEX(constantValue) => 52 | constantValue.toLong 53 | case currentFieldName => 54 | value.getAsToString(currentFieldName).toLong 55 | } 56 | } 57 | 58 | /** 59 | * Extracts the type of versioning used. 60 | * 61 | * @param value Object to extract mappings. 62 | * @return The type of versioning used.. 63 | */ 64 | override def extractMappingVersionType(value: Row): Option[VersionType] = { 65 | mapperConf.esMappingVersion 66 | .flatMap(_ => mapperConf.esMappingVersionType.map(VersionType.fromString)) 67 | } 68 | 69 | /** 70 | * Extracts the document field/property name containing the document routing. 71 | * 72 | * @param value Object to extract mappings. 73 | * @return The document field/property name containing the document routing. 74 | */ 75 | override def extractMappingRouting(value: Row): Option[String] = { 76 | mapperConf.esMappingRouting.map { 77 | case CONSTANT_FIELD_REGEX(constantValue) => 78 | constantValue 79 | case currentFieldName => 80 | value.getAsToString(currentFieldName) 81 | } 82 | } 83 | 84 | /** 85 | * Extracts the document field/property name containing the document time-to-live. 86 | * 87 | * @param value Object to extract mappings. 88 | * @return The document field/property name containing the document time-to-live. 89 | */ 90 | override def extractMappingTTLInMillis(value: Row): Option[Long] = { 91 | mapperConf.esMappingTTLInMillis.map { 92 | case CONSTANT_FIELD_REGEX(constantValue) => 93 | constantValue.toLong 94 | case currentFieldName => 95 | value.getAsToString(currentFieldName).toLong 96 | } 97 | } 98 | 99 | /** 100 | * Extracts the document field/property name containing the document timestamp. 101 | * 102 | * @param value Object to extract mappings. 103 | * @return The document field/property name containing the document timestamp. 104 | */ 105 | override def extractMappingTimestamp(value: Row): Option[String] = { 106 | mapperConf.esMappingTimestamp.map { 107 | case CONSTANT_FIELD_REGEX(constantValue) => 108 | constantValue 109 | case currentFieldName => 110 | value.getAsToString(currentFieldName) 111 | } 112 | } 113 | } 114 | 115 | object SparkEsDataFrameMapper { 116 | /** 117 | * Adds method to retrieve field as String through Any's toString. 118 | * 119 | * @param currentRow Row object. 120 | */ 121 | implicit class RichRow(currentRow: Row) { 122 | // TODO: Find a better and safer way. 123 | def getAsToString(fieldName: String): String = { 124 | currentRow.getAs[Any](fieldName).toString 125 | } 126 | } 127 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameSerializer.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.sql 2 | 3 | import java.sql.{ Date, Timestamp } 4 | 5 | import com.github.jparkie.spark.elasticsearch.SparkEsSerializer 6 | import org.apache.spark.sql.Row 7 | import org.apache.spark.sql.types._ 8 | import org.elasticsearch.common.xcontent.{ XContentBuilder, XContentFactory } 9 | 10 | import scala.collection.JavaConverters._ 11 | 12 | /** 13 | * Serializes a Row from a DataFrame into an Array[Byte]. 14 | * 15 | * @param schema The StructType of a DataFrame. 16 | */ 17 | class SparkEsDataFrameSerializer(schema: StructType) extends SparkEsSerializer[Row] { 18 | /** 19 | * Serializes a Row from a DataFrame into an Array[Byte]. 20 | * 21 | * @param value A Row. 22 | * @return The source JSON as Array[Byte]. 23 | */ 24 | override def write(value: Row): Array[Byte] = { 25 | val currentJsonBuilder = XContentFactory.jsonBuilder() 26 | 27 | write(schema, value, currentJsonBuilder) 28 | 29 | currentJsonBuilder 30 | .bytes() 31 | .toBytes 32 | } 33 | 34 | private[sql] def write(dataType: DataType, value: Any, builder: XContentBuilder): XContentBuilder = { 35 | dataType match { 36 | case structType @ StructType(_) => writeStruct(structType, value, builder) 37 | case arrayType @ ArrayType(_, _) => writeArray(arrayType, value, builder) 38 | case mapType @ MapType(_, _, _) => writeMap(mapType, value, builder) 39 | case _ => writePrimitive(dataType, value, builder) 40 | } 41 | } 42 | 43 | private[sql] def writeStruct(structType: StructType, value: Any, builder: XContentBuilder): XContentBuilder = { 44 | value match { 45 | case currentRow: Row => 46 | builder.startObject() 47 | 48 | structType.fields.view.zipWithIndex foreach { 49 | case (field, index) => 50 | builder.field(field.name) 51 | if (currentRow.isNullAt(index)) { 52 | builder.nullValue() 53 | } else { 54 | write(field.dataType, currentRow(index), builder) 55 | } 56 | } 57 | 58 | builder.endObject() 59 | } 60 | 61 | builder 62 | } 63 | 64 | private[sql] def writeArray(arrayType: ArrayType, value: Any, builder: XContentBuilder): XContentBuilder = { 65 | value match { 66 | case array: Array[_] => 67 | serializeArray(arrayType.elementType, array, builder) 68 | case seq: Seq[_] => 69 | serializeArray(arrayType.elementType, seq, builder) 70 | case _ => 71 | throw new IllegalArgumentException(s"Unknown ArrayType: $value.") 72 | } 73 | } 74 | 75 | private[sql] def serializeArray(dataType: DataType, value: Seq[_], builder: XContentBuilder): XContentBuilder = { 76 | // TODO: Consider utilizing builder.value(Iterable[_]). 77 | builder.startArray() 78 | 79 | if (value != null) { 80 | value foreach { element => 81 | write(dataType, element, builder) 82 | } 83 | } 84 | 85 | builder.endArray() 86 | builder 87 | } 88 | 89 | private[sql] def writeMap(mapType: MapType, value: Any, builder: XContentBuilder): XContentBuilder = { 90 | value match { 91 | case scalaMap: scala.collection.Map[_, _] => 92 | serializeMap(mapType, scalaMap, builder) 93 | case javaMap: java.util.Map[_, _] => 94 | serializeMap(mapType, javaMap.asScala, builder) 95 | case _ => 96 | throw new IllegalArgumentException(s"Unknown MapType: $value.") 97 | } 98 | } 99 | 100 | private[sql] def serializeMap(mapType: MapType, value: scala.collection.Map[_, _], builder: XContentBuilder): XContentBuilder = { 101 | // TODO: Consider utilizing builder.value(Map[_, AnyRef]). 102 | builder.startObject() 103 | 104 | for ((currentKey, currentValue) <- value) { 105 | builder.field(currentKey.toString) 106 | write(mapType.valueType, currentValue, builder) 107 | } 108 | 109 | builder.endObject() 110 | builder 111 | } 112 | 113 | private[sql] def writePrimitive(dataType: DataType, value: Any, builder: XContentBuilder): XContentBuilder = { 114 | dataType match { 115 | case BinaryType => builder.value(value.asInstanceOf[Array[Byte]]) 116 | case BooleanType => builder.value(value.asInstanceOf[Boolean]) 117 | case ByteType => builder.value(value.asInstanceOf[Byte]) 118 | case ShortType => builder.value(value.asInstanceOf[Short]) 119 | case IntegerType => builder.value(value.asInstanceOf[Int]) 120 | case LongType => builder.value(value.asInstanceOf[Long]) 121 | case DoubleType => builder.value(value.asInstanceOf[Double]) 122 | case FloatType => builder.value(value.asInstanceOf[Float]) 123 | case TimestampType => builder.value(value.asInstanceOf[Timestamp].getTime) 124 | case DateType => builder.value(value.asInstanceOf[Date].getTime) 125 | case StringType => builder.value(value.toString) 126 | case _ => 127 | throw new IllegalArgumentException(s"Unknown DataType: $value.") 128 | } 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/sql/package.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch 2 | 3 | import com.github.jparkie.spark.elasticsearch.transport.SparkEsTransportClientManager 4 | import org.apache.spark.sql.DataFrame 5 | 6 | package object sql { 7 | implicit val sparkEsTransportClientManager = SparkEsTransportClientManager 8 | 9 | /** 10 | * Implicitly lift a DataFrame with SparkEsDataFrameFunctions. 11 | * 12 | * @param dataFrame A DataFrame to lift. 13 | * @return Enriched DataFrame with SparkEsDataFrameFunctions. 14 | */ 15 | implicit def sparkEsDataFrameFunctions(dataFrame: DataFrame): SparkEsDataFrameFunctions = new SparkEsDataFrameFunctions(dataFrame) 16 | } 17 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/transport/SparkEsTransportClientManager.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.transport 2 | 3 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsTransportClientConf 4 | import org.apache.spark.Logging 5 | import org.elasticsearch.client.Client 6 | import org.elasticsearch.client.transport.TransportClient 7 | import org.elasticsearch.common.settings.Settings 8 | import org.elasticsearch.common.transport.InetSocketTransportAddress 9 | 10 | import scala.collection.mutable 11 | 12 | private[elasticsearch] trait SparkEsTransportClientManager extends Serializable with Logging { 13 | @transient 14 | private[transport] val internalTransportClients = mutable.HashMap.empty[SparkEsTransportClientConf, TransportClient] 15 | 16 | private[transport] def buildTransportSettings(clientConf: SparkEsTransportClientConf): Settings = { 17 | val esSettingsBuilder = Settings.builder() 18 | 19 | clientConf.transportSettings foreach { currentSetting => 20 | esSettingsBuilder.put(currentSetting._1, currentSetting._2) 21 | } 22 | 23 | esSettingsBuilder.build() 24 | } 25 | 26 | private[transport] def buildTransportClient(clientConf: SparkEsTransportClientConf, esSettings: Settings): TransportClient = { 27 | import SparkEsTransportClientConf._ 28 | 29 | val esClient = TransportClient.builder() 30 | .settings(esSettings) 31 | .build() 32 | 33 | getTransportAddresses(clientConf.transportAddresses, clientConf.transportPort) foreach { inetSocketAddress => 34 | esClient.addTransportAddresses(new InetSocketTransportAddress(inetSocketAddress)) 35 | } 36 | 37 | sys.addShutdownHook { 38 | logInfo("Closed Elasticsearch Transport Client.") 39 | 40 | esClient.close() 41 | } 42 | 43 | logInfo(s"Connected to the following Elasticsearch nodes: ${esClient.connectedNodes()}.") 44 | 45 | esClient 46 | } 47 | 48 | /** 49 | * Gets or creates a TransportClient per JVM. 50 | * 51 | * @param clientConf Settings and initial endpoints for connection. 52 | * @return SparkEsTransportClientProxy as Client. 53 | */ 54 | def getTransportClient(clientConf: SparkEsTransportClientConf): Client = synchronized { 55 | internalTransportClients.get(clientConf) match { 56 | case Some(transportClient) => 57 | new SparkEsTransportClientProxy(transportClient) 58 | case None => 59 | val transportSettings = buildTransportSettings(clientConf) 60 | val transportClient = buildTransportClient(clientConf, transportSettings) 61 | internalTransportClients.put(clientConf, transportClient) 62 | new SparkEsTransportClientProxy(transportClient) 63 | } 64 | } 65 | 66 | /** 67 | * Evicts and closes a TransportClient. 68 | * 69 | * @param clientConf Settings and initial endpoints for connection. 70 | */ 71 | def closeTransportClient(clientConf: SparkEsTransportClientConf): Unit = synchronized { 72 | internalTransportClients.remove(clientConf) match { 73 | case Some(transportClient) => 74 | transportClient.close() 75 | case None => 76 | logError(s"No TransportClient for $clientConf.") 77 | } 78 | } 79 | } 80 | 81 | object SparkEsTransportClientManager extends SparkEsTransportClientManager -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/transport/SparkEsTransportClientProxy.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.transport 2 | 3 | import com.github.jparkie.spark.elasticsearch.util.SparkEsException 4 | import org.elasticsearch.client.{ FilterClient, Client } 5 | 6 | /** 7 | * Restrict access to TransportClient by disabling close() without use of SparkEsTransportClientManager. 8 | */ 9 | class SparkEsTransportClientProxy(client: Client) extends FilterClient(client) { 10 | override def close(): Unit = { 11 | throw new SparkEsException("close() is not supported in SparkEsTransportClientProxy. Please close with SparkEsTransportClientManager.") 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/util/SparkEsConfParam.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.util 2 | 3 | import org.apache.spark.SparkConf 4 | 5 | /** 6 | * Defines parameter to extract values from SparkConf. 7 | * 8 | * @param name The key in SparkConf. 9 | * @param default The default value to fallback on missing key. 10 | */ 11 | case class SparkEsConfParam[T](name: String, default: T) { 12 | def fromConf(sparkConf: SparkConf)(sparkConfFunc: (SparkConf, String) => T): T = { 13 | if (sparkConf.contains(name)) { 14 | sparkConfFunc(sparkConf, name) 15 | } else if (sparkConf.contains(s"spark.$name")) { 16 | sparkConfFunc(sparkConf, s"spark.$name") 17 | } else { 18 | default 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/scala/com/github/jparkie/spark/elasticsearch/util/SparkEsException.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.util 2 | 3 | /** 4 | * General exceptions captured by EsNativeDataFrameBulkProcessorListener. 5 | * 6 | * @param message the detail message (which is saved for later retrieval 7 | * by the { @link #getMessage()} method). 8 | * @param cause the cause (which is saved for later retrieval by the 9 | * { @link #getCause()} method). (A null value is 10 | * permitted, and indicates that the cause is nonexistent or 11 | * unknown.) 12 | */ 13 | class SparkEsException(message: String, cause: Throwable) extends RuntimeException(message, cause) with Serializable { 14 | def this() = this(null, null) 15 | def this(message: String) = this(message, null) 16 | def this(cause: Throwable) = this(null, cause) 17 | } -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/ElasticSearchServer.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch 2 | 3 | import java.nio.file.Files 4 | 5 | import org.apache.commons.io.FileUtils 6 | import org.elasticsearch.client.Client 7 | import org.elasticsearch.common.settings.Settings 8 | import org.elasticsearch.node.{ Node, NodeBuilder } 9 | 10 | /** 11 | * Local Elasticsearch server of one node for integration testing. 12 | */ 13 | class ElasticSearchServer { 14 | private val homeDir = Files.createTempDirectory("elasticsearch").toFile 15 | private val dataDir = Files.createTempDirectory("elasticsearch").toFile 16 | 17 | val clusterName = "Spark2Elasticsearch" 18 | 19 | private lazy val internalNode: Node = { 20 | val tempSettings = Settings.builder() 21 | .put("path.home", homeDir.getAbsolutePath) 22 | .put("path.data", dataDir.getAbsolutePath) 23 | .put("es.logger.level", "OFF") 24 | .build() 25 | val tempNode = NodeBuilder.nodeBuilder() 26 | .clusterName(clusterName) 27 | .local(true) 28 | .data(true) 29 | .settings(tempSettings) 30 | .build() 31 | 32 | tempNode 33 | } 34 | 35 | /** 36 | * Fetch a client to the Elasticsearch cluster. 37 | * 38 | * @return A Client. 39 | */ 40 | def client: Client = { 41 | internalNode.client() 42 | } 43 | 44 | /** 45 | * Start the Elasticsearch cluster. 46 | */ 47 | def start(): Unit = { 48 | internalNode.start() 49 | } 50 | 51 | /** 52 | * Stop the Elasticsearch cluster. 53 | */ 54 | def stop(): Unit = { 55 | internalNode.close() 56 | 57 | try { 58 | FileUtils.forceDelete(homeDir) 59 | FileUtils.forceDelete(dataDir) 60 | } catch { 61 | case e: Exception => 62 | // Do Nothing. 63 | } 64 | } 65 | 66 | /** 67 | * Create Index. 68 | * 69 | * @param index The name of Index. 70 | */ 71 | def createAndWaitForIndex(index: String): Unit = { 72 | client.admin.indices.prepareCreate(index).execute.actionGet() 73 | client.admin.cluster.prepareHealth(index).setWaitForActiveShards(1).execute.actionGet() 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/SparkEsBulkWriterSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch 2 | 3 | import com.github.jparkie.spark.elasticsearch.conf.{ SparkEsMapperConf, SparkEsWriteConf } 4 | import com.github.jparkie.spark.elasticsearch.sql.{ SparkEsDataFrameMapper, SparkEsDataFrameSerializer } 5 | import com.holdenkarau.spark.testing.SharedSparkContext 6 | import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType } 7 | import org.apache.spark.sql.{ Row, SQLContext } 8 | import org.scalatest.{ MustMatchers, WordSpec } 9 | 10 | class SparkEsBulkWriterSpec extends WordSpec with MustMatchers with SharedSparkContext { 11 | val esServer = new ElasticSearchServer() 12 | 13 | override def beforeAll(): Unit = { 14 | super.beforeAll() 15 | 16 | esServer.start() 17 | } 18 | 19 | override def afterAll(): Unit = { 20 | esServer.stop() 21 | 22 | super.afterAll() 23 | } 24 | 25 | "SparkEsBulkWriter" must { 26 | "execute write() successfully" in { 27 | esServer.createAndWaitForIndex("test_index") 28 | 29 | val sqlContext = new SQLContext(sc) 30 | 31 | val inputSparkEsWriteConf = SparkEsWriteConf( 32 | bulkActions = 10, 33 | bulkSizeInMB = 1, 34 | concurrentRequests = 0, 35 | flushTimeoutInSeconds = 1 36 | ) 37 | val inputMapperConf = SparkEsMapperConf( 38 | esMappingId = Some("id"), 39 | esMappingParent = None, 40 | esMappingVersion = None, 41 | esMappingVersionType = None, 42 | esMappingRouting = None, 43 | esMappingTTLInMillis = None, 44 | esMappingTimestamp = None 45 | ) 46 | val inputSchema = StructType( 47 | Array( 48 | StructField("id", StringType, true), 49 | StructField("parent", StringType, true), 50 | StructField("version", LongType, true), 51 | StructField("routing", StringType, true), 52 | StructField("ttl", LongType, true), 53 | StructField("timestamp", StringType, true), 54 | StructField("value", LongType, true) 55 | ) 56 | ) 57 | val inputData = sc.parallelize { 58 | Array( 59 | Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L), 60 | Row("TEST_ID_1", "TEST_PARENT_2", 2L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 2L), 61 | Row("TEST_ID_1", "TEST_PARENT_3", 3L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 3L), 62 | Row("TEST_ID_1", "TEST_PARENT_4", 4L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 4L), 63 | Row("TEST_ID_1", "TEST_PARENT_5", 5L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 5L), 64 | Row("TEST_ID_5", "TEST_PARENT_6", 6L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 6L), 65 | Row("TEST_ID_6", "TEST_PARENT_7", 7L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 7L), 66 | Row("TEST_ID_7", "TEST_PARENT_8", 8L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 8L), 67 | Row("TEST_ID_8", "TEST_PARENT_9", 9L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 9L), 68 | Row("TEST_ID_9", "TEST_PARENT_10", 10L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 10L), 69 | Row("TEST_ID_10", "TEST_PARENT_11", 11L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 11L) 70 | ) 71 | } 72 | val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema) 73 | val inputDataIterator = inputDataFrame.rdd.toLocalIterator 74 | val inputSparkEsBulkWriter = new SparkEsBulkWriter[Row]( 75 | esIndex = "test_index", 76 | esType = "test_type", 77 | esClient = () => esServer.client, 78 | sparkEsSerializer = new SparkEsDataFrameSerializer(inputSchema), 79 | sparkEsMapper = new SparkEsDataFrameMapper(inputMapperConf), 80 | sparkEsWriteConf = inputSparkEsWriteConf 81 | ) 82 | 83 | inputSparkEsBulkWriter.write(null, inputDataIterator) 84 | 85 | val outputGetResponse = esServer.client.prepareGet("test_index", "test_type", "TEST_ID_1").get() 86 | 87 | outputGetResponse.isExists mustEqual true 88 | outputGetResponse.getSource.get("parent").asInstanceOf[String] mustEqual "TEST_PARENT_5" 89 | outputGetResponse.getSource.get("version").asInstanceOf[Integer] mustEqual 5 90 | outputGetResponse.getSource.get("routing").asInstanceOf[String] mustEqual "TEST_ROUTING_1" 91 | outputGetResponse.getSource.get("ttl").asInstanceOf[Integer] mustEqual 86400000 92 | outputGetResponse.getSource.get("timestamp").asInstanceOf[String] mustEqual "TEST_TIMESTAMP_1" 93 | outputGetResponse.getSource.get("value").asInstanceOf[Integer] mustEqual 5 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsMapperConfSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.conf 2 | 3 | import org.apache.spark.SparkConf 4 | import org.scalatest.{ MustMatchers, WordSpec } 5 | 6 | class SparkEsMapperConfSpec extends WordSpec with MustMatchers { 7 | "SparkEsMapperConf" must { 8 | "be extracted from SparkConf successfully" in { 9 | val inputSparkConf = new SparkConf() 10 | .set("es.mapping.id", "TEST_VALUE_1") 11 | .set("es.mapping.parent", "TEST_VALUE_2") 12 | .set("es.mapping.version", "TEST_VALUE_3") 13 | .set("es.mapping.version.type", "TEST_VALUE_4") 14 | .set("es.mapping.routing", "TEST_VALUE_5") 15 | .set("es.mapping.ttl", "TEST_VALUE_6") 16 | .set("es.mapping.timestamp", "TEST_VALUE_7") 17 | 18 | val expectedSparkEsMapperConf = SparkEsMapperConf( 19 | esMappingId = Some("TEST_VALUE_1"), 20 | esMappingParent = Some("TEST_VALUE_2"), 21 | esMappingVersion = Some("TEST_VALUE_3"), 22 | esMappingVersionType = Some("TEST_VALUE_4"), 23 | esMappingRouting = Some("TEST_VALUE_5"), 24 | esMappingTTLInMillis = Some("TEST_VALUE_6"), 25 | esMappingTimestamp = Some("TEST_VALUE_7") 26 | ) 27 | 28 | val outputSparkEsMapperConf = SparkEsMapperConf.fromSparkConf(inputSparkConf) 29 | 30 | outputSparkEsMapperConf mustEqual expectedSparkEsMapperConf 31 | } 32 | 33 | "extract CONSTANT_FIELD_REGEX successfully" in { 34 | val inputString = "" 35 | 36 | val expectedString = "TEST_VALUE_1" 37 | 38 | val outputString = inputString match { 39 | case SparkEsMapperConf.CONSTANT_FIELD_REGEX(outputString) => 40 | outputString 41 | case _ => 42 | fail("CONSTANT_FIELD_REGEX failed.") 43 | } 44 | 45 | outputString mustEqual expectedString 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsTransportClientConfSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.conf 2 | 3 | import java.net.InetSocketAddress 4 | 5 | import org.apache.spark.SparkConf 6 | import org.scalatest.{ MustMatchers, WordSpec } 7 | 8 | class SparkEsTransportClientConfSpec extends WordSpec with MustMatchers { 9 | "SparkEsTransportClientConf" must { 10 | "be extracted from SparkConf successfully" in { 11 | val inputSparkConf = new SparkConf() 12 | .set("es.nodes", "127.0.0.1:9000,127.0.0.1:9001,127.0.0.1:9002") 13 | .set("es.port", "1337") 14 | 15 | val expectedSparkEsTransportClientConf = SparkEsTransportClientConf( 16 | transportAddresses = Seq("127.0.0.1:9000", "127.0.0.1:9001", "127.0.0.1:9002"), 17 | transportPort = 1337, 18 | transportSettings = Map.empty[String, String] 19 | ) 20 | 21 | val outputSparkEsTransportClientConf = SparkEsTransportClientConf.fromSparkConf(inputSparkConf) 22 | 23 | outputSparkEsTransportClientConf mustEqual expectedSparkEsTransportClientConf 24 | } 25 | 26 | "be extracted from SparkConf unsuccessfully" in { 27 | val inputSparkConf = new SparkConf() 28 | 29 | val outputException = intercept[IllegalArgumentException] { 30 | SparkEsTransportClientConf.fromSparkConf(inputSparkConf) 31 | } 32 | 33 | outputException.getMessage must include("No nodes defined in property es.nodes is in SparkConf.") 34 | } 35 | 36 | "extract transportSettings successfully" in { 37 | val inputSparkConf = new SparkConf() 38 | .set("es.nodes", "127.0.0.1:9000,127.0.0.1:9001,127.0.0.1:9002") 39 | .set("es.port", "1337") 40 | .set("es.cluster.name", "TEST_VALUE_1") 41 | .set("es.client.transport.sniff", "TEST_VALUE_2") 42 | .set("es.client.transport.ignore_cluster_name", "TEST_VALUE_3") 43 | .set("es.client.transport.ping_timeout", "TEST_VALUE_4") 44 | .set("es.client.transport.nodes_sampler_interval", "TEST_VALUE_5") 45 | 46 | val expectedSparkEsTransportClientConf = SparkEsTransportClientConf( 47 | transportAddresses = Seq("127.0.0.1:9000", "127.0.0.1:9001", "127.0.0.1:9002"), 48 | transportPort = 1337, 49 | transportSettings = Map( 50 | "cluster.name" -> "TEST_VALUE_1", 51 | "client.transport.sniff" -> "TEST_VALUE_2", 52 | "client.transport.ignore_cluster_name" -> "TEST_VALUE_3", 53 | "client.transport.ping_timeout" -> "TEST_VALUE_4", 54 | "client.transport.nodes_sampler_interval" -> "TEST_VALUE_5" 55 | ) 56 | ) 57 | 58 | val outputSparkEsTransportClientConf = SparkEsTransportClientConf.fromSparkConf(inputSparkConf) 59 | 60 | outputSparkEsTransportClientConf mustEqual expectedSparkEsTransportClientConf 61 | } 62 | 63 | "extract transportAddresses as Seq[InetSocketAddress] successfully with port secondly" in { 64 | val inputAddresses = Seq("127.0.0.1:9000", "127.0.0.1:9001", "127.0.0.1:9002") 65 | val inputPort = 1337 66 | 67 | val expectedTransportAddresses = Seq( 68 | new InetSocketAddress("127.0.0.1", 9000), 69 | new InetSocketAddress("127.0.0.1", 9001), 70 | new InetSocketAddress("127.0.0.1", 9002) 71 | ) 72 | 73 | val outputTransportAddresses = SparkEsTransportClientConf.getTransportAddresses(inputAddresses, inputPort) 74 | 75 | outputTransportAddresses mustEqual expectedTransportAddresses 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsWriteConfSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.conf 2 | 3 | import org.apache.spark.SparkConf 4 | import org.scalatest.{ MustMatchers, WordSpec } 5 | 6 | class SparkEsWriteConfSpec extends WordSpec with MustMatchers { 7 | "SparkEsWriteConf" must { 8 | "be extracted from SparkConf successfully" in { 9 | val inputSparkConf = new SparkConf() 10 | .set("es.batch.size.entries", "1") 11 | .set("es.batch.size.bytes", "2") 12 | .set("es.batch.concurrent.request", "3") 13 | .set("es.batch.flush.timeout", "4") 14 | 15 | val expectedSparkEsWriteConf = SparkEsWriteConf( 16 | bulkActions = 1, 17 | bulkSizeInMB = 2, 18 | concurrentRequests = 3, 19 | flushTimeoutInSeconds = 4 20 | ) 21 | 22 | val outputSparkEsWriteConf = SparkEsWriteConf.fromSparkConf(inputSparkConf) 23 | 24 | outputSparkEsWriteConf mustEqual expectedSparkEsWriteConf 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/sql/PackageSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.sql 2 | 3 | import com.holdenkarau.spark.testing.SharedSparkContext 4 | import org.apache.spark.sql.SQLContext 5 | import org.scalatest.{ MustMatchers, WordSpec } 6 | 7 | class PackageSpec extends WordSpec with MustMatchers with SharedSparkContext { 8 | "Package com.github.jparkie.spark.elasticsearch.sql" must { 9 | "lift DataFrame into SparkEsDataFrameFunctions" in { 10 | 11 | val sqlContext = new SQLContext(sc) 12 | 13 | val inputData = Seq( 14 | ("TEST_VALUE_1", 1), 15 | ("TEST_VALUE_2", 2), 16 | ("TEST_VALUE_3", 3) 17 | ) 18 | 19 | val outputDataFrame = sqlContext.createDataFrame(inputData) 20 | .toDF("key", "value") 21 | 22 | // If sparkContext is available, DataFrame was lifted into SparkEsDataFrameFunctions. 23 | outputDataFrame.sparkContext 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameMapperSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.sql 2 | 3 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsMapperConf 4 | import com.holdenkarau.spark.testing.SharedSparkContext 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType } 7 | import org.apache.spark.sql.{ Row, SQLContext } 8 | import org.elasticsearch.index.VersionType 9 | import org.scalatest.{ MustMatchers, WordSpec } 10 | 11 | class SparkEsDataFrameMapperSpec extends WordSpec with MustMatchers with SharedSparkContext { 12 | def createInputRow(sparkContext: SparkContext): Row = { 13 | val sqlContext = new SQLContext(sparkContext) 14 | 15 | val inputSchema = StructType( 16 | Array( 17 | StructField("id", StringType, true), 18 | StructField("parent", StringType, true), 19 | StructField("version", LongType, true), 20 | StructField("routing", StringType, true), 21 | StructField("ttl", LongType, true), 22 | StructField("timestamp", StringType, true), 23 | StructField("value", LongType, true) 24 | ) 25 | ) 26 | val inputData = sc.parallelize { 27 | Array( 28 | Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L) 29 | ) 30 | } 31 | val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema) 32 | val inputRow = inputDataFrame.first() 33 | 34 | inputRow 35 | } 36 | 37 | "SparkEsDataFrameMapper" must { 38 | "from Row extract noting successfully" in { 39 | val inputRow = createInputRow(sc) 40 | val inputMapperConf = SparkEsMapperConf( 41 | esMappingId = None, 42 | esMappingParent = None, 43 | esMappingVersion = None, 44 | esMappingVersionType = None, 45 | esMappingRouting = None, 46 | esMappingTTLInMillis = None, 47 | esMappingTimestamp = None 48 | ) 49 | 50 | val outputSparkEsDataFrameMapper = new SparkEsDataFrameMapper(inputMapperConf) 51 | 52 | outputSparkEsDataFrameMapper.extractMappingId(inputRow) mustEqual None 53 | outputSparkEsDataFrameMapper.extractMappingParent(inputRow) mustEqual None 54 | outputSparkEsDataFrameMapper.extractMappingVersion(inputRow) mustEqual None 55 | outputSparkEsDataFrameMapper.extractMappingVersionType(inputRow) mustEqual None 56 | outputSparkEsDataFrameMapper.extractMappingRouting(inputRow) mustEqual None 57 | outputSparkEsDataFrameMapper.extractMappingTTLInMillis(inputRow) mustEqual None 58 | outputSparkEsDataFrameMapper.extractMappingTimestamp(inputRow) mustEqual None 59 | } 60 | 61 | "from Row execute extractMappingId() successfully" in { 62 | val inputRow = createInputRow(sc) 63 | val inputMapperConf = SparkEsMapperConf( 64 | esMappingId = Some("id"), 65 | esMappingParent = None, 66 | esMappingVersion = None, 67 | esMappingVersionType = None, 68 | esMappingRouting = None, 69 | esMappingTTLInMillis = None, 70 | esMappingTimestamp = None 71 | ) 72 | 73 | val outputSparkEsDataFrameMapper = new SparkEsDataFrameMapper(inputMapperConf) 74 | 75 | outputSparkEsDataFrameMapper.extractMappingId(inputRow) mustEqual Some("TEST_ID_1") 76 | } 77 | 78 | "from Row execute extractMappingParent() successfully" in { 79 | val inputRow = createInputRow(sc) 80 | val inputMapperConf1 = SparkEsMapperConf( 81 | esMappingId = None, 82 | esMappingParent = Some("parent"), 83 | esMappingVersion = None, 84 | esMappingVersionType = None, 85 | esMappingRouting = None, 86 | esMappingTTLInMillis = None, 87 | esMappingTimestamp = None 88 | ) 89 | val inputMapperConf2 = SparkEsMapperConf( 90 | esMappingId = None, 91 | esMappingParent = Some(""), 92 | esMappingVersion = None, 93 | esMappingVersionType = None, 94 | esMappingRouting = None, 95 | esMappingTTLInMillis = None, 96 | esMappingTimestamp = None 97 | ) 98 | 99 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1) 100 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2) 101 | 102 | outputSparkEsDataFrameMapper1.extractMappingParent(inputRow) mustEqual Some("TEST_PARENT_1") 103 | outputSparkEsDataFrameMapper2.extractMappingParent(inputRow) mustEqual Some("TEST_VALUE") 104 | } 105 | 106 | "from Row execute extractMappingVersion() successfully" in { 107 | val inputRow = createInputRow(sc) 108 | val inputMapperConf1 = SparkEsMapperConf( 109 | esMappingId = None, 110 | esMappingParent = None, 111 | esMappingVersion = Some("version"), 112 | esMappingVersionType = None, 113 | esMappingRouting = None, 114 | esMappingTTLInMillis = None, 115 | esMappingTimestamp = None 116 | ) 117 | val inputMapperConf2 = SparkEsMapperConf( 118 | esMappingId = None, 119 | esMappingParent = None, 120 | esMappingVersion = Some("<1337>"), 121 | esMappingVersionType = None, 122 | esMappingRouting = None, 123 | esMappingTTLInMillis = None, 124 | esMappingTimestamp = None 125 | ) 126 | 127 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1) 128 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2) 129 | 130 | outputSparkEsDataFrameMapper1.extractMappingVersion(inputRow) mustEqual Some(1L) 131 | outputSparkEsDataFrameMapper2.extractMappingVersion(inputRow) mustEqual Some(1337L) 132 | } 133 | 134 | "from Row execute extractMappingVersionType() successfully" in { 135 | val inputRow = createInputRow(sc) 136 | val inputMapperConf1 = SparkEsMapperConf( 137 | esMappingId = None, 138 | esMappingParent = None, 139 | esMappingVersion = Some(""), 140 | esMappingVersionType = Some("force"), 141 | esMappingRouting = None, 142 | esMappingTTLInMillis = None, 143 | esMappingTimestamp = None 144 | ) 145 | val inputMapperConf2 = SparkEsMapperConf( 146 | esMappingId = None, 147 | esMappingParent = None, 148 | esMappingVersion = None, 149 | esMappingVersionType = Some("force"), 150 | esMappingRouting = None, 151 | esMappingTTLInMillis = None, 152 | esMappingTimestamp = None 153 | ) 154 | 155 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1) 156 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2) 157 | 158 | outputSparkEsDataFrameMapper1.extractMappingVersionType(inputRow) mustEqual Some(VersionType.FORCE) 159 | outputSparkEsDataFrameMapper2.extractMappingVersionType(inputRow) mustEqual None 160 | } 161 | 162 | "from Row execute extractMappingRouting() successfully" in { 163 | val inputRow = createInputRow(sc) 164 | val inputMapperConf1 = SparkEsMapperConf( 165 | esMappingId = None, 166 | esMappingParent = None, 167 | esMappingVersion = None, 168 | esMappingVersionType = None, 169 | esMappingRouting = Some("routing"), 170 | esMappingTTLInMillis = None, 171 | esMappingTimestamp = None 172 | ) 173 | val inputMapperConf2 = SparkEsMapperConf( 174 | esMappingId = None, 175 | esMappingParent = None, 176 | esMappingVersion = None, 177 | esMappingVersionType = None, 178 | esMappingRouting = Some(""), 179 | esMappingTTLInMillis = None, 180 | esMappingTimestamp = None 181 | ) 182 | 183 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1) 184 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2) 185 | 186 | outputSparkEsDataFrameMapper1.extractMappingRouting(inputRow) mustEqual Some("TEST_ROUTING_1") 187 | outputSparkEsDataFrameMapper2.extractMappingRouting(inputRow) mustEqual Some("TEST_VALUE") 188 | } 189 | 190 | "from Row execute extractMappingTTLInMillis() successfully" in { 191 | val inputRow = createInputRow(sc) 192 | val inputMapperConf1 = SparkEsMapperConf( 193 | esMappingId = None, 194 | esMappingParent = None, 195 | esMappingVersion = None, 196 | esMappingVersionType = None, 197 | esMappingRouting = None, 198 | esMappingTTLInMillis = Some("ttl"), 199 | esMappingTimestamp = None 200 | ) 201 | val inputMapperConf2 = SparkEsMapperConf( 202 | esMappingId = None, 203 | esMappingParent = None, 204 | esMappingVersion = None, 205 | esMappingVersionType = None, 206 | esMappingRouting = None, 207 | esMappingTTLInMillis = Some("<1337>"), 208 | esMappingTimestamp = None 209 | ) 210 | 211 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1) 212 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2) 213 | 214 | outputSparkEsDataFrameMapper1.extractMappingTTLInMillis(inputRow) mustEqual Some(86400000L) 215 | outputSparkEsDataFrameMapper2.extractMappingTTLInMillis(inputRow) mustEqual Some(1337L) 216 | } 217 | 218 | "from Row execute extractMappingTimestamp() successfully" in { 219 | val inputRow = createInputRow(sc) 220 | val inputMapperConf1 = SparkEsMapperConf( 221 | esMappingId = None, 222 | esMappingParent = None, 223 | esMappingVersion = None, 224 | esMappingVersionType = None, 225 | esMappingRouting = None, 226 | esMappingTTLInMillis = None, 227 | esMappingTimestamp = Some("timestamp") 228 | ) 229 | val inputMapperConf2 = SparkEsMapperConf( 230 | esMappingId = None, 231 | esMappingParent = None, 232 | esMappingVersion = None, 233 | esMappingVersionType = None, 234 | esMappingRouting = None, 235 | esMappingTTLInMillis = None, 236 | esMappingTimestamp = Some("") 237 | ) 238 | 239 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1) 240 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2) 241 | 242 | outputSparkEsDataFrameMapper1.extractMappingTimestamp(inputRow) mustEqual Some("TEST_TIMESTAMP_1") 243 | outputSparkEsDataFrameMapper2.extractMappingTimestamp(inputRow) mustEqual Some("TEST_VALUE") 244 | } 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameSerializerSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.sql 2 | 3 | import java.sql.{ Date, Timestamp } 4 | 5 | import com.holdenkarau.spark.testing.SharedSparkContext 6 | import org.apache.spark.sql.types._ 7 | import org.apache.spark.sql.{ Row, SQLContext } 8 | import org.elasticsearch.common.xcontent.XContentFactory 9 | import org.scalatest.{ MustMatchers, WordSpec } 10 | 11 | class SparkEsDataFrameSerializerSpec extends WordSpec with MustMatchers with SharedSparkContext { 12 | "SparkEsDataFrameSerializer" must { 13 | "execute writeStruct() successfully" in { 14 | val sqlContext = new SQLContext(sc) 15 | 16 | val inputSchema = StructType( 17 | Array( 18 | StructField("id", StringType, true), 19 | StructField("parent", StringType, true), 20 | StructField("version", LongType, true), 21 | StructField("routing", StringType, true), 22 | StructField("ttl", LongType, true), 23 | StructField("timestamp", StringType, true), 24 | StructField("value", LongType, true) 25 | ) 26 | ) 27 | val inputData = sc.parallelize { 28 | Array( 29 | Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L) 30 | ) 31 | } 32 | val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema) 33 | val inputRow = inputDataFrame.first() 34 | val inputSparkEsDataFrameSerializer = new SparkEsDataFrameSerializer(inputSchema) 35 | 36 | val inputBuilder = XContentFactory.jsonBuilder() 37 | val outputBuilder = inputSparkEsDataFrameSerializer.writeStruct(inputSchema, inputRow, inputBuilder) 38 | outputBuilder.string() must include("""{"id":"TEST_ID_1","parent":"TEST_PARENT_1","version":1,"routing":"TEST_ROUTING_1","ttl":86400000,"timestamp":"TEST_TIMESTAMP_1","value":1}""") 39 | inputBuilder.close() 40 | } 41 | 42 | "execute writeArray() successfully" in { 43 | val inputSparkEsDataFrameSerializer = new SparkEsDataFrameSerializer(null) 44 | 45 | val inputArray = Array(1, 2, 3) 46 | val inputBuilder = XContentFactory.jsonBuilder() 47 | val outputBuilder = inputSparkEsDataFrameSerializer.writeArray(ArrayType(IntegerType), inputArray, inputBuilder) 48 | outputBuilder.string() must include("""1,2,3""") 49 | inputBuilder.close() 50 | } 51 | 52 | "execute writeMap() successfully" in { 53 | val inputSparkEsDataFrameSerializer = new SparkEsDataFrameSerializer(null) 54 | 55 | val inputMap = Map( 56 | "TEST_KEY_1" -> "TEST_VALUE_1", 57 | "TEST_KEY_2" -> "TEST_VALUE_3", 58 | "TEST_KEY_3" -> "TEST_VALUE_3" 59 | ) 60 | val inputBuilder = XContentFactory.jsonBuilder() 61 | val outputBuilder = inputSparkEsDataFrameSerializer.writeMap(MapType(StringType, StringType), inputMap, inputBuilder) 62 | outputBuilder.string() must include("""{"TEST_KEY_1":"TEST_VALUE_1","TEST_KEY_2":"TEST_VALUE_3","TEST_KEY_3":"TEST_VALUE_3"}""") 63 | inputBuilder.close() 64 | } 65 | 66 | "execute writePrimitive() successfully" in { 67 | val inputSparkEsDataFrameSerializer = new SparkEsDataFrameSerializer(null) 68 | 69 | val inputBuilder1 = XContentFactory.jsonBuilder() 70 | val outputBuilder1 = inputSparkEsDataFrameSerializer.writePrimitive(BinaryType, Array[Byte](1), inputBuilder1) 71 | outputBuilder1.string() must include("AQ==") 72 | inputBuilder1.close() 73 | 74 | val inputBuilder2 = XContentFactory.jsonBuilder() 75 | val outputBuilder2 = inputSparkEsDataFrameSerializer.writePrimitive(BooleanType, true, inputBuilder2) 76 | outputBuilder2.string() mustEqual "true" 77 | inputBuilder2.close() 78 | 79 | val inputBuilder3 = XContentFactory.jsonBuilder() 80 | val outputBuilder3 = inputSparkEsDataFrameSerializer.writePrimitive(ByteType, 1.toByte, inputBuilder3) 81 | outputBuilder3.string() mustEqual "1" 82 | inputBuilder3.close() 83 | 84 | val inputBuilder4 = XContentFactory.jsonBuilder() 85 | val outputBuilder4 = inputSparkEsDataFrameSerializer.writePrimitive(ShortType, 1.toShort, inputBuilder4) 86 | outputBuilder4.string() mustEqual "1" 87 | inputBuilder4.close() 88 | 89 | val inputBuilder5 = XContentFactory.jsonBuilder() 90 | val outputBuilder5 = inputSparkEsDataFrameSerializer.writePrimitive(IntegerType, 1.toInt, inputBuilder5) 91 | outputBuilder5.string() mustEqual "1" 92 | inputBuilder5.close() 93 | 94 | val inputBuilder6 = XContentFactory.jsonBuilder() 95 | val outputBuilder6 = inputSparkEsDataFrameSerializer.writePrimitive(LongType, 1.toLong, inputBuilder6) 96 | outputBuilder6.string() mustEqual "1" 97 | inputBuilder6.close() 98 | 99 | val inputBuilder7 = XContentFactory.jsonBuilder() 100 | val outputBuilder7 = inputSparkEsDataFrameSerializer.writePrimitive(DoubleType, 1.0, inputBuilder7) 101 | outputBuilder7.string() mustEqual "1.0" 102 | inputBuilder7.close() 103 | 104 | val inputBuilder8 = XContentFactory.jsonBuilder() 105 | val outputBuilder8 = inputSparkEsDataFrameSerializer.writePrimitive(FloatType, 1.0F, inputBuilder8) 106 | outputBuilder8.string() mustEqual "1.0" 107 | inputBuilder8.close() 108 | 109 | val inputBuilder9 = XContentFactory.jsonBuilder() 110 | val outputBuilder9 = inputSparkEsDataFrameSerializer.writePrimitive(TimestampType, new Timestamp(834120000000L), inputBuilder9) 111 | outputBuilder9.string() must include("834120000000") 112 | inputBuilder9.close() 113 | 114 | val inputBuilder10 = XContentFactory.jsonBuilder() 115 | val outputBuilder10 = inputSparkEsDataFrameSerializer.writePrimitive(DateType, new Date(834120000000L), inputBuilder10) 116 | outputBuilder10.string() must include("834120000000") 117 | inputBuilder10.close() 118 | 119 | val inputBuilder11 = XContentFactory.jsonBuilder() 120 | val outputBuilder11 = inputSparkEsDataFrameSerializer.writePrimitive(StringType, "TEST_VALUE", inputBuilder11) 121 | outputBuilder11.string() must include("TEST_VALUE") 122 | inputBuilder11.close() 123 | } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/transport/SparkEsTransportClientManagerSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.transport 2 | 3 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsTransportClientConf 4 | import org.scalatest.{ MustMatchers, WordSpec } 5 | 6 | class SparkEsTransportClientManagerSpec extends WordSpec with MustMatchers { 7 | "SparkEsTransportClientManager" must { 8 | "maintain one unique TransportClient in internalTransportClients" in { 9 | val inputClientConf = SparkEsTransportClientConf( 10 | transportAddresses = Seq("127.0.0.1"), 11 | transportPort = 9300, 12 | transportSettings = Map.empty[String, String] 13 | ) 14 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {} 15 | 16 | inputSparkEsTransportClientManager.getTransportClient(inputClientConf) 17 | inputSparkEsTransportClientManager.getTransportClient(inputClientConf) 18 | 19 | inputSparkEsTransportClientManager.internalTransportClients.size mustEqual 1 20 | 21 | inputSparkEsTransportClientManager.closeTransportClient(inputClientConf) 22 | } 23 | 24 | "return a SparkEsTransportClientProxy when calling getTransportClient()" in { 25 | val inputClientConf = SparkEsTransportClientConf( 26 | transportAddresses = Seq("127.0.0.1"), 27 | transportPort = 9300, 28 | transportSettings = Map.empty[String, String] 29 | ) 30 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {} 31 | 32 | val outputClient = inputSparkEsTransportClientManager.getTransportClient(inputClientConf) 33 | 34 | outputClient.getClass mustEqual classOf[SparkEsTransportClientProxy] 35 | 36 | inputSparkEsTransportClientManager.closeTransportClient(inputClientConf) 37 | } 38 | 39 | "evict TransportClient after calling closeTransportClient" in { 40 | val inputClientConf = SparkEsTransportClientConf( 41 | transportAddresses = Seq("127.0.0.1"), 42 | transportPort = 9300, 43 | transportSettings = Map.empty[String, String] 44 | ) 45 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {} 46 | 47 | inputSparkEsTransportClientManager.getTransportClient(inputClientConf) 48 | 49 | inputSparkEsTransportClientManager.closeTransportClient(inputClientConf) 50 | 51 | inputSparkEsTransportClientManager.internalTransportClients.size mustEqual 0 52 | } 53 | 54 | "returns buildTransportSettings() successfully" in { 55 | val inputClientConf = SparkEsTransportClientConf( 56 | transportAddresses = Seq("127.0.0.1"), 57 | transportPort = 9300, 58 | transportSettings = Map( 59 | "TEST_KEY_1" -> "TEST_VALUE_1", 60 | "TEST_KEY_2" -> "TEST_VALUE_2", 61 | "TEST_KEY_3" -> "TEST_VALUE_3" 62 | ) 63 | ) 64 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {} 65 | 66 | val outputSettings = inputSparkEsTransportClientManager.buildTransportSettings(inputClientConf) 67 | 68 | outputSettings.get("TEST_KEY_1") mustEqual "TEST_VALUE_1" 69 | outputSettings.get("TEST_KEY_2") mustEqual "TEST_VALUE_2" 70 | outputSettings.get("TEST_KEY_3") mustEqual "TEST_VALUE_3" 71 | } 72 | 73 | "returns buildTransportClient() successfully" in { 74 | val inputClientConf = SparkEsTransportClientConf( 75 | transportAddresses = Seq("127.0.0.1"), 76 | transportPort = 9300, 77 | transportSettings = Map.empty[String, String] 78 | ) 79 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {} 80 | 81 | val outputSettings = inputSparkEsTransportClientManager.buildTransportSettings(inputClientConf) 82 | val outputClient = inputSparkEsTransportClientManager.buildTransportClient(inputClientConf, outputSettings) 83 | val outputHost = outputClient.transportAddresses().get(0).getHost 84 | val outputPort = outputClient.transportAddresses().get(0).getPort 85 | 86 | outputHost mustEqual "127.0.0.1" 87 | outputPort mustEqual 9300 88 | 89 | outputClient.close() 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/test/scala/com/github/jparkie/spark/elasticsearch/transport/SparkEsTransportClientProxySpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.jparkie.spark.elasticsearch.transport 2 | 3 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsTransportClientConf 4 | import com.github.jparkie.spark.elasticsearch.util.SparkEsException 5 | import org.scalatest.{ MustMatchers, WordSpec } 6 | 7 | class SparkEsTransportClientProxySpec extends WordSpec with MustMatchers { 8 | "SparkEsTransportClientProxy" must { 9 | "prohibit close() call" in { 10 | val inputClientConf = SparkEsTransportClientConf( 11 | transportAddresses = Seq("127.0.0.1"), 12 | transportPort = 9300, 13 | transportSettings = Map.empty[String, String] 14 | ) 15 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {} 16 | val inputSparkEsTransportClient = inputSparkEsTransportClientManager.getTransportClient(inputClientConf) 17 | val inputSparkEsTransportClientProxy = new SparkEsTransportClientProxy(inputSparkEsTransportClient) 18 | 19 | val outputException = intercept[SparkEsException] { 20 | inputSparkEsTransportClientProxy.close() 21 | } 22 | 23 | outputException.getMessage must include("close() is not supported in SparkEsTransportClientProxy. Please close with SparkEsTransportClientManager.") 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /version.sbt: -------------------------------------------------------------------------------- 1 | version in ThisBuild := "2.0.0-SNAPSHOT" --------------------------------------------------------------------------------