├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── build.sbt
├── project
├── build.properties
└── plugins.sbt
├── src
├── main
│ └── scala
│ │ └── com
│ │ └── github
│ │ └── jparkie
│ │ └── spark
│ │ └── elasticsearch
│ │ ├── SparkEsBulkWriter.scala
│ │ ├── SparkEsMapper.scala
│ │ ├── SparkEsSerializer.scala
│ │ ├── conf
│ │ ├── SparkEsMapperConf.scala
│ │ ├── SparkEsTransportClientConf.scala
│ │ └── SparkEsWriteConf.scala
│ │ ├── sql
│ │ ├── SparkEsDataFrameFunctions.scala
│ │ ├── SparkEsDataFrameMapper.scala
│ │ ├── SparkEsDataFrameSerializer.scala
│ │ └── package.scala
│ │ ├── transport
│ │ ├── SparkEsTransportClientManager.scala
│ │ └── SparkEsTransportClientProxy.scala
│ │ └── util
│ │ ├── SparkEsConfParam.scala
│ │ └── SparkEsException.scala
└── test
│ └── scala
│ └── com
│ └── github
│ └── jparkie
│ └── spark
│ └── elasticsearch
│ ├── ElasticSearchServer.scala
│ ├── SparkEsBulkWriterSpec.scala
│ ├── conf
│ ├── SparkEsMapperConfSpec.scala
│ ├── SparkEsTransportClientConfSpec.scala
│ └── SparkEsWriteConfSpec.scala
│ ├── sql
│ ├── PackageSpec.scala
│ ├── SparkEsDataFrameMapperSpec.scala
│ └── SparkEsDataFrameSerializerSpec.scala
│ └── transport
│ ├── SparkEsTransportClientManagerSpec.scala
│ └── SparkEsTransportClientProxySpec.scala
└── version.sbt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.gitignore.io/api/intellij,scala
2 |
3 | ### Intellij ###
4 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
5 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
6 |
7 | # Override All:
8 | .idea/
9 |
10 | # User-specific stuff:
11 | .idea/workspace.xml
12 | .idea/tasks.xml
13 | .idea/dictionaries
14 | .idea/vcs.xml
15 | .idea/jsLibraryMappings.xml
16 |
17 | # Sensitive or high-churn files:
18 | .idea/dataSources.ids
19 | .idea/dataSources.xml
20 | .idea/sqlDataSources.xml
21 | .idea/dynamic.xml
22 | .idea/uiDesigner.xml
23 |
24 | # Gradle:
25 | .idea/gradle.xml
26 | .idea/libraries
27 |
28 | # Mongo Explorer plugin:
29 | .idea/mongoSettings.xml
30 |
31 | ## File-based project format:
32 | *.iws
33 |
34 | ## Plugin-specific files:
35 |
36 | # IntelliJ
37 | /out/
38 |
39 | # mpeltonen/sbt-idea plugin
40 | .idea_modules/
41 |
42 | # JIRA plugin
43 | atlassian-ide-plugin.xml
44 |
45 | # Crashlytics plugin (for Android Studio and IntelliJ)
46 | com_crashlytics_export_strings.xml
47 | crashlytics.properties
48 | crashlytics-build.properties
49 | fabric.properties
50 |
51 |
52 | ### Scala ###
53 | *.class
54 | *.log
55 |
56 | # sbt specific
57 | .cache
58 | .history
59 | .lib/
60 | dist/*
61 | target/
62 | lib_managed/
63 | src_managed/
64 | project/boot/
65 | project/plugins/project/
66 |
67 | # Scala-IDE specific
68 | .scala_dependencies
69 | .worksheet
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: scala
2 |
3 | scala:
4 | - 2.10.6
5 | - 2.11.7
6 |
7 | jdk:
8 | - oraclejdk7
9 | - oraclejdk8
10 |
11 | sudo: false
12 |
13 | cache:
14 | directories:
15 | - $HOME/.ivy2
16 |
17 | script:
18 | - sbt clean coverage test
19 |
20 | after_success:
21 | - bash <(curl -s https://codecov.io/bash)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spark2Elasticsearch
2 |
3 | Spark Library for Bulk Loading into Elasticsearch
4 |
5 | [](https://travis-ci.org/jparkie/Spark2Elasticsearch)
6 |
7 | ## Requirements
8 |
9 | Spark2Elasticsearch supports Spark 1.4 and above.
10 |
11 | | Spark2Elasticsearch Version | Elasticsearch Version |
12 | | --------------------------- | --------------------- |
13 | | `2.0.X` | `2.0.X` |
14 | | `2.1.X` | `2.1.X` |
15 |
16 | ## Downloads
17 |
18 | #### SBT
19 | ```scala
20 | libraryDependencies += "com.github.jparkie" %% "spark2elasticsearch" % "2.0.2"
21 | ```
22 |
23 | Or:
24 |
25 | ```scala
26 | libraryDependencies += "com.github.jparkie" %% "spark2elasticsearch" % "2.1.2"
27 | ```
28 |
29 | Add the following resolver if needed:
30 |
31 | ```scala
32 | resolvers += "Sonatype OSS Releases" at "https://oss.sonatype.org/content/repositories/releases"
33 | resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots"
34 | ```
35 |
36 | #### Maven
37 | ```xml
38 |
39 | com.github.jparkie
40 | spark2elasticsearch_2.10
41 | x.y.z-SNAPSHOT
42 |
43 | ```
44 |
45 | It is planned for Spark2Elasticsearch to be available on the following:
46 | - http://spark-packages.org/
47 |
48 | ## Features
49 | - Utilizes Elasticsearch Java API with a `TransportClient` to bulk load data from a `DataFrame` into Elasticsearch.
50 |
51 | ## Usage
52 |
53 | ### Bulk Loading into Elasticsearch
54 |
55 | ```scala
56 | // Import the following to have access to the `bulkLoadToEs()` function.
57 | import com.github.jparkie.spark.elasticsearch.sql._
58 |
59 | val sparkConf = new SparkConf()
60 | val sc = SparkContext.getOrCreate(sparkConf)
61 | val sqlContext = SQLContext.getOrCreate(sc)
62 |
63 | val df = sqlContext.read.parquet("")
64 |
65 | // Specify the `index` and the `type` to write.
66 | df.bulkLoadToEs(
67 | esIndex = "twitter",
68 | esType = "tweets"
69 | )
70 | ```
71 |
72 | Refer to for more: [SparkEsDataFrameFunctions.scala](https://github.com/jparkie/Spark2Elasticsearch/blob/master/src/main/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameFunctions.scala)
73 |
74 | ## Configurations
75 |
76 | When adding configurations to through spark-submit, prefix property names with `spark.`.
77 |
78 | ### SparkEsMapperConf
79 |
80 | Refer to for more: [SparkEsMapperConf.scala](https://github.com/jparkie/Spark2Elasticsearch/blob/master/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsMapperConf.scala)
81 |
82 | | Property Name | Default | Description |
83 | | ------------------------- |:-------:| ------------|
84 | | `es.mapping.id` | None | The document field/property name containing the document id. |
85 | | `es.mapping.parent` | None | The document field/property name containing the document parent. To specify a constant, use the format. |
86 | | `es.mapping.version` | None | The document field/property name containing the document version. To specify a constant, use the format. |
87 | | `es.mapping.version.type` | None | Indicates the type of versioning used. http://www.elastic.co/guide/en/elasticsearch/reference/2.0/docs-index_.html#_version_types If es.mapping.version is undefined (default), its value is unspecified. If es.mapping.version is specified, its value becomes external. |
88 | | `es.mapping.routing` | None | The document field/property name containing the document routing. To specify a constant, use the format. |
89 | | `es.mapping.ttl` | None | The document field/property name containing the document time-to-live. To specify a constant, use the format. |
90 | | `es.mapping.timestamp` | None | The document field/property name containing the document timestamp. To specify a constant, use the format. |
91 |
92 | ### SparkEsTransportClientConf
93 |
94 | Refer to for more: [SparkEsTransportClientConf.scala](https://github.com/jparkie/Spark2Elasticsearch/blob/master/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsTransportClientConf.scala)
95 |
96 | | Property Name | Default | Description |
97 | | -------------------------------------------- |:----------:| ------------|
98 | | `es.nodes` | *Required* | The minimum set of hosts to connect to when establishing a client. Comma separated, colon separated host and port. |
99 | | `es.port` | 9300 | The port to connect when establishing a client. |
100 | | `es.cluster.name` | None | The name of the Elasticsearch cluster to connect. |
101 | | `es.client.transport.sniff` | None | If set to true, will discover other IP addresses to connect. |
102 | | `es.client.transport.ignore_cluster_name` | None | Set to true to ignore cluster name validation of connected nodes. |
103 | | `es.client.transport.ping_timeout` | 5s | The time to wait for a ping response from a node. |
104 | | `es.client.transport.nodes_sampler_interval` | 5s | How often to sample / ping the nodes listed and connected. |
105 |
106 | ### SparkEsWriteConf
107 |
108 | Refer to for more: [SparkEsWriteConf.scala](https://github.com/jparkie/Spark2Elasticsearch/blob/master/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsWriteConf.scala)
109 |
110 | | Property Name | Default | Description |
111 | | ----------------------------- |:-------:| ------------|
112 | | `es.batch.size.entries` | 1000 | The number of IndexRequests to batch in one request. |
113 | | `es.batch.size.bytes` | 5 | The maximum size in MB of a batch. |
114 | | `es.batch.concurrent.request` | 1 | The number of concurrent requests in flight. |
115 | | `es.batch.flush.timeout` | 10 | The maximum time in seconds to wait while closing a BulkProcessor. |
116 |
117 | ## Documentation
118 |
119 | Scaladocs are currently unavailable.
120 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | import com.typesafe.sbt.SbtScalariform
2 | import com.typesafe.sbt.SbtScalariform._
3 | import scalariform.formatter.preferences._
4 |
5 | /**
6 | * Organization:
7 | */
8 | organization := "com.github.jparkie"
9 | organizationName := "jparkie"
10 |
11 | /**
12 | * Library Meta:
13 | */
14 | name := "Spark2Elasticsearch"
15 | licenses := Seq(("Apache License, Version 2.0", url("http://www.apache.org/licenses/LICENSE-2.0")))
16 |
17 | /**
18 | * Scala:
19 | */
20 | scalaVersion := "2.10.6"
21 | crossScalaVersions := Seq("2.10.6", "2.11.7")
22 |
23 | /**
24 | * Library Dependencies:
25 | */
26 |
27 | // Exclusion Rules:
28 | val guavaRule = ExclusionRule("com.google.guava", "guava")
29 | val sparkNetworkCommonRule = ExclusionRule("org.apache.spark", "spark-network-common")
30 |
31 | // Versions:
32 | val SparkVersion = "1.4.1"
33 | val SparkTestVersion = "1.4.1_0.3.0"
34 | val ScalaTestVersion = "2.2.4"
35 | val ElasticsearchVersion = "2.0.2"
36 | val Slf4jVersion = "1.7.10"
37 |
38 | // Dependencies:
39 | val sparkCore = "org.apache.spark" %% "spark-core" % SparkVersion % "provided" excludeAll(sparkNetworkCommonRule, guavaRule)
40 | val sparkSql = "org.apache.spark" %% "spark-sql" % SparkVersion % "provided" excludeAll(sparkNetworkCommonRule, guavaRule)
41 | val sparkTest = "com.holdenkarau" %% "spark-testing-base" % SparkTestVersion % "test"
42 | val scalaTest = "org.scalatest" %% "scalatest" % ScalaTestVersion % "test"
43 | val elasticsearch = "org.elasticsearch" % "elasticsearch" % ElasticsearchVersion
44 | val slf4j = "org.slf4j" % "slf4j-api" % Slf4jVersion
45 |
46 | libraryDependencies ++= Seq(sparkCore, sparkSql, sparkTest, scalaTest, elasticsearch, slf4j)
47 |
48 | /**
49 | * Tests:
50 | */
51 | parallelExecution in Test := false
52 |
53 | /**
54 | * Scalariform:
55 | */
56 | SbtScalariform.scalariformSettings
57 | ScalariformKeys.preferences := FormattingPreferences()
58 | .setPreference(RewriteArrowSymbols, false)
59 | .setPreference(AlignParameters, true)
60 | .setPreference(AlignSingleLineCaseStatements, true)
61 | .setPreference(SpacesAroundMultiImports, true)
62 |
63 | /**
64 | * Scoverage:
65 | */
66 | coverageEnabled in Test := true
67 |
68 | /**
69 | * Publishing to Sonatype:
70 | */
71 | publishMavenStyle := true
72 |
73 | publishArtifact in Test := false
74 |
75 | publishTo := {
76 | val nexus = "https://oss.sonatype.org/"
77 | if (isSnapshot.value)
78 | Some("snapshots" at nexus + "content/repositories/snapshots")
79 | else
80 | Some("releases" at nexus + "service/local/staging/deploy/maven2")
81 | }
82 |
83 | pomExtra := {
84 | https://github.com/jparkie/Spark2Elasticsearch
85 |
86 | git@github.com:jparkie/Spark2Elasticsearch.git
87 | scm:git:git@github.com:jparkie/Spark2Elasticsearch.git
88 |
89 |
90 |
91 | jparkie
92 | Jacob Park
93 | https://github.com/jparkie
94 |
95 |
96 | }
97 |
98 | /**
99 | * Release:
100 | */
101 | import ReleaseTransformations._
102 |
103 | releasePublishArtifactsAction := PgpKeys.publishSigned.value
104 |
105 | releaseProcess := Seq[ReleaseStep](
106 | checkSnapshotDependencies,
107 | inquireVersions,
108 | runTest,
109 | setReleaseVersion,
110 | commitReleaseVersion,
111 | tagRelease,
112 | publishArtifacts,
113 | setNextVersion,
114 | commitNextVersion,
115 | pushChanges
116 | )
117 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version = 0.13.8
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/"
2 | resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/"
3 | resolvers += "Sonatype OSS Releases" at "https://oss.sonatype.org/service/local/staging/deploy/maven2"
4 |
5 | addSbtPlugin("org.scalariform" % "sbt-scalariform" % "1.6.0")
6 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
7 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.2")
8 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.3.5")
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/SparkEsBulkWriter.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch
2 |
3 | import java.util.concurrent.TimeUnit
4 |
5 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsWriteConf
6 | import com.github.jparkie.spark.elasticsearch.util.SparkEsException
7 | import org.apache.spark.{ Logging, TaskContext }
8 | import org.elasticsearch.action.bulk.{ BulkProcessor, BulkRequest, BulkResponse }
9 | import org.elasticsearch.action.index.IndexRequest
10 | import org.elasticsearch.action.update.UpdateRequest
11 | import org.elasticsearch.client.Client
12 | import org.elasticsearch.common.unit.{ ByteSizeUnit, ByteSizeValue }
13 |
14 | class SparkEsBulkWriter[T](
15 | esIndex: String,
16 | esType: String,
17 | esClient: () => Client,
18 | sparkEsSerializer: SparkEsSerializer[T],
19 | sparkEsMapper: SparkEsMapper[T],
20 | sparkEsWriteConf: SparkEsWriteConf
21 | ) extends Serializable with Logging {
22 | /**
23 | * Logs the executionId, number of requests, size, and latency of flushes.
24 | */
25 | class SparkEsBulkProcessorListener() extends BulkProcessor.Listener {
26 | override def beforeBulk(executionId: Long, request: BulkRequest): Unit = {
27 | logInfo(s"For executionId ($executionId), executing ${request.numberOfActions()} actions of estimate size ${request.estimatedSizeInBytes()} in bytes.")
28 | }
29 |
30 | override def afterBulk(executionId: Long, request: BulkRequest, response: BulkResponse): Unit = {
31 | logInfo(s"For executionId ($executionId), executed ${request.numberOfActions()} in ${response.getTookInMillis} milliseconds.")
32 |
33 | if (response.hasFailures) {
34 | throw new SparkEsException(response.buildFailureMessage())
35 | }
36 | }
37 |
38 | override def afterBulk(executionId: Long, request: BulkRequest, failure: Throwable): Unit = {
39 | logError(s"For executionId ($executionId), BulkRequest failed.", failure)
40 |
41 | throw new SparkEsException(failure.getMessage, failure)
42 | }
43 | }
44 |
45 | private[elasticsearch] def logDuration(closure: () => Unit): Unit = {
46 | val localStartTime = System.nanoTime()
47 |
48 | closure()
49 |
50 | val localEndTime = System.nanoTime()
51 |
52 | val differenceTime = localEndTime - localStartTime
53 | logInfo(s"Elasticsearch Task completed in ${TimeUnit.MILLISECONDS.convert(differenceTime, TimeUnit.NANOSECONDS)} milliseconds.")
54 | }
55 |
56 | private[elasticsearch] def createBulkProcessor(): BulkProcessor = {
57 | val esBulkProcessorListener = new SparkEsBulkProcessorListener()
58 | val esBulkProcessor = BulkProcessor.builder(esClient(), esBulkProcessorListener)
59 | .setBulkActions(sparkEsWriteConf.bulkActions)
60 | .setBulkSize(new ByteSizeValue(sparkEsWriteConf.bulkSizeInMB, ByteSizeUnit.MB))
61 | .setConcurrentRequests(sparkEsWriteConf.concurrentRequests)
62 | .build()
63 |
64 | esBulkProcessor
65 | }
66 |
67 | private[elasticsearch] def closeBulkProcessor(bulkProcessor: BulkProcessor): Unit = {
68 | val isClosed = bulkProcessor.awaitClose(sparkEsWriteConf.flushTimeoutInSeconds, TimeUnit.SECONDS)
69 | if (isClosed) {
70 | logInfo("Closed Elasticsearch Bulk Processor.")
71 | } else {
72 | logError("Elasticsearch Bulk Processor failed to close.")
73 | }
74 | }
75 |
76 | private[elasticsearch] def applyMappings(currentRow: T, indexRequest: IndexRequest): Unit = {
77 | sparkEsMapper.extractMappingId(currentRow).foreach(indexRequest.id)
78 | sparkEsMapper.extractMappingParent(currentRow).foreach(indexRequest.parent)
79 | sparkEsMapper.extractMappingVersion(currentRow).foreach(indexRequest.version)
80 | sparkEsMapper.extractMappingVersionType(currentRow).foreach(indexRequest.versionType)
81 | sparkEsMapper.extractMappingRouting(currentRow).foreach(indexRequest.routing)
82 | sparkEsMapper.extractMappingTTLInMillis(currentRow).foreach(indexRequest.ttl(_))
83 | sparkEsMapper.extractMappingTimestamp(currentRow).foreach(indexRequest.timestamp)
84 | }
85 |
86 | /**
87 | * Upserts T to Elasticsearch by establishing a TransportClient and BulkProcessor.
88 | *
89 | * @param taskContext The TaskContext provided by the Spark DAGScheduler.
90 | * @param data The set of T to persist.
91 | */
92 | def write(taskContext: TaskContext, data: Iterator[T]): Unit = logDuration { () =>
93 | val esBulkProcessor = createBulkProcessor()
94 |
95 | for (currentRow <- data) {
96 | val currentIndexRequest = new IndexRequest(esIndex, esType)
97 | .source(sparkEsSerializer.write(currentRow))
98 |
99 | applyMappings(currentRow, currentIndexRequest)
100 |
101 | val currentId = currentIndexRequest.id()
102 | val currentParent = currentIndexRequest.parent()
103 | val currentVersion = currentIndexRequest.version()
104 | val currentVersionType = currentIndexRequest.versionType()
105 | val currentRouting = currentIndexRequest.routing()
106 |
107 | val currentUpsertRequest = new UpdateRequest(esIndex, esType, currentId)
108 | .parent(currentParent)
109 | .version(currentVersion)
110 | .versionType(currentVersionType)
111 | .routing(currentRouting)
112 | .doc(currentIndexRequest)
113 | .docAsUpsert(true)
114 |
115 | esBulkProcessor.add(currentUpsertRequest)
116 | }
117 |
118 | closeBulkProcessor(esBulkProcessor)
119 | }
120 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/SparkEsMapper.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch
2 |
3 | import org.elasticsearch.index.VersionType
4 |
5 | /**
6 | * Extracts mappings from a T for an IndexRequest.
7 | *
8 | * @tparam T Object to extract mappings.
9 | */
10 | trait SparkEsMapper[T] extends Serializable {
11 | /**
12 | * Extracts the document field/property name containing the document id.
13 | *
14 | * @param value Object to extract mappings.
15 | * @return The document field/property name containing the document id.
16 | */
17 | def extractMappingId(value: T): Option[String]
18 |
19 | /**
20 | * Extracts the document field/property name containing the document parent.
21 | *
22 | * @param value Object to extract mappings.
23 | * @return The document field/property name containing the document parent.
24 | */
25 | def extractMappingParent(value: T): Option[String]
26 |
27 | /**
28 | * Extracts the document field/property name containing the document version.
29 | *
30 | * @param value Object to extract mappings.
31 | * @return The document field/property name containing the document version.
32 | */
33 | def extractMappingVersion(value: T): Option[Long]
34 |
35 | /**
36 | * Extracts the type of versioning used.
37 | *
38 | * @param value Object to extract mappings.
39 | * @return The type of versioning used..
40 | */
41 | def extractMappingVersionType(value: T): Option[VersionType]
42 |
43 | /**
44 | * Extracts the document field/property name containing the document routing.
45 | *
46 | * @param value Object to extract mappings.
47 | * @return The document field/property name containing the document routing.
48 | */
49 | def extractMappingRouting(value: T): Option[String]
50 |
51 | /**
52 | * Extracts the document field/property name containing the document time-to-live.
53 | *
54 | * @param value Object to extract mappings.
55 | * @return The document field/property name containing the document time-to-live.
56 | */
57 | def extractMappingTTLInMillis(value: T): Option[Long]
58 |
59 | /**
60 | * Extracts the document field/property name containing the document timestamp.
61 | *
62 | * @param value Object to extract mappings.
63 | * @return The document field/property name containing the document timestamp.
64 | */
65 | def extractMappingTimestamp(value: T): Option[String]
66 | }
67 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/SparkEsSerializer.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch
2 |
3 | /**
4 | * Serializes a T into an Array[Byte] for an IndexRequest.
5 | *
6 | * @tparam T T Object to serialize.
7 | */
8 | trait SparkEsSerializer[T] extends Serializable {
9 | /**
10 | * Serialize a T from a DataFrame into an Array[Byte].
11 | *
12 | * @param value A T
13 | * @return The source T as Array[Byte].
14 | */
15 | def write(value: T): Array[Byte]
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsMapperConf.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.conf
2 |
3 | import com.github.jparkie.spark.elasticsearch.util.SparkEsConfParam
4 | import org.apache.spark.SparkConf
5 |
6 | /**
7 | * https://www.elastic.co/guide/en/elasticsearch/hadoop/current/configuration.html#cfg-mapping
8 | *
9 | * @param esMappingId The document field/property name containing the document id.
10 | * @param esMappingParent The document field/property name containing the document parent.
11 | * To specify a constant, use the format.
12 | * @param esMappingVersion The document field/property name containing the document version.
13 | * To specify a constant, use the format.
14 | * @param esMappingVersionType Indicates the type of versioning used.
15 | * http://www.elastic.co/guide/en/elasticsearch/reference/2.0/docs-index_.html#_version_types
16 | * If es.mapping.version is undefined (default), its value is unspecified.
17 | * If es.mapping.version is specified, its value becomes external.
18 | * @param esMappingRouting The document field/property name containing the document routing.
19 | * To specify a constant, use the format.
20 | * @param esMappingTTLInMillis The document field/property name containing the document time-to-live.
21 | * To specify a constant, use the format.
22 | * @param esMappingTimestamp The document field/property name containing the document timestamp.
23 | * To specify a constant, use the format.
24 | */
25 |
26 | case class SparkEsMapperConf(
27 | esMappingId: Option[String],
28 | esMappingParent: Option[String],
29 | esMappingVersion: Option[String],
30 | esMappingVersionType: Option[String],
31 | esMappingRouting: Option[String],
32 | esMappingTTLInMillis: Option[String],
33 | esMappingTimestamp: Option[String]
34 | ) extends Serializable
35 |
36 | object SparkEsMapperConf {
37 | val CONSTANT_FIELD_REGEX = """\<([^>]+)\>""".r
38 |
39 | val ES_MAPPING_ID = SparkEsConfParam[Option[String]](
40 | name = "es.mapping.id",
41 | default = None
42 | )
43 | val ES_MAPPING_PARENT = SparkEsConfParam[Option[String]](
44 | name = "es.mapping.parent",
45 | default = None
46 | )
47 | val ES_MAPPING_VERSION = SparkEsConfParam[Option[String]](
48 | name = "es.mapping.version",
49 | default = None
50 | )
51 | val ES_MAPPING_VERSION_TYPE = SparkEsConfParam[Option[String]](
52 | name = "es.mapping.version.type",
53 | default = None
54 | )
55 | val ES_MAPPING_ROUTING = SparkEsConfParam[Option[String]](
56 | name = "es.mapping.routing",
57 | default = None
58 | )
59 | val ES_MAPPING_TTL_IN_MILLIS = SparkEsConfParam[Option[String]](
60 | name = "es.mapping.ttl",
61 | default = None
62 | )
63 | val ES_MAPPING_TIMESTAMP = SparkEsConfParam[Option[String]](
64 | name = "es.mapping.timestamp",
65 | default = None
66 | )
67 |
68 | /**
69 | * Extracts SparkEsMapperConf from a SparkConf.
70 | *
71 | * @param sparkConf A SparkConf.
72 | * @return A SparkEsMapperConf from a SparkConf.
73 | */
74 | def fromSparkConf(sparkConf: SparkConf): SparkEsMapperConf = {
75 | SparkEsMapperConf(
76 | esMappingId = ES_MAPPING_ID.fromConf(sparkConf)((sc, name) => sc.getOption(name)),
77 | esMappingParent = ES_MAPPING_PARENT.fromConf(sparkConf)((sc, name) => sc.getOption(name)),
78 | esMappingVersion = ES_MAPPING_VERSION.fromConf(sparkConf)((sc, name) => sc.getOption(name)),
79 | esMappingVersionType = ES_MAPPING_VERSION_TYPE.fromConf(sparkConf)((sc, name) => sc.getOption(name)),
80 | esMappingRouting = ES_MAPPING_ROUTING.fromConf(sparkConf)((sc, name) => sc.getOption(name)),
81 | esMappingTTLInMillis = ES_MAPPING_TTL_IN_MILLIS.fromConf(sparkConf)((sc, name) => sc.getOption(name)),
82 | esMappingTimestamp = ES_MAPPING_TIMESTAMP.fromConf(sparkConf)((sc, name) => sc.getOption(name))
83 | )
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsTransportClientConf.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.conf
2 |
3 | import java.net.InetSocketAddress
4 |
5 | import com.github.jparkie.spark.elasticsearch.util.SparkEsConfParam
6 | import org.apache.spark.SparkConf
7 |
8 | import scala.collection.mutable
9 |
10 | /**
11 | * Configurations for EsNativeDataFrameWriter's TransportClient.
12 | *
13 | * @param transportAddresses The minimum set of hosts to connect to when establishing a client.
14 | * CONFIG_CLIENT_TRANSPORT_SNIFF is enabled by default.
15 | * @param transportPort The port to connect when establishing a client.
16 | * @param transportSettings Miscellaneous settings for the TransportClient.
17 | * Empty by default.
18 | */
19 | case class SparkEsTransportClientConf(
20 | transportAddresses: Seq[String],
21 | transportPort: Int,
22 | transportSettings: Map[String, String]
23 | ) extends Serializable
24 |
25 | object SparkEsTransportClientConf {
26 | val CONFIG_CLUSTER_NAME = "cluster.name"
27 | val CONFIG_CLIENT_TRANSPORT_SNIFF = "client.transport.sniff"
28 | val CONFIG_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME = "client.transport.ignore_cluster_name"
29 | val CONFIG_CLIENT_TRANSPORT_PING_TIMEOUT = "client.transport.ping_timeout"
30 | val CONFIG_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL = "client.transport.nodes_sampler_interval"
31 |
32 | val ES_NODES = SparkEsConfParam[Seq[String]](
33 | name = "es.nodes",
34 | default = Seq.empty[String]
35 | )
36 | val ES_PORT = SparkEsConfParam[Int](
37 | name = "es.port",
38 | default = 9300
39 | )
40 | val ES_CLUSTER_NAME = SparkEsConfParam[String](
41 | name = s"es.$CONFIG_CLUSTER_NAME",
42 | default = null
43 | )
44 | val ES_CLIENT_TRANSPORT_SNIFF = SparkEsConfParam[String](
45 | name = s"es.$CONFIG_CLIENT_TRANSPORT_SNIFF",
46 | default = null
47 | )
48 | val ES_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME = SparkEsConfParam[String](
49 | name = s"es.$CONFIG_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME",
50 | default = null
51 | )
52 | val ES_CLIENT_TRANSPORT_PING_TIMEOUT = SparkEsConfParam[String](
53 | name = s"es.$CONFIG_CLIENT_TRANSPORT_PING_TIMEOUT",
54 | default = null
55 | )
56 | val ES_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL = SparkEsConfParam[String](
57 | name = s"es.$CONFIG_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL",
58 | default = null
59 | )
60 |
61 | def getTransportAddresses(transportAddresses: Seq[String], transportPort: Int): Seq[InetSocketAddress] = {
62 | transportAddresses match {
63 | case null | Nil => throw new IllegalArgumentException("A contact point list cannot be empty.")
64 | case hosts => hosts map {
65 | ipWithPort =>
66 | ipWithPort.split(":") match {
67 | case Array(actualHost, actualPort) =>
68 | new InetSocketAddress(actualHost, actualPort.toInt)
69 | case Array(actualHost) =>
70 | new InetSocketAddress(actualHost, transportPort)
71 | case errorMessage =>
72 | throw new IllegalArgumentException(s"A contact point should have the form [host:port] or [host] but was: $errorMessage.")
73 | }
74 | }
75 | }
76 | }
77 |
78 | /**
79 | * Extracts SparkEsTransportClientConf from a SparkConf.
80 | *
81 | * @param sparkConf A SparkConf.
82 | * @return A SparkEsTransportClientConf from a SparkConf.
83 | */
84 | def fromSparkConf(sparkConf: SparkConf): SparkEsTransportClientConf = {
85 | val tempEsNodes = ES_NODES.fromConf(sparkConf)((sc, name) => sc.get(name).split(","))
86 | val tempEsPort = ES_PORT.fromConf(sparkConf)((sc, name) => sc.getInt(name, ES_PORT.default))
87 | val tempSettings = mutable.HashMap.empty[String, String]
88 |
89 | require(
90 | tempEsNodes.nonEmpty,
91 | s"""No nodes defined in property ${ES_NODES.name} is in SparkConf.""".stripMargin
92 | )
93 |
94 | if (sparkConf.contains(ES_CLUSTER_NAME.name) || sparkConf.contains(s"spark.${ES_CLUSTER_NAME.name}"))
95 | tempSettings.put(CONFIG_CLUSTER_NAME, ES_CLUSTER_NAME.fromConf(sparkConf)((sc, name) => sc.get(name)))
96 |
97 | if (sparkConf.contains(ES_CLIENT_TRANSPORT_SNIFF.name) || sparkConf.contains(s"spark.${ES_CLIENT_TRANSPORT_SNIFF.name}"))
98 | tempSettings.put(CONFIG_CLIENT_TRANSPORT_SNIFF, ES_CLIENT_TRANSPORT_SNIFF.fromConf(sparkConf)((sc, name) => sc.get(name)))
99 |
100 | if (sparkConf.contains(ES_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME.name) || sparkConf.contains(s"spark.${ES_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME.name}"))
101 | tempSettings.put(CONFIG_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME, ES_CLIENT_TRANSPORT_IGNORE_CLUSTER_NAME.fromConf(sparkConf)((sc, name) => sc.get(name)))
102 |
103 | if (sparkConf.contains(ES_CLIENT_TRANSPORT_PING_TIMEOUT.name) || sparkConf.contains(s"spark.${ES_CLIENT_TRANSPORT_PING_TIMEOUT.name}"))
104 | tempSettings.put(CONFIG_CLIENT_TRANSPORT_PING_TIMEOUT, ES_CLIENT_TRANSPORT_PING_TIMEOUT.fromConf(sparkConf)((sc, name) => sc.get(name)))
105 |
106 | if (sparkConf.contains(ES_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL.name) || sparkConf.contains(s"spark.${ES_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL.name}"))
107 | tempSettings.put(CONFIG_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL, ES_CLIENT_TRANSPORT_NODES_SAMPLER_INTERVAL.fromConf(sparkConf)((sc, name) => sc.get(name)))
108 |
109 | SparkEsTransportClientConf(
110 | transportAddresses = tempEsNodes,
111 | transportPort = tempEsPort,
112 | transportSettings = tempSettings.toMap
113 | )
114 | }
115 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsWriteConf.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.conf
2 |
3 | import com.github.jparkie.spark.elasticsearch.util.SparkEsConfParam
4 | import org.apache.spark.SparkConf
5 |
6 | /**
7 | * Configurations for EsNativeDataFrameWriter's BulkProcessor.
8 | *
9 | * @param bulkActions The number of IndexRequests to batch in one request.
10 | * @param bulkSizeInMB The maximum size in MB of a batch.
11 | * @param concurrentRequests The number of concurrent requests in flight.
12 | * @param flushTimeoutInSeconds The maximum time in seconds to wait while closing a BulkProcessor.
13 | */
14 | case class SparkEsWriteConf(
15 | bulkActions: Int,
16 | bulkSizeInMB: Int,
17 | concurrentRequests: Int,
18 | flushTimeoutInSeconds: Long
19 | ) extends Serializable
20 |
21 | object SparkEsWriteConf {
22 | val BULK_ACTIONS = SparkEsConfParam[Int](
23 | name = "es.batch.size.entries",
24 | default = 1000
25 | )
26 | val BULK_SIZE_IN_MB = SparkEsConfParam[Int](
27 | name = "es.batch.size.bytes",
28 | default = 5
29 | )
30 | val CONCURRENT_REQUESTS = SparkEsConfParam[Int](
31 | name = "es.batch.concurrent.request",
32 | default = 1
33 | )
34 | val FLUSH_TIMEOUT_IN_SECONDS = SparkEsConfParam[Long](
35 | name = "es.batch.flush.timeout",
36 | default = 10
37 | )
38 |
39 | /**
40 | * Extracts SparkEsTransportClientConf from a SparkConf.
41 | *
42 | * @param sparkConf A SparkConf.
43 | * @return A SparkEsTransportClientConf from a SparkConf.
44 | */
45 | def fromSparkConf(sparkConf: SparkConf): SparkEsWriteConf = {
46 | SparkEsWriteConf(
47 | bulkActions = BULK_ACTIONS.fromConf(sparkConf)((sc, name) => sc.getInt(name, BULK_ACTIONS.default)),
48 | bulkSizeInMB = BULK_SIZE_IN_MB.fromConf(sparkConf)((sc, name) => sc.getInt(name, BULK_SIZE_IN_MB.default)),
49 | concurrentRequests = CONCURRENT_REQUESTS.fromConf(sparkConf)((sc, name) => sc.getInt(name, CONCURRENT_REQUESTS.default)),
50 | flushTimeoutInSeconds = FLUSH_TIMEOUT_IN_SECONDS.fromConf(sparkConf)((sc, name) => sc.getLong(name, FLUSH_TIMEOUT_IN_SECONDS.default))
51 | )
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameFunctions.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.sql
2 |
3 | import com.github.jparkie.spark.elasticsearch.SparkEsBulkWriter
4 | import com.github.jparkie.spark.elasticsearch.conf.{ SparkEsMapperConf, SparkEsTransportClientConf, SparkEsWriteConf }
5 | import com.github.jparkie.spark.elasticsearch.transport.SparkEsTransportClientManager
6 | import org.apache.spark.sql.{ DataFrame, Row }
7 |
8 | /**
9 | * Extension of DataFrame with 'bulkLoadToEs()' function.
10 | *
11 | * @param dataFrame The DataFrame to lift into extension.
12 | */
13 | class SparkEsDataFrameFunctions(dataFrame: DataFrame) extends Serializable {
14 | /**
15 | * SparkContext to schedule SparkEsWriter Tasks.
16 | */
17 | private[sql] val sparkContext = dataFrame.sqlContext.sparkContext
18 |
19 | /**
20 | * Upserts DataFrame into Elasticsearch with the Java API utilizing a TransportClient.
21 | *
22 | * @param esIndex Index of DataFrame in Elasticsearch.
23 | * @param esType Type of DataFrame in Elasticsearch.
24 | * @param sparkEsTransportClientConf Configurations for the TransportClient.
25 | * @param sparkEsMapperConf Configurations for IndexRequest.
26 | * @param sparkEsWriteConf Configurations for the BulkProcessor.
27 | * Empty by default.
28 | */
29 | def bulkLoadToEs(
30 | esIndex: String,
31 | esType: String,
32 | sparkEsTransportClientConf: SparkEsTransportClientConf = SparkEsTransportClientConf.fromSparkConf(sparkContext.getConf),
33 | sparkEsMapperConf: SparkEsMapperConf = SparkEsMapperConf.fromSparkConf(sparkContext.getConf),
34 | sparkEsWriteConf: SparkEsWriteConf = SparkEsWriteConf.fromSparkConf(sparkContext.getConf)
35 | )(implicit sparkEsTransportClientManager: SparkEsTransportClientManager = sparkEsTransportClientManager): Unit = {
36 | val sparkEsWriter = new SparkEsBulkWriter[Row](
37 | esIndex = esIndex,
38 | esType = esType,
39 | esClient = () => sparkEsTransportClientManager.getTransportClient(sparkEsTransportClientConf),
40 | sparkEsSerializer = new SparkEsDataFrameSerializer(dataFrame.schema),
41 | sparkEsMapper = new SparkEsDataFrameMapper(sparkEsMapperConf),
42 | sparkEsWriteConf = sparkEsWriteConf
43 | )
44 |
45 | sparkContext.runJob(dataFrame.rdd, sparkEsWriter.write _)
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameMapper.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.sql
2 |
3 | import com.github.jparkie.spark.elasticsearch.SparkEsMapper
4 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsMapperConf
5 | import org.apache.spark.sql.Row
6 | import org.elasticsearch.index.VersionType
7 |
8 | /**
9 | * Extracts mappings from a Row for an IndexRequest.
10 | *
11 | * @param mapperConf Configurations for IndexRequest.
12 | */
13 | class SparkEsDataFrameMapper(mapperConf: SparkEsMapperConf) extends SparkEsMapper[Row] {
14 | import SparkEsDataFrameMapper._
15 | import SparkEsMapperConf._
16 |
17 | /**
18 | * Extracts the document field/property name containing the document id.
19 | *
20 | * @param value Object to extract mappings.
21 | * @return The document field/property name containing the document id.
22 | */
23 | override def extractMappingId(value: Row): Option[String] = {
24 | mapperConf.esMappingId
25 | .map(currentFieldName => value.getAsToString(currentFieldName))
26 | }
27 |
28 | /**
29 | * Extracts the document field/property name containing the document parent.
30 | *
31 | * @param value Object to extract mappings.
32 | * @return The document field/property name containing the document parent.
33 | */
34 | override def extractMappingParent(value: Row): Option[String] = {
35 | mapperConf.esMappingParent.map {
36 | case CONSTANT_FIELD_REGEX(constantValue) =>
37 | constantValue
38 | case currentFieldName =>
39 | value.getAsToString(currentFieldName)
40 | }
41 | }
42 |
43 | /**
44 | * Extracts the document field/property name containing the document version.
45 | *
46 | * @param value Object to extract mappings.
47 | * @return The document field/property name containing the document version.
48 | */
49 | override def extractMappingVersion(value: Row): Option[Long] = {
50 | mapperConf.esMappingVersion.map {
51 | case CONSTANT_FIELD_REGEX(constantValue) =>
52 | constantValue.toLong
53 | case currentFieldName =>
54 | value.getAsToString(currentFieldName).toLong
55 | }
56 | }
57 |
58 | /**
59 | * Extracts the type of versioning used.
60 | *
61 | * @param value Object to extract mappings.
62 | * @return The type of versioning used..
63 | */
64 | override def extractMappingVersionType(value: Row): Option[VersionType] = {
65 | mapperConf.esMappingVersion
66 | .flatMap(_ => mapperConf.esMappingVersionType.map(VersionType.fromString))
67 | }
68 |
69 | /**
70 | * Extracts the document field/property name containing the document routing.
71 | *
72 | * @param value Object to extract mappings.
73 | * @return The document field/property name containing the document routing.
74 | */
75 | override def extractMappingRouting(value: Row): Option[String] = {
76 | mapperConf.esMappingRouting.map {
77 | case CONSTANT_FIELD_REGEX(constantValue) =>
78 | constantValue
79 | case currentFieldName =>
80 | value.getAsToString(currentFieldName)
81 | }
82 | }
83 |
84 | /**
85 | * Extracts the document field/property name containing the document time-to-live.
86 | *
87 | * @param value Object to extract mappings.
88 | * @return The document field/property name containing the document time-to-live.
89 | */
90 | override def extractMappingTTLInMillis(value: Row): Option[Long] = {
91 | mapperConf.esMappingTTLInMillis.map {
92 | case CONSTANT_FIELD_REGEX(constantValue) =>
93 | constantValue.toLong
94 | case currentFieldName =>
95 | value.getAsToString(currentFieldName).toLong
96 | }
97 | }
98 |
99 | /**
100 | * Extracts the document field/property name containing the document timestamp.
101 | *
102 | * @param value Object to extract mappings.
103 | * @return The document field/property name containing the document timestamp.
104 | */
105 | override def extractMappingTimestamp(value: Row): Option[String] = {
106 | mapperConf.esMappingTimestamp.map {
107 | case CONSTANT_FIELD_REGEX(constantValue) =>
108 | constantValue
109 | case currentFieldName =>
110 | value.getAsToString(currentFieldName)
111 | }
112 | }
113 | }
114 |
115 | object SparkEsDataFrameMapper {
116 | /**
117 | * Adds method to retrieve field as String through Any's toString.
118 | *
119 | * @param currentRow Row object.
120 | */
121 | implicit class RichRow(currentRow: Row) {
122 | // TODO: Find a better and safer way.
123 | def getAsToString(fieldName: String): String = {
124 | currentRow.getAs[Any](fieldName).toString
125 | }
126 | }
127 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameSerializer.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.sql
2 |
3 | import java.sql.{ Date, Timestamp }
4 |
5 | import com.github.jparkie.spark.elasticsearch.SparkEsSerializer
6 | import org.apache.spark.sql.Row
7 | import org.apache.spark.sql.types._
8 | import org.elasticsearch.common.xcontent.{ XContentBuilder, XContentFactory }
9 |
10 | import scala.collection.JavaConverters._
11 |
12 | /**
13 | * Serializes a Row from a DataFrame into an Array[Byte].
14 | *
15 | * @param schema The StructType of a DataFrame.
16 | */
17 | class SparkEsDataFrameSerializer(schema: StructType) extends SparkEsSerializer[Row] {
18 | /**
19 | * Serializes a Row from a DataFrame into an Array[Byte].
20 | *
21 | * @param value A Row.
22 | * @return The source JSON as Array[Byte].
23 | */
24 | override def write(value: Row): Array[Byte] = {
25 | val currentJsonBuilder = XContentFactory.jsonBuilder()
26 |
27 | write(schema, value, currentJsonBuilder)
28 |
29 | currentJsonBuilder
30 | .bytes()
31 | .toBytes
32 | }
33 |
34 | private[sql] def write(dataType: DataType, value: Any, builder: XContentBuilder): XContentBuilder = {
35 | dataType match {
36 | case structType @ StructType(_) => writeStruct(structType, value, builder)
37 | case arrayType @ ArrayType(_, _) => writeArray(arrayType, value, builder)
38 | case mapType @ MapType(_, _, _) => writeMap(mapType, value, builder)
39 | case _ => writePrimitive(dataType, value, builder)
40 | }
41 | }
42 |
43 | private[sql] def writeStruct(structType: StructType, value: Any, builder: XContentBuilder): XContentBuilder = {
44 | value match {
45 | case currentRow: Row =>
46 | builder.startObject()
47 |
48 | structType.fields.view.zipWithIndex foreach {
49 | case (field, index) =>
50 | builder.field(field.name)
51 | if (currentRow.isNullAt(index)) {
52 | builder.nullValue()
53 | } else {
54 | write(field.dataType, currentRow(index), builder)
55 | }
56 | }
57 |
58 | builder.endObject()
59 | }
60 |
61 | builder
62 | }
63 |
64 | private[sql] def writeArray(arrayType: ArrayType, value: Any, builder: XContentBuilder): XContentBuilder = {
65 | value match {
66 | case array: Array[_] =>
67 | serializeArray(arrayType.elementType, array, builder)
68 | case seq: Seq[_] =>
69 | serializeArray(arrayType.elementType, seq, builder)
70 | case _ =>
71 | throw new IllegalArgumentException(s"Unknown ArrayType: $value.")
72 | }
73 | }
74 |
75 | private[sql] def serializeArray(dataType: DataType, value: Seq[_], builder: XContentBuilder): XContentBuilder = {
76 | // TODO: Consider utilizing builder.value(Iterable[_]).
77 | builder.startArray()
78 |
79 | if (value != null) {
80 | value foreach { element =>
81 | write(dataType, element, builder)
82 | }
83 | }
84 |
85 | builder.endArray()
86 | builder
87 | }
88 |
89 | private[sql] def writeMap(mapType: MapType, value: Any, builder: XContentBuilder): XContentBuilder = {
90 | value match {
91 | case scalaMap: scala.collection.Map[_, _] =>
92 | serializeMap(mapType, scalaMap, builder)
93 | case javaMap: java.util.Map[_, _] =>
94 | serializeMap(mapType, javaMap.asScala, builder)
95 | case _ =>
96 | throw new IllegalArgumentException(s"Unknown MapType: $value.")
97 | }
98 | }
99 |
100 | private[sql] def serializeMap(mapType: MapType, value: scala.collection.Map[_, _], builder: XContentBuilder): XContentBuilder = {
101 | // TODO: Consider utilizing builder.value(Map[_, AnyRef]).
102 | builder.startObject()
103 |
104 | for ((currentKey, currentValue) <- value) {
105 | builder.field(currentKey.toString)
106 | write(mapType.valueType, currentValue, builder)
107 | }
108 |
109 | builder.endObject()
110 | builder
111 | }
112 |
113 | private[sql] def writePrimitive(dataType: DataType, value: Any, builder: XContentBuilder): XContentBuilder = {
114 | dataType match {
115 | case BinaryType => builder.value(value.asInstanceOf[Array[Byte]])
116 | case BooleanType => builder.value(value.asInstanceOf[Boolean])
117 | case ByteType => builder.value(value.asInstanceOf[Byte])
118 | case ShortType => builder.value(value.asInstanceOf[Short])
119 | case IntegerType => builder.value(value.asInstanceOf[Int])
120 | case LongType => builder.value(value.asInstanceOf[Long])
121 | case DoubleType => builder.value(value.asInstanceOf[Double])
122 | case FloatType => builder.value(value.asInstanceOf[Float])
123 | case TimestampType => builder.value(value.asInstanceOf[Timestamp].getTime)
124 | case DateType => builder.value(value.asInstanceOf[Date].getTime)
125 | case StringType => builder.value(value.toString)
126 | case _ =>
127 | throw new IllegalArgumentException(s"Unknown DataType: $value.")
128 | }
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/sql/package.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch
2 |
3 | import com.github.jparkie.spark.elasticsearch.transport.SparkEsTransportClientManager
4 | import org.apache.spark.sql.DataFrame
5 |
6 | package object sql {
7 | implicit val sparkEsTransportClientManager = SparkEsTransportClientManager
8 |
9 | /**
10 | * Implicitly lift a DataFrame with SparkEsDataFrameFunctions.
11 | *
12 | * @param dataFrame A DataFrame to lift.
13 | * @return Enriched DataFrame with SparkEsDataFrameFunctions.
14 | */
15 | implicit def sparkEsDataFrameFunctions(dataFrame: DataFrame): SparkEsDataFrameFunctions = new SparkEsDataFrameFunctions(dataFrame)
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/transport/SparkEsTransportClientManager.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.transport
2 |
3 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsTransportClientConf
4 | import org.apache.spark.Logging
5 | import org.elasticsearch.client.Client
6 | import org.elasticsearch.client.transport.TransportClient
7 | import org.elasticsearch.common.settings.Settings
8 | import org.elasticsearch.common.transport.InetSocketTransportAddress
9 |
10 | import scala.collection.mutable
11 |
12 | private[elasticsearch] trait SparkEsTransportClientManager extends Serializable with Logging {
13 | @transient
14 | private[transport] val internalTransportClients = mutable.HashMap.empty[SparkEsTransportClientConf, TransportClient]
15 |
16 | private[transport] def buildTransportSettings(clientConf: SparkEsTransportClientConf): Settings = {
17 | val esSettingsBuilder = Settings.builder()
18 |
19 | clientConf.transportSettings foreach { currentSetting =>
20 | esSettingsBuilder.put(currentSetting._1, currentSetting._2)
21 | }
22 |
23 | esSettingsBuilder.build()
24 | }
25 |
26 | private[transport] def buildTransportClient(clientConf: SparkEsTransportClientConf, esSettings: Settings): TransportClient = {
27 | import SparkEsTransportClientConf._
28 |
29 | val esClient = TransportClient.builder()
30 | .settings(esSettings)
31 | .build()
32 |
33 | getTransportAddresses(clientConf.transportAddresses, clientConf.transportPort) foreach { inetSocketAddress =>
34 | esClient.addTransportAddresses(new InetSocketTransportAddress(inetSocketAddress))
35 | }
36 |
37 | sys.addShutdownHook {
38 | logInfo("Closed Elasticsearch Transport Client.")
39 |
40 | esClient.close()
41 | }
42 |
43 | logInfo(s"Connected to the following Elasticsearch nodes: ${esClient.connectedNodes()}.")
44 |
45 | esClient
46 | }
47 |
48 | /**
49 | * Gets or creates a TransportClient per JVM.
50 | *
51 | * @param clientConf Settings and initial endpoints for connection.
52 | * @return SparkEsTransportClientProxy as Client.
53 | */
54 | def getTransportClient(clientConf: SparkEsTransportClientConf): Client = synchronized {
55 | internalTransportClients.get(clientConf) match {
56 | case Some(transportClient) =>
57 | new SparkEsTransportClientProxy(transportClient)
58 | case None =>
59 | val transportSettings = buildTransportSettings(clientConf)
60 | val transportClient = buildTransportClient(clientConf, transportSettings)
61 | internalTransportClients.put(clientConf, transportClient)
62 | new SparkEsTransportClientProxy(transportClient)
63 | }
64 | }
65 |
66 | /**
67 | * Evicts and closes a TransportClient.
68 | *
69 | * @param clientConf Settings and initial endpoints for connection.
70 | */
71 | def closeTransportClient(clientConf: SparkEsTransportClientConf): Unit = synchronized {
72 | internalTransportClients.remove(clientConf) match {
73 | case Some(transportClient) =>
74 | transportClient.close()
75 | case None =>
76 | logError(s"No TransportClient for $clientConf.")
77 | }
78 | }
79 | }
80 |
81 | object SparkEsTransportClientManager extends SparkEsTransportClientManager
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/transport/SparkEsTransportClientProxy.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.transport
2 |
3 | import com.github.jparkie.spark.elasticsearch.util.SparkEsException
4 | import org.elasticsearch.client.{ FilterClient, Client }
5 |
6 | /**
7 | * Restrict access to TransportClient by disabling close() without use of SparkEsTransportClientManager.
8 | */
9 | class SparkEsTransportClientProxy(client: Client) extends FilterClient(client) {
10 | override def close(): Unit = {
11 | throw new SparkEsException("close() is not supported in SparkEsTransportClientProxy. Please close with SparkEsTransportClientManager.")
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/util/SparkEsConfParam.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.util
2 |
3 | import org.apache.spark.SparkConf
4 |
5 | /**
6 | * Defines parameter to extract values from SparkConf.
7 | *
8 | * @param name The key in SparkConf.
9 | * @param default The default value to fallback on missing key.
10 | */
11 | case class SparkEsConfParam[T](name: String, default: T) {
12 | def fromConf(sparkConf: SparkConf)(sparkConfFunc: (SparkConf, String) => T): T = {
13 | if (sparkConf.contains(name)) {
14 | sparkConfFunc(sparkConf, name)
15 | } else if (sparkConf.contains(s"spark.$name")) {
16 | sparkConfFunc(sparkConf, s"spark.$name")
17 | } else {
18 | default
19 | }
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/jparkie/spark/elasticsearch/util/SparkEsException.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.util
2 |
3 | /**
4 | * General exceptions captured by EsNativeDataFrameBulkProcessorListener.
5 | *
6 | * @param message the detail message (which is saved for later retrieval
7 | * by the { @link #getMessage()} method).
8 | * @param cause the cause (which is saved for later retrieval by the
9 | * { @link #getCause()} method). (A null value is
10 | * permitted, and indicates that the cause is nonexistent or
11 | * unknown.)
12 | */
13 | class SparkEsException(message: String, cause: Throwable) extends RuntimeException(message, cause) with Serializable {
14 | def this() = this(null, null)
15 | def this(message: String) = this(message, null)
16 | def this(cause: Throwable) = this(null, cause)
17 | }
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/ElasticSearchServer.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch
2 |
3 | import java.nio.file.Files
4 |
5 | import org.apache.commons.io.FileUtils
6 | import org.elasticsearch.client.Client
7 | import org.elasticsearch.common.settings.Settings
8 | import org.elasticsearch.node.{ Node, NodeBuilder }
9 |
10 | /**
11 | * Local Elasticsearch server of one node for integration testing.
12 | */
13 | class ElasticSearchServer {
14 | private val homeDir = Files.createTempDirectory("elasticsearch").toFile
15 | private val dataDir = Files.createTempDirectory("elasticsearch").toFile
16 |
17 | val clusterName = "Spark2Elasticsearch"
18 |
19 | private lazy val internalNode: Node = {
20 | val tempSettings = Settings.builder()
21 | .put("path.home", homeDir.getAbsolutePath)
22 | .put("path.data", dataDir.getAbsolutePath)
23 | .put("es.logger.level", "OFF")
24 | .build()
25 | val tempNode = NodeBuilder.nodeBuilder()
26 | .clusterName(clusterName)
27 | .local(true)
28 | .data(true)
29 | .settings(tempSettings)
30 | .build()
31 |
32 | tempNode
33 | }
34 |
35 | /**
36 | * Fetch a client to the Elasticsearch cluster.
37 | *
38 | * @return A Client.
39 | */
40 | def client: Client = {
41 | internalNode.client()
42 | }
43 |
44 | /**
45 | * Start the Elasticsearch cluster.
46 | */
47 | def start(): Unit = {
48 | internalNode.start()
49 | }
50 |
51 | /**
52 | * Stop the Elasticsearch cluster.
53 | */
54 | def stop(): Unit = {
55 | internalNode.close()
56 |
57 | try {
58 | FileUtils.forceDelete(homeDir)
59 | FileUtils.forceDelete(dataDir)
60 | } catch {
61 | case e: Exception =>
62 | // Do Nothing.
63 | }
64 | }
65 |
66 | /**
67 | * Create Index.
68 | *
69 | * @param index The name of Index.
70 | */
71 | def createAndWaitForIndex(index: String): Unit = {
72 | client.admin.indices.prepareCreate(index).execute.actionGet()
73 | client.admin.cluster.prepareHealth(index).setWaitForActiveShards(1).execute.actionGet()
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/SparkEsBulkWriterSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch
2 |
3 | import com.github.jparkie.spark.elasticsearch.conf.{ SparkEsMapperConf, SparkEsWriteConf }
4 | import com.github.jparkie.spark.elasticsearch.sql.{ SparkEsDataFrameMapper, SparkEsDataFrameSerializer }
5 | import com.holdenkarau.spark.testing.SharedSparkContext
6 | import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
7 | import org.apache.spark.sql.{ Row, SQLContext }
8 | import org.scalatest.{ MustMatchers, WordSpec }
9 |
10 | class SparkEsBulkWriterSpec extends WordSpec with MustMatchers with SharedSparkContext {
11 | val esServer = new ElasticSearchServer()
12 |
13 | override def beforeAll(): Unit = {
14 | super.beforeAll()
15 |
16 | esServer.start()
17 | }
18 |
19 | override def afterAll(): Unit = {
20 | esServer.stop()
21 |
22 | super.afterAll()
23 | }
24 |
25 | "SparkEsBulkWriter" must {
26 | "execute write() successfully" in {
27 | esServer.createAndWaitForIndex("test_index")
28 |
29 | val sqlContext = new SQLContext(sc)
30 |
31 | val inputSparkEsWriteConf = SparkEsWriteConf(
32 | bulkActions = 10,
33 | bulkSizeInMB = 1,
34 | concurrentRequests = 0,
35 | flushTimeoutInSeconds = 1
36 | )
37 | val inputMapperConf = SparkEsMapperConf(
38 | esMappingId = Some("id"),
39 | esMappingParent = None,
40 | esMappingVersion = None,
41 | esMappingVersionType = None,
42 | esMappingRouting = None,
43 | esMappingTTLInMillis = None,
44 | esMappingTimestamp = None
45 | )
46 | val inputSchema = StructType(
47 | Array(
48 | StructField("id", StringType, true),
49 | StructField("parent", StringType, true),
50 | StructField("version", LongType, true),
51 | StructField("routing", StringType, true),
52 | StructField("ttl", LongType, true),
53 | StructField("timestamp", StringType, true),
54 | StructField("value", LongType, true)
55 | )
56 | )
57 | val inputData = sc.parallelize {
58 | Array(
59 | Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L),
60 | Row("TEST_ID_1", "TEST_PARENT_2", 2L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 2L),
61 | Row("TEST_ID_1", "TEST_PARENT_3", 3L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 3L),
62 | Row("TEST_ID_1", "TEST_PARENT_4", 4L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 4L),
63 | Row("TEST_ID_1", "TEST_PARENT_5", 5L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 5L),
64 | Row("TEST_ID_5", "TEST_PARENT_6", 6L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 6L),
65 | Row("TEST_ID_6", "TEST_PARENT_7", 7L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 7L),
66 | Row("TEST_ID_7", "TEST_PARENT_8", 8L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 8L),
67 | Row("TEST_ID_8", "TEST_PARENT_9", 9L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 9L),
68 | Row("TEST_ID_9", "TEST_PARENT_10", 10L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 10L),
69 | Row("TEST_ID_10", "TEST_PARENT_11", 11L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 11L)
70 | )
71 | }
72 | val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema)
73 | val inputDataIterator = inputDataFrame.rdd.toLocalIterator
74 | val inputSparkEsBulkWriter = new SparkEsBulkWriter[Row](
75 | esIndex = "test_index",
76 | esType = "test_type",
77 | esClient = () => esServer.client,
78 | sparkEsSerializer = new SparkEsDataFrameSerializer(inputSchema),
79 | sparkEsMapper = new SparkEsDataFrameMapper(inputMapperConf),
80 | sparkEsWriteConf = inputSparkEsWriteConf
81 | )
82 |
83 | inputSparkEsBulkWriter.write(null, inputDataIterator)
84 |
85 | val outputGetResponse = esServer.client.prepareGet("test_index", "test_type", "TEST_ID_1").get()
86 |
87 | outputGetResponse.isExists mustEqual true
88 | outputGetResponse.getSource.get("parent").asInstanceOf[String] mustEqual "TEST_PARENT_5"
89 | outputGetResponse.getSource.get("version").asInstanceOf[Integer] mustEqual 5
90 | outputGetResponse.getSource.get("routing").asInstanceOf[String] mustEqual "TEST_ROUTING_1"
91 | outputGetResponse.getSource.get("ttl").asInstanceOf[Integer] mustEqual 86400000
92 | outputGetResponse.getSource.get("timestamp").asInstanceOf[String] mustEqual "TEST_TIMESTAMP_1"
93 | outputGetResponse.getSource.get("value").asInstanceOf[Integer] mustEqual 5
94 | }
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsMapperConfSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.conf
2 |
3 | import org.apache.spark.SparkConf
4 | import org.scalatest.{ MustMatchers, WordSpec }
5 |
6 | class SparkEsMapperConfSpec extends WordSpec with MustMatchers {
7 | "SparkEsMapperConf" must {
8 | "be extracted from SparkConf successfully" in {
9 | val inputSparkConf = new SparkConf()
10 | .set("es.mapping.id", "TEST_VALUE_1")
11 | .set("es.mapping.parent", "TEST_VALUE_2")
12 | .set("es.mapping.version", "TEST_VALUE_3")
13 | .set("es.mapping.version.type", "TEST_VALUE_4")
14 | .set("es.mapping.routing", "TEST_VALUE_5")
15 | .set("es.mapping.ttl", "TEST_VALUE_6")
16 | .set("es.mapping.timestamp", "TEST_VALUE_7")
17 |
18 | val expectedSparkEsMapperConf = SparkEsMapperConf(
19 | esMappingId = Some("TEST_VALUE_1"),
20 | esMappingParent = Some("TEST_VALUE_2"),
21 | esMappingVersion = Some("TEST_VALUE_3"),
22 | esMappingVersionType = Some("TEST_VALUE_4"),
23 | esMappingRouting = Some("TEST_VALUE_5"),
24 | esMappingTTLInMillis = Some("TEST_VALUE_6"),
25 | esMappingTimestamp = Some("TEST_VALUE_7")
26 | )
27 |
28 | val outputSparkEsMapperConf = SparkEsMapperConf.fromSparkConf(inputSparkConf)
29 |
30 | outputSparkEsMapperConf mustEqual expectedSparkEsMapperConf
31 | }
32 |
33 | "extract CONSTANT_FIELD_REGEX successfully" in {
34 | val inputString = ""
35 |
36 | val expectedString = "TEST_VALUE_1"
37 |
38 | val outputString = inputString match {
39 | case SparkEsMapperConf.CONSTANT_FIELD_REGEX(outputString) =>
40 | outputString
41 | case _ =>
42 | fail("CONSTANT_FIELD_REGEX failed.")
43 | }
44 |
45 | outputString mustEqual expectedString
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsTransportClientConfSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.conf
2 |
3 | import java.net.InetSocketAddress
4 |
5 | import org.apache.spark.SparkConf
6 | import org.scalatest.{ MustMatchers, WordSpec }
7 |
8 | class SparkEsTransportClientConfSpec extends WordSpec with MustMatchers {
9 | "SparkEsTransportClientConf" must {
10 | "be extracted from SparkConf successfully" in {
11 | val inputSparkConf = new SparkConf()
12 | .set("es.nodes", "127.0.0.1:9000,127.0.0.1:9001,127.0.0.1:9002")
13 | .set("es.port", "1337")
14 |
15 | val expectedSparkEsTransportClientConf = SparkEsTransportClientConf(
16 | transportAddresses = Seq("127.0.0.1:9000", "127.0.0.1:9001", "127.0.0.1:9002"),
17 | transportPort = 1337,
18 | transportSettings = Map.empty[String, String]
19 | )
20 |
21 | val outputSparkEsTransportClientConf = SparkEsTransportClientConf.fromSparkConf(inputSparkConf)
22 |
23 | outputSparkEsTransportClientConf mustEqual expectedSparkEsTransportClientConf
24 | }
25 |
26 | "be extracted from SparkConf unsuccessfully" in {
27 | val inputSparkConf = new SparkConf()
28 |
29 | val outputException = intercept[IllegalArgumentException] {
30 | SparkEsTransportClientConf.fromSparkConf(inputSparkConf)
31 | }
32 |
33 | outputException.getMessage must include("No nodes defined in property es.nodes is in SparkConf.")
34 | }
35 |
36 | "extract transportSettings successfully" in {
37 | val inputSparkConf = new SparkConf()
38 | .set("es.nodes", "127.0.0.1:9000,127.0.0.1:9001,127.0.0.1:9002")
39 | .set("es.port", "1337")
40 | .set("es.cluster.name", "TEST_VALUE_1")
41 | .set("es.client.transport.sniff", "TEST_VALUE_2")
42 | .set("es.client.transport.ignore_cluster_name", "TEST_VALUE_3")
43 | .set("es.client.transport.ping_timeout", "TEST_VALUE_4")
44 | .set("es.client.transport.nodes_sampler_interval", "TEST_VALUE_5")
45 |
46 | val expectedSparkEsTransportClientConf = SparkEsTransportClientConf(
47 | transportAddresses = Seq("127.0.0.1:9000", "127.0.0.1:9001", "127.0.0.1:9002"),
48 | transportPort = 1337,
49 | transportSettings = Map(
50 | "cluster.name" -> "TEST_VALUE_1",
51 | "client.transport.sniff" -> "TEST_VALUE_2",
52 | "client.transport.ignore_cluster_name" -> "TEST_VALUE_3",
53 | "client.transport.ping_timeout" -> "TEST_VALUE_4",
54 | "client.transport.nodes_sampler_interval" -> "TEST_VALUE_5"
55 | )
56 | )
57 |
58 | val outputSparkEsTransportClientConf = SparkEsTransportClientConf.fromSparkConf(inputSparkConf)
59 |
60 | outputSparkEsTransportClientConf mustEqual expectedSparkEsTransportClientConf
61 | }
62 |
63 | "extract transportAddresses as Seq[InetSocketAddress] successfully with port secondly" in {
64 | val inputAddresses = Seq("127.0.0.1:9000", "127.0.0.1:9001", "127.0.0.1:9002")
65 | val inputPort = 1337
66 |
67 | val expectedTransportAddresses = Seq(
68 | new InetSocketAddress("127.0.0.1", 9000),
69 | new InetSocketAddress("127.0.0.1", 9001),
70 | new InetSocketAddress("127.0.0.1", 9002)
71 | )
72 |
73 | val outputTransportAddresses = SparkEsTransportClientConf.getTransportAddresses(inputAddresses, inputPort)
74 |
75 | outputTransportAddresses mustEqual expectedTransportAddresses
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/conf/SparkEsWriteConfSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.conf
2 |
3 | import org.apache.spark.SparkConf
4 | import org.scalatest.{ MustMatchers, WordSpec }
5 |
6 | class SparkEsWriteConfSpec extends WordSpec with MustMatchers {
7 | "SparkEsWriteConf" must {
8 | "be extracted from SparkConf successfully" in {
9 | val inputSparkConf = new SparkConf()
10 | .set("es.batch.size.entries", "1")
11 | .set("es.batch.size.bytes", "2")
12 | .set("es.batch.concurrent.request", "3")
13 | .set("es.batch.flush.timeout", "4")
14 |
15 | val expectedSparkEsWriteConf = SparkEsWriteConf(
16 | bulkActions = 1,
17 | bulkSizeInMB = 2,
18 | concurrentRequests = 3,
19 | flushTimeoutInSeconds = 4
20 | )
21 |
22 | val outputSparkEsWriteConf = SparkEsWriteConf.fromSparkConf(inputSparkConf)
23 |
24 | outputSparkEsWriteConf mustEqual expectedSparkEsWriteConf
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/sql/PackageSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.sql
2 |
3 | import com.holdenkarau.spark.testing.SharedSparkContext
4 | import org.apache.spark.sql.SQLContext
5 | import org.scalatest.{ MustMatchers, WordSpec }
6 |
7 | class PackageSpec extends WordSpec with MustMatchers with SharedSparkContext {
8 | "Package com.github.jparkie.spark.elasticsearch.sql" must {
9 | "lift DataFrame into SparkEsDataFrameFunctions" in {
10 |
11 | val sqlContext = new SQLContext(sc)
12 |
13 | val inputData = Seq(
14 | ("TEST_VALUE_1", 1),
15 | ("TEST_VALUE_2", 2),
16 | ("TEST_VALUE_3", 3)
17 | )
18 |
19 | val outputDataFrame = sqlContext.createDataFrame(inputData)
20 | .toDF("key", "value")
21 |
22 | // If sparkContext is available, DataFrame was lifted into SparkEsDataFrameFunctions.
23 | outputDataFrame.sparkContext
24 | }
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameMapperSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.sql
2 |
3 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsMapperConf
4 | import com.holdenkarau.spark.testing.SharedSparkContext
5 | import org.apache.spark.SparkContext
6 | import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
7 | import org.apache.spark.sql.{ Row, SQLContext }
8 | import org.elasticsearch.index.VersionType
9 | import org.scalatest.{ MustMatchers, WordSpec }
10 |
11 | class SparkEsDataFrameMapperSpec extends WordSpec with MustMatchers with SharedSparkContext {
12 | def createInputRow(sparkContext: SparkContext): Row = {
13 | val sqlContext = new SQLContext(sparkContext)
14 |
15 | val inputSchema = StructType(
16 | Array(
17 | StructField("id", StringType, true),
18 | StructField("parent", StringType, true),
19 | StructField("version", LongType, true),
20 | StructField("routing", StringType, true),
21 | StructField("ttl", LongType, true),
22 | StructField("timestamp", StringType, true),
23 | StructField("value", LongType, true)
24 | )
25 | )
26 | val inputData = sc.parallelize {
27 | Array(
28 | Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L)
29 | )
30 | }
31 | val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema)
32 | val inputRow = inputDataFrame.first()
33 |
34 | inputRow
35 | }
36 |
37 | "SparkEsDataFrameMapper" must {
38 | "from Row extract noting successfully" in {
39 | val inputRow = createInputRow(sc)
40 | val inputMapperConf = SparkEsMapperConf(
41 | esMappingId = None,
42 | esMappingParent = None,
43 | esMappingVersion = None,
44 | esMappingVersionType = None,
45 | esMappingRouting = None,
46 | esMappingTTLInMillis = None,
47 | esMappingTimestamp = None
48 | )
49 |
50 | val outputSparkEsDataFrameMapper = new SparkEsDataFrameMapper(inputMapperConf)
51 |
52 | outputSparkEsDataFrameMapper.extractMappingId(inputRow) mustEqual None
53 | outputSparkEsDataFrameMapper.extractMappingParent(inputRow) mustEqual None
54 | outputSparkEsDataFrameMapper.extractMappingVersion(inputRow) mustEqual None
55 | outputSparkEsDataFrameMapper.extractMappingVersionType(inputRow) mustEqual None
56 | outputSparkEsDataFrameMapper.extractMappingRouting(inputRow) mustEqual None
57 | outputSparkEsDataFrameMapper.extractMappingTTLInMillis(inputRow) mustEqual None
58 | outputSparkEsDataFrameMapper.extractMappingTimestamp(inputRow) mustEqual None
59 | }
60 |
61 | "from Row execute extractMappingId() successfully" in {
62 | val inputRow = createInputRow(sc)
63 | val inputMapperConf = SparkEsMapperConf(
64 | esMappingId = Some("id"),
65 | esMappingParent = None,
66 | esMappingVersion = None,
67 | esMappingVersionType = None,
68 | esMappingRouting = None,
69 | esMappingTTLInMillis = None,
70 | esMappingTimestamp = None
71 | )
72 |
73 | val outputSparkEsDataFrameMapper = new SparkEsDataFrameMapper(inputMapperConf)
74 |
75 | outputSparkEsDataFrameMapper.extractMappingId(inputRow) mustEqual Some("TEST_ID_1")
76 | }
77 |
78 | "from Row execute extractMappingParent() successfully" in {
79 | val inputRow = createInputRow(sc)
80 | val inputMapperConf1 = SparkEsMapperConf(
81 | esMappingId = None,
82 | esMappingParent = Some("parent"),
83 | esMappingVersion = None,
84 | esMappingVersionType = None,
85 | esMappingRouting = None,
86 | esMappingTTLInMillis = None,
87 | esMappingTimestamp = None
88 | )
89 | val inputMapperConf2 = SparkEsMapperConf(
90 | esMappingId = None,
91 | esMappingParent = Some(""),
92 | esMappingVersion = None,
93 | esMappingVersionType = None,
94 | esMappingRouting = None,
95 | esMappingTTLInMillis = None,
96 | esMappingTimestamp = None
97 | )
98 |
99 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1)
100 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2)
101 |
102 | outputSparkEsDataFrameMapper1.extractMappingParent(inputRow) mustEqual Some("TEST_PARENT_1")
103 | outputSparkEsDataFrameMapper2.extractMappingParent(inputRow) mustEqual Some("TEST_VALUE")
104 | }
105 |
106 | "from Row execute extractMappingVersion() successfully" in {
107 | val inputRow = createInputRow(sc)
108 | val inputMapperConf1 = SparkEsMapperConf(
109 | esMappingId = None,
110 | esMappingParent = None,
111 | esMappingVersion = Some("version"),
112 | esMappingVersionType = None,
113 | esMappingRouting = None,
114 | esMappingTTLInMillis = None,
115 | esMappingTimestamp = None
116 | )
117 | val inputMapperConf2 = SparkEsMapperConf(
118 | esMappingId = None,
119 | esMappingParent = None,
120 | esMappingVersion = Some("<1337>"),
121 | esMappingVersionType = None,
122 | esMappingRouting = None,
123 | esMappingTTLInMillis = None,
124 | esMappingTimestamp = None
125 | )
126 |
127 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1)
128 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2)
129 |
130 | outputSparkEsDataFrameMapper1.extractMappingVersion(inputRow) mustEqual Some(1L)
131 | outputSparkEsDataFrameMapper2.extractMappingVersion(inputRow) mustEqual Some(1337L)
132 | }
133 |
134 | "from Row execute extractMappingVersionType() successfully" in {
135 | val inputRow = createInputRow(sc)
136 | val inputMapperConf1 = SparkEsMapperConf(
137 | esMappingId = None,
138 | esMappingParent = None,
139 | esMappingVersion = Some(""),
140 | esMappingVersionType = Some("force"),
141 | esMappingRouting = None,
142 | esMappingTTLInMillis = None,
143 | esMappingTimestamp = None
144 | )
145 | val inputMapperConf2 = SparkEsMapperConf(
146 | esMappingId = None,
147 | esMappingParent = None,
148 | esMappingVersion = None,
149 | esMappingVersionType = Some("force"),
150 | esMappingRouting = None,
151 | esMappingTTLInMillis = None,
152 | esMappingTimestamp = None
153 | )
154 |
155 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1)
156 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2)
157 |
158 | outputSparkEsDataFrameMapper1.extractMappingVersionType(inputRow) mustEqual Some(VersionType.FORCE)
159 | outputSparkEsDataFrameMapper2.extractMappingVersionType(inputRow) mustEqual None
160 | }
161 |
162 | "from Row execute extractMappingRouting() successfully" in {
163 | val inputRow = createInputRow(sc)
164 | val inputMapperConf1 = SparkEsMapperConf(
165 | esMappingId = None,
166 | esMappingParent = None,
167 | esMappingVersion = None,
168 | esMappingVersionType = None,
169 | esMappingRouting = Some("routing"),
170 | esMappingTTLInMillis = None,
171 | esMappingTimestamp = None
172 | )
173 | val inputMapperConf2 = SparkEsMapperConf(
174 | esMappingId = None,
175 | esMappingParent = None,
176 | esMappingVersion = None,
177 | esMappingVersionType = None,
178 | esMappingRouting = Some(""),
179 | esMappingTTLInMillis = None,
180 | esMappingTimestamp = None
181 | )
182 |
183 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1)
184 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2)
185 |
186 | outputSparkEsDataFrameMapper1.extractMappingRouting(inputRow) mustEqual Some("TEST_ROUTING_1")
187 | outputSparkEsDataFrameMapper2.extractMappingRouting(inputRow) mustEqual Some("TEST_VALUE")
188 | }
189 |
190 | "from Row execute extractMappingTTLInMillis() successfully" in {
191 | val inputRow = createInputRow(sc)
192 | val inputMapperConf1 = SparkEsMapperConf(
193 | esMappingId = None,
194 | esMappingParent = None,
195 | esMappingVersion = None,
196 | esMappingVersionType = None,
197 | esMappingRouting = None,
198 | esMappingTTLInMillis = Some("ttl"),
199 | esMappingTimestamp = None
200 | )
201 | val inputMapperConf2 = SparkEsMapperConf(
202 | esMappingId = None,
203 | esMappingParent = None,
204 | esMappingVersion = None,
205 | esMappingVersionType = None,
206 | esMappingRouting = None,
207 | esMappingTTLInMillis = Some("<1337>"),
208 | esMappingTimestamp = None
209 | )
210 |
211 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1)
212 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2)
213 |
214 | outputSparkEsDataFrameMapper1.extractMappingTTLInMillis(inputRow) mustEqual Some(86400000L)
215 | outputSparkEsDataFrameMapper2.extractMappingTTLInMillis(inputRow) mustEqual Some(1337L)
216 | }
217 |
218 | "from Row execute extractMappingTimestamp() successfully" in {
219 | val inputRow = createInputRow(sc)
220 | val inputMapperConf1 = SparkEsMapperConf(
221 | esMappingId = None,
222 | esMappingParent = None,
223 | esMappingVersion = None,
224 | esMappingVersionType = None,
225 | esMappingRouting = None,
226 | esMappingTTLInMillis = None,
227 | esMappingTimestamp = Some("timestamp")
228 | )
229 | val inputMapperConf2 = SparkEsMapperConf(
230 | esMappingId = None,
231 | esMappingParent = None,
232 | esMappingVersion = None,
233 | esMappingVersionType = None,
234 | esMappingRouting = None,
235 | esMappingTTLInMillis = None,
236 | esMappingTimestamp = Some("")
237 | )
238 |
239 | val outputSparkEsDataFrameMapper1 = new SparkEsDataFrameMapper(inputMapperConf1)
240 | val outputSparkEsDataFrameMapper2 = new SparkEsDataFrameMapper(inputMapperConf2)
241 |
242 | outputSparkEsDataFrameMapper1.extractMappingTimestamp(inputRow) mustEqual Some("TEST_TIMESTAMP_1")
243 | outputSparkEsDataFrameMapper2.extractMappingTimestamp(inputRow) mustEqual Some("TEST_VALUE")
244 | }
245 | }
246 | }
247 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/sql/SparkEsDataFrameSerializerSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.sql
2 |
3 | import java.sql.{ Date, Timestamp }
4 |
5 | import com.holdenkarau.spark.testing.SharedSparkContext
6 | import org.apache.spark.sql.types._
7 | import org.apache.spark.sql.{ Row, SQLContext }
8 | import org.elasticsearch.common.xcontent.XContentFactory
9 | import org.scalatest.{ MustMatchers, WordSpec }
10 |
11 | class SparkEsDataFrameSerializerSpec extends WordSpec with MustMatchers with SharedSparkContext {
12 | "SparkEsDataFrameSerializer" must {
13 | "execute writeStruct() successfully" in {
14 | val sqlContext = new SQLContext(sc)
15 |
16 | val inputSchema = StructType(
17 | Array(
18 | StructField("id", StringType, true),
19 | StructField("parent", StringType, true),
20 | StructField("version", LongType, true),
21 | StructField("routing", StringType, true),
22 | StructField("ttl", LongType, true),
23 | StructField("timestamp", StringType, true),
24 | StructField("value", LongType, true)
25 | )
26 | )
27 | val inputData = sc.parallelize {
28 | Array(
29 | Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L)
30 | )
31 | }
32 | val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema)
33 | val inputRow = inputDataFrame.first()
34 | val inputSparkEsDataFrameSerializer = new SparkEsDataFrameSerializer(inputSchema)
35 |
36 | val inputBuilder = XContentFactory.jsonBuilder()
37 | val outputBuilder = inputSparkEsDataFrameSerializer.writeStruct(inputSchema, inputRow, inputBuilder)
38 | outputBuilder.string() must include("""{"id":"TEST_ID_1","parent":"TEST_PARENT_1","version":1,"routing":"TEST_ROUTING_1","ttl":86400000,"timestamp":"TEST_TIMESTAMP_1","value":1}""")
39 | inputBuilder.close()
40 | }
41 |
42 | "execute writeArray() successfully" in {
43 | val inputSparkEsDataFrameSerializer = new SparkEsDataFrameSerializer(null)
44 |
45 | val inputArray = Array(1, 2, 3)
46 | val inputBuilder = XContentFactory.jsonBuilder()
47 | val outputBuilder = inputSparkEsDataFrameSerializer.writeArray(ArrayType(IntegerType), inputArray, inputBuilder)
48 | outputBuilder.string() must include("""1,2,3""")
49 | inputBuilder.close()
50 | }
51 |
52 | "execute writeMap() successfully" in {
53 | val inputSparkEsDataFrameSerializer = new SparkEsDataFrameSerializer(null)
54 |
55 | val inputMap = Map(
56 | "TEST_KEY_1" -> "TEST_VALUE_1",
57 | "TEST_KEY_2" -> "TEST_VALUE_3",
58 | "TEST_KEY_3" -> "TEST_VALUE_3"
59 | )
60 | val inputBuilder = XContentFactory.jsonBuilder()
61 | val outputBuilder = inputSparkEsDataFrameSerializer.writeMap(MapType(StringType, StringType), inputMap, inputBuilder)
62 | outputBuilder.string() must include("""{"TEST_KEY_1":"TEST_VALUE_1","TEST_KEY_2":"TEST_VALUE_3","TEST_KEY_3":"TEST_VALUE_3"}""")
63 | inputBuilder.close()
64 | }
65 |
66 | "execute writePrimitive() successfully" in {
67 | val inputSparkEsDataFrameSerializer = new SparkEsDataFrameSerializer(null)
68 |
69 | val inputBuilder1 = XContentFactory.jsonBuilder()
70 | val outputBuilder1 = inputSparkEsDataFrameSerializer.writePrimitive(BinaryType, Array[Byte](1), inputBuilder1)
71 | outputBuilder1.string() must include("AQ==")
72 | inputBuilder1.close()
73 |
74 | val inputBuilder2 = XContentFactory.jsonBuilder()
75 | val outputBuilder2 = inputSparkEsDataFrameSerializer.writePrimitive(BooleanType, true, inputBuilder2)
76 | outputBuilder2.string() mustEqual "true"
77 | inputBuilder2.close()
78 |
79 | val inputBuilder3 = XContentFactory.jsonBuilder()
80 | val outputBuilder3 = inputSparkEsDataFrameSerializer.writePrimitive(ByteType, 1.toByte, inputBuilder3)
81 | outputBuilder3.string() mustEqual "1"
82 | inputBuilder3.close()
83 |
84 | val inputBuilder4 = XContentFactory.jsonBuilder()
85 | val outputBuilder4 = inputSparkEsDataFrameSerializer.writePrimitive(ShortType, 1.toShort, inputBuilder4)
86 | outputBuilder4.string() mustEqual "1"
87 | inputBuilder4.close()
88 |
89 | val inputBuilder5 = XContentFactory.jsonBuilder()
90 | val outputBuilder5 = inputSparkEsDataFrameSerializer.writePrimitive(IntegerType, 1.toInt, inputBuilder5)
91 | outputBuilder5.string() mustEqual "1"
92 | inputBuilder5.close()
93 |
94 | val inputBuilder6 = XContentFactory.jsonBuilder()
95 | val outputBuilder6 = inputSparkEsDataFrameSerializer.writePrimitive(LongType, 1.toLong, inputBuilder6)
96 | outputBuilder6.string() mustEqual "1"
97 | inputBuilder6.close()
98 |
99 | val inputBuilder7 = XContentFactory.jsonBuilder()
100 | val outputBuilder7 = inputSparkEsDataFrameSerializer.writePrimitive(DoubleType, 1.0, inputBuilder7)
101 | outputBuilder7.string() mustEqual "1.0"
102 | inputBuilder7.close()
103 |
104 | val inputBuilder8 = XContentFactory.jsonBuilder()
105 | val outputBuilder8 = inputSparkEsDataFrameSerializer.writePrimitive(FloatType, 1.0F, inputBuilder8)
106 | outputBuilder8.string() mustEqual "1.0"
107 | inputBuilder8.close()
108 |
109 | val inputBuilder9 = XContentFactory.jsonBuilder()
110 | val outputBuilder9 = inputSparkEsDataFrameSerializer.writePrimitive(TimestampType, new Timestamp(834120000000L), inputBuilder9)
111 | outputBuilder9.string() must include("834120000000")
112 | inputBuilder9.close()
113 |
114 | val inputBuilder10 = XContentFactory.jsonBuilder()
115 | val outputBuilder10 = inputSparkEsDataFrameSerializer.writePrimitive(DateType, new Date(834120000000L), inputBuilder10)
116 | outputBuilder10.string() must include("834120000000")
117 | inputBuilder10.close()
118 |
119 | val inputBuilder11 = XContentFactory.jsonBuilder()
120 | val outputBuilder11 = inputSparkEsDataFrameSerializer.writePrimitive(StringType, "TEST_VALUE", inputBuilder11)
121 | outputBuilder11.string() must include("TEST_VALUE")
122 | inputBuilder11.close()
123 | }
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/transport/SparkEsTransportClientManagerSpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.transport
2 |
3 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsTransportClientConf
4 | import org.scalatest.{ MustMatchers, WordSpec }
5 |
6 | class SparkEsTransportClientManagerSpec extends WordSpec with MustMatchers {
7 | "SparkEsTransportClientManager" must {
8 | "maintain one unique TransportClient in internalTransportClients" in {
9 | val inputClientConf = SparkEsTransportClientConf(
10 | transportAddresses = Seq("127.0.0.1"),
11 | transportPort = 9300,
12 | transportSettings = Map.empty[String, String]
13 | )
14 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {}
15 |
16 | inputSparkEsTransportClientManager.getTransportClient(inputClientConf)
17 | inputSparkEsTransportClientManager.getTransportClient(inputClientConf)
18 |
19 | inputSparkEsTransportClientManager.internalTransportClients.size mustEqual 1
20 |
21 | inputSparkEsTransportClientManager.closeTransportClient(inputClientConf)
22 | }
23 |
24 | "return a SparkEsTransportClientProxy when calling getTransportClient()" in {
25 | val inputClientConf = SparkEsTransportClientConf(
26 | transportAddresses = Seq("127.0.0.1"),
27 | transportPort = 9300,
28 | transportSettings = Map.empty[String, String]
29 | )
30 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {}
31 |
32 | val outputClient = inputSparkEsTransportClientManager.getTransportClient(inputClientConf)
33 |
34 | outputClient.getClass mustEqual classOf[SparkEsTransportClientProxy]
35 |
36 | inputSparkEsTransportClientManager.closeTransportClient(inputClientConf)
37 | }
38 |
39 | "evict TransportClient after calling closeTransportClient" in {
40 | val inputClientConf = SparkEsTransportClientConf(
41 | transportAddresses = Seq("127.0.0.1"),
42 | transportPort = 9300,
43 | transportSettings = Map.empty[String, String]
44 | )
45 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {}
46 |
47 | inputSparkEsTransportClientManager.getTransportClient(inputClientConf)
48 |
49 | inputSparkEsTransportClientManager.closeTransportClient(inputClientConf)
50 |
51 | inputSparkEsTransportClientManager.internalTransportClients.size mustEqual 0
52 | }
53 |
54 | "returns buildTransportSettings() successfully" in {
55 | val inputClientConf = SparkEsTransportClientConf(
56 | transportAddresses = Seq("127.0.0.1"),
57 | transportPort = 9300,
58 | transportSettings = Map(
59 | "TEST_KEY_1" -> "TEST_VALUE_1",
60 | "TEST_KEY_2" -> "TEST_VALUE_2",
61 | "TEST_KEY_3" -> "TEST_VALUE_3"
62 | )
63 | )
64 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {}
65 |
66 | val outputSettings = inputSparkEsTransportClientManager.buildTransportSettings(inputClientConf)
67 |
68 | outputSettings.get("TEST_KEY_1") mustEqual "TEST_VALUE_1"
69 | outputSettings.get("TEST_KEY_2") mustEqual "TEST_VALUE_2"
70 | outputSettings.get("TEST_KEY_3") mustEqual "TEST_VALUE_3"
71 | }
72 |
73 | "returns buildTransportClient() successfully" in {
74 | val inputClientConf = SparkEsTransportClientConf(
75 | transportAddresses = Seq("127.0.0.1"),
76 | transportPort = 9300,
77 | transportSettings = Map.empty[String, String]
78 | )
79 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {}
80 |
81 | val outputSettings = inputSparkEsTransportClientManager.buildTransportSettings(inputClientConf)
82 | val outputClient = inputSparkEsTransportClientManager.buildTransportClient(inputClientConf, outputSettings)
83 | val outputHost = outputClient.transportAddresses().get(0).getHost
84 | val outputPort = outputClient.transportAddresses().get(0).getPort
85 |
86 | outputHost mustEqual "127.0.0.1"
87 | outputPort mustEqual 9300
88 |
89 | outputClient.close()
90 | }
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/jparkie/spark/elasticsearch/transport/SparkEsTransportClientProxySpec.scala:
--------------------------------------------------------------------------------
1 | package com.github.jparkie.spark.elasticsearch.transport
2 |
3 | import com.github.jparkie.spark.elasticsearch.conf.SparkEsTransportClientConf
4 | import com.github.jparkie.spark.elasticsearch.util.SparkEsException
5 | import org.scalatest.{ MustMatchers, WordSpec }
6 |
7 | class SparkEsTransportClientProxySpec extends WordSpec with MustMatchers {
8 | "SparkEsTransportClientProxy" must {
9 | "prohibit close() call" in {
10 | val inputClientConf = SparkEsTransportClientConf(
11 | transportAddresses = Seq("127.0.0.1"),
12 | transportPort = 9300,
13 | transportSettings = Map.empty[String, String]
14 | )
15 | val inputSparkEsTransportClientManager = new SparkEsTransportClientManager {}
16 | val inputSparkEsTransportClient = inputSparkEsTransportClientManager.getTransportClient(inputClientConf)
17 | val inputSparkEsTransportClientProxy = new SparkEsTransportClientProxy(inputSparkEsTransportClient)
18 |
19 | val outputException = intercept[SparkEsException] {
20 | inputSparkEsTransportClientProxy.close()
21 | }
22 |
23 | outputException.getMessage must include("close() is not supported in SparkEsTransportClientProxy. Please close with SparkEsTransportClientManager.")
24 | }
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/version.sbt:
--------------------------------------------------------------------------------
1 | version in ThisBuild := "2.0.0-SNAPSHOT"
--------------------------------------------------------------------------------