├── .gitignore
├── .travis.yml
├── Documentation.md
├── LICENSE
├── README.md
├── pom.xml
└── src
├── it
└── scala
│ └── org
│ └── apache
│ └── spark
│ └── orientdb
│ ├── documents
│ ├── IntegrationSuiteBase.scala
│ └── OrientDBIntegrationSuite.scala
│ └── graphs
│ ├── IntegrationSuiteBase.scala
│ └── OrientDBGraphIntegrationSuite.scala
├── main
├── examples
│ └── org
│ │ └── apache
│ │ └── spark
│ │ └── orientdb
│ │ ├── documents
│ │ └── DataFrameTest.scala
│ │ └── graphs
│ │ └── GraphFrameTest.scala
└── scala
│ └── org
│ └── apache
│ └── spark
│ └── orientdb
│ ├── documents
│ ├── Conversions.scala
│ ├── DefaultSource.scala
│ ├── FilterPushdown.scala
│ ├── OrientDBClientFactory.scala
│ ├── OrientDBCredentials.scala
│ ├── OrientDBDocumentWrapper.scala
│ ├── OrientDBRelation.scala
│ ├── OrientDBWriter.scala
│ ├── Parameters.scala
│ └── TableName.scala
│ ├── graphs
│ ├── DefaultSource.scala
│ ├── OrientDBClientFactory.scala
│ ├── OrientDBCredentials.scala
│ ├── OrientDBGraphWrapper.scala
│ ├── OrientDBRelation.scala
│ ├── OrientDBWriter.scala
│ └── Parameters.scala
│ └── udts
│ ├── EmbeddedListType.scala
│ ├── EmbeddedMapType.scala
│ ├── EmbeddedSetType.scala
│ ├── LinkBagType.scala
│ ├── LinkListType.scala
│ ├── LinkMapType.scala
│ ├── LinkSetType.scala
│ └── LinkType.scala
└── test
└── scala
└── org
└── apache
└── spark
└── orientdb
├── QueryTest.scala
├── TestUtils.scala
├── documents
├── ConversionsSuite.scala
├── FilterPushdownSuite.scala
├── MockOrientDBDocument.scala
├── OrientDBEmbeddedUDTsSourceSuite.scala
├── OrientDBLinkUDTsSourceSuite.scala
├── OrientDBSourceSuite.scala
├── ParametersSuite.scala
└── TableNameSuite.scala
└── graphs
├── MockEdge.scala
├── MockOrientDBGraph.scala
├── MockVertex.scala
├── OrientDBEmbeddedUDTsSourceSuite.scala
├── OrientDBGraphSourceSuite.scala
├── OrientDBLinkUDTsSourceSuite.scala
└── ParametersSuite.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | .idea/
3 | *.iml
4 | *.log
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: java
--------------------------------------------------------------------------------
/Documentation.md:
--------------------------------------------------------------------------------
1 | Apache Spark datasource for OrientDB
2 | ====================================
3 |
4 | ## Introduction
5 |
6 | This documentation discusses the topic of how to connect Apache Spark to OrientDB. [Apache Spark](http://spark.apache.org/) is the widely popular engine for large-scale data processing.
7 | Here, we will discuss how to use the [**Apache Spark datasource for OrientDB**](https://github.com/sbcd90/spark-orientdb) to leverage Spark's capabilities while using OrientDB as the datastore.
8 |
9 | ## Installation Guide
10 |
11 | To use the Apache Spark datasource for OrientDB inside a Spark application, the following steps need to be performed.
12 |
13 | - Add the repository location to `pom.xml`.
14 |
15 | ```
16 |
17 | bintray
18 | bintray
19 | https://dl.bintray.com/sbcd90/org.apache.spark/
20 |
21 | ```
22 |
23 | - Add the datasource as a maven dependency in `pom.xml`.
24 |
25 | ```
26 |
27 | org.apache.spark
28 | spark-orientdb-{spark.version}_2.10
29 | 1.3
30 |
31 | ```
32 |
33 |
34 | ## Configuration
35 |
36 | The datasource is supported for Apache Spark version 1.6+ & OrientDB 2.2.0+.
37 |
38 | ## API Reference
39 |
40 | The **Apache Spark datasource for OrientDB** allows users to leverage the Apache Spark datasource api s for reading and writing data from OrientDB. The datasource loads a collection of OrientDB documents into a dataframe & an OrientDB graph containing a set of vertices & edges into a Graphframe.
41 |
42 | The complete api reference for using the **Apache Spark datasource for OrientDB** is provided below.
43 |
44 | ### OrientDB Documents
45 |
46 | #### Write api:
47 |
48 | ```
49 | import org.apache.spark.sql.SQLContext
50 |
51 | val sqlContext = new SQLContext(sc)
52 | sqlContext.createDataFrame(sc.parallelize(Array(1, 2, 3, 4, 5)),
53 | StructType(Seq(StructField("id", IntegerType)))
54 | .write
55 | .format("org.apache.spark.orientdb.documents")
56 | .option("dburl", ORIENTDB_CONNECTION_URL)
57 | .option("user", ORIENTDB_USER).option("password", ORIENTDB_PASSWORD)
58 | .option("class", test_class)
59 | .mode(SaveMode.Overwrite)
60 | .save()
61 | ```
62 |
63 | #### Read api:
64 |
65 | ```
66 | import org.apache.spark.sql.SQLContext
67 |
68 | val sqlContext = new SQLContext(sc)
69 | val loadedDf = sqlContext.read
70 | .format("org.apache.spark.orientdb.documents")
71 | .option("dburl", ORIENTDB_CONNECTION_URL)
72 | .option("user", ORIENTDB_USER)
73 | .option("password", ORIENTDB_PASSWORD)
74 | .option("class", test_class)
75 | .option("query", s"select * from $test_table where teststring = 'asdf'")
76 | .load()
77 | ```
78 |
79 | #### Query using OrientDB SQL:
80 |
81 | ```
82 | import org.apache.spark.sql.SQLContext
83 |
84 | val sqlContext = new SQLContext(sc)
85 | val loadedDf = sqlContext.read
86 | .format("org.apache.spark.orientdb.documents")
87 | .option("dburl", ORIENTDB_CONNECTION_URL)
88 | .option("user", ORIENTDB_USER)
89 | .option("password", ORIENTDB_PASSWORD)
90 | .option("class", test_class)
91 | .option("query", s"select * from $test_table where teststring = 'asdf'")
92 | .load()
93 | ```
94 |
95 | ### OrientDB Graphs:
96 |
97 | #### Create Vertex api:
98 |
99 | ```
100 | import org.apache.spark.sql.SQLContext
101 |
102 | val sqlContext = new SQLContext(sc)
103 | sqlContext.createDataFrame(sc.parallelize(Array(1, 2, 3, 4, 5)),
104 | StructType(Seq(StructField("id", IntegerType)))
105 | .write
106 | .format("org.apache.spark.orientdb.graphs")
107 | .option("dburl", ORIENTDB_CONNECTION_URL)
108 | .option("user", ORIENTDB_USER)
109 | .option("password", ORIENTDB_PASSWORD)
110 | .option("vertextype", test_vertex_type2)
111 | .mode(SaveMode.Overwrite)
112 | .save()
113 | ```
114 |
115 | #### Create Edge api:
116 |
117 | ```
118 | import org.apache.spark.sql.SQLContext
119 |
120 | val sqlContext = new SQLContext(sc)
121 | sqlContext.createDataFrame(
122 | sc.parallelize(Seq(
123 | Row(1, 2, "friends"),
124 | Row(2, 3, "enemy"),
125 | Row(3, 4, "friends"),
126 | Row(4, 1, "enemy")
127 | )),
128 | StructType(Seq(
129 | StructField("src", IntegerType),
130 | StructField("dst", IntegerType),
131 | StructField("relationship", StringType)
132 | )))
133 | .write
134 | .format("org.apache.spark.orientdb.graphs")
135 | .option("dburl", ORIENTDB_CONNECTION_URL)
136 | .option("user", ORIENTDB_USER)
137 | .option("password", ORIENTDB_PASSWORD)
138 | .option("vertextype", test_vertex_type2)
139 | .option("edgetype", test_edge_type2)
140 | .mode(SaveMode.Overwrite)
141 | .save()
142 | ```
143 |
144 | #### Read Vertex api:
145 |
146 | ```
147 | import org.apache.spark.sql.SQLContext
148 |
149 | val sqlContext = new SQLContext(sc)
150 | val loadedDf = sqlContext.read
151 | .format("org.apache.spark.orientdb.graphs")
152 | .option("dburl", ORIENTDB_CONNECTION_URL)
153 | .option("user", ORIENTDB_USER)
154 | .option("password", ORIENTDB_PASSWORD)
155 | .option("vertextype", test_vertex_type2)
156 | .load()
157 | ```
158 |
159 | #### Read edge api:
160 |
161 | ```
162 | import org.apache.spark.sql.SQLContext
163 |
164 | val sqlContext = new SQLContext(sc)
165 | val loadedDf = sqlContext.read
166 | .format("org.apache.spark.orientdb.graphs")
167 | .option("dburl", ORIENTDB_CONNECTION_URL)
168 | .option("user", ORIENTDB_USER)
169 | .option("password", ORIENTDB_PASSWORD)
170 | .option("edgetype", test_edge_type2)
171 | .load()
172 | ```
173 |
174 | #### Query using OrientDB Graph SQL:
175 |
176 | ```
177 | import org.apache.spark.sql.SQLContext
178 |
179 | val sqlContext = new SQLContext(sc)
180 | val loadedVerticesDf = sqlContext.read
181 | .format("org.apache.spark.orientdb.graphs")
182 | .option("dburl", ORIENTDB_CONNECTION_URL)
183 | .option("user", ORIENTDB_USER)
184 | .option("password", ORIENTDB_PASSWORD)
185 | .option("vertextype", test_vertex_type2)
186 | .option("query", s"select * from $test_vertex_type2 where teststring = 'asdf'")
187 | .load()
188 |
189 | val loadedEdgesDf = sqlContext.read
190 | .format("org.apache.spark.orientdb.graphs")
191 | .option("dburl", ORIENTDB_CONNECTION_URL)
192 | .option("user", ORIENTDB_USER)
193 | .option("password", ORIENTDB_PASSWORD)
194 | .option("edgetype", test_edge_type2)
195 | .option("query", s"select * from $test_edge_type2 where relationship = 'friends'")
196 | .load()
197 | ```
198 |
199 | ### Integration with GraphFrames
200 |
201 | ```
202 | import org.apache.spark.sql.SQLContext
203 |
204 | val sqlContext = new SQLContext(sc)
205 | val loadedVerticesDf = sqlContext.read
206 | .format("org.apache.spark.orientdb.graphs")
207 | .option("dburl", ORIENTDB_CONNECTION_URL)
208 | .option("user", ORIENTDB_USER)
209 | .option("password", ORIENTDB_PASSWORD)
210 | .option("vertextype", test_vertex_type2)
211 | .option("query", s"select * from $test_vertex_type2 where teststring = 'asdf'")
212 | .load()
213 |
214 | val loadedEdgesDf = sqlContext.read
215 | .format("org.apache.spark.orientdb.graphs")
216 | .option("dburl", ORIENTDB_CONNECTION_URL)
217 | .option("user", ORIENTDB_USER)
218 | .option("password", ORIENTDB_PASSWORD)
219 | .option("edgetype", test_edge_type2)
220 | .option("query", s"select * from $test_edge_type2 where relationship = 'friends'")
221 | .load()
222 |
223 | val g = GraphFrame(loadedVerticesDf, loadedEdgesDf)
224 | ```
225 |
226 | ## Examples
227 |
228 | ### An example Spark application with OrientDB Documents
229 |
230 | ```
231 | package org.apache.spark.orientdb.documents
232 |
233 | import org.apache.spark.sql.{SQLContext, SaveMode}
234 | import org.apache.spark.{SparkConf, SparkContext}
235 |
236 | object DataFrameTest extends App {
237 | val conf = new SparkConf().setAppName("DataFrameTest").setMaster("local[*]")
238 | val sc = new SparkContext(conf)
239 | val sqlContext = new SQLContext(sc)
240 |
241 | import sqlContext.implicits._
242 | val df = sc.parallelize(Array(1, 2, 3, 4, 5)).toDF("id")
243 |
244 | df.write.format("org.apache.spark.orientdb.documents")
245 | .option("dburl", "")
246 | .option("user", "****")
247 | .option("password", "****")
248 | .option("class", "test_class")
249 | .mode(SaveMode.Overwrite)
250 | .save()
251 |
252 | val resultDf = sqlContext.read
253 | .format("org.apache.spark.orientdb.documents")
254 | .option("dburl", "")
255 | .option("user", "****")
256 | .option("password", "****")
257 | .option("class", "test_class")
258 | .load()
259 |
260 | resultDf.show()
261 | }
262 | ```
263 |
264 | ### An example Spark application with OrientDB Graph
265 |
266 | ```
267 | package org.apache.spark.orientdb.graphs
268 |
269 | import org.apache.spark.{SparkConf, SparkContext}
270 | import org.apache.spark.sql.{Row, SQLContext, SaveMode}
271 | import org.apache.spark.sql.types.{StringType, StructField, StructType}
272 | import org.graphframes.GraphFrame
273 |
274 | object GraphFrameTest extends App {
275 | val conf = new SparkConf().setAppName("MainApplication").setMaster("local[*]")
276 | val sc = new SparkContext(conf)
277 | sc.setLogLevel("WARN")
278 | val sqlContext = new SQLContext(sc)
279 |
280 | import sqlContext.implicits._
281 | val df = sc.parallelize(Array(1, 2, 3, 4, 5)).toDF("id")
282 |
283 | df.write.format("org.apache.spark.orientdb.graphs")
284 | .option("dburl", "")
285 | .option("user", "****")
286 | .option("password", "****")
287 | .option("vertextype", "v104")
288 | .mode(SaveMode.Overwrite)
289 | .save()
290 |
291 | val vertices = sqlContext.read
292 | .format("org.apache.spark.orientdb.graphs")
293 | .option("dburl", "")
294 | .option("user", "****")
295 | .option("password", "****")
296 | .option("vertextype", "v104")
297 | .load()
298 |
299 | var inVertex: Integer = null
300 | var outVertex: Integer = null
301 | vertices.collect().foreach(row => {
302 | if (inVertex == null) {
303 | inVertex = row.getAs[Integer]("id")
304 | }
305 | if (outVertex == null) {
306 | outVertex = row.getAs[Integer]("id")
307 | }
308 | })
309 |
310 | val df1 = sqlContext.createDataFrame(sc.parallelize(Seq(Row("friends", "1", "2"),
311 | Row("enemies", "2", "3"), Row("friends", "3", "1"))),
312 | StructType(List(StructField("relationship", StringType), StructField("src", StringType),
313 | StructField("dst", StringType))))
314 |
315 | df1.write.format("org.apache.spark.orientdb.graphs")
316 | .option("dburl", "")
317 | .option("user", "****")
318 | .option("password", "****")
319 | .option("vertextype", "v104")
320 | .option("edgetype", "e104")
321 | .mode(SaveMode.Overwrite)
322 | .save()
323 |
324 | val edges = sqlContext.read
325 | .format("org.apache.spark.orientdb.graphs")
326 | .option("dburl", "")
327 | .option("user", "****")
328 | .option("password", "****")
329 | .option("edgetype", "e104")
330 | .load()
331 |
332 | edges.show()
333 |
334 | val g = GraphFrame(vertices, edges)
335 | g.inDegrees.show()
336 | println(g.edges.filter("relationship = 'friends'").count())
337 | }
338 | ```
339 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2016 Subhobrata Dey
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | spark-orientdb
2 | ==============
3 | [](https://travis-ci.org/sbcd90/spark-orientdb) [  ](https://bintray.com/sbcd90/org.apache.spark/spark-orientdb-1.6.2_2.10/_latestVersion)
4 |
5 | Apache Spark datasource for OrientDB
6 |
7 | OrientDB documentation
8 | ======================
9 |
10 | Here is the latest documentation on [OrientDB](http://orientdb.com/orientdb/)
11 |
12 | Compatibility
13 | =============
14 |
15 | `Spark`: 1.6+
16 | `OrientDB`: 2.2.0+
17 |
18 | Getting Started
19 | ===============
20 |
21 | - Add the repository
22 |
23 | ```
24 |
25 | bintray
26 | bintray
27 | https://dl.bintray.com/sbcd90/org.apache.spark/
28 |
29 | ```
30 |
31 | ### For Spark 1.6
32 |
33 | - Add the datasource as a maven dependency
34 |
35 | ```
36 |
37 | org.apache.spark
38 | spark-orientdb-1.6.2_2.10
39 | 1.3
40 |
41 | ```
42 |
43 | ### For Spark 2.0
44 |
45 | - Add the datasource as a maven dependency
46 |
47 | ```
48 |
49 | org.apache.spark
50 | spark-orientdb-2.0.0_2.10
51 | 1.4
52 |
53 | ```
54 |
55 | ### For Spark 2.1
56 |
57 | ```
58 |
59 | org.apache.spark
60 | spark-orientdb-2.1.1_2.11
61 | 1.4
62 |
63 | ```
64 |
65 | ### For Spark 2.2
66 |
67 | ```
68 |
69 | org.apache.spark
70 | spark-orientdb-2.2.1_2.11
71 | 1.4
72 |
73 | ```
74 |
75 | Scala api
76 | =========
77 |
78 | ### OrientDB Documents
79 |
80 | #### Write api:
81 |
82 | ```
83 | import org.apache.spark.sql.SQLContext
84 |
85 | val sqlContext = new SQLContext(sc)
86 | sqlContext.createDataFrame(sc.parallelize(Array(1, 2, 3, 4, 5)),
87 | StructType(Seq(StructField("id", IntegerType)))
88 | .write
89 | .format("org.apache.spark.orientdb.documents")
90 | .option("dburl", ORIENTDB_CONNECTION_URL)
91 | .option("user", ORIENTDB_USER).option("password", ORIENTDB_PASSWORD)
92 | .option("class", test_table)
93 | .mode(SaveMode.Overwrite)
94 | .save()
95 | ```
96 |
97 | #### Read api:
98 |
99 | ```
100 | import org.apache.spark.sql.SQLContext
101 |
102 | val sqlContext = new SQLContext(sc)
103 | val loadedDf = sqlContext.read
104 | .format("org.apache.spark.orientdb.documents")
105 | .option("dburl", ORIENTDB_CONNECTION_URL)
106 | .option("user", ORIENTDB_USER)
107 | .option("password", ORIENTDB_PASSWORD)
108 | .option("class", test_table)
109 | .option("query", s"select * from $test_table where teststring = 'asdf'")
110 | .load()
111 | ```
112 |
113 | #### Query using OrientDB SQL:
114 |
115 | ```
116 | import org.apache.spark.sql.SQLContext
117 |
118 | val sqlContext = new SQLContext(sc)
119 | val loadedDf = sqlContext.read
120 | .format("org.apache.spark.orientdb.documents")
121 | .option("dburl", ORIENTDB_CONNECTION_URL)
122 | .option("user", ORIENTDB_USER)
123 | .option("password", ORIENTDB_PASSWORD)
124 | .option("class", test_table)
125 | .option("query", s"select * from $test_table where teststring = 'asdf'")
126 | .load()
127 | ```
128 |
129 | #### Support for Embedded Types( Since Spark 2.1 release):
130 |
131 | ```
132 | val testSchemaForEmbeddedUDTs: StructType = {
133 | StructType(Seq(
134 | StructField("embeddedlist", EmbeddedListType),
135 | StructField("embeddedset", EmbeddedSetType),
136 | StructField("embeddedmap", EmbeddedMapType)
137 | ))
138 | }
139 | ```
140 |
141 | ```
142 | val expectedDataForEmbeddedUDTs: Seq[Row] = Seq(
143 | Row(EmbeddedList(Array(1, 1.toByte, true, TestUtils.toDate(2015, 6, 1), 1234152.12312498,
144 | 1.0f, 42, 1239012341823719L, 23.toShort, "Unicode's樂趣",
145 | TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1))),
146 | EmbeddedSet(Array(1, 1.toByte, true, TestUtils.toDate(2015, 6, 1), 1234152.12312498,
147 | 1.0f, 42, 1239012341823719L, 23.toShort, "Unicode's樂趣",
148 | TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1))),
149 | EmbeddedMap(Map(1 -> 1, 2 -> 1.toByte, 3 -> true, 4 -> TestUtils.toDate(2015, 6, 1), 5 -> 1234152.12312498,
150 | 6 -> 1.0f, 7 -> 42, 8 -> 1239012341823719L, 9 -> 23.toShort, 10 -> "Unicode's樂趣", 11 -> TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1))))
151 | )
152 | ```
153 |
154 | #### Support for Link Types( Since Spark 2.1 release):
155 |
156 | ```
157 | val testSchemaForLinkUDTs: StructType = {
158 | StructType(Seq(
159 | StructField("linklist", LinkListType),
160 | StructField("linkset", LinkSetType),
161 | StructField("linkmap", LinkMapType),
162 | StructField("linkbag", LinkBagType)
163 | ))
164 | }
165 | ```
166 |
167 | ```
168 | val expectedDataForLinkUDTs: Seq[Row] = Seq(
169 | Row(LinkList(Array(oDocument1)), LinkSet(Array(oDocument1)), LinkMap(Map("1" -> oDocument1)), LinkBag(Array(oRid1))),
170 | Row(LinkList(Array(oDocument2)), LinkSet(Array(oDocument2)), LinkMap(Map("1" -> oDocument2)), LinkBag(Array(oRid2))),
171 | Row(LinkList(Array(oDocument3)), LinkSet(Array(oDocument3)), LinkMap(Map("1" -> oDocument3)), LinkBag(Array(oRid3))),
172 | Row(LinkList(Array(oDocument4)), LinkSet(Array(oDocument4)), LinkMap(Map("1" -> oDocument4)), LinkBag(Array(oRid4))),
173 | Row(LinkList(Array(oDocument5)), LinkSet(Array(oDocument5)), LinkMap(Map("1" -> oDocument5)), LinkBag(Array(oRid5)))
174 | )
175 | ```
176 |
177 | ### OrientDB Graphs:
178 |
179 | #### Create Vertex api:
180 |
181 | ```
182 | import org.apache.spark.sql.SQLContext
183 |
184 | val sqlContext = new SQLContext(sc)
185 | sqlContext.createDataFrame(sc.parallelize(Array(1, 2, 3, 4, 5)),
186 | StructType(Seq(StructField("id", IntegerType)))
187 | .write
188 | .format("org.apache.spark.orientdb.graphs")
189 | .option("dburl", ORIENTDB_CONNECTION_URL)
190 | .option("user", ORIENTDB_USER)
191 | .option("password", ORIENTDB_PASSWORD)
192 | .option("vertextype", test_vertex_type2)
193 | .mode(SaveMode.Overwrite)
194 | .save()
195 | ```
196 |
197 | #### Create Edge api:
198 |
199 | ```
200 | import org.apache.spark.sql.SQLContext
201 |
202 | val sqlContext = new SQLContext(sc)
203 | sqlContext.createDataFrame(
204 | sc.parallelize(Seq(
205 | Row(1, 2, "friends"),
206 | Row(2, 3, "enemy"),
207 | Row(3, 4, "friends"),
208 | Row(4, 1, "enemy")
209 | )),
210 | StructType(Seq(
211 | StructField("src", IntegerType),
212 | StructField("dst", IntegerType),
213 | StructField("relationship", StringType)
214 | )))
215 | .write
216 | .format("org.apache.spark.orientdb.graphs")
217 | .option("dburl", ORIENTDB_CONNECTION_URL)
218 | .option("user", ORIENTDB_USER)
219 | .option("password", ORIENTDB_PASSWORD)
220 | .option("vertextype", test_vertex_type2)
221 | .option("edgetype", test_edge_type2)
222 | .mode(SaveMode.Overwrite)
223 | .save()
224 | ```
225 |
226 | #### Read Vertex api:
227 |
228 | ```
229 | import org.apache.spark.sql.SQLContext
230 |
231 | val sqlContext = new SQLContext(sc)
232 | val loadedDf = sqlContext.read
233 | .format("org.apache.spark.orientdb.graphs")
234 | .option("dburl", ORIENTDB_CONNECTION_URL)
235 | .option("user", ORIENTDB_USER)
236 | .option("password", ORIENTDB_PASSWORD)
237 | .option("vertextype", test_vertex_type2)
238 | .load()
239 | ```
240 |
241 | #### Read edge api:
242 |
243 | ```
244 | import org.apache.spark.sql.SQLContext
245 |
246 | val sqlContext = new SQLContext(sc)
247 | val loadedDf = sqlContext.read
248 | .format("org.apache.spark.orientdb.graphs")
249 | .option("dburl", ORIENTDB_CONNECTION_URL)
250 | .option("user", ORIENTDB_USER)
251 | .option("password", ORIENTDB_PASSWORD)
252 | .option("edgetype", test_edge_type2)
253 | .load()
254 | ```
255 |
256 | #### Query using OrientDB Graph SQL:
257 |
258 | ```
259 | import org.apache.spark.sql.SQLContext
260 |
261 | val sqlContext = new SQLContext(sc)
262 | val loadedVerticesDf = sqlContext.read
263 | .format("org.apache.spark.orientdb.graphs")
264 | .option("dburl", ORIENTDB_CONNECTION_URL)
265 | .option("user", ORIENTDB_USER)
266 | .option("password", ORIENTDB_PASSWORD)
267 | .option("vertextype", test_vertex_type2)
268 | .option("query", s"select * from $test_vertex_type2 where teststring = 'asdf'")
269 | .load()
270 |
271 | val loadedEdgesDf = sqlContext.read
272 | .format("org.apache.spark.orientdb.graphs")
273 | .option("dburl", ORIENTDB_CONNECTION_URL)
274 | .option("user", ORIENTDB_USER)
275 | .option("password", ORIENTDB_PASSWORD)
276 | .option("edgetype", test_edge_type2)
277 | .option("query", s"select * from $test_edge_type2 where relationship = 'friends'")
278 | .load()
279 | ```
280 |
281 | #### Support for embedded types & link types( Since Spark 2.1 release)
282 |
283 | The Spark UDTs are available for OrientDB Graph datasource as well.
284 | Usage is very similar to the ones documented for OrientDB Document datasource.
285 | Examples can be found in Integration tests.
286 |
287 | ### Integration with GraphFrames
288 |
289 | ```
290 | import org.apache.spark.sql.SQLContext
291 |
292 | val sqlContext = new SQLContext(sc)
293 | val loadedVerticesDf = sqlContext.read
294 | .format("org.apache.spark.orientdb.graphs")
295 | .option("dburl", ORIENTDB_CONNECTION_URL)
296 | .option("user", ORIENTDB_USER)
297 | .option("password", ORIENTDB_PASSWORD)
298 | .option("vertextype", test_vertex_type2)
299 | .option("query", s"select * from $test_vertex_type2 where teststring = 'asdf'")
300 | .load()
301 |
302 | val loadedEdgesDf = sqlContext.read
303 | .format("org.apache.spark.orientdb.graphs")
304 | .option("dburl", ORIENTDB_CONNECTION_URL)
305 | .option("user", ORIENTDB_USER)
306 | .option("password", ORIENTDB_PASSWORD)
307 | .option("edgetype", test_edge_type2)
308 | .option("query", s"select * from $test_edge_type2 where relationship = 'friends'")
309 | .load()
310 |
311 | val g = GraphFrame(loadedVerticesDf, loadedEdgesDf)
312 | ```
313 |
314 | A full example can be found in directory `src/main/examples`
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | org.apache.spark
8 | spark-orientdb-${spark.version}_${scala.binary.version}
9 | 1.4
10 |
11 |
12 | 2.12.12
13 | 2.12
14 | 3.1.1
15 | 3.0.13
16 | 1.8
17 | 1.8
18 |
19 |
20 |
21 |
22 | SparkPackagesRepo
23 | SparkPackagesRepo
24 | https://dl.bintray.com/spark-packages/maven
25 |
26 |
27 | bintray
28 | bintray.com
29 | https://repos.spark-packages.org
30 |
31 |
32 |
33 |
34 |
35 | org.scala-lang
36 | scala-compiler
37 | ${scala.version}
38 |
39 |
40 | com.orientechnologies
41 | orientdb-graphdb
42 | ${orientdb.graphdb.version}
43 |
44 |
45 | org.scalatest
46 | scalatest_${scala.binary.version}
47 | 3.0.0
48 | test
49 |
50 |
51 | org.mockito
52 | mockito-all
53 | 1.10.19
54 | test
55 |
56 |
57 | org.apache.spark
58 | spark-core_${scala.binary.version}
59 | ${spark.version}
60 |
61 |
62 | org.apache.spark
63 | spark-sql_${scala.binary.version}
64 | ${spark.version}
65 |
66 |
67 | org.apache.spark
68 | spark-hive_${scala.binary.version}
69 | ${spark.version}
70 |
71 |
72 | org.apache.spark
73 | spark-graphx_${scala.binary.version}
74 | ${spark.version}
75 |
76 |
77 |
78 | graphframes
79 | graphframes
80 | 0.8.1-spark3.0-s_${scala.binary.version}
81 |
82 |
83 |
84 |
85 |
86 |
87 | bintray-sbcd90-spark-orientdb
88 |
89 | https://api.bintray.com/maven/sbcd90/org.apache.spark/spark-orientdb-${spark.version}_${scala.binary.version}/;publish=1
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 | org.apache.maven.plugins
98 | maven-compiler-plugin
99 | 3.5.1
100 |
101 |
102 | org.scala-tools
103 | maven-scala-plugin
104 |
105 |
106 |
107 | compile
108 | testCompile
109 |
110 |
111 |
112 |
113 | src/main/scala
114 |
115 | -Xms64m
116 | -Xmx1024m
117 |
118 |
119 |
120 |
121 |
122 | org.apache.maven.plugins
123 | maven-surefire-plugin
124 | 2.19.1
125 |
126 | true
127 |
128 |
129 |
130 |
131 | org.scalatest
132 | scalatest-maven-plugin
133 | 1.0
134 |
135 | ${project.build.directory}/surefire-reports
136 | .
137 | TestSuite.txt
138 |
139 |
140 |
141 | test
142 |
143 | test
144 |
145 |
146 |
147 |
148 |
149 | org.apache.maven.plugins
150 | maven-jar-plugin
151 | 2.3.2
152 |
153 | spark-orientdb-${spark.version}_${scala.binary.version}
154 |
155 |
156 |
157 |
158 |
159 |
--------------------------------------------------------------------------------
/src/it/scala/org/apache/spark/orientdb/documents/IntegrationSuiteBase.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import com.orientechnologies.orient.core.db.document.ODatabaseDocument
4 | import org.apache.spark.SparkContext
5 | import org.apache.spark.orientdb.QueryTest
6 | import org.apache.spark.sql.types.StructType
7 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
8 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers}
9 |
10 | trait IntegrationSuiteBase
11 | extends QueryTest
12 | with Matchers
13 | with BeforeAndAfterAll
14 | with BeforeAndAfterEach {
15 |
16 | protected def loadConfigFromEnv(envVarName: String): String = {
17 | Option(System.getenv(envVarName)).getOrElse{
18 | fail(s"Must set $envVarName environment variable")
19 | }
20 | }
21 |
22 | protected val ORIENTDB_CONNECTION_URL = loadConfigFromEnv("ORIENTDB_CONN_URL")
23 | protected val ORIENTDB_USER = loadConfigFromEnv("ORIENTDB_USER")
24 | protected val ORIENTDB_PASSWORD = loadConfigFromEnv("ORIENTDB_PASSWORD")
25 |
26 | protected var sc: SparkContext = _
27 | protected var sqlContext: SQLContext = _
28 | protected var connection: ODatabaseDocument = _
29 |
30 | protected val orientDBWrapper = DefaultOrientDBDocumentWrapper
31 |
32 | override def beforeAll(): Unit = {
33 | super.beforeAll()
34 | sc = new SparkContext("local", "OrientDBSourceSuite")
35 |
36 | val parameters = Map("dburl" -> ORIENTDB_CONNECTION_URL,
37 | "user" -> ORIENTDB_USER,
38 | "password" -> ORIENTDB_PASSWORD,
39 | "class" -> "dummy")
40 | connection = orientDBWrapper.getConnection(Parameters.mergeParameters(parameters))
41 | }
42 |
43 | override def afterAll(): Unit = {
44 | try {
45 | connection.close()
46 | } finally {
47 | try {
48 | sc.stop()
49 | } finally {
50 | super.afterAll()
51 | }
52 | }
53 | }
54 |
55 | override protected def beforeEach(): Unit = {
56 | super.beforeEach()
57 | sqlContext = new SQLContext(sc)
58 | }
59 |
60 | def testRoundtripSaveAndLoad(className: String,
61 | df: DataFrame,
62 | expectedSchemaAfterLoad: Option[StructType] = None,
63 | saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = {
64 | try {
65 | df.write
66 | .format("org.apache.spark.orientdb.documents")
67 | .option("dburl", ORIENTDB_CONNECTION_URL)
68 | .option("user", ORIENTDB_USER)
69 | .option("password", ORIENTDB_PASSWORD)
70 | .option("class", className)
71 | .mode(saveMode)
72 | .save()
73 |
74 | if (!orientDBWrapper.doesClassExists(className)) {
75 | Thread.sleep(1000)
76 | assert(orientDBWrapper.doesClassExists(className))
77 | }
78 |
79 | val loadedDf = sqlContext.read.format("org.apache.spark.orientdb.documents")
80 | .option("dburl", ORIENTDB_CONNECTION_URL)
81 | .option("user", ORIENTDB_USER)
82 | .option("password", ORIENTDB_PASSWORD)
83 | .option("class", className)
84 | .load()
85 | assert(loadedDf.schema === expectedSchemaAfterLoad.getOrElse(df.schema))
86 | checkAnswer(loadedDf, df.collect())
87 | } finally {
88 | orientDBWrapper.delete(null, className, null)
89 | val schema = connection.getMetadata.getSchema
90 | schema.dropClass(className)
91 | }
92 | }
93 | }
--------------------------------------------------------------------------------
/src/it/scala/org/apache/spark/orientdb/graphs/IntegrationSuiteBase.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import com.tinkerpop.blueprints.impls.orient.OrientGraphNoTx
4 | import org.apache.spark.SparkContext
5 | import org.apache.spark.orientdb.QueryTest
6 | import org.apache.spark.sql.types.StructType
7 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
8 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers}
9 |
10 | trait IntegrationSuiteBase
11 | extends QueryTest
12 | with Matchers
13 | with BeforeAndAfterAll
14 | with BeforeAndAfterEach {
15 |
16 | protected def loadConfigFromEnv(envVarName: String): String = {
17 | Option(System.getenv(envVarName)).getOrElse{
18 | fail(s"Must set $envVarName environment variable")
19 | }
20 | }
21 |
22 | protected val ORIENTDB_CONNECTION_URL = loadConfigFromEnv("ORIENTDB_CONN_URL")
23 | protected val ORIENTDB_USER = loadConfigFromEnv("ORIENTDB_USER")
24 | protected val ORIENTDB_PASSWORD = loadConfigFromEnv("ORIENTDB_PASSWORD")
25 |
26 | protected var sc: SparkContext = _
27 | protected var sqlContext: SQLContext = _
28 | protected var vertex_connection: OrientGraphNoTx = _
29 | protected var edge_connection: OrientGraphNoTx = _
30 |
31 | protected val orientDBGraphVertexWrapper = DefaultOrientDBGraphVertexWrapper
32 | protected val orientDBGraphEdgeWrapper = DefaultOrientDBGraphEdgeWrapper
33 |
34 | override def beforeAll(): Unit = {
35 | super.beforeAll()
36 | sc = new SparkContext("local", "OrientDBSourceSuite")
37 |
38 | val parameters = Map[String, String](
39 | "dburl" -> ORIENTDB_CONNECTION_URL,
40 | "user" -> ORIENTDB_USER,
41 | "password" -> ORIENTDB_PASSWORD)
42 |
43 | val vertexParams: Map[String, String] = parameters ++ Map[String, String]("vertextype" -> "dummy_vertex")
44 | val edgeParams = parameters ++ Map[String, String]("edgetype" -> "dummy_edge")
45 |
46 | vertex_connection = orientDBGraphVertexWrapper
47 | .getConnection(Parameters.mergeParameters(vertexParams))
48 | edge_connection = orientDBGraphEdgeWrapper
49 | .getConnection(Parameters.mergeParameters(edgeParams))
50 | }
51 |
52 | override def afterAll(): Unit = {
53 | try {
54 | orientDBGraphVertexWrapper.close()
55 | orientDBGraphEdgeWrapper.close()
56 | } finally {
57 | try {
58 | sc.stop()
59 | } finally {
60 | super.afterAll()
61 | }
62 | }
63 | }
64 |
65 | override protected def beforeEach(): Unit = {
66 | super.beforeEach()
67 | sqlContext = new SQLContext(sc)
68 | }
69 |
70 | def testRoundtripSaveAndLoadForVertices(vertexType: String,
71 | df: DataFrame,
72 | expectedSchemaAfterLoad: Option[StructType] = None,
73 | saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = {
74 | try {
75 | df.write
76 | .format("org.apache.spark.orientdb.graphs")
77 | .option("dburl", ORIENTDB_CONNECTION_URL)
78 | .option("user", ORIENTDB_USER)
79 | .option("password", ORIENTDB_PASSWORD)
80 | .option("vertextype", vertexType)
81 | .mode(saveMode)
82 | .save()
83 |
84 | if (!orientDBGraphVertexWrapper.doesVertexTypeExists(vertexType)) {
85 | Thread.sleep(1000)
86 | assert(orientDBGraphVertexWrapper.doesVertexTypeExists(vertexType))
87 | }
88 |
89 | val loadedDf = sqlContext.read
90 | .format("org.apache.spark.orientdb.graphs")
91 | .option("dburl", ORIENTDB_CONNECTION_URL)
92 | .option("user", ORIENTDB_USER)
93 | .option("password", ORIENTDB_PASSWORD)
94 | .option("vertextype", vertexType)
95 | .load()
96 |
97 | loadedDf.schema.fields.foreach(field => assert(df.schema.fields.contains(field)))
98 | checkAnswer(loadedDf, df.collect())
99 | } finally {
100 | orientDBGraphVertexWrapper.delete(vertexType, null)
101 | vertex_connection.dropVertexType(vertexType)
102 | }
103 | }
104 |
105 | def testRoundtripSaveAndLoadForEdges( vertexType: String,
106 | edgeType: String,
107 | vertexDf: DataFrame,
108 | df: DataFrame,
109 | expectedSchemaAfterLoad: Option[StructType] = None,
110 | saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = {
111 | try {
112 | vertexDf.write
113 | .format("org.apache.spark.orientdb.graphs")
114 | .option("dburl", ORIENTDB_CONNECTION_URL)
115 | .option("user", ORIENTDB_USER)
116 | .option("password", ORIENTDB_PASSWORD)
117 | .option("vertextype", vertexType)
118 | .mode(saveMode)
119 | .save()
120 |
121 | df.write
122 | .format("org.apache.spark.orientdb.graphs")
123 | .option("dburl", ORIENTDB_CONNECTION_URL)
124 | .option("user", ORIENTDB_USER)
125 | .option("password", ORIENTDB_PASSWORD)
126 | .option("vertextype", vertexType)
127 | .option("edgetype", edgeType)
128 | .mode(saveMode)
129 | .save()
130 |
131 | if (!orientDBGraphEdgeWrapper.doesEdgeTypeExists(edgeType)) {
132 | Thread.sleep(1000)
133 | assert(orientDBGraphEdgeWrapper.doesEdgeTypeExists(edgeType))
134 | }
135 |
136 | val loadedDf = sqlContext.read
137 | .format("org.apache.spark.orientdb.graphs")
138 | .option("dburl", ORIENTDB_CONNECTION_URL)
139 | .option("user", ORIENTDB_USER)
140 | .option("password", ORIENTDB_PASSWORD)
141 | .option("edgetype", edgeType)
142 | .load()
143 |
144 | loadedDf.schema.fields.foreach(field => assert(df.schema.fields.contains(field)))
145 | assert(loadedDf.count() === df.collect().size)
146 | } finally {
147 | try {
148 | orientDBGraphEdgeWrapper.delete(edgeType, null)
149 | orientDBGraphVertexWrapper.delete(vertexType, null)
150 | } finally {
151 | edge_connection.dropEdgeType(edgeType)
152 | vertex_connection.dropVertexType(vertexType)
153 | }
154 | }
155 | }
156 | }
--------------------------------------------------------------------------------
/src/main/examples/org/apache/spark/orientdb/documents/DataFrameTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import org.apache.spark.sql.{SQLContext, SaveMode, SparkSession}
4 | import org.apache.spark.{SparkConf, SparkContext}
5 |
6 | object DataFrameTest extends App {
7 | val spark = SparkSession.builder().appName("DataFrameTest").master("local[*]").getOrCreate()
8 |
9 | import spark.implicits._
10 | val df = spark.sparkContext.parallelize(Array(1, 2, 3, 4, 5)).toDF("id")
11 |
12 | df.write.format("org.apache.spark.orientdb.documents")
13 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts")
14 | .option("user", "root")
15 | .option("password", "root")
16 | .option("class", "test_class")
17 | .mode(SaveMode.Overwrite)
18 | .save()
19 |
20 | val resultDf = spark.sqlContext.read
21 | .format("org.apache.spark.orientdb.documents")
22 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts")
23 | .option("user", "root")
24 | .option("password", "root")
25 | .option("class", "test_class")
26 | .load()
27 |
28 | resultDf.show()
29 | }
--------------------------------------------------------------------------------
/src/main/examples/org/apache/spark/orientdb/graphs/GraphFrameTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import org.apache.spark.{SparkConf, SparkContext}
4 | import org.apache.spark.sql.{Row, SQLContext, SaveMode, SparkSession}
5 | import org.apache.spark.sql.types.{StringType, StructField, StructType}
6 | import org.graphframes.GraphFrame
7 |
8 | object GraphFrameTest extends App {
9 | val spark = SparkSession.builder().appName("MainApplication").master("local[*]").getOrCreate()
10 |
11 | val sc = spark.sparkContext
12 | sc.setLogLevel("WARN")
13 | val sqlContext = spark.sqlContext
14 |
15 | import sqlContext.implicits._
16 | val df = sc.parallelize(Array(1, 2, 3, 4, 5)).toDF("id")
17 |
18 | df.write.format("org.apache.spark.orientdb.graphs")
19 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts")
20 | .option("user", "root")
21 | .option("password", "root")
22 | .option("vertextype", "v104")
23 | .mode(SaveMode.Overwrite)
24 | .save()
25 |
26 | val vertices = sqlContext.read
27 | .format("org.apache.spark.orientdb.graphs")
28 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts")
29 | .option("user", "root")
30 | .option("password", "root")
31 | .option("vertextype", "v104")
32 | .load()
33 |
34 | var inVertex: Integer = null
35 | var outVertex: Integer = null
36 | vertices.collect().foreach(row => {
37 | if (inVertex == null) {
38 | inVertex = row.getAs[Integer]("id")
39 | }
40 | if (outVertex == null) {
41 | outVertex = row.getAs[Integer]("id")
42 | }
43 | })
44 |
45 | val df1 = sqlContext.createDataFrame(sc.parallelize(Seq(Row("friends", "1", "2"),
46 | Row("enemies", "2", "3"), Row("friends", "3", "1"))),
47 | StructType(List(StructField("relationship", StringType), StructField("src", StringType),
48 | StructField("dst", StringType))))
49 |
50 | df1.write.format("org.apache.spark.orientdb.graphs")
51 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts")
52 | .option("user", "root")
53 | .option("password", "root")
54 | .option("vertextype", "v104")
55 | .option("edgetype", "e104")
56 | .mode(SaveMode.Overwrite)
57 | .save()
58 |
59 | val edges = sqlContext.read
60 | .format("org.apache.spark.orientdb.graphs")
61 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts")
62 | .option("user", "root")
63 | .option("password", "root")
64 | .option("edgetype", "e104")
65 | .load()
66 |
67 | edges.show()
68 |
69 | val g = GraphFrame(vertices, edges)
70 | g.inDegrees.show()
71 | println(g.edges.filter("relationship = 'friends'").count())
72 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/Conversions.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import java.sql.{Date, Timestamp}
4 | import java.text.SimpleDateFormat
5 | import java.util
6 | import java.util.{Locale, Map}
7 | import java.util.function.Consumer
8 |
9 | import com.orientechnologies.orient.core.db.record._
10 | import com.orientechnologies.orient.core.db.record.ridbag.ORidBag
11 | import com.orientechnologies.orient.core.id.ORecordId
12 | import com.orientechnologies.orient.core.metadata.schema.OType
13 | import com.orientechnologies.orient.core.record.ORecord
14 | import com.orientechnologies.orient.core.record.impl.ODocument
15 | import com.tinkerpop.blueprints.{Edge, Vertex}
16 | import org.apache.spark.orientdb.udts._
17 | import org.apache.spark.sql.Row
18 | import org.apache.spark.sql.types._
19 |
20 | import scala.collection.JavaConversions._
21 | import scala.collection.mutable
22 |
23 | private[orientdb] object Conversions {
24 | def sparkDTtoOrientDBDT(dataType: DataType): OType = {
25 | dataType match {
26 | case ByteType => OType.BYTE
27 | case ShortType => OType.SHORT
28 | case IntegerType => OType.INTEGER
29 | case LongType => OType.LONG
30 | case FloatType => OType.FLOAT
31 | case DoubleType => OType.DOUBLE
32 | case _: DecimalType => OType.DECIMAL
33 | case StringType => OType.STRING
34 | case BinaryType => OType.BINARY
35 | case BooleanType => OType.BOOLEAN
36 | case DateType => OType.DATE
37 | case TimestampType => OType.DATETIME
38 | case _: EmbeddedListType => OType.EMBEDDEDLIST
39 | case _: EmbeddedSetType => OType.EMBEDDEDSET
40 | case _: EmbeddedMapType => OType.EMBEDDEDMAP
41 | case _: LinkListType => OType.LINKLIST
42 | case _: LinkSetType => OType.LINKSET
43 | case _: LinkMapType => OType.LINKMAP
44 | case _: LinkBagType => OType.LINKBAG
45 | case _: LinkType => OType.LINK
46 | case other => throw new UnsupportedOperationException(s"Unexpected DataType $dataType")
47 | }
48 | }
49 |
50 | /* def orientDBDTtoSparkDT(dataType: OType): DataType = {
51 | dataType match {
52 | case OType.BYTE => ByteType
53 | case OType.SHORT => ShortType
54 | case OType.INTEGER => IntegerType
55 | case OType.LONG => LongType
56 | case OType.FLOAT => FloatType
57 | case OType.DOUBLE => DoubleType
58 | case OType.DECIMAL =>
59 | DecimalType(DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE)
60 | case OType.STRING => StringType
61 | case OType.BINARY => BinaryType
62 | case OType.BOOLEAN => BooleanType
63 | case OType.DATE => DateType
64 | case OType.DATETIME => TimestampType
65 | case other => throw new UnsupportedOperationException(s"Unexpected DataType $dataType")
66 | }
67 | } */
68 |
69 | def convertRowsToODocuments(row: Row): ODocument = {
70 | val oDocument = new ODocument()
71 | row.schema.fields.foreach(field => {
72 | oDocument.field(field.name, getField(row, field),
73 | Conversions.sparkDTtoOrientDBDT(field.dataType))
74 | })
75 | oDocument
76 | }
77 |
78 | def orientDBDTtoSparkDT(dataType: DataType, field: AnyRef) = {
79 | if (field == null)
80 | field
81 | else {
82 | val dateFormat = new SimpleDateFormat("E MMM dd HH:mm:ss Z yyyy", Locale.ENGLISH)
83 | dataType match {
84 | case ByteType => java.lang.Byte.valueOf(field.toString)
85 | case ShortType => field.toString.toShort
86 | case IntegerType => field.toString.toInt
87 | case LongType => field.toString.toLong
88 | case FloatType => field.toString.toFloat
89 | case DoubleType => field.toString.toDouble
90 | case _: DecimalType => field.asInstanceOf[java.math.BigDecimal]
91 | case StringType => field
92 | case BinaryType => field
93 | case BooleanType => field.toString.toBoolean
94 | case DateType => new Date(dateFormat.parse(field.toString).getTime)
95 | case TimestampType => new Timestamp(dateFormat.parse(field.toString).getTime)
96 | case _: EmbeddedListType =>
97 | var elements = Array[Any]()
98 | field.asInstanceOf[util.ArrayList[Any]].forEach(new Consumer[Any] {
99 | override def accept(t: Any): Unit = elements :+= t
100 | })
101 | new EmbeddedList(elements)
102 | case _: EmbeddedSetType =>
103 | var elements = Array[Any]()
104 | field.asInstanceOf[OTrackedSet[Any]].forEach(new Consumer[Any] {
105 | override def accept(t: Any): Unit = elements :+= t
106 | })
107 | new EmbeddedSet(elements)
108 | case _: EmbeddedMapType =>
109 | var elements = mutable.Map[Any, Any]()
110 | field.asInstanceOf[OTrackedMap[Any]].entrySet().forEach(new Consumer[Map.Entry[AnyRef, Any]] {
111 | override def accept(t: Map.Entry[AnyRef, Any]): Unit = {
112 | elements.put(t.getKey, t.getValue)
113 | }
114 | })
115 | new EmbeddedMap(elements.toMap)
116 | case _: LinkListType =>
117 | var elements = Array[ORecord]()
118 | field.asInstanceOf[ORecordLazyList].forEach(new Consumer[OIdentifiable] {
119 | override def accept(t: OIdentifiable): Unit = t match {
120 | case recordId: ORecordId =>
121 | elements:+= new ODocument(recordId).asInstanceOf[ORecord]
122 | case record: ORecord =>
123 | elements :+= record
124 | }
125 | })
126 | new LinkList(elements)
127 | case _: LinkSetType =>
128 | var elements = Array[ORecord]()
129 | field.asInstanceOf[ORecordLazySet].forEach(new Consumer[OIdentifiable] {
130 | override def accept(t: OIdentifiable): Unit = t match {
131 | case recordId: ORecordId =>
132 | elements:+= new ODocument(recordId).asInstanceOf[ORecord]
133 | case record: ORecord =>
134 | elements :+= record
135 | }
136 | })
137 | new LinkSet(elements)
138 | case _: LinkMapType =>
139 | val elements = mutable.Map[String, ORecord]()
140 | field.asInstanceOf[ORecordLazyMap].entrySet().forEach(new Consumer[Map.Entry[AnyRef, OIdentifiable]] {
141 | override def accept(t: Map.Entry[AnyRef, OIdentifiable]): Unit = {
142 | elements.put(t.getKey.toString, t.getValue match {
143 | case recordId: ORecordId =>
144 | new ODocument(recordId).asInstanceOf[ORecord]
145 | case record: ORecord =>
146 | record
147 | })
148 | }
149 | })
150 | new LinkMap(elements.toMap)
151 | case _: LinkBagType =>
152 | var elements = Array[ORecordId]()
153 | field.asInstanceOf[ORidBag].rawIterator().forEachRemaining(new Consumer[OIdentifiable] {
154 | override def accept(t: OIdentifiable): Unit = t match {
155 | case recordId: ORecordId =>
156 | elements :+= recordId
157 | case other => sys.error(s"LinkBag cannot be of type ${other.getClass}")
158 | }
159 | })
160 | new LinkBag(elements)
161 | case _: LinkType =>
162 | val element = field.asInstanceOf[OIdentifiable].getRecord[ORecord]()
163 | new Link(element)
164 | case other => throw new UnsupportedOperationException(s"Unexpected DataType $dataType")
165 | }
166 | }
167 | }
168 |
169 | def convertRowToGraph(row: Row, count: Int): AnyRef = {
170 | getField(row, row.schema.fields(count))
171 | }
172 |
173 | /* def orientDBDTtoSparkDT(dataType: OType, field: String) = {
174 | if (field == null)
175 | field
176 | else {
177 | val dateFormat = new SimpleDateFormat("E MMM dd HH:mm:ss Z yyyy")
178 | dataType match {
179 | case OType.BYTE => field.toByte
180 | case OType.SHORT => field.toShort
181 | case OType.INTEGER => field.toInt
182 | case OType.LONG => field.toLong
183 | case OType.FLOAT => field.toFloat
184 | case OType.DOUBLE => field.toDouble
185 | case OType.DECIMAL => field.asInstanceOf[java.math.BigDecimal]
186 | case OType.STRING => field
187 | case OType.BINARY => field
188 | case OType.BOOLEAN => field.toBoolean
189 | case OType.DATE => new Date(dateFormat.parse(field).getTime)
190 | case OType.DATETIME => new Timestamp(dateFormat.parse(field).getTime)
191 | case _ => throw new UnsupportedOperationException(s"Unexpected DataType $dataType")
192 | }
193 | }
194 | } */
195 |
196 | def convertODocumentsToRows(oDocument: ODocument, schema: StructType): Row = {
197 | val converted: scala.collection.mutable.IndexedSeq[Any] = mutable.IndexedSeq.fill(schema.length)(null)
198 | val fieldNames = oDocument.fieldNames()
199 | val fieldValues = oDocument.fieldValues()
200 |
201 | var i = 0
202 | while (i < schema.length) {
203 | if (fieldNames.contains(schema.fields(i).name)) {
204 | val idx = fieldNames.indexOf(schema.fields(i).name)
205 | val value = fieldValues(idx)
206 |
207 | converted(i) = orientDBDTtoSparkDT(schema.fields(i).dataType, value)
208 | } else {
209 | converted(i) = null
210 | }
211 | i = i + 1
212 | }
213 |
214 | Row.fromSeq(converted)
215 | }
216 |
217 | def convertVerticesToRows(vertex: Vertex, schema: StructType): Row = {
218 | val converted: scala.collection.mutable.IndexedSeq[Any] = mutable.IndexedSeq.fill(schema.length)(null)
219 | val fieldNames = vertex.getPropertyKeys
220 |
221 | var i = 0
222 | while (i < schema.length) {
223 | if (fieldNames.contains(schema.fields(i).name)) {
224 | val value = vertex.getProperty[Object](schema.fields(i).name)
225 |
226 | converted(i) = orientDBDTtoSparkDT(schema.fields(i).dataType, value)
227 | } else {
228 | converted(i) = null
229 | }
230 | i = i + 1
231 | }
232 |
233 | Row.fromSeq(converted)
234 | }
235 |
236 | def convertEdgesToRows(edge: Edge, schema: StructType): Row = {
237 | val converted: scala.collection.mutable.IndexedSeq[Any] = mutable.IndexedSeq.fill(schema.length)(null)
238 | val fieldNames = edge.getPropertyKeys
239 |
240 | var i = 0
241 | while (i < schema.length) {
242 | if (fieldNames.contains(schema.fields(i).name)) {
243 | val value = edge.getProperty[Object](schema.fields(i).name)
244 |
245 | converted(i) = orientDBDTtoSparkDT(schema.fields(i).dataType, value)
246 | } else {
247 | converted(i) = null
248 | }
249 | i = i + 1
250 | }
251 |
252 | Row.fromSeq(converted)
253 | }
254 |
255 | private def getField(row: Row, field: StructField) = field.dataType.typeName match {
256 | case "embeddedlist" => row.getAs[field.dataType.type](field.name).asInstanceOf[EmbeddedList].elements
257 | case "embeddedset" => row.getAs[field.dataType.type](field.name).asInstanceOf[EmbeddedSet].elements
258 | case "embeddedmap" => mapAsJavaMap(row.getAs[field.dataType.type](field.name).asInstanceOf[EmbeddedMap].elements)
259 | case "linklist" => row.getAs[field.dataType.type](field.name).asInstanceOf[LinkList].elements
260 | case "linkset" => row.getAs[field.dataType.type](field.name).asInstanceOf[LinkSet].elements
261 | case "linkmap" => mapAsJavaMap(row.getAs[field.dataType.type](field.name).asInstanceOf[LinkMap].elements)
262 | case "linkbag" =>
263 | val oRidBag = new ORidBag()
264 | oRidBag.addAll(row.getAs[field.dataType.type ](field.name).asInstanceOf[LinkBag].elements.toSeq)
265 | oRidBag
266 | case "link" =>
267 | row.getAs[field.dataType.type](field.name).asInstanceOf[Link].element
268 | case _ => row.getAs[field.dataType.type](field.name)
269 | }
270 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/DefaultSource.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
4 | import org.apache.spark.sql.types.StructType
5 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
6 | import org.slf4j.LoggerFactory
7 |
8 | class DefaultSource(orientDBWrapper: OrientDBDocumentWrapper,
9 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory)
10 | extends RelationProvider
11 | with SchemaRelationProvider
12 | with CreatableRelationProvider {
13 |
14 | private val log = LoggerFactory.getLogger(getClass)
15 |
16 | def this() = this(DefaultOrientDBDocumentWrapper, orientDBCredentials => new OrientDBClientFactory(orientDBCredentials))
17 |
18 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
19 | val params = Parameters.mergeParameters(parameters)
20 |
21 | if (params.query.isDefined && params.className.isEmpty) {
22 | throw new IllegalArgumentException("Along with the 'query' parameter you must specify either 'class' parameter or"+
23 | " user-defined schema")
24 | }
25 |
26 | OrientDBRelation(orientDBWrapper, orientDBClientFactory, params, None)(sqlContext)
27 | }
28 |
29 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String],
30 | schema: StructType): BaseRelation = {
31 | val params = Parameters.mergeParameters(parameters)
32 | OrientDBRelation(orientDBWrapper, orientDBClientFactory, params, Some(schema))(sqlContext)
33 | }
34 |
35 | override def createRelation(sqlContext: SQLContext, mode: SaveMode,
36 | parameters: Map[String, String], data: DataFrame): BaseRelation = {
37 | val params = Parameters.mergeParameters(parameters)
38 | val classname = params.className.getOrElse{
39 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " +
40 | "name with the 'classname' parameter")
41 | }
42 |
43 | val clusters = params.clusterNames.getOrElse{
44 | log.warn("Orient DB cluster name not specified. Using default Cluster Id for the class specified")
45 | }
46 |
47 | def tableExists: Boolean = {
48 | val connection = orientDBWrapper.getConnection(params)
49 | try {
50 | orientDBWrapper.doesClassExists(classname)
51 | } finally {
52 | connection.close()
53 | }
54 | }
55 |
56 | val (doSave, dropExisting) = mode match {
57 | case SaveMode.Append => (true, false)
58 | case SaveMode.Overwrite => (true, true)
59 | case SaveMode.ErrorIfExists =>
60 | if (tableExists) {
61 | sys.error(s"Class $classname already exists! (SaveMode is set to ErrorIfExists)")
62 | } else {
63 | (true, false)
64 | }
65 | case SaveMode.Ignore =>
66 | if (tableExists) {
67 | log.info(s"Class $classname already exists. Ignoring save request.")
68 | (false, false)
69 | } else {
70 | (true, false)
71 | }
72 | }
73 |
74 | if (doSave) {
75 | val updatedParams = parameters.updated("overwrite", dropExisting.toString)
76 | new OrientDBWriter(orientDBWrapper, orientDBClientFactory)
77 | .saveToOrientDB(data, mode, Parameters.mergeParameters(updatedParams))
78 | }
79 | createRelation(sqlContext, parameters)
80 | }
81 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/FilterPushdown.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import java.sql.{Date, Timestamp}
4 |
5 | import org.apache.spark.sql.sources._
6 | import org.apache.spark.sql.types._
7 |
8 | private[orientdb] object FilterPushdown {
9 | /**
10 | * Build a SQL WHERE clause for the given filters. If a filter cannot be pushed down then no
11 | * condition will be added to the WHERE clause. If none of the filters can be pushed down then
12 | * an empty string will be returned.
13 | *
14 | * @param schema the schema of the table being queried
15 | * @param filters an array of filters, the conjunction of which is the filter condition for the
16 | * scan.
17 | */
18 | def buildWhereClause(schema: StructType, filters: Seq[Filter]): String = {
19 | val filterExpressions = filters.flatMap(f => buildFilterExpression(schema, f)).mkString(" AND ")
20 | if (filterExpressions.isEmpty) "" else "WHERE " + filterExpressions
21 | }
22 |
23 | /**
24 | * Attempt to convert the given filter into a SQL expression. Returns None if the expression
25 | * could not be converted.
26 | */
27 | def buildFilterExpression(schema: StructType, filter: Filter): Option[String] = {
28 | def buildComparison(attr: String, value: Any, comparisonOp: String): Option[String] = {
29 | getTypeForAttribute(schema, attr).map { dataType =>
30 | val sqlEscapedValue: String = dataType match {
31 | case StringType => s"'${value.toString.replace("'", "\\'\\'")}'"
32 | case DateType => s"'${value.asInstanceOf[Date]}'"
33 | case TimestampType => s"'${value.asInstanceOf[Timestamp]}'"
34 | case _ => value.toString
35 | }
36 | s"""$attr $comparisonOp $sqlEscapedValue"""
37 | }
38 | }
39 |
40 | filter match {
41 | case EqualTo(attr, value) => buildComparison(attr, value, "=")
42 | case LessThan(attr, value) => buildComparison(attr, value, "<")
43 | case GreaterThan(attr, value) => buildComparison(attr, value, ">")
44 | case LessThanOrEqual(attr, value) => buildComparison(attr, value, "<=")
45 | case GreaterThanOrEqual(attr, value) => buildComparison(attr, value, ">=")
46 | case IsNotNull(attr) =>
47 | getTypeForAttribute(schema, attr).map(dataType => s"""$attr IS NOT NULL""")
48 | case IsNull(attr) =>
49 | getTypeForAttribute(schema, attr).map(dataType => s"""$attr IS NULL""")
50 | case _ => None
51 | }
52 | }
53 |
54 | /**
55 | * Use the given schema to look up the attribute's data type. Returns None if the attribute could
56 | * not be resolved.
57 | */
58 | private def getTypeForAttribute(schema: StructType, attribute: String): Option[DataType] = {
59 | if (schema.fieldNames.contains(attribute)) {
60 | Some(schema(attribute).dataType)
61 | } else {
62 | None
63 | }
64 | }
65 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/OrientDBClientFactory.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import com.orientechnologies.orient.core.db.OPartitionedDatabasePool
4 | import com.orientechnologies.orient.core.db.document.ODatabaseDocument
5 |
6 | class OrientDBClientFactory(orientDBCredentials: OrientDBCredentials) extends Serializable {
7 | private val db: OPartitionedDatabasePool =
8 | new OPartitionedDatabasePool(orientDBCredentials.dbUrl, orientDBCredentials.username,
9 | orientDBCredentials.password)
10 | private var connection: ODatabaseDocument = _
11 |
12 | def getConnection(): ODatabaseDocument = {
13 | connection = db.acquire()
14 | connection
15 | }
16 |
17 | def closeConnection(): Unit = {
18 | connection.close()
19 | }
20 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/OrientDBCredentials.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | trait OrientDBCredentials extends Serializable {
4 | var dbUrl: String = null
5 | var username: String = null
6 | var password: String = null
7 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/OrientDBDocumentWrapper.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import java.util
4 | import java.util.function.Consumer
5 |
6 | import com.orientechnologies.orient.core.db.document.ODatabaseDocument
7 | import com.orientechnologies.orient.core.metadata.schema.OType
8 | import com.orientechnologies.orient.core.record.OElement
9 | import com.orientechnologies.orient.core.record.impl.ODocument
10 | import com.orientechnologies.orient.core.sql.executor.OResult
11 | import com.orientechnologies.orient.core.sql.query.OSQLSynchQuery
12 | import com.orientechnologies.orient.core.tx.OTransaction.TXTYPE
13 | import org.apache.spark.orientdb.documents.Parameters.MergedParameters
14 | import org.apache.spark.orientdb.udts._
15 | import org.apache.spark.sql.types._
16 |
17 | import scala.collection.JavaConversions._
18 |
19 | class OrientDBDocumentWrapper extends Serializable {
20 | private var connectionPool: Option[OrientDBClientFactory] = None
21 | private var connection: ODatabaseDocument = _
22 |
23 | def openTransaction(): Unit = {
24 | connection.begin(TXTYPE.OPTIMISTIC)
25 | }
26 |
27 | def commitTransaction(): Unit = {
28 | connection.commit()
29 | }
30 |
31 | def rollbackTransaction(): Unit = {
32 | connection.rollback()
33 | }
34 |
35 | /**
36 | * Get instance of Database connection
37 | * @return Database connection instance
38 | */
39 | def getConnection(params: MergedParameters): ODatabaseDocument = {
40 | try {
41 | if (connectionPool.isEmpty) {
42 | connectionPool = Some(new OrientDBClientFactory(new OrientDBCredentials {
43 | this.dbUrl = params.dbUrl.get
44 | this.username = params.credentials.get._1
45 | this.password = params.credentials.get._2
46 | }))
47 | }
48 | connection = connectionPool.get.getConnection()
49 | connection
50 | } catch {
51 | case e: Exception => throw new RuntimeException(s"Connection Exception Occurred: ${e.getMessage}")
52 | }
53 | }
54 |
55 | /**
56 | * Check if class already exists in Orient DB
57 | * @param classname cluster name in OrientDB
58 | * @return true/false
59 | */
60 | def doesClassExists(classname: String): Boolean = {
61 | val schema = connection.getMetadata.getSchema
62 | schema.existsClass(classname)
63 | }
64 |
65 | /**
66 | * Create API
67 | * @param cluster cluster name in OrientDB
68 | * @param classname class name in OrientDB
69 | * @param document document to be created
70 | */
71 | def create(cluster: String, classname: String, document: ODocument): Boolean = {
72 | if (connection.save(document, cluster) != null)
73 | return true
74 | false
75 | }
76 |
77 | /**
78 | * Read API
79 | * @param clusters cluster name in OrientDB
80 | * @param classname class name in OrientDB
81 | * @param filters filters the no. of records retrieved
82 | * @return
83 | */
84 | def read(clusters: List[String], classname: String, requiredColumns: Array[String],
85 | filters: String, query: String = null): List[ODocument] = {
86 | var documents: java.util.List[ODocument] = new util.ArrayList[ODocument]()
87 | val columns = requiredColumns.mkString(", ")
88 |
89 | if (query == null) {
90 | val oResultSet = connection.query(s"select $columns from $classname $filters")
91 | oResultSet.map(_.toElement).foreach {
92 | case oDocument: ODocument =>
93 | documents.add(oDocument)
94 | case _ =>
95 | throw new RuntimeException("Result is not of type ODocument")
96 | }
97 | } else {
98 | var queryStr = ""
99 |
100 | if (filters != "") {
101 | if (query.contains("WHERE ")) {
102 | val parts = query.split("WHERE ")
103 |
104 | if ( parts.size > 1) {
105 | val firstpart = parts(0)
106 | val secondpart = parts(1)
107 |
108 | queryStr = s"$firstpart $filters and $secondpart"
109 | } else {
110 | queryStr = s"$query $filters"
111 | }
112 | } else if (query.contains("where ")) {
113 | val parts = query.split("where ")
114 |
115 | if ( parts.size > 1) {
116 | val firstpart = parts(0)
117 | val secondpart = parts(1)
118 |
119 | queryStr = s"$firstpart $filters and $secondpart"
120 | } else {
121 | queryStr = s"$query $filters"
122 | }
123 | } else {
124 | queryStr = s"$query $filters"
125 | }
126 | } else queryStr = query
127 | val oResultSet = connection.query(queryStr)
128 | oResultSet.map(_.toElement).foreach {
129 | case oDocument: ODocument =>
130 | documents.add(oDocument)
131 | case _ =>
132 | throw new RuntimeException("Result is not of type ODocument")
133 | }
134 | }
135 |
136 | documents.toList
137 | }
138 |
139 | /**
140 | * Bulk Create API
141 | * @param cluster cluster name in OrientDB
142 | * @param classname class name in OrientDB
143 | * @param documents List of documents to be created
144 | */
145 | def bulkCreate(cluster: String, classname: String, documents: List[ODocument]): Boolean = {
146 | try {
147 | openTransaction()
148 | documents.foreach(document => {
149 | if (!create(cluster, classname, document)) {
150 | rollbackTransaction()
151 | return false
152 | }
153 |
154 | })
155 | commitTransaction()
156 | true
157 | } catch {
158 | case e: Exception =>
159 | rollbackTransaction()
160 | throw new RuntimeException("An exception was thrown: " + e.getMessage)
161 | }
162 | }
163 |
164 | /**
165 | * Update API
166 | * @param cluster cluster name in OrientDB
167 | * @param classname class name in OrientDB
168 | * @param record new record to be updated in the format (fieldName, (fieldValue, fieldType))
169 | * @param filter filter which filters records to be updated in the format (fieldName, (filterOperator, fieldValue))
170 | */
171 | def update(cluster: String, classname: String, record: Map[String, Tuple2[String, OType]],
172 | filter: Map[String, Tuple2[String, String]]): Boolean = {
173 | var filterStr = ""
174 |
175 | val filterLength = filter.size
176 | var count = 1
177 | filter.foreach(p => {
178 | if (count == filterLength)
179 | // handle int s, only strings are handled
180 | filterStr += p._1 + " " + p._2._1 + " '" + p._2._2 + "'"
181 | else
182 | filterStr += p._1 + " " + p._2._1 + " '" + p._2._2 + "' and "
183 | count += 1
184 | })
185 |
186 | val documentsToBeUpdated: java.util.List[ODocument] = new util.ArrayList[ODocument]()
187 | val oResultSet = connection.query("select * from " + classname + " where " + filterStr)
188 |
189 | oResultSet.map(_.toElement).foreach {
190 | case oDocument: ODocument =>
191 | documentsToBeUpdated.add(oDocument)
192 | case _ =>
193 | throw new RuntimeException("Result is not of type ODocument")
194 | }
195 |
196 | try {
197 | openTransaction()
198 | documentsToBeUpdated.foreach(document => {
199 | record.foreach(field => {
200 | document.field(field._1, field._2._1, field._2._2)
201 | })
202 | if (!create(cluster, classname, document)) {
203 | rollbackTransaction()
204 | return false
205 | }
206 | })
207 | commitTransaction()
208 | true
209 | } catch {
210 | case e: Exception =>
211 | rollbackTransaction()
212 | throw new RuntimeException("An exception was thrown: " + e.getMessage)
213 | }
214 | }
215 |
216 | /**
217 | * Delete API
218 | * @param cluster cluster name in OrientDB
219 | * @param classname class name in OrientDB
220 | * @param filter filter which filters records to be deleted in the format (fieldName, (filterOperator, fieldValue))
221 | */
222 | def delete(cluster: String, classname: String, filter: Map[String, Tuple2[String, String]]): Boolean = {
223 | try {
224 | var filterStr = ""
225 |
226 | if (filter != null) {
227 | val filterLength = filter.size
228 | var count = 1
229 | filter.foreach(p => {
230 | if (count == filterLength)
231 | // handle int s, only strings are handled
232 | filterStr += p._1 + " " + p._2._1 + " '" + p._2._2 + "'"
233 | else
234 | filterStr += p._1 + " " + p._2._1 + " '" + p._2._2 + "' and "
235 | count += 1
236 | })
237 | }
238 |
239 | val documentsToBeDeleted: java.util.List[ODocument] = new util.ArrayList[ODocument]()
240 | if (filterStr != "") {
241 | val oResultSet = connection.query("select * from " + classname + " where " + filterStr)
242 | oResultSet.map(_.toElement).foreach {
243 | case oDocument: ODocument =>
244 | documentsToBeDeleted.add(oDocument)
245 | case _ =>
246 | throw new RuntimeException("Result is not of type ODocument")
247 | }
248 | } else {
249 | val oResultSet = connection.query("select * from " + classname)
250 | oResultSet.map(_.toElement).foreach {
251 | case oDocument: ODocument =>
252 | documentsToBeDeleted.add(oDocument)
253 | case _ =>
254 | throw new RuntimeException("Result is not of type ODocument")
255 | }
256 | }
257 |
258 | openTransaction()
259 | documentsToBeDeleted.foreach(document => {
260 | if (connection.delete(document) == null) {
261 | rollbackTransaction()
262 | return false
263 | }
264 | })
265 | commitTransaction()
266 | true
267 | } catch {
268 | case e: Exception =>
269 | rollbackTransaction()
270 | throw new RuntimeException("An exception was thrown: " + e.getMessage)
271 | }
272 | }
273 |
274 | /**
275 | * Resolve OrientDB class metadata
276 | * @param cluster cluster name in OrientDB
277 | * @param classname class name in OrientDB
278 | * @return
279 | */
280 | def resolveTable(cluster: String, classname: String): StructType = {
281 | val oClass = connection.getMetadata.getSchema.getClass(classname)
282 |
283 | val properties = oClass.properties()
284 | val ncols = properties.size()
285 | val fields = new Array[StructField](ncols)
286 | val iterator = properties.iterator()
287 | var i = 0
288 | while (iterator.hasNext) {
289 | val property = iterator.next()
290 | val columnName = property.getName
291 | // there are no keys..hence for now every field is nullable
292 | val nullable = true
293 | val columnType = getCatalystType(property.getType)
294 | fields(i) = StructField(columnName, columnType, nullable)
295 | i = i + 1
296 | }
297 | new StructType(fields)
298 | }
299 |
300 | /**
301 | * execute generic query on OrientDB
302 | */
303 | def genericQuery(query: String): List[ODocument] = {
304 | val documents: java.util.List[ODocument] = connection.query(new OSQLSynchQuery[ODocument](query))
305 | documents.toList
306 | }
307 |
308 | private def getCatalystType(oType: OType): DataType = {
309 | val dataType = oType match {
310 | case OType.BOOLEAN => BooleanType
311 | case OType.INTEGER => IntegerType
312 | case OType.SHORT => ShortType
313 | case OType.LONG => LongType
314 | case OType.FLOAT => FloatType
315 | case OType.DOUBLE => DoubleType
316 | case OType.DATETIME => TimestampType
317 | case OType.STRING => StringType
318 | case OType.BINARY => BinaryType
319 | case OType.BYTE => ByteType
320 | case OType.DATE => DateType
321 | case OType.DECIMAL => DecimalType(38, 18)
322 | case OType.EMBEDDEDLIST => new EmbeddedListType
323 | case OType.EMBEDDEDSET => new EmbeddedSetType
324 | case OType.EMBEDDEDMAP => new EmbeddedMapType
325 | case OType.LINKLIST => new LinkListType
326 | case OType.LINKSET => new LinkSetType
327 | case OType.LINKMAP => new LinkMapType
328 | case OType.LINKBAG => new LinkBagType
329 | case OType.LINK => new LinkType
330 | case OType.ANY => null
331 | }
332 | dataType
333 | }
334 | }
335 |
336 | private[orientdb] object DefaultOrientDBDocumentWrapper extends OrientDBDocumentWrapper
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/OrientDBRelation.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import com.orientechnologies.orient.core.record.impl.ODocument
4 | import Parameters.MergedParameters
5 | import org.apache.spark.rdd.RDD
6 | import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan}
7 | import org.apache.spark.sql.types.StructType
8 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
9 | import org.slf4j.LoggerFactory
10 |
11 | private[orientdb] case class OrientDBRelation(
12 | orientDBWrapper: OrientDBDocumentWrapper,
13 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory,
14 | params: MergedParameters,
15 | userSchema: Option[StructType]
16 | ) (@transient val sqlContext: SQLContext)
17 | extends BaseRelation
18 | with PrunedFilteredScan
19 | with InsertableRelation {
20 | private val log = LoggerFactory.getLogger(getClass)
21 |
22 | // any kind of assertion
23 |
24 | private val tableNameOrSubQuery = params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get
25 |
26 | override lazy val schema: StructType = {
27 | userSchema.getOrElse{
28 | val tableName = params.table.map(_.toString).get
29 | val conn = orientDBWrapper.getConnection(params)
30 |
31 | try {
32 | orientDBWrapper.resolveTable(null, tableName)
33 | } finally {
34 | conn.close()
35 | }
36 | }
37 | }
38 |
39 | override def toString: String = s"""OrientDBRelation($tableNameOrSubQuery)"""
40 |
41 | override def insert(data: DataFrame, overwrite: Boolean): Unit = {
42 | val saveMode = if (overwrite) {
43 | SaveMode.Overwrite
44 | } else {
45 | SaveMode.Append
46 | }
47 | val writer = new OrientDBWriter(orientDBWrapper, orientDBClientFactory)
48 | writer.saveToOrientDB(data, saveMode, params)
49 | }
50 |
51 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
52 | filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined)
53 | }
54 |
55 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
56 | if (requiredColumns.isEmpty) {
57 | val whereClause = FilterPushdown.buildWhereClause(schema, filters)
58 | var countQuery = s"select count(*) from $tableNameOrSubQuery $whereClause"
59 |
60 | if (params.query.isDefined) {
61 | countQuery = tableNameOrSubQuery.drop(1).dropRight(1)
62 | }
63 |
64 | log.info("count query")
65 | val connection = orientDBWrapper.getConnection(params)
66 |
67 | try {
68 | // todo use future
69 | val results = orientDBWrapper.genericQuery(countQuery)
70 | if (params.query.isEmpty && results.nonEmpty) {
71 | val numRows: Long = results.head.field("count")
72 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
73 | val emptyRow = Row.empty
74 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow)
75 | } else if (params.query.isDefined) {
76 | val numRows: Long = results.length
77 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
78 | val emptyRow = Row.empty
79 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow)
80 | } else {
81 | throw new IllegalStateException("Cannot read count from OrientDB")
82 | }
83 | } finally {
84 | connection.close()
85 | }
86 | } else {
87 | var classname: String = null
88 | var clusters: List[String] = null
89 | if (params.query.isEmpty) {
90 | classname = params.className match {
91 | case Some(className) => className
92 | case None =>
93 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " +
94 | "name with the 'classname' parameter")
95 | }
96 | clusters = params.clusterNames match {
97 | case Some(clusterName) => clusterName
98 | case None =>
99 | val connection = orientDBWrapper.getConnection(params)
100 | val schema = connection.getMetadata.getSchema
101 | val currClass = schema.getClass(classname)
102 | List(connection.getClusterNameById(currClass.getDefaultClusterId))
103 | }
104 | }
105 |
106 | val filterStr = FilterPushdown.buildWhereClause(schema, filters)
107 | val connection = orientDBWrapper.getConnection(params)
108 | var oDocuments: List[ODocument] = List()
109 | try {
110 | // todo use Future
111 | if (params.query.isEmpty) {
112 | oDocuments = orientDBWrapper.read(clusters, classname, requiredColumns, filterStr, null)
113 | } else {
114 | oDocuments = orientDBWrapper
115 | .read(null, null, requiredColumns, filterStr, params.query.get)
116 | }
117 | } finally {
118 | connection.close()
119 | }
120 |
121 | if (params.query.isEmpty) {
122 | val prunedSchema = pruneSchema(schema, requiredColumns)
123 | sqlContext.sparkContext.makeRDD(
124 | oDocuments.map(oDocument => Conversions.convertODocumentsToRows(oDocument, prunedSchema))
125 | )
126 | } else {
127 | assert(oDocuments.nonEmpty)
128 | val prunedSchema = pruneSchema(schema, chooseRecordForSchema(oDocuments).fieldNames())
129 | sqlContext.sparkContext.makeRDD(
130 | oDocuments.map(oDocument => Conversions.convertODocumentsToRows(oDocument, prunedSchema))
131 | )
132 | }
133 | }
134 | }
135 |
136 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
137 | new StructType(schema.fields.filter(p => columns.contains(p.name)))
138 | }
139 |
140 | private def chooseRecordForSchema(oDocuments: List[ODocument]): ODocument = {
141 | var maxLen = -1
142 | var idx: ODocument = null
143 | oDocuments.foreach(oDocument =>
144 | if (maxLen < oDocument.fieldNames().length) {
145 | idx = oDocument
146 | maxLen = oDocument.fieldNames().length
147 | })
148 | idx
149 | }
150 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/OrientDBWriter.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import com.orientechnologies.orient.core.db.document.{ODatabaseDocument, ODatabaseDocumentPool, ODatabaseDocumentTxPooled}
4 | import Parameters.MergedParameters
5 | import org.apache.spark.sql.{DataFrame, Row, SaveMode}
6 | import org.slf4j.LoggerFactory
7 |
8 | private[orientdb] class OrientDBWriter(orientDBWrapper: OrientDBDocumentWrapper,
9 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory)
10 | extends Serializable {
11 | private val log = LoggerFactory.getLogger(getClass)
12 |
13 | private[orientdb] def createOrientDBClass(data: DataFrame, params: MergedParameters): Unit = {
14 | val dfSchema = data.schema
15 | val classname = params.className match {
16 | case Some(className) => className
17 | case None => throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " +
18 | "name with the 'classname' parameter")
19 | }
20 | var clusters = params.clusterNames match {
21 | case Some(clusterNames) => clusterNames
22 | case None => null
23 | }
24 |
25 | val connector = orientDBWrapper.getConnection(params)
26 | val schema = connector.getMetadata.getSchema
27 | val createdClass = schema.createClass(classname)
28 |
29 | dfSchema.foreach(field => {
30 | if (params.linkedType.nonEmpty && params.linkedType.get.exists(linkType => linkType._1.equals(field.name))) {
31 | createdClass.createProperty(field.name, Conversions.sparkDTtoOrientDBDT(field.dataType),
32 | connector.getMetadata.getSchema.getClass(params.linkedType.get(field.name).split("-").last))
33 | } else {
34 | createdClass.createProperty(field.name, Conversions.sparkDTtoOrientDBDT(field.dataType))
35 | }
36 | })
37 |
38 | if (clusters != null) {
39 | clusters.foreach(cluster => createdClass.addCluster(cluster))
40 | } else {
41 | clusters = List(connector.getClusterNameById(createdClass.getDefaultClusterId))
42 | }
43 | }
44 |
45 | private[orientdb] def dropOrientDBClass(params: MergedParameters): Unit = {
46 | val connection = orientDBWrapper.getConnection(params)
47 |
48 | // create class if not exists
49 | val classname = params.className
50 | if (classname.isEmpty) {
51 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " +
52 | "name with the 'classname' parameter")
53 | }
54 | var cluster = params.clusterNames
55 |
56 | // Todo use Future
57 | if (connection.getMetadata.getSchema.existsClass(classname.get)) {
58 | orientDBWrapper.delete(null, classname.get, null)
59 |
60 | val schema = connection.getMetadata.getSchema
61 | schema.dropClass(classname.get)
62 | }
63 | }
64 |
65 | private def doOrientDBLoad(connection: ODatabaseDocument,
66 | data: DataFrame,
67 | params: MergedParameters): Unit = {
68 | // create class if not exists
69 | val classname = params.className
70 | if (classname.isEmpty) {
71 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " +
72 | "name with the 'classname' parameter")
73 | }
74 |
75 | val schema = connection.getMetadata.getSchema
76 |
77 | if (!schema.existsClass(classname.get)) {
78 | createOrientDBClass(data, params)
79 | }
80 |
81 | var clusters = params.clusterNames
82 | if (clusters.isEmpty) {
83 | val schema = connection.getMetadata.getSchema
84 | val currClass = schema.getClass(classname.get)
85 | clusters = Some(List(connection.getClusterNameById(currClass.getDefaultClusterId)))
86 | }
87 |
88 | // Todo use future
89 | // load data into Orient DB
90 | try {
91 | data.foreachPartition((rows: Iterator[Row]) => {
92 | val ownerPool = new ODatabaseDocumentPool(params.dbUrl.get,
93 | params.credentials.get._1, params.credentials.get._2)
94 | val connection = new ODatabaseDocumentTxPooled(ownerPool, params.dbUrl.get,
95 | params.credentials.get._1, params.credentials.get._2)
96 | val rowsList = rows.toList
97 | val rowsPerCluster = rowsList.length % clusters.get.length match {
98 | case 0 => rowsList.length / clusters.get.length
99 | case _ => (rowsList.length / clusters.get.length) + 1
100 | }
101 |
102 | var countPerCluster = 0
103 | var clusterIdx = 0
104 | rowsList.foreach { row =>
105 | connection.save(Conversions.convertRowsToODocuments(row), clusters.get(clusterIdx))
106 | countPerCluster = countPerCluster + 1
107 |
108 | if (countPerCluster % rowsPerCluster == 0) {
109 | countPerCluster = 0
110 | clusterIdx = clusterIdx + 1
111 | }
112 | }
113 |
114 | connection.close()
115 | })
116 | } catch {
117 | case e: Exception =>
118 | throw new RuntimeException("An exception was thrown: " + e.getMessage)
119 | }
120 | }
121 |
122 | def saveToOrientDB(data: DataFrame, saveMode: SaveMode, params: MergedParameters): Unit = {
123 | val connection = orientDBWrapper.getConnection(params)
124 | try {
125 | if (saveMode == SaveMode.Overwrite) {
126 | dropOrientDBClass(params)
127 | }
128 | doOrientDBLoad(connection, data, params)
129 | } finally {
130 | connection.close()
131 | }
132 | }
133 | }
134 |
135 | object DefaultOrientDBWriter extends OrientDBWriter(
136 | DefaultOrientDBDocumentWrapper,
137 | orientDBCredemtials => new OrientDBClientFactory(orientDBCredemtials))
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/Parameters.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | private[orientdb] object Parameters {
4 | val DEFAULT_PARAMETERS: Map[String, String] = Map(
5 | "overwrite" -> "false"
6 | )
7 |
8 | def mergeParameters(userParameters: Map[String, String]): MergedParameters = {
9 | if (!userParameters.contains("dburl")) {
10 | throw new IllegalArgumentException("A Orient DB URL must be provided with 'dburl' parameter")
11 | }
12 |
13 | if (!userParameters.contains("class") && !userParameters.contains("query")) {
14 | throw new IllegalArgumentException("You must specify Orient DB class name with 'class' " +
15 | "parameter or a query with the 'query' parameter. If it is a 'query' you must define your own" +
16 | " schema or specify orientdb 'class' name.")
17 | }
18 |
19 | if (!userParameters.contains("user") || !userParameters.contains("password")) {
20 | throw new IllegalArgumentException("A Orient DB username & password must be provided" +
21 | " with 'user' & 'password' parameters respectively")
22 | }
23 | MergedParameters(DEFAULT_PARAMETERS ++ userParameters)
24 | }
25 |
26 | case class MergedParameters(parameters: Map[String, String]) {
27 |
28 | /**
29 | * OrientDB Table from where to load and write data.
30 | */
31 | def table: Option[TableName] = parameters.get("class").map(_.trim).flatMap(dbTable => {
32 | /**
33 | * this case is going to be handled in query variable.
34 | */
35 | if (dbTable.startsWith("(") && dbTable.endsWith(")"))
36 | None
37 | else
38 | Some(TableName(dbTable))
39 | })
40 |
41 | /**
42 | * The OrientDB query to be used when loading data
43 | */
44 | def query: Option[String] = parameters.get("query").orElse({
45 | parameters.get("class")
46 | .map(_.trim)
47 | .filter(t => t.startsWith("(") && t.endsWith(")"))
48 | .map(t => t.drop(1).dropRight(1))
49 | })
50 |
51 | /**
52 | * Username & Password for authenticating with OrientDB.
53 | */
54 | def credentials: Option[(String, String)] = {
55 | for {
56 | username <- parameters.get("user")
57 | password <- parameters.get("password")
58 | } yield (username, password)
59 | }
60 |
61 | /**
62 | * A url in the format
63 | *
64 | * remote::/
65 | */
66 | def dbUrl: Option[String] = parameters.get("dburl")
67 |
68 | /**
69 | * class name in Orient DB
70 | */
71 | def className: Option[String] = parameters.get("class")
72 |
73 | /**
74 | * cluster name in Orient DB
75 | */
76 | def clusterNames: Option[List[String]] =
77 | parameters.get("clusters").map(_.split(",").toList)
78 |
79 | def linkedType: Option[Map[String, String]] =
80 | if (parameters.exists(paramPair => paramPair._2.contains("linkedType")))
81 | Some(parameters.filter(paramPair => paramPair._2.contains("linkedType")))
82 | else None
83 | }
84 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/documents/TableName.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | private[orientdb] case class TableName(var unescapedTableName: String) {
4 | /**
5 | * drop the quotes from the two ends
6 | */
7 | if (unescapedTableName.startsWith("\"") && unescapedTableName.endsWith("\""))
8 | unescapedTableName = unescapedTableName.drop(1).dropRight(1)
9 |
10 | private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"'
11 | def escapedTableName: String = unescapedTableName
12 |
13 | override def toString: String = s"$escapedTableName"
14 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/graphs/DefaultSource.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
4 | import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
5 | import org.apache.spark.sql.types.StructType
6 | import org.slf4j.LoggerFactory
7 |
8 | class DefaultSource( orientDBGraphVertexWrapper: OrientDBGraphVertexWrapper,
9 | orientDBGraphEdgeWrapper: OrientDBGraphEdgeWrapper,
10 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory)
11 | extends RelationProvider
12 | with SchemaRelationProvider
13 | with CreatableRelationProvider {
14 | private val log = LoggerFactory.getLogger(getClass)
15 |
16 | def this() = this(DefaultOrientDBGraphVertexWrapper, DefaultOrientDBGraphEdgeWrapper,
17 | orientDBCredentials => new OrientDBClientFactory(orientDBCredentials))
18 |
19 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
20 | val params = Parameters.mergeParameters(parameters)
21 |
22 | if (params.query.isDefined && (params.vertexType.isEmpty && params.edgeType.isEmpty)) {
23 | throw new IllegalArgumentException("Along with the 'query' parameter you must specify either 'vertextype' parameter or" +
24 | " 'edgetype' parameter or user-defined Schema")
25 | }
26 |
27 | if (params.vertexType.isDefined && params.edgeType.isEmpty) {
28 | OrientDBVertexRelation(orientDBGraphVertexWrapper, orientDBClientFactory, params, None)(sqlContext)
29 | } else {
30 | OrientDBEdgeRelation(orientDBGraphEdgeWrapper, orientDBClientFactory, params, None)(sqlContext)
31 | }
32 | }
33 |
34 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String],
35 | schema: StructType): BaseRelation = {
36 | val params = Parameters.mergeParameters(parameters)
37 |
38 | if (params.vertexType.isDefined && params.edgeType.isEmpty) {
39 | OrientDBVertexRelation(orientDBGraphVertexWrapper, orientDBClientFactory, params, Some(schema))(sqlContext)
40 | } else {
41 | OrientDBEdgeRelation(orientDBGraphEdgeWrapper, orientDBClientFactory, params, Some(schema))(sqlContext)
42 | }
43 | }
44 |
45 | override def createRelation(sqlContext: SQLContext, mode: SaveMode,
46 | parameters: Map[String, String], data: DataFrame): BaseRelation = {
47 | val params = Parameters.mergeParameters(parameters)
48 |
49 | val vertexType = params.vertexType
50 | val edgeType = params.edgeType
51 |
52 | if (vertexType.isEmpty && edgeType.isEmpty) {
53 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Graph Vertex" +
54 | " or Edge type with the 'vertextype' & 'edgetype' parameter respectively")
55 | }
56 |
57 |
58 | def tableExists: Boolean = {
59 | if (vertexType.isDefined && edgeType.isEmpty) {
60 | try {
61 | orientDBGraphVertexWrapper.doesVertexTypeExists(vertexType.get)
62 | } finally {
63 | orientDBGraphEdgeWrapper.close()
64 | }
65 | }
66 | else {
67 | try {
68 | orientDBGraphEdgeWrapper.doesEdgeTypeExists(edgeType.get)
69 | } finally {
70 | orientDBGraphEdgeWrapper.close()
71 | }
72 | }
73 | }
74 |
75 | if (vertexType.isDefined && edgeType.isEmpty) {
76 | val (doSave, dropExisting) = mode match {
77 | case SaveMode.Append => (true, false)
78 | case SaveMode.Overwrite => (true, true)
79 | case SaveMode.ErrorIfExists =>
80 | if (tableExists) {
81 | sys.error(s"Vertex type $vertexType already exists! (SaveMode is set to ErrorIfExists)")
82 | } else {
83 | (true, false)
84 | }
85 | case SaveMode.Ignore =>
86 | if (tableExists) {
87 | log.info(s"Vertex Type $vertexType already exists. Ignoring save requests.")
88 | (false, false)
89 | } else {
90 | (true, false)
91 | }
92 | }
93 |
94 | if (doSave) {
95 | val updatedParams = parameters.updated("overwrite", dropExisting.toString)
96 | new OrientDBVertexWriter(orientDBGraphVertexWrapper, orientDBClientFactory)
97 | .saveToOrientDB(data, mode, Parameters.mergeParameters(updatedParams))
98 | }
99 | createRelation(sqlContext, parameters)
100 | } else {
101 | if (!parameters.contains("vertextype") && parameters.contains("edgetype")) {
102 | throw new IllegalArgumentException("You must specify the Orient DB Vertex type in the 'vertextype'" +
103 | " parameter along with Orient DB Edge type in the 'edgetype' parameter")
104 | }
105 |
106 | try {
107 | val connection = orientDBGraphEdgeWrapper.getConnection(params)
108 | orientDBGraphEdgeWrapper.doesEdgeTypeExists(edgeType.get)
109 | } finally {
110 | orientDBGraphEdgeWrapper.close()
111 | }
112 |
113 | val (doSave, dropExisting) = mode match {
114 | case SaveMode.Append => (true, false)
115 | case SaveMode.Overwrite => (true, true)
116 | case SaveMode.ErrorIfExists =>
117 | if (tableExists) {
118 | sys.error(s"Edge Type $edgeType already exists! (SaveMode is set to ErrorIfExists)")
119 | } else {
120 | (true, false)
121 | }
122 | case SaveMode.Ignore =>
123 | if (tableExists) {
124 | log.info(s"Edge Type $edgeType already exists. Ignoring save requests.")
125 | (false, false)
126 | } else {
127 | (true, false)
128 | }
129 | }
130 |
131 | if (doSave) {
132 | val updatedParams = parameters.updated("overwrite", dropExisting.toString)
133 | new OrientDBEdgeWriter(orientDBGraphEdgeWrapper, orientDBClientFactory)
134 | .saveToOrientDB(data, mode, Parameters.mergeParameters(updatedParams))
135 | }
136 | createRelation(sqlContext, parameters)
137 | }
138 | }
139 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/graphs/OrientDBClientFactory.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import com.tinkerpop.blueprints.impls.orient.{OrientGraphFactory, OrientGraphNoTx}
4 |
5 | class OrientDBClientFactory(orientDBCredentials: OrientDBCredentials) extends Serializable {
6 | private val db: OrientGraphFactory =
7 | new OrientGraphFactory(orientDBCredentials.dbUrl, orientDBCredentials.username,
8 | orientDBCredentials.password)
9 | private var connection: OrientGraphNoTx = _
10 |
11 | def getConnection(): OrientGraphNoTx = {
12 | connection = db.getNoTx
13 | connection
14 | }
15 |
16 | def closeConnection(): Unit = {
17 | db.close()
18 | }
19 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/graphs/OrientDBCredentials.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | trait OrientDBCredentials extends Serializable {
4 | var dbUrl: String = null
5 | var username: String = null
6 | var password: String = null
7 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/graphs/OrientDBRelation.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import com.tinkerpop.blueprints.{Edge, Vertex}
4 | import org.apache.spark.orientdb.documents.{Conversions, FilterPushdown}
5 | import org.apache.spark.orientdb.graphs.Parameters.MergedParameters
6 | import org.apache.spark.rdd.RDD
7 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
8 | import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan}
9 | import org.apache.spark.sql.types.StructType
10 | import org.slf4j.LoggerFactory
11 |
12 | private[orientdb] case class OrientDBVertexRelation(
13 | orientDBVertexWrapper: OrientDBGraphVertexWrapper,
14 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory,
15 | params: MergedParameters,
16 | userSchema: Option[StructType]
17 | ) (@transient val sqlContext: SQLContext)
18 | extends BaseRelation
19 | with PrunedFilteredScan
20 | with InsertableRelation {
21 | private val log = LoggerFactory.getLogger(getClass)
22 |
23 | private val tableNameOrSubQuery = params.query.map(q => s"($q)").orElse(params.vertexType.map(_.toString)).get
24 |
25 | override lazy val schema: StructType = {
26 | userSchema.getOrElse{
27 | val tableName = params.vertexType.map(_.toString).get
28 | val conn = orientDBVertexWrapper.getConnection(params)
29 |
30 | try {
31 | orientDBVertexWrapper.resolveTable(tableName)
32 | } finally {
33 | orientDBVertexWrapper.close()
34 | }
35 | }
36 | }
37 |
38 | override def toString: String = s"""OrientDBRelation($tableNameOrSubQuery)"""
39 |
40 | override def insert(data: DataFrame, overwrite: Boolean): Unit = {
41 | val saveMode = if (overwrite) {
42 | SaveMode.Overwrite
43 | } else {
44 | SaveMode.Append
45 | }
46 | val writer = new OrientDBVertexWriter(orientDBVertexWrapper,
47 | orientDBClientFactory)
48 | writer.saveToOrientDB(data, saveMode, params)
49 | }
50 |
51 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
52 | filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined)
53 | }
54 |
55 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
56 | if (requiredColumns.isEmpty) {
57 | val whereClause = FilterPushdown.buildWhereClause(schema, filters)
58 | var countQuery = s"select count(*) from $tableNameOrSubQuery $whereClause"
59 |
60 | if (params.query.isDefined) {
61 | countQuery = tableNameOrSubQuery.drop(1).dropRight(1)
62 | }
63 |
64 | log.info("count query")
65 | orientDBVertexWrapper.getConnection(params)
66 |
67 | try {
68 | val results = orientDBVertexWrapper.genericQuery(countQuery)
69 | if (params.query.isEmpty && results.nonEmpty) {
70 | val numRows: Long = results.head.getProperty[Long]("count")
71 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
72 | val emptyRow = Row.empty
73 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow)
74 | } else if (params.query.isDefined) {
75 | val numRows: Long = results.length
76 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
77 | val emptyRow = Row.empty
78 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow)
79 | } else {
80 | throw new IllegalArgumentException("Cannot read count for OrientDB graph vertices")
81 | }
82 | } finally {
83 | orientDBVertexWrapper.close()
84 | }
85 | } else {
86 | var vertexTypeName: String = null
87 | if (params.query.isEmpty) {
88 | vertexTypeName = params.vertexType match {
89 | case Some(vertexType) => vertexType
90 | case None =>
91 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Graph" +
92 | " Vertex type name with the 'vertextype' parameter")
93 | }
94 | }
95 |
96 | val filterStr = FilterPushdown.buildWhereClause(schema, filters)
97 | var oVertices: List[Vertex] = List()
98 | try {
99 | orientDBVertexWrapper.getConnection(params)
100 | if (params.query.isEmpty) {
101 | oVertices = orientDBVertexWrapper.read(vertexTypeName, requiredColumns,
102 | filterStr, null)
103 | } else {
104 | oVertices = orientDBVertexWrapper.read(null, requiredColumns, filterStr,
105 | params.query.get)
106 | }
107 | } finally {
108 | orientDBVertexWrapper.close()
109 | }
110 |
111 | if (params.query.isEmpty) {
112 | val prunedSchema = pruneSchema(schema, requiredColumns)
113 | sqlContext.sparkContext.makeRDD(
114 | oVertices.map(vertex => Conversions.convertVerticesToRows(vertex, prunedSchema))
115 | )
116 | } else {
117 | assert(oVertices.nonEmpty)
118 | val propKeysArray = new Array[String](chooseRecordForSchema(oVertices).getPropertyKeys.size())
119 | val prunedSchema = pruneSchema(schema, chooseRecordForSchema(oVertices)
120 | .getPropertyKeys.toArray[String](propKeysArray))
121 | sqlContext.sparkContext.makeRDD(
122 | oVertices.map(vertex => Conversions.convertVerticesToRows(vertex, prunedSchema))
123 | )
124 | }
125 | }
126 | }
127 |
128 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
129 | new StructType(schema.fields.filter(p => columns.contains(p.name)))
130 | }
131 |
132 | private def chooseRecordForSchema(oVertices: List[Vertex]): Vertex = {
133 | var maxLen = -1
134 | var idx: Vertex = null
135 | oVertices.foreach(oVertex => {
136 | if (maxLen < oVertex.getPropertyKeys.size()) {
137 | idx = oVertex
138 | maxLen = oVertex.getPropertyKeys.size()
139 | }
140 | })
141 | idx
142 | }
143 | }
144 |
145 | private[orientdb] case class OrientDBEdgeRelation(
146 | orientDBEdgeWrapper: OrientDBGraphEdgeWrapper,
147 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory,
148 | params: MergedParameters,
149 | userSchema: Option[StructType]
150 | ) (@transient val sqlContext: SQLContext)
151 | extends BaseRelation
152 | with PrunedFilteredScan
153 | with InsertableRelation {
154 | private val log = LoggerFactory.getLogger(getClass)
155 |
156 | private val tableNameOrSubQuery = params.query.map(q => s"($q)").orElse(params.edgeType.map(_.toString)).get
157 |
158 | override lazy val schema: StructType = {
159 | userSchema.getOrElse{
160 | val tableName = params.edgeType.map(_.toString).get
161 | val conn = orientDBEdgeWrapper.getConnection(params)
162 |
163 | try {
164 | orientDBEdgeWrapper.resolveTable(tableName)
165 | } finally {
166 | orientDBEdgeWrapper.close()
167 | }
168 | }
169 | }
170 |
171 | override def toString: String = s"""OrientDBRelation($tableNameOrSubQuery)"""
172 |
173 | override def insert(data: DataFrame, overwrite: Boolean): Unit = {
174 | val saveMode = if (overwrite) {
175 | SaveMode.Overwrite
176 | } else {
177 | SaveMode.Append
178 | }
179 | val writer = new OrientDBEdgeWriter(orientDBEdgeWrapper, orientDBClientFactory)
180 | writer.saveToOrientDB(data, saveMode, params)
181 | }
182 |
183 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
184 | filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined)
185 | }
186 |
187 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
188 | if (requiredColumns.isEmpty) {
189 | val whereClause = FilterPushdown.buildWhereClause(schema, filters)
190 | var countQuery = s"select count(*) from $tableNameOrSubQuery $whereClause"
191 |
192 | if (params.query.isDefined) {
193 | countQuery = tableNameOrSubQuery.drop(1).dropRight(1)
194 | }
195 |
196 | log.info("count query")
197 | orientDBEdgeWrapper.getConnection(params)
198 |
199 | try {
200 | val results = orientDBEdgeWrapper.genericQuery(countQuery)
201 | if (params.query.isEmpty && results.nonEmpty) {
202 | val numRows: Long = results.head.getProperty[Long]("count")
203 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
204 | val emptyRow = Row.empty
205 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow)
206 | } else if (params.query.isDefined) {
207 | val numRows: Long = results.length
208 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
209 | val emptyRow = Row.empty
210 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow)
211 | } else {
212 | throw new IllegalArgumentException("Cannot read count for OrientDB graph edges")
213 | }
214 | } finally {
215 | orientDBEdgeWrapper.close()
216 | }
217 | } else {
218 | var edgeTypeName: String = null
219 | if (params.query.isEmpty) {
220 | edgeTypeName = params.edgeType match {
221 | case Some(edgeType) => edgeType
222 | case None => throw new IllegalArgumentException("For save operations you must specify a OrientDB Graph" +
223 | " Edge type name with the 'edgetype' parameter")
224 | }
225 | }
226 |
227 | val filterStr = FilterPushdown.buildWhereClause(schema, filters)
228 | var oEdges: List[Edge] = List()
229 |
230 | try {
231 | if (params.query.isEmpty) {
232 | oEdges = orientDBEdgeWrapper.read(edgeTypeName, requiredColumns, filterStr, null)
233 | } else {
234 | oEdges = orientDBEdgeWrapper.read(null, requiredColumns, filterStr, params.query.get)
235 | }
236 | } finally {
237 | orientDBEdgeWrapper.close()
238 | }
239 |
240 | if (params.query.isEmpty) {
241 | val prunedSchema = pruneSchema(schema, requiredColumns)
242 | sqlContext.sparkContext.makeRDD(
243 | oEdges.map(edge => Conversions.convertEdgesToRows(edge, prunedSchema))
244 | )
245 | } else {
246 | assert(oEdges.nonEmpty)
247 | val propKeysArray = new Array[String](chooseRecordForSchema(oEdges).getPropertyKeys.size())
248 | val prunedSchema = pruneSchema(schema,
249 | chooseRecordForSchema(oEdges).getPropertyKeys.toArray[String](propKeysArray))
250 | sqlContext.sparkContext.makeRDD(
251 | oEdges.map(edge => Conversions.convertEdgesToRows(edge, prunedSchema))
252 | )
253 | }
254 | }
255 | }
256 |
257 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
258 | new StructType(schema.fields.filter(p => columns.contains(p.name)))
259 | }
260 |
261 | private def chooseRecordForSchema(oEdges: List[Edge]): Edge = {
262 | var maxLen = -1
263 | var idx: Edge = null
264 | oEdges.foreach(oEdge => {
265 | if (maxLen < oEdge.getPropertyKeys.size()) {
266 | idx = oEdge
267 | maxLen = oEdge.getPropertyKeys.size()
268 | }
269 | })
270 | idx
271 | }
272 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/graphs/OrientDBWriter.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import java.util
4 |
5 | import com.orientechnologies.orient.core.metadata.schema.OClass
6 | import com.orientechnologies.orient.core.sql.OCommandSQL
7 | import com.tinkerpop.blueprints.Vertex
8 | import com.tinkerpop.blueprints.impls.orient.{OrientGraphFactory, OrientGraphNoTx}
9 | import org.apache.spark.orientdb.documents.Conversions
10 | import org.apache.spark.orientdb.graphs.Parameters.MergedParameters
11 | import org.apache.spark.sql.{DataFrame, Row, SaveMode}
12 | import org.slf4j.LoggerFactory
13 |
14 | import scala.collection.JavaConversions._
15 |
16 | private[orientdb] class OrientDBVertexWriter(orientDBWrapper: OrientDBGraphVertexWrapper,
17 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory)
18 | extends Serializable {
19 |
20 | private[orientdb] def createOrientDBVertex(data: DataFrame, params: MergedParameters): Unit = {
21 | val dfSchema = data.schema
22 | val vertexType = params.vertexType match {
23 | case Some(vertexTypeName) => vertexTypeName
24 | case None => throw new IllegalArgumentException("For save operations you must specify a OrientDB Vertex Type" +
25 | " with the 'vertextype' parameter")
26 | }
27 | var cluster = params.cluster match {
28 | case Some(clusterName) => clusterName
29 | case None => null
30 | }
31 |
32 | val connector = orientDBWrapper.getConnection(params)
33 | val createdVertexType = connector.createVertexType(vertexType)
34 |
35 | dfSchema.foreach(field => {
36 | if (params.linkedType.nonEmpty && params.linkedType.get.exists(linkType => linkType._1.equals(field.name))) {
37 | createdVertexType.createProperty(field.name,
38 | Conversions.sparkDTtoOrientDBDT(field.dataType),
39 | connector.getVertexType(params.linkedType.get(field.name).split("-").last))
40 | } else {
41 | createdVertexType.createProperty(field.name, Conversions.sparkDTtoOrientDBDT(field.dataType))
42 | }
43 |
44 | if (field.name == "id") {
45 | createdVertexType.createIndex(s"${vertexType}Idx", OClass.INDEX_TYPE.UNIQUE, field.name)
46 | }
47 | })
48 | }
49 |
50 | private[orientdb] def dropOrientDBVertex(params: MergedParameters): Unit = {
51 | val connection = orientDBWrapper.getConnection(params)
52 |
53 | val vertexType = params.vertexType
54 | if (vertexType.isEmpty) {
55 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Vertex Type" +
56 | " with the 'vertextype' parameter")
57 | }
58 |
59 | if (connection.getVertexType(vertexType.get) != null) {
60 | orientDBWrapper.delete(vertexType.get, null)
61 | connection.dropVertexType(vertexType.get)
62 | }
63 | }
64 |
65 | private def doOrientDBVertexLoad(connection: OrientGraphNoTx,
66 | data: DataFrame,
67 | params: MergedParameters): Unit = {
68 | val vertexType = params.vertexType
69 | if (vertexType.isEmpty) {
70 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Vertex Type" +
71 | " with the 'vertextype' parameter")
72 | }
73 |
74 | if (connection.getVertexType(vertexType.get) == null) {
75 | createOrientDBVertex(data, params)
76 | }
77 |
78 | try {
79 | data.foreachPartition((rows: Iterator[Row]) => {
80 | val graphFactory = new OrientGraphFactory(params.dbUrl.get,
81 | params.credentials.get._1,
82 | params.credentials.get._2)
83 | val connection = graphFactory.getNoTx
84 |
85 | while (rows.hasNext) {
86 | val row = rows.next()
87 |
88 | val fields = row.schema.fields
89 |
90 | val fieldNames = fields.map(_.name)
91 | if (!fieldNames.contains("id")) {
92 | throw new IllegalArgumentException("'id' is a mandatory parameter " +
93 | "for creating a vertex")
94 | }
95 | val key = row.getAs[Object](fieldNames.indexOf("id"))
96 | val properties = new util.HashMap[String, Object]()
97 | properties.put("id", key)
98 | connection.setStandardElementConstraints(false)
99 | val createdVertex = connection.addVertex(s"class:${params.vertexType.get}", properties)
100 |
101 | var count = 0
102 | while (count < fields.length) {
103 | val sparkType = fields(count).dataType
104 | val orientDBType = Conversions
105 | .sparkDTtoOrientDBDT(sparkType)
106 | createdVertex.setProperty(fields(count).name,
107 | Conversions.convertRowToGraph(row, count), orientDBType)
108 |
109 | count = count + 1
110 | }
111 | createdVertex.moveToClass(params.vertexType.get)
112 | }
113 | graphFactory.close()
114 | })
115 | } catch {
116 | case e: Exception =>
117 | throw new RuntimeException("An exception was thrown: " + e.getMessage)
118 | }
119 | }
120 |
121 | def saveToOrientDB(data: DataFrame, saveMode: SaveMode, params: MergedParameters): Unit = {
122 | val connection = orientDBWrapper.getConnection(params)
123 | try {
124 | if (saveMode == SaveMode.Overwrite) {
125 | dropOrientDBVertex(params)
126 | }
127 | doOrientDBVertexLoad(connection, data, params)
128 | } finally {
129 | orientDBWrapper.close()
130 | }
131 | }
132 | }
133 |
134 | object DefaultOrientDBVertexWriter extends OrientDBVertexWriter(
135 | DefaultOrientDBGraphVertexWrapper,
136 | orientDBCredentials => new OrientDBClientFactory(orientDBCredentials))
137 |
138 | private[orientdb] class OrientDBEdgeWriter(orientDBWrapper: OrientDBGraphEdgeWrapper,
139 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory)
140 | extends Serializable {
141 | private val log = LoggerFactory.getLogger(getClass)
142 |
143 | private[orientdb] def createOrientDBEdge(data: DataFrame, params: MergedParameters): Unit = {
144 | val dfSchema = data.schema
145 | val edgeType = params.edgeType match {
146 | case Some(edgeTypeName) => edgeTypeName
147 | case None => throw new IllegalArgumentException("For save operations you must specify a OrientDB Edge Type" +
148 | " with the 'edgetype' parameter")
149 | }
150 |
151 | val connector = orientDBWrapper.getConnection(params)
152 | val createdEdgeType = connector.createEdgeType(edgeType)
153 | if (!params.lightWeightEdge) {
154 | dfSchema.foreach(field => {
155 | if (params.linkedType.nonEmpty && params.linkedType.get.exists(linkType => linkType._1.equals(field.name))) {
156 | createdEdgeType.createProperty(field.name,
157 | Conversions.sparkDTtoOrientDBDT(field.dataType),
158 | connector.getEdgeType(params.linkedType.get(field.name).split("-").last))
159 | }
160 | createdEdgeType.createProperty(field.name, Conversions.sparkDTtoOrientDBDT(field.dataType))
161 | })
162 | }
163 | }
164 |
165 | private[orientdb] def dropOrientDBEdge(params: MergedParameters): Unit = {
166 | val connection = orientDBWrapper.getConnection(params)
167 |
168 | val edgeType = params.edgeType
169 | if (edgeType.isEmpty) {
170 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Edge Type" +
171 | " with the 'edgetype' parameter")
172 | }
173 |
174 | if (connection.getEdgeType(edgeType.get) != null) {
175 | orientDBWrapper.delete(edgeType.get, null)
176 | connection.dropEdgeType(edgeType.get)
177 | }
178 | }
179 |
180 | private def doOrientDBEdgeLoad(connection: OrientGraphNoTx,
181 | data: DataFrame,
182 | params: MergedParameters): Unit = {
183 | val edgeType = params.edgeType
184 | if (edgeType.isEmpty) {
185 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Edge Type" +
186 | " with the 'edgetype' parameter")
187 | }
188 |
189 | if (connection.getEdgeType(edgeType.get) == null) {
190 | createOrientDBEdge(data, params)
191 | }
192 |
193 | val (inVertexType, outVertexType) = params.vertexType match {
194 | case Some(vertexTypeNames) =>
195 | val cols = vertexTypeNames.split(",")
196 | if (cols.length == 1) (cols(0), cols(0))
197 | else if (cols.length == 2) (cols(0), cols(1))
198 | else throw new IllegalArgumentException("More than 2 'vertextype' specified")
199 | case None =>
200 | throw new IllegalArgumentException("Saving edges also require a vertex type specified by " +
201 | "'vertextype' parameter")
202 | }
203 |
204 | try {
205 | data.foreachPartition((rows: Iterator[Row]) => {
206 | val graphFactory = new OrientGraphFactory(params.dbUrl.get,
207 | params.credentials.get._1,
208 | params.credentials.get._2)
209 | val connection = graphFactory.getNoTx
210 |
211 | while (rows.hasNext) {
212 | val row = rows.next()
213 |
214 | val fields = row.schema.fields
215 |
216 | var inVertexName: String = null
217 | try {
218 | inVertexName = row.getAs[Object](fields.map(_.name).indexOf("src")).toString
219 | } catch {
220 | case e: Exception => throw new IllegalArgumentException("'src' is a mandatory parameter " +
221 | "for creating an edge")
222 | }
223 |
224 | var outVertexName: String = null
225 | try {
226 | outVertexName = row.getAs[Object](fields.map(_.name).indexOf("dst")).toString
227 | } catch {
228 | case e: Exception => throw new IllegalArgumentException("'dst' is a mandatory parameters " +
229 | "for creating an edge")
230 | }
231 |
232 | val inVertices: List[Vertex] = connection
233 | .command(new OCommandSQL(s"select * from $inVertexType where id = '$inVertexName'"))
234 | .execute()
235 | .asInstanceOf[java.lang.Iterable[Vertex]]
236 | .toList
237 |
238 | val inVertex: Option[Vertex] =
239 | if (inVertices.isEmpty && params.createVertexIfNotExist) {
240 | // log.info(s"in Vertex $inVertexName does not exist. Creating it...")
241 | val v = connection.addVertex(inVertexType, null)
242 | v.setProperty("id", inVertexName)
243 | Some(v)
244 | } else {
245 | inVertices.headOption
246 | }
247 |
248 | val outVertices: List[Vertex] = connection
249 | .command(new OCommandSQL(s"select * from $outVertexType where id = '$outVertexName'"))
250 | .execute()
251 | .asInstanceOf[java.lang.Iterable[Vertex]]
252 | .toList
253 |
254 | val outVertex: Option[Vertex] =
255 | if (outVertices.isEmpty && params.createVertexIfNotExist) {
256 | // log.info(s"out Vertex $outVertexName does not exist. Creating it...")
257 | val v = connection.addVertex(outVertexType, null)
258 | v.setProperty("id", outVertexName)
259 | Some(v)
260 | } else {
261 | outVertices.headOption
262 | }
263 |
264 | for {
265 | in <- inVertex
266 | out <- outVertex
267 | } {
268 | var id: String = null
269 | for (i <- fields.indices) {
270 | if (fields(i).name == "id") {
271 | id = Conversions.convertRowToGraph(row, i).toString
272 | }
273 | }
274 | val createdEdge = connection.addEdge(id, in, out, params.edgeType.get)
275 | for (i <- fields.indices) {
276 | val sparkType = fields(i).dataType
277 | val orientDBType = Conversions.sparkDTtoOrientDBDT(sparkType)
278 | createdEdge.setProperty(fields(i).name, Conversions.convertRowToGraph(row, i), orientDBType)
279 | }
280 | }
281 |
282 | graphFactory.close()
283 | }
284 | })
285 | } catch {
286 | case e: Exception =>
287 | throw new RuntimeException("An exception was thrown: " + e.getMessage)
288 | }
289 | }
290 |
291 | def saveToOrientDB(data: DataFrame, saveMode: SaveMode, params: MergedParameters): Unit = {
292 | val connection = orientDBWrapper.getConnection(params)
293 |
294 | try {
295 | if (saveMode == SaveMode.Overwrite) {
296 | dropOrientDBEdge(params)
297 | }
298 | doOrientDBEdgeLoad(connection, data, params)
299 | } finally {
300 | orientDBWrapper.close()
301 | }
302 | }
303 | }
304 |
305 | object DefaultOrientDBEdgeWriter extends OrientDBEdgeWriter(
306 | DefaultOrientDBGraphEdgeWrapper,
307 | orientDBCredentials => new OrientDBClientFactory(orientDBCredentials))
308 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/graphs/Parameters.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | private[orientdb] object Parameters {
4 | val DEFAULT_PARAMETERS: Map[String, String] = Map(
5 | "overwrite" -> "false"
6 | )
7 |
8 | def mergeParameters(userParameters: Map[String, String]): MergedParameters = {
9 | if (!userParameters.contains("dburl")) {
10 | throw new IllegalArgumentException("A Orient DB URL must be provided with 'dburl' parameter")
11 | }
12 |
13 | if (!userParameters.contains("vertextype") && !userParameters.contains("edgetype") && !userParameters.contains("query")) {
14 | throw new IllegalArgumentException("You must specify one of Orient DB Vertex type in the 'vertextype'" +
15 | " parameter or Orient DB Edge type in the 'edgetype' parameter or a user specified query using 'query' parameter")
16 | }
17 |
18 | if (!userParameters.contains("user") || !userParameters.contains("password")) {
19 | throw new IllegalArgumentException("You must specify both the OrientDB username in 'user' parameter &" +
20 | " OrientDB password in the 'password' parameter")
21 | }
22 | MergedParameters(DEFAULT_PARAMETERS ++ userParameters)
23 | }
24 |
25 | case class MergedParameters(parameters: Map[String, String]) {
26 |
27 | /**
28 | * The Orient DB Graph vertex Type to be used to load & write data
29 | */
30 | def vertexType: Option[String] = parameters.get("vertextype").orElse(None)
31 |
32 | /**
33 | * The Orient DB Graph edge Type to be used to load & write data
34 | */
35 | def edgeType: Option[String] = parameters.get("edgetype").orElse(None)
36 |
37 | def createVertexIfNotExist: Boolean = parameters.get("createVertexIfNotExist").isDefined &&
38 | parameters("createVertexIfNotExist") == "true"
39 |
40 | def lightWeightEdge: Boolean = parameters.get("lightWeightEdge").isDefined &&
41 | parameters("lightWeightEdge") == "true"
42 |
43 | /**
44 | * The Orient DB Graph sql query to be used for loading data
45 | */
46 | def query: Option[String] = parameters.get("query").orElse(None)
47 |
48 | /**
49 | * Username & Password for authentication with OrientDB
50 | */
51 | def credentials: Option[(String, String)] = {
52 | for {
53 | username <- parameters.get("user")
54 | password <- parameters.get("password")
55 | } yield (username, password)
56 | }
57 |
58 | /**
59 | * A url in the format
60 | * remote::/
61 | */
62 | def dbUrl: Option[String] = parameters.get("dburl")
63 |
64 | /**
65 | * cluster name in Orient DB
66 | */
67 | def cluster: Option[String] = parameters.get("cluster").orElse(None)
68 |
69 | /**
70 | * mention linked properties in the form "vertextype/edgetype" - "linkedType-linked vertextype/edgetype"
71 | */
72 | def linkedType: Option[Map[String, String]] =
73 | if (parameters.exists(paramPair => paramPair._2.contains("linkedType")))
74 | Some(parameters.filter(paramPair => paramPair._2.contains("linkedType")))
75 | else None
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/udts/EmbeddedListType.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.udts
2 |
3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
4 |
5 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
6 | import org.apache.spark.sql.types._
7 | import org.apache.spark.unsafe.types.UTF8String
8 |
9 | @SQLUserDefinedType(udt = classOf[EmbeddedListType])
10 | case class EmbeddedList(elements: Array[Any]) extends Serializable {
11 | override def hashCode(): Int = {
12 | var hashCode = 1
13 | val i = elements.iterator
14 | while (i.hasNext) {
15 | val obj = i.next()
16 |
17 | val elemValue = if (obj == null) 0 else obj.hashCode()
18 | hashCode = 31 * hashCode + elemValue
19 | }
20 | hashCode
21 | }
22 |
23 | override def equals(other: scala.Any): Boolean = other match {
24 | case that: EmbeddedList => that.elements.sameElements(this.elements)
25 | case _ => false
26 | }
27 |
28 | override def toString: String = elements.mkString(", ")
29 | }
30 |
31 | class EmbeddedListType extends UserDefinedType[EmbeddedList] {
32 |
33 | override def sqlType: DataType = ArrayType(StringType)
34 |
35 | override def serialize(obj: EmbeddedList): Any = {
36 | new GenericArrayData(obj.elements.map{elem =>
37 | val out = new ByteArrayOutputStream()
38 | val os = new ObjectOutputStream(out)
39 | os.writeObject(elem)
40 | UTF8String.fromBytes(out.toByteArray)
41 | })
42 | }
43 |
44 | override def deserialize(datum: Any): EmbeddedList = {
45 | datum match {
46 | case values: ArrayData =>
47 | new EmbeddedList(values.toArray[UTF8String](StringType).map{ elem =>
48 | val in = new ByteArrayInputStream(elem.getBytes)
49 | val is = new ObjectInputStream(in)
50 | is.readObject()
51 | })
52 | case other => sys.error(s"Cannot deserialize $other")
53 | }
54 | }
55 |
56 | override def userClass: Class[EmbeddedList] = classOf[EmbeddedList]
57 | }
58 |
59 | object EmbeddedListType extends EmbeddedListType
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/udts/EmbeddedMapType.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.udts
2 |
3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
4 |
5 | import org.apache.spark.sql.catalyst.expressions.UnsafeMapData
6 | import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
7 | import org.apache.spark.sql.types._
8 | import org.apache.spark.unsafe.types.UTF8String
9 |
10 | @SQLUserDefinedType(udt = classOf[EmbeddedMapType])
11 | case class EmbeddedMap(elements: Map[Any, Any]) extends Serializable {
12 | override def hashCode(): Int = 1
13 |
14 | override def equals(other: scala.Any): Boolean = other match {
15 | case that: EmbeddedMap => that.elements == this.elements
16 | case _ => false
17 | }
18 |
19 | override def toString: String = elements.mkString(", ")
20 | }
21 |
22 | class EmbeddedMapType extends UserDefinedType[EmbeddedMap] {
23 |
24 | override def sqlType: DataType = MapType(StringType, StringType)
25 |
26 | override def serialize(obj: EmbeddedMap): Any = {
27 | ArrayBasedMapData(obj.elements.keySet.map{ elem =>
28 | val outKey = new ByteArrayOutputStream()
29 | val osKey = new ObjectOutputStream(outKey)
30 | osKey.writeObject(elem)
31 | UTF8String.fromBytes(outKey.toByteArray)
32 | }.toArray,
33 | obj.elements.values.map{ elem =>
34 | val outValue = new ByteArrayOutputStream()
35 | val osValue = new ObjectOutputStream(outValue)
36 | osValue.writeObject(elem)
37 | UTF8String.fromBytes(outValue.toByteArray)
38 | }.toArray)
39 | }
40 |
41 | override def deserialize(datum: Any): EmbeddedMap = {
42 | datum match {
43 | case values: UnsafeMapData =>
44 | new EmbeddedMap(values.keyArray().toArray[UTF8String](StringType).map{ elem =>
45 | val in = new ByteArrayInputStream(elem.getBytes)
46 | val is = new ObjectInputStream(in)
47 | is.readObject()
48 | }.zip(values.valueArray().toArray[UTF8String](StringType).map{ elem =>
49 | val in = new ByteArrayInputStream(elem.getBytes)
50 | val is = new ObjectInputStream(in)
51 | is.readObject()
52 | }).toMap)
53 | case other => sys.error(s"Cannot deserialize $other")
54 | }
55 | }
56 |
57 | override def userClass: Class[EmbeddedMap] = classOf[EmbeddedMap]
58 | }
59 |
60 | object EmbeddedMapType extends EmbeddedMapType
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/udts/EmbeddedSetType.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.udts
2 |
3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
4 |
5 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
6 | import org.apache.spark.sql.types._
7 | import org.apache.spark.unsafe.types.UTF8String
8 |
9 | @SQLUserDefinedType(udt = classOf[EmbeddedSetType])
10 | case class EmbeddedSet(elements: Array[Any]) extends Serializable {
11 | override def hashCode(): Int = {
12 | var hashCode = 1
13 | val i = elements.iterator
14 | while (i.hasNext) {
15 | val obj = i.next()
16 |
17 | val elemValue = if (obj == null) 0 else obj.hashCode()
18 | hashCode = 31 * hashCode + elemValue
19 | }
20 | hashCode
21 | }
22 |
23 | override def equals(other: scala.Any): Boolean = other match {
24 | case that: EmbeddedSet => that.elements.sameElements(this.elements)
25 | case _ => false
26 | }
27 |
28 | override def toString: String = elements.mkString(", ")
29 | }
30 |
31 | class EmbeddedSetType extends UserDefinedType[EmbeddedSet] {
32 |
33 | override def sqlType: DataType = ArrayType(StringType)
34 |
35 | override def serialize(obj: EmbeddedSet): Any = {
36 | new GenericArrayData(obj.elements.map{elem =>
37 | val out = new ByteArrayOutputStream()
38 | val os = new ObjectOutputStream(out)
39 | os.writeObject(elem)
40 | UTF8String.fromBytes(out.toByteArray)
41 | })
42 | }
43 |
44 | override def deserialize(datum: Any): EmbeddedSet = {
45 | datum match {
46 | case values: ArrayData =>
47 | new EmbeddedSet(values.toArray[UTF8String](StringType).map{ elem =>
48 | val in = new ByteArrayInputStream(elem.getBytes)
49 | val is = new ObjectInputStream(in)
50 | is.readObject()
51 | })
52 | case other => sys.error(s"Cannot deserialize $other")
53 | }
54 | }
55 |
56 | override def userClass: Class[EmbeddedSet] = classOf[EmbeddedSet]
57 | }
58 |
59 | object EmbeddedSetType extends EmbeddedSetType
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/udts/LinkBagType.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.udts
2 |
3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
4 |
5 | import com.orientechnologies.orient.core.id.ORecordId
6 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
7 | import org.apache.spark.sql.types._
8 | import org.apache.spark.unsafe.types.UTF8String
9 |
10 | @SQLUserDefinedType(udt = classOf[LinkBagType])
11 | case class LinkBag(elements: Array[_ <: ORecordId]) extends Serializable {
12 | override def hashCode(): Int = {
13 | val hashCode = 1
14 | val elemValue = if (elements == null) 0 else elements.hashCode()
15 |
16 | 31 * hashCode + elemValue
17 | }
18 |
19 | override def equals(other: scala.Any): Boolean = other match {
20 | case that: LinkBag => that.elements.sameElements(this.elements)
21 | case _ => false
22 | }
23 |
24 | override def toString: String = elements.toString
25 | }
26 |
27 | class LinkBagType extends UserDefinedType[LinkBag] {
28 |
29 | override def sqlType: DataType = ArrayType(StringType)
30 |
31 | override def serialize(obj: LinkBag): Any = {
32 | new GenericArrayData(obj.elements.map{ elem =>
33 | val out = new ByteArrayOutputStream()
34 | val os = new ObjectOutputStream(out)
35 | os.writeObject(elem)
36 | UTF8String.fromBytes(out.toByteArray)
37 | })
38 | }
39 |
40 | override def deserialize(datum: Any): LinkBag = {
41 | datum match {
42 | case values: ArrayData =>
43 | new LinkBag(values.toArray[UTF8String](StringType).map{ elem =>
44 | val in = new ByteArrayInputStream(elem.getBytes)
45 | val is = new ObjectInputStream(in)
46 | is.readObject().asInstanceOf[ORecordId]
47 | })
48 | case other => sys.error(s"Cannot deserialize $other")
49 | }
50 | }
51 |
52 | override def userClass: Class[LinkBag] = classOf[LinkBag]
53 | }
54 |
55 | object LinkBagType extends LinkBagType
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/udts/LinkListType.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.udts
2 |
3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
4 |
5 | import com.orientechnologies.orient.core.record.ORecord
6 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
7 | import org.apache.spark.sql.types._
8 | import org.apache.spark.unsafe.types.UTF8String
9 |
10 | @SQLUserDefinedType(udt = classOf[LinkListType])
11 | case class LinkList(elements: Array[_ <: ORecord]) extends Serializable {
12 | override def hashCode(): Int = {
13 | var hashCode = 1
14 | val i = elements.iterator
15 | while (i.hasNext) {
16 | val obj = i.next()
17 |
18 | val elemValue = if (obj == null) 0 else obj.hashCode()
19 | hashCode = 31 * hashCode + elemValue
20 | }
21 | hashCode
22 | }
23 |
24 | override def equals(other: scala.Any): Boolean = other match {
25 | case that: LinkList => that.elements.sameElements(this.elements)
26 | case _ => false
27 | }
28 |
29 | override def toString: String = elements.mkString(", ")
30 | }
31 |
32 | class LinkListType extends UserDefinedType[LinkList] {
33 |
34 | override def sqlType: DataType = ArrayType(StringType)
35 |
36 | override def serialize(obj: LinkList): Any = {
37 | new GenericArrayData(obj.elements.map{ elem =>
38 | val out = new ByteArrayOutputStream()
39 | val os = new ObjectOutputStream(out)
40 | os.writeObject(elem)
41 | UTF8String.fromBytes(out.toByteArray)
42 | })
43 | }
44 |
45 | override def deserialize(datum: Any): LinkList = {
46 | datum match {
47 | case values: ArrayData =>
48 | new LinkList(values.toArray[UTF8String](StringType).map{ elem =>
49 | val in = new ByteArrayInputStream(elem.getBytes)
50 | val is = new ObjectInputStream(in)
51 | is.readObject().asInstanceOf[ORecord]
52 | })
53 | case other => sys.error(s"Cannot deserialize $other")
54 | }
55 | }
56 |
57 | override def userClass: Class[LinkList] = classOf[LinkList]
58 | }
59 |
60 | object LinkListType extends LinkListType
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/udts/LinkMapType.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.udts
2 |
3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
4 |
5 | import com.orientechnologies.orient.core.record.ORecord
6 | import org.apache.spark.sql.catalyst.expressions.UnsafeMapData
7 | import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
8 | import org.apache.spark.sql.types._
9 | import org.apache.spark.unsafe.types.UTF8String
10 |
11 | @SQLUserDefinedType(udt = classOf[LinkMapType])
12 | case class LinkMap(elements: Map[String, _ <: ORecord]) extends Serializable {
13 | override def hashCode(): Int = 1
14 |
15 | override def equals(other: scala.Any): Boolean = other match {
16 | case that: LinkMap => that.elements == this.elements
17 | case _ => false
18 | }
19 |
20 | override def toString: String = elements.mkString(", ")
21 | }
22 |
23 | class LinkMapType extends UserDefinedType[LinkMap] {
24 |
25 | override def sqlType: DataType = MapType(StringType, StringType)
26 |
27 | override def serialize(obj: LinkMap): Any = {
28 | ArrayBasedMapData(obj.elements.keySet.map{ elem =>
29 | val outKey = new ByteArrayOutputStream()
30 | val osKey = new ObjectOutputStream(outKey)
31 | osKey.writeObject(elem)
32 | UTF8String.fromBytes(outKey.toByteArray)
33 | }.toArray,
34 | obj.elements.values.map{ elem =>
35 | val outValue = new ByteArrayOutputStream()
36 | val osValue = new ObjectOutputStream(outValue)
37 | osValue.writeObject(elem)
38 | UTF8String.fromBytes(outValue.toByteArray)
39 | }.toArray)
40 | }
41 |
42 | override def deserialize(datum: Any): LinkMap = {
43 | datum match {
44 | case values: UnsafeMapData =>
45 | new LinkMap(values.keyArray().toArray[UTF8String](StringType).map { elem =>
46 | val in = new ByteArrayInputStream(elem.getBytes)
47 | val is = new ObjectInputStream(in)
48 | is.readObject().toString
49 | }.zip(values.valueArray().toArray[UTF8String](StringType).map { elem =>
50 | val in = new ByteArrayInputStream(elem.getBytes)
51 | val is = new ObjectInputStream(in)
52 | is.readObject().asInstanceOf[ORecord]
53 | }).toMap)
54 | case other => sys.error(s"Cannot deserialize $other")
55 | }
56 | }
57 |
58 | override def userClass: Class[LinkMap] = classOf[LinkMap]
59 | }
60 |
61 | object LinkMapType extends LinkMapType
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/udts/LinkSetType.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.udts
2 |
3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
4 |
5 | import com.orientechnologies.orient.core.record.ORecord
6 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
7 | import org.apache.spark.sql.types._
8 | import org.apache.spark.unsafe.types.UTF8String
9 |
10 | @SQLUserDefinedType(udt = classOf[LinkSetType])
11 | case class LinkSet(elements: Array[_ <: ORecord]) extends Serializable {
12 | override def hashCode(): Int = {
13 | var hashCode = 1
14 | val i = elements.iterator
15 | while (i.hasNext) {
16 | val obj = i.next()
17 |
18 | val elemValue = if (obj == null) 0 else obj.hashCode()
19 | hashCode = 31 * hashCode + elemValue
20 | }
21 | hashCode
22 | }
23 |
24 | override def equals(other: scala.Any): Boolean = other match {
25 | case that: LinkSet => that.elements.sameElements(this.elements)
26 | case _ => false
27 | }
28 |
29 | override def toString: String = elements.mkString(", ")
30 | }
31 |
32 | class LinkSetType extends UserDefinedType[LinkSet] {
33 |
34 | override def sqlType: DataType = ArrayType(StringType)
35 |
36 | override def serialize(obj: LinkSet): Any = {
37 | new GenericArrayData(obj.elements.map{elem =>
38 | val out = new ByteArrayOutputStream()
39 | val os = new ObjectOutputStream(out)
40 | os.writeObject(elem)
41 | UTF8String.fromBytes(out.toByteArray)
42 | })
43 | }
44 |
45 | override def deserialize(datum: Any): LinkSet = {
46 | datum match {
47 | case values: ArrayData =>
48 | new LinkSet(values.toArray[UTF8String](StringType).map{ elem =>
49 | val in = new ByteArrayInputStream(elem.getBytes)
50 | val is = new ObjectInputStream(in)
51 | is.readObject().asInstanceOf[ORecord]
52 | })
53 | case other => sys.error(s"Cannot deserialize $other")
54 | }
55 | }
56 |
57 | override def userClass: Class[LinkSet] = classOf[LinkSet]
58 | }
59 |
60 | object LinkSetType extends LinkSetType
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/orientdb/udts/LinkType.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.udts
2 |
3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
4 |
5 | import com.orientechnologies.orient.core.db.record.OIdentifiable
6 | import com.orientechnologies.orient.core.id.ORecordId
7 | import com.orientechnologies.orient.core.record.ORecord
8 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
9 | import org.apache.spark.sql.types._
10 | import org.apache.spark.unsafe.types.UTF8String
11 |
12 | @SQLUserDefinedType(udt = classOf[LinkType])
13 | case class Link(element: OIdentifiable) extends Serializable {
14 | override def hashCode(): Int = {
15 | var hashCode = 1
16 |
17 | val elemValue = if (element == null) 0 else element.hashCode()
18 | hashCode = 31 * hashCode + elemValue
19 | hashCode
20 | }
21 |
22 | override def equals(other: scala.Any): Boolean = other match {
23 | case that: Link => that.element.equals(this.element)
24 | case _ => false
25 | }
26 |
27 | override def toString: String = element.toString
28 | }
29 |
30 | class LinkType extends UserDefinedType[Link] {
31 |
32 | override def sqlType: DataType = ArrayType(StringType)
33 |
34 | override def serialize(obj: Link): Any = {
35 | val out = new ByteArrayOutputStream()
36 | val os = new ObjectOutputStream(out)
37 | if (obj.element.isInstanceOf[ORecord])
38 | os.writeObject(obj.element.asInstanceOf[ORecord])
39 | else
40 | os.writeObject(obj.element.asInstanceOf[ORecordId])
41 | new GenericArrayData(Array(UTF8String.fromBytes(out.toByteArray)))
42 | }
43 |
44 | override def deserialize(datum: Any): Link = {
45 | datum match {
46 | case values: ArrayData =>
47 | new Link(values.toArray[UTF8String](StringType).map { elem =>
48 | val in = new ByteArrayInputStream(elem.getBytes)
49 | val is = new ObjectInputStream(in)
50 | val data = is.readObject()
51 | if (data.isInstanceOf[ORecord]) {
52 | data.asInstanceOf[ORecord]
53 | } else {
54 | data.asInstanceOf[ORecordId]
55 | }
56 | }.head)
57 | case other => sys.error(s"Cannot deserialize $other")
58 | }
59 | }
60 |
61 | override def userClass: Class[Link] = classOf[Link]
62 | }
63 |
64 | object LinkType extends LinkType
65 |
66 |
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/QueryTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb
2 |
3 | import org.apache.spark.sql.catalyst.plans.logical
4 | import org.apache.spark.sql.{DataFrame, Row}
5 | import org.scalatest.FunSuite
6 |
7 | trait QueryTest extends FunSuite {
8 |
9 | def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = {
10 | val isSorted = df.queryExecution.logical.collect {case s: logical.Sort => s}.nonEmpty
11 | def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
12 | val converted: Seq[Row] = answer.map { s =>
13 | Row.fromSeq(s.toSeq.map {
14 | case d: java.math.BigDecimal => BigDecimal(d)
15 | case b: Array[Byte] => b.toSeq
16 | case o => o
17 | })
18 | }
19 | if (!isSorted) converted.sortBy(_.toString()) else converted
20 | }
21 |
22 | val sparkAnswer = try df.collect().toSeq catch {
23 | case e: Exception =>
24 | val errorMessage =
25 | s"""
26 | |Exception thrown while executing query:
27 | |${df.queryExecution}
28 | |== Exception ==
29 | |$e
30 | |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
31 | """.stripMargin
32 | fail(errorMessage)
33 | }
34 |
35 | if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
36 | val errorMessage =
37 | s"""
38 | |Results do not match for query:
39 | |${df.queryExecution}
40 | |== Results ==
41 | |${sideBySide(
42 | s"== Correct Answer - ${expectedAnswer.size} ==" +:
43 | prepareAnswer(expectedAnswer).map(_.toString()),
44 | s"== Spark Answer - ${sparkAnswer.size} ==" +:
45 | prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
46 | """.stripMargin
47 | fail(errorMessage)
48 | }
49 | }
50 |
51 | private def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = {
52 | val maxLeftSize = left.map(_.length).max
53 | val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("")
54 | val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("")
55 |
56 | leftPadded.zip(rightPadded).map {
57 | case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r
58 | }
59 | }
60 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/documents/ConversionsSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import com.orientechnologies.orient.core.metadata.schema.OType
4 | import com.orientechnologies.orient.core.record.impl.ODocument
5 | import org.apache.spark.sql.{Row, SparkSession}
6 | import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
7 | import org.scalatest.FunSuite
8 |
9 | class ConversionsSuite extends FunSuite {
10 |
11 | test("Spark datatype to OrientDB datatype test") {
12 | val orientDBType = Conversions.sparkDTtoOrientDBDT(StringType)
13 | assert(orientDBType === OType.STRING)
14 | }
15 |
16 | test("Convert Spark Row to Orient DB ODocument") {
17 | val expectedData = new ODocument()
18 | expectedData.field("key", 1, OType.INTEGER)
19 | expectedData.field("value", "Spark datasource for Orient DB", OType.STRING)
20 |
21 | val spark = SparkSession.builder().appName("OrientDBLinkUDTsSourceSuite")
22 | .master("local[*]")
23 | .getOrCreate()
24 |
25 | val rows = spark.sqlContext.createDataFrame(spark.sparkContext.parallelize(Seq(Row(1, "Spark datasource for Orient DB"))),
26 | StructType(Array(StructField("key", IntegerType, true),
27 | StructField("value", StringType, true)))).collect()
28 |
29 | val actualData = Conversions.convertRowsToODocuments(rows(0))
30 | assert(expectedData.field[Int]("key") == actualData.field[Int]("key"))
31 | assert(expectedData.field[String]("value") == actualData.field[String]("value"))
32 | spark.close()
33 | }
34 |
35 | test("Convert OrientDB ODocument to Spark Row") {
36 | val oDocument = new ODocument()
37 | oDocument.field("key", 1, OType.INTEGER)
38 | oDocument.field("value", "Orient DB ODocument to Spark Row", OType.STRING)
39 |
40 | val schema = StructType(Array(StructField("key", IntegerType),
41 | StructField("value", StringType)))
42 |
43 | val expectedData = Row(1, "Orient DB ODocument to Spark Row")
44 | val actualData = Conversions.convertODocumentsToRows(oDocument, schema)
45 |
46 | assert(expectedData === actualData)
47 | }
48 |
49 | test("Return field of correct type") {
50 | val field = Conversions.orientDBDTtoSparkDT(IntegerType, "1")
51 | assert(field.isInstanceOf[Int])
52 | }
53 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/documents/FilterPushdownSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import org.apache.spark.sql.sources._
4 | import org.apache.spark.sql.types._
5 | import org.scalatest.FunSuite
6 |
7 | class FilterPushdownSuite extends FunSuite {
8 |
9 | test("buildWhereClause with empty list of filters") {
10 | assert(FilterPushdown.buildWhereClause(StructType(Nil), Seq.empty) === "")
11 | }
12 |
13 | test("buildWhereClause with no filters that can be pushed down") {
14 | assert(FilterPushdown.buildWhereClause(StructType(Nil), Seq(AlwaysTrue, AlwaysTrue)) === "")
15 | }
16 |
17 | test("buildWhereClause with with some filters that cannot be pushed down") {
18 | val whereClause = FilterPushdown.buildWhereClause(testSchema, Seq(EqualTo("test_int", 1), AlwaysTrue))
19 | assert(whereClause === "WHERE test_int = 1")
20 | }
21 |
22 | test("buildWhereClause with multiple filters") {
23 | val filters = Seq(
24 | EqualTo("test_bool", true),
25 | EqualTo("test_string", "Unicode是樂趣"),
26 | GreaterThan("test_double", 1000.0),
27 | LessThan("test_double", Double.MaxValue),
28 | GreaterThanOrEqual("test_float", 1.0f),
29 | LessThanOrEqual("test_int", 43),
30 | IsNotNull("test_int"),
31 | IsNull("test_int")
32 | )
33 |
34 | val whereClause = FilterPushdown.buildWhereClause(testSchema, filters)
35 |
36 | val expectedWhereClause =
37 | """
38 | |WHERE test_bool = true
39 | |AND test_string = 'Unicode是樂趣'
40 | |AND test_double > 1000.0
41 | |AND test_double < 1.7976931348623157E308
42 | |AND test_float >= 1.0
43 | |AND test_int <= 43
44 | |AND test_int IS NOT NULL
45 | |AND test_int IS NULL
46 | """.stripMargin.linesIterator.mkString(" ").trim
47 |
48 | assert(whereClause === expectedWhereClause)
49 | }
50 |
51 | private val testSchema: StructType = StructType(Seq(
52 | StructField("test_byte", ByteType),
53 | StructField("test_bool", BooleanType),
54 | StructField("test_date", DateType),
55 | StructField("test_double", DoubleType),
56 | StructField("test_float", FloatType),
57 | StructField("test_int", IntegerType),
58 | StructField("test_long", LongType),
59 | StructField("test_short", ShortType),
60 | StructField("test_string", StringType),
61 | StructField("test_timestamp", TimestampType)
62 | ))
63 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/documents/MockOrientDBDocument.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import com.orientechnologies.orient.core.db.document.ODatabaseDocumentTx
4 | import com.orientechnologies.orient.core.record.impl.ODocument
5 | import org.apache.spark.orientdb.documents.Parameters.MergedParameters
6 | import org.apache.spark.sql.types.StructType
7 | import org.mockito.Matchers._
8 | import org.mockito.Mockito._
9 | import org.mockito.invocation.InvocationOnMock
10 | import org.mockito.stubbing.Answer
11 |
12 | class MockOrientDBDocument(existingTablesAndSchemas: Map[String, StructType],
13 | oDocuments: List[ODocument]) {
14 | val documentWrapper: OrientDBDocumentWrapper = spy(new OrientDBDocumentWrapper())
15 |
16 | doAnswer(new Answer[ODatabaseDocumentTx] {
17 | override def answer(invocationOnMock: InvocationOnMock): ODatabaseDocumentTx = {
18 | mock(classOf[ODatabaseDocumentTx], RETURNS_SMART_NULLS)
19 | }
20 | }).when(documentWrapper).getConnection(any(classOf[MergedParameters]))
21 |
22 | doAnswer(new Answer[Boolean] {
23 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
24 | existingTablesAndSchemas.contains(invocationOnMock.getArguments()(1).asInstanceOf[String])
25 | }
26 | }).when(documentWrapper).doesClassExists(any(classOf[String]))
27 |
28 | doAnswer(new Answer[Boolean] {
29 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
30 | true
31 | }
32 | }).when(documentWrapper).create(any(classOf[String]), any(classOf[String]),
33 | any(classOf[ODocument]))
34 |
35 | doAnswer(new Answer[List[ODocument]] {
36 | override def answer(invocationOnMock: InvocationOnMock): List[ODocument] = {
37 | oDocuments
38 | }
39 | }).when(documentWrapper).read(any(classOf[List[String]]), any(classOf[String]), any(classOf[Array[String]]),
40 | any(classOf[String]), any(classOf[String]))
41 |
42 | doAnswer(new Answer[Boolean] {
43 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
44 | true
45 | }
46 | }).when(documentWrapper).delete(any(classOf[String]), any(classOf[String]),
47 | any(classOf[Map[String, Tuple2[String, String]]]))
48 |
49 | doAnswer(new Answer[StructType] {
50 | override def answer(invocationOnMock: InvocationOnMock): StructType = {
51 | existingTablesAndSchemas.get(invocationOnMock.getArguments()(1).asInstanceOf[String]).get
52 | }
53 | }).when(documentWrapper).resolveTable(any(classOf[String]), any(classOf[String]))
54 |
55 | doAnswer(new Answer[List[ODocument]] {
56 | override def answer(invocationOnMock: InvocationOnMock): List[ODocument] = {
57 | oDocuments
58 | }
59 | }).when(documentWrapper).genericQuery(any(classOf[String]))
60 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/documents/OrientDBEmbeddedUDTsSourceSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import com.orientechnologies.orient.core.db.record.{OTrackedList, OTrackedMap, OTrackedSet}
4 | import com.orientechnologies.orient.core.metadata.schema.OType
5 | import com.orientechnologies.orient.core.record.impl.ODocument
6 | import org.apache.spark.orientdb.udts._
7 | import org.apache.spark.{SparkContext}
8 | import org.apache.spark.orientdb.{QueryTest, TestUtils}
9 | import org.apache.spark.sql.sources.{EqualTo, Filter, PrunedFilteredScan}
10 | import org.apache.spark.sql.types.{StructField, StructType}
11 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
12 | import org.mockito.Mockito
13 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
14 |
15 | class OrientDBEmbeddedUDTsSourceSuite extends QueryTest
16 | with BeforeAndAfterAll
17 | with BeforeAndAfterEach {
18 | private var sc: SparkContext = _
19 | private var spark: SparkSession = _
20 | private var sqlContext: SQLContext = _
21 | private var mockOrientDBClient: OrientDBClientFactory = _
22 | private var expectedDataDf: DataFrame = _
23 |
24 | override protected def beforeAll(): Unit = {
25 | spark = SparkSession.builder().appName("OrientDBLinkUDTsSourceSuite")
26 | .master("local[*]")
27 | .config("", "true")
28 | .getOrCreate()
29 | sc = spark.sparkContext
30 | }
31 |
32 | override protected def afterAll(): Unit = {
33 | if (spark != null) {
34 | spark.stop()
35 | }
36 | }
37 |
38 | override protected def beforeEach(): Unit = {
39 | sqlContext = spark.sqlContext
40 | mockOrientDBClient = Mockito.mock(classOf[OrientDBClientFactory],
41 | Mockito.RETURNS_SMART_NULLS)
42 | expectedDataDf = sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedDataForEmbeddedUDTs),
43 | TestUtils.testSchemaForEmbeddedUDTs)
44 | }
45 |
46 | override protected def afterEach(): Unit = {
47 | sqlContext = null
48 | }
49 |
50 | test("Can load output of OrientDB queries") {
51 | val query = "select embeddedset, embeddedmap from test_table"
52 |
53 | val querySchema = StructType(Seq(StructField("embeddedset", EmbeddedSetType),
54 | StructField("embeddedmap", EmbeddedMapType)))
55 |
56 | {
57 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
58 | "user" -> "root",
59 | "password" -> "root",
60 | "class" -> "test_table",
61 | "clusters" -> "test_cluster")
62 |
63 | val iSourceRecord = new ODocument()
64 | iSourceRecord.field("id", 1, OType.INTEGER)
65 |
66 | var oTrackedSet = new OTrackedSet[Int](iSourceRecord)
67 | oTrackedSet.add(1)
68 | var oTrackedMap = new OTrackedMap[Boolean](iSourceRecord)
69 | oTrackedMap.put(1, true)
70 |
71 | val oDoc1 = new ODocument()
72 | oDoc1.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET)
73 | oDoc1.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP)
74 |
75 | oTrackedSet = new OTrackedSet[Int](iSourceRecord)
76 | oTrackedSet.add(2)
77 | oTrackedMap = new OTrackedMap[Boolean](iSourceRecord)
78 | oTrackedMap.put(1, false)
79 |
80 | val oDoc2 = new ODocument()
81 | oDoc2.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET)
82 | oDoc2.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP)
83 |
84 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema),
85 | List(oDoc1, oDoc2))
86 |
87 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
88 | .createRelation(sqlContext, params)
89 | sqlContext.baseRelationToDataFrame(relation).collect()
90 | }
91 |
92 | {
93 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
94 | "user" -> "root",
95 | "password" -> "root",
96 | "class" -> "test_table",
97 | "query" -> query,
98 | "clusters" -> "test_cluster")
99 |
100 | val iSourceRecord = new ODocument()
101 | iSourceRecord.field("id", 1, OType.INTEGER)
102 |
103 | var oTrackedSet = new OTrackedSet[Int](iSourceRecord)
104 | oTrackedSet.add(1)
105 | var oTrackedMap = new OTrackedMap[Boolean](iSourceRecord)
106 | oTrackedMap.put(1, true)
107 |
108 | val oDoc1 = new ODocument()
109 | oDoc1.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET)
110 | oDoc1.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP)
111 |
112 | oTrackedSet = new OTrackedSet[Int](iSourceRecord)
113 | oTrackedSet.add(2)
114 | oTrackedMap = new OTrackedMap[Boolean](iSourceRecord)
115 | oTrackedMap.put(1, false)
116 |
117 | val oDoc2 = new ODocument()
118 | oDoc2.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET)
119 | oDoc2.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP)
120 |
121 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema),
122 | List(oDoc1, oDoc2))
123 |
124 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
125 | .createRelation(sqlContext, params)
126 | sqlContext.baseRelationToDataFrame(relation).collect()
127 | }
128 | }
129 |
130 | test("DefaultSource supports simple column filtering") {
131 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
132 | "user" -> "root",
133 | "password" -> "root",
134 | "class" -> "test_table",
135 | "clusters" -> "test_cluster")
136 |
137 | val iSourceRecord = new ODocument()
138 | iSourceRecord.field("id", 1, OType.INTEGER)
139 |
140 | var oTrackedList = new OTrackedList[Byte](iSourceRecord)
141 | oTrackedList.add(1.toByte)
142 | var oTrackedSet = new OTrackedSet[Boolean](iSourceRecord)
143 | oTrackedSet.add(true)
144 | var oTrackedMap = new OTrackedMap[String](iSourceRecord)
145 | oTrackedMap.put(1, "Hello")
146 |
147 | val oDoc1 = new ODocument()
148 | oDoc1.field("embeddedlist", oTrackedList, OType.EMBEDDEDLIST)
149 | oDoc1.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET)
150 | oDoc1.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP)
151 |
152 | oTrackedList = new OTrackedList[Byte](iSourceRecord)
153 | oTrackedList.add(2.toByte)
154 | oTrackedSet = new OTrackedSet[Boolean](iSourceRecord)
155 | oTrackedSet.add(false)
156 | oTrackedMap = new OTrackedMap[String](iSourceRecord)
157 | oTrackedMap.put(1, "World")
158 |
159 | val oDoc2 = new ODocument()
160 | oDoc2.field("embeddedlist", oTrackedList, OType.EMBEDDEDLIST)
161 | oDoc2.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET)
162 | oDoc2.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP)
163 |
164 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchemaForEmbeddedUDTs),
165 | List(oDoc1, oDoc2))
166 |
167 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
168 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForEmbeddedUDTs)
169 |
170 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
171 | .buildScan(Array("embeddedlist", "embeddedset"), Array.empty[Filter])
172 |
173 | val prunedExpectedValues = Array(
174 | Row(EmbeddedList(Array(1.toByte)), EmbeddedSet(Array(true))), Row(EmbeddedList(Array(2.toByte)), EmbeddedSet(Array(false)))
175 | )
176 |
177 | val result = rdd.collect()
178 | assert(result.length === prunedExpectedValues.length)
179 | assert(result === prunedExpectedValues)
180 | }
181 |
182 | test("DefaultSource supports user schema, pruned and filtered scans") {
183 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
184 | "user" -> "root",
185 | "password" -> "root",
186 | "class" -> "test_table",
187 | "clusters" -> "test_cluster")
188 |
189 | val iSourceRecord = new ODocument()
190 | iSourceRecord.field("id", 1, OType.INTEGER)
191 |
192 | var oTrackedList = new OTrackedList[Byte](iSourceRecord)
193 | oTrackedList.add(1.toByte)
194 | var oTrackedSet = new OTrackedSet[Boolean](iSourceRecord)
195 | oTrackedSet.add(true)
196 | var oTrackedMap = new OTrackedMap[String](iSourceRecord)
197 | oTrackedMap.put(1, "Hello")
198 |
199 | val oDoc1 = new ODocument()
200 | oDoc1.field("embeddedlist", oTrackedList, OType.EMBEDDEDLIST)
201 | oDoc1.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET)
202 | oDoc1.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP)
203 |
204 | oTrackedList = new OTrackedList[Byte](iSourceRecord)
205 | oTrackedList.add(1.toByte)
206 | oTrackedSet = new OTrackedSet[Boolean](iSourceRecord)
207 | oTrackedSet.add(false)
208 | oTrackedMap = new OTrackedMap[String](iSourceRecord)
209 | oTrackedMap.put(1, "World")
210 |
211 | val oDoc2 = new ODocument()
212 | oDoc2.field("embeddedlist", oTrackedList, OType.EMBEDDEDLIST)
213 | oDoc2.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET)
214 | oDoc2.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP)
215 |
216 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchemaForEmbeddedUDTs),
217 | List(oDoc1, oDoc2))
218 |
219 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
220 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForEmbeddedUDTs)
221 |
222 | val filters: Array[Filter] = Array(
223 | EqualTo("embeddedlist", oTrackedList),
224 | EqualTo("embeddedset", oTrackedSet)
225 | )
226 |
227 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
228 | .buildScan(Array("embeddedmap", "embeddedset"), filters)
229 |
230 | assert(rdd.collect().contains(Row(EmbeddedSet(Array(false)), EmbeddedMap(Map(1 -> "World")))))
231 | }
232 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/documents/OrientDBLinkUDTsSourceSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import java.util
4 | import com.orientechnologies.orient.core.db.record.ridbag.ORidBag
5 | import com.orientechnologies.orient.core.db.record.{ORecordLazyList, ORecordLazyMap, ORecordLazySet}
6 | import com.orientechnologies.orient.core.id.ORecordId
7 | import com.orientechnologies.orient.core.metadata.schema.OType
8 | import com.orientechnologies.orient.core.record.ORecord
9 | import com.orientechnologies.orient.core.record.impl.ODocument
10 | import org.apache.spark.orientdb.udts._
11 | import org.apache.spark.{SparkConf, SparkContext}
12 | import org.apache.spark.orientdb.{QueryTest, TestUtils}
13 | import org.apache.spark.sql.sources.{EqualTo, Filter, PrunedFilteredScan}
14 | import org.apache.spark.sql.types.{StructField, StructType}
15 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
16 | import org.mockito.Mockito
17 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
18 |
19 | class OrientDBLinkUDTsSourceSuite extends QueryTest
20 | with BeforeAndAfterAll
21 | with BeforeAndAfterEach {
22 | private var sc: SparkContext = _
23 | private var spark: SparkSession = _
24 | private var sqlContext: SQLContext = _
25 | private var mockOrientDBClient: OrientDBClientFactory = _
26 | private var expectedDataDf: DataFrame = _
27 |
28 | override protected def beforeAll(): Unit = {
29 | spark = SparkSession.builder().appName("OrientDBLinkUDTsSourceSuite")
30 | .master("local[*]")
31 | .getOrCreate()
32 | sc = spark.sparkContext;
33 | }
34 |
35 | override protected def afterAll(): Unit = {
36 | if (spark != null) {
37 | spark.close()
38 | }
39 | }
40 |
41 | override protected def beforeEach(): Unit = {
42 | sqlContext = spark.sqlContext
43 | mockOrientDBClient = Mockito.mock(classOf[OrientDBClientFactory],
44 | Mockito.RETURNS_SMART_NULLS)
45 | expectedDataDf = sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedDataForLinkUDTs),
46 | TestUtils.testSchemaForLinkUDTs)
47 | }
48 |
49 | override protected def afterEach(): Unit = {
50 | sqlContext = null
51 | }
52 |
53 | test("Can load output of OrientDB queries") {
54 | val query = "select linkset, linkmap, linkbag, link from test_link_table"
55 |
56 | val querySchema = StructType(Seq(StructField("linkset", LinkSetType),
57 | StructField("linkmap", LinkMapType),
58 | StructField("linkbag", LinkBagType),
59 | StructField("link", LinkType)))
60 |
61 | {
62 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
63 | "user" -> "root",
64 | "password" -> "root",
65 | "class" -> "test_link_table",
66 | "clusters" -> "test_link_cluster")
67 |
68 | val iSourceRecord = new ODocument()
69 | iSourceRecord.field("id", 1, OType.INTEGER)
70 |
71 | var oDoc1 = new ODocument()
72 | oDoc1.field("int", 1, OType.INTEGER)
73 | var oDoc2 = new ODocument()
74 | oDoc2.field("boolean", true)
75 | var oRid1 = new ORecordId()
76 | oRid1.fromString("#1:1")
77 | var oRid2 = new ORecordId()
78 | oRid2.fromString("#2:2")
79 |
80 | var oRecordLazySet = new ORecordLazySet(iSourceRecord)
81 | oRecordLazySet.add(oDoc1)
82 | var oRecordLazyMap = new ORecordLazyMap(iSourceRecord)
83 | oRecordLazyMap.put("1", oDoc2)
84 | var oRidBag = new ORidBag()
85 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2))
86 |
87 | val oDoc3 = new ODocument()
88 | oDoc3.field("linkset", oRecordLazySet, OType.LINKSET)
89 | oDoc3.field("linkmap", oRecordLazyMap, OType.LINKMAP)
90 | oDoc3.field("linkbag", oRidBag, OType.LINKBAG)
91 | oDoc3.field("link", oDoc1, OType.LINK)
92 |
93 | oDoc1 = new ODocument()
94 | oDoc1.field("int", 2, OType.INTEGER)
95 | oDoc2 = new ODocument()
96 | oDoc2.field("boolean", false)
97 | oRid1 = new ORecordId()
98 | oRid1.fromString("#3:3")
99 | oRid2 = new ORecordId()
100 | oRid2.fromString("#4:4")
101 |
102 | oRecordLazySet = new ORecordLazySet(iSourceRecord)
103 | oRecordLazySet.add(oDoc1)
104 | oRecordLazyMap = new ORecordLazyMap(iSourceRecord)
105 | oRecordLazyMap.put(1, oDoc2)
106 | oRidBag = new ORidBag()
107 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2))
108 |
109 | val oDoc4 = new ODocument()
110 | oDoc4.field("linkset", oRecordLazySet, OType.LINKSET)
111 | oDoc4.field("linkmap", oRecordLazyMap, OType.LINKMAP)
112 | oDoc4.field("linkbag", oRidBag, OType.LINKBAG)
113 | oDoc4.field("link", oDoc1, OType.LINK)
114 |
115 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema),
116 | List(oDoc3, oDoc4))
117 |
118 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
119 | .createRelation(sqlContext, params)
120 | sqlContext.baseRelationToDataFrame(relation).collect()
121 | }
122 |
123 | {
124 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
125 | "user" -> "root",
126 | "password" -> "root",
127 | "class" -> "test_link_table",
128 | "query" -> query,
129 | "clusters" -> "test_link_cluster")
130 |
131 | val iSourceRecord = new ODocument()
132 | iSourceRecord.field("id", 1, OType.INTEGER)
133 |
134 | var oDoc1 = new ODocument()
135 | oDoc1.field("int", 1, OType.INTEGER)
136 | var oDoc2 = new ODocument()
137 | oDoc2.field("boolean", true)
138 | var oRid1 = new ORecordId()
139 | oRid1.fromString("#1:1")
140 | var oRid2 = new ORecordId()
141 | oRid2.fromString("#2:2")
142 |
143 | var oRecordLazySet = new ORecordLazySet(iSourceRecord)
144 | oRecordLazySet.add(oDoc1)
145 | var oRecordLazyMap = new ORecordLazyMap(iSourceRecord)
146 | oRecordLazyMap.put("1", oDoc2)
147 | var oRidBag = new ORidBag()
148 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2))
149 |
150 | val oDoc3 = new ODocument()
151 | oDoc3.field("linkset", oRecordLazySet, OType.LINKSET)
152 | oDoc3.field("linkmap", oRecordLazyMap, OType.LINKMAP)
153 |
154 | oDoc1 = new ODocument()
155 | oDoc1.field("int", 2, OType.INTEGER)
156 | oDoc2 = new ODocument()
157 | oDoc2.field("boolean", false)
158 | oRid1 = new ORecordId()
159 | oRid1.fromString("#3:3")
160 | oRid2 = new ORecordId()
161 | oRid2.fromString("#4:4")
162 |
163 | oRecordLazySet = new ORecordLazySet(iSourceRecord)
164 | oRecordLazySet.add(oDoc1)
165 | oRecordLazyMap = new ORecordLazyMap(iSourceRecord)
166 | oRecordLazyMap.put(1, oDoc2)
167 | oRidBag = new ORidBag()
168 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2))
169 |
170 | val oDoc4 = new ODocument()
171 | oDoc4.field("linkset", oRecordLazySet, OType.LINKSET)
172 | oDoc4.field("linkmap", oRecordLazyMap, OType.LINKMAP)
173 | oDoc4.field("linkbag", oRidBag, OType.LINKBAG)
174 | oDoc4.field("link", oDoc1, OType.LINK)
175 |
176 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema),
177 | List(oDoc3, oDoc4))
178 |
179 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
180 | .createRelation(sqlContext, params)
181 | sqlContext.baseRelationToDataFrame(relation).collect()
182 | }
183 | }
184 |
185 | test("DefaultSource supports simple column filtering") {
186 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
187 | "user" -> "root",
188 | "password" -> "root",
189 | "class" -> "test_link_table",
190 | "clusters" -> "test_link_cluster")
191 |
192 | val iSourceRecord = new ODocument()
193 | iSourceRecord.field("id", 1, OType.INTEGER)
194 |
195 | var oDoc0 = new ODocument(new ORecordId("#1:1"))
196 | oDoc0.field("byte", 1.toByte, OType.BYTE)
197 | var oDoc1 = new ODocument(new ORecordId("#2:2"))
198 | oDoc1.field("boolean", true, OType.BOOLEAN)
199 | var oDoc2 = new ODocument(new ORecordId("#4:4"))
200 | oDoc2.field("string", "Hello")
201 | var oRid1 = new ORecordId()
202 | oRid1.fromString("#1:1")
203 | var oRid2 = new ORecordId()
204 | oRid2.fromString("#2:2")
205 |
206 | var oRecordLazyList = new ORecordLazyList(iSourceRecord)
207 | oRecordLazyList.add(oDoc0)
208 | var oRecordLazySet = new ORecordLazySet(iSourceRecord)
209 | oRecordLazySet.add(oDoc1)
210 | var oRecordLazyMap = new ORecordLazyMap(iSourceRecord)
211 | oRecordLazyMap.put("1", oDoc2)
212 | var oRidBag = new ORidBag()
213 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2))
214 |
215 | val oDoc3 = new ODocument()
216 | oDoc3.field("linklist", oRecordLazyList, OType.LINKLIST)
217 | oDoc3.field("linkset", oRecordLazySet, OType.LINKSET)
218 | oDoc3.field("linkmap", oRecordLazyMap, OType.LINKMAP)
219 | oDoc3.field("linkbag", oRidBag, OType.LINKBAG)
220 | oDoc3.field("link", oDoc1, OType.LINK)
221 |
222 | val expected1 = Row(LinkList(Array(oDoc0.asInstanceOf[ORecord])), LinkSet(Array(oDoc1.asInstanceOf[ORecord])))
223 |
224 | oDoc0 = new ODocument(new ORecordId("#1:1"))
225 | oDoc0.field("byte", 2.toByte, OType.BYTE)
226 | oDoc1 = new ODocument(new ORecordId("#2:2"))
227 | oDoc1.field("boolean", false, OType.BOOLEAN)
228 | oDoc2 = new ODocument(new ORecordId("#4:4"))
229 | oDoc2.field("string", "World")
230 | oRid1 = new ORecordId()
231 | oRid1.fromString("#3:3")
232 | oRid2 = new ORecordId()
233 | oRid2.fromString("#4:4")
234 |
235 | oRecordLazyList = new ORecordLazyList(iSourceRecord)
236 | oRecordLazyList.add(oDoc0)
237 | oRecordLazySet = new ORecordLazySet(iSourceRecord)
238 | oRecordLazySet.add(oDoc1)
239 | oRecordLazyMap = new ORecordLazyMap(iSourceRecord)
240 | oRecordLazyMap.put(1, oDoc2)
241 | oRidBag = new ORidBag()
242 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2))
243 |
244 | val oDoc4 = new ODocument()
245 | oDoc4.field("linklist", oRecordLazyList, OType.LINKLIST)
246 | oDoc4.field("linkset", oRecordLazySet, OType.LINKSET)
247 | oDoc4.field("linkmap", oRecordLazyMap, OType.LINKMAP)
248 | oDoc4.field("linkbag", oRidBag, OType.LINKBAG)
249 | oDoc4.field("link", oDoc1, OType.LINK)
250 |
251 | val expected2 = Row(LinkList(Array(oDoc0.asInstanceOf[ORecord])), LinkSet(Array(oDoc1.asInstanceOf[ORecord])))
252 |
253 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchemaForLinkUDTs),
254 | List(oDoc3, oDoc4))
255 |
256 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
257 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForLinkUDTs)
258 |
259 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
260 | .buildScan(Array("linklist", "linkset"), Array.empty[Filter])
261 |
262 | val prunedExpectedValues = Array(expected1, expected2)
263 |
264 | val result = rdd.collect()
265 | assert(result.length === prunedExpectedValues.length)
266 | assert(result === prunedExpectedValues)
267 | }
268 |
269 | test("DefaultSource supports user schema, pruned and filtered scans") {
270 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
271 | "user" -> "root",
272 | "password" -> "root",
273 | "class" -> "test_link_table",
274 | "clusters" -> "test_link_cluster")
275 |
276 | val iSourceRecord = new ODocument()
277 | iSourceRecord.field("id", 1, OType.INTEGER)
278 |
279 | var oDoc0 = new ODocument(new ORecordId("#1:1"))
280 | oDoc0.field("byte", 1.toByte, OType.BYTE)
281 | var oDoc1 = new ODocument(new ORecordId("#2:2"))
282 | oDoc1.field("boolean", true, OType.BOOLEAN)
283 | var oDoc2 = new ODocument(new ORecordId("#4:4"))
284 | oDoc2.field("string", "Hello")
285 | var oRid1 = new ORecordId()
286 | oRid1.fromString("#1:1")
287 | var oRid2 = new ORecordId()
288 | oRid2.fromString("#2:2")
289 |
290 | var oRecordLazyList = new ORecordLazyList(iSourceRecord)
291 | oRecordLazyList.add(oDoc0)
292 | var oRecordLazySet = new ORecordLazySet(iSourceRecord)
293 | oRecordLazySet.add(oDoc1)
294 | var oRecordLazyMap = new ORecordLazyMap(iSourceRecord)
295 | oRecordLazyMap.put("1", oDoc2)
296 | var oRidBag = new ORidBag()
297 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2))
298 |
299 | val oDoc3 = new ODocument()
300 | oDoc3.field("linklist", oRecordLazyList, OType.LINKLIST)
301 | oDoc3.field("linkset", oRecordLazySet, OType.LINKSET)
302 | oDoc3.field("linkmap", oRecordLazyMap, OType.LINKMAP)
303 | oDoc3.field("linkbag", oRidBag, OType.LINKBAG)
304 | oDoc3.field("link", oDoc1, OType.LINK)
305 |
306 | oDoc0 = new ODocument(new ORecordId("#1:1"))
307 | oDoc0.field("byte", 2.toByte, OType.BYTE)
308 | oDoc1 = new ODocument(new ORecordId("#2:2"))
309 | oDoc1.field("boolean", false, OType.BOOLEAN)
310 | oDoc2 = new ODocument(new ORecordId("#4:4"))
311 | oDoc2.field("string", "World")
312 | oRid1 = new ORecordId()
313 | oRid1.fromString("#1:1")
314 | oRid2 = new ORecordId()
315 | oRid2.fromString("#2:2")
316 |
317 | oRecordLazyList = new ORecordLazyList(iSourceRecord)
318 | oRecordLazyList.add(oDoc0)
319 | oRecordLazySet = new ORecordLazySet(iSourceRecord)
320 | oRecordLazySet.add(oDoc1)
321 | oRecordLazyMap = new ORecordLazyMap(iSourceRecord)
322 | oRecordLazyMap.put(1, oDoc2)
323 | oRidBag = new ORidBag()
324 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2))
325 |
326 | val oDoc4 = new ODocument()
327 | oDoc4.field("linklist", oRecordLazyList, OType.LINKLIST)
328 | oDoc4.field("linkset", oRecordLazySet, OType.LINKSET)
329 | oDoc4.field("linkmap", oRecordLazyMap, OType.LINKMAP)
330 | oDoc4.field("linkbag", oRidBag, OType.LINKBAG)
331 | oDoc4.field("link", oDoc1, OType.LINK)
332 |
333 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchemaForLinkUDTs),
334 | List(oDoc3, oDoc4))
335 |
336 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
337 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForLinkUDTs)
338 |
339 | val filters: Array[Filter] = Array(
340 | EqualTo("linklist", oRecordLazyList),
341 | EqualTo("linkset", oRecordLazySet)
342 | )
343 |
344 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
345 | .buildScan(Array("linklist", "linkset", "linkbag", "link"), filters)
346 |
347 | assert(rdd.collect().contains(Row(LinkList(Array(oDoc0)), LinkSet(Array(oDoc1)), LinkBag(Array(oRid1, oRid2)), Link(oDoc1))))
348 | }
349 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/documents/OrientDBSourceSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import com.orientechnologies.orient.core.metadata.schema.OType
4 | import com.orientechnologies.orient.core.record.impl.ODocument
5 | import org.apache.spark.{SparkConf, SparkContext}
6 | import org.apache.spark.orientdb.{QueryTest, TestUtils}
7 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
8 | import org.apache.spark.sql.sources.{EqualTo, Filter, PrunedFilteredScan}
9 | import org.apache.spark.sql.types.{BooleanType, ByteType, StructField, StructType}
10 | import org.mockito.Mockito
11 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
12 |
13 | class OrientDBSourceSuite extends QueryTest
14 | with BeforeAndAfterAll
15 | with BeforeAndAfterEach {
16 | private var sc: SparkContext = _
17 | private var spark: SparkSession = _
18 | private var sqlContext: SQLContext = _
19 | private var mockOrientDBClient: OrientDBClientFactory = _
20 | private var expectedDataDf: DataFrame = _
21 |
22 | override def beforeAll(): Unit = {
23 | spark = SparkSession.builder().appName("OrientDBLinkUDTsSourceSuite")
24 | .master("local[*]")
25 | .getOrCreate()
26 | sc = spark.sparkContext;
27 | }
28 |
29 | override def afterAll(): Unit = {
30 | if (spark != null)
31 | spark.close()
32 | }
33 |
34 | override def beforeEach(): Unit = {
35 | sqlContext = spark.sqlContext
36 | mockOrientDBClient = Mockito.mock(classOf[OrientDBClientFactory],
37 | Mockito.RETURNS_SMART_NULLS)
38 | expectedDataDf = sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData),
39 | TestUtils.testSchema)
40 | }
41 |
42 | override def afterEach(): Unit = {
43 | sqlContext = null
44 | }
45 |
46 | test("Can load output of OrientDB queries") {
47 | val query =
48 | """select testbyte, testbool from test_table where teststring = '\\Unicode''s樂趣'"""
49 |
50 | val querySchema = StructType(Seq(StructField("testbyte", ByteType),
51 | StructField("testbool", BooleanType)))
52 |
53 | {
54 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
55 | "user" -> "root",
56 | "password" -> "root",
57 | "class" -> "test_table",
58 | "clusters" -> "test_cluster")
59 |
60 | val oDoc1 = new ODocument()
61 | oDoc1.field("testbyte", 1, OType.BYTE)
62 | oDoc1.field("testbool", true, OType.BOOLEAN)
63 |
64 | val oDoc2 = new ODocument()
65 | oDoc2.field("testbyte", 2, OType.BYTE)
66 | oDoc2.field("testbool", false, OType.BOOLEAN)
67 |
68 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema),
69 | List(oDoc1, oDoc2))
70 |
71 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
72 | .createRelation(sqlContext, params)
73 | sqlContext.baseRelationToDataFrame(relation).collect()
74 | }
75 |
76 | {
77 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
78 | "user" -> "root",
79 | "password" -> "root",
80 | "class" -> "test_table",
81 | "query" -> query,
82 | "clusters" -> "test_cluster")
83 |
84 | val oDoc1 = new ODocument()
85 | oDoc1.field("testbyte", 1, OType.BYTE)
86 | oDoc1.field("testbool", true, OType.BOOLEAN)
87 |
88 | val oDoc2 = new ODocument()
89 | oDoc2.field("testbyte", 2, OType.BYTE)
90 | oDoc2.field("testbool", false, OType.BOOLEAN)
91 |
92 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema),
93 | List(oDoc1, oDoc2))
94 |
95 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
96 | .createRelation(sqlContext, params)
97 | sqlContext.baseRelationToDataFrame(relation).collect()
98 | }
99 | }
100 |
101 | test("DefaultSource supports simple column filtering") {
102 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
103 | "user" -> "root",
104 | "password" -> "root",
105 | "class" -> "test_table",
106 | "clusters" -> "test_cluster")
107 |
108 | val oDoc1 = new ODocument()
109 | oDoc1.field("testbyte", 1, OType.BYTE)
110 | oDoc1.field("testbool", true, OType.BOOLEAN)
111 | oDoc1.field("teststring", "Hello", OType.STRING)
112 |
113 | val oDoc2 = new ODocument()
114 | oDoc2.field("testbyte", 2, OType.BYTE)
115 | oDoc2.field("testbool", false, OType.BOOLEAN)
116 | oDoc2.field("teststring", "World", OType.STRING)
117 |
118 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchema),
119 | List(oDoc1, oDoc2))
120 |
121 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
122 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchema)
123 |
124 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
125 | .buildScan(Array("testbyte", "testbool"), Array.empty[Filter])
126 |
127 | val prunedExpectedValues = Array(
128 | Row(1.toByte, true),
129 | Row(2.toByte, false)
130 | )
131 | assert(rdd.collect() === prunedExpectedValues)
132 | }
133 |
134 | test("DefaultSource supports user schema, pruned and filtered scans") {
135 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
136 | "user" -> "root",
137 | "password" -> "root",
138 | "class" -> "test_table",
139 | "clusters" -> "test_cluster")
140 |
141 | val oDoc1 = new ODocument()
142 | oDoc1.field("testbyte", 1, OType.BYTE)
143 | oDoc1.field("testbool", true, OType.BOOLEAN)
144 | oDoc1.field("teststring", "Hello", OType.STRING)
145 |
146 | val oDoc2 = new ODocument()
147 | oDoc2.field("testbyte", 2, OType.BYTE)
148 | oDoc2.field("testbool", false, OType.BOOLEAN)
149 | oDoc2.field("teststring", "World", OType.STRING)
150 |
151 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchema),
152 | List(oDoc1, oDoc2))
153 |
154 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient)
155 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchema)
156 |
157 | val filters: Array[Filter] = Array(
158 | EqualTo("testbool", true),
159 | EqualTo("teststring", "Hello")
160 | )
161 |
162 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
163 | .buildScan(Array("testbyte", "testbool"), filters)
164 |
165 | assert(rdd.collect().contains(Row(1, true)))
166 | }
167 |
168 | test("Cannot save when 'query' parameter is specified instead of 'class'") {
169 | val invalidParams = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
170 | "user" -> "root",
171 | "password" -> "root",
172 | "query" -> "select testbyte, testbool from test_table where teststring = '\\Unicode''s樂趣'",
173 | "clusters" -> "test_cluster"
174 | )
175 |
176 | intercept[IllegalArgumentException] {
177 | expectedDataDf.write.format("org.apache.spark.orientdb.documents").options(invalidParams).save()
178 | }
179 | }
180 |
181 | test("DefaultSource has default constructor, required by Data Source API") {
182 | new DefaultSource()
183 | }
184 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/documents/ParametersSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import org.scalatest.{FunSuite, Matchers}
4 |
5 | class ParametersSuite extends FunSuite with Matchers {
6 |
7 | test("Minimal valid parameter map is accepted") {
8 | val params = Map(
9 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
10 | "user" -> "root",
11 | "password" -> "root",
12 | "class" -> "test_class"
13 | )
14 |
15 | val mergedParams = Parameters.mergeParameters(params)
16 |
17 | mergedParams.dbUrl shouldBe params.get("dburl")
18 | mergedParams.credentials.get._1 shouldBe params.get("user").get
19 | mergedParams.credentials.get._2 shouldBe params.get("password").get
20 | mergedParams.className.get shouldBe params.get("class").get
21 |
22 | Parameters.DEFAULT_PARAMETERS foreach {
23 | case (key, value) => mergedParams.parameters(key) shouldBe value
24 | }
25 | }
26 |
27 | test("Errors are thrown when mandatory parameters are not provided") {
28 | def checkMerge(params: Map[String, String]): Unit = {
29 | intercept[IllegalArgumentException] {
30 | Parameters.mergeParameters(params)
31 | }
32 | }
33 |
34 | val testURL = "remote:127.0.0.1:2424/GratefulDeadConcerts"
35 | checkMerge(Map("dburl" -> testURL, "class" -> "test_class"))
36 | checkMerge(Map("dburl" -> testURL, "user" -> "root", "password" -> "root"))
37 | checkMerge(Map("user" -> "root", "password" -> "root"))
38 | }
39 |
40 | test("Must specify either 'class' param, or, 'class' and 'query' parameter, or 'query' parameter and " +
41 | "user-defined schema") {
42 | intercept[IllegalArgumentException] {
43 | Parameters.mergeParameters(Map(
44 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
45 | "user" -> "root",
46 | "password" -> "root"
47 | ))
48 | }
49 |
50 | Parameters.mergeParameters(Map(
51 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
52 | "user" -> "root",
53 | "password" -> "root",
54 | "class" -> "test_class"
55 | ))
56 |
57 | Parameters.mergeParameters(Map(
58 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
59 | "user" -> "root",
60 | "password" -> "root",
61 | "class" -> "test_class",
62 | "query" -> "select * from test_class"
63 | ))
64 | }
65 |
66 | test("Must specify 'url' parameter and 'user' and 'password' parameters") {
67 | intercept[IllegalArgumentException] {
68 | Parameters.mergeParameters(Map(
69 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
70 | "class" -> "test_class"
71 | ))
72 | }
73 |
74 | intercept[IllegalArgumentException] {
75 | Parameters.mergeParameters(Map(
76 | "user" -> "root",
77 | "password" -> "root",
78 | "class" -> "test_class"
79 | ))
80 | }
81 |
82 | Parameters.mergeParameters(Map(
83 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
84 | "user" -> "root",
85 | "password" -> "root",
86 | "class" -> "test_class"
87 | ))
88 | }
89 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/documents/TableNameSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.documents
2 |
3 | import org.scalatest.FunSuite
4 |
5 | class TableNameSuite extends FunSuite {
6 |
7 | test("escaped table name") {
8 | val tableName = new TableName("test_table")
9 | assert(tableName.unescapedTableName === "test_table")
10 | }
11 |
12 | test("table Name to String") {
13 | val tableName = new TableName("test_table")
14 | assert(tableName.toString === "test_table")
15 | }
16 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/graphs/MockEdge.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import java.util
4 |
5 | import com.orientechnologies.orient.core.db.record.OIdentifiable
6 | import com.tinkerpop.blueprints.impls.orient.{OrientBaseGraph, OrientEdge}
7 |
8 | class MockEdge(graph: OrientBaseGraph, record: OIdentifiable)
9 | extends OrientEdge(graph, record) {
10 |
11 | override def setProperty(key: String, value: Object): Unit = {
12 | this.getRecord().field(key, value)
13 | }
14 |
15 | override def getPropertyKeys: util.Set[String] = {
16 | val fieldNames = this.getRecord().fieldNames()
17 | val length = fieldNames.length
18 |
19 | var count = 0
20 | val result = new util.HashSet[String]()
21 | while (count < length) {
22 | result.add(fieldNames(count))
23 | count = count + 1
24 | }
25 | result
26 | }
27 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/graphs/MockOrientDBGraph.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import com.tinkerpop.blueprints.{Edge, Vertex}
4 | import com.tinkerpop.blueprints.impls.orient.OrientGraphNoTx
5 | import org.apache.spark.orientdb.graphs.Parameters.MergedParameters
6 | import org.apache.spark.sql.types.StructType
7 | import org.mockito.Matchers._
8 | import org.mockito.Mockito._
9 | import org.mockito.invocation.InvocationOnMock
10 | import org.mockito.stubbing.Answer
11 |
12 | class MockOrientDBGraph(existingTablesAndSchemas: Map[String, StructType],
13 | oVertices: List[Vertex] = null,
14 | oEdges: List[Edge] = null) {
15 | val vertexWrapper: OrientDBGraphVertexWrapper = spy(new OrientDBGraphVertexWrapper())
16 | val edgeWrapper: OrientDBGraphEdgeWrapper = spy(new OrientDBGraphEdgeWrapper())
17 |
18 | doAnswer(new Answer[OrientGraphNoTx] {
19 | override def answer(invocationOnMock: InvocationOnMock): OrientGraphNoTx = {
20 | mock(classOf[OrientGraphNoTx], RETURNS_SMART_NULLS)
21 | }
22 | }).when(vertexWrapper).getConnection(any(classOf[MergedParameters]))
23 |
24 | doAnswer(new Answer[OrientGraphNoTx] {
25 | override def answer(invocationOnMock: InvocationOnMock): OrientGraphNoTx = {
26 | mock(classOf[OrientGraphNoTx], RETURNS_SMART_NULLS)
27 | }
28 | }).when(edgeWrapper).getConnection(any(classOf[MergedParameters]))
29 |
30 | doAnswer(new Answer[Boolean] {
31 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
32 | existingTablesAndSchemas.contains(invocationOnMock.getArguments()(1).asInstanceOf[String])
33 | }
34 | }).when(vertexWrapper).doesVertexTypeExists(any(classOf[String]))
35 |
36 | doAnswer(new Answer[Boolean] {
37 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
38 | existingTablesAndSchemas.contains(invocationOnMock.getArguments()(1).asInstanceOf[String])
39 | }
40 | }).when(edgeWrapper).doesEdgeTypeExists(any(classOf[String]))
41 |
42 | doAnswer(new Answer[Boolean] {
43 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
44 | true
45 | }
46 | }).when(vertexWrapper).create(any(classOf[String]), any(classOf[Map[String, Object]]))
47 |
48 | doAnswer(new Answer[Boolean] {
49 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
50 | true
51 | }
52 | }).when(edgeWrapper).create(any(classOf[String]), any(classOf[Vertex]),
53 | any(classOf[Vertex]), any(classOf[Map[String, Object]]))
54 |
55 | doAnswer(new Answer[List[Vertex]] {
56 | override def answer(invocationOnMock: InvocationOnMock): List[Vertex] = {
57 | oVertices
58 | }
59 | }).when(vertexWrapper).read(any(classOf[String]), any(classOf[Array[String]]),
60 | any(classOf[String]), any(classOf[String]))
61 |
62 | doAnswer(new Answer[List[Edge]] {
63 | override def answer(invocationOnMock: InvocationOnMock): List[Edge] = {
64 | oEdges
65 | }
66 | }).when(edgeWrapper).read(any(classOf[String]), any(classOf[Array[String]]),
67 | any(classOf[String]), any(classOf[String]))
68 |
69 | doAnswer(new Answer[Boolean] {
70 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
71 | true
72 | }
73 | }).when(vertexWrapper).delete(any(classOf[String]),
74 | any(classOf[Map[String, Tuple2[String, String]]]))
75 |
76 | doAnswer(new Answer[Boolean] {
77 | override def answer(invocationOnMock: InvocationOnMock): Boolean = {
78 | true
79 | }
80 | }).when(edgeWrapper).delete(any(classOf[String]),
81 | any(classOf[Map[String, Tuple2[String, String]]]))
82 |
83 | doAnswer(new Answer[StructType] {
84 | override def answer(invocationOnMock: InvocationOnMock): StructType = {
85 | existingTablesAndSchemas
86 | .get(invocationOnMock.getArguments()(0).asInstanceOf[String]).get
87 | }
88 | }).when(vertexWrapper).resolveTable(any(classOf[String]))
89 |
90 | doAnswer(new Answer[StructType] {
91 | override def answer(invocationOnMock: InvocationOnMock): StructType = {
92 | existingTablesAndSchemas
93 | .get(invocationOnMock.getArguments()(0).asInstanceOf[String]).get
94 | }
95 | }).when(edgeWrapper).resolveTable(any(classOf[String]))
96 |
97 | doAnswer(new Answer[List[Vertex]] {
98 | override def answer(invocationOnMock: InvocationOnMock): List[Vertex] = {
99 | oVertices
100 | }
101 | }).when(vertexWrapper).genericQuery(any(classOf[String]))
102 |
103 | doAnswer(new Answer[List[Edge]] {
104 | override def answer(invocationOnMock: InvocationOnMock): List[Edge] = {
105 | oEdges
106 | }
107 | }).when(edgeWrapper).genericQuery(any(classOf[String]))
108 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/graphs/MockVertex.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import com.orientechnologies.orient.core.db.record.OIdentifiable
4 | import com.tinkerpop.blueprints.impls.orient.{OrientBaseGraph, OrientVertex}
5 |
6 | class MockVertex(graph: OrientBaseGraph, record: OIdentifiable)
7 | extends OrientVertex(graph, record) {
8 | override def setProperty(key: String, value: Object): Unit = {
9 | this.getRecord().field(key, value)
10 | }
11 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/graphs/OrientDBGraphSourceSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import com.orientechnologies.orient.core.record.impl.ODocument
4 | import com.tinkerpop.blueprints.impls.orient.OrientBaseGraph
5 | import org.apache.spark.{SparkConf, SparkContext}
6 | import org.apache.spark.orientdb.{QueryTest, TestUtils}
7 | import org.apache.spark.sql.sources.{EqualTo, Filter, PrunedFilteredScan}
8 | import org.apache.spark.sql.types._
9 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
10 | import org.mockito.Mockito
11 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
12 |
13 | class OrientDBGraphSourceSuite extends QueryTest
14 | with BeforeAndAfterAll
15 | with BeforeAndAfterEach {
16 | private var sc: SparkContext = _
17 | private var spark: SparkSession = _
18 | private var sqlContext: SQLContext = _
19 | private var mockOrientDBClient: OrientDBClientFactory = _
20 | private var expectedDataDfVertices: DataFrame = _
21 | private var expectedDataDfEdges: DataFrame = _
22 |
23 | override def beforeAll(): Unit = {
24 | spark = SparkSession.builder().appName("OrientDBSourceSuite")
25 | .master("local[*]")
26 | .getOrCreate()
27 | sc = spark.sparkContext;
28 | }
29 |
30 | override def afterAll(): Unit = {
31 | if (spark != null) {
32 | spark.close()
33 | }
34 | }
35 |
36 | override def beforeEach(): Unit = {
37 | sqlContext = spark.sqlContext
38 | mockOrientDBClient = Mockito.mock(classOf[OrientDBClientFactory],
39 | Mockito.RETURNS_SMART_NULLS)
40 | expectedDataDfVertices = sqlContext.createDataFrame(
41 | sc.parallelize(TestUtils.expectedDataForVertices),
42 | TestUtils.testSchemaForVertices)
43 | expectedDataDfEdges = sqlContext.createDataFrame(
44 | sc.parallelize(TestUtils.expectedDataForEdges),
45 | TestUtils.testSchemaForEdges)
46 | }
47 |
48 | override def afterEach(): Unit = {
49 | sqlContext = null
50 | }
51 |
52 | test("Can load output of OrientDB Graph queries on vertices") {
53 | val query =
54 | "select testbyte, testbool from test_vertex where teststring = '\\Unicode''s樂趣'"
55 |
56 | val querySchema = StructType(Seq(StructField("testbyte", ByteType, true),
57 | StructField("testbool", BooleanType, true)))
58 |
59 | {
60 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
61 | "user" -> "root",
62 | "password" -> "root",
63 | "vertextype" -> "test_vertex")
64 |
65 | val oVertex1 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]),
66 | new ODocument())
67 | oVertex1.setProperty("id", new Integer(1))
68 | oVertex1.setProperty("testbyte", new java.lang.Byte(1.toByte))
69 | oVertex1.setProperty("testbool", new java.lang.Boolean(true))
70 |
71 | val oVertex2 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]),
72 | new ODocument())
73 | oVertex2.setProperty("id", new Integer(2))
74 | oVertex2.setProperty("testbyte", new java.lang.Byte(2.toByte))
75 | oVertex2.setProperty("testbool", new java.lang.Boolean(false))
76 |
77 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("vertextype") -> querySchema),
78 | List(oVertex1, oVertex2))
79 |
80 | val relation = new DefaultSource(mockOrientDBGraph.vertexWrapper, null, _ => mockOrientDBClient)
81 | .createRelation(sqlContext, params)
82 | sqlContext.baseRelationToDataFrame(relation).collect()
83 | }
84 |
85 | {
86 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
87 | "user" -> "root",
88 | "password" -> "root",
89 | "vertextype" -> "test_vertex",
90 | "query" -> query)
91 |
92 | val oVertex1 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]),
93 | new ODocument())
94 | oVertex1.setProperty("id", new Integer(1))
95 | oVertex1.setProperty("testbyte", new java.lang.Byte(1.toByte))
96 | oVertex1.setProperty("testbool", new java.lang.Boolean(true))
97 |
98 | val oVertex2 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]),
99 | new ODocument())
100 | oVertex2.setProperty("id", new Integer(2))
101 | oVertex2.setProperty("testbyte", new java.lang.Byte(2.toByte))
102 | oVertex2.setProperty("testbool", new java.lang.Boolean(false))
103 |
104 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("vertextype") -> querySchema),
105 | List(oVertex1, oVertex2))
106 |
107 | val relation = new DefaultSource(mockOrientDBGraph.vertexWrapper, null, _ => mockOrientDBClient)
108 | .createRelation(sqlContext, params)
109 | sqlContext.baseRelationToDataFrame(relation).collect()
110 | }
111 | }
112 |
113 | test("Can load output of OrientDB Graph queries on edges") {
114 | val query =
115 | "select relationship from test_edge where relationship = 'enemy'"
116 |
117 | val querySchema = StructType(Seq(StructField("relationship", StringType)))
118 |
119 | {
120 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
121 | "user" -> "root",
122 | "password" -> "root",
123 | "edgetype" -> "test_edge")
124 |
125 | val oEdge1 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument())
126 | oEdge1.setProperty("src", new Integer(1))
127 | oEdge1.setProperty("dst", new Integer(2))
128 | oEdge1.setProperty("relationship", new String("enemy"))
129 |
130 | val oEdge2 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument())
131 | oEdge2.setProperty("src", new Integer(1))
132 | oEdge2.setProperty("dst", new Integer(2))
133 | oEdge2.setProperty("relationship", new String("friend"))
134 |
135 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("edgetype") -> querySchema), null,
136 | List(oEdge1, oEdge2))
137 |
138 | val relation = new DefaultSource(null, mockOrientDBGraph.edgeWrapper, _ => mockOrientDBClient)
139 | .createRelation(sqlContext, params)
140 | sqlContext.baseRelationToDataFrame(relation).collect()
141 | }
142 |
143 | {
144 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
145 | "user" -> "root",
146 | "password" -> "root",
147 | "edgetype" -> "test_edge",
148 | "query" -> query)
149 |
150 | val oEdge1 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument())
151 | oEdge1.setProperty("src", new Integer(1))
152 | oEdge1.setProperty("dst", new Integer(2))
153 | oEdge1.setProperty("relationship", new String("enemy"))
154 |
155 | val oEdge2 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument())
156 | oEdge2.setProperty("src", new Integer(1))
157 | oEdge2.setProperty("dst", new Integer(2))
158 | oEdge2.setProperty("relationship", new String("friend"))
159 |
160 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("edgetype") -> querySchema), null,
161 | List(oEdge1, oEdge2))
162 |
163 | val relation = new DefaultSource(null, mockOrientDBGraph.edgeWrapper, _ => mockOrientDBClient)
164 | .createRelation(sqlContext, params)
165 | sqlContext.baseRelationToDataFrame(relation).collect()
166 | }
167 | }
168 |
169 | test("DefaultSource supports simple column filtering for Vertices") {
170 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
171 | "user" -> "root",
172 | "password" -> "root",
173 | "vertextype" -> "test_vertex")
174 |
175 | val oVertex1 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]),
176 | new ODocument())
177 | oVertex1.setProperty("id", new Integer(1))
178 | oVertex1.setProperty("testbyte", new java.lang.Byte(1.toByte))
179 | oVertex1.setProperty("testbool", new java.lang.Boolean(true))
180 |
181 | val oVertex2 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]),
182 | new ODocument())
183 | oVertex2.setProperty("id", new Integer(2))
184 | oVertex2.setProperty("testbyte", new java.lang.Byte(2.toByte))
185 | oVertex2.setProperty("testbool", new java.lang.Boolean(false))
186 |
187 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("vertextype") -> TestUtils.testSchemaForVertices),
188 | List(oVertex1, oVertex2))
189 |
190 | val source = new DefaultSource(mockOrientDBGraph.vertexWrapper, null, _ => mockOrientDBClient)
191 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForVertices)
192 |
193 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
194 | .buildScan(Array("testbyte", "testbool"), Array.empty[Filter])
195 |
196 | val prunedExpectedValues = Array(
197 | Row(1.toByte, true),
198 | Row(2.toByte, false)
199 | )
200 | assert(rdd.collect() === prunedExpectedValues)
201 | }
202 |
203 | test("DefaultSource supports simple column filtering for Edges") {
204 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
205 | "user" -> "root",
206 | "password" -> "root",
207 | "edgetype" -> "test_edge")
208 |
209 | val oEdge1 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument())
210 | oEdge1.setProperty("src", new Integer(1))
211 | oEdge1.setProperty("dst", new Integer(2))
212 | oEdge1.setProperty("relationship", new String("enemy"))
213 |
214 | val oEdge2 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument())
215 | oEdge2.setProperty("src", new Integer(1))
216 | oEdge2.setProperty("dst", new Integer(2))
217 | oEdge2.setProperty("relationship", new String("friend"))
218 |
219 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("edgetype") -> TestUtils.testSchemaForEdges),
220 | null, List(oEdge1, oEdge2))
221 |
222 | val source = new DefaultSource(null, mockOrientDBGraph.edgeWrapper, _ => mockOrientDBClient)
223 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForEdges)
224 |
225 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
226 | .buildScan(Array("relationship"), Array.empty[Filter])
227 |
228 | val prunedExpectedValues = Array(
229 | Row("enemy"), Row("friend")
230 | )
231 | assert(rdd.collect() === prunedExpectedValues)
232 | }
233 |
234 | test("DefaultSource supports user schema, pruned and filtered scans for Vertices") {
235 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
236 | "user" -> "root",
237 | "password" -> "root",
238 | "vertextype" -> "test_vertex")
239 |
240 | val oVertex1 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]),
241 | new ODocument())
242 | oVertex1.setProperty("id", new Integer(1))
243 | oVertex1.setProperty("testbyte", new java.lang.Byte(1.toByte))
244 | oVertex1.setProperty("testbool", new java.lang.Boolean(true))
245 |
246 | val oVertex2 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]),
247 | new ODocument())
248 | oVertex2.setProperty("id", new Integer(2))
249 | oVertex2.setProperty("testbyte", new java.lang.Byte(2.toByte))
250 | oVertex2.setProperty("testbool", new java.lang.Boolean(false))
251 |
252 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("vertextype") -> TestUtils.testSchemaForVertices),
253 | List(oVertex1, oVertex2))
254 |
255 | val source = new DefaultSource(mockOrientDBGraph.vertexWrapper, null, _ => mockOrientDBClient)
256 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForVertices)
257 |
258 | val filters: Array[Filter] = Array(
259 | EqualTo("testbool", true),
260 | EqualTo("teststring", "Hello")
261 | )
262 |
263 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
264 | .buildScan(Array("testbyte", "testbool"), filters)
265 |
266 | assert(rdd.collect().contains(Row(1, true)))
267 | }
268 |
269 | test("DefaultSource supports user schema, pruned and filtered scans for Edges") {
270 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
271 | "user" -> "root",
272 | "password" -> "root",
273 | "edgetype" -> "test_edge")
274 |
275 | val oEdge2 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument())
276 | oEdge2.setProperty("src", new Integer(1))
277 | oEdge2.setProperty("dst", new Integer(2))
278 | oEdge2.setProperty("relationship", new String("friend"))
279 |
280 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("edgetype") -> TestUtils.testSchemaForEdges),
281 | null, List(oEdge2))
282 |
283 | val source = new DefaultSource(null, mockOrientDBGraph.edgeWrapper, _ => mockOrientDBClient)
284 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForEdges)
285 |
286 | val filters: Array[Filter] = Array(
287 | EqualTo("relationship", "friend")
288 | )
289 |
290 | val rdd = relation.asInstanceOf[PrunedFilteredScan]
291 | .buildScan(Array("relationship"), filters)
292 |
293 | val prunedExpectedValues = Array(
294 | Row("friend")
295 | )
296 | assert(rdd.collect() === prunedExpectedValues)
297 | }
298 |
299 | test("Cannot save when 'query' parameter is specified instead of 'vertextype' for Vertices") {
300 | val invalidParams = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
301 | "user" -> "root",
302 | "password" -> "root",
303 | "query" -> "select testbyte, testbool from test_vertex where teststring = '\\Unicode''s樂趣'")
304 |
305 | intercept[IllegalArgumentException] {
306 | expectedDataDfVertices.write.format("org.apache.spark.orientdb.graphs")
307 | .options(invalidParams).save()
308 | }
309 | }
310 |
311 | test("Cannot save when 'query' parameter is specified instead of 'edgetype' for Edges") {
312 | val invalidParams = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
313 | "user" -> "root",
314 | "password" -> "root",
315 | "query" -> "select relationship from test_edge where relationship = 'enemy'")
316 |
317 | intercept[IllegalArgumentException] {
318 | expectedDataDfEdges.write.format("org.apache.spark.orientdb.graphs")
319 | .options(invalidParams).save()
320 | }
321 | }
322 |
323 | test("DefaultSource has default constructor, required by Data Source API") {
324 | new DefaultSource()
325 | }
326 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/orientdb/graphs/ParametersSuite.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.orientdb.graphs
2 |
3 | import org.scalatest.{FunSuite, Matchers}
4 |
5 | class ParametersSuite extends FunSuite with Matchers {
6 |
7 | test("Minimal valid parameter map is accepted for vertices") {
8 | val params = Map(
9 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
10 | "user" -> "root",
11 | "password" -> "root",
12 | "vertextype" -> "test_vertex"
13 | )
14 |
15 | val mergedParams = Parameters.mergeParameters(params)
16 |
17 | mergedParams.dbUrl.get shouldBe params("dburl")
18 | mergedParams.credentials.get._1 shouldBe params("user")
19 | mergedParams.credentials.get._2 shouldBe params("password")
20 | mergedParams.vertexType.get shouldBe params("vertextype")
21 |
22 | Parameters.DEFAULT_PARAMETERS.foreach {
23 | case (key, value) => mergedParams.parameters(key) shouldBe value
24 | }
25 | }
26 |
27 | test("Minimal valid parameter map is accepted for edges") {
28 | val params = Map(
29 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
30 | "user" -> "root",
31 | "password" -> "root",
32 | "edgetype" -> "test_edge"
33 | )
34 |
35 | val mergedParams = Parameters.mergeParameters(params)
36 |
37 | mergedParams.dbUrl.get shouldBe params("dburl")
38 | mergedParams.credentials.get._1 shouldBe params("user")
39 | mergedParams.credentials.get._2 shouldBe params("password")
40 | mergedParams.edgeType.get shouldBe params("edgetype")
41 |
42 | Parameters.DEFAULT_PARAMETERS.foreach {
43 | case (key, value) => mergedParams.parameters(key) shouldBe value
44 | }
45 | }
46 |
47 | test("Errors are thrown when mandatory parameters are not provided") {
48 | def checkMerge(params: Map[String, String]): Unit = {
49 | intercept[IllegalArgumentException] {
50 | Parameters.mergeParameters(params)
51 | }
52 | }
53 |
54 | val testURL = "remote:127.0.0.1:2424/GratefulDeadConcerts"
55 | checkMerge(Map("dburl" -> testURL, "user" -> "root", "password" -> "root"))
56 | checkMerge(Map("user" -> "root", "password" -> "root"))
57 | checkMerge(Map("dburl" -> testURL, "vertextype" -> "test_vertex"))
58 | checkMerge(Map("user" -> "root", "password" -> "root", "edgetype" -> "test_edge"))
59 | }
60 |
61 | test("Must specify either 'vertextype' param, " +
62 | "or, 'vertextype' and 'query' parameter, " +
63 | "or 'vertextype' and 'query' parameter and user-defined schema") {
64 | intercept[IllegalArgumentException] {
65 | Parameters.mergeParameters(Map(
66 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
67 | "user" -> "root",
68 | "password" -> "root"
69 | ))
70 | }
71 |
72 | Parameters.mergeParameters(Map(
73 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
74 | "user" -> "root",
75 | "password" -> "root",
76 | "vertextype" -> "test_vertex"
77 | ))
78 |
79 | Parameters.mergeParameters(Map(
80 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
81 | "user" -> "root",
82 | "password" -> "root",
83 | "vertextype" -> "test_vertex",
84 | "query" -> "select * from test_vertex"
85 | ))
86 | }
87 |
88 | test("Must specify either 'edgetype' param, " +
89 | "or, 'edgetype' and 'query' parameter, " +
90 | "or 'edgetype' and 'query' parameter and user-defined schema") {
91 | intercept[IllegalArgumentException] {
92 | Parameters.mergeParameters(Map(
93 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
94 | "user" -> "root",
95 | "password" -> "root"
96 | ))
97 | }
98 |
99 | Parameters.mergeParameters(Map(
100 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
101 | "user" -> "root",
102 | "password" -> "root",
103 | "edgetype" -> "test_edge"
104 | ))
105 |
106 | Parameters.mergeParameters(Map(
107 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
108 | "user" -> "root",
109 | "password" -> "root",
110 | "edgetype" -> "test_edge",
111 | "query" -> "select * from test_edge"
112 | ))
113 | }
114 |
115 | test("Must specify 'url' parameter and 'user' and 'password' parameters") {
116 | intercept[IllegalArgumentException] {
117 | Parameters.mergeParameters(Map(
118 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts"
119 | ))
120 | }
121 |
122 | intercept[IllegalArgumentException] {
123 | Parameters.mergeParameters(Map(
124 | "user" -> "root",
125 | "password" -> "root"
126 | ))
127 | }
128 |
129 | intercept[IllegalArgumentException] {
130 | Parameters.mergeParameters(Map(
131 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts",
132 | "password" -> "root"
133 | ))
134 | }
135 | }
136 | }
--------------------------------------------------------------------------------