├── .gitignore ├── .travis.yml ├── Documentation.md ├── LICENSE ├── README.md ├── pom.xml └── src ├── it └── scala │ └── org │ └── apache │ └── spark │ └── orientdb │ ├── documents │ ├── IntegrationSuiteBase.scala │ └── OrientDBIntegrationSuite.scala │ └── graphs │ ├── IntegrationSuiteBase.scala │ └── OrientDBGraphIntegrationSuite.scala ├── main ├── examples │ └── org │ │ └── apache │ │ └── spark │ │ └── orientdb │ │ ├── documents │ │ └── DataFrameTest.scala │ │ └── graphs │ │ └── GraphFrameTest.scala └── scala │ └── org │ └── apache │ └── spark │ └── orientdb │ ├── documents │ ├── Conversions.scala │ ├── DefaultSource.scala │ ├── FilterPushdown.scala │ ├── OrientDBClientFactory.scala │ ├── OrientDBCredentials.scala │ ├── OrientDBDocumentWrapper.scala │ ├── OrientDBRelation.scala │ ├── OrientDBWriter.scala │ ├── Parameters.scala │ └── TableName.scala │ ├── graphs │ ├── DefaultSource.scala │ ├── OrientDBClientFactory.scala │ ├── OrientDBCredentials.scala │ ├── OrientDBGraphWrapper.scala │ ├── OrientDBRelation.scala │ ├── OrientDBWriter.scala │ └── Parameters.scala │ └── udts │ ├── EmbeddedListType.scala │ ├── EmbeddedMapType.scala │ ├── EmbeddedSetType.scala │ ├── LinkBagType.scala │ ├── LinkListType.scala │ ├── LinkMapType.scala │ ├── LinkSetType.scala │ └── LinkType.scala └── test └── scala └── org └── apache └── spark └── orientdb ├── QueryTest.scala ├── TestUtils.scala ├── documents ├── ConversionsSuite.scala ├── FilterPushdownSuite.scala ├── MockOrientDBDocument.scala ├── OrientDBEmbeddedUDTsSourceSuite.scala ├── OrientDBLinkUDTsSourceSuite.scala ├── OrientDBSourceSuite.scala ├── ParametersSuite.scala └── TableNameSuite.scala └── graphs ├── MockEdge.scala ├── MockOrientDBGraph.scala ├── MockVertex.scala ├── OrientDBEmbeddedUDTsSourceSuite.scala ├── OrientDBGraphSourceSuite.scala ├── OrientDBLinkUDTsSourceSuite.scala └── ParametersSuite.scala /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .idea/ 3 | *.iml 4 | *.log -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java -------------------------------------------------------------------------------- /Documentation.md: -------------------------------------------------------------------------------- 1 | Apache Spark datasource for OrientDB 2 | ==================================== 3 | 4 | ## Introduction 5 | 6 | This documentation discusses the topic of how to connect Apache Spark to OrientDB. [Apache Spark](http://spark.apache.org/) is the widely popular engine for large-scale data processing. 7 | Here, we will discuss how to use the [**Apache Spark datasource for OrientDB**](https://github.com/sbcd90/spark-orientdb) to leverage Spark's capabilities while using OrientDB as the datastore. 8 | 9 | ## Installation Guide 10 | 11 | To use the Apache Spark datasource for OrientDB inside a Spark application, the following steps need to be performed. 12 | 13 | - Add the repository location to `pom.xml`. 14 | 15 | ``` 16 | 17 | bintray 18 | bintray 19 | https://dl.bintray.com/sbcd90/org.apache.spark/ 20 | 21 | ``` 22 | 23 | - Add the datasource as a maven dependency in `pom.xml`. 24 | 25 | ``` 26 | 27 | org.apache.spark 28 | spark-orientdb-{spark.version}_2.10 29 | 1.3 30 | 31 | ``` 32 | 33 | 34 | ## Configuration 35 | 36 | The datasource is supported for Apache Spark version 1.6+ & OrientDB 2.2.0+. 37 | 38 | ## API Reference 39 | 40 | The **Apache Spark datasource for OrientDB** allows users to leverage the Apache Spark datasource api s for reading and writing data from OrientDB. The datasource loads a collection of OrientDB documents into a dataframe & an OrientDB graph containing a set of vertices & edges into a Graphframe. 41 | 42 | The complete api reference for using the **Apache Spark datasource for OrientDB** is provided below. 43 | 44 | ### OrientDB Documents 45 | 46 | #### Write api: 47 | 48 | ``` 49 | import org.apache.spark.sql.SQLContext 50 | 51 | val sqlContext = new SQLContext(sc) 52 | sqlContext.createDataFrame(sc.parallelize(Array(1, 2, 3, 4, 5)), 53 | StructType(Seq(StructField("id", IntegerType))) 54 | .write 55 | .format("org.apache.spark.orientdb.documents") 56 | .option("dburl", ORIENTDB_CONNECTION_URL) 57 | .option("user", ORIENTDB_USER).option("password", ORIENTDB_PASSWORD) 58 | .option("class", test_class) 59 | .mode(SaveMode.Overwrite) 60 | .save() 61 | ``` 62 | 63 | #### Read api: 64 | 65 | ``` 66 | import org.apache.spark.sql.SQLContext 67 | 68 | val sqlContext = new SQLContext(sc) 69 | val loadedDf = sqlContext.read 70 | .format("org.apache.spark.orientdb.documents") 71 | .option("dburl", ORIENTDB_CONNECTION_URL) 72 | .option("user", ORIENTDB_USER) 73 | .option("password", ORIENTDB_PASSWORD) 74 | .option("class", test_class) 75 | .option("query", s"select * from $test_table where teststring = 'asdf'") 76 | .load() 77 | ``` 78 | 79 | #### Query using OrientDB SQL: 80 | 81 | ``` 82 | import org.apache.spark.sql.SQLContext 83 | 84 | val sqlContext = new SQLContext(sc) 85 | val loadedDf = sqlContext.read 86 | .format("org.apache.spark.orientdb.documents") 87 | .option("dburl", ORIENTDB_CONNECTION_URL) 88 | .option("user", ORIENTDB_USER) 89 | .option("password", ORIENTDB_PASSWORD) 90 | .option("class", test_class) 91 | .option("query", s"select * from $test_table where teststring = 'asdf'") 92 | .load() 93 | ``` 94 | 95 | ### OrientDB Graphs: 96 | 97 | #### Create Vertex api: 98 | 99 | ``` 100 | import org.apache.spark.sql.SQLContext 101 | 102 | val sqlContext = new SQLContext(sc) 103 | sqlContext.createDataFrame(sc.parallelize(Array(1, 2, 3, 4, 5)), 104 | StructType(Seq(StructField("id", IntegerType))) 105 | .write 106 | .format("org.apache.spark.orientdb.graphs") 107 | .option("dburl", ORIENTDB_CONNECTION_URL) 108 | .option("user", ORIENTDB_USER) 109 | .option("password", ORIENTDB_PASSWORD) 110 | .option("vertextype", test_vertex_type2) 111 | .mode(SaveMode.Overwrite) 112 | .save() 113 | ``` 114 | 115 | #### Create Edge api: 116 | 117 | ``` 118 | import org.apache.spark.sql.SQLContext 119 | 120 | val sqlContext = new SQLContext(sc) 121 | sqlContext.createDataFrame( 122 | sc.parallelize(Seq( 123 | Row(1, 2, "friends"), 124 | Row(2, 3, "enemy"), 125 | Row(3, 4, "friends"), 126 | Row(4, 1, "enemy") 127 | )), 128 | StructType(Seq( 129 | StructField("src", IntegerType), 130 | StructField("dst", IntegerType), 131 | StructField("relationship", StringType) 132 | ))) 133 | .write 134 | .format("org.apache.spark.orientdb.graphs") 135 | .option("dburl", ORIENTDB_CONNECTION_URL) 136 | .option("user", ORIENTDB_USER) 137 | .option("password", ORIENTDB_PASSWORD) 138 | .option("vertextype", test_vertex_type2) 139 | .option("edgetype", test_edge_type2) 140 | .mode(SaveMode.Overwrite) 141 | .save() 142 | ``` 143 | 144 | #### Read Vertex api: 145 | 146 | ``` 147 | import org.apache.spark.sql.SQLContext 148 | 149 | val sqlContext = new SQLContext(sc) 150 | val loadedDf = sqlContext.read 151 | .format("org.apache.spark.orientdb.graphs") 152 | .option("dburl", ORIENTDB_CONNECTION_URL) 153 | .option("user", ORIENTDB_USER) 154 | .option("password", ORIENTDB_PASSWORD) 155 | .option("vertextype", test_vertex_type2) 156 | .load() 157 | ``` 158 | 159 | #### Read edge api: 160 | 161 | ``` 162 | import org.apache.spark.sql.SQLContext 163 | 164 | val sqlContext = new SQLContext(sc) 165 | val loadedDf = sqlContext.read 166 | .format("org.apache.spark.orientdb.graphs") 167 | .option("dburl", ORIENTDB_CONNECTION_URL) 168 | .option("user", ORIENTDB_USER) 169 | .option("password", ORIENTDB_PASSWORD) 170 | .option("edgetype", test_edge_type2) 171 | .load() 172 | ``` 173 | 174 | #### Query using OrientDB Graph SQL: 175 | 176 | ``` 177 | import org.apache.spark.sql.SQLContext 178 | 179 | val sqlContext = new SQLContext(sc) 180 | val loadedVerticesDf = sqlContext.read 181 | .format("org.apache.spark.orientdb.graphs") 182 | .option("dburl", ORIENTDB_CONNECTION_URL) 183 | .option("user", ORIENTDB_USER) 184 | .option("password", ORIENTDB_PASSWORD) 185 | .option("vertextype", test_vertex_type2) 186 | .option("query", s"select * from $test_vertex_type2 where teststring = 'asdf'") 187 | .load() 188 | 189 | val loadedEdgesDf = sqlContext.read 190 | .format("org.apache.spark.orientdb.graphs") 191 | .option("dburl", ORIENTDB_CONNECTION_URL) 192 | .option("user", ORIENTDB_USER) 193 | .option("password", ORIENTDB_PASSWORD) 194 | .option("edgetype", test_edge_type2) 195 | .option("query", s"select * from $test_edge_type2 where relationship = 'friends'") 196 | .load() 197 | ``` 198 | 199 | ### Integration with GraphFrames 200 | 201 | ``` 202 | import org.apache.spark.sql.SQLContext 203 | 204 | val sqlContext = new SQLContext(sc) 205 | val loadedVerticesDf = sqlContext.read 206 | .format("org.apache.spark.orientdb.graphs") 207 | .option("dburl", ORIENTDB_CONNECTION_URL) 208 | .option("user", ORIENTDB_USER) 209 | .option("password", ORIENTDB_PASSWORD) 210 | .option("vertextype", test_vertex_type2) 211 | .option("query", s"select * from $test_vertex_type2 where teststring = 'asdf'") 212 | .load() 213 | 214 | val loadedEdgesDf = sqlContext.read 215 | .format("org.apache.spark.orientdb.graphs") 216 | .option("dburl", ORIENTDB_CONNECTION_URL) 217 | .option("user", ORIENTDB_USER) 218 | .option("password", ORIENTDB_PASSWORD) 219 | .option("edgetype", test_edge_type2) 220 | .option("query", s"select * from $test_edge_type2 where relationship = 'friends'") 221 | .load() 222 | 223 | val g = GraphFrame(loadedVerticesDf, loadedEdgesDf) 224 | ``` 225 | 226 | ## Examples 227 | 228 | ### An example Spark application with OrientDB Documents 229 | 230 | ``` 231 | package org.apache.spark.orientdb.documents 232 | 233 | import org.apache.spark.sql.{SQLContext, SaveMode} 234 | import org.apache.spark.{SparkConf, SparkContext} 235 | 236 | object DataFrameTest extends App { 237 | val conf = new SparkConf().setAppName("DataFrameTest").setMaster("local[*]") 238 | val sc = new SparkContext(conf) 239 | val sqlContext = new SQLContext(sc) 240 | 241 | import sqlContext.implicits._ 242 | val df = sc.parallelize(Array(1, 2, 3, 4, 5)).toDF("id") 243 | 244 | df.write.format("org.apache.spark.orientdb.documents") 245 | .option("dburl", "") 246 | .option("user", "****") 247 | .option("password", "****") 248 | .option("class", "test_class") 249 | .mode(SaveMode.Overwrite) 250 | .save() 251 | 252 | val resultDf = sqlContext.read 253 | .format("org.apache.spark.orientdb.documents") 254 | .option("dburl", "") 255 | .option("user", "****") 256 | .option("password", "****") 257 | .option("class", "test_class") 258 | .load() 259 | 260 | resultDf.show() 261 | } 262 | ``` 263 | 264 | ### An example Spark application with OrientDB Graph 265 | 266 | ``` 267 | package org.apache.spark.orientdb.graphs 268 | 269 | import org.apache.spark.{SparkConf, SparkContext} 270 | import org.apache.spark.sql.{Row, SQLContext, SaveMode} 271 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 272 | import org.graphframes.GraphFrame 273 | 274 | object GraphFrameTest extends App { 275 | val conf = new SparkConf().setAppName("MainApplication").setMaster("local[*]") 276 | val sc = new SparkContext(conf) 277 | sc.setLogLevel("WARN") 278 | val sqlContext = new SQLContext(sc) 279 | 280 | import sqlContext.implicits._ 281 | val df = sc.parallelize(Array(1, 2, 3, 4, 5)).toDF("id") 282 | 283 | df.write.format("org.apache.spark.orientdb.graphs") 284 | .option("dburl", "") 285 | .option("user", "****") 286 | .option("password", "****") 287 | .option("vertextype", "v104") 288 | .mode(SaveMode.Overwrite) 289 | .save() 290 | 291 | val vertices = sqlContext.read 292 | .format("org.apache.spark.orientdb.graphs") 293 | .option("dburl", "") 294 | .option("user", "****") 295 | .option("password", "****") 296 | .option("vertextype", "v104") 297 | .load() 298 | 299 | var inVertex: Integer = null 300 | var outVertex: Integer = null 301 | vertices.collect().foreach(row => { 302 | if (inVertex == null) { 303 | inVertex = row.getAs[Integer]("id") 304 | } 305 | if (outVertex == null) { 306 | outVertex = row.getAs[Integer]("id") 307 | } 308 | }) 309 | 310 | val df1 = sqlContext.createDataFrame(sc.parallelize(Seq(Row("friends", "1", "2"), 311 | Row("enemies", "2", "3"), Row("friends", "3", "1"))), 312 | StructType(List(StructField("relationship", StringType), StructField("src", StringType), 313 | StructField("dst", StringType)))) 314 | 315 | df1.write.format("org.apache.spark.orientdb.graphs") 316 | .option("dburl", "") 317 | .option("user", "****") 318 | .option("password", "****") 319 | .option("vertextype", "v104") 320 | .option("edgetype", "e104") 321 | .mode(SaveMode.Overwrite) 322 | .save() 323 | 324 | val edges = sqlContext.read 325 | .format("org.apache.spark.orientdb.graphs") 326 | .option("dburl", "") 327 | .option("user", "****") 328 | .option("password", "****") 329 | .option("edgetype", "e104") 330 | .load() 331 | 332 | edges.show() 333 | 334 | val g = GraphFrame(vertices, edges) 335 | g.inDegrees.show() 336 | println(g.edges.filter("relationship = 'friends'").count()) 337 | } 338 | ``` 339 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2016 Subhobrata Dey 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | spark-orientdb 2 | ============== 3 | [![Build Status](https://travis-ci.org/sbcd90/spark-orientdb.svg?branch=master)](https://travis-ci.org/sbcd90/spark-orientdb) [ ![Download](https://api.bintray.com/packages/sbcd90/org.apache.spark/spark-orientdb-1.6.2_2.10/images/download.svg) ](https://bintray.com/sbcd90/org.apache.spark/spark-orientdb-1.6.2_2.10/_latestVersion) 4 | 5 | Apache Spark datasource for OrientDB 6 | 7 | OrientDB documentation 8 | ====================== 9 | 10 | Here is the latest documentation on [OrientDB](http://orientdb.com/orientdb/) 11 | 12 | Compatibility 13 | ============= 14 | 15 | `Spark`: 1.6+ 16 | `OrientDB`: 2.2.0+ 17 | 18 | Getting Started 19 | =============== 20 | 21 | - Add the repository 22 | 23 | ``` 24 | 25 | bintray 26 | bintray 27 | https://dl.bintray.com/sbcd90/org.apache.spark/ 28 | 29 | ``` 30 | 31 | ### For Spark 1.6 32 | 33 | - Add the datasource as a maven dependency 34 | 35 | ``` 36 | 37 | org.apache.spark 38 | spark-orientdb-1.6.2_2.10 39 | 1.3 40 | 41 | ``` 42 | 43 | ### For Spark 2.0 44 | 45 | - Add the datasource as a maven dependency 46 | 47 | ``` 48 | 49 | org.apache.spark 50 | spark-orientdb-2.0.0_2.10 51 | 1.4 52 | 53 | ``` 54 | 55 | ### For Spark 2.1 56 | 57 | ``` 58 | 59 | org.apache.spark 60 | spark-orientdb-2.1.1_2.11 61 | 1.4 62 | 63 | ``` 64 | 65 | ### For Spark 2.2 66 | 67 | ``` 68 | 69 | org.apache.spark 70 | spark-orientdb-2.2.1_2.11 71 | 1.4 72 | 73 | ``` 74 | 75 | Scala api 76 | ========= 77 | 78 | ### OrientDB Documents 79 | 80 | #### Write api: 81 | 82 | ``` 83 | import org.apache.spark.sql.SQLContext 84 | 85 | val sqlContext = new SQLContext(sc) 86 | sqlContext.createDataFrame(sc.parallelize(Array(1, 2, 3, 4, 5)), 87 | StructType(Seq(StructField("id", IntegerType))) 88 | .write 89 | .format("org.apache.spark.orientdb.documents") 90 | .option("dburl", ORIENTDB_CONNECTION_URL) 91 | .option("user", ORIENTDB_USER).option("password", ORIENTDB_PASSWORD) 92 | .option("class", test_table) 93 | .mode(SaveMode.Overwrite) 94 | .save() 95 | ``` 96 | 97 | #### Read api: 98 | 99 | ``` 100 | import org.apache.spark.sql.SQLContext 101 | 102 | val sqlContext = new SQLContext(sc) 103 | val loadedDf = sqlContext.read 104 | .format("org.apache.spark.orientdb.documents") 105 | .option("dburl", ORIENTDB_CONNECTION_URL) 106 | .option("user", ORIENTDB_USER) 107 | .option("password", ORIENTDB_PASSWORD) 108 | .option("class", test_table) 109 | .option("query", s"select * from $test_table where teststring = 'asdf'") 110 | .load() 111 | ``` 112 | 113 | #### Query using OrientDB SQL: 114 | 115 | ``` 116 | import org.apache.spark.sql.SQLContext 117 | 118 | val sqlContext = new SQLContext(sc) 119 | val loadedDf = sqlContext.read 120 | .format("org.apache.spark.orientdb.documents") 121 | .option("dburl", ORIENTDB_CONNECTION_URL) 122 | .option("user", ORIENTDB_USER) 123 | .option("password", ORIENTDB_PASSWORD) 124 | .option("class", test_table) 125 | .option("query", s"select * from $test_table where teststring = 'asdf'") 126 | .load() 127 | ``` 128 | 129 | #### Support for Embedded Types( Since Spark 2.1 release): 130 | 131 | ``` 132 | val testSchemaForEmbeddedUDTs: StructType = { 133 | StructType(Seq( 134 | StructField("embeddedlist", EmbeddedListType), 135 | StructField("embeddedset", EmbeddedSetType), 136 | StructField("embeddedmap", EmbeddedMapType) 137 | )) 138 | } 139 | ``` 140 | 141 | ``` 142 | val expectedDataForEmbeddedUDTs: Seq[Row] = Seq( 143 | Row(EmbeddedList(Array(1, 1.toByte, true, TestUtils.toDate(2015, 6, 1), 1234152.12312498, 144 | 1.0f, 42, 1239012341823719L, 23.toShort, "Unicode's樂趣", 145 | TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1))), 146 | EmbeddedSet(Array(1, 1.toByte, true, TestUtils.toDate(2015, 6, 1), 1234152.12312498, 147 | 1.0f, 42, 1239012341823719L, 23.toShort, "Unicode's樂趣", 148 | TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1))), 149 | EmbeddedMap(Map(1 -> 1, 2 -> 1.toByte, 3 -> true, 4 -> TestUtils.toDate(2015, 6, 1), 5 -> 1234152.12312498, 150 | 6 -> 1.0f, 7 -> 42, 8 -> 1239012341823719L, 9 -> 23.toShort, 10 -> "Unicode's樂趣", 11 -> TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1)))) 151 | ) 152 | ``` 153 | 154 | #### Support for Link Types( Since Spark 2.1 release): 155 | 156 | ``` 157 | val testSchemaForLinkUDTs: StructType = { 158 | StructType(Seq( 159 | StructField("linklist", LinkListType), 160 | StructField("linkset", LinkSetType), 161 | StructField("linkmap", LinkMapType), 162 | StructField("linkbag", LinkBagType) 163 | )) 164 | } 165 | ``` 166 | 167 | ``` 168 | val expectedDataForLinkUDTs: Seq[Row] = Seq( 169 | Row(LinkList(Array(oDocument1)), LinkSet(Array(oDocument1)), LinkMap(Map("1" -> oDocument1)), LinkBag(Array(oRid1))), 170 | Row(LinkList(Array(oDocument2)), LinkSet(Array(oDocument2)), LinkMap(Map("1" -> oDocument2)), LinkBag(Array(oRid2))), 171 | Row(LinkList(Array(oDocument3)), LinkSet(Array(oDocument3)), LinkMap(Map("1" -> oDocument3)), LinkBag(Array(oRid3))), 172 | Row(LinkList(Array(oDocument4)), LinkSet(Array(oDocument4)), LinkMap(Map("1" -> oDocument4)), LinkBag(Array(oRid4))), 173 | Row(LinkList(Array(oDocument5)), LinkSet(Array(oDocument5)), LinkMap(Map("1" -> oDocument5)), LinkBag(Array(oRid5))) 174 | ) 175 | ``` 176 | 177 | ### OrientDB Graphs: 178 | 179 | #### Create Vertex api: 180 | 181 | ``` 182 | import org.apache.spark.sql.SQLContext 183 | 184 | val sqlContext = new SQLContext(sc) 185 | sqlContext.createDataFrame(sc.parallelize(Array(1, 2, 3, 4, 5)), 186 | StructType(Seq(StructField("id", IntegerType))) 187 | .write 188 | .format("org.apache.spark.orientdb.graphs") 189 | .option("dburl", ORIENTDB_CONNECTION_URL) 190 | .option("user", ORIENTDB_USER) 191 | .option("password", ORIENTDB_PASSWORD) 192 | .option("vertextype", test_vertex_type2) 193 | .mode(SaveMode.Overwrite) 194 | .save() 195 | ``` 196 | 197 | #### Create Edge api: 198 | 199 | ``` 200 | import org.apache.spark.sql.SQLContext 201 | 202 | val sqlContext = new SQLContext(sc) 203 | sqlContext.createDataFrame( 204 | sc.parallelize(Seq( 205 | Row(1, 2, "friends"), 206 | Row(2, 3, "enemy"), 207 | Row(3, 4, "friends"), 208 | Row(4, 1, "enemy") 209 | )), 210 | StructType(Seq( 211 | StructField("src", IntegerType), 212 | StructField("dst", IntegerType), 213 | StructField("relationship", StringType) 214 | ))) 215 | .write 216 | .format("org.apache.spark.orientdb.graphs") 217 | .option("dburl", ORIENTDB_CONNECTION_URL) 218 | .option("user", ORIENTDB_USER) 219 | .option("password", ORIENTDB_PASSWORD) 220 | .option("vertextype", test_vertex_type2) 221 | .option("edgetype", test_edge_type2) 222 | .mode(SaveMode.Overwrite) 223 | .save() 224 | ``` 225 | 226 | #### Read Vertex api: 227 | 228 | ``` 229 | import org.apache.spark.sql.SQLContext 230 | 231 | val sqlContext = new SQLContext(sc) 232 | val loadedDf = sqlContext.read 233 | .format("org.apache.spark.orientdb.graphs") 234 | .option("dburl", ORIENTDB_CONNECTION_URL) 235 | .option("user", ORIENTDB_USER) 236 | .option("password", ORIENTDB_PASSWORD) 237 | .option("vertextype", test_vertex_type2) 238 | .load() 239 | ``` 240 | 241 | #### Read edge api: 242 | 243 | ``` 244 | import org.apache.spark.sql.SQLContext 245 | 246 | val sqlContext = new SQLContext(sc) 247 | val loadedDf = sqlContext.read 248 | .format("org.apache.spark.orientdb.graphs") 249 | .option("dburl", ORIENTDB_CONNECTION_URL) 250 | .option("user", ORIENTDB_USER) 251 | .option("password", ORIENTDB_PASSWORD) 252 | .option("edgetype", test_edge_type2) 253 | .load() 254 | ``` 255 | 256 | #### Query using OrientDB Graph SQL: 257 | 258 | ``` 259 | import org.apache.spark.sql.SQLContext 260 | 261 | val sqlContext = new SQLContext(sc) 262 | val loadedVerticesDf = sqlContext.read 263 | .format("org.apache.spark.orientdb.graphs") 264 | .option("dburl", ORIENTDB_CONNECTION_URL) 265 | .option("user", ORIENTDB_USER) 266 | .option("password", ORIENTDB_PASSWORD) 267 | .option("vertextype", test_vertex_type2) 268 | .option("query", s"select * from $test_vertex_type2 where teststring = 'asdf'") 269 | .load() 270 | 271 | val loadedEdgesDf = sqlContext.read 272 | .format("org.apache.spark.orientdb.graphs") 273 | .option("dburl", ORIENTDB_CONNECTION_URL) 274 | .option("user", ORIENTDB_USER) 275 | .option("password", ORIENTDB_PASSWORD) 276 | .option("edgetype", test_edge_type2) 277 | .option("query", s"select * from $test_edge_type2 where relationship = 'friends'") 278 | .load() 279 | ``` 280 | 281 | #### Support for embedded types & link types( Since Spark 2.1 release) 282 | 283 | The Spark UDTs are available for OrientDB Graph datasource as well. 284 | Usage is very similar to the ones documented for OrientDB Document datasource. 285 | Examples can be found in Integration tests. 286 | 287 | ### Integration with GraphFrames 288 | 289 | ``` 290 | import org.apache.spark.sql.SQLContext 291 | 292 | val sqlContext = new SQLContext(sc) 293 | val loadedVerticesDf = sqlContext.read 294 | .format("org.apache.spark.orientdb.graphs") 295 | .option("dburl", ORIENTDB_CONNECTION_URL) 296 | .option("user", ORIENTDB_USER) 297 | .option("password", ORIENTDB_PASSWORD) 298 | .option("vertextype", test_vertex_type2) 299 | .option("query", s"select * from $test_vertex_type2 where teststring = 'asdf'") 300 | .load() 301 | 302 | val loadedEdgesDf = sqlContext.read 303 | .format("org.apache.spark.orientdb.graphs") 304 | .option("dburl", ORIENTDB_CONNECTION_URL) 305 | .option("user", ORIENTDB_USER) 306 | .option("password", ORIENTDB_PASSWORD) 307 | .option("edgetype", test_edge_type2) 308 | .option("query", s"select * from $test_edge_type2 where relationship = 'friends'") 309 | .load() 310 | 311 | val g = GraphFrame(loadedVerticesDf, loadedEdgesDf) 312 | ``` 313 | 314 | A full example can be found in directory `src/main/examples` -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | org.apache.spark 8 | spark-orientdb-${spark.version}_${scala.binary.version} 9 | 1.4 10 | 11 | 12 | 2.12.12 13 | 2.12 14 | 3.1.1 15 | 3.0.13 16 | 1.8 17 | 1.8 18 | 19 | 20 | 21 | 22 | SparkPackagesRepo 23 | SparkPackagesRepo 24 | https://dl.bintray.com/spark-packages/maven 25 | 26 | 27 | bintray 28 | bintray.com 29 | https://repos.spark-packages.org 30 | 31 | 32 | 33 | 34 | 35 | org.scala-lang 36 | scala-compiler 37 | ${scala.version} 38 | 39 | 40 | com.orientechnologies 41 | orientdb-graphdb 42 | ${orientdb.graphdb.version} 43 | 44 | 45 | org.scalatest 46 | scalatest_${scala.binary.version} 47 | 3.0.0 48 | test 49 | 50 | 51 | org.mockito 52 | mockito-all 53 | 1.10.19 54 | test 55 | 56 | 57 | org.apache.spark 58 | spark-core_${scala.binary.version} 59 | ${spark.version} 60 | 61 | 62 | org.apache.spark 63 | spark-sql_${scala.binary.version} 64 | ${spark.version} 65 | 66 | 67 | org.apache.spark 68 | spark-hive_${scala.binary.version} 69 | ${spark.version} 70 | 71 | 72 | org.apache.spark 73 | spark-graphx_${scala.binary.version} 74 | ${spark.version} 75 | 76 | 77 | 78 | graphframes 79 | graphframes 80 | 0.8.1-spark3.0-s_${scala.binary.version} 81 | 82 | 83 | 84 | 85 | 86 | 87 | bintray-sbcd90-spark-orientdb 88 | 89 | https://api.bintray.com/maven/sbcd90/org.apache.spark/spark-orientdb-${spark.version}_${scala.binary.version}/;publish=1 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | org.apache.maven.plugins 98 | maven-compiler-plugin 99 | 3.5.1 100 | 101 | 102 | org.scala-tools 103 | maven-scala-plugin 104 | 105 | 106 | 107 | compile 108 | testCompile 109 | 110 | 111 | 112 | 113 | src/main/scala 114 | 115 | -Xms64m 116 | -Xmx1024m 117 | 118 | 119 | 120 | 121 | 122 | org.apache.maven.plugins 123 | maven-surefire-plugin 124 | 2.19.1 125 | 126 | true 127 | 128 | 129 | 130 | 131 | org.scalatest 132 | scalatest-maven-plugin 133 | 1.0 134 | 135 | ${project.build.directory}/surefire-reports 136 | . 137 | TestSuite.txt 138 | 139 | 140 | 141 | test 142 | 143 | test 144 | 145 | 146 | 147 | 148 | 149 | org.apache.maven.plugins 150 | maven-jar-plugin 151 | 2.3.2 152 | 153 | spark-orientdb-${spark.version}_${scala.binary.version} 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /src/it/scala/org/apache/spark/orientdb/documents/IntegrationSuiteBase.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import com.orientechnologies.orient.core.db.document.ODatabaseDocument 4 | import org.apache.spark.SparkContext 5 | import org.apache.spark.orientdb.QueryTest 6 | import org.apache.spark.sql.types.StructType 7 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} 8 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} 9 | 10 | trait IntegrationSuiteBase 11 | extends QueryTest 12 | with Matchers 13 | with BeforeAndAfterAll 14 | with BeforeAndAfterEach { 15 | 16 | protected def loadConfigFromEnv(envVarName: String): String = { 17 | Option(System.getenv(envVarName)).getOrElse{ 18 | fail(s"Must set $envVarName environment variable") 19 | } 20 | } 21 | 22 | protected val ORIENTDB_CONNECTION_URL = loadConfigFromEnv("ORIENTDB_CONN_URL") 23 | protected val ORIENTDB_USER = loadConfigFromEnv("ORIENTDB_USER") 24 | protected val ORIENTDB_PASSWORD = loadConfigFromEnv("ORIENTDB_PASSWORD") 25 | 26 | protected var sc: SparkContext = _ 27 | protected var sqlContext: SQLContext = _ 28 | protected var connection: ODatabaseDocument = _ 29 | 30 | protected val orientDBWrapper = DefaultOrientDBDocumentWrapper 31 | 32 | override def beforeAll(): Unit = { 33 | super.beforeAll() 34 | sc = new SparkContext("local", "OrientDBSourceSuite") 35 | 36 | val parameters = Map("dburl" -> ORIENTDB_CONNECTION_URL, 37 | "user" -> ORIENTDB_USER, 38 | "password" -> ORIENTDB_PASSWORD, 39 | "class" -> "dummy") 40 | connection = orientDBWrapper.getConnection(Parameters.mergeParameters(parameters)) 41 | } 42 | 43 | override def afterAll(): Unit = { 44 | try { 45 | connection.close() 46 | } finally { 47 | try { 48 | sc.stop() 49 | } finally { 50 | super.afterAll() 51 | } 52 | } 53 | } 54 | 55 | override protected def beforeEach(): Unit = { 56 | super.beforeEach() 57 | sqlContext = new SQLContext(sc) 58 | } 59 | 60 | def testRoundtripSaveAndLoad(className: String, 61 | df: DataFrame, 62 | expectedSchemaAfterLoad: Option[StructType] = None, 63 | saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = { 64 | try { 65 | df.write 66 | .format("org.apache.spark.orientdb.documents") 67 | .option("dburl", ORIENTDB_CONNECTION_URL) 68 | .option("user", ORIENTDB_USER) 69 | .option("password", ORIENTDB_PASSWORD) 70 | .option("class", className) 71 | .mode(saveMode) 72 | .save() 73 | 74 | if (!orientDBWrapper.doesClassExists(className)) { 75 | Thread.sleep(1000) 76 | assert(orientDBWrapper.doesClassExists(className)) 77 | } 78 | 79 | val loadedDf = sqlContext.read.format("org.apache.spark.orientdb.documents") 80 | .option("dburl", ORIENTDB_CONNECTION_URL) 81 | .option("user", ORIENTDB_USER) 82 | .option("password", ORIENTDB_PASSWORD) 83 | .option("class", className) 84 | .load() 85 | assert(loadedDf.schema === expectedSchemaAfterLoad.getOrElse(df.schema)) 86 | checkAnswer(loadedDf, df.collect()) 87 | } finally { 88 | orientDBWrapper.delete(null, className, null) 89 | val schema = connection.getMetadata.getSchema 90 | schema.dropClass(className) 91 | } 92 | } 93 | } -------------------------------------------------------------------------------- /src/it/scala/org/apache/spark/orientdb/graphs/IntegrationSuiteBase.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import com.tinkerpop.blueprints.impls.orient.OrientGraphNoTx 4 | import org.apache.spark.SparkContext 5 | import org.apache.spark.orientdb.QueryTest 6 | import org.apache.spark.sql.types.StructType 7 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} 8 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} 9 | 10 | trait IntegrationSuiteBase 11 | extends QueryTest 12 | with Matchers 13 | with BeforeAndAfterAll 14 | with BeforeAndAfterEach { 15 | 16 | protected def loadConfigFromEnv(envVarName: String): String = { 17 | Option(System.getenv(envVarName)).getOrElse{ 18 | fail(s"Must set $envVarName environment variable") 19 | } 20 | } 21 | 22 | protected val ORIENTDB_CONNECTION_URL = loadConfigFromEnv("ORIENTDB_CONN_URL") 23 | protected val ORIENTDB_USER = loadConfigFromEnv("ORIENTDB_USER") 24 | protected val ORIENTDB_PASSWORD = loadConfigFromEnv("ORIENTDB_PASSWORD") 25 | 26 | protected var sc: SparkContext = _ 27 | protected var sqlContext: SQLContext = _ 28 | protected var vertex_connection: OrientGraphNoTx = _ 29 | protected var edge_connection: OrientGraphNoTx = _ 30 | 31 | protected val orientDBGraphVertexWrapper = DefaultOrientDBGraphVertexWrapper 32 | protected val orientDBGraphEdgeWrapper = DefaultOrientDBGraphEdgeWrapper 33 | 34 | override def beforeAll(): Unit = { 35 | super.beforeAll() 36 | sc = new SparkContext("local", "OrientDBSourceSuite") 37 | 38 | val parameters = Map[String, String]( 39 | "dburl" -> ORIENTDB_CONNECTION_URL, 40 | "user" -> ORIENTDB_USER, 41 | "password" -> ORIENTDB_PASSWORD) 42 | 43 | val vertexParams: Map[String, String] = parameters ++ Map[String, String]("vertextype" -> "dummy_vertex") 44 | val edgeParams = parameters ++ Map[String, String]("edgetype" -> "dummy_edge") 45 | 46 | vertex_connection = orientDBGraphVertexWrapper 47 | .getConnection(Parameters.mergeParameters(vertexParams)) 48 | edge_connection = orientDBGraphEdgeWrapper 49 | .getConnection(Parameters.mergeParameters(edgeParams)) 50 | } 51 | 52 | override def afterAll(): Unit = { 53 | try { 54 | orientDBGraphVertexWrapper.close() 55 | orientDBGraphEdgeWrapper.close() 56 | } finally { 57 | try { 58 | sc.stop() 59 | } finally { 60 | super.afterAll() 61 | } 62 | } 63 | } 64 | 65 | override protected def beforeEach(): Unit = { 66 | super.beforeEach() 67 | sqlContext = new SQLContext(sc) 68 | } 69 | 70 | def testRoundtripSaveAndLoadForVertices(vertexType: String, 71 | df: DataFrame, 72 | expectedSchemaAfterLoad: Option[StructType] = None, 73 | saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = { 74 | try { 75 | df.write 76 | .format("org.apache.spark.orientdb.graphs") 77 | .option("dburl", ORIENTDB_CONNECTION_URL) 78 | .option("user", ORIENTDB_USER) 79 | .option("password", ORIENTDB_PASSWORD) 80 | .option("vertextype", vertexType) 81 | .mode(saveMode) 82 | .save() 83 | 84 | if (!orientDBGraphVertexWrapper.doesVertexTypeExists(vertexType)) { 85 | Thread.sleep(1000) 86 | assert(orientDBGraphVertexWrapper.doesVertexTypeExists(vertexType)) 87 | } 88 | 89 | val loadedDf = sqlContext.read 90 | .format("org.apache.spark.orientdb.graphs") 91 | .option("dburl", ORIENTDB_CONNECTION_URL) 92 | .option("user", ORIENTDB_USER) 93 | .option("password", ORIENTDB_PASSWORD) 94 | .option("vertextype", vertexType) 95 | .load() 96 | 97 | loadedDf.schema.fields.foreach(field => assert(df.schema.fields.contains(field))) 98 | checkAnswer(loadedDf, df.collect()) 99 | } finally { 100 | orientDBGraphVertexWrapper.delete(vertexType, null) 101 | vertex_connection.dropVertexType(vertexType) 102 | } 103 | } 104 | 105 | def testRoundtripSaveAndLoadForEdges( vertexType: String, 106 | edgeType: String, 107 | vertexDf: DataFrame, 108 | df: DataFrame, 109 | expectedSchemaAfterLoad: Option[StructType] = None, 110 | saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = { 111 | try { 112 | vertexDf.write 113 | .format("org.apache.spark.orientdb.graphs") 114 | .option("dburl", ORIENTDB_CONNECTION_URL) 115 | .option("user", ORIENTDB_USER) 116 | .option("password", ORIENTDB_PASSWORD) 117 | .option("vertextype", vertexType) 118 | .mode(saveMode) 119 | .save() 120 | 121 | df.write 122 | .format("org.apache.spark.orientdb.graphs") 123 | .option("dburl", ORIENTDB_CONNECTION_URL) 124 | .option("user", ORIENTDB_USER) 125 | .option("password", ORIENTDB_PASSWORD) 126 | .option("vertextype", vertexType) 127 | .option("edgetype", edgeType) 128 | .mode(saveMode) 129 | .save() 130 | 131 | if (!orientDBGraphEdgeWrapper.doesEdgeTypeExists(edgeType)) { 132 | Thread.sleep(1000) 133 | assert(orientDBGraphEdgeWrapper.doesEdgeTypeExists(edgeType)) 134 | } 135 | 136 | val loadedDf = sqlContext.read 137 | .format("org.apache.spark.orientdb.graphs") 138 | .option("dburl", ORIENTDB_CONNECTION_URL) 139 | .option("user", ORIENTDB_USER) 140 | .option("password", ORIENTDB_PASSWORD) 141 | .option("edgetype", edgeType) 142 | .load() 143 | 144 | loadedDf.schema.fields.foreach(field => assert(df.schema.fields.contains(field))) 145 | assert(loadedDf.count() === df.collect().size) 146 | } finally { 147 | try { 148 | orientDBGraphEdgeWrapper.delete(edgeType, null) 149 | orientDBGraphVertexWrapper.delete(vertexType, null) 150 | } finally { 151 | edge_connection.dropEdgeType(edgeType) 152 | vertex_connection.dropVertexType(vertexType) 153 | } 154 | } 155 | } 156 | } -------------------------------------------------------------------------------- /src/main/examples/org/apache/spark/orientdb/documents/DataFrameTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import org.apache.spark.sql.{SQLContext, SaveMode, SparkSession} 4 | import org.apache.spark.{SparkConf, SparkContext} 5 | 6 | object DataFrameTest extends App { 7 | val spark = SparkSession.builder().appName("DataFrameTest").master("local[*]").getOrCreate() 8 | 9 | import spark.implicits._ 10 | val df = spark.sparkContext.parallelize(Array(1, 2, 3, 4, 5)).toDF("id") 11 | 12 | df.write.format("org.apache.spark.orientdb.documents") 13 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts") 14 | .option("user", "root") 15 | .option("password", "root") 16 | .option("class", "test_class") 17 | .mode(SaveMode.Overwrite) 18 | .save() 19 | 20 | val resultDf = spark.sqlContext.read 21 | .format("org.apache.spark.orientdb.documents") 22 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts") 23 | .option("user", "root") 24 | .option("password", "root") 25 | .option("class", "test_class") 26 | .load() 27 | 28 | resultDf.show() 29 | } -------------------------------------------------------------------------------- /src/main/examples/org/apache/spark/orientdb/graphs/GraphFrameTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import org.apache.spark.{SparkConf, SparkContext} 4 | import org.apache.spark.sql.{Row, SQLContext, SaveMode, SparkSession} 5 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 6 | import org.graphframes.GraphFrame 7 | 8 | object GraphFrameTest extends App { 9 | val spark = SparkSession.builder().appName("MainApplication").master("local[*]").getOrCreate() 10 | 11 | val sc = spark.sparkContext 12 | sc.setLogLevel("WARN") 13 | val sqlContext = spark.sqlContext 14 | 15 | import sqlContext.implicits._ 16 | val df = sc.parallelize(Array(1, 2, 3, 4, 5)).toDF("id") 17 | 18 | df.write.format("org.apache.spark.orientdb.graphs") 19 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts") 20 | .option("user", "root") 21 | .option("password", "root") 22 | .option("vertextype", "v104") 23 | .mode(SaveMode.Overwrite) 24 | .save() 25 | 26 | val vertices = sqlContext.read 27 | .format("org.apache.spark.orientdb.graphs") 28 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts") 29 | .option("user", "root") 30 | .option("password", "root") 31 | .option("vertextype", "v104") 32 | .load() 33 | 34 | var inVertex: Integer = null 35 | var outVertex: Integer = null 36 | vertices.collect().foreach(row => { 37 | if (inVertex == null) { 38 | inVertex = row.getAs[Integer]("id") 39 | } 40 | if (outVertex == null) { 41 | outVertex = row.getAs[Integer]("id") 42 | } 43 | }) 44 | 45 | val df1 = sqlContext.createDataFrame(sc.parallelize(Seq(Row("friends", "1", "2"), 46 | Row("enemies", "2", "3"), Row("friends", "3", "1"))), 47 | StructType(List(StructField("relationship", StringType), StructField("src", StringType), 48 | StructField("dst", StringType)))) 49 | 50 | df1.write.format("org.apache.spark.orientdb.graphs") 51 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts") 52 | .option("user", "root") 53 | .option("password", "root") 54 | .option("vertextype", "v104") 55 | .option("edgetype", "e104") 56 | .mode(SaveMode.Overwrite) 57 | .save() 58 | 59 | val edges = sqlContext.read 60 | .format("org.apache.spark.orientdb.graphs") 61 | .option("dburl", "remote:127.0.0.1:2424/GratefulDeadConcerts") 62 | .option("user", "root") 63 | .option("password", "root") 64 | .option("edgetype", "e104") 65 | .load() 66 | 67 | edges.show() 68 | 69 | val g = GraphFrame(vertices, edges) 70 | g.inDegrees.show() 71 | println(g.edges.filter("relationship = 'friends'").count()) 72 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/Conversions.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import java.sql.{Date, Timestamp} 4 | import java.text.SimpleDateFormat 5 | import java.util 6 | import java.util.{Locale, Map} 7 | import java.util.function.Consumer 8 | 9 | import com.orientechnologies.orient.core.db.record._ 10 | import com.orientechnologies.orient.core.db.record.ridbag.ORidBag 11 | import com.orientechnologies.orient.core.id.ORecordId 12 | import com.orientechnologies.orient.core.metadata.schema.OType 13 | import com.orientechnologies.orient.core.record.ORecord 14 | import com.orientechnologies.orient.core.record.impl.ODocument 15 | import com.tinkerpop.blueprints.{Edge, Vertex} 16 | import org.apache.spark.orientdb.udts._ 17 | import org.apache.spark.sql.Row 18 | import org.apache.spark.sql.types._ 19 | 20 | import scala.collection.JavaConversions._ 21 | import scala.collection.mutable 22 | 23 | private[orientdb] object Conversions { 24 | def sparkDTtoOrientDBDT(dataType: DataType): OType = { 25 | dataType match { 26 | case ByteType => OType.BYTE 27 | case ShortType => OType.SHORT 28 | case IntegerType => OType.INTEGER 29 | case LongType => OType.LONG 30 | case FloatType => OType.FLOAT 31 | case DoubleType => OType.DOUBLE 32 | case _: DecimalType => OType.DECIMAL 33 | case StringType => OType.STRING 34 | case BinaryType => OType.BINARY 35 | case BooleanType => OType.BOOLEAN 36 | case DateType => OType.DATE 37 | case TimestampType => OType.DATETIME 38 | case _: EmbeddedListType => OType.EMBEDDEDLIST 39 | case _: EmbeddedSetType => OType.EMBEDDEDSET 40 | case _: EmbeddedMapType => OType.EMBEDDEDMAP 41 | case _: LinkListType => OType.LINKLIST 42 | case _: LinkSetType => OType.LINKSET 43 | case _: LinkMapType => OType.LINKMAP 44 | case _: LinkBagType => OType.LINKBAG 45 | case _: LinkType => OType.LINK 46 | case other => throw new UnsupportedOperationException(s"Unexpected DataType $dataType") 47 | } 48 | } 49 | 50 | /* def orientDBDTtoSparkDT(dataType: OType): DataType = { 51 | dataType match { 52 | case OType.BYTE => ByteType 53 | case OType.SHORT => ShortType 54 | case OType.INTEGER => IntegerType 55 | case OType.LONG => LongType 56 | case OType.FLOAT => FloatType 57 | case OType.DOUBLE => DoubleType 58 | case OType.DECIMAL => 59 | DecimalType(DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE) 60 | case OType.STRING => StringType 61 | case OType.BINARY => BinaryType 62 | case OType.BOOLEAN => BooleanType 63 | case OType.DATE => DateType 64 | case OType.DATETIME => TimestampType 65 | case other => throw new UnsupportedOperationException(s"Unexpected DataType $dataType") 66 | } 67 | } */ 68 | 69 | def convertRowsToODocuments(row: Row): ODocument = { 70 | val oDocument = new ODocument() 71 | row.schema.fields.foreach(field => { 72 | oDocument.field(field.name, getField(row, field), 73 | Conversions.sparkDTtoOrientDBDT(field.dataType)) 74 | }) 75 | oDocument 76 | } 77 | 78 | def orientDBDTtoSparkDT(dataType: DataType, field: AnyRef) = { 79 | if (field == null) 80 | field 81 | else { 82 | val dateFormat = new SimpleDateFormat("E MMM dd HH:mm:ss Z yyyy", Locale.ENGLISH) 83 | dataType match { 84 | case ByteType => java.lang.Byte.valueOf(field.toString) 85 | case ShortType => field.toString.toShort 86 | case IntegerType => field.toString.toInt 87 | case LongType => field.toString.toLong 88 | case FloatType => field.toString.toFloat 89 | case DoubleType => field.toString.toDouble 90 | case _: DecimalType => field.asInstanceOf[java.math.BigDecimal] 91 | case StringType => field 92 | case BinaryType => field 93 | case BooleanType => field.toString.toBoolean 94 | case DateType => new Date(dateFormat.parse(field.toString).getTime) 95 | case TimestampType => new Timestamp(dateFormat.parse(field.toString).getTime) 96 | case _: EmbeddedListType => 97 | var elements = Array[Any]() 98 | field.asInstanceOf[util.ArrayList[Any]].forEach(new Consumer[Any] { 99 | override def accept(t: Any): Unit = elements :+= t 100 | }) 101 | new EmbeddedList(elements) 102 | case _: EmbeddedSetType => 103 | var elements = Array[Any]() 104 | field.asInstanceOf[OTrackedSet[Any]].forEach(new Consumer[Any] { 105 | override def accept(t: Any): Unit = elements :+= t 106 | }) 107 | new EmbeddedSet(elements) 108 | case _: EmbeddedMapType => 109 | var elements = mutable.Map[Any, Any]() 110 | field.asInstanceOf[OTrackedMap[Any]].entrySet().forEach(new Consumer[Map.Entry[AnyRef, Any]] { 111 | override def accept(t: Map.Entry[AnyRef, Any]): Unit = { 112 | elements.put(t.getKey, t.getValue) 113 | } 114 | }) 115 | new EmbeddedMap(elements.toMap) 116 | case _: LinkListType => 117 | var elements = Array[ORecord]() 118 | field.asInstanceOf[ORecordLazyList].forEach(new Consumer[OIdentifiable] { 119 | override def accept(t: OIdentifiable): Unit = t match { 120 | case recordId: ORecordId => 121 | elements:+= new ODocument(recordId).asInstanceOf[ORecord] 122 | case record: ORecord => 123 | elements :+= record 124 | } 125 | }) 126 | new LinkList(elements) 127 | case _: LinkSetType => 128 | var elements = Array[ORecord]() 129 | field.asInstanceOf[ORecordLazySet].forEach(new Consumer[OIdentifiable] { 130 | override def accept(t: OIdentifiable): Unit = t match { 131 | case recordId: ORecordId => 132 | elements:+= new ODocument(recordId).asInstanceOf[ORecord] 133 | case record: ORecord => 134 | elements :+= record 135 | } 136 | }) 137 | new LinkSet(elements) 138 | case _: LinkMapType => 139 | val elements = mutable.Map[String, ORecord]() 140 | field.asInstanceOf[ORecordLazyMap].entrySet().forEach(new Consumer[Map.Entry[AnyRef, OIdentifiable]] { 141 | override def accept(t: Map.Entry[AnyRef, OIdentifiable]): Unit = { 142 | elements.put(t.getKey.toString, t.getValue match { 143 | case recordId: ORecordId => 144 | new ODocument(recordId).asInstanceOf[ORecord] 145 | case record: ORecord => 146 | record 147 | }) 148 | } 149 | }) 150 | new LinkMap(elements.toMap) 151 | case _: LinkBagType => 152 | var elements = Array[ORecordId]() 153 | field.asInstanceOf[ORidBag].rawIterator().forEachRemaining(new Consumer[OIdentifiable] { 154 | override def accept(t: OIdentifiable): Unit = t match { 155 | case recordId: ORecordId => 156 | elements :+= recordId 157 | case other => sys.error(s"LinkBag cannot be of type ${other.getClass}") 158 | } 159 | }) 160 | new LinkBag(elements) 161 | case _: LinkType => 162 | val element = field.asInstanceOf[OIdentifiable].getRecord[ORecord]() 163 | new Link(element) 164 | case other => throw new UnsupportedOperationException(s"Unexpected DataType $dataType") 165 | } 166 | } 167 | } 168 | 169 | def convertRowToGraph(row: Row, count: Int): AnyRef = { 170 | getField(row, row.schema.fields(count)) 171 | } 172 | 173 | /* def orientDBDTtoSparkDT(dataType: OType, field: String) = { 174 | if (field == null) 175 | field 176 | else { 177 | val dateFormat = new SimpleDateFormat("E MMM dd HH:mm:ss Z yyyy") 178 | dataType match { 179 | case OType.BYTE => field.toByte 180 | case OType.SHORT => field.toShort 181 | case OType.INTEGER => field.toInt 182 | case OType.LONG => field.toLong 183 | case OType.FLOAT => field.toFloat 184 | case OType.DOUBLE => field.toDouble 185 | case OType.DECIMAL => field.asInstanceOf[java.math.BigDecimal] 186 | case OType.STRING => field 187 | case OType.BINARY => field 188 | case OType.BOOLEAN => field.toBoolean 189 | case OType.DATE => new Date(dateFormat.parse(field).getTime) 190 | case OType.DATETIME => new Timestamp(dateFormat.parse(field).getTime) 191 | case _ => throw new UnsupportedOperationException(s"Unexpected DataType $dataType") 192 | } 193 | } 194 | } */ 195 | 196 | def convertODocumentsToRows(oDocument: ODocument, schema: StructType): Row = { 197 | val converted: scala.collection.mutable.IndexedSeq[Any] = mutable.IndexedSeq.fill(schema.length)(null) 198 | val fieldNames = oDocument.fieldNames() 199 | val fieldValues = oDocument.fieldValues() 200 | 201 | var i = 0 202 | while (i < schema.length) { 203 | if (fieldNames.contains(schema.fields(i).name)) { 204 | val idx = fieldNames.indexOf(schema.fields(i).name) 205 | val value = fieldValues(idx) 206 | 207 | converted(i) = orientDBDTtoSparkDT(schema.fields(i).dataType, value) 208 | } else { 209 | converted(i) = null 210 | } 211 | i = i + 1 212 | } 213 | 214 | Row.fromSeq(converted) 215 | } 216 | 217 | def convertVerticesToRows(vertex: Vertex, schema: StructType): Row = { 218 | val converted: scala.collection.mutable.IndexedSeq[Any] = mutable.IndexedSeq.fill(schema.length)(null) 219 | val fieldNames = vertex.getPropertyKeys 220 | 221 | var i = 0 222 | while (i < schema.length) { 223 | if (fieldNames.contains(schema.fields(i).name)) { 224 | val value = vertex.getProperty[Object](schema.fields(i).name) 225 | 226 | converted(i) = orientDBDTtoSparkDT(schema.fields(i).dataType, value) 227 | } else { 228 | converted(i) = null 229 | } 230 | i = i + 1 231 | } 232 | 233 | Row.fromSeq(converted) 234 | } 235 | 236 | def convertEdgesToRows(edge: Edge, schema: StructType): Row = { 237 | val converted: scala.collection.mutable.IndexedSeq[Any] = mutable.IndexedSeq.fill(schema.length)(null) 238 | val fieldNames = edge.getPropertyKeys 239 | 240 | var i = 0 241 | while (i < schema.length) { 242 | if (fieldNames.contains(schema.fields(i).name)) { 243 | val value = edge.getProperty[Object](schema.fields(i).name) 244 | 245 | converted(i) = orientDBDTtoSparkDT(schema.fields(i).dataType, value) 246 | } else { 247 | converted(i) = null 248 | } 249 | i = i + 1 250 | } 251 | 252 | Row.fromSeq(converted) 253 | } 254 | 255 | private def getField(row: Row, field: StructField) = field.dataType.typeName match { 256 | case "embeddedlist" => row.getAs[field.dataType.type](field.name).asInstanceOf[EmbeddedList].elements 257 | case "embeddedset" => row.getAs[field.dataType.type](field.name).asInstanceOf[EmbeddedSet].elements 258 | case "embeddedmap" => mapAsJavaMap(row.getAs[field.dataType.type](field.name).asInstanceOf[EmbeddedMap].elements) 259 | case "linklist" => row.getAs[field.dataType.type](field.name).asInstanceOf[LinkList].elements 260 | case "linkset" => row.getAs[field.dataType.type](field.name).asInstanceOf[LinkSet].elements 261 | case "linkmap" => mapAsJavaMap(row.getAs[field.dataType.type](field.name).asInstanceOf[LinkMap].elements) 262 | case "linkbag" => 263 | val oRidBag = new ORidBag() 264 | oRidBag.addAll(row.getAs[field.dataType.type ](field.name).asInstanceOf[LinkBag].elements.toSeq) 265 | oRidBag 266 | case "link" => 267 | row.getAs[field.dataType.type](field.name).asInstanceOf[Link].element 268 | case _ => row.getAs[field.dataType.type](field.name) 269 | } 270 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider} 4 | import org.apache.spark.sql.types.StructType 5 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} 6 | import org.slf4j.LoggerFactory 7 | 8 | class DefaultSource(orientDBWrapper: OrientDBDocumentWrapper, 9 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory) 10 | extends RelationProvider 11 | with SchemaRelationProvider 12 | with CreatableRelationProvider { 13 | 14 | private val log = LoggerFactory.getLogger(getClass) 15 | 16 | def this() = this(DefaultOrientDBDocumentWrapper, orientDBCredentials => new OrientDBClientFactory(orientDBCredentials)) 17 | 18 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { 19 | val params = Parameters.mergeParameters(parameters) 20 | 21 | if (params.query.isDefined && params.className.isEmpty) { 22 | throw new IllegalArgumentException("Along with the 'query' parameter you must specify either 'class' parameter or"+ 23 | " user-defined schema") 24 | } 25 | 26 | OrientDBRelation(orientDBWrapper, orientDBClientFactory, params, None)(sqlContext) 27 | } 28 | 29 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], 30 | schema: StructType): BaseRelation = { 31 | val params = Parameters.mergeParameters(parameters) 32 | OrientDBRelation(orientDBWrapper, orientDBClientFactory, params, Some(schema))(sqlContext) 33 | } 34 | 35 | override def createRelation(sqlContext: SQLContext, mode: SaveMode, 36 | parameters: Map[String, String], data: DataFrame): BaseRelation = { 37 | val params = Parameters.mergeParameters(parameters) 38 | val classname = params.className.getOrElse{ 39 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " + 40 | "name with the 'classname' parameter") 41 | } 42 | 43 | val clusters = params.clusterNames.getOrElse{ 44 | log.warn("Orient DB cluster name not specified. Using default Cluster Id for the class specified") 45 | } 46 | 47 | def tableExists: Boolean = { 48 | val connection = orientDBWrapper.getConnection(params) 49 | try { 50 | orientDBWrapper.doesClassExists(classname) 51 | } finally { 52 | connection.close() 53 | } 54 | } 55 | 56 | val (doSave, dropExisting) = mode match { 57 | case SaveMode.Append => (true, false) 58 | case SaveMode.Overwrite => (true, true) 59 | case SaveMode.ErrorIfExists => 60 | if (tableExists) { 61 | sys.error(s"Class $classname already exists! (SaveMode is set to ErrorIfExists)") 62 | } else { 63 | (true, false) 64 | } 65 | case SaveMode.Ignore => 66 | if (tableExists) { 67 | log.info(s"Class $classname already exists. Ignoring save request.") 68 | (false, false) 69 | } else { 70 | (true, false) 71 | } 72 | } 73 | 74 | if (doSave) { 75 | val updatedParams = parameters.updated("overwrite", dropExisting.toString) 76 | new OrientDBWriter(orientDBWrapper, orientDBClientFactory) 77 | .saveToOrientDB(data, mode, Parameters.mergeParameters(updatedParams)) 78 | } 79 | createRelation(sqlContext, parameters) 80 | } 81 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/FilterPushdown.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import java.sql.{Date, Timestamp} 4 | 5 | import org.apache.spark.sql.sources._ 6 | import org.apache.spark.sql.types._ 7 | 8 | private[orientdb] object FilterPushdown { 9 | /** 10 | * Build a SQL WHERE clause for the given filters. If a filter cannot be pushed down then no 11 | * condition will be added to the WHERE clause. If none of the filters can be pushed down then 12 | * an empty string will be returned. 13 | * 14 | * @param schema the schema of the table being queried 15 | * @param filters an array of filters, the conjunction of which is the filter condition for the 16 | * scan. 17 | */ 18 | def buildWhereClause(schema: StructType, filters: Seq[Filter]): String = { 19 | val filterExpressions = filters.flatMap(f => buildFilterExpression(schema, f)).mkString(" AND ") 20 | if (filterExpressions.isEmpty) "" else "WHERE " + filterExpressions 21 | } 22 | 23 | /** 24 | * Attempt to convert the given filter into a SQL expression. Returns None if the expression 25 | * could not be converted. 26 | */ 27 | def buildFilterExpression(schema: StructType, filter: Filter): Option[String] = { 28 | def buildComparison(attr: String, value: Any, comparisonOp: String): Option[String] = { 29 | getTypeForAttribute(schema, attr).map { dataType => 30 | val sqlEscapedValue: String = dataType match { 31 | case StringType => s"'${value.toString.replace("'", "\\'\\'")}'" 32 | case DateType => s"'${value.asInstanceOf[Date]}'" 33 | case TimestampType => s"'${value.asInstanceOf[Timestamp]}'" 34 | case _ => value.toString 35 | } 36 | s"""$attr $comparisonOp $sqlEscapedValue""" 37 | } 38 | } 39 | 40 | filter match { 41 | case EqualTo(attr, value) => buildComparison(attr, value, "=") 42 | case LessThan(attr, value) => buildComparison(attr, value, "<") 43 | case GreaterThan(attr, value) => buildComparison(attr, value, ">") 44 | case LessThanOrEqual(attr, value) => buildComparison(attr, value, "<=") 45 | case GreaterThanOrEqual(attr, value) => buildComparison(attr, value, ">=") 46 | case IsNotNull(attr) => 47 | getTypeForAttribute(schema, attr).map(dataType => s"""$attr IS NOT NULL""") 48 | case IsNull(attr) => 49 | getTypeForAttribute(schema, attr).map(dataType => s"""$attr IS NULL""") 50 | case _ => None 51 | } 52 | } 53 | 54 | /** 55 | * Use the given schema to look up the attribute's data type. Returns None if the attribute could 56 | * not be resolved. 57 | */ 58 | private def getTypeForAttribute(schema: StructType, attribute: String): Option[DataType] = { 59 | if (schema.fieldNames.contains(attribute)) { 60 | Some(schema(attribute).dataType) 61 | } else { 62 | None 63 | } 64 | } 65 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/OrientDBClientFactory.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import com.orientechnologies.orient.core.db.OPartitionedDatabasePool 4 | import com.orientechnologies.orient.core.db.document.ODatabaseDocument 5 | 6 | class OrientDBClientFactory(orientDBCredentials: OrientDBCredentials) extends Serializable { 7 | private val db: OPartitionedDatabasePool = 8 | new OPartitionedDatabasePool(orientDBCredentials.dbUrl, orientDBCredentials.username, 9 | orientDBCredentials.password) 10 | private var connection: ODatabaseDocument = _ 11 | 12 | def getConnection(): ODatabaseDocument = { 13 | connection = db.acquire() 14 | connection 15 | } 16 | 17 | def closeConnection(): Unit = { 18 | connection.close() 19 | } 20 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/OrientDBCredentials.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | trait OrientDBCredentials extends Serializable { 4 | var dbUrl: String = null 5 | var username: String = null 6 | var password: String = null 7 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/OrientDBDocumentWrapper.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import java.util 4 | import java.util.function.Consumer 5 | 6 | import com.orientechnologies.orient.core.db.document.ODatabaseDocument 7 | import com.orientechnologies.orient.core.metadata.schema.OType 8 | import com.orientechnologies.orient.core.record.OElement 9 | import com.orientechnologies.orient.core.record.impl.ODocument 10 | import com.orientechnologies.orient.core.sql.executor.OResult 11 | import com.orientechnologies.orient.core.sql.query.OSQLSynchQuery 12 | import com.orientechnologies.orient.core.tx.OTransaction.TXTYPE 13 | import org.apache.spark.orientdb.documents.Parameters.MergedParameters 14 | import org.apache.spark.orientdb.udts._ 15 | import org.apache.spark.sql.types._ 16 | 17 | import scala.collection.JavaConversions._ 18 | 19 | class OrientDBDocumentWrapper extends Serializable { 20 | private var connectionPool: Option[OrientDBClientFactory] = None 21 | private var connection: ODatabaseDocument = _ 22 | 23 | def openTransaction(): Unit = { 24 | connection.begin(TXTYPE.OPTIMISTIC) 25 | } 26 | 27 | def commitTransaction(): Unit = { 28 | connection.commit() 29 | } 30 | 31 | def rollbackTransaction(): Unit = { 32 | connection.rollback() 33 | } 34 | 35 | /** 36 | * Get instance of Database connection 37 | * @return Database connection instance 38 | */ 39 | def getConnection(params: MergedParameters): ODatabaseDocument = { 40 | try { 41 | if (connectionPool.isEmpty) { 42 | connectionPool = Some(new OrientDBClientFactory(new OrientDBCredentials { 43 | this.dbUrl = params.dbUrl.get 44 | this.username = params.credentials.get._1 45 | this.password = params.credentials.get._2 46 | })) 47 | } 48 | connection = connectionPool.get.getConnection() 49 | connection 50 | } catch { 51 | case e: Exception => throw new RuntimeException(s"Connection Exception Occurred: ${e.getMessage}") 52 | } 53 | } 54 | 55 | /** 56 | * Check if class already exists in Orient DB 57 | * @param classname cluster name in OrientDB 58 | * @return true/false 59 | */ 60 | def doesClassExists(classname: String): Boolean = { 61 | val schema = connection.getMetadata.getSchema 62 | schema.existsClass(classname) 63 | } 64 | 65 | /** 66 | * Create API 67 | * @param cluster cluster name in OrientDB 68 | * @param classname class name in OrientDB 69 | * @param document document to be created 70 | */ 71 | def create(cluster: String, classname: String, document: ODocument): Boolean = { 72 | if (connection.save(document, cluster) != null) 73 | return true 74 | false 75 | } 76 | 77 | /** 78 | * Read API 79 | * @param clusters cluster name in OrientDB 80 | * @param classname class name in OrientDB 81 | * @param filters filters the no. of records retrieved 82 | * @return 83 | */ 84 | def read(clusters: List[String], classname: String, requiredColumns: Array[String], 85 | filters: String, query: String = null): List[ODocument] = { 86 | var documents: java.util.List[ODocument] = new util.ArrayList[ODocument]() 87 | val columns = requiredColumns.mkString(", ") 88 | 89 | if (query == null) { 90 | val oResultSet = connection.query(s"select $columns from $classname $filters") 91 | oResultSet.map(_.toElement).foreach { 92 | case oDocument: ODocument => 93 | documents.add(oDocument) 94 | case _ => 95 | throw new RuntimeException("Result is not of type ODocument") 96 | } 97 | } else { 98 | var queryStr = "" 99 | 100 | if (filters != "") { 101 | if (query.contains("WHERE ")) { 102 | val parts = query.split("WHERE ") 103 | 104 | if ( parts.size > 1) { 105 | val firstpart = parts(0) 106 | val secondpart = parts(1) 107 | 108 | queryStr = s"$firstpart $filters and $secondpart" 109 | } else { 110 | queryStr = s"$query $filters" 111 | } 112 | } else if (query.contains("where ")) { 113 | val parts = query.split("where ") 114 | 115 | if ( parts.size > 1) { 116 | val firstpart = parts(0) 117 | val secondpart = parts(1) 118 | 119 | queryStr = s"$firstpart $filters and $secondpart" 120 | } else { 121 | queryStr = s"$query $filters" 122 | } 123 | } else { 124 | queryStr = s"$query $filters" 125 | } 126 | } else queryStr = query 127 | val oResultSet = connection.query(queryStr) 128 | oResultSet.map(_.toElement).foreach { 129 | case oDocument: ODocument => 130 | documents.add(oDocument) 131 | case _ => 132 | throw new RuntimeException("Result is not of type ODocument") 133 | } 134 | } 135 | 136 | documents.toList 137 | } 138 | 139 | /** 140 | * Bulk Create API 141 | * @param cluster cluster name in OrientDB 142 | * @param classname class name in OrientDB 143 | * @param documents List of documents to be created 144 | */ 145 | def bulkCreate(cluster: String, classname: String, documents: List[ODocument]): Boolean = { 146 | try { 147 | openTransaction() 148 | documents.foreach(document => { 149 | if (!create(cluster, classname, document)) { 150 | rollbackTransaction() 151 | return false 152 | } 153 | 154 | }) 155 | commitTransaction() 156 | true 157 | } catch { 158 | case e: Exception => 159 | rollbackTransaction() 160 | throw new RuntimeException("An exception was thrown: " + e.getMessage) 161 | } 162 | } 163 | 164 | /** 165 | * Update API 166 | * @param cluster cluster name in OrientDB 167 | * @param classname class name in OrientDB 168 | * @param record new record to be updated in the format (fieldName, (fieldValue, fieldType)) 169 | * @param filter filter which filters records to be updated in the format (fieldName, (filterOperator, fieldValue)) 170 | */ 171 | def update(cluster: String, classname: String, record: Map[String, Tuple2[String, OType]], 172 | filter: Map[String, Tuple2[String, String]]): Boolean = { 173 | var filterStr = "" 174 | 175 | val filterLength = filter.size 176 | var count = 1 177 | filter.foreach(p => { 178 | if (count == filterLength) 179 | // handle int s, only strings are handled 180 | filterStr += p._1 + " " + p._2._1 + " '" + p._2._2 + "'" 181 | else 182 | filterStr += p._1 + " " + p._2._1 + " '" + p._2._2 + "' and " 183 | count += 1 184 | }) 185 | 186 | val documentsToBeUpdated: java.util.List[ODocument] = new util.ArrayList[ODocument]() 187 | val oResultSet = connection.query("select * from " + classname + " where " + filterStr) 188 | 189 | oResultSet.map(_.toElement).foreach { 190 | case oDocument: ODocument => 191 | documentsToBeUpdated.add(oDocument) 192 | case _ => 193 | throw new RuntimeException("Result is not of type ODocument") 194 | } 195 | 196 | try { 197 | openTransaction() 198 | documentsToBeUpdated.foreach(document => { 199 | record.foreach(field => { 200 | document.field(field._1, field._2._1, field._2._2) 201 | }) 202 | if (!create(cluster, classname, document)) { 203 | rollbackTransaction() 204 | return false 205 | } 206 | }) 207 | commitTransaction() 208 | true 209 | } catch { 210 | case e: Exception => 211 | rollbackTransaction() 212 | throw new RuntimeException("An exception was thrown: " + e.getMessage) 213 | } 214 | } 215 | 216 | /** 217 | * Delete API 218 | * @param cluster cluster name in OrientDB 219 | * @param classname class name in OrientDB 220 | * @param filter filter which filters records to be deleted in the format (fieldName, (filterOperator, fieldValue)) 221 | */ 222 | def delete(cluster: String, classname: String, filter: Map[String, Tuple2[String, String]]): Boolean = { 223 | try { 224 | var filterStr = "" 225 | 226 | if (filter != null) { 227 | val filterLength = filter.size 228 | var count = 1 229 | filter.foreach(p => { 230 | if (count == filterLength) 231 | // handle int s, only strings are handled 232 | filterStr += p._1 + " " + p._2._1 + " '" + p._2._2 + "'" 233 | else 234 | filterStr += p._1 + " " + p._2._1 + " '" + p._2._2 + "' and " 235 | count += 1 236 | }) 237 | } 238 | 239 | val documentsToBeDeleted: java.util.List[ODocument] = new util.ArrayList[ODocument]() 240 | if (filterStr != "") { 241 | val oResultSet = connection.query("select * from " + classname + " where " + filterStr) 242 | oResultSet.map(_.toElement).foreach { 243 | case oDocument: ODocument => 244 | documentsToBeDeleted.add(oDocument) 245 | case _ => 246 | throw new RuntimeException("Result is not of type ODocument") 247 | } 248 | } else { 249 | val oResultSet = connection.query("select * from " + classname) 250 | oResultSet.map(_.toElement).foreach { 251 | case oDocument: ODocument => 252 | documentsToBeDeleted.add(oDocument) 253 | case _ => 254 | throw new RuntimeException("Result is not of type ODocument") 255 | } 256 | } 257 | 258 | openTransaction() 259 | documentsToBeDeleted.foreach(document => { 260 | if (connection.delete(document) == null) { 261 | rollbackTransaction() 262 | return false 263 | } 264 | }) 265 | commitTransaction() 266 | true 267 | } catch { 268 | case e: Exception => 269 | rollbackTransaction() 270 | throw new RuntimeException("An exception was thrown: " + e.getMessage) 271 | } 272 | } 273 | 274 | /** 275 | * Resolve OrientDB class metadata 276 | * @param cluster cluster name in OrientDB 277 | * @param classname class name in OrientDB 278 | * @return 279 | */ 280 | def resolveTable(cluster: String, classname: String): StructType = { 281 | val oClass = connection.getMetadata.getSchema.getClass(classname) 282 | 283 | val properties = oClass.properties() 284 | val ncols = properties.size() 285 | val fields = new Array[StructField](ncols) 286 | val iterator = properties.iterator() 287 | var i = 0 288 | while (iterator.hasNext) { 289 | val property = iterator.next() 290 | val columnName = property.getName 291 | // there are no keys..hence for now every field is nullable 292 | val nullable = true 293 | val columnType = getCatalystType(property.getType) 294 | fields(i) = StructField(columnName, columnType, nullable) 295 | i = i + 1 296 | } 297 | new StructType(fields) 298 | } 299 | 300 | /** 301 | * execute generic query on OrientDB 302 | */ 303 | def genericQuery(query: String): List[ODocument] = { 304 | val documents: java.util.List[ODocument] = connection.query(new OSQLSynchQuery[ODocument](query)) 305 | documents.toList 306 | } 307 | 308 | private def getCatalystType(oType: OType): DataType = { 309 | val dataType = oType match { 310 | case OType.BOOLEAN => BooleanType 311 | case OType.INTEGER => IntegerType 312 | case OType.SHORT => ShortType 313 | case OType.LONG => LongType 314 | case OType.FLOAT => FloatType 315 | case OType.DOUBLE => DoubleType 316 | case OType.DATETIME => TimestampType 317 | case OType.STRING => StringType 318 | case OType.BINARY => BinaryType 319 | case OType.BYTE => ByteType 320 | case OType.DATE => DateType 321 | case OType.DECIMAL => DecimalType(38, 18) 322 | case OType.EMBEDDEDLIST => new EmbeddedListType 323 | case OType.EMBEDDEDSET => new EmbeddedSetType 324 | case OType.EMBEDDEDMAP => new EmbeddedMapType 325 | case OType.LINKLIST => new LinkListType 326 | case OType.LINKSET => new LinkSetType 327 | case OType.LINKMAP => new LinkMapType 328 | case OType.LINKBAG => new LinkBagType 329 | case OType.LINK => new LinkType 330 | case OType.ANY => null 331 | } 332 | dataType 333 | } 334 | } 335 | 336 | private[orientdb] object DefaultOrientDBDocumentWrapper extends OrientDBDocumentWrapper -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/OrientDBRelation.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import com.orientechnologies.orient.core.record.impl.ODocument 4 | import Parameters.MergedParameters 5 | import org.apache.spark.rdd.RDD 6 | import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan} 7 | import org.apache.spark.sql.types.StructType 8 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} 9 | import org.slf4j.LoggerFactory 10 | 11 | private[orientdb] case class OrientDBRelation( 12 | orientDBWrapper: OrientDBDocumentWrapper, 13 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory, 14 | params: MergedParameters, 15 | userSchema: Option[StructType] 16 | ) (@transient val sqlContext: SQLContext) 17 | extends BaseRelation 18 | with PrunedFilteredScan 19 | with InsertableRelation { 20 | private val log = LoggerFactory.getLogger(getClass) 21 | 22 | // any kind of assertion 23 | 24 | private val tableNameOrSubQuery = params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get 25 | 26 | override lazy val schema: StructType = { 27 | userSchema.getOrElse{ 28 | val tableName = params.table.map(_.toString).get 29 | val conn = orientDBWrapper.getConnection(params) 30 | 31 | try { 32 | orientDBWrapper.resolveTable(null, tableName) 33 | } finally { 34 | conn.close() 35 | } 36 | } 37 | } 38 | 39 | override def toString: String = s"""OrientDBRelation($tableNameOrSubQuery)""" 40 | 41 | override def insert(data: DataFrame, overwrite: Boolean): Unit = { 42 | val saveMode = if (overwrite) { 43 | SaveMode.Overwrite 44 | } else { 45 | SaveMode.Append 46 | } 47 | val writer = new OrientDBWriter(orientDBWrapper, orientDBClientFactory) 48 | writer.saveToOrientDB(data, saveMode, params) 49 | } 50 | 51 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { 52 | filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined) 53 | } 54 | 55 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { 56 | if (requiredColumns.isEmpty) { 57 | val whereClause = FilterPushdown.buildWhereClause(schema, filters) 58 | var countQuery = s"select count(*) from $tableNameOrSubQuery $whereClause" 59 | 60 | if (params.query.isDefined) { 61 | countQuery = tableNameOrSubQuery.drop(1).dropRight(1) 62 | } 63 | 64 | log.info("count query") 65 | val connection = orientDBWrapper.getConnection(params) 66 | 67 | try { 68 | // todo use future 69 | val results = orientDBWrapper.genericQuery(countQuery) 70 | if (params.query.isEmpty && results.nonEmpty) { 71 | val numRows: Long = results.head.field("count") 72 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt 73 | val emptyRow = Row.empty 74 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow) 75 | } else if (params.query.isDefined) { 76 | val numRows: Long = results.length 77 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt 78 | val emptyRow = Row.empty 79 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow) 80 | } else { 81 | throw new IllegalStateException("Cannot read count from OrientDB") 82 | } 83 | } finally { 84 | connection.close() 85 | } 86 | } else { 87 | var classname: String = null 88 | var clusters: List[String] = null 89 | if (params.query.isEmpty) { 90 | classname = params.className match { 91 | case Some(className) => className 92 | case None => 93 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " + 94 | "name with the 'classname' parameter") 95 | } 96 | clusters = params.clusterNames match { 97 | case Some(clusterName) => clusterName 98 | case None => 99 | val connection = orientDBWrapper.getConnection(params) 100 | val schema = connection.getMetadata.getSchema 101 | val currClass = schema.getClass(classname) 102 | List(connection.getClusterNameById(currClass.getDefaultClusterId)) 103 | } 104 | } 105 | 106 | val filterStr = FilterPushdown.buildWhereClause(schema, filters) 107 | val connection = orientDBWrapper.getConnection(params) 108 | var oDocuments: List[ODocument] = List() 109 | try { 110 | // todo use Future 111 | if (params.query.isEmpty) { 112 | oDocuments = orientDBWrapper.read(clusters, classname, requiredColumns, filterStr, null) 113 | } else { 114 | oDocuments = orientDBWrapper 115 | .read(null, null, requiredColumns, filterStr, params.query.get) 116 | } 117 | } finally { 118 | connection.close() 119 | } 120 | 121 | if (params.query.isEmpty) { 122 | val prunedSchema = pruneSchema(schema, requiredColumns) 123 | sqlContext.sparkContext.makeRDD( 124 | oDocuments.map(oDocument => Conversions.convertODocumentsToRows(oDocument, prunedSchema)) 125 | ) 126 | } else { 127 | assert(oDocuments.nonEmpty) 128 | val prunedSchema = pruneSchema(schema, chooseRecordForSchema(oDocuments).fieldNames()) 129 | sqlContext.sparkContext.makeRDD( 130 | oDocuments.map(oDocument => Conversions.convertODocumentsToRows(oDocument, prunedSchema)) 131 | ) 132 | } 133 | } 134 | } 135 | 136 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { 137 | new StructType(schema.fields.filter(p => columns.contains(p.name))) 138 | } 139 | 140 | private def chooseRecordForSchema(oDocuments: List[ODocument]): ODocument = { 141 | var maxLen = -1 142 | var idx: ODocument = null 143 | oDocuments.foreach(oDocument => 144 | if (maxLen < oDocument.fieldNames().length) { 145 | idx = oDocument 146 | maxLen = oDocument.fieldNames().length 147 | }) 148 | idx 149 | } 150 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/OrientDBWriter.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import com.orientechnologies.orient.core.db.document.{ODatabaseDocument, ODatabaseDocumentPool, ODatabaseDocumentTxPooled} 4 | import Parameters.MergedParameters 5 | import org.apache.spark.sql.{DataFrame, Row, SaveMode} 6 | import org.slf4j.LoggerFactory 7 | 8 | private[orientdb] class OrientDBWriter(orientDBWrapper: OrientDBDocumentWrapper, 9 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory) 10 | extends Serializable { 11 | private val log = LoggerFactory.getLogger(getClass) 12 | 13 | private[orientdb] def createOrientDBClass(data: DataFrame, params: MergedParameters): Unit = { 14 | val dfSchema = data.schema 15 | val classname = params.className match { 16 | case Some(className) => className 17 | case None => throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " + 18 | "name with the 'classname' parameter") 19 | } 20 | var clusters = params.clusterNames match { 21 | case Some(clusterNames) => clusterNames 22 | case None => null 23 | } 24 | 25 | val connector = orientDBWrapper.getConnection(params) 26 | val schema = connector.getMetadata.getSchema 27 | val createdClass = schema.createClass(classname) 28 | 29 | dfSchema.foreach(field => { 30 | if (params.linkedType.nonEmpty && params.linkedType.get.exists(linkType => linkType._1.equals(field.name))) { 31 | createdClass.createProperty(field.name, Conversions.sparkDTtoOrientDBDT(field.dataType), 32 | connector.getMetadata.getSchema.getClass(params.linkedType.get(field.name).split("-").last)) 33 | } else { 34 | createdClass.createProperty(field.name, Conversions.sparkDTtoOrientDBDT(field.dataType)) 35 | } 36 | }) 37 | 38 | if (clusters != null) { 39 | clusters.foreach(cluster => createdClass.addCluster(cluster)) 40 | } else { 41 | clusters = List(connector.getClusterNameById(createdClass.getDefaultClusterId)) 42 | } 43 | } 44 | 45 | private[orientdb] def dropOrientDBClass(params: MergedParameters): Unit = { 46 | val connection = orientDBWrapper.getConnection(params) 47 | 48 | // create class if not exists 49 | val classname = params.className 50 | if (classname.isEmpty) { 51 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " + 52 | "name with the 'classname' parameter") 53 | } 54 | var cluster = params.clusterNames 55 | 56 | // Todo use Future 57 | if (connection.getMetadata.getSchema.existsClass(classname.get)) { 58 | orientDBWrapper.delete(null, classname.get, null) 59 | 60 | val schema = connection.getMetadata.getSchema 61 | schema.dropClass(classname.get) 62 | } 63 | } 64 | 65 | private def doOrientDBLoad(connection: ODatabaseDocument, 66 | data: DataFrame, 67 | params: MergedParameters): Unit = { 68 | // create class if not exists 69 | val classname = params.className 70 | if (classname.isEmpty) { 71 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Class " + 72 | "name with the 'classname' parameter") 73 | } 74 | 75 | val schema = connection.getMetadata.getSchema 76 | 77 | if (!schema.existsClass(classname.get)) { 78 | createOrientDBClass(data, params) 79 | } 80 | 81 | var clusters = params.clusterNames 82 | if (clusters.isEmpty) { 83 | val schema = connection.getMetadata.getSchema 84 | val currClass = schema.getClass(classname.get) 85 | clusters = Some(List(connection.getClusterNameById(currClass.getDefaultClusterId))) 86 | } 87 | 88 | // Todo use future 89 | // load data into Orient DB 90 | try { 91 | data.foreachPartition((rows: Iterator[Row]) => { 92 | val ownerPool = new ODatabaseDocumentPool(params.dbUrl.get, 93 | params.credentials.get._1, params.credentials.get._2) 94 | val connection = new ODatabaseDocumentTxPooled(ownerPool, params.dbUrl.get, 95 | params.credentials.get._1, params.credentials.get._2) 96 | val rowsList = rows.toList 97 | val rowsPerCluster = rowsList.length % clusters.get.length match { 98 | case 0 => rowsList.length / clusters.get.length 99 | case _ => (rowsList.length / clusters.get.length) + 1 100 | } 101 | 102 | var countPerCluster = 0 103 | var clusterIdx = 0 104 | rowsList.foreach { row => 105 | connection.save(Conversions.convertRowsToODocuments(row), clusters.get(clusterIdx)) 106 | countPerCluster = countPerCluster + 1 107 | 108 | if (countPerCluster % rowsPerCluster == 0) { 109 | countPerCluster = 0 110 | clusterIdx = clusterIdx + 1 111 | } 112 | } 113 | 114 | connection.close() 115 | }) 116 | } catch { 117 | case e: Exception => 118 | throw new RuntimeException("An exception was thrown: " + e.getMessage) 119 | } 120 | } 121 | 122 | def saveToOrientDB(data: DataFrame, saveMode: SaveMode, params: MergedParameters): Unit = { 123 | val connection = orientDBWrapper.getConnection(params) 124 | try { 125 | if (saveMode == SaveMode.Overwrite) { 126 | dropOrientDBClass(params) 127 | } 128 | doOrientDBLoad(connection, data, params) 129 | } finally { 130 | connection.close() 131 | } 132 | } 133 | } 134 | 135 | object DefaultOrientDBWriter extends OrientDBWriter( 136 | DefaultOrientDBDocumentWrapper, 137 | orientDBCredemtials => new OrientDBClientFactory(orientDBCredemtials)) -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/Parameters.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | private[orientdb] object Parameters { 4 | val DEFAULT_PARAMETERS: Map[String, String] = Map( 5 | "overwrite" -> "false" 6 | ) 7 | 8 | def mergeParameters(userParameters: Map[String, String]): MergedParameters = { 9 | if (!userParameters.contains("dburl")) { 10 | throw new IllegalArgumentException("A Orient DB URL must be provided with 'dburl' parameter") 11 | } 12 | 13 | if (!userParameters.contains("class") && !userParameters.contains("query")) { 14 | throw new IllegalArgumentException("You must specify Orient DB class name with 'class' " + 15 | "parameter or a query with the 'query' parameter. If it is a 'query' you must define your own" + 16 | " schema or specify orientdb 'class' name.") 17 | } 18 | 19 | if (!userParameters.contains("user") || !userParameters.contains("password")) { 20 | throw new IllegalArgumentException("A Orient DB username & password must be provided" + 21 | " with 'user' & 'password' parameters respectively") 22 | } 23 | MergedParameters(DEFAULT_PARAMETERS ++ userParameters) 24 | } 25 | 26 | case class MergedParameters(parameters: Map[String, String]) { 27 | 28 | /** 29 | * OrientDB Table from where to load and write data. 30 | */ 31 | def table: Option[TableName] = parameters.get("class").map(_.trim).flatMap(dbTable => { 32 | /** 33 | * this case is going to be handled in query variable. 34 | */ 35 | if (dbTable.startsWith("(") && dbTable.endsWith(")")) 36 | None 37 | else 38 | Some(TableName(dbTable)) 39 | }) 40 | 41 | /** 42 | * The OrientDB query to be used when loading data 43 | */ 44 | def query: Option[String] = parameters.get("query").orElse({ 45 | parameters.get("class") 46 | .map(_.trim) 47 | .filter(t => t.startsWith("(") && t.endsWith(")")) 48 | .map(t => t.drop(1).dropRight(1)) 49 | }) 50 | 51 | /** 52 | * Username & Password for authenticating with OrientDB. 53 | */ 54 | def credentials: Option[(String, String)] = { 55 | for { 56 | username <- parameters.get("user") 57 | password <- parameters.get("password") 58 | } yield (username, password) 59 | } 60 | 61 | /** 62 | * A url in the format 63 | * 64 | * remote::/ 65 | */ 66 | def dbUrl: Option[String] = parameters.get("dburl") 67 | 68 | /** 69 | * class name in Orient DB 70 | */ 71 | def className: Option[String] = parameters.get("class") 72 | 73 | /** 74 | * cluster name in Orient DB 75 | */ 76 | def clusterNames: Option[List[String]] = 77 | parameters.get("clusters").map(_.split(",").toList) 78 | 79 | def linkedType: Option[Map[String, String]] = 80 | if (parameters.exists(paramPair => paramPair._2.contains("linkedType"))) 81 | Some(parameters.filter(paramPair => paramPair._2.contains("linkedType"))) 82 | else None 83 | } 84 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/documents/TableName.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | private[orientdb] case class TableName(var unescapedTableName: String) { 4 | /** 5 | * drop the quotes from the two ends 6 | */ 7 | if (unescapedTableName.startsWith("\"") && unescapedTableName.endsWith("\"")) 8 | unescapedTableName = unescapedTableName.drop(1).dropRight(1) 9 | 10 | private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"' 11 | def escapedTableName: String = unescapedTableName 12 | 13 | override def toString: String = s"$escapedTableName" 14 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/graphs/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} 4 | import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider} 5 | import org.apache.spark.sql.types.StructType 6 | import org.slf4j.LoggerFactory 7 | 8 | class DefaultSource( orientDBGraphVertexWrapper: OrientDBGraphVertexWrapper, 9 | orientDBGraphEdgeWrapper: OrientDBGraphEdgeWrapper, 10 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory) 11 | extends RelationProvider 12 | with SchemaRelationProvider 13 | with CreatableRelationProvider { 14 | private val log = LoggerFactory.getLogger(getClass) 15 | 16 | def this() = this(DefaultOrientDBGraphVertexWrapper, DefaultOrientDBGraphEdgeWrapper, 17 | orientDBCredentials => new OrientDBClientFactory(orientDBCredentials)) 18 | 19 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { 20 | val params = Parameters.mergeParameters(parameters) 21 | 22 | if (params.query.isDefined && (params.vertexType.isEmpty && params.edgeType.isEmpty)) { 23 | throw new IllegalArgumentException("Along with the 'query' parameter you must specify either 'vertextype' parameter or" + 24 | " 'edgetype' parameter or user-defined Schema") 25 | } 26 | 27 | if (params.vertexType.isDefined && params.edgeType.isEmpty) { 28 | OrientDBVertexRelation(orientDBGraphVertexWrapper, orientDBClientFactory, params, None)(sqlContext) 29 | } else { 30 | OrientDBEdgeRelation(orientDBGraphEdgeWrapper, orientDBClientFactory, params, None)(sqlContext) 31 | } 32 | } 33 | 34 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], 35 | schema: StructType): BaseRelation = { 36 | val params = Parameters.mergeParameters(parameters) 37 | 38 | if (params.vertexType.isDefined && params.edgeType.isEmpty) { 39 | OrientDBVertexRelation(orientDBGraphVertexWrapper, orientDBClientFactory, params, Some(schema))(sqlContext) 40 | } else { 41 | OrientDBEdgeRelation(orientDBGraphEdgeWrapper, orientDBClientFactory, params, Some(schema))(sqlContext) 42 | } 43 | } 44 | 45 | override def createRelation(sqlContext: SQLContext, mode: SaveMode, 46 | parameters: Map[String, String], data: DataFrame): BaseRelation = { 47 | val params = Parameters.mergeParameters(parameters) 48 | 49 | val vertexType = params.vertexType 50 | val edgeType = params.edgeType 51 | 52 | if (vertexType.isEmpty && edgeType.isEmpty) { 53 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Graph Vertex" + 54 | " or Edge type with the 'vertextype' & 'edgetype' parameter respectively") 55 | } 56 | 57 | 58 | def tableExists: Boolean = { 59 | if (vertexType.isDefined && edgeType.isEmpty) { 60 | try { 61 | orientDBGraphVertexWrapper.doesVertexTypeExists(vertexType.get) 62 | } finally { 63 | orientDBGraphEdgeWrapper.close() 64 | } 65 | } 66 | else { 67 | try { 68 | orientDBGraphEdgeWrapper.doesEdgeTypeExists(edgeType.get) 69 | } finally { 70 | orientDBGraphEdgeWrapper.close() 71 | } 72 | } 73 | } 74 | 75 | if (vertexType.isDefined && edgeType.isEmpty) { 76 | val (doSave, dropExisting) = mode match { 77 | case SaveMode.Append => (true, false) 78 | case SaveMode.Overwrite => (true, true) 79 | case SaveMode.ErrorIfExists => 80 | if (tableExists) { 81 | sys.error(s"Vertex type $vertexType already exists! (SaveMode is set to ErrorIfExists)") 82 | } else { 83 | (true, false) 84 | } 85 | case SaveMode.Ignore => 86 | if (tableExists) { 87 | log.info(s"Vertex Type $vertexType already exists. Ignoring save requests.") 88 | (false, false) 89 | } else { 90 | (true, false) 91 | } 92 | } 93 | 94 | if (doSave) { 95 | val updatedParams = parameters.updated("overwrite", dropExisting.toString) 96 | new OrientDBVertexWriter(orientDBGraphVertexWrapper, orientDBClientFactory) 97 | .saveToOrientDB(data, mode, Parameters.mergeParameters(updatedParams)) 98 | } 99 | createRelation(sqlContext, parameters) 100 | } else { 101 | if (!parameters.contains("vertextype") && parameters.contains("edgetype")) { 102 | throw new IllegalArgumentException("You must specify the Orient DB Vertex type in the 'vertextype'" + 103 | " parameter along with Orient DB Edge type in the 'edgetype' parameter") 104 | } 105 | 106 | try { 107 | val connection = orientDBGraphEdgeWrapper.getConnection(params) 108 | orientDBGraphEdgeWrapper.doesEdgeTypeExists(edgeType.get) 109 | } finally { 110 | orientDBGraphEdgeWrapper.close() 111 | } 112 | 113 | val (doSave, dropExisting) = mode match { 114 | case SaveMode.Append => (true, false) 115 | case SaveMode.Overwrite => (true, true) 116 | case SaveMode.ErrorIfExists => 117 | if (tableExists) { 118 | sys.error(s"Edge Type $edgeType already exists! (SaveMode is set to ErrorIfExists)") 119 | } else { 120 | (true, false) 121 | } 122 | case SaveMode.Ignore => 123 | if (tableExists) { 124 | log.info(s"Edge Type $edgeType already exists. Ignoring save requests.") 125 | (false, false) 126 | } else { 127 | (true, false) 128 | } 129 | } 130 | 131 | if (doSave) { 132 | val updatedParams = parameters.updated("overwrite", dropExisting.toString) 133 | new OrientDBEdgeWriter(orientDBGraphEdgeWrapper, orientDBClientFactory) 134 | .saveToOrientDB(data, mode, Parameters.mergeParameters(updatedParams)) 135 | } 136 | createRelation(sqlContext, parameters) 137 | } 138 | } 139 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/graphs/OrientDBClientFactory.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import com.tinkerpop.blueprints.impls.orient.{OrientGraphFactory, OrientGraphNoTx} 4 | 5 | class OrientDBClientFactory(orientDBCredentials: OrientDBCredentials) extends Serializable { 6 | private val db: OrientGraphFactory = 7 | new OrientGraphFactory(orientDBCredentials.dbUrl, orientDBCredentials.username, 8 | orientDBCredentials.password) 9 | private var connection: OrientGraphNoTx = _ 10 | 11 | def getConnection(): OrientGraphNoTx = { 12 | connection = db.getNoTx 13 | connection 14 | } 15 | 16 | def closeConnection(): Unit = { 17 | db.close() 18 | } 19 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/graphs/OrientDBCredentials.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | trait OrientDBCredentials extends Serializable { 4 | var dbUrl: String = null 5 | var username: String = null 6 | var password: String = null 7 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/graphs/OrientDBRelation.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import com.tinkerpop.blueprints.{Edge, Vertex} 4 | import org.apache.spark.orientdb.documents.{Conversions, FilterPushdown} 5 | import org.apache.spark.orientdb.graphs.Parameters.MergedParameters 6 | import org.apache.spark.rdd.RDD 7 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} 8 | import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan} 9 | import org.apache.spark.sql.types.StructType 10 | import org.slf4j.LoggerFactory 11 | 12 | private[orientdb] case class OrientDBVertexRelation( 13 | orientDBVertexWrapper: OrientDBGraphVertexWrapper, 14 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory, 15 | params: MergedParameters, 16 | userSchema: Option[StructType] 17 | ) (@transient val sqlContext: SQLContext) 18 | extends BaseRelation 19 | with PrunedFilteredScan 20 | with InsertableRelation { 21 | private val log = LoggerFactory.getLogger(getClass) 22 | 23 | private val tableNameOrSubQuery = params.query.map(q => s"($q)").orElse(params.vertexType.map(_.toString)).get 24 | 25 | override lazy val schema: StructType = { 26 | userSchema.getOrElse{ 27 | val tableName = params.vertexType.map(_.toString).get 28 | val conn = orientDBVertexWrapper.getConnection(params) 29 | 30 | try { 31 | orientDBVertexWrapper.resolveTable(tableName) 32 | } finally { 33 | orientDBVertexWrapper.close() 34 | } 35 | } 36 | } 37 | 38 | override def toString: String = s"""OrientDBRelation($tableNameOrSubQuery)""" 39 | 40 | override def insert(data: DataFrame, overwrite: Boolean): Unit = { 41 | val saveMode = if (overwrite) { 42 | SaveMode.Overwrite 43 | } else { 44 | SaveMode.Append 45 | } 46 | val writer = new OrientDBVertexWriter(orientDBVertexWrapper, 47 | orientDBClientFactory) 48 | writer.saveToOrientDB(data, saveMode, params) 49 | } 50 | 51 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { 52 | filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined) 53 | } 54 | 55 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { 56 | if (requiredColumns.isEmpty) { 57 | val whereClause = FilterPushdown.buildWhereClause(schema, filters) 58 | var countQuery = s"select count(*) from $tableNameOrSubQuery $whereClause" 59 | 60 | if (params.query.isDefined) { 61 | countQuery = tableNameOrSubQuery.drop(1).dropRight(1) 62 | } 63 | 64 | log.info("count query") 65 | orientDBVertexWrapper.getConnection(params) 66 | 67 | try { 68 | val results = orientDBVertexWrapper.genericQuery(countQuery) 69 | if (params.query.isEmpty && results.nonEmpty) { 70 | val numRows: Long = results.head.getProperty[Long]("count") 71 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt 72 | val emptyRow = Row.empty 73 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow) 74 | } else if (params.query.isDefined) { 75 | val numRows: Long = results.length 76 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt 77 | val emptyRow = Row.empty 78 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow) 79 | } else { 80 | throw new IllegalArgumentException("Cannot read count for OrientDB graph vertices") 81 | } 82 | } finally { 83 | orientDBVertexWrapper.close() 84 | } 85 | } else { 86 | var vertexTypeName: String = null 87 | if (params.query.isEmpty) { 88 | vertexTypeName = params.vertexType match { 89 | case Some(vertexType) => vertexType 90 | case None => 91 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Graph" + 92 | " Vertex type name with the 'vertextype' parameter") 93 | } 94 | } 95 | 96 | val filterStr = FilterPushdown.buildWhereClause(schema, filters) 97 | var oVertices: List[Vertex] = List() 98 | try { 99 | orientDBVertexWrapper.getConnection(params) 100 | if (params.query.isEmpty) { 101 | oVertices = orientDBVertexWrapper.read(vertexTypeName, requiredColumns, 102 | filterStr, null) 103 | } else { 104 | oVertices = orientDBVertexWrapper.read(null, requiredColumns, filterStr, 105 | params.query.get) 106 | } 107 | } finally { 108 | orientDBVertexWrapper.close() 109 | } 110 | 111 | if (params.query.isEmpty) { 112 | val prunedSchema = pruneSchema(schema, requiredColumns) 113 | sqlContext.sparkContext.makeRDD( 114 | oVertices.map(vertex => Conversions.convertVerticesToRows(vertex, prunedSchema)) 115 | ) 116 | } else { 117 | assert(oVertices.nonEmpty) 118 | val propKeysArray = new Array[String](chooseRecordForSchema(oVertices).getPropertyKeys.size()) 119 | val prunedSchema = pruneSchema(schema, chooseRecordForSchema(oVertices) 120 | .getPropertyKeys.toArray[String](propKeysArray)) 121 | sqlContext.sparkContext.makeRDD( 122 | oVertices.map(vertex => Conversions.convertVerticesToRows(vertex, prunedSchema)) 123 | ) 124 | } 125 | } 126 | } 127 | 128 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { 129 | new StructType(schema.fields.filter(p => columns.contains(p.name))) 130 | } 131 | 132 | private def chooseRecordForSchema(oVertices: List[Vertex]): Vertex = { 133 | var maxLen = -1 134 | var idx: Vertex = null 135 | oVertices.foreach(oVertex => { 136 | if (maxLen < oVertex.getPropertyKeys.size()) { 137 | idx = oVertex 138 | maxLen = oVertex.getPropertyKeys.size() 139 | } 140 | }) 141 | idx 142 | } 143 | } 144 | 145 | private[orientdb] case class OrientDBEdgeRelation( 146 | orientDBEdgeWrapper: OrientDBGraphEdgeWrapper, 147 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory, 148 | params: MergedParameters, 149 | userSchema: Option[StructType] 150 | ) (@transient val sqlContext: SQLContext) 151 | extends BaseRelation 152 | with PrunedFilteredScan 153 | with InsertableRelation { 154 | private val log = LoggerFactory.getLogger(getClass) 155 | 156 | private val tableNameOrSubQuery = params.query.map(q => s"($q)").orElse(params.edgeType.map(_.toString)).get 157 | 158 | override lazy val schema: StructType = { 159 | userSchema.getOrElse{ 160 | val tableName = params.edgeType.map(_.toString).get 161 | val conn = orientDBEdgeWrapper.getConnection(params) 162 | 163 | try { 164 | orientDBEdgeWrapper.resolveTable(tableName) 165 | } finally { 166 | orientDBEdgeWrapper.close() 167 | } 168 | } 169 | } 170 | 171 | override def toString: String = s"""OrientDBRelation($tableNameOrSubQuery)""" 172 | 173 | override def insert(data: DataFrame, overwrite: Boolean): Unit = { 174 | val saveMode = if (overwrite) { 175 | SaveMode.Overwrite 176 | } else { 177 | SaveMode.Append 178 | } 179 | val writer = new OrientDBEdgeWriter(orientDBEdgeWrapper, orientDBClientFactory) 180 | writer.saveToOrientDB(data, saveMode, params) 181 | } 182 | 183 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { 184 | filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined) 185 | } 186 | 187 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { 188 | if (requiredColumns.isEmpty) { 189 | val whereClause = FilterPushdown.buildWhereClause(schema, filters) 190 | var countQuery = s"select count(*) from $tableNameOrSubQuery $whereClause" 191 | 192 | if (params.query.isDefined) { 193 | countQuery = tableNameOrSubQuery.drop(1).dropRight(1) 194 | } 195 | 196 | log.info("count query") 197 | orientDBEdgeWrapper.getConnection(params) 198 | 199 | try { 200 | val results = orientDBEdgeWrapper.genericQuery(countQuery) 201 | if (params.query.isEmpty && results.nonEmpty) { 202 | val numRows: Long = results.head.getProperty[Long]("count") 203 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt 204 | val emptyRow = Row.empty 205 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow) 206 | } else if (params.query.isDefined) { 207 | val numRows: Long = results.length 208 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt 209 | val emptyRow = Row.empty 210 | sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow) 211 | } else { 212 | throw new IllegalArgumentException("Cannot read count for OrientDB graph edges") 213 | } 214 | } finally { 215 | orientDBEdgeWrapper.close() 216 | } 217 | } else { 218 | var edgeTypeName: String = null 219 | if (params.query.isEmpty) { 220 | edgeTypeName = params.edgeType match { 221 | case Some(edgeType) => edgeType 222 | case None => throw new IllegalArgumentException("For save operations you must specify a OrientDB Graph" + 223 | " Edge type name with the 'edgetype' parameter") 224 | } 225 | } 226 | 227 | val filterStr = FilterPushdown.buildWhereClause(schema, filters) 228 | var oEdges: List[Edge] = List() 229 | 230 | try { 231 | if (params.query.isEmpty) { 232 | oEdges = orientDBEdgeWrapper.read(edgeTypeName, requiredColumns, filterStr, null) 233 | } else { 234 | oEdges = orientDBEdgeWrapper.read(null, requiredColumns, filterStr, params.query.get) 235 | } 236 | } finally { 237 | orientDBEdgeWrapper.close() 238 | } 239 | 240 | if (params.query.isEmpty) { 241 | val prunedSchema = pruneSchema(schema, requiredColumns) 242 | sqlContext.sparkContext.makeRDD( 243 | oEdges.map(edge => Conversions.convertEdgesToRows(edge, prunedSchema)) 244 | ) 245 | } else { 246 | assert(oEdges.nonEmpty) 247 | val propKeysArray = new Array[String](chooseRecordForSchema(oEdges).getPropertyKeys.size()) 248 | val prunedSchema = pruneSchema(schema, 249 | chooseRecordForSchema(oEdges).getPropertyKeys.toArray[String](propKeysArray)) 250 | sqlContext.sparkContext.makeRDD( 251 | oEdges.map(edge => Conversions.convertEdgesToRows(edge, prunedSchema)) 252 | ) 253 | } 254 | } 255 | } 256 | 257 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { 258 | new StructType(schema.fields.filter(p => columns.contains(p.name))) 259 | } 260 | 261 | private def chooseRecordForSchema(oEdges: List[Edge]): Edge = { 262 | var maxLen = -1 263 | var idx: Edge = null 264 | oEdges.foreach(oEdge => { 265 | if (maxLen < oEdge.getPropertyKeys.size()) { 266 | idx = oEdge 267 | maxLen = oEdge.getPropertyKeys.size() 268 | } 269 | }) 270 | idx 271 | } 272 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/graphs/OrientDBWriter.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import java.util 4 | 5 | import com.orientechnologies.orient.core.metadata.schema.OClass 6 | import com.orientechnologies.orient.core.sql.OCommandSQL 7 | import com.tinkerpop.blueprints.Vertex 8 | import com.tinkerpop.blueprints.impls.orient.{OrientGraphFactory, OrientGraphNoTx} 9 | import org.apache.spark.orientdb.documents.Conversions 10 | import org.apache.spark.orientdb.graphs.Parameters.MergedParameters 11 | import org.apache.spark.sql.{DataFrame, Row, SaveMode} 12 | import org.slf4j.LoggerFactory 13 | 14 | import scala.collection.JavaConversions._ 15 | 16 | private[orientdb] class OrientDBVertexWriter(orientDBWrapper: OrientDBGraphVertexWrapper, 17 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory) 18 | extends Serializable { 19 | 20 | private[orientdb] def createOrientDBVertex(data: DataFrame, params: MergedParameters): Unit = { 21 | val dfSchema = data.schema 22 | val vertexType = params.vertexType match { 23 | case Some(vertexTypeName) => vertexTypeName 24 | case None => throw new IllegalArgumentException("For save operations you must specify a OrientDB Vertex Type" + 25 | " with the 'vertextype' parameter") 26 | } 27 | var cluster = params.cluster match { 28 | case Some(clusterName) => clusterName 29 | case None => null 30 | } 31 | 32 | val connector = orientDBWrapper.getConnection(params) 33 | val createdVertexType = connector.createVertexType(vertexType) 34 | 35 | dfSchema.foreach(field => { 36 | if (params.linkedType.nonEmpty && params.linkedType.get.exists(linkType => linkType._1.equals(field.name))) { 37 | createdVertexType.createProperty(field.name, 38 | Conversions.sparkDTtoOrientDBDT(field.dataType), 39 | connector.getVertexType(params.linkedType.get(field.name).split("-").last)) 40 | } else { 41 | createdVertexType.createProperty(field.name, Conversions.sparkDTtoOrientDBDT(field.dataType)) 42 | } 43 | 44 | if (field.name == "id") { 45 | createdVertexType.createIndex(s"${vertexType}Idx", OClass.INDEX_TYPE.UNIQUE, field.name) 46 | } 47 | }) 48 | } 49 | 50 | private[orientdb] def dropOrientDBVertex(params: MergedParameters): Unit = { 51 | val connection = orientDBWrapper.getConnection(params) 52 | 53 | val vertexType = params.vertexType 54 | if (vertexType.isEmpty) { 55 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Vertex Type" + 56 | " with the 'vertextype' parameter") 57 | } 58 | 59 | if (connection.getVertexType(vertexType.get) != null) { 60 | orientDBWrapper.delete(vertexType.get, null) 61 | connection.dropVertexType(vertexType.get) 62 | } 63 | } 64 | 65 | private def doOrientDBVertexLoad(connection: OrientGraphNoTx, 66 | data: DataFrame, 67 | params: MergedParameters): Unit = { 68 | val vertexType = params.vertexType 69 | if (vertexType.isEmpty) { 70 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Vertex Type" + 71 | " with the 'vertextype' parameter") 72 | } 73 | 74 | if (connection.getVertexType(vertexType.get) == null) { 75 | createOrientDBVertex(data, params) 76 | } 77 | 78 | try { 79 | data.foreachPartition((rows: Iterator[Row]) => { 80 | val graphFactory = new OrientGraphFactory(params.dbUrl.get, 81 | params.credentials.get._1, 82 | params.credentials.get._2) 83 | val connection = graphFactory.getNoTx 84 | 85 | while (rows.hasNext) { 86 | val row = rows.next() 87 | 88 | val fields = row.schema.fields 89 | 90 | val fieldNames = fields.map(_.name) 91 | if (!fieldNames.contains("id")) { 92 | throw new IllegalArgumentException("'id' is a mandatory parameter " + 93 | "for creating a vertex") 94 | } 95 | val key = row.getAs[Object](fieldNames.indexOf("id")) 96 | val properties = new util.HashMap[String, Object]() 97 | properties.put("id", key) 98 | connection.setStandardElementConstraints(false) 99 | val createdVertex = connection.addVertex(s"class:${params.vertexType.get}", properties) 100 | 101 | var count = 0 102 | while (count < fields.length) { 103 | val sparkType = fields(count).dataType 104 | val orientDBType = Conversions 105 | .sparkDTtoOrientDBDT(sparkType) 106 | createdVertex.setProperty(fields(count).name, 107 | Conversions.convertRowToGraph(row, count), orientDBType) 108 | 109 | count = count + 1 110 | } 111 | createdVertex.moveToClass(params.vertexType.get) 112 | } 113 | graphFactory.close() 114 | }) 115 | } catch { 116 | case e: Exception => 117 | throw new RuntimeException("An exception was thrown: " + e.getMessage) 118 | } 119 | } 120 | 121 | def saveToOrientDB(data: DataFrame, saveMode: SaveMode, params: MergedParameters): Unit = { 122 | val connection = orientDBWrapper.getConnection(params) 123 | try { 124 | if (saveMode == SaveMode.Overwrite) { 125 | dropOrientDBVertex(params) 126 | } 127 | doOrientDBVertexLoad(connection, data, params) 128 | } finally { 129 | orientDBWrapper.close() 130 | } 131 | } 132 | } 133 | 134 | object DefaultOrientDBVertexWriter extends OrientDBVertexWriter( 135 | DefaultOrientDBGraphVertexWrapper, 136 | orientDBCredentials => new OrientDBClientFactory(orientDBCredentials)) 137 | 138 | private[orientdb] class OrientDBEdgeWriter(orientDBWrapper: OrientDBGraphEdgeWrapper, 139 | orientDBClientFactory: OrientDBCredentials => OrientDBClientFactory) 140 | extends Serializable { 141 | private val log = LoggerFactory.getLogger(getClass) 142 | 143 | private[orientdb] def createOrientDBEdge(data: DataFrame, params: MergedParameters): Unit = { 144 | val dfSchema = data.schema 145 | val edgeType = params.edgeType match { 146 | case Some(edgeTypeName) => edgeTypeName 147 | case None => throw new IllegalArgumentException("For save operations you must specify a OrientDB Edge Type" + 148 | " with the 'edgetype' parameter") 149 | } 150 | 151 | val connector = orientDBWrapper.getConnection(params) 152 | val createdEdgeType = connector.createEdgeType(edgeType) 153 | if (!params.lightWeightEdge) { 154 | dfSchema.foreach(field => { 155 | if (params.linkedType.nonEmpty && params.linkedType.get.exists(linkType => linkType._1.equals(field.name))) { 156 | createdEdgeType.createProperty(field.name, 157 | Conversions.sparkDTtoOrientDBDT(field.dataType), 158 | connector.getEdgeType(params.linkedType.get(field.name).split("-").last)) 159 | } 160 | createdEdgeType.createProperty(field.name, Conversions.sparkDTtoOrientDBDT(field.dataType)) 161 | }) 162 | } 163 | } 164 | 165 | private[orientdb] def dropOrientDBEdge(params: MergedParameters): Unit = { 166 | val connection = orientDBWrapper.getConnection(params) 167 | 168 | val edgeType = params.edgeType 169 | if (edgeType.isEmpty) { 170 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Edge Type" + 171 | " with the 'edgetype' parameter") 172 | } 173 | 174 | if (connection.getEdgeType(edgeType.get) != null) { 175 | orientDBWrapper.delete(edgeType.get, null) 176 | connection.dropEdgeType(edgeType.get) 177 | } 178 | } 179 | 180 | private def doOrientDBEdgeLoad(connection: OrientGraphNoTx, 181 | data: DataFrame, 182 | params: MergedParameters): Unit = { 183 | val edgeType = params.edgeType 184 | if (edgeType.isEmpty) { 185 | throw new IllegalArgumentException("For save operations you must specify a OrientDB Edge Type" + 186 | " with the 'edgetype' parameter") 187 | } 188 | 189 | if (connection.getEdgeType(edgeType.get) == null) { 190 | createOrientDBEdge(data, params) 191 | } 192 | 193 | val (inVertexType, outVertexType) = params.vertexType match { 194 | case Some(vertexTypeNames) => 195 | val cols = vertexTypeNames.split(",") 196 | if (cols.length == 1) (cols(0), cols(0)) 197 | else if (cols.length == 2) (cols(0), cols(1)) 198 | else throw new IllegalArgumentException("More than 2 'vertextype' specified") 199 | case None => 200 | throw new IllegalArgumentException("Saving edges also require a vertex type specified by " + 201 | "'vertextype' parameter") 202 | } 203 | 204 | try { 205 | data.foreachPartition((rows: Iterator[Row]) => { 206 | val graphFactory = new OrientGraphFactory(params.dbUrl.get, 207 | params.credentials.get._1, 208 | params.credentials.get._2) 209 | val connection = graphFactory.getNoTx 210 | 211 | while (rows.hasNext) { 212 | val row = rows.next() 213 | 214 | val fields = row.schema.fields 215 | 216 | var inVertexName: String = null 217 | try { 218 | inVertexName = row.getAs[Object](fields.map(_.name).indexOf("src")).toString 219 | } catch { 220 | case e: Exception => throw new IllegalArgumentException("'src' is a mandatory parameter " + 221 | "for creating an edge") 222 | } 223 | 224 | var outVertexName: String = null 225 | try { 226 | outVertexName = row.getAs[Object](fields.map(_.name).indexOf("dst")).toString 227 | } catch { 228 | case e: Exception => throw new IllegalArgumentException("'dst' is a mandatory parameters " + 229 | "for creating an edge") 230 | } 231 | 232 | val inVertices: List[Vertex] = connection 233 | .command(new OCommandSQL(s"select * from $inVertexType where id = '$inVertexName'")) 234 | .execute() 235 | .asInstanceOf[java.lang.Iterable[Vertex]] 236 | .toList 237 | 238 | val inVertex: Option[Vertex] = 239 | if (inVertices.isEmpty && params.createVertexIfNotExist) { 240 | // log.info(s"in Vertex $inVertexName does not exist. Creating it...") 241 | val v = connection.addVertex(inVertexType, null) 242 | v.setProperty("id", inVertexName) 243 | Some(v) 244 | } else { 245 | inVertices.headOption 246 | } 247 | 248 | val outVertices: List[Vertex] = connection 249 | .command(new OCommandSQL(s"select * from $outVertexType where id = '$outVertexName'")) 250 | .execute() 251 | .asInstanceOf[java.lang.Iterable[Vertex]] 252 | .toList 253 | 254 | val outVertex: Option[Vertex] = 255 | if (outVertices.isEmpty && params.createVertexIfNotExist) { 256 | // log.info(s"out Vertex $outVertexName does not exist. Creating it...") 257 | val v = connection.addVertex(outVertexType, null) 258 | v.setProperty("id", outVertexName) 259 | Some(v) 260 | } else { 261 | outVertices.headOption 262 | } 263 | 264 | for { 265 | in <- inVertex 266 | out <- outVertex 267 | } { 268 | var id: String = null 269 | for (i <- fields.indices) { 270 | if (fields(i).name == "id") { 271 | id = Conversions.convertRowToGraph(row, i).toString 272 | } 273 | } 274 | val createdEdge = connection.addEdge(id, in, out, params.edgeType.get) 275 | for (i <- fields.indices) { 276 | val sparkType = fields(i).dataType 277 | val orientDBType = Conversions.sparkDTtoOrientDBDT(sparkType) 278 | createdEdge.setProperty(fields(i).name, Conversions.convertRowToGraph(row, i), orientDBType) 279 | } 280 | } 281 | 282 | graphFactory.close() 283 | } 284 | }) 285 | } catch { 286 | case e: Exception => 287 | throw new RuntimeException("An exception was thrown: " + e.getMessage) 288 | } 289 | } 290 | 291 | def saveToOrientDB(data: DataFrame, saveMode: SaveMode, params: MergedParameters): Unit = { 292 | val connection = orientDBWrapper.getConnection(params) 293 | 294 | try { 295 | if (saveMode == SaveMode.Overwrite) { 296 | dropOrientDBEdge(params) 297 | } 298 | doOrientDBEdgeLoad(connection, data, params) 299 | } finally { 300 | orientDBWrapper.close() 301 | } 302 | } 303 | } 304 | 305 | object DefaultOrientDBEdgeWriter extends OrientDBEdgeWriter( 306 | DefaultOrientDBGraphEdgeWrapper, 307 | orientDBCredentials => new OrientDBClientFactory(orientDBCredentials)) 308 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/graphs/Parameters.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | private[orientdb] object Parameters { 4 | val DEFAULT_PARAMETERS: Map[String, String] = Map( 5 | "overwrite" -> "false" 6 | ) 7 | 8 | def mergeParameters(userParameters: Map[String, String]): MergedParameters = { 9 | if (!userParameters.contains("dburl")) { 10 | throw new IllegalArgumentException("A Orient DB URL must be provided with 'dburl' parameter") 11 | } 12 | 13 | if (!userParameters.contains("vertextype") && !userParameters.contains("edgetype") && !userParameters.contains("query")) { 14 | throw new IllegalArgumentException("You must specify one of Orient DB Vertex type in the 'vertextype'" + 15 | " parameter or Orient DB Edge type in the 'edgetype' parameter or a user specified query using 'query' parameter") 16 | } 17 | 18 | if (!userParameters.contains("user") || !userParameters.contains("password")) { 19 | throw new IllegalArgumentException("You must specify both the OrientDB username in 'user' parameter &" + 20 | " OrientDB password in the 'password' parameter") 21 | } 22 | MergedParameters(DEFAULT_PARAMETERS ++ userParameters) 23 | } 24 | 25 | case class MergedParameters(parameters: Map[String, String]) { 26 | 27 | /** 28 | * The Orient DB Graph vertex Type to be used to load & write data 29 | */ 30 | def vertexType: Option[String] = parameters.get("vertextype").orElse(None) 31 | 32 | /** 33 | * The Orient DB Graph edge Type to be used to load & write data 34 | */ 35 | def edgeType: Option[String] = parameters.get("edgetype").orElse(None) 36 | 37 | def createVertexIfNotExist: Boolean = parameters.get("createVertexIfNotExist").isDefined && 38 | parameters("createVertexIfNotExist") == "true" 39 | 40 | def lightWeightEdge: Boolean = parameters.get("lightWeightEdge").isDefined && 41 | parameters("lightWeightEdge") == "true" 42 | 43 | /** 44 | * The Orient DB Graph sql query to be used for loading data 45 | */ 46 | def query: Option[String] = parameters.get("query").orElse(None) 47 | 48 | /** 49 | * Username & Password for authentication with OrientDB 50 | */ 51 | def credentials: Option[(String, String)] = { 52 | for { 53 | username <- parameters.get("user") 54 | password <- parameters.get("password") 55 | } yield (username, password) 56 | } 57 | 58 | /** 59 | * A url in the format 60 | * remote::/ 61 | */ 62 | def dbUrl: Option[String] = parameters.get("dburl") 63 | 64 | /** 65 | * cluster name in Orient DB 66 | */ 67 | def cluster: Option[String] = parameters.get("cluster").orElse(None) 68 | 69 | /** 70 | * mention linked properties in the form "vertextype/edgetype" - "linkedType-linked vertextype/edgetype" 71 | */ 72 | def linkedType: Option[Map[String, String]] = 73 | if (parameters.exists(paramPair => paramPair._2.contains("linkedType"))) 74 | Some(parameters.filter(paramPair => paramPair._2.contains("linkedType"))) 75 | else None 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/udts/EmbeddedListType.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.udts 2 | 3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 4 | 5 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 6 | import org.apache.spark.sql.types._ 7 | import org.apache.spark.unsafe.types.UTF8String 8 | 9 | @SQLUserDefinedType(udt = classOf[EmbeddedListType]) 10 | case class EmbeddedList(elements: Array[Any]) extends Serializable { 11 | override def hashCode(): Int = { 12 | var hashCode = 1 13 | val i = elements.iterator 14 | while (i.hasNext) { 15 | val obj = i.next() 16 | 17 | val elemValue = if (obj == null) 0 else obj.hashCode() 18 | hashCode = 31 * hashCode + elemValue 19 | } 20 | hashCode 21 | } 22 | 23 | override def equals(other: scala.Any): Boolean = other match { 24 | case that: EmbeddedList => that.elements.sameElements(this.elements) 25 | case _ => false 26 | } 27 | 28 | override def toString: String = elements.mkString(", ") 29 | } 30 | 31 | class EmbeddedListType extends UserDefinedType[EmbeddedList] { 32 | 33 | override def sqlType: DataType = ArrayType(StringType) 34 | 35 | override def serialize(obj: EmbeddedList): Any = { 36 | new GenericArrayData(obj.elements.map{elem => 37 | val out = new ByteArrayOutputStream() 38 | val os = new ObjectOutputStream(out) 39 | os.writeObject(elem) 40 | UTF8String.fromBytes(out.toByteArray) 41 | }) 42 | } 43 | 44 | override def deserialize(datum: Any): EmbeddedList = { 45 | datum match { 46 | case values: ArrayData => 47 | new EmbeddedList(values.toArray[UTF8String](StringType).map{ elem => 48 | val in = new ByteArrayInputStream(elem.getBytes) 49 | val is = new ObjectInputStream(in) 50 | is.readObject() 51 | }) 52 | case other => sys.error(s"Cannot deserialize $other") 53 | } 54 | } 55 | 56 | override def userClass: Class[EmbeddedList] = classOf[EmbeddedList] 57 | } 58 | 59 | object EmbeddedListType extends EmbeddedListType -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/udts/EmbeddedMapType.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.udts 2 | 3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 4 | 5 | import org.apache.spark.sql.catalyst.expressions.UnsafeMapData 6 | import org.apache.spark.sql.catalyst.util.ArrayBasedMapData 7 | import org.apache.spark.sql.types._ 8 | import org.apache.spark.unsafe.types.UTF8String 9 | 10 | @SQLUserDefinedType(udt = classOf[EmbeddedMapType]) 11 | case class EmbeddedMap(elements: Map[Any, Any]) extends Serializable { 12 | override def hashCode(): Int = 1 13 | 14 | override def equals(other: scala.Any): Boolean = other match { 15 | case that: EmbeddedMap => that.elements == this.elements 16 | case _ => false 17 | } 18 | 19 | override def toString: String = elements.mkString(", ") 20 | } 21 | 22 | class EmbeddedMapType extends UserDefinedType[EmbeddedMap] { 23 | 24 | override def sqlType: DataType = MapType(StringType, StringType) 25 | 26 | override def serialize(obj: EmbeddedMap): Any = { 27 | ArrayBasedMapData(obj.elements.keySet.map{ elem => 28 | val outKey = new ByteArrayOutputStream() 29 | val osKey = new ObjectOutputStream(outKey) 30 | osKey.writeObject(elem) 31 | UTF8String.fromBytes(outKey.toByteArray) 32 | }.toArray, 33 | obj.elements.values.map{ elem => 34 | val outValue = new ByteArrayOutputStream() 35 | val osValue = new ObjectOutputStream(outValue) 36 | osValue.writeObject(elem) 37 | UTF8String.fromBytes(outValue.toByteArray) 38 | }.toArray) 39 | } 40 | 41 | override def deserialize(datum: Any): EmbeddedMap = { 42 | datum match { 43 | case values: UnsafeMapData => 44 | new EmbeddedMap(values.keyArray().toArray[UTF8String](StringType).map{ elem => 45 | val in = new ByteArrayInputStream(elem.getBytes) 46 | val is = new ObjectInputStream(in) 47 | is.readObject() 48 | }.zip(values.valueArray().toArray[UTF8String](StringType).map{ elem => 49 | val in = new ByteArrayInputStream(elem.getBytes) 50 | val is = new ObjectInputStream(in) 51 | is.readObject() 52 | }).toMap) 53 | case other => sys.error(s"Cannot deserialize $other") 54 | } 55 | } 56 | 57 | override def userClass: Class[EmbeddedMap] = classOf[EmbeddedMap] 58 | } 59 | 60 | object EmbeddedMapType extends EmbeddedMapType -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/udts/EmbeddedSetType.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.udts 2 | 3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 4 | 5 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 6 | import org.apache.spark.sql.types._ 7 | import org.apache.spark.unsafe.types.UTF8String 8 | 9 | @SQLUserDefinedType(udt = classOf[EmbeddedSetType]) 10 | case class EmbeddedSet(elements: Array[Any]) extends Serializable { 11 | override def hashCode(): Int = { 12 | var hashCode = 1 13 | val i = elements.iterator 14 | while (i.hasNext) { 15 | val obj = i.next() 16 | 17 | val elemValue = if (obj == null) 0 else obj.hashCode() 18 | hashCode = 31 * hashCode + elemValue 19 | } 20 | hashCode 21 | } 22 | 23 | override def equals(other: scala.Any): Boolean = other match { 24 | case that: EmbeddedSet => that.elements.sameElements(this.elements) 25 | case _ => false 26 | } 27 | 28 | override def toString: String = elements.mkString(", ") 29 | } 30 | 31 | class EmbeddedSetType extends UserDefinedType[EmbeddedSet] { 32 | 33 | override def sqlType: DataType = ArrayType(StringType) 34 | 35 | override def serialize(obj: EmbeddedSet): Any = { 36 | new GenericArrayData(obj.elements.map{elem => 37 | val out = new ByteArrayOutputStream() 38 | val os = new ObjectOutputStream(out) 39 | os.writeObject(elem) 40 | UTF8String.fromBytes(out.toByteArray) 41 | }) 42 | } 43 | 44 | override def deserialize(datum: Any): EmbeddedSet = { 45 | datum match { 46 | case values: ArrayData => 47 | new EmbeddedSet(values.toArray[UTF8String](StringType).map{ elem => 48 | val in = new ByteArrayInputStream(elem.getBytes) 49 | val is = new ObjectInputStream(in) 50 | is.readObject() 51 | }) 52 | case other => sys.error(s"Cannot deserialize $other") 53 | } 54 | } 55 | 56 | override def userClass: Class[EmbeddedSet] = classOf[EmbeddedSet] 57 | } 58 | 59 | object EmbeddedSetType extends EmbeddedSetType -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/udts/LinkBagType.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.udts 2 | 3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 4 | 5 | import com.orientechnologies.orient.core.id.ORecordId 6 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 7 | import org.apache.spark.sql.types._ 8 | import org.apache.spark.unsafe.types.UTF8String 9 | 10 | @SQLUserDefinedType(udt = classOf[LinkBagType]) 11 | case class LinkBag(elements: Array[_ <: ORecordId]) extends Serializable { 12 | override def hashCode(): Int = { 13 | val hashCode = 1 14 | val elemValue = if (elements == null) 0 else elements.hashCode() 15 | 16 | 31 * hashCode + elemValue 17 | } 18 | 19 | override def equals(other: scala.Any): Boolean = other match { 20 | case that: LinkBag => that.elements.sameElements(this.elements) 21 | case _ => false 22 | } 23 | 24 | override def toString: String = elements.toString 25 | } 26 | 27 | class LinkBagType extends UserDefinedType[LinkBag] { 28 | 29 | override def sqlType: DataType = ArrayType(StringType) 30 | 31 | override def serialize(obj: LinkBag): Any = { 32 | new GenericArrayData(obj.elements.map{ elem => 33 | val out = new ByteArrayOutputStream() 34 | val os = new ObjectOutputStream(out) 35 | os.writeObject(elem) 36 | UTF8String.fromBytes(out.toByteArray) 37 | }) 38 | } 39 | 40 | override def deserialize(datum: Any): LinkBag = { 41 | datum match { 42 | case values: ArrayData => 43 | new LinkBag(values.toArray[UTF8String](StringType).map{ elem => 44 | val in = new ByteArrayInputStream(elem.getBytes) 45 | val is = new ObjectInputStream(in) 46 | is.readObject().asInstanceOf[ORecordId] 47 | }) 48 | case other => sys.error(s"Cannot deserialize $other") 49 | } 50 | } 51 | 52 | override def userClass: Class[LinkBag] = classOf[LinkBag] 53 | } 54 | 55 | object LinkBagType extends LinkBagType -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/udts/LinkListType.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.udts 2 | 3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 4 | 5 | import com.orientechnologies.orient.core.record.ORecord 6 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 7 | import org.apache.spark.sql.types._ 8 | import org.apache.spark.unsafe.types.UTF8String 9 | 10 | @SQLUserDefinedType(udt = classOf[LinkListType]) 11 | case class LinkList(elements: Array[_ <: ORecord]) extends Serializable { 12 | override def hashCode(): Int = { 13 | var hashCode = 1 14 | val i = elements.iterator 15 | while (i.hasNext) { 16 | val obj = i.next() 17 | 18 | val elemValue = if (obj == null) 0 else obj.hashCode() 19 | hashCode = 31 * hashCode + elemValue 20 | } 21 | hashCode 22 | } 23 | 24 | override def equals(other: scala.Any): Boolean = other match { 25 | case that: LinkList => that.elements.sameElements(this.elements) 26 | case _ => false 27 | } 28 | 29 | override def toString: String = elements.mkString(", ") 30 | } 31 | 32 | class LinkListType extends UserDefinedType[LinkList] { 33 | 34 | override def sqlType: DataType = ArrayType(StringType) 35 | 36 | override def serialize(obj: LinkList): Any = { 37 | new GenericArrayData(obj.elements.map{ elem => 38 | val out = new ByteArrayOutputStream() 39 | val os = new ObjectOutputStream(out) 40 | os.writeObject(elem) 41 | UTF8String.fromBytes(out.toByteArray) 42 | }) 43 | } 44 | 45 | override def deserialize(datum: Any): LinkList = { 46 | datum match { 47 | case values: ArrayData => 48 | new LinkList(values.toArray[UTF8String](StringType).map{ elem => 49 | val in = new ByteArrayInputStream(elem.getBytes) 50 | val is = new ObjectInputStream(in) 51 | is.readObject().asInstanceOf[ORecord] 52 | }) 53 | case other => sys.error(s"Cannot deserialize $other") 54 | } 55 | } 56 | 57 | override def userClass: Class[LinkList] = classOf[LinkList] 58 | } 59 | 60 | object LinkListType extends LinkListType -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/udts/LinkMapType.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.udts 2 | 3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 4 | 5 | import com.orientechnologies.orient.core.record.ORecord 6 | import org.apache.spark.sql.catalyst.expressions.UnsafeMapData 7 | import org.apache.spark.sql.catalyst.util.ArrayBasedMapData 8 | import org.apache.spark.sql.types._ 9 | import org.apache.spark.unsafe.types.UTF8String 10 | 11 | @SQLUserDefinedType(udt = classOf[LinkMapType]) 12 | case class LinkMap(elements: Map[String, _ <: ORecord]) extends Serializable { 13 | override def hashCode(): Int = 1 14 | 15 | override def equals(other: scala.Any): Boolean = other match { 16 | case that: LinkMap => that.elements == this.elements 17 | case _ => false 18 | } 19 | 20 | override def toString: String = elements.mkString(", ") 21 | } 22 | 23 | class LinkMapType extends UserDefinedType[LinkMap] { 24 | 25 | override def sqlType: DataType = MapType(StringType, StringType) 26 | 27 | override def serialize(obj: LinkMap): Any = { 28 | ArrayBasedMapData(obj.elements.keySet.map{ elem => 29 | val outKey = new ByteArrayOutputStream() 30 | val osKey = new ObjectOutputStream(outKey) 31 | osKey.writeObject(elem) 32 | UTF8String.fromBytes(outKey.toByteArray) 33 | }.toArray, 34 | obj.elements.values.map{ elem => 35 | val outValue = new ByteArrayOutputStream() 36 | val osValue = new ObjectOutputStream(outValue) 37 | osValue.writeObject(elem) 38 | UTF8String.fromBytes(outValue.toByteArray) 39 | }.toArray) 40 | } 41 | 42 | override def deserialize(datum: Any): LinkMap = { 43 | datum match { 44 | case values: UnsafeMapData => 45 | new LinkMap(values.keyArray().toArray[UTF8String](StringType).map { elem => 46 | val in = new ByteArrayInputStream(elem.getBytes) 47 | val is = new ObjectInputStream(in) 48 | is.readObject().toString 49 | }.zip(values.valueArray().toArray[UTF8String](StringType).map { elem => 50 | val in = new ByteArrayInputStream(elem.getBytes) 51 | val is = new ObjectInputStream(in) 52 | is.readObject().asInstanceOf[ORecord] 53 | }).toMap) 54 | case other => sys.error(s"Cannot deserialize $other") 55 | } 56 | } 57 | 58 | override def userClass: Class[LinkMap] = classOf[LinkMap] 59 | } 60 | 61 | object LinkMapType extends LinkMapType -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/udts/LinkSetType.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.udts 2 | 3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 4 | 5 | import com.orientechnologies.orient.core.record.ORecord 6 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 7 | import org.apache.spark.sql.types._ 8 | import org.apache.spark.unsafe.types.UTF8String 9 | 10 | @SQLUserDefinedType(udt = classOf[LinkSetType]) 11 | case class LinkSet(elements: Array[_ <: ORecord]) extends Serializable { 12 | override def hashCode(): Int = { 13 | var hashCode = 1 14 | val i = elements.iterator 15 | while (i.hasNext) { 16 | val obj = i.next() 17 | 18 | val elemValue = if (obj == null) 0 else obj.hashCode() 19 | hashCode = 31 * hashCode + elemValue 20 | } 21 | hashCode 22 | } 23 | 24 | override def equals(other: scala.Any): Boolean = other match { 25 | case that: LinkSet => that.elements.sameElements(this.elements) 26 | case _ => false 27 | } 28 | 29 | override def toString: String = elements.mkString(", ") 30 | } 31 | 32 | class LinkSetType extends UserDefinedType[LinkSet] { 33 | 34 | override def sqlType: DataType = ArrayType(StringType) 35 | 36 | override def serialize(obj: LinkSet): Any = { 37 | new GenericArrayData(obj.elements.map{elem => 38 | val out = new ByteArrayOutputStream() 39 | val os = new ObjectOutputStream(out) 40 | os.writeObject(elem) 41 | UTF8String.fromBytes(out.toByteArray) 42 | }) 43 | } 44 | 45 | override def deserialize(datum: Any): LinkSet = { 46 | datum match { 47 | case values: ArrayData => 48 | new LinkSet(values.toArray[UTF8String](StringType).map{ elem => 49 | val in = new ByteArrayInputStream(elem.getBytes) 50 | val is = new ObjectInputStream(in) 51 | is.readObject().asInstanceOf[ORecord] 52 | }) 53 | case other => sys.error(s"Cannot deserialize $other") 54 | } 55 | } 56 | 57 | override def userClass: Class[LinkSet] = classOf[LinkSet] 58 | } 59 | 60 | object LinkSetType extends LinkSetType -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/orientdb/udts/LinkType.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.udts 2 | 3 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 4 | 5 | import com.orientechnologies.orient.core.db.record.OIdentifiable 6 | import com.orientechnologies.orient.core.id.ORecordId 7 | import com.orientechnologies.orient.core.record.ORecord 8 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 9 | import org.apache.spark.sql.types._ 10 | import org.apache.spark.unsafe.types.UTF8String 11 | 12 | @SQLUserDefinedType(udt = classOf[LinkType]) 13 | case class Link(element: OIdentifiable) extends Serializable { 14 | override def hashCode(): Int = { 15 | var hashCode = 1 16 | 17 | val elemValue = if (element == null) 0 else element.hashCode() 18 | hashCode = 31 * hashCode + elemValue 19 | hashCode 20 | } 21 | 22 | override def equals(other: scala.Any): Boolean = other match { 23 | case that: Link => that.element.equals(this.element) 24 | case _ => false 25 | } 26 | 27 | override def toString: String = element.toString 28 | } 29 | 30 | class LinkType extends UserDefinedType[Link] { 31 | 32 | override def sqlType: DataType = ArrayType(StringType) 33 | 34 | override def serialize(obj: Link): Any = { 35 | val out = new ByteArrayOutputStream() 36 | val os = new ObjectOutputStream(out) 37 | if (obj.element.isInstanceOf[ORecord]) 38 | os.writeObject(obj.element.asInstanceOf[ORecord]) 39 | else 40 | os.writeObject(obj.element.asInstanceOf[ORecordId]) 41 | new GenericArrayData(Array(UTF8String.fromBytes(out.toByteArray))) 42 | } 43 | 44 | override def deserialize(datum: Any): Link = { 45 | datum match { 46 | case values: ArrayData => 47 | new Link(values.toArray[UTF8String](StringType).map { elem => 48 | val in = new ByteArrayInputStream(elem.getBytes) 49 | val is = new ObjectInputStream(in) 50 | val data = is.readObject() 51 | if (data.isInstanceOf[ORecord]) { 52 | data.asInstanceOf[ORecord] 53 | } else { 54 | data.asInstanceOf[ORecordId] 55 | } 56 | }.head) 57 | case other => sys.error(s"Cannot deserialize $other") 58 | } 59 | } 60 | 61 | override def userClass: Class[Link] = classOf[Link] 62 | } 63 | 64 | object LinkType extends LinkType 65 | 66 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/QueryTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb 2 | 3 | import org.apache.spark.sql.catalyst.plans.logical 4 | import org.apache.spark.sql.{DataFrame, Row} 5 | import org.scalatest.FunSuite 6 | 7 | trait QueryTest extends FunSuite { 8 | 9 | def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { 10 | val isSorted = df.queryExecution.logical.collect {case s: logical.Sort => s}.nonEmpty 11 | def prepareAnswer(answer: Seq[Row]): Seq[Row] = { 12 | val converted: Seq[Row] = answer.map { s => 13 | Row.fromSeq(s.toSeq.map { 14 | case d: java.math.BigDecimal => BigDecimal(d) 15 | case b: Array[Byte] => b.toSeq 16 | case o => o 17 | }) 18 | } 19 | if (!isSorted) converted.sortBy(_.toString()) else converted 20 | } 21 | 22 | val sparkAnswer = try df.collect().toSeq catch { 23 | case e: Exception => 24 | val errorMessage = 25 | s""" 26 | |Exception thrown while executing query: 27 | |${df.queryExecution} 28 | |== Exception == 29 | |$e 30 | |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} 31 | """.stripMargin 32 | fail(errorMessage) 33 | } 34 | 35 | if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { 36 | val errorMessage = 37 | s""" 38 | |Results do not match for query: 39 | |${df.queryExecution} 40 | |== Results == 41 | |${sideBySide( 42 | s"== Correct Answer - ${expectedAnswer.size} ==" +: 43 | prepareAnswer(expectedAnswer).map(_.toString()), 44 | s"== Spark Answer - ${sparkAnswer.size} ==" +: 45 | prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} 46 | """.stripMargin 47 | fail(errorMessage) 48 | } 49 | } 50 | 51 | private def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { 52 | val maxLeftSize = left.map(_.length).max 53 | val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") 54 | val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") 55 | 56 | leftPadded.zip(rightPadded).map { 57 | case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r 58 | } 59 | } 60 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/documents/ConversionsSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import com.orientechnologies.orient.core.metadata.schema.OType 4 | import com.orientechnologies.orient.core.record.impl.ODocument 5 | import org.apache.spark.sql.{Row, SparkSession} 6 | import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} 7 | import org.scalatest.FunSuite 8 | 9 | class ConversionsSuite extends FunSuite { 10 | 11 | test("Spark datatype to OrientDB datatype test") { 12 | val orientDBType = Conversions.sparkDTtoOrientDBDT(StringType) 13 | assert(orientDBType === OType.STRING) 14 | } 15 | 16 | test("Convert Spark Row to Orient DB ODocument") { 17 | val expectedData = new ODocument() 18 | expectedData.field("key", 1, OType.INTEGER) 19 | expectedData.field("value", "Spark datasource for Orient DB", OType.STRING) 20 | 21 | val spark = SparkSession.builder().appName("OrientDBLinkUDTsSourceSuite") 22 | .master("local[*]") 23 | .getOrCreate() 24 | 25 | val rows = spark.sqlContext.createDataFrame(spark.sparkContext.parallelize(Seq(Row(1, "Spark datasource for Orient DB"))), 26 | StructType(Array(StructField("key", IntegerType, true), 27 | StructField("value", StringType, true)))).collect() 28 | 29 | val actualData = Conversions.convertRowsToODocuments(rows(0)) 30 | assert(expectedData.field[Int]("key") == actualData.field[Int]("key")) 31 | assert(expectedData.field[String]("value") == actualData.field[String]("value")) 32 | spark.close() 33 | } 34 | 35 | test("Convert OrientDB ODocument to Spark Row") { 36 | val oDocument = new ODocument() 37 | oDocument.field("key", 1, OType.INTEGER) 38 | oDocument.field("value", "Orient DB ODocument to Spark Row", OType.STRING) 39 | 40 | val schema = StructType(Array(StructField("key", IntegerType), 41 | StructField("value", StringType))) 42 | 43 | val expectedData = Row(1, "Orient DB ODocument to Spark Row") 44 | val actualData = Conversions.convertODocumentsToRows(oDocument, schema) 45 | 46 | assert(expectedData === actualData) 47 | } 48 | 49 | test("Return field of correct type") { 50 | val field = Conversions.orientDBDTtoSparkDT(IntegerType, "1") 51 | assert(field.isInstanceOf[Int]) 52 | } 53 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/documents/FilterPushdownSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import org.apache.spark.sql.sources._ 4 | import org.apache.spark.sql.types._ 5 | import org.scalatest.FunSuite 6 | 7 | class FilterPushdownSuite extends FunSuite { 8 | 9 | test("buildWhereClause with empty list of filters") { 10 | assert(FilterPushdown.buildWhereClause(StructType(Nil), Seq.empty) === "") 11 | } 12 | 13 | test("buildWhereClause with no filters that can be pushed down") { 14 | assert(FilterPushdown.buildWhereClause(StructType(Nil), Seq(AlwaysTrue, AlwaysTrue)) === "") 15 | } 16 | 17 | test("buildWhereClause with with some filters that cannot be pushed down") { 18 | val whereClause = FilterPushdown.buildWhereClause(testSchema, Seq(EqualTo("test_int", 1), AlwaysTrue)) 19 | assert(whereClause === "WHERE test_int = 1") 20 | } 21 | 22 | test("buildWhereClause with multiple filters") { 23 | val filters = Seq( 24 | EqualTo("test_bool", true), 25 | EqualTo("test_string", "Unicode是樂趣"), 26 | GreaterThan("test_double", 1000.0), 27 | LessThan("test_double", Double.MaxValue), 28 | GreaterThanOrEqual("test_float", 1.0f), 29 | LessThanOrEqual("test_int", 43), 30 | IsNotNull("test_int"), 31 | IsNull("test_int") 32 | ) 33 | 34 | val whereClause = FilterPushdown.buildWhereClause(testSchema, filters) 35 | 36 | val expectedWhereClause = 37 | """ 38 | |WHERE test_bool = true 39 | |AND test_string = 'Unicode是樂趣' 40 | |AND test_double > 1000.0 41 | |AND test_double < 1.7976931348623157E308 42 | |AND test_float >= 1.0 43 | |AND test_int <= 43 44 | |AND test_int IS NOT NULL 45 | |AND test_int IS NULL 46 | """.stripMargin.linesIterator.mkString(" ").trim 47 | 48 | assert(whereClause === expectedWhereClause) 49 | } 50 | 51 | private val testSchema: StructType = StructType(Seq( 52 | StructField("test_byte", ByteType), 53 | StructField("test_bool", BooleanType), 54 | StructField("test_date", DateType), 55 | StructField("test_double", DoubleType), 56 | StructField("test_float", FloatType), 57 | StructField("test_int", IntegerType), 58 | StructField("test_long", LongType), 59 | StructField("test_short", ShortType), 60 | StructField("test_string", StringType), 61 | StructField("test_timestamp", TimestampType) 62 | )) 63 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/documents/MockOrientDBDocument.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import com.orientechnologies.orient.core.db.document.ODatabaseDocumentTx 4 | import com.orientechnologies.orient.core.record.impl.ODocument 5 | import org.apache.spark.orientdb.documents.Parameters.MergedParameters 6 | import org.apache.spark.sql.types.StructType 7 | import org.mockito.Matchers._ 8 | import org.mockito.Mockito._ 9 | import org.mockito.invocation.InvocationOnMock 10 | import org.mockito.stubbing.Answer 11 | 12 | class MockOrientDBDocument(existingTablesAndSchemas: Map[String, StructType], 13 | oDocuments: List[ODocument]) { 14 | val documentWrapper: OrientDBDocumentWrapper = spy(new OrientDBDocumentWrapper()) 15 | 16 | doAnswer(new Answer[ODatabaseDocumentTx] { 17 | override def answer(invocationOnMock: InvocationOnMock): ODatabaseDocumentTx = { 18 | mock(classOf[ODatabaseDocumentTx], RETURNS_SMART_NULLS) 19 | } 20 | }).when(documentWrapper).getConnection(any(classOf[MergedParameters])) 21 | 22 | doAnswer(new Answer[Boolean] { 23 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 24 | existingTablesAndSchemas.contains(invocationOnMock.getArguments()(1).asInstanceOf[String]) 25 | } 26 | }).when(documentWrapper).doesClassExists(any(classOf[String])) 27 | 28 | doAnswer(new Answer[Boolean] { 29 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 30 | true 31 | } 32 | }).when(documentWrapper).create(any(classOf[String]), any(classOf[String]), 33 | any(classOf[ODocument])) 34 | 35 | doAnswer(new Answer[List[ODocument]] { 36 | override def answer(invocationOnMock: InvocationOnMock): List[ODocument] = { 37 | oDocuments 38 | } 39 | }).when(documentWrapper).read(any(classOf[List[String]]), any(classOf[String]), any(classOf[Array[String]]), 40 | any(classOf[String]), any(classOf[String])) 41 | 42 | doAnswer(new Answer[Boolean] { 43 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 44 | true 45 | } 46 | }).when(documentWrapper).delete(any(classOf[String]), any(classOf[String]), 47 | any(classOf[Map[String, Tuple2[String, String]]])) 48 | 49 | doAnswer(new Answer[StructType] { 50 | override def answer(invocationOnMock: InvocationOnMock): StructType = { 51 | existingTablesAndSchemas.get(invocationOnMock.getArguments()(1).asInstanceOf[String]).get 52 | } 53 | }).when(documentWrapper).resolveTable(any(classOf[String]), any(classOf[String])) 54 | 55 | doAnswer(new Answer[List[ODocument]] { 56 | override def answer(invocationOnMock: InvocationOnMock): List[ODocument] = { 57 | oDocuments 58 | } 59 | }).when(documentWrapper).genericQuery(any(classOf[String])) 60 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/documents/OrientDBEmbeddedUDTsSourceSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import com.orientechnologies.orient.core.db.record.{OTrackedList, OTrackedMap, OTrackedSet} 4 | import com.orientechnologies.orient.core.metadata.schema.OType 5 | import com.orientechnologies.orient.core.record.impl.ODocument 6 | import org.apache.spark.orientdb.udts._ 7 | import org.apache.spark.{SparkContext} 8 | import org.apache.spark.orientdb.{QueryTest, TestUtils} 9 | import org.apache.spark.sql.sources.{EqualTo, Filter, PrunedFilteredScan} 10 | import org.apache.spark.sql.types.{StructField, StructType} 11 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession} 12 | import org.mockito.Mockito 13 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 14 | 15 | class OrientDBEmbeddedUDTsSourceSuite extends QueryTest 16 | with BeforeAndAfterAll 17 | with BeforeAndAfterEach { 18 | private var sc: SparkContext = _ 19 | private var spark: SparkSession = _ 20 | private var sqlContext: SQLContext = _ 21 | private var mockOrientDBClient: OrientDBClientFactory = _ 22 | private var expectedDataDf: DataFrame = _ 23 | 24 | override protected def beforeAll(): Unit = { 25 | spark = SparkSession.builder().appName("OrientDBLinkUDTsSourceSuite") 26 | .master("local[*]") 27 | .config("", "true") 28 | .getOrCreate() 29 | sc = spark.sparkContext 30 | } 31 | 32 | override protected def afterAll(): Unit = { 33 | if (spark != null) { 34 | spark.stop() 35 | } 36 | } 37 | 38 | override protected def beforeEach(): Unit = { 39 | sqlContext = spark.sqlContext 40 | mockOrientDBClient = Mockito.mock(classOf[OrientDBClientFactory], 41 | Mockito.RETURNS_SMART_NULLS) 42 | expectedDataDf = sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedDataForEmbeddedUDTs), 43 | TestUtils.testSchemaForEmbeddedUDTs) 44 | } 45 | 46 | override protected def afterEach(): Unit = { 47 | sqlContext = null 48 | } 49 | 50 | test("Can load output of OrientDB queries") { 51 | val query = "select embeddedset, embeddedmap from test_table" 52 | 53 | val querySchema = StructType(Seq(StructField("embeddedset", EmbeddedSetType), 54 | StructField("embeddedmap", EmbeddedMapType))) 55 | 56 | { 57 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 58 | "user" -> "root", 59 | "password" -> "root", 60 | "class" -> "test_table", 61 | "clusters" -> "test_cluster") 62 | 63 | val iSourceRecord = new ODocument() 64 | iSourceRecord.field("id", 1, OType.INTEGER) 65 | 66 | var oTrackedSet = new OTrackedSet[Int](iSourceRecord) 67 | oTrackedSet.add(1) 68 | var oTrackedMap = new OTrackedMap[Boolean](iSourceRecord) 69 | oTrackedMap.put(1, true) 70 | 71 | val oDoc1 = new ODocument() 72 | oDoc1.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET) 73 | oDoc1.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP) 74 | 75 | oTrackedSet = new OTrackedSet[Int](iSourceRecord) 76 | oTrackedSet.add(2) 77 | oTrackedMap = new OTrackedMap[Boolean](iSourceRecord) 78 | oTrackedMap.put(1, false) 79 | 80 | val oDoc2 = new ODocument() 81 | oDoc2.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET) 82 | oDoc2.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP) 83 | 84 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema), 85 | List(oDoc1, oDoc2)) 86 | 87 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 88 | .createRelation(sqlContext, params) 89 | sqlContext.baseRelationToDataFrame(relation).collect() 90 | } 91 | 92 | { 93 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 94 | "user" -> "root", 95 | "password" -> "root", 96 | "class" -> "test_table", 97 | "query" -> query, 98 | "clusters" -> "test_cluster") 99 | 100 | val iSourceRecord = new ODocument() 101 | iSourceRecord.field("id", 1, OType.INTEGER) 102 | 103 | var oTrackedSet = new OTrackedSet[Int](iSourceRecord) 104 | oTrackedSet.add(1) 105 | var oTrackedMap = new OTrackedMap[Boolean](iSourceRecord) 106 | oTrackedMap.put(1, true) 107 | 108 | val oDoc1 = new ODocument() 109 | oDoc1.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET) 110 | oDoc1.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP) 111 | 112 | oTrackedSet = new OTrackedSet[Int](iSourceRecord) 113 | oTrackedSet.add(2) 114 | oTrackedMap = new OTrackedMap[Boolean](iSourceRecord) 115 | oTrackedMap.put(1, false) 116 | 117 | val oDoc2 = new ODocument() 118 | oDoc2.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET) 119 | oDoc2.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP) 120 | 121 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema), 122 | List(oDoc1, oDoc2)) 123 | 124 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 125 | .createRelation(sqlContext, params) 126 | sqlContext.baseRelationToDataFrame(relation).collect() 127 | } 128 | } 129 | 130 | test("DefaultSource supports simple column filtering") { 131 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 132 | "user" -> "root", 133 | "password" -> "root", 134 | "class" -> "test_table", 135 | "clusters" -> "test_cluster") 136 | 137 | val iSourceRecord = new ODocument() 138 | iSourceRecord.field("id", 1, OType.INTEGER) 139 | 140 | var oTrackedList = new OTrackedList[Byte](iSourceRecord) 141 | oTrackedList.add(1.toByte) 142 | var oTrackedSet = new OTrackedSet[Boolean](iSourceRecord) 143 | oTrackedSet.add(true) 144 | var oTrackedMap = new OTrackedMap[String](iSourceRecord) 145 | oTrackedMap.put(1, "Hello") 146 | 147 | val oDoc1 = new ODocument() 148 | oDoc1.field("embeddedlist", oTrackedList, OType.EMBEDDEDLIST) 149 | oDoc1.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET) 150 | oDoc1.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP) 151 | 152 | oTrackedList = new OTrackedList[Byte](iSourceRecord) 153 | oTrackedList.add(2.toByte) 154 | oTrackedSet = new OTrackedSet[Boolean](iSourceRecord) 155 | oTrackedSet.add(false) 156 | oTrackedMap = new OTrackedMap[String](iSourceRecord) 157 | oTrackedMap.put(1, "World") 158 | 159 | val oDoc2 = new ODocument() 160 | oDoc2.field("embeddedlist", oTrackedList, OType.EMBEDDEDLIST) 161 | oDoc2.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET) 162 | oDoc2.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP) 163 | 164 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchemaForEmbeddedUDTs), 165 | List(oDoc1, oDoc2)) 166 | 167 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 168 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForEmbeddedUDTs) 169 | 170 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 171 | .buildScan(Array("embeddedlist", "embeddedset"), Array.empty[Filter]) 172 | 173 | val prunedExpectedValues = Array( 174 | Row(EmbeddedList(Array(1.toByte)), EmbeddedSet(Array(true))), Row(EmbeddedList(Array(2.toByte)), EmbeddedSet(Array(false))) 175 | ) 176 | 177 | val result = rdd.collect() 178 | assert(result.length === prunedExpectedValues.length) 179 | assert(result === prunedExpectedValues) 180 | } 181 | 182 | test("DefaultSource supports user schema, pruned and filtered scans") { 183 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 184 | "user" -> "root", 185 | "password" -> "root", 186 | "class" -> "test_table", 187 | "clusters" -> "test_cluster") 188 | 189 | val iSourceRecord = new ODocument() 190 | iSourceRecord.field("id", 1, OType.INTEGER) 191 | 192 | var oTrackedList = new OTrackedList[Byte](iSourceRecord) 193 | oTrackedList.add(1.toByte) 194 | var oTrackedSet = new OTrackedSet[Boolean](iSourceRecord) 195 | oTrackedSet.add(true) 196 | var oTrackedMap = new OTrackedMap[String](iSourceRecord) 197 | oTrackedMap.put(1, "Hello") 198 | 199 | val oDoc1 = new ODocument() 200 | oDoc1.field("embeddedlist", oTrackedList, OType.EMBEDDEDLIST) 201 | oDoc1.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET) 202 | oDoc1.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP) 203 | 204 | oTrackedList = new OTrackedList[Byte](iSourceRecord) 205 | oTrackedList.add(1.toByte) 206 | oTrackedSet = new OTrackedSet[Boolean](iSourceRecord) 207 | oTrackedSet.add(false) 208 | oTrackedMap = new OTrackedMap[String](iSourceRecord) 209 | oTrackedMap.put(1, "World") 210 | 211 | val oDoc2 = new ODocument() 212 | oDoc2.field("embeddedlist", oTrackedList, OType.EMBEDDEDLIST) 213 | oDoc2.field("embeddedset", oTrackedSet, OType.EMBEDDEDSET) 214 | oDoc2.field("embeddedmap", oTrackedMap, OType.EMBEDDEDMAP) 215 | 216 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchemaForEmbeddedUDTs), 217 | List(oDoc1, oDoc2)) 218 | 219 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 220 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForEmbeddedUDTs) 221 | 222 | val filters: Array[Filter] = Array( 223 | EqualTo("embeddedlist", oTrackedList), 224 | EqualTo("embeddedset", oTrackedSet) 225 | ) 226 | 227 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 228 | .buildScan(Array("embeddedmap", "embeddedset"), filters) 229 | 230 | assert(rdd.collect().contains(Row(EmbeddedSet(Array(false)), EmbeddedMap(Map(1 -> "World"))))) 231 | } 232 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/documents/OrientDBLinkUDTsSourceSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import java.util 4 | import com.orientechnologies.orient.core.db.record.ridbag.ORidBag 5 | import com.orientechnologies.orient.core.db.record.{ORecordLazyList, ORecordLazyMap, ORecordLazySet} 6 | import com.orientechnologies.orient.core.id.ORecordId 7 | import com.orientechnologies.orient.core.metadata.schema.OType 8 | import com.orientechnologies.orient.core.record.ORecord 9 | import com.orientechnologies.orient.core.record.impl.ODocument 10 | import org.apache.spark.orientdb.udts._ 11 | import org.apache.spark.{SparkConf, SparkContext} 12 | import org.apache.spark.orientdb.{QueryTest, TestUtils} 13 | import org.apache.spark.sql.sources.{EqualTo, Filter, PrunedFilteredScan} 14 | import org.apache.spark.sql.types.{StructField, StructType} 15 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession} 16 | import org.mockito.Mockito 17 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 18 | 19 | class OrientDBLinkUDTsSourceSuite extends QueryTest 20 | with BeforeAndAfterAll 21 | with BeforeAndAfterEach { 22 | private var sc: SparkContext = _ 23 | private var spark: SparkSession = _ 24 | private var sqlContext: SQLContext = _ 25 | private var mockOrientDBClient: OrientDBClientFactory = _ 26 | private var expectedDataDf: DataFrame = _ 27 | 28 | override protected def beforeAll(): Unit = { 29 | spark = SparkSession.builder().appName("OrientDBLinkUDTsSourceSuite") 30 | .master("local[*]") 31 | .getOrCreate() 32 | sc = spark.sparkContext; 33 | } 34 | 35 | override protected def afterAll(): Unit = { 36 | if (spark != null) { 37 | spark.close() 38 | } 39 | } 40 | 41 | override protected def beforeEach(): Unit = { 42 | sqlContext = spark.sqlContext 43 | mockOrientDBClient = Mockito.mock(classOf[OrientDBClientFactory], 44 | Mockito.RETURNS_SMART_NULLS) 45 | expectedDataDf = sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedDataForLinkUDTs), 46 | TestUtils.testSchemaForLinkUDTs) 47 | } 48 | 49 | override protected def afterEach(): Unit = { 50 | sqlContext = null 51 | } 52 | 53 | test("Can load output of OrientDB queries") { 54 | val query = "select linkset, linkmap, linkbag, link from test_link_table" 55 | 56 | val querySchema = StructType(Seq(StructField("linkset", LinkSetType), 57 | StructField("linkmap", LinkMapType), 58 | StructField("linkbag", LinkBagType), 59 | StructField("link", LinkType))) 60 | 61 | { 62 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 63 | "user" -> "root", 64 | "password" -> "root", 65 | "class" -> "test_link_table", 66 | "clusters" -> "test_link_cluster") 67 | 68 | val iSourceRecord = new ODocument() 69 | iSourceRecord.field("id", 1, OType.INTEGER) 70 | 71 | var oDoc1 = new ODocument() 72 | oDoc1.field("int", 1, OType.INTEGER) 73 | var oDoc2 = new ODocument() 74 | oDoc2.field("boolean", true) 75 | var oRid1 = new ORecordId() 76 | oRid1.fromString("#1:1") 77 | var oRid2 = new ORecordId() 78 | oRid2.fromString("#2:2") 79 | 80 | var oRecordLazySet = new ORecordLazySet(iSourceRecord) 81 | oRecordLazySet.add(oDoc1) 82 | var oRecordLazyMap = new ORecordLazyMap(iSourceRecord) 83 | oRecordLazyMap.put("1", oDoc2) 84 | var oRidBag = new ORidBag() 85 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2)) 86 | 87 | val oDoc3 = new ODocument() 88 | oDoc3.field("linkset", oRecordLazySet, OType.LINKSET) 89 | oDoc3.field("linkmap", oRecordLazyMap, OType.LINKMAP) 90 | oDoc3.field("linkbag", oRidBag, OType.LINKBAG) 91 | oDoc3.field("link", oDoc1, OType.LINK) 92 | 93 | oDoc1 = new ODocument() 94 | oDoc1.field("int", 2, OType.INTEGER) 95 | oDoc2 = new ODocument() 96 | oDoc2.field("boolean", false) 97 | oRid1 = new ORecordId() 98 | oRid1.fromString("#3:3") 99 | oRid2 = new ORecordId() 100 | oRid2.fromString("#4:4") 101 | 102 | oRecordLazySet = new ORecordLazySet(iSourceRecord) 103 | oRecordLazySet.add(oDoc1) 104 | oRecordLazyMap = new ORecordLazyMap(iSourceRecord) 105 | oRecordLazyMap.put(1, oDoc2) 106 | oRidBag = new ORidBag() 107 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2)) 108 | 109 | val oDoc4 = new ODocument() 110 | oDoc4.field("linkset", oRecordLazySet, OType.LINKSET) 111 | oDoc4.field("linkmap", oRecordLazyMap, OType.LINKMAP) 112 | oDoc4.field("linkbag", oRidBag, OType.LINKBAG) 113 | oDoc4.field("link", oDoc1, OType.LINK) 114 | 115 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema), 116 | List(oDoc3, oDoc4)) 117 | 118 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 119 | .createRelation(sqlContext, params) 120 | sqlContext.baseRelationToDataFrame(relation).collect() 121 | } 122 | 123 | { 124 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 125 | "user" -> "root", 126 | "password" -> "root", 127 | "class" -> "test_link_table", 128 | "query" -> query, 129 | "clusters" -> "test_link_cluster") 130 | 131 | val iSourceRecord = new ODocument() 132 | iSourceRecord.field("id", 1, OType.INTEGER) 133 | 134 | var oDoc1 = new ODocument() 135 | oDoc1.field("int", 1, OType.INTEGER) 136 | var oDoc2 = new ODocument() 137 | oDoc2.field("boolean", true) 138 | var oRid1 = new ORecordId() 139 | oRid1.fromString("#1:1") 140 | var oRid2 = new ORecordId() 141 | oRid2.fromString("#2:2") 142 | 143 | var oRecordLazySet = new ORecordLazySet(iSourceRecord) 144 | oRecordLazySet.add(oDoc1) 145 | var oRecordLazyMap = new ORecordLazyMap(iSourceRecord) 146 | oRecordLazyMap.put("1", oDoc2) 147 | var oRidBag = new ORidBag() 148 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2)) 149 | 150 | val oDoc3 = new ODocument() 151 | oDoc3.field("linkset", oRecordLazySet, OType.LINKSET) 152 | oDoc3.field("linkmap", oRecordLazyMap, OType.LINKMAP) 153 | 154 | oDoc1 = new ODocument() 155 | oDoc1.field("int", 2, OType.INTEGER) 156 | oDoc2 = new ODocument() 157 | oDoc2.field("boolean", false) 158 | oRid1 = new ORecordId() 159 | oRid1.fromString("#3:3") 160 | oRid2 = new ORecordId() 161 | oRid2.fromString("#4:4") 162 | 163 | oRecordLazySet = new ORecordLazySet(iSourceRecord) 164 | oRecordLazySet.add(oDoc1) 165 | oRecordLazyMap = new ORecordLazyMap(iSourceRecord) 166 | oRecordLazyMap.put(1, oDoc2) 167 | oRidBag = new ORidBag() 168 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2)) 169 | 170 | val oDoc4 = new ODocument() 171 | oDoc4.field("linkset", oRecordLazySet, OType.LINKSET) 172 | oDoc4.field("linkmap", oRecordLazyMap, OType.LINKMAP) 173 | oDoc4.field("linkbag", oRidBag, OType.LINKBAG) 174 | oDoc4.field("link", oDoc1, OType.LINK) 175 | 176 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema), 177 | List(oDoc3, oDoc4)) 178 | 179 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 180 | .createRelation(sqlContext, params) 181 | sqlContext.baseRelationToDataFrame(relation).collect() 182 | } 183 | } 184 | 185 | test("DefaultSource supports simple column filtering") { 186 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 187 | "user" -> "root", 188 | "password" -> "root", 189 | "class" -> "test_link_table", 190 | "clusters" -> "test_link_cluster") 191 | 192 | val iSourceRecord = new ODocument() 193 | iSourceRecord.field("id", 1, OType.INTEGER) 194 | 195 | var oDoc0 = new ODocument(new ORecordId("#1:1")) 196 | oDoc0.field("byte", 1.toByte, OType.BYTE) 197 | var oDoc1 = new ODocument(new ORecordId("#2:2")) 198 | oDoc1.field("boolean", true, OType.BOOLEAN) 199 | var oDoc2 = new ODocument(new ORecordId("#4:4")) 200 | oDoc2.field("string", "Hello") 201 | var oRid1 = new ORecordId() 202 | oRid1.fromString("#1:1") 203 | var oRid2 = new ORecordId() 204 | oRid2.fromString("#2:2") 205 | 206 | var oRecordLazyList = new ORecordLazyList(iSourceRecord) 207 | oRecordLazyList.add(oDoc0) 208 | var oRecordLazySet = new ORecordLazySet(iSourceRecord) 209 | oRecordLazySet.add(oDoc1) 210 | var oRecordLazyMap = new ORecordLazyMap(iSourceRecord) 211 | oRecordLazyMap.put("1", oDoc2) 212 | var oRidBag = new ORidBag() 213 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2)) 214 | 215 | val oDoc3 = new ODocument() 216 | oDoc3.field("linklist", oRecordLazyList, OType.LINKLIST) 217 | oDoc3.field("linkset", oRecordLazySet, OType.LINKSET) 218 | oDoc3.field("linkmap", oRecordLazyMap, OType.LINKMAP) 219 | oDoc3.field("linkbag", oRidBag, OType.LINKBAG) 220 | oDoc3.field("link", oDoc1, OType.LINK) 221 | 222 | val expected1 = Row(LinkList(Array(oDoc0.asInstanceOf[ORecord])), LinkSet(Array(oDoc1.asInstanceOf[ORecord]))) 223 | 224 | oDoc0 = new ODocument(new ORecordId("#1:1")) 225 | oDoc0.field("byte", 2.toByte, OType.BYTE) 226 | oDoc1 = new ODocument(new ORecordId("#2:2")) 227 | oDoc1.field("boolean", false, OType.BOOLEAN) 228 | oDoc2 = new ODocument(new ORecordId("#4:4")) 229 | oDoc2.field("string", "World") 230 | oRid1 = new ORecordId() 231 | oRid1.fromString("#3:3") 232 | oRid2 = new ORecordId() 233 | oRid2.fromString("#4:4") 234 | 235 | oRecordLazyList = new ORecordLazyList(iSourceRecord) 236 | oRecordLazyList.add(oDoc0) 237 | oRecordLazySet = new ORecordLazySet(iSourceRecord) 238 | oRecordLazySet.add(oDoc1) 239 | oRecordLazyMap = new ORecordLazyMap(iSourceRecord) 240 | oRecordLazyMap.put(1, oDoc2) 241 | oRidBag = new ORidBag() 242 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2)) 243 | 244 | val oDoc4 = new ODocument() 245 | oDoc4.field("linklist", oRecordLazyList, OType.LINKLIST) 246 | oDoc4.field("linkset", oRecordLazySet, OType.LINKSET) 247 | oDoc4.field("linkmap", oRecordLazyMap, OType.LINKMAP) 248 | oDoc4.field("linkbag", oRidBag, OType.LINKBAG) 249 | oDoc4.field("link", oDoc1, OType.LINK) 250 | 251 | val expected2 = Row(LinkList(Array(oDoc0.asInstanceOf[ORecord])), LinkSet(Array(oDoc1.asInstanceOf[ORecord]))) 252 | 253 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchemaForLinkUDTs), 254 | List(oDoc3, oDoc4)) 255 | 256 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 257 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForLinkUDTs) 258 | 259 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 260 | .buildScan(Array("linklist", "linkset"), Array.empty[Filter]) 261 | 262 | val prunedExpectedValues = Array(expected1, expected2) 263 | 264 | val result = rdd.collect() 265 | assert(result.length === prunedExpectedValues.length) 266 | assert(result === prunedExpectedValues) 267 | } 268 | 269 | test("DefaultSource supports user schema, pruned and filtered scans") { 270 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 271 | "user" -> "root", 272 | "password" -> "root", 273 | "class" -> "test_link_table", 274 | "clusters" -> "test_link_cluster") 275 | 276 | val iSourceRecord = new ODocument() 277 | iSourceRecord.field("id", 1, OType.INTEGER) 278 | 279 | var oDoc0 = new ODocument(new ORecordId("#1:1")) 280 | oDoc0.field("byte", 1.toByte, OType.BYTE) 281 | var oDoc1 = new ODocument(new ORecordId("#2:2")) 282 | oDoc1.field("boolean", true, OType.BOOLEAN) 283 | var oDoc2 = new ODocument(new ORecordId("#4:4")) 284 | oDoc2.field("string", "Hello") 285 | var oRid1 = new ORecordId() 286 | oRid1.fromString("#1:1") 287 | var oRid2 = new ORecordId() 288 | oRid2.fromString("#2:2") 289 | 290 | var oRecordLazyList = new ORecordLazyList(iSourceRecord) 291 | oRecordLazyList.add(oDoc0) 292 | var oRecordLazySet = new ORecordLazySet(iSourceRecord) 293 | oRecordLazySet.add(oDoc1) 294 | var oRecordLazyMap = new ORecordLazyMap(iSourceRecord) 295 | oRecordLazyMap.put("1", oDoc2) 296 | var oRidBag = new ORidBag() 297 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2)) 298 | 299 | val oDoc3 = new ODocument() 300 | oDoc3.field("linklist", oRecordLazyList, OType.LINKLIST) 301 | oDoc3.field("linkset", oRecordLazySet, OType.LINKSET) 302 | oDoc3.field("linkmap", oRecordLazyMap, OType.LINKMAP) 303 | oDoc3.field("linkbag", oRidBag, OType.LINKBAG) 304 | oDoc3.field("link", oDoc1, OType.LINK) 305 | 306 | oDoc0 = new ODocument(new ORecordId("#1:1")) 307 | oDoc0.field("byte", 2.toByte, OType.BYTE) 308 | oDoc1 = new ODocument(new ORecordId("#2:2")) 309 | oDoc1.field("boolean", false, OType.BOOLEAN) 310 | oDoc2 = new ODocument(new ORecordId("#4:4")) 311 | oDoc2.field("string", "World") 312 | oRid1 = new ORecordId() 313 | oRid1.fromString("#1:1") 314 | oRid2 = new ORecordId() 315 | oRid2.fromString("#2:2") 316 | 317 | oRecordLazyList = new ORecordLazyList(iSourceRecord) 318 | oRecordLazyList.add(oDoc0) 319 | oRecordLazySet = new ORecordLazySet(iSourceRecord) 320 | oRecordLazySet.add(oDoc1) 321 | oRecordLazyMap = new ORecordLazyMap(iSourceRecord) 322 | oRecordLazyMap.put(1, oDoc2) 323 | oRidBag = new ORidBag() 324 | oRidBag.addAll(util.Arrays.asList(oRid1, oRid2)) 325 | 326 | val oDoc4 = new ODocument() 327 | oDoc4.field("linklist", oRecordLazyList, OType.LINKLIST) 328 | oDoc4.field("linkset", oRecordLazySet, OType.LINKSET) 329 | oDoc4.field("linkmap", oRecordLazyMap, OType.LINKMAP) 330 | oDoc4.field("linkbag", oRidBag, OType.LINKBAG) 331 | oDoc4.field("link", oDoc1, OType.LINK) 332 | 333 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchemaForLinkUDTs), 334 | List(oDoc3, oDoc4)) 335 | 336 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 337 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForLinkUDTs) 338 | 339 | val filters: Array[Filter] = Array( 340 | EqualTo("linklist", oRecordLazyList), 341 | EqualTo("linkset", oRecordLazySet) 342 | ) 343 | 344 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 345 | .buildScan(Array("linklist", "linkset", "linkbag", "link"), filters) 346 | 347 | assert(rdd.collect().contains(Row(LinkList(Array(oDoc0)), LinkSet(Array(oDoc1)), LinkBag(Array(oRid1, oRid2)), Link(oDoc1)))) 348 | } 349 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/documents/OrientDBSourceSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import com.orientechnologies.orient.core.metadata.schema.OType 4 | import com.orientechnologies.orient.core.record.impl.ODocument 5 | import org.apache.spark.{SparkConf, SparkContext} 6 | import org.apache.spark.orientdb.{QueryTest, TestUtils} 7 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession} 8 | import org.apache.spark.sql.sources.{EqualTo, Filter, PrunedFilteredScan} 9 | import org.apache.spark.sql.types.{BooleanType, ByteType, StructField, StructType} 10 | import org.mockito.Mockito 11 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 12 | 13 | class OrientDBSourceSuite extends QueryTest 14 | with BeforeAndAfterAll 15 | with BeforeAndAfterEach { 16 | private var sc: SparkContext = _ 17 | private var spark: SparkSession = _ 18 | private var sqlContext: SQLContext = _ 19 | private var mockOrientDBClient: OrientDBClientFactory = _ 20 | private var expectedDataDf: DataFrame = _ 21 | 22 | override def beforeAll(): Unit = { 23 | spark = SparkSession.builder().appName("OrientDBLinkUDTsSourceSuite") 24 | .master("local[*]") 25 | .getOrCreate() 26 | sc = spark.sparkContext; 27 | } 28 | 29 | override def afterAll(): Unit = { 30 | if (spark != null) 31 | spark.close() 32 | } 33 | 34 | override def beforeEach(): Unit = { 35 | sqlContext = spark.sqlContext 36 | mockOrientDBClient = Mockito.mock(classOf[OrientDBClientFactory], 37 | Mockito.RETURNS_SMART_NULLS) 38 | expectedDataDf = sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), 39 | TestUtils.testSchema) 40 | } 41 | 42 | override def afterEach(): Unit = { 43 | sqlContext = null 44 | } 45 | 46 | test("Can load output of OrientDB queries") { 47 | val query = 48 | """select testbyte, testbool from test_table where teststring = '\\Unicode''s樂趣'""" 49 | 50 | val querySchema = StructType(Seq(StructField("testbyte", ByteType), 51 | StructField("testbool", BooleanType))) 52 | 53 | { 54 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 55 | "user" -> "root", 56 | "password" -> "root", 57 | "class" -> "test_table", 58 | "clusters" -> "test_cluster") 59 | 60 | val oDoc1 = new ODocument() 61 | oDoc1.field("testbyte", 1, OType.BYTE) 62 | oDoc1.field("testbool", true, OType.BOOLEAN) 63 | 64 | val oDoc2 = new ODocument() 65 | oDoc2.field("testbyte", 2, OType.BYTE) 66 | oDoc2.field("testbool", false, OType.BOOLEAN) 67 | 68 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema), 69 | List(oDoc1, oDoc2)) 70 | 71 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 72 | .createRelation(sqlContext, params) 73 | sqlContext.baseRelationToDataFrame(relation).collect() 74 | } 75 | 76 | { 77 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 78 | "user" -> "root", 79 | "password" -> "root", 80 | "class" -> "test_table", 81 | "query" -> query, 82 | "clusters" -> "test_cluster") 83 | 84 | val oDoc1 = new ODocument() 85 | oDoc1.field("testbyte", 1, OType.BYTE) 86 | oDoc1.field("testbool", true, OType.BOOLEAN) 87 | 88 | val oDoc2 = new ODocument() 89 | oDoc2.field("testbyte", 2, OType.BYTE) 90 | oDoc2.field("testbool", false, OType.BOOLEAN) 91 | 92 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> querySchema), 93 | List(oDoc1, oDoc2)) 94 | 95 | val relation = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 96 | .createRelation(sqlContext, params) 97 | sqlContext.baseRelationToDataFrame(relation).collect() 98 | } 99 | } 100 | 101 | test("DefaultSource supports simple column filtering") { 102 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 103 | "user" -> "root", 104 | "password" -> "root", 105 | "class" -> "test_table", 106 | "clusters" -> "test_cluster") 107 | 108 | val oDoc1 = new ODocument() 109 | oDoc1.field("testbyte", 1, OType.BYTE) 110 | oDoc1.field("testbool", true, OType.BOOLEAN) 111 | oDoc1.field("teststring", "Hello", OType.STRING) 112 | 113 | val oDoc2 = new ODocument() 114 | oDoc2.field("testbyte", 2, OType.BYTE) 115 | oDoc2.field("testbool", false, OType.BOOLEAN) 116 | oDoc2.field("teststring", "World", OType.STRING) 117 | 118 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchema), 119 | List(oDoc1, oDoc2)) 120 | 121 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 122 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchema) 123 | 124 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 125 | .buildScan(Array("testbyte", "testbool"), Array.empty[Filter]) 126 | 127 | val prunedExpectedValues = Array( 128 | Row(1.toByte, true), 129 | Row(2.toByte, false) 130 | ) 131 | assert(rdd.collect() === prunedExpectedValues) 132 | } 133 | 134 | test("DefaultSource supports user schema, pruned and filtered scans") { 135 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 136 | "user" -> "root", 137 | "password" -> "root", 138 | "class" -> "test_table", 139 | "clusters" -> "test_cluster") 140 | 141 | val oDoc1 = new ODocument() 142 | oDoc1.field("testbyte", 1, OType.BYTE) 143 | oDoc1.field("testbool", true, OType.BOOLEAN) 144 | oDoc1.field("teststring", "Hello", OType.STRING) 145 | 146 | val oDoc2 = new ODocument() 147 | oDoc2.field("testbyte", 2, OType.BYTE) 148 | oDoc2.field("testbool", false, OType.BOOLEAN) 149 | oDoc2.field("teststring", "World", OType.STRING) 150 | 151 | val mockOrientDBDocument = new MockOrientDBDocument(Map(params("class") -> TestUtils.testSchema), 152 | List(oDoc1, oDoc2)) 153 | 154 | val source = new DefaultSource(mockOrientDBDocument.documentWrapper, _ => mockOrientDBClient) 155 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchema) 156 | 157 | val filters: Array[Filter] = Array( 158 | EqualTo("testbool", true), 159 | EqualTo("teststring", "Hello") 160 | ) 161 | 162 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 163 | .buildScan(Array("testbyte", "testbool"), filters) 164 | 165 | assert(rdd.collect().contains(Row(1, true))) 166 | } 167 | 168 | test("Cannot save when 'query' parameter is specified instead of 'class'") { 169 | val invalidParams = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 170 | "user" -> "root", 171 | "password" -> "root", 172 | "query" -> "select testbyte, testbool from test_table where teststring = '\\Unicode''s樂趣'", 173 | "clusters" -> "test_cluster" 174 | ) 175 | 176 | intercept[IllegalArgumentException] { 177 | expectedDataDf.write.format("org.apache.spark.orientdb.documents").options(invalidParams).save() 178 | } 179 | } 180 | 181 | test("DefaultSource has default constructor, required by Data Source API") { 182 | new DefaultSource() 183 | } 184 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/documents/ParametersSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import org.scalatest.{FunSuite, Matchers} 4 | 5 | class ParametersSuite extends FunSuite with Matchers { 6 | 7 | test("Minimal valid parameter map is accepted") { 8 | val params = Map( 9 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 10 | "user" -> "root", 11 | "password" -> "root", 12 | "class" -> "test_class" 13 | ) 14 | 15 | val mergedParams = Parameters.mergeParameters(params) 16 | 17 | mergedParams.dbUrl shouldBe params.get("dburl") 18 | mergedParams.credentials.get._1 shouldBe params.get("user").get 19 | mergedParams.credentials.get._2 shouldBe params.get("password").get 20 | mergedParams.className.get shouldBe params.get("class").get 21 | 22 | Parameters.DEFAULT_PARAMETERS foreach { 23 | case (key, value) => mergedParams.parameters(key) shouldBe value 24 | } 25 | } 26 | 27 | test("Errors are thrown when mandatory parameters are not provided") { 28 | def checkMerge(params: Map[String, String]): Unit = { 29 | intercept[IllegalArgumentException] { 30 | Parameters.mergeParameters(params) 31 | } 32 | } 33 | 34 | val testURL = "remote:127.0.0.1:2424/GratefulDeadConcerts" 35 | checkMerge(Map("dburl" -> testURL, "class" -> "test_class")) 36 | checkMerge(Map("dburl" -> testURL, "user" -> "root", "password" -> "root")) 37 | checkMerge(Map("user" -> "root", "password" -> "root")) 38 | } 39 | 40 | test("Must specify either 'class' param, or, 'class' and 'query' parameter, or 'query' parameter and " + 41 | "user-defined schema") { 42 | intercept[IllegalArgumentException] { 43 | Parameters.mergeParameters(Map( 44 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 45 | "user" -> "root", 46 | "password" -> "root" 47 | )) 48 | } 49 | 50 | Parameters.mergeParameters(Map( 51 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 52 | "user" -> "root", 53 | "password" -> "root", 54 | "class" -> "test_class" 55 | )) 56 | 57 | Parameters.mergeParameters(Map( 58 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 59 | "user" -> "root", 60 | "password" -> "root", 61 | "class" -> "test_class", 62 | "query" -> "select * from test_class" 63 | )) 64 | } 65 | 66 | test("Must specify 'url' parameter and 'user' and 'password' parameters") { 67 | intercept[IllegalArgumentException] { 68 | Parameters.mergeParameters(Map( 69 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 70 | "class" -> "test_class" 71 | )) 72 | } 73 | 74 | intercept[IllegalArgumentException] { 75 | Parameters.mergeParameters(Map( 76 | "user" -> "root", 77 | "password" -> "root", 78 | "class" -> "test_class" 79 | )) 80 | } 81 | 82 | Parameters.mergeParameters(Map( 83 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 84 | "user" -> "root", 85 | "password" -> "root", 86 | "class" -> "test_class" 87 | )) 88 | } 89 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/documents/TableNameSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.documents 2 | 3 | import org.scalatest.FunSuite 4 | 5 | class TableNameSuite extends FunSuite { 6 | 7 | test("escaped table name") { 8 | val tableName = new TableName("test_table") 9 | assert(tableName.unescapedTableName === "test_table") 10 | } 11 | 12 | test("table Name to String") { 13 | val tableName = new TableName("test_table") 14 | assert(tableName.toString === "test_table") 15 | } 16 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/graphs/MockEdge.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import java.util 4 | 5 | import com.orientechnologies.orient.core.db.record.OIdentifiable 6 | import com.tinkerpop.blueprints.impls.orient.{OrientBaseGraph, OrientEdge} 7 | 8 | class MockEdge(graph: OrientBaseGraph, record: OIdentifiable) 9 | extends OrientEdge(graph, record) { 10 | 11 | override def setProperty(key: String, value: Object): Unit = { 12 | this.getRecord().field(key, value) 13 | } 14 | 15 | override def getPropertyKeys: util.Set[String] = { 16 | val fieldNames = this.getRecord().fieldNames() 17 | val length = fieldNames.length 18 | 19 | var count = 0 20 | val result = new util.HashSet[String]() 21 | while (count < length) { 22 | result.add(fieldNames(count)) 23 | count = count + 1 24 | } 25 | result 26 | } 27 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/graphs/MockOrientDBGraph.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import com.tinkerpop.blueprints.{Edge, Vertex} 4 | import com.tinkerpop.blueprints.impls.orient.OrientGraphNoTx 5 | import org.apache.spark.orientdb.graphs.Parameters.MergedParameters 6 | import org.apache.spark.sql.types.StructType 7 | import org.mockito.Matchers._ 8 | import org.mockito.Mockito._ 9 | import org.mockito.invocation.InvocationOnMock 10 | import org.mockito.stubbing.Answer 11 | 12 | class MockOrientDBGraph(existingTablesAndSchemas: Map[String, StructType], 13 | oVertices: List[Vertex] = null, 14 | oEdges: List[Edge] = null) { 15 | val vertexWrapper: OrientDBGraphVertexWrapper = spy(new OrientDBGraphVertexWrapper()) 16 | val edgeWrapper: OrientDBGraphEdgeWrapper = spy(new OrientDBGraphEdgeWrapper()) 17 | 18 | doAnswer(new Answer[OrientGraphNoTx] { 19 | override def answer(invocationOnMock: InvocationOnMock): OrientGraphNoTx = { 20 | mock(classOf[OrientGraphNoTx], RETURNS_SMART_NULLS) 21 | } 22 | }).when(vertexWrapper).getConnection(any(classOf[MergedParameters])) 23 | 24 | doAnswer(new Answer[OrientGraphNoTx] { 25 | override def answer(invocationOnMock: InvocationOnMock): OrientGraphNoTx = { 26 | mock(classOf[OrientGraphNoTx], RETURNS_SMART_NULLS) 27 | } 28 | }).when(edgeWrapper).getConnection(any(classOf[MergedParameters])) 29 | 30 | doAnswer(new Answer[Boolean] { 31 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 32 | existingTablesAndSchemas.contains(invocationOnMock.getArguments()(1).asInstanceOf[String]) 33 | } 34 | }).when(vertexWrapper).doesVertexTypeExists(any(classOf[String])) 35 | 36 | doAnswer(new Answer[Boolean] { 37 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 38 | existingTablesAndSchemas.contains(invocationOnMock.getArguments()(1).asInstanceOf[String]) 39 | } 40 | }).when(edgeWrapper).doesEdgeTypeExists(any(classOf[String])) 41 | 42 | doAnswer(new Answer[Boolean] { 43 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 44 | true 45 | } 46 | }).when(vertexWrapper).create(any(classOf[String]), any(classOf[Map[String, Object]])) 47 | 48 | doAnswer(new Answer[Boolean] { 49 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 50 | true 51 | } 52 | }).when(edgeWrapper).create(any(classOf[String]), any(classOf[Vertex]), 53 | any(classOf[Vertex]), any(classOf[Map[String, Object]])) 54 | 55 | doAnswer(new Answer[List[Vertex]] { 56 | override def answer(invocationOnMock: InvocationOnMock): List[Vertex] = { 57 | oVertices 58 | } 59 | }).when(vertexWrapper).read(any(classOf[String]), any(classOf[Array[String]]), 60 | any(classOf[String]), any(classOf[String])) 61 | 62 | doAnswer(new Answer[List[Edge]] { 63 | override def answer(invocationOnMock: InvocationOnMock): List[Edge] = { 64 | oEdges 65 | } 66 | }).when(edgeWrapper).read(any(classOf[String]), any(classOf[Array[String]]), 67 | any(classOf[String]), any(classOf[String])) 68 | 69 | doAnswer(new Answer[Boolean] { 70 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 71 | true 72 | } 73 | }).when(vertexWrapper).delete(any(classOf[String]), 74 | any(classOf[Map[String, Tuple2[String, String]]])) 75 | 76 | doAnswer(new Answer[Boolean] { 77 | override def answer(invocationOnMock: InvocationOnMock): Boolean = { 78 | true 79 | } 80 | }).when(edgeWrapper).delete(any(classOf[String]), 81 | any(classOf[Map[String, Tuple2[String, String]]])) 82 | 83 | doAnswer(new Answer[StructType] { 84 | override def answer(invocationOnMock: InvocationOnMock): StructType = { 85 | existingTablesAndSchemas 86 | .get(invocationOnMock.getArguments()(0).asInstanceOf[String]).get 87 | } 88 | }).when(vertexWrapper).resolveTable(any(classOf[String])) 89 | 90 | doAnswer(new Answer[StructType] { 91 | override def answer(invocationOnMock: InvocationOnMock): StructType = { 92 | existingTablesAndSchemas 93 | .get(invocationOnMock.getArguments()(0).asInstanceOf[String]).get 94 | } 95 | }).when(edgeWrapper).resolveTable(any(classOf[String])) 96 | 97 | doAnswer(new Answer[List[Vertex]] { 98 | override def answer(invocationOnMock: InvocationOnMock): List[Vertex] = { 99 | oVertices 100 | } 101 | }).when(vertexWrapper).genericQuery(any(classOf[String])) 102 | 103 | doAnswer(new Answer[List[Edge]] { 104 | override def answer(invocationOnMock: InvocationOnMock): List[Edge] = { 105 | oEdges 106 | } 107 | }).when(edgeWrapper).genericQuery(any(classOf[String])) 108 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/graphs/MockVertex.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import com.orientechnologies.orient.core.db.record.OIdentifiable 4 | import com.tinkerpop.blueprints.impls.orient.{OrientBaseGraph, OrientVertex} 5 | 6 | class MockVertex(graph: OrientBaseGraph, record: OIdentifiable) 7 | extends OrientVertex(graph, record) { 8 | override def setProperty(key: String, value: Object): Unit = { 9 | this.getRecord().field(key, value) 10 | } 11 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/graphs/OrientDBGraphSourceSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import com.orientechnologies.orient.core.record.impl.ODocument 4 | import com.tinkerpop.blueprints.impls.orient.OrientBaseGraph 5 | import org.apache.spark.{SparkConf, SparkContext} 6 | import org.apache.spark.orientdb.{QueryTest, TestUtils} 7 | import org.apache.spark.sql.sources.{EqualTo, Filter, PrunedFilteredScan} 8 | import org.apache.spark.sql.types._ 9 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession} 10 | import org.mockito.Mockito 11 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 12 | 13 | class OrientDBGraphSourceSuite extends QueryTest 14 | with BeforeAndAfterAll 15 | with BeforeAndAfterEach { 16 | private var sc: SparkContext = _ 17 | private var spark: SparkSession = _ 18 | private var sqlContext: SQLContext = _ 19 | private var mockOrientDBClient: OrientDBClientFactory = _ 20 | private var expectedDataDfVertices: DataFrame = _ 21 | private var expectedDataDfEdges: DataFrame = _ 22 | 23 | override def beforeAll(): Unit = { 24 | spark = SparkSession.builder().appName("OrientDBSourceSuite") 25 | .master("local[*]") 26 | .getOrCreate() 27 | sc = spark.sparkContext; 28 | } 29 | 30 | override def afterAll(): Unit = { 31 | if (spark != null) { 32 | spark.close() 33 | } 34 | } 35 | 36 | override def beforeEach(): Unit = { 37 | sqlContext = spark.sqlContext 38 | mockOrientDBClient = Mockito.mock(classOf[OrientDBClientFactory], 39 | Mockito.RETURNS_SMART_NULLS) 40 | expectedDataDfVertices = sqlContext.createDataFrame( 41 | sc.parallelize(TestUtils.expectedDataForVertices), 42 | TestUtils.testSchemaForVertices) 43 | expectedDataDfEdges = sqlContext.createDataFrame( 44 | sc.parallelize(TestUtils.expectedDataForEdges), 45 | TestUtils.testSchemaForEdges) 46 | } 47 | 48 | override def afterEach(): Unit = { 49 | sqlContext = null 50 | } 51 | 52 | test("Can load output of OrientDB Graph queries on vertices") { 53 | val query = 54 | "select testbyte, testbool from test_vertex where teststring = '\\Unicode''s樂趣'" 55 | 56 | val querySchema = StructType(Seq(StructField("testbyte", ByteType, true), 57 | StructField("testbool", BooleanType, true))) 58 | 59 | { 60 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 61 | "user" -> "root", 62 | "password" -> "root", 63 | "vertextype" -> "test_vertex") 64 | 65 | val oVertex1 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]), 66 | new ODocument()) 67 | oVertex1.setProperty("id", new Integer(1)) 68 | oVertex1.setProperty("testbyte", new java.lang.Byte(1.toByte)) 69 | oVertex1.setProperty("testbool", new java.lang.Boolean(true)) 70 | 71 | val oVertex2 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]), 72 | new ODocument()) 73 | oVertex2.setProperty("id", new Integer(2)) 74 | oVertex2.setProperty("testbyte", new java.lang.Byte(2.toByte)) 75 | oVertex2.setProperty("testbool", new java.lang.Boolean(false)) 76 | 77 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("vertextype") -> querySchema), 78 | List(oVertex1, oVertex2)) 79 | 80 | val relation = new DefaultSource(mockOrientDBGraph.vertexWrapper, null, _ => mockOrientDBClient) 81 | .createRelation(sqlContext, params) 82 | sqlContext.baseRelationToDataFrame(relation).collect() 83 | } 84 | 85 | { 86 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 87 | "user" -> "root", 88 | "password" -> "root", 89 | "vertextype" -> "test_vertex", 90 | "query" -> query) 91 | 92 | val oVertex1 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]), 93 | new ODocument()) 94 | oVertex1.setProperty("id", new Integer(1)) 95 | oVertex1.setProperty("testbyte", new java.lang.Byte(1.toByte)) 96 | oVertex1.setProperty("testbool", new java.lang.Boolean(true)) 97 | 98 | val oVertex2 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]), 99 | new ODocument()) 100 | oVertex2.setProperty("id", new Integer(2)) 101 | oVertex2.setProperty("testbyte", new java.lang.Byte(2.toByte)) 102 | oVertex2.setProperty("testbool", new java.lang.Boolean(false)) 103 | 104 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("vertextype") -> querySchema), 105 | List(oVertex1, oVertex2)) 106 | 107 | val relation = new DefaultSource(mockOrientDBGraph.vertexWrapper, null, _ => mockOrientDBClient) 108 | .createRelation(sqlContext, params) 109 | sqlContext.baseRelationToDataFrame(relation).collect() 110 | } 111 | } 112 | 113 | test("Can load output of OrientDB Graph queries on edges") { 114 | val query = 115 | "select relationship from test_edge where relationship = 'enemy'" 116 | 117 | val querySchema = StructType(Seq(StructField("relationship", StringType))) 118 | 119 | { 120 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 121 | "user" -> "root", 122 | "password" -> "root", 123 | "edgetype" -> "test_edge") 124 | 125 | val oEdge1 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument()) 126 | oEdge1.setProperty("src", new Integer(1)) 127 | oEdge1.setProperty("dst", new Integer(2)) 128 | oEdge1.setProperty("relationship", new String("enemy")) 129 | 130 | val oEdge2 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument()) 131 | oEdge2.setProperty("src", new Integer(1)) 132 | oEdge2.setProperty("dst", new Integer(2)) 133 | oEdge2.setProperty("relationship", new String("friend")) 134 | 135 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("edgetype") -> querySchema), null, 136 | List(oEdge1, oEdge2)) 137 | 138 | val relation = new DefaultSource(null, mockOrientDBGraph.edgeWrapper, _ => mockOrientDBClient) 139 | .createRelation(sqlContext, params) 140 | sqlContext.baseRelationToDataFrame(relation).collect() 141 | } 142 | 143 | { 144 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 145 | "user" -> "root", 146 | "password" -> "root", 147 | "edgetype" -> "test_edge", 148 | "query" -> query) 149 | 150 | val oEdge1 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument()) 151 | oEdge1.setProperty("src", new Integer(1)) 152 | oEdge1.setProperty("dst", new Integer(2)) 153 | oEdge1.setProperty("relationship", new String("enemy")) 154 | 155 | val oEdge2 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument()) 156 | oEdge2.setProperty("src", new Integer(1)) 157 | oEdge2.setProperty("dst", new Integer(2)) 158 | oEdge2.setProperty("relationship", new String("friend")) 159 | 160 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("edgetype") -> querySchema), null, 161 | List(oEdge1, oEdge2)) 162 | 163 | val relation = new DefaultSource(null, mockOrientDBGraph.edgeWrapper, _ => mockOrientDBClient) 164 | .createRelation(sqlContext, params) 165 | sqlContext.baseRelationToDataFrame(relation).collect() 166 | } 167 | } 168 | 169 | test("DefaultSource supports simple column filtering for Vertices") { 170 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 171 | "user" -> "root", 172 | "password" -> "root", 173 | "vertextype" -> "test_vertex") 174 | 175 | val oVertex1 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]), 176 | new ODocument()) 177 | oVertex1.setProperty("id", new Integer(1)) 178 | oVertex1.setProperty("testbyte", new java.lang.Byte(1.toByte)) 179 | oVertex1.setProperty("testbool", new java.lang.Boolean(true)) 180 | 181 | val oVertex2 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]), 182 | new ODocument()) 183 | oVertex2.setProperty("id", new Integer(2)) 184 | oVertex2.setProperty("testbyte", new java.lang.Byte(2.toByte)) 185 | oVertex2.setProperty("testbool", new java.lang.Boolean(false)) 186 | 187 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("vertextype") -> TestUtils.testSchemaForVertices), 188 | List(oVertex1, oVertex2)) 189 | 190 | val source = new DefaultSource(mockOrientDBGraph.vertexWrapper, null, _ => mockOrientDBClient) 191 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForVertices) 192 | 193 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 194 | .buildScan(Array("testbyte", "testbool"), Array.empty[Filter]) 195 | 196 | val prunedExpectedValues = Array( 197 | Row(1.toByte, true), 198 | Row(2.toByte, false) 199 | ) 200 | assert(rdd.collect() === prunedExpectedValues) 201 | } 202 | 203 | test("DefaultSource supports simple column filtering for Edges") { 204 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 205 | "user" -> "root", 206 | "password" -> "root", 207 | "edgetype" -> "test_edge") 208 | 209 | val oEdge1 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument()) 210 | oEdge1.setProperty("src", new Integer(1)) 211 | oEdge1.setProperty("dst", new Integer(2)) 212 | oEdge1.setProperty("relationship", new String("enemy")) 213 | 214 | val oEdge2 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument()) 215 | oEdge2.setProperty("src", new Integer(1)) 216 | oEdge2.setProperty("dst", new Integer(2)) 217 | oEdge2.setProperty("relationship", new String("friend")) 218 | 219 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("edgetype") -> TestUtils.testSchemaForEdges), 220 | null, List(oEdge1, oEdge2)) 221 | 222 | val source = new DefaultSource(null, mockOrientDBGraph.edgeWrapper, _ => mockOrientDBClient) 223 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForEdges) 224 | 225 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 226 | .buildScan(Array("relationship"), Array.empty[Filter]) 227 | 228 | val prunedExpectedValues = Array( 229 | Row("enemy"), Row("friend") 230 | ) 231 | assert(rdd.collect() === prunedExpectedValues) 232 | } 233 | 234 | test("DefaultSource supports user schema, pruned and filtered scans for Vertices") { 235 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 236 | "user" -> "root", 237 | "password" -> "root", 238 | "vertextype" -> "test_vertex") 239 | 240 | val oVertex1 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]), 241 | new ODocument()) 242 | oVertex1.setProperty("id", new Integer(1)) 243 | oVertex1.setProperty("testbyte", new java.lang.Byte(1.toByte)) 244 | oVertex1.setProperty("testbool", new java.lang.Boolean(true)) 245 | 246 | val oVertex2 = new MockVertex(Mockito.mock(classOf[OrientBaseGraph]), 247 | new ODocument()) 248 | oVertex2.setProperty("id", new Integer(2)) 249 | oVertex2.setProperty("testbyte", new java.lang.Byte(2.toByte)) 250 | oVertex2.setProperty("testbool", new java.lang.Boolean(false)) 251 | 252 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("vertextype") -> TestUtils.testSchemaForVertices), 253 | List(oVertex1, oVertex2)) 254 | 255 | val source = new DefaultSource(mockOrientDBGraph.vertexWrapper, null, _ => mockOrientDBClient) 256 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForVertices) 257 | 258 | val filters: Array[Filter] = Array( 259 | EqualTo("testbool", true), 260 | EqualTo("teststring", "Hello") 261 | ) 262 | 263 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 264 | .buildScan(Array("testbyte", "testbool"), filters) 265 | 266 | assert(rdd.collect().contains(Row(1, true))) 267 | } 268 | 269 | test("DefaultSource supports user schema, pruned and filtered scans for Edges") { 270 | val params = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 271 | "user" -> "root", 272 | "password" -> "root", 273 | "edgetype" -> "test_edge") 274 | 275 | val oEdge2 = new MockEdge(Mockito.mock(classOf[OrientBaseGraph]), new ODocument()) 276 | oEdge2.setProperty("src", new Integer(1)) 277 | oEdge2.setProperty("dst", new Integer(2)) 278 | oEdge2.setProperty("relationship", new String("friend")) 279 | 280 | val mockOrientDBGraph = new MockOrientDBGraph(Map(params("edgetype") -> TestUtils.testSchemaForEdges), 281 | null, List(oEdge2)) 282 | 283 | val source = new DefaultSource(null, mockOrientDBGraph.edgeWrapper, _ => mockOrientDBClient) 284 | val relation = source.createRelation(sqlContext, params, TestUtils.testSchemaForEdges) 285 | 286 | val filters: Array[Filter] = Array( 287 | EqualTo("relationship", "friend") 288 | ) 289 | 290 | val rdd = relation.asInstanceOf[PrunedFilteredScan] 291 | .buildScan(Array("relationship"), filters) 292 | 293 | val prunedExpectedValues = Array( 294 | Row("friend") 295 | ) 296 | assert(rdd.collect() === prunedExpectedValues) 297 | } 298 | 299 | test("Cannot save when 'query' parameter is specified instead of 'vertextype' for Vertices") { 300 | val invalidParams = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 301 | "user" -> "root", 302 | "password" -> "root", 303 | "query" -> "select testbyte, testbool from test_vertex where teststring = '\\Unicode''s樂趣'") 304 | 305 | intercept[IllegalArgumentException] { 306 | expectedDataDfVertices.write.format("org.apache.spark.orientdb.graphs") 307 | .options(invalidParams).save() 308 | } 309 | } 310 | 311 | test("Cannot save when 'query' parameter is specified instead of 'edgetype' for Edges") { 312 | val invalidParams = Map("dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 313 | "user" -> "root", 314 | "password" -> "root", 315 | "query" -> "select relationship from test_edge where relationship = 'enemy'") 316 | 317 | intercept[IllegalArgumentException] { 318 | expectedDataDfEdges.write.format("org.apache.spark.orientdb.graphs") 319 | .options(invalidParams).save() 320 | } 321 | } 322 | 323 | test("DefaultSource has default constructor, required by Data Source API") { 324 | new DefaultSource() 325 | } 326 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/orientdb/graphs/ParametersSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.orientdb.graphs 2 | 3 | import org.scalatest.{FunSuite, Matchers} 4 | 5 | class ParametersSuite extends FunSuite with Matchers { 6 | 7 | test("Minimal valid parameter map is accepted for vertices") { 8 | val params = Map( 9 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 10 | "user" -> "root", 11 | "password" -> "root", 12 | "vertextype" -> "test_vertex" 13 | ) 14 | 15 | val mergedParams = Parameters.mergeParameters(params) 16 | 17 | mergedParams.dbUrl.get shouldBe params("dburl") 18 | mergedParams.credentials.get._1 shouldBe params("user") 19 | mergedParams.credentials.get._2 shouldBe params("password") 20 | mergedParams.vertexType.get shouldBe params("vertextype") 21 | 22 | Parameters.DEFAULT_PARAMETERS.foreach { 23 | case (key, value) => mergedParams.parameters(key) shouldBe value 24 | } 25 | } 26 | 27 | test("Minimal valid parameter map is accepted for edges") { 28 | val params = Map( 29 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 30 | "user" -> "root", 31 | "password" -> "root", 32 | "edgetype" -> "test_edge" 33 | ) 34 | 35 | val mergedParams = Parameters.mergeParameters(params) 36 | 37 | mergedParams.dbUrl.get shouldBe params("dburl") 38 | mergedParams.credentials.get._1 shouldBe params("user") 39 | mergedParams.credentials.get._2 shouldBe params("password") 40 | mergedParams.edgeType.get shouldBe params("edgetype") 41 | 42 | Parameters.DEFAULT_PARAMETERS.foreach { 43 | case (key, value) => mergedParams.parameters(key) shouldBe value 44 | } 45 | } 46 | 47 | test("Errors are thrown when mandatory parameters are not provided") { 48 | def checkMerge(params: Map[String, String]): Unit = { 49 | intercept[IllegalArgumentException] { 50 | Parameters.mergeParameters(params) 51 | } 52 | } 53 | 54 | val testURL = "remote:127.0.0.1:2424/GratefulDeadConcerts" 55 | checkMerge(Map("dburl" -> testURL, "user" -> "root", "password" -> "root")) 56 | checkMerge(Map("user" -> "root", "password" -> "root")) 57 | checkMerge(Map("dburl" -> testURL, "vertextype" -> "test_vertex")) 58 | checkMerge(Map("user" -> "root", "password" -> "root", "edgetype" -> "test_edge")) 59 | } 60 | 61 | test("Must specify either 'vertextype' param, " + 62 | "or, 'vertextype' and 'query' parameter, " + 63 | "or 'vertextype' and 'query' parameter and user-defined schema") { 64 | intercept[IllegalArgumentException] { 65 | Parameters.mergeParameters(Map( 66 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 67 | "user" -> "root", 68 | "password" -> "root" 69 | )) 70 | } 71 | 72 | Parameters.mergeParameters(Map( 73 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 74 | "user" -> "root", 75 | "password" -> "root", 76 | "vertextype" -> "test_vertex" 77 | )) 78 | 79 | Parameters.mergeParameters(Map( 80 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 81 | "user" -> "root", 82 | "password" -> "root", 83 | "vertextype" -> "test_vertex", 84 | "query" -> "select * from test_vertex" 85 | )) 86 | } 87 | 88 | test("Must specify either 'edgetype' param, " + 89 | "or, 'edgetype' and 'query' parameter, " + 90 | "or 'edgetype' and 'query' parameter and user-defined schema") { 91 | intercept[IllegalArgumentException] { 92 | Parameters.mergeParameters(Map( 93 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 94 | "user" -> "root", 95 | "password" -> "root" 96 | )) 97 | } 98 | 99 | Parameters.mergeParameters(Map( 100 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 101 | "user" -> "root", 102 | "password" -> "root", 103 | "edgetype" -> "test_edge" 104 | )) 105 | 106 | Parameters.mergeParameters(Map( 107 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 108 | "user" -> "root", 109 | "password" -> "root", 110 | "edgetype" -> "test_edge", 111 | "query" -> "select * from test_edge" 112 | )) 113 | } 114 | 115 | test("Must specify 'url' parameter and 'user' and 'password' parameters") { 116 | intercept[IllegalArgumentException] { 117 | Parameters.mergeParameters(Map( 118 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts" 119 | )) 120 | } 121 | 122 | intercept[IllegalArgumentException] { 123 | Parameters.mergeParameters(Map( 124 | "user" -> "root", 125 | "password" -> "root" 126 | )) 127 | } 128 | 129 | intercept[IllegalArgumentException] { 130 | Parameters.mergeParameters(Map( 131 | "dburl" -> "remote:127.0.0.1:2424/GratefulDeadConcerts", 132 | "password" -> "root" 133 | )) 134 | } 135 | } 136 | } --------------------------------------------------------------------------------