├── docker ├── jwtSecret ├── jwtHeader ├── server.pem ├── start_db.sh └── start_db_macos.sh ├── demo ├── docker │ ├── jwtSecret │ ├── .ivy2 │ │ ├── cache │ │ │ └── .gitignore │ │ └── jars │ │ │ └── .gitignore │ ├── jwtHeader │ ├── stop.sh │ ├── start_spark.sh │ ├── server.pem │ └── start_db.sh ├── python-demo │ ├── requirements.txt │ ├── utils.py │ ├── read_write_demo.py │ ├── schemas.py │ ├── read_demo.py │ ├── demo.py │ └── write_demo.py ├── src │ ├── test │ │ ├── scala │ │ │ └── DemoTest.scala │ │ └── resources │ │ │ └── log4j.xml │ └── main │ │ └── scala │ │ ├── ReadWriteDemo.scala │ │ ├── Schemas.scala │ │ ├── Demo.scala │ │ ├── ReadDemo.scala │ │ └── WriteDemo.scala └── README.md ├── python-integration-tests ├── integration │ ├── __init__.py │ ├── write │ │ ├── __init__.py │ │ └── test_savemode.py │ ├── utils.py │ ├── conftest.py │ ├── test_composite_filter.py │ ├── test_deserialization_cast.py │ ├── test_readwrite_datatype.py │ └── test_bad_records.py └── test-requirements.txt ├── integration-tests ├── src │ └── test │ │ ├── resources │ │ ├── allure.properties │ │ ├── cert.p12 │ │ └── log4j2.properties │ │ └── scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── sql │ │ └── arangodb │ │ ├── datasource │ │ ├── User.scala │ │ ├── AcquireHostListTest.scala │ │ ├── TestUtils.scala │ │ ├── write │ │ │ ├── WriteWithNullKey.scala │ │ │ ├── CreateCollectionTest.scala │ │ │ ├── WriteResiliencyTest.scala │ │ │ └── EdgeSchemaValidationTest.scala │ │ ├── ReadSmartEdgeCollectionTest.scala │ │ ├── StringFiltersTest.scala │ │ ├── CompositeFilterTest.scala │ │ ├── DeserializationCastTest.scala │ │ └── BadRecordsTest.scala │ │ ├── examples │ │ ├── PushdownExample.scala │ │ └── DataTypesExample.scala │ │ └── JacksonTest.scala └── pom.xml ├── .gitignore ├── arangodb-spark-datasource-3.4 ├── src │ └── main │ │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ ├── org.apache.spark.sql.sources.DataSourceRegister │ │ │ ├── org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider │ │ │ └── org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider │ │ └── scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── sql │ │ └── arangodb │ │ └── datasource │ │ └── mapping │ │ ├── package.scala │ │ ├── ArangoGeneratorImpl.scala │ │ ├── ArangoParserImpl.scala │ │ └── json │ │ ├── JacksonUtils.scala │ │ └── CreateJacksonParser.scala └── pom.xml ├── arangodb-spark-datasource-3.5 ├── src │ └── main │ │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ ├── org.apache.spark.sql.sources.DataSourceRegister │ │ │ ├── org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider │ │ │ └── org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider │ │ └── scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── sql │ │ └── arangodb │ │ └── datasource │ │ └── mapping │ │ ├── package.scala │ │ ├── ArangoGeneratorImpl.scala │ │ ├── ArangoParserImpl.scala │ │ └── json │ │ ├── JacksonUtils.scala │ │ └── CreateJacksonParser.scala └── pom.xml ├── bin ├── clean.sh └── test.sh ├── arangodb-spark-commons ├── src │ ├── main │ │ └── scala │ │ │ ├── org │ │ │ └── apache │ │ │ │ └── spark │ │ │ │ └── sql │ │ │ │ └── arangodb │ │ │ │ ├── commons │ │ │ │ ├── exceptions │ │ │ │ │ ├── DataWriteAbortException.scala │ │ │ │ │ ├── ArangoDBDataWriterException.scala │ │ │ │ │ └── ArangoDBMultiException.scala │ │ │ │ ├── utils │ │ │ │ │ └── PushDownCtx.scala │ │ │ │ ├── mapping │ │ │ │ │ ├── ArangoParserProvider.scala │ │ │ │ │ ├── MappingUtils.scala │ │ │ │ │ ├── ArangoGeneratorProvider.scala │ │ │ │ │ ├── ArangoParser.scala │ │ │ │ │ └── ArangoGenerator.scala │ │ │ │ ├── filter │ │ │ │ │ ├── FilterSupport.scala │ │ │ │ │ └── package.scala │ │ │ │ ├── ArangoUtils.scala │ │ │ │ ├── PushdownUtils.scala │ │ │ │ └── package.scala │ │ │ │ └── datasource │ │ │ │ ├── reader │ │ │ │ ├── ArangoCollectionPartition.scala │ │ │ │ ├── ArangoPartitionReaderFactory.scala │ │ │ │ ├── ArangoScan.scala │ │ │ │ ├── ArangoQueryReader.scala │ │ │ │ ├── ArangoCollectionPartitionReader.scala │ │ │ │ └── ArangoScanBuilder.scala │ │ │ │ ├── writer │ │ │ │ ├── ArangoDataWriterFactory.scala │ │ │ │ ├── ArangoBatchWriter.scala │ │ │ │ └── ArangoWriterBuilder.scala │ │ │ │ └── ArangoTable.scala │ │ │ └── com │ │ │ └── arangodb │ │ │ └── spark │ │ │ └── DefaultSource.scala │ └── test │ │ └── scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── sql │ │ └── arangodb │ │ └── commons │ │ ├── filter │ │ ├── IsNullTest.scala │ │ ├── PackageTest.scala │ │ ├── NotFilterTest.scala │ │ ├── StringEndsWithFilterTest.scala │ │ ├── StringContainsFilterTest.scala │ │ ├── StringStartsWithFilterTest.scala │ │ ├── OrFilterTest.scala │ │ └── AndFilterTest.scala │ │ ├── exceptions │ │ └── ExceptionsSerializationTest.scala │ │ └── ColumnsPruningTest.scala └── pom.xml ├── .circleci └── maven-release-settings.xml ├── README.md ├── dev-README.md └── ChangeLog.md /docker/jwtSecret: -------------------------------------------------------------------------------- 1 | Averysecretword 2 | -------------------------------------------------------------------------------- /demo/docker/jwtSecret: -------------------------------------------------------------------------------- 1 | Averysecretword 2 | -------------------------------------------------------------------------------- /python-integration-tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python-integration-tests/integration/write/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/docker/.ivy2/cache/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /demo/docker/.ivy2/jars/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /demo/python-demo/requirements.txt: -------------------------------------------------------------------------------- 1 | pyspark[pandas_on_spark]==3.5.7 2 | -------------------------------------------------------------------------------- /python-integration-tests/test-requirements.txt: -------------------------------------------------------------------------------- 1 | python-arango==7.3.4 2 | pytest==7.1.2 -------------------------------------------------------------------------------- /integration-tests/src/test/resources/allure.properties: -------------------------------------------------------------------------------- 1 | allure.results.directory=target/allure-results 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | /target/ 4 | **/target/ 5 | *.iml 6 | **/.idea/ 7 | .directory 8 | **/.flattened-pom.xml 9 | __pycache__/ -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | com.arangodb.spark.DefaultSource 2 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | com.arangodb.spark.DefaultSource 2 | -------------------------------------------------------------------------------- /integration-tests/src/test/resources/cert.p12: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arangodb/arangodb-spark-datasource/HEAD/integration-tests/src/test/resources/cert.p12 -------------------------------------------------------------------------------- /docker/jwtHeader: -------------------------------------------------------------------------------- 1 | Authorization: bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmFuZ29kYiIsInNlcnZlcl9pZCI6ImZvbyJ9.QmuhPHkmRPJuHGxsEqggHGRyVXikV44tb5YU_yWEvEM 2 | -------------------------------------------------------------------------------- /demo/docker/jwtHeader: -------------------------------------------------------------------------------- 1 | Authorization: bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmFuZ29kYiIsInNlcnZlcl9pZCI6ImZvbyJ9.QmuhPHkmRPJuHGxsEqggHGRyVXikV44tb5YU_yWEvEM 2 | -------------------------------------------------------------------------------- /demo/python-demo/utils.py: -------------------------------------------------------------------------------- 1 | def combine_dicts(list_of_dicts): 2 | whole_dict = {} 3 | for d in list_of_dicts: 4 | whole_dict.update(d) 5 | return whole_dict 6 | -------------------------------------------------------------------------------- /bin/clean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mvn clean -Pspark-3.4 -Pscala-2.12 4 | mvn clean -Pspark-3.4 -Pscala-2.13 5 | mvn clean -Pspark-3.5 -Pscala-2.12 6 | mvn clean -Pspark-3.5 -Pscala-2.13 7 | -------------------------------------------------------------------------------- /python-integration-tests/integration/utils.py: -------------------------------------------------------------------------------- 1 | def combine_dicts(list_of_dicts): 2 | whole_dict = {} 3 | for d in list_of_dicts: 4 | whole_dict.update(d) 5 | return whole_dict -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider: -------------------------------------------------------------------------------- 1 | org.apache.spark.sql.arangodb.datasource.mapping.ArangoGeneratorProviderImpl -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider: -------------------------------------------------------------------------------- 1 | org.apache.spark.sql.arangodb.datasource.mapping.ArangoParserProviderImpl 2 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider: -------------------------------------------------------------------------------- 1 | org.apache.spark.sql.arangodb.datasource.mapping.ArangoGeneratorProviderImpl -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider: -------------------------------------------------------------------------------- 1 | org.apache.spark.sql.arangodb.datasource.mapping.ArangoParserProviderImpl 2 | -------------------------------------------------------------------------------- /demo/docker/stop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker exec adb /app/arangodb stop 4 | sleep 1 5 | docker rm -f \ 6 | adb \ 7 | spark-master \ 8 | spark-worker-1 \ 9 | spark-worker-2 \ 10 | spark-worker-3 11 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/exceptions/DataWriteAbortException.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.exceptions 2 | 3 | class DataWriteAbortException(message: String) extends RuntimeException(message) 4 | -------------------------------------------------------------------------------- /demo/src/test/scala/DemoTest.scala: -------------------------------------------------------------------------------- 1 | import org.junit.jupiter.api.Test 2 | 3 | class DemoTest { 4 | 5 | @Test 6 | def testDemo(): Unit = { 7 | System.setProperty("importPath", "docker/import") 8 | Demo.main(Array.empty) 9 | } 10 | 11 | } 12 | -------------------------------------------------------------------------------- /bin/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # exit when any command fails 4 | set -e 5 | 6 | mvn clean -Pspark-3.4 -Pscala-2.12 7 | mvn test -Pspark-3.4 -Pscala-2.12 8 | 9 | mvn clean -Pspark-3.5 -Pscala-2.12 10 | mvn test -Pspark-3.5 -Pscala-2.12 11 | 12 | 13 | mvn clean -Pspark-3.4 -Pscala-2.13 14 | mvn test -Pspark-3.4 -Pscala-2.13 15 | 16 | mvn clean -Pspark-3.5 -Pscala-2.13 17 | mvn test -Pspark-3.5 -Pscala-2.13 18 | -------------------------------------------------------------------------------- /python-integration-tests/integration/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from integration.test_basespark import arangodb_client, database_conn, spark, endpoints, single_endpoint, adb_hostname 3 | 4 | 5 | def pytest_addoption(parser): 6 | parser.addoption("--adb-datasource-jar", action="store", dest="datasource_jar_loc", required=True) 7 | parser.addoption("--adb-hostname", action="store", dest="adb_hostname", default="172.28.0.1") 8 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/User.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import java.sql.Date 4 | 5 | case class User( 6 | name: Name, 7 | birthday: Date, 8 | gender: String, 9 | likes: List[String] 10 | ) 11 | 12 | case class Name( 13 | first: String, 14 | last: String 15 | ) -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/utils/PushDownCtx.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.utils 2 | 3 | import org.apache.spark.sql.arangodb.commons.filter.PushableFilter 4 | import org.apache.spark.sql.types.StructType 5 | 6 | class PushDownCtx( 7 | // columns projection to return 8 | val requiredSchema: StructType, 9 | 10 | // filters to push down 11 | val filters: Array[PushableFilter] 12 | ) 13 | extends Serializable 14 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/ArangoParserProvider.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.mapping 2 | 3 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} 4 | import org.apache.spark.sql.types.DataType 5 | 6 | import java.util.ServiceLoader 7 | 8 | trait ArangoParserProvider { 9 | def of(contentType: ContentType, schema: DataType, conf: ArangoDBConf): ArangoParser 10 | } 11 | 12 | object ArangoParserProvider { 13 | def apply(): ArangoParserProvider = ServiceLoader.load(classOf[ArangoParserProvider]).iterator().next() 14 | } 15 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/MappingUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.mapping 2 | 3 | import com.arangodb.jackson.dataformat.velocypack.VPackMapper 4 | import com.fasterxml.jackson.databind.ObjectMapper 5 | 6 | object MappingUtils { 7 | 8 | private val jsonMapper = new ObjectMapper() 9 | private val vpackMapper = new VPackMapper() 10 | 11 | def vpackToJson(in: Array[Byte]): String = jsonMapper.writeValueAsString(vpackMapper.readTree(in)) 12 | 13 | def jsonToVPack(in: String): Array[Byte] = vpackMapper.writeValueAsBytes(jsonMapper.readTree(in)) 14 | 15 | } 16 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.reader 2 | 3 | import org.apache.spark.sql.connector.read.InputPartition 4 | 5 | /** 6 | * Partition corresponding to an Arango collection shard 7 | * @param shardId collection shard id 8 | * @param endpoint db endpoint to use to query the partition 9 | */ 10 | class ArangoCollectionPartition(val shardId: String, val endpoint: String) extends InputPartition 11 | 12 | /** 13 | * Custom user queries will not be partitioned (eg. AQL traversals) 14 | */ 15 | object SingletonPartition extends InputPartition 16 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.writer 2 | 3 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 4 | import org.apache.spark.sql.catalyst.InternalRow 5 | import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} 6 | import org.apache.spark.sql.types.StructType 7 | 8 | class ArangoDataWriterFactory(schema: StructType, options: ArangoDBConf) extends DataWriterFactory { 9 | override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { 10 | new ArangoDataWriter(schema, options, partitionId) 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /demo/src/test/resources/log4j.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/ArangoGeneratorProvider.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.mapping 2 | 3 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} 4 | import org.apache.spark.sql.types.StructType 5 | 6 | import java.io.OutputStream 7 | import java.util.ServiceLoader 8 | 9 | trait ArangoGeneratorProvider { 10 | def of(contentType: ContentType, schema: StructType, outputStream: OutputStream, conf: ArangoDBConf): ArangoGenerator 11 | } 12 | 13 | object ArangoGeneratorProvider { 14 | def apply(): ArangoGeneratorProvider = ServiceLoader.load(classOf[ArangoGeneratorProvider]).iterator().next() 15 | } 16 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import com.fasterxml.jackson.core.JsonFactory 4 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 5 | import org.apache.spark.sql.arangodb.datasource.mapping.json.JSONOptions 6 | 7 | package object mapping { 8 | private[mapping] def createOptions(jsonFactory: JsonFactory, conf: ArangoDBConf) = 9 | new JSONOptions(Map.empty[String, String], "UTC") { 10 | override def buildJsonFactory(): JsonFactory = jsonFactory 11 | 12 | override val ignoreNullFields: Boolean = conf.mappingOptions.ignoreNullFields 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import com.fasterxml.jackson.core.JsonFactory 4 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 5 | import org.apache.spark.sql.arangodb.datasource.mapping.json.JSONOptions 6 | 7 | package object mapping { 8 | private[mapping] def createOptions(jsonFactory: JsonFactory, conf: ArangoDBConf) = 9 | new JSONOptions(Map.empty[String, String], "UTC") { 10 | override def buildJsonFactory(): JsonFactory = jsonFactory 11 | 12 | override val ignoreNullFields: Boolean = conf.mappingOptions.ignoreNullFields 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/filter/FilterSupport.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | 4 | sealed trait FilterSupport 5 | 6 | object FilterSupport { 7 | 8 | /** 9 | * the filter can be applied and does not need to be evaluated again after scanning 10 | */ 11 | case object FULL extends FilterSupport 12 | 13 | /** 14 | * the filter can be partially applied and it needs to be evaluated again after scanning 15 | */ 16 | case object PARTIAL extends FilterSupport 17 | 18 | /** 19 | * the filter cannot be applied and it needs to be evaluated again after scanning 20 | */ 21 | case object NONE extends FilterSupport 22 | } 23 | -------------------------------------------------------------------------------- /demo/python-demo/read_write_demo.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from pyspark.sql import SparkSession 4 | 5 | import read_demo 6 | import write_demo 7 | from schemas import movie_schema 8 | 9 | 10 | def read_write_demo(spark: SparkSession, opts: Dict[str, str]): 11 | print("-----------------------") 12 | print("--- READ-WRITE DEMO ---") 13 | print("-----------------------") 14 | 15 | print("Reading 'movies' collection and writing 'actionMovies' collection...") 16 | action_movies_df = read_demo.read_collection(spark, "movies", opts, movie_schema)\ 17 | .select("_key", "title", "releaseDate", "runtime", "description")\ 18 | .filter("genre = 'Action'") 19 | write_demo.save_df(action_movies_df.to_pandas_on_spark(), "actionMovies", opts) 20 | print("You can now view the actionMovies collection in ArangoDB!") -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.reader 2 | 3 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 4 | import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx 5 | import org.apache.spark.sql.catalyst.InternalRow 6 | import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} 7 | 8 | class ArangoPartitionReaderFactory(ctx: PushDownCtx, options: ArangoDBConf) extends PartitionReaderFactory { 9 | override def createReader(partition: InputPartition): PartitionReader[InternalRow] = partition match { 10 | case p: ArangoCollectionPartition => new ArangoCollectionPartitionReader(p, ctx, options) 11 | case SingletonPartition => new ArangoQueryReader(ctx.requiredSchema, options) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/ArangoParser.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.mapping 2 | 3 | import com.fasterxml.jackson.core.{JsonFactory, JsonParser} 4 | import org.apache.spark.sql.catalyst.InternalRow 5 | import org.apache.spark.unsafe.types.UTF8String 6 | 7 | trait ArangoParser { 8 | 9 | /** 10 | * Parse the JSON or VPACK input to the set of [[InternalRow]]s. 11 | * 12 | * @param recordLiteral an optional function that will be used to generate 13 | * the corrupt record text instead of record.toString 14 | */ 15 | def parse[T]( 16 | record: T, 17 | createParser: (JsonFactory, T) => JsonParser, 18 | recordLiteral: T => UTF8String): Iterable[InternalRow] 19 | 20 | def parse(data: Array[Byte]): Iterable[InternalRow] 21 | 22 | } 23 | -------------------------------------------------------------------------------- /.circleci/maven-release-settings.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | central 7 | 8 | true 9 | 10 | 11 | ${env.GPG_KEYNAME} 12 | ${env.GPG_PASSPHRASE} 13 | 14 | 15 | 16 | 17 | 18 | 19 | central 20 | ${env.CENTRAL_USERNAME} 21 | ${env.CENTRAL_PASSWORD} 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/exceptions/ArangoDBDataWriterException.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.exceptions 2 | 3 | /** 4 | * Exception thrown after all writes attempts have failed. 5 | * It contains the exceptions thrown at each attempt. 6 | * 7 | * @param exceptions array of exceptions thrown at each attempt 8 | */ 9 | class ArangoDBDataWriterException(val exceptions: Array[Exception]) 10 | extends RuntimeException(s"Failed ${exceptions.length} times: ${ArangoDBDataWriterException.toMessage(exceptions)}") { 11 | val attempts: Int = exceptions.length 12 | 13 | override def getCause: Throwable = exceptions(0) 14 | } 15 | 16 | private object ArangoDBDataWriterException { 17 | 18 | // creates exception message 19 | private def toMessage(exceptions: Array[Exception]): String = exceptions 20 | .zipWithIndex 21 | .map(it => s"""Attempt #${it._2 + 1}: ${it._1}""") 22 | .mkString("[\n\t", ",\n\t", "\n]") 23 | 24 | } 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![ArangoDB-Logo](https://arangodb.com/wp-content/uploads/2023/09/ArangoDB-dark-logo-2022.png) 2 | 3 | # ArangoDB Datasource for Apache Spark 4 | [![Maven Central](https://maven-badges.herokuapp.com/maven-central/com.arangodb/arangodb-spark-datasource-3.5_2.12/badge.svg)](https://maven-badges.herokuapp.com/maven-central/com.arangodb/arangodb-spark-datasource-3.5_2.12) 5 | [![Actions Status](https://github.com/arangodb/arangodb-spark-datasource/workflows/Java%20CI/badge.svg)](https://github.com/arangodb/arangodb-spark-datasource/actions) 6 | [![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=arangodb_arangodb-spark-datasource&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=arangodb_arangodb-spark-datasource) 7 | 8 | The official [ArangoDB](https://www.arangodb.com/) Datasource connector for Apache Spark. 9 | 10 | ## Learn more 11 | - [ChangeLog](ChangeLog.md) 12 | - [Demo](./demo) 13 | - [Documentation](https://www.arangodb.com/docs/stable/drivers/spark-connector-new.html) 14 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/AcquireHostListTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import com.arangodb.spark.DefaultSource 4 | import org.apache.spark.sql.SparkSession 5 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 6 | import org.junit.jupiter.api.{Disabled, Test} 7 | 8 | @Disabled("manual test only") 9 | class AcquireHostListTest { 10 | 11 | private val spark: SparkSession = SparkSession.builder() 12 | .appName("ArangoDBSparkTest") 13 | .master("local[*]") 14 | .config("spark.driver.host", "127.0.0.1") 15 | .getOrCreate() 16 | 17 | @Test 18 | def read(): Unit = { 19 | spark.read 20 | .format(classOf[DefaultSource].getName) 21 | .options(Map( 22 | ArangoDBConf.COLLECTION -> "_fishbowl", 23 | ArangoDBConf.ENDPOINTS -> "172.28.0.1:8529", 24 | ArangoDBConf.ACQUIRE_HOST_LIST -> "true", 25 | ArangoDBConf.PASSWORD -> "test" 26 | )) 27 | .load() 28 | .show() 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/ArangoGenerator.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.mapping 2 | 3 | import org.apache.spark.sql.catalyst.InternalRow 4 | import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} 5 | 6 | trait ArangoGenerator { 7 | 8 | def close(): Unit 9 | 10 | def flush(): Unit 11 | 12 | def writeStartArray(): Unit 13 | 14 | def writeEndArray(): Unit 15 | 16 | /** 17 | * Transforms a single `InternalRow` to JSON or VPACK object. 18 | * 19 | * @param row The row to convert 20 | */ 21 | def write(row: InternalRow): Unit 22 | 23 | /** 24 | * Transforms multiple `InternalRow`s or `MapData`s to JSON or VPACK array 25 | * 26 | * @param array The array of rows or maps to convert 27 | */ 28 | def write(array: ArrayData): Unit 29 | 30 | /** 31 | * Transforms a single `MapData` to JSON or VPACK object 32 | * 33 | * @param map a map to convert 34 | */ 35 | def write(map: MapData): Unit 36 | 37 | def writeLineEnding(): Unit 38 | } 39 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/IsNullTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | import org.apache.spark.sql.sources.{IsNull, IsNotNull} 4 | import org.apache.spark.sql.types._ 5 | import org.assertj.core.api.Assertions.assertThat 6 | import org.junit.jupiter.api.Test 7 | 8 | class IsNullTest { 9 | private val schema = StructType(Array( 10 | StructField("a", StructType(Array( 11 | StructField("b", StringType) 12 | ))) 13 | )) 14 | 15 | @Test 16 | def isNull(): Unit = { 17 | val isNullFilter = PushableFilter(IsNull("a.b"), schema) 18 | assertThat(isNullFilter.support()).isEqualTo(FilterSupport.FULL) 19 | assertThat(isNullFilter.aql("d")).isEqualTo("""`d`.`a`.`b` == null""") 20 | } 21 | 22 | @Test 23 | def isNotNull(): Unit = { 24 | val isNotNullFilter = PushableFilter(IsNotNull("a.b"), schema) 25 | assertThat(isNotNullFilter.support()).isEqualTo(FilterSupport.FULL) 26 | assertThat(isNotNullFilter.aql("d")).isEqualTo("""`d`.`a`.`b` != null""") 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/exceptions/ExceptionsSerializationTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.exceptions 2 | 3 | import com.arangodb.entity.ErrorEntity 4 | import com.fasterxml.jackson.databind.ObjectMapper 5 | import org.junit.jupiter.api.Test 6 | 7 | import java.io.{ByteArrayOutputStream, ObjectOutputStream} 8 | 9 | class ExceptionsSerializationTest { 10 | 11 | @Test 12 | def arangoDBMultiException(): Unit = { 13 | val mapper = new ObjectMapper() 14 | val errors = Stream.range(1, 10000).map(_ => 15 | ( 16 | mapper.readValue( 17 | """{ 18 | |"errorMessage": "errorMessage", 19 | |"errorNum": 1234 20 | |}""".stripMargin, classOf[ErrorEntity]), 21 | "record" 22 | ) 23 | ) 24 | val e = new ArangoDBMultiException(errors.toArray) 25 | val objectOutputStream = new ObjectOutputStream(new ByteArrayOutputStream()) 26 | objectOutputStream.writeObject(e) 27 | objectOutputStream.flush() 28 | objectOutputStream.close() 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/TestUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | object TestUtils { 4 | def isAtLeastVersion(version: String, otherMajor: Int, otherMinor: Int, otherPatch: Int): Boolean = compareVersion(version, otherMajor, otherMinor, otherPatch) >= 0 5 | 6 | def isLessThanVersion(version: String, otherMajor: Int, otherMinor: Int, otherPatch: Int): Boolean = compareVersion(version, otherMajor, otherMinor, otherPatch) < 0 7 | 8 | private def compareVersion(version: String, otherMajor: Int, otherMinor: Int, otherPatch: Int): Int = { 9 | val parts = version.split("-")(0).split("\\.") 10 | val major = parts(0).toInt 11 | val minor = parts(1).toInt 12 | val patch = parts(2).toInt 13 | val majorComparison = Integer.compare(major, otherMajor) 14 | if (majorComparison != 0) { 15 | return majorComparison 16 | } 17 | val minorComparison = Integer.compare(minor, otherMinor) 18 | if (minorComparison != 0) { 19 | return minorComparison 20 | } 21 | Integer.compare(patch, otherPatch) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons 2 | 3 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 4 | import org.apache.spark.sql.{Encoders, SparkSession} 5 | 6 | /** 7 | * @author Michele Rastelli 8 | */ 9 | object ArangoUtils { 10 | 11 | def inferSchema(options: ArangoDBConf): StructType = { 12 | val client = ArangoClient(options) 13 | val sampleEntries = options.readOptions.readMode match { 14 | case ReadMode.Query => client.readQuerySample() 15 | case ReadMode.Collection => client.readCollectionSample() 16 | } 17 | client.shutdown() 18 | 19 | val spark = SparkSession.getActiveSession.get 20 | val schema = spark 21 | .read 22 | .json(spark.createDataset(sampleEntries)(Encoders.STRING)) 23 | .schema 24 | 25 | if (options.readOptions.columnNameOfCorruptRecord.isEmpty) { 26 | schema 27 | } else { 28 | schema.add(StructField(options.readOptions.columnNameOfCorruptRecord, StringType, nullable = true)) 29 | } 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/ColumnsPruningTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons 2 | 3 | import org.apache.spark.sql.types._ 4 | import org.assertj.core.api.Assertions.assertThat 5 | import org.junit.jupiter.api.Test 6 | 7 | class ColumnsPruningTest { 8 | 9 | @Test 10 | def generateAqlReturnClause(): Unit = { 11 | val schema = StructType(Array( 12 | StructField(""""birthday"""", DateType), 13 | StructField("gender", StringType), 14 | StructField("likes", ArrayType(StringType)), 15 | StructField("name", StructType(Array( 16 | StructField("first", StringType), 17 | StructField("last", StringType) 18 | ))) 19 | )) 20 | 21 | val res = PushdownUtils.generateColumnsFilter(schema, "d") 22 | assertThat(res).isEqualTo( 23 | """ 24 | |{ 25 | | `"birthday"`: `d`.`"birthday"`, 26 | | `gender`: `d`.`gender`, 27 | | `likes`: `d`.`likes`, 28 | | `name`: { 29 | | `first`: `d`.`name`.`first`, 30 | | `last`: `d`.`name`.`last` 31 | | } 32 | |} 33 | |""".stripMargin.replaceAll("\\s", "")) 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /integration-tests/src/test/resources/log4j2.properties: -------------------------------------------------------------------------------- 1 | status=warn 2 | 3 | appender.console.type=Console 4 | appender.console.name=console 5 | appender.console.layout.type=PatternLayout 6 | appender.console.layout.pattern=%d{HH:mm:ss.SSS} %-5p %c{10}:%L - %m%n 7 | 8 | rootLogger.level=info 9 | rootLogger.appenderRef.stdout.ref=console 10 | 11 | # Settings to quiet third party logs that are too verbose 12 | logger.jetty.name = org.sparkproject.jetty 13 | logger.jetty.level = warn 14 | logger.jetty2.name = org.sparkproject.jetty.util.component.AbstractLifeCycle 15 | logger.jetty2.level = error 16 | logger.repl1.name = org.apache.spark.repl.SparkIMain$exprTyper 17 | logger.repl1.level = info 18 | logger.repl2.name = org.apache.spark.repl.SparkILoop$SparkILoopInterpreter 19 | logger.repl2.level = info 20 | 21 | ## --- 22 | 23 | #logger.writer.name=org.apache.spark.sql.arangodb.datasource.writer.ArangoDataWriter 24 | #logger.writer.level=debug 25 | 26 | #logger.driver.name=com.arangodb 27 | #logger.driver.level=debug 28 | 29 | #logger.netty.name=com.arangodb.shaded.netty 30 | #logger.netty.level=debug 31 | 32 | #logger.communication.name=com.arangodb.internal.net.Communication 33 | #logger.communication.level=debug 34 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/PackageTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | import org.assertj.core.api.Assertions.assertThat 4 | import org.junit.jupiter.params.ParameterizedTest 5 | import org.junit.jupiter.params.provider.{Arguments, MethodSource} 6 | 7 | import java.util.stream 8 | 9 | class PackageTest { 10 | @ParameterizedTest 11 | @MethodSource(Array("provideSplitAttributeName")) 12 | def splitAttributeName(attribute: String, expected: Array[String]): Unit = { 13 | assertThat(splitAttributeNameParts(attribute)).isEqualTo(expected) 14 | } 15 | } 16 | 17 | object PackageTest { 18 | def provideSplitAttributeName: stream.Stream[Arguments] = 19 | stream.Stream.of( 20 | Arguments.of("a.b.c.d.e.f", Array("a", "b", "c", "d", "e", "f")), 21 | Arguments.of("a.`b`.c.d.e.f", Array("a", "b", "c", "d", "e", "f")), 22 | Arguments.of("a.`b.c`.d.e.f", Array("a", "b.c", "d", "e", "f")), 23 | Arguments.of("a.b.`c.d`.e.f", Array("a", "b", "c.d", "e", "f")), 24 | Arguments.of("a.b.`.c.d.`.e.f", Array("a", "b", ".c.d.", "e", "f")), 25 | Arguments.of("a.b.`.`.e.f", Array("a", "b", ".", "e", "f")) 26 | ) 27 | } -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/exceptions/ArangoDBMultiException.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.exceptions 2 | 3 | import com.arangodb.entity.ErrorEntity 4 | 5 | // Due to https://github.com/scala/bug/issues/10679 scala Stream serialization could generate StackOverflowError. To 6 | // avoid it we use: 7 | // val errors: Array[ErrorEntity] 8 | // instead of: 9 | // val errors: Iterable[ErrorEntity] 10 | 11 | /** 12 | * @param errors array of tuples with: 13 | * _1 : the error entity 14 | * _2 : the stringified record causing the error 15 | */ 16 | class ArangoDBMultiException(val errors: Array[(ErrorEntity, String)]) 17 | extends RuntimeException(ArangoDBMultiException.toMessage(errors)) 18 | 19 | private object ArangoDBMultiException { 20 | 21 | // creates exception message keeping only 5 errors for each error type 22 | private def toMessage(errors: Array[(ErrorEntity, String)]): String = errors 23 | .groupBy(_._1.getErrorNum) 24 | .mapValues(_.take(5)) 25 | .values 26 | .flatten 27 | .map(it => s"""Error: ${it._1.getErrorNum} - ${it._1.getErrorMessage} for record: ${it._2}""") 28 | .mkString("[\n\t", ",\n\t", "\n]") 29 | } 30 | -------------------------------------------------------------------------------- /demo/src/main/scala/ReadWriteDemo.scala: -------------------------------------------------------------------------------- 1 | import Schemas.movieSchema 2 | 3 | object ReadWriteDemo { 4 | 5 | def readWriteDemo(): Unit = { 6 | println("-----------------------") 7 | println("--- READ-WRITE DEMO ---") 8 | println("-----------------------") 9 | 10 | println("Reading 'movies' collection and writing 'actionMovies' collection...") 11 | val actionMoviesDF = ReadDemo.readTable("movies", movieSchema) 12 | .select("_key", "title", "releaseDate", "runtime", "description") 13 | .filter("genre = 'Action'") 14 | WriteDemo.saveDF(actionMoviesDF, "actionMovies") 15 | /* 16 | Filters and projection pushdowns are applied in this case. 17 | 18 | In the console an info message log like the following will be printed: 19 | > INFO ArangoScanBuilder:57 - Filters fully applied in AQL: 20 | > IsNotNull(genre) 21 | > EqualTo(genre,Action) 22 | 23 | Also the generated AQL query will be printed with log level debug: 24 | > DEBUG ArangoClient:61 - Executing AQL query: 25 | > FOR d IN @@col FILTER `d`.`genre` != null AND `d`.`genre` == "Action" RETURN {`_key`:`d`.`_key`,`description`:`d`.`description`,`releaseDate`:`d`.`releaseDate`,`runtime`:`d`.`runtime`,`title`:`d`.`title`} 26 | > with params: Map(@col -> movies) 27 | */ 28 | 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.writer 2 | 3 | import org.apache.spark.sql.SaveMode 4 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} 5 | import org.apache.spark.sql.arangodb.commons.exceptions.DataWriteAbortException 6 | import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage} 7 | import org.apache.spark.sql.types.StructType 8 | 9 | class ArangoBatchWriter(schema: StructType, options: ArangoDBConf, mode: SaveMode) extends BatchWrite { 10 | 11 | override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = 12 | new ArangoDataWriterFactory(schema, options) 13 | 14 | override def commit(messages: Array[WriterCommitMessage]): Unit = { 15 | // nothing to do here 16 | } 17 | 18 | override def abort(messages: Array[WriterCommitMessage]): Unit = { 19 | val client = ArangoClient(options) 20 | mode match { 21 | case SaveMode.Append => throw new DataWriteAbortException( 22 | "Cannot abort with SaveMode.Append: the underlying data source may require manual cleanup.") 23 | case SaveMode.Overwrite => client.truncate() 24 | case SaveMode.ErrorIfExists => ??? 25 | case SaveMode.Ignore => ??? 26 | } 27 | client.shutdown() 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /dev-README.md: -------------------------------------------------------------------------------- 1 | # dev-README 2 | 3 | ## GH Actions 4 | Check results [here](https://github.com/arangodb/arangodb-spark-datasource/actions). 5 | 6 | ## SonarCloud 7 | Check results [here](https://sonarcloud.io/project/overview?id=arangodb_arangodb-spark-datasource). 8 | 9 | ## check dependencies updates 10 | ```shell 11 | mvn -Pspark-${sparkVersion} -Pscala-${scalaVersion} versions:display-dependency-updates 12 | ``` 13 | 14 | ## analysis tools 15 | 16 | ### scalastyle 17 | ```shell 18 | mvn -Pspark-${sparkVersion} -Pscala-${scalaVersion} process-sources 19 | ``` 20 | Reports: 21 | - [arangodb-spark-commons](arangodb-spark-commons/target/scalastyle-output.xml) 22 | - [arangodb-spark-datasource-3.4](arangodb-spark-datasource-3.4/target/scalastyle-output.xml) 23 | - [arangodb-spark-datasource-3.5](arangodb-spark-datasource-3.5/target/scalastyle-output.xml) 24 | 25 | ### scapegoat 26 | ```shell 27 | mvn -Pspark-${sparkVersion} -Pscala-${scalaVersion} test-compile 28 | ``` 29 | Reports: 30 | - [arangodb-spark-commons](arangodb-spark-commons/target/scapegoat/scapegoat.html) 31 | - [arangodb-spark-datasource-3.4](arangodb-spark-datasource-3.4/target/scapegoat/scapegoat.html) 32 | - [arangodb-spark-datasource-3.5](arangodb-spark-datasource-3.5/target/scapegoat/scapegoat.html) 33 | 34 | ### JaCoCo 35 | ```shell 36 | mvn -Pspark-${sparkVersion} -Pscala-${scalaVersion} verify 37 | ``` 38 | Report: 39 | - [integration-tests](integration-tests/target/site/jacoco-aggregate/index.html) 40 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.reader 2 | 3 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf, ReadMode} 4 | import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx 5 | import org.apache.spark.sql.catalyst.expressions.ExprUtils 6 | import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan} 7 | import org.apache.spark.sql.types.StructType 8 | 9 | class ArangoScan(ctx: PushDownCtx, options: ArangoDBConf) extends Scan with Batch { 10 | ExprUtils.verifyColumnNameOfCorruptRecord(ctx.requiredSchema, options.readOptions.columnNameOfCorruptRecord) 11 | 12 | override def readSchema(): StructType = ctx.requiredSchema 13 | 14 | override def toBatch: Batch = this 15 | 16 | override def planInputPartitions(): Array[InputPartition] = options.readOptions.readMode match { 17 | case ReadMode.Query => Array(SingletonPartition) 18 | case ReadMode.Collection => planCollectionPartitions() 19 | } 20 | 21 | override def createReaderFactory(): PartitionReaderFactory = new ArangoPartitionReaderFactory(ctx, options) 22 | 23 | private def planCollectionPartitions(): Array[InputPartition] = 24 | ArangoClient.getCollectionShardIds(options) 25 | .zip(Stream.continually(options.driverOptions.endpoints).flatten) 26 | .map(it => new ArangoCollectionPartition(it._1, it._2)) 27 | 28 | } 29 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/WriteWithNullKey.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.write 2 | 3 | import com.arangodb.ArangoCollection 4 | import org.apache.spark.sql.SaveMode 5 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 6 | import org.apache.spark.sql.arangodb.datasource.BaseSparkTest 7 | import org.assertj.core.api.Assertions.assertThat 8 | import org.junit.jupiter.params.ParameterizedTest 9 | import org.junit.jupiter.params.provider.MethodSource 10 | 11 | class WriteWithNullKey extends BaseSparkTest { 12 | private val collectionName = "writeWithNullKey" 13 | private val collection: ArangoCollection = db.collection(collectionName) 14 | 15 | import spark.implicits._ 16 | 17 | @ParameterizedTest 18 | @MethodSource(Array("provideProtocolAndContentType")) 19 | def writeWithNullKey(protocol: String, contentType: String): Unit = { 20 | Seq( 21 | ("Carlsen", "Magnus"), 22 | (null, "Fabiano") 23 | ) 24 | .toDF("_key", "name") 25 | .write 26 | .format(BaseSparkTest.arangoDatasource) 27 | .mode(SaveMode.Overwrite) 28 | .options(options + ( 29 | ArangoDBConf.COLLECTION -> collectionName, 30 | ArangoDBConf.PROTOCOL -> protocol, 31 | ArangoDBConf.CONTENT_TYPE -> contentType, 32 | ArangoDBConf.CONFIRM_TRUNCATE -> "true" 33 | )) 34 | .save() 35 | 36 | assertThat(collection.count().getCount).isEqualTo(2L) 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/NotFilterTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | import org.apache.spark.sql.sources.{And, EqualTo, Not} 4 | import org.apache.spark.sql.types._ 5 | import org.assertj.core.api.Assertions.assertThat 6 | import org.junit.jupiter.api.Test 7 | 8 | class NotFilterTest { 9 | private val schema = StructType(Array( 10 | StructField("integer", IntegerType), 11 | StructField("string", StringType), 12 | StructField("binary", BinaryType) 13 | )) 14 | 15 | // FilterSupport.FULL 16 | private val f1 = EqualTo("string", "str") 17 | private val pushF1 = PushableFilter(f1, schema: StructType) 18 | 19 | // FilterSupport.NONE 20 | private val f2 = EqualTo("binary", Array(Byte.MaxValue)) 21 | 22 | // FilterSupport.PARTIAL 23 | private val f3 = And(f1, f2) 24 | 25 | @Test 26 | def notFilterSupportFull(): Unit = { 27 | val notFilter = PushableFilter(Not(f1), schema) 28 | assertThat(notFilter.support()).isEqualTo(FilterSupport.FULL) 29 | assertThat(notFilter.aql("d")).isEqualTo(s"""NOT (${pushF1.aql("d")})""") 30 | } 31 | 32 | @Test 33 | def notFilterSupportNone(): Unit = { 34 | val notFilter = PushableFilter(Not(f2), schema) 35 | assertThat(notFilter.support()).isEqualTo(FilterSupport.NONE) 36 | } 37 | 38 | @Test 39 | def notFilterSupportPartial(): Unit = { 40 | val notFilter = PushableFilter(Not(f3), schema) 41 | assertThat(notFilter.support()).isEqualTo(FilterSupport.NONE) 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/examples/PushdownExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.examples 2 | 3 | import com.arangodb.ArangoDB 4 | import org.apache.spark.sql.{Dataset, Encoders, SparkSession} 5 | 6 | object PushdownExample { 7 | 8 | final case class User(name: String, age: Int) 9 | 10 | def main(args: Array[String]): Unit = { 11 | prepareDB() 12 | 13 | val spark: SparkSession = SparkSession.builder() 14 | .appName("ArangoDBSparkTest") 15 | .master("local[*, 3]") 16 | .config("spark.driver.host", "127.0.0.1") 17 | .getOrCreate() 18 | 19 | import spark.implicits._ 20 | val ds: Dataset[User] = spark.read 21 | .format("com.arangodb.spark") 22 | .option("password", "test") 23 | .option("endpoints", "172.28.0.1:8529") 24 | .option("table", "users") 25 | .schema(Encoders.product[User].schema) 26 | .load() 27 | .as[User] 28 | 29 | ds 30 | .select("name") 31 | .filter("age >= 18 AND age < 22") 32 | .show 33 | 34 | /* 35 | Generated query: 36 | FOR d IN @@col FILTER `d`.`age` >= 18 AND `d`.`age` < 22 RETURN {`name`:`d`.`name`} 37 | with params: Map(@col -> users) 38 | */ 39 | } 40 | 41 | private def prepareDB(): Unit = { 42 | val arangoDB = new ArangoDB.Builder() 43 | .host("172.28.0.1", 8529) 44 | .password("test") 45 | .build() 46 | 47 | val col = arangoDB.db().collection("users") 48 | if (!col.exists()) 49 | col.create() 50 | col.truncate() 51 | col.insertDocument(User("Alice", 10)) 52 | col.insertDocument(User("Bob", 20)) 53 | col.insertDocument(User("Eve", 30)) 54 | 55 | arangoDB.shutdown() 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/com/arangodb/spark/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package com.arangodb.spark 2 | 3 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} 4 | import org.apache.spark.sql.arangodb.datasource.ArangoTable 5 | import org.apache.spark.sql.connector.catalog.{Table, TableProvider} 6 | import org.apache.spark.sql.connector.expressions.Transform 7 | import org.apache.spark.sql.sources.DataSourceRegister 8 | import org.apache.spark.sql.types.StructType 9 | import org.apache.spark.sql.util.CaseInsensitiveStringMap 10 | 11 | import java.util 12 | 13 | class DefaultSource extends TableProvider with DataSourceRegister { 14 | 15 | private def extractOptions(options: util.Map[String, String]): ArangoDBConf = { 16 | val opts: ArangoDBConf = ArangoDBConf(options) 17 | if (opts.driverOptions.acquireHostList) { 18 | val hosts = ArangoClient.acquireHostList(opts) 19 | opts.updated(ArangoDBConf.ENDPOINTS, hosts.mkString(",")) 20 | } else { 21 | opts 22 | } 23 | } 24 | 25 | override def inferSchema(options: CaseInsensitiveStringMap): StructType = getTable(options).schema() 26 | 27 | private def getTable(options: CaseInsensitiveStringMap): Table = 28 | getTable(None, options.asCaseSensitiveMap()) // scalastyle:ignore null 29 | 30 | override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table = 31 | getTable(Option(schema), properties) 32 | 33 | override def supportsExternalMetadata(): Boolean = true 34 | 35 | override def shortName(): String = "arangodb" 36 | 37 | private def getTable(schema: Option[StructType], properties: util.Map[String, String]) = 38 | new ArangoTable(schema, extractOptions(properties)) 39 | 40 | } 41 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ArangoUtils} 4 | import org.apache.spark.sql.arangodb.datasource.reader.ArangoScanBuilder 5 | import org.apache.spark.sql.arangodb.datasource.writer.ArangoWriterBuilder 6 | import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} 7 | import org.apache.spark.sql.connector.read.ScanBuilder 8 | import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} 9 | import org.apache.spark.sql.types.StructType 10 | import org.apache.spark.sql.util.CaseInsensitiveStringMap 11 | 12 | import java.util 13 | import scala.collection.JavaConverters.setAsJavaSetConverter 14 | 15 | class ArangoTable(private var schemaOpt: Option[StructType], options: ArangoDBConf) extends Table with SupportsRead with SupportsWrite { 16 | private lazy val tableSchema = schemaOpt.getOrElse(ArangoUtils.inferSchema(options)) 17 | 18 | override def name(): String = this.getClass.toString 19 | 20 | override def schema(): StructType = tableSchema 21 | 22 | override def capabilities(): util.Set[TableCapability] = Set( 23 | TableCapability.BATCH_READ, 24 | TableCapability.BATCH_WRITE, 25 | // TableCapability.STREAMING_WRITE, 26 | TableCapability.ACCEPT_ANY_SCHEMA, 27 | TableCapability.TRUNCATE 28 | // TableCapability.OVERWRITE_BY_FILTER, 29 | // TableCapability.OVERWRITE_DYNAMIC, 30 | ).asJava 31 | 32 | override def newScanBuilder(scanOptions: CaseInsensitiveStringMap): ScanBuilder = 33 | new ArangoScanBuilder(options.updated(ArangoDBConf(scanOptions)), schema()) 34 | 35 | override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = 36 | new ArangoWriterBuilder(info.schema(), options.updated(ArangoDBConf(info.options()))) 37 | } 38 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/PushdownUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons 2 | 3 | import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter} 4 | import org.apache.spark.sql.sources.Filter 5 | import org.apache.spark.sql.types.{StructField, StructType} 6 | 7 | import scala.annotation.tailrec 8 | 9 | // FIXME: use documentVariable instead of "d" 10 | object PushdownUtils { 11 | 12 | private[commons] def generateColumnsFilter(schema: StructType, documentVariable: String): String = 13 | doGenerateColumnsFilter(schema, s"`$documentVariable`.") 14 | 15 | private def doGenerateColumnsFilter(schema: StructType, ctx: String): String = 16 | s"""{${schema.fields.map(generateFieldFilter(_, ctx)).mkString(",")}}""" 17 | 18 | private def generateFieldFilter(field: StructField, ctx: String): String = { 19 | val fieldName = s"`${field.name}`" 20 | val value = s"$ctx$fieldName" 21 | s"$fieldName:" + (field.dataType match { 22 | case s: StructType => doGenerateColumnsFilter(s, s"$value.") 23 | case _ => value 24 | }) 25 | } 26 | 27 | def generateFilterClause(filters: Array[PushableFilter]): String = filters match { 28 | case Array() => "" 29 | case _ => "FILTER " + filters 30 | .filter(_.support != FilterSupport.NONE) 31 | .map(_.aql("d")) 32 | .mkString(" AND ") 33 | } 34 | 35 | def generateRowFilters(filters: Array[Filter], schema: StructType, documentVariable: String = "d"): Array[PushableFilter] = 36 | filters.map(PushableFilter(_, schema)) 37 | 38 | @tailrec 39 | private[commons] def getStructField(fieldNameParts: Array[String], fieldSchema: StructField): StructField = fieldNameParts match { 40 | case Array() => fieldSchema 41 | case _ => getStructField(fieldNameParts.tail, fieldSchema.dataType.asInstanceOf[StructType](fieldNameParts.head)) 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /demo/src/main/scala/Schemas.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.sql.types.{DateType, IntegerType, StringType, StructField, StructType, TimestampType} 2 | 3 | object Schemas { 4 | val movieSchema: StructType = StructType(Array( 5 | StructField("_id", StringType, nullable = false), 6 | StructField("_key", StringType, nullable = false), 7 | StructField("description", StringType), 8 | StructField("genre", StringType), 9 | StructField("homepage", StringType), 10 | StructField("imageUrl", StringType), 11 | StructField("imdbId", StringType), 12 | StructField("language", StringType), 13 | StructField("lastModified", TimestampType), 14 | StructField("releaseDate", DateType), 15 | StructField("runtime", IntegerType), 16 | StructField("studio", StringType), 17 | StructField("tagline", StringType), 18 | StructField("title", StringType), 19 | StructField("trailer", StringType) 20 | )) 21 | 22 | val personSchema: StructType = StructType(Array( 23 | StructField("_id", StringType, nullable = false), 24 | StructField("_key", StringType, nullable = false), 25 | StructField("biography", StringType), 26 | StructField("birthday", DateType), 27 | StructField("birthplace", StringType), 28 | StructField("lastModified", TimestampType), 29 | StructField("name", StringType), 30 | StructField("profileImageUrl", StringType) 31 | )) 32 | 33 | val actsInSchema: StructType = StructType(Array( 34 | StructField("_id", StringType, nullable = false), 35 | StructField("_key", StringType, nullable = false), 36 | StructField("_from", StringType, nullable = false), 37 | StructField("_to", StringType, nullable = false), 38 | StructField("name", StringType) 39 | )) 40 | 41 | val directedSchema: StructType = StructType(Array( 42 | StructField("_id", StringType, nullable = false), 43 | StructField("_key", StringType, nullable = false), 44 | StructField("_from", StringType, nullable = false), 45 | StructField("_to", StringType, nullable = false) 46 | )) 47 | 48 | } 49 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/StringEndsWithFilterTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | import org.apache.spark.sql.sources.StringEndsWith 4 | import org.apache.spark.sql.types._ 5 | import org.assertj.core.api.Assertions.assertThat 6 | import org.junit.jupiter.api.Test 7 | 8 | class StringEndsWithFilterTest { 9 | private val schema = StructType(Array( 10 | // atomic types 11 | StructField("bool", BooleanType), 12 | StructField("double", DoubleType), 13 | StructField("float", FloatType), 14 | StructField("integer", IntegerType), 15 | StructField("date", DateType), 16 | StructField("timestamp", TimestampType), 17 | StructField("short", ShortType), 18 | StructField("byte", ByteType), 19 | StructField("string", StringType), 20 | 21 | // complex types 22 | StructField("array", ArrayType(StringType)), 23 | StructField("null", NullType), 24 | StructField("struct", StructType(Array( 25 | StructField("a", StringType), 26 | StructField("b", IntegerType) 27 | ))) 28 | )) 29 | 30 | @Test 31 | def stringEndsWithStringFilter(): Unit = { 32 | val field = "string" 33 | val value = "str" 34 | val filter = PushableFilter(StringEndsWith(field, value), schema: StructType) 35 | assertThat(filter.support()).isEqualTo(FilterSupport.FULL) 36 | assertThat(filter.aql("d")).isEqualTo(s"""STARTS_WITH(REVERSE(`d`.`$field`), REVERSE("$value"))""") 37 | } 38 | 39 | @Test 40 | def stringEndsWithFilterTimestamp(): Unit = { 41 | val field = "timestamp" 42 | val value = "2001-01-02T15:30:45.678111Z" 43 | val filter = PushableFilter(StringEndsWith(field, value), schema: StructType) 44 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE) 45 | } 46 | 47 | @Test 48 | def stringEndsWithFilterDate(): Unit = { 49 | val field = "date" 50 | val value = "2001-01-02" 51 | val filter = PushableFilter(StringEndsWith(field, value), schema: StructType) 52 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE) 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.mapping 2 | 3 | import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder 4 | import com.fasterxml.jackson.core.JsonFactoryBuilder 5 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} 6 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGenerator, ArangoGeneratorProvider} 7 | import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonGenerator} 8 | import org.apache.spark.sql.types.{DataType, StructType} 9 | 10 | import java.io.OutputStream 11 | 12 | abstract sealed class ArangoGeneratorImpl( 13 | schema: DataType, 14 | writer: OutputStream, 15 | options: JSONOptions) 16 | extends JacksonGenerator( 17 | schema, 18 | options.buildJsonFactory().createGenerator(writer), 19 | options) with ArangoGenerator 20 | 21 | class ArangoGeneratorProviderImpl extends ArangoGeneratorProvider { 22 | override def of( 23 | contentType: ContentType, 24 | schema: StructType, 25 | outputStream: OutputStream, 26 | conf: ArangoDBConf 27 | ): ArangoGeneratorImpl = contentType match { 28 | case ContentType.JSON => new JsonArangoGenerator(schema, outputStream, conf) 29 | case ContentType.VPACK => new VPackArangoGenerator(schema, outputStream, conf) 30 | case _ => throw new IllegalArgumentException 31 | } 32 | } 33 | 34 | class JsonArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf) 35 | extends ArangoGeneratorImpl( 36 | schema, 37 | outputStream, 38 | createOptions(new JsonFactoryBuilder().build(), conf) 39 | ) 40 | 41 | class VPackArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf) 42 | extends ArangoGeneratorImpl( 43 | schema, 44 | outputStream, 45 | createOptions(new VPackFactoryBuilder().build(), conf) 46 | ) 47 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.mapping 2 | 3 | import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder 4 | import com.fasterxml.jackson.core.JsonFactoryBuilder 5 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} 6 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGenerator, ArangoGeneratorProvider} 7 | import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonGenerator} 8 | import org.apache.spark.sql.types.{DataType, StructType} 9 | 10 | import java.io.OutputStream 11 | 12 | abstract sealed class ArangoGeneratorImpl( 13 | schema: DataType, 14 | writer: OutputStream, 15 | options: JSONOptions) 16 | extends JacksonGenerator( 17 | schema, 18 | options.buildJsonFactory().createGenerator(writer), 19 | options) with ArangoGenerator 20 | 21 | class ArangoGeneratorProviderImpl extends ArangoGeneratorProvider { 22 | override def of( 23 | contentType: ContentType, 24 | schema: StructType, 25 | outputStream: OutputStream, 26 | conf: ArangoDBConf 27 | ): ArangoGeneratorImpl = contentType match { 28 | case ContentType.JSON => new JsonArangoGenerator(schema, outputStream, conf) 29 | case ContentType.VPACK => new VPackArangoGenerator(schema, outputStream, conf) 30 | case _ => throw new IllegalArgumentException 31 | } 32 | } 33 | 34 | class JsonArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf) 35 | extends ArangoGeneratorImpl( 36 | schema, 37 | outputStream, 38 | createOptions(new JsonFactoryBuilder().build(), conf) 39 | ) 40 | 41 | class VPackArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf) 42 | extends ArangoGeneratorImpl( 43 | schema, 44 | outputStream, 45 | createOptions(new VPackFactoryBuilder().build(), conf) 46 | ) 47 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/StringContainsFilterTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | import org.apache.spark.sql.sources.StringContains 4 | import org.apache.spark.sql.types._ 5 | import org.assertj.core.api.Assertions.assertThat 6 | import org.junit.jupiter.api.Test 7 | 8 | class StringContainsFilterTest { 9 | private val schema = StructType(Array( 10 | // atomic types 11 | StructField("bool", BooleanType), 12 | StructField("double", DoubleType), 13 | StructField("float", FloatType), 14 | StructField("integer", IntegerType), 15 | StructField("date", DateType), 16 | StructField("timestamp", TimestampType), 17 | StructField("short", ShortType), 18 | StructField("byte", ByteType), 19 | StructField("string", StringType), 20 | 21 | // complex types 22 | StructField("array", ArrayType(StringType)), 23 | StructField("intMap", MapType(StringType, IntegerType)), 24 | StructField("null", NullType), 25 | StructField("struct", StructType(Array( 26 | StructField("a", StringType), 27 | StructField("b", IntegerType) 28 | ))) 29 | )) 30 | 31 | @Test 32 | def stringContainsStringFilter(): Unit = { 33 | val field = "string" 34 | val value = "str" 35 | val filter = PushableFilter(StringContains(field, value), schema: StructType) 36 | assertThat(filter.support()).isEqualTo(FilterSupport.FULL) 37 | assertThat(filter.aql("d")).isEqualTo(s"""CONTAINS(`d`.`$field`, "$value")""") 38 | } 39 | 40 | @Test 41 | def stringContainsFilterTimestamp(): Unit = { 42 | val field = "timestamp" 43 | val value = "2001-01-02T15:30:45.678111Z" 44 | val filter = PushableFilter(StringContains(field, value), schema: StructType) 45 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE) 46 | } 47 | 48 | @Test 49 | def stringContainsFilterDate(): Unit = { 50 | val field = "date" 51 | val value = "2001-01-02" 52 | val filter = PushableFilter(StringContains(field, value), schema: StructType) 53 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE) 54 | } 55 | 56 | } 57 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/StringStartsWithFilterTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | import org.apache.spark.sql.sources.StringStartsWith 4 | import org.apache.spark.sql.types._ 5 | import org.assertj.core.api.Assertions.assertThat 6 | import org.junit.jupiter.api.Test 7 | 8 | class StringStartsWithFilterTest { 9 | private val schema = StructType(Array( 10 | // atomic types 11 | StructField("bool", BooleanType), 12 | StructField("double", DoubleType), 13 | StructField("float", FloatType), 14 | StructField("integer", IntegerType), 15 | StructField("date", DateType), 16 | StructField("timestamp", TimestampType), 17 | StructField("short", ShortType), 18 | StructField("byte", ByteType), 19 | StructField("string", StringType), 20 | 21 | // complex types 22 | StructField("array", ArrayType(StringType)), 23 | StructField("intMap", MapType(StringType, IntegerType)), 24 | StructField("null", NullType), 25 | StructField("struct", StructType(Array( 26 | StructField("a", StringType), 27 | StructField("b", IntegerType) 28 | ))) 29 | )) 30 | 31 | @Test 32 | def stringStartsWithStringFilter(): Unit = { 33 | val field = "string" 34 | val value = "str" 35 | val filter = PushableFilter(StringStartsWith(field, value), schema: StructType) 36 | assertThat(filter.support()).isEqualTo(FilterSupport.FULL) 37 | assertThat(filter.aql("d")).isEqualTo(s"""STARTS_WITH(`d`.`$field`, "$value")""") 38 | } 39 | 40 | @Test 41 | def stringStartsWithFilterTimestamp(): Unit = { 42 | val field = "timestamp" 43 | val value = "2001-01-02T15:30:45.678111Z" 44 | val filter = PushableFilter(StringStartsWith(field, value), schema: StructType) 45 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE) 46 | } 47 | 48 | @Test 49 | def stringStartsWithFilterDate(): Unit = { 50 | val field = "date" 51 | val value = "2001-01-02" 52 | val filter = PushableFilter(StringStartsWith(field, value), schema: StructType) 53 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE) 54 | } 55 | 56 | } 57 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/examples/DataTypesExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.examples 2 | 3 | import org.apache.spark.sql.{Dataset, Encoders, SparkSession} 4 | 5 | import java.sql.{Date, Timestamp} 6 | import java.time.{LocalDate, LocalDateTime} 7 | 8 | object DataTypesExample { 9 | 10 | final case class Order( 11 | userId: String, 12 | price: Double, 13 | shipped: Boolean, 14 | totItems: Int, 15 | creationDate: Date, 16 | lastModifiedTs: Timestamp, 17 | itemIds: List[String], 18 | qty: Map[String, Int] 19 | ) 20 | 21 | def main(args: Array[String]): Unit = { 22 | val spark: SparkSession = SparkSession.builder() 23 | .appName("ArangoDBSparkTest") 24 | .master("local[*, 3]") 25 | .config("spark.driver.host", "127.0.0.1") 26 | .getOrCreate() 27 | 28 | import spark.implicits._ 29 | 30 | val o = Order( 31 | userId = "Mike", 32 | price = 9.99, 33 | shipped = true, 34 | totItems = 2, 35 | creationDate = Date.valueOf(LocalDate.now()), 36 | lastModifiedTs = Timestamp.valueOf(LocalDateTime.now()), 37 | itemIds = List("itm1", "itm2"), 38 | qty = Map("itm1" -> 1, "itm2" -> 3) 39 | ) 40 | 41 | val ds = Seq(o).toDS() 42 | 43 | ds.show() 44 | ds.printSchema() 45 | 46 | ds 47 | .write 48 | .mode("overwrite") 49 | .format("com.arangodb.spark") 50 | .option("password", "test") 51 | .option("endpoints", "172.28.0.1:8529") 52 | .option("table", "orders") 53 | .option("confirmTruncate", "true") 54 | .save() 55 | 56 | val readDS: Dataset[Order] = spark.read 57 | .format("com.arangodb.spark") 58 | .option("password", "test") 59 | .option("endpoints", "172.28.0.1:8529") 60 | .option("table", "orders") 61 | .schema(Encoders.product[Order].schema) 62 | .load() 63 | .as[Order] 64 | 65 | readDS.show() 66 | readDS.printSchema() 67 | 68 | assert(readDS.collect().head == o) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /demo/docker/start_spark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker network create arangodb --subnet 172.28.0.0/16 4 | 5 | docker run -d --network arangodb --ip 172.28.10.1 --name spark-master -h spark-master \ 6 | -e SPARK_MODE=master \ 7 | -e SPARK_RPC_AUTHENTICATION_ENABLED=no \ 8 | -e SPARK_RPC_ENCRYPTION_ENABLED=no \ 9 | -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \ 10 | -e SPARK_SSL_ENABLED=no \ 11 | -v $(pwd)/docker/import:/import \ 12 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \ 13 | docker.io/bitnamilegacy/spark:3.5.6 14 | 15 | docker run -d --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spark-worker-1 \ 16 | -e SPARK_MODE=worker \ 17 | -e SPARK_MASTER_URL=spark://spark-master:7077 \ 18 | -e SPARK_WORKER_MEMORY=1G \ 19 | -e SPARK_WORKER_CORES=1 \ 20 | -e SPARK_RPC_AUTHENTICATION_ENABLED=no \ 21 | -e SPARK_RPC_ENCRYPTION_ENABLED=no \ 22 | -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \ 23 | -e SPARK_SSL_ENABLED=no \ 24 | -v $(pwd)/docker/import:/import \ 25 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \ 26 | docker.io/bitnamilegacy/spark:3.5.6 27 | 28 | docker run -d --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spark-worker-2 \ 29 | -e SPARK_MODE=worker \ 30 | -e SPARK_MASTER_URL=spark://spark-master:7077 \ 31 | -e SPARK_WORKER_MEMORY=1G \ 32 | -e SPARK_WORKER_CORES=1 \ 33 | -e SPARK_RPC_AUTHENTICATION_ENABLED=no \ 34 | -e SPARK_RPC_ENCRYPTION_ENABLED=no \ 35 | -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \ 36 | -e SPARK_SSL_ENABLED=no \ 37 | -v $(pwd)/docker/import:/import \ 38 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \ 39 | docker.io/bitnamilegacy/spark:3.5.6 40 | 41 | docker run -d --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spark-worker-3 \ 42 | -e SPARK_MODE=worker \ 43 | -e SPARK_MASTER_URL=spark://spark-master:7077 \ 44 | -e SPARK_WORKER_MEMORY=1G \ 45 | -e SPARK_WORKER_CORES=1 \ 46 | -e SPARK_RPC_AUTHENTICATION_ENABLED=no \ 47 | -e SPARK_RPC_ENCRYPTION_ENABLED=no \ 48 | -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \ 49 | -e SPARK_SSL_ENABLED=no \ 50 | -v $(pwd)/docker/import:/import \ 51 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \ 52 | docker.io/bitnamilegacy/spark:3.5.6 53 | -------------------------------------------------------------------------------- /demo/python-demo/schemas.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql.types import StructType, StructField, StringType, TimestampType, DateType, IntegerType 2 | 3 | movie_schema: StructType = StructType([ 4 | StructField("_id", StringType(), nullable=False), 5 | StructField("_key", StringType(), nullable=False), 6 | StructField("description", StringType()), 7 | StructField("genre", StringType()), 8 | StructField("homepage", StringType()), 9 | StructField("imageUrl", StringType()), 10 | StructField("imdbId", StringType()), 11 | StructField("language", StringType()), 12 | StructField("lastModified", TimestampType()), 13 | StructField("releaseDate", DateType()), 14 | StructField("runtime", IntegerType()), 15 | StructField("studio", StringType()), 16 | StructField("tagline", StringType()), 17 | StructField("title", StringType()), 18 | StructField("trailer", StringType()) 19 | ]) 20 | person_schema: StructType = StructType([ 21 | StructField("_id", StringType(), nullable=False), 22 | StructField("_key", StringType(), nullable=False), 23 | StructField("biography", StringType()), 24 | StructField("birthday", DateType()), 25 | StructField("birthplace", StringType()), 26 | StructField("lastModified", TimestampType()), 27 | StructField("name", StringType()), 28 | StructField("profileImageUrl", StringType()) 29 | ]) 30 | edges_schema: StructType = StructType([ 31 | StructField("_key", StringType(), nullable=False), 32 | StructField("_from", StringType(), nullable=False), 33 | StructField("_to", StringType(), nullable=False), 34 | StructField("$label", StringType()), 35 | StructField("name", StringType()), 36 | StructField("type", StringType()), 37 | ]) 38 | acts_in_schema: StructType = StructType([ 39 | StructField("_id", StringType(), nullable=False), 40 | StructField("_key", StringType(), nullable=False), 41 | StructField("_from", StringType(), nullable=False), 42 | StructField("_to", StringType(), nullable=False), 43 | StructField("name", StringType()) 44 | ]) 45 | directed_schema: StructType = StructType([ 46 | StructField("_id", StringType(), nullable=False), 47 | StructField("_key", StringType(), nullable=False), 48 | StructField("_from", StringType(), nullable=False), 49 | StructField("_to", StringType(), nullable=False) 50 | ]) 51 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.mapping 2 | 3 | import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder 4 | import com.fasterxml.jackson.core.json.JsonReadFeature 5 | import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder} 6 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} 7 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoParser, ArangoParserProvider, MappingUtils} 8 | import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonParser} 9 | import org.apache.spark.sql.catalyst.InternalRow 10 | import org.apache.spark.sql.types.DataType 11 | import org.apache.spark.unsafe.types.UTF8String 12 | 13 | abstract sealed class ArangoParserImpl( 14 | schema: DataType, 15 | options: JSONOptions, 16 | recordLiteral: Array[Byte] => UTF8String) 17 | extends JacksonParser(schema, options) with ArangoParser { 18 | override def parse(data: Array[Byte]): Iterable[InternalRow] = super.parse( 19 | data, 20 | (jsonFactory: JsonFactory, record: Array[Byte]) => jsonFactory.createParser(record), 21 | recordLiteral 22 | ) 23 | } 24 | 25 | class ArangoParserProviderImpl extends ArangoParserProvider { 26 | override def of(contentType: ContentType, schema: DataType, conf: ArangoDBConf): ArangoParserImpl = contentType match { 27 | case ContentType.JSON => new JsonArangoParser(schema, conf) 28 | case ContentType.VPACK => new VPackArangoParser(schema, conf) 29 | case _ => throw new IllegalArgumentException 30 | } 31 | } 32 | 33 | class JsonArangoParser(schema: DataType, conf: ArangoDBConf) 34 | extends ArangoParserImpl( 35 | schema, 36 | createOptions(new JsonFactoryBuilder() 37 | .configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS, true) 38 | .build(), conf), 39 | (bytes: Array[Byte]) => UTF8String.fromBytes(bytes) 40 | ) 41 | 42 | class VPackArangoParser(schema: DataType, conf: ArangoDBConf) 43 | extends ArangoParserImpl( 44 | schema, 45 | createOptions(new VPackFactoryBuilder().build(), conf), 46 | (bytes: Array[Byte]) => UTF8String.fromString(MappingUtils.vpackToJson(bytes)) 47 | ) 48 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.mapping 2 | 3 | import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder 4 | import com.fasterxml.jackson.core.json.JsonReadFeature 5 | import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder} 6 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} 7 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoParser, ArangoParserProvider, MappingUtils} 8 | import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonParser} 9 | import org.apache.spark.sql.catalyst.InternalRow 10 | import org.apache.spark.sql.types.DataType 11 | import org.apache.spark.unsafe.types.UTF8String 12 | 13 | abstract sealed class ArangoParserImpl( 14 | schema: DataType, 15 | options: JSONOptions, 16 | recordLiteral: Array[Byte] => UTF8String) 17 | extends JacksonParser(schema, options) with ArangoParser { 18 | override def parse(data: Array[Byte]): Iterable[InternalRow] = super.parse( 19 | data, 20 | (jsonFactory: JsonFactory, record: Array[Byte]) => jsonFactory.createParser(record), 21 | recordLiteral 22 | ) 23 | } 24 | 25 | class ArangoParserProviderImpl extends ArangoParserProvider { 26 | override def of(contentType: ContentType, schema: DataType, conf: ArangoDBConf): ArangoParserImpl = contentType match { 27 | case ContentType.JSON => new JsonArangoParser(schema, conf) 28 | case ContentType.VPACK => new VPackArangoParser(schema, conf) 29 | case _ => throw new IllegalArgumentException 30 | } 31 | } 32 | 33 | class JsonArangoParser(schema: DataType, conf: ArangoDBConf) 34 | extends ArangoParserImpl( 35 | schema, 36 | createOptions(new JsonFactoryBuilder() 37 | .configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS, true) 38 | .build(), conf), 39 | (bytes: Array[Byte]) => UTF8String.fromBytes(bytes) 40 | ) 41 | 42 | class VPackArangoParser(schema: DataType, conf: ArangoDBConf) 43 | extends ArangoParserImpl( 44 | schema, 45 | createOptions(new VPackFactoryBuilder().build(), conf), 46 | (bytes: Array[Byte]) => UTF8String.fromString(MappingUtils.vpackToJson(bytes)) 47 | ) 48 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/OrFilterTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | import org.apache.spark.sql.sources.{And, Or, EqualTo} 4 | import org.apache.spark.sql.types._ 5 | import org.assertj.core.api.Assertions.assertThat 6 | import org.junit.jupiter.api.Test 7 | 8 | class OrFilterTest { 9 | private val schema = StructType(Array( 10 | StructField("integer", IntegerType), 11 | StructField("string", StringType), 12 | StructField("binary", BinaryType) 13 | )) 14 | 15 | // FilterSupport.FULL 16 | private val f1 = EqualTo("string", "str") 17 | private val pushF1 = PushableFilter(f1, schema) 18 | 19 | // FilterSupport.NONE 20 | private val f2 = EqualTo("binary", Array(Byte.MaxValue)) 21 | 22 | // FilterSupport.PARTIAL 23 | private val f3 = And(f1, f2) 24 | 25 | @Test 26 | def orFilterSupportFullFull(): Unit = { 27 | val orFilter = PushableFilter(Or(f1, f1), schema) 28 | assertThat(orFilter.support()).isEqualTo(FilterSupport.FULL) 29 | assertThat(orFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")} OR ${pushF1.aql("d")})""") 30 | } 31 | 32 | @Test 33 | def orFilterSupportFullNone(): Unit = { 34 | val orFilter = PushableFilter(Or(f1, f2), schema) 35 | assertThat(orFilter.support()).isEqualTo(FilterSupport.NONE) 36 | } 37 | 38 | @Test 39 | def orFilterSupportFullPartial(): Unit = { 40 | val orFilter = PushableFilter(Or(f1, f3), schema) 41 | assertThat(orFilter.support()).isEqualTo(FilterSupport.PARTIAL) 42 | assertThat(orFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")} OR (${pushF1.aql("d")}))""") 43 | } 44 | 45 | @Test 46 | def orFilterSupportPartialPartial(): Unit = { 47 | val orFilter = PushableFilter(Or(f3, f3), schema) 48 | assertThat(orFilter.support()).isEqualTo(FilterSupport.PARTIAL) 49 | assertThat(orFilter.aql("d")).isEqualTo(s"""((${pushF1.aql("d")}) OR (${pushF1.aql("d")}))""") 50 | } 51 | 52 | @Test 53 | def orFilterSupportPartialNone(): Unit = { 54 | val orFilter = PushableFilter(Or(f3, f2), schema) 55 | assertThat(orFilter.support()).isEqualTo(FilterSupport.NONE) 56 | } 57 | 58 | @Test 59 | def orFilterSupportNoneNone(): Unit = { 60 | val orFilter = PushableFilter(Or(f2, f2), schema) 61 | assertThat(orFilter.support()).isEqualTo(FilterSupport.NONE) 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /arangodb-spark-commons/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | arangodb-spark-datasource 7 | com.arangodb 8 | 1.9.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | arangodb-spark-commons-${spark.compat.version}_${scala.compat.version} 13 | 14 | arangodb-spark-commons 15 | ArangoDB Spark Datasource Commons 16 | https://github.com/arangodb/arangodb-spark-datasource 17 | 18 | 19 | 20 | Michele Rastelli 21 | https://github.com/rashtao 22 | 23 | 24 | 25 | 26 | https://github.com/arangodb/arangodb-spark-datasource 27 | 28 | 29 | 30 | false 31 | ../integration-tests/target/site/jacoco-aggregate/jacoco.xml 32 | false 33 | 34 | 35 | 36 | 37 | 38 | org.jacoco 39 | jacoco-maven-plugin 40 | 41 | 42 | report-aggregate 43 | verify 44 | 45 | report-aggregate 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | org.apache.maven.plugins 57 | maven-surefire-report-plugin 58 | 3.3.0 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.reader 2 | 3 | import com.arangodb.entity.CursorWarning 4 | import org.apache.spark.internal.Logging 5 | import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider 6 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} 7 | import org.apache.spark.sql.catalyst.InternalRow 8 | import org.apache.spark.sql.catalyst.util.FailureSafeParser 9 | import org.apache.spark.sql.connector.read.PartitionReader 10 | import org.apache.spark.sql.types._ 11 | 12 | import scala.annotation.tailrec 13 | import scala.collection.JavaConverters.iterableAsScalaIterableConverter 14 | 15 | 16 | class ArangoQueryReader(schema: StructType, options: ArangoDBConf) extends PartitionReader[InternalRow] with Logging { 17 | 18 | private val actualSchema = StructType(schema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord)) 19 | private val parser = ArangoParserProvider().of(options.driverOptions.contentType, actualSchema, options) 20 | private val safeParser = new FailureSafeParser[Array[Byte]]( 21 | parser.parse, 22 | options.readOptions.parseMode, 23 | schema, 24 | options.readOptions.columnNameOfCorruptRecord) 25 | private val client = ArangoClient(options) 26 | private val iterator = client.readQuery() 27 | 28 | var rowIterator: Iterator[InternalRow] = _ 29 | 30 | // warnings of non stream AQL cursors are all returned along with the first batch 31 | if (!options.readOptions.stream) logWarns() 32 | 33 | @tailrec 34 | final override def next: Boolean = 35 | if (iterator.hasNext) { 36 | val current = iterator.next() 37 | rowIterator = safeParser.parse(current.get) 38 | if (rowIterator.hasNext) { 39 | true 40 | } else { 41 | next 42 | } 43 | } else { 44 | // FIXME: https://arangodb.atlassian.net/browse/BTS-671 45 | // stream AQL cursors' warnings are only returned along with the final batch 46 | if (options.readOptions.stream) logWarns() 47 | false 48 | } 49 | 50 | override def get: InternalRow = rowIterator.next() 51 | 52 | override def close(): Unit = { 53 | iterator.close() 54 | client.shutdown() 55 | } 56 | 57 | private def logWarns(): Unit = Option(iterator.getWarnings).foreach(_.asScala.foreach((w: CursorWarning) => 58 | logWarning(s"Got AQL warning: [${w.getCode}] ${w.getMessage}") 59 | )) 60 | 61 | } 62 | 63 | 64 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/AndFilterTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons.filter 2 | 3 | import org.apache.spark.sql.sources.{EqualTo, And} 4 | import org.apache.spark.sql.types._ 5 | import org.assertj.core.api.Assertions.assertThat 6 | import org.junit.jupiter.api.Test 7 | 8 | class AndFilterTest { 9 | private val schema = StructType(Array( 10 | StructField("integer", IntegerType), 11 | StructField("string", StringType), 12 | StructField("binary", BinaryType) 13 | )) 14 | 15 | // FilterSupport.FULL 16 | private val f1 = EqualTo("string", "str") 17 | private val pushF1 = PushableFilter(f1, schema) 18 | 19 | // FilterSupport.NONE 20 | private val f2 = EqualTo("binary", Array(Byte.MaxValue)) 21 | 22 | // FilterSupport.PARTIAL 23 | private val f3 = And(f1, f2) 24 | 25 | @Test 26 | def andFilterSupportFullFull(): Unit = { 27 | val andFilter = PushableFilter(And(f1, f1), schema) 28 | assertThat(andFilter.support()).isEqualTo(FilterSupport.FULL) 29 | assertThat(andFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")} AND ${pushF1.aql("d")})""") 30 | } 31 | 32 | @Test 33 | def andFilterSupportFullNone(): Unit = { 34 | val andFilter = PushableFilter(And(f1, f2), schema) 35 | assertThat(andFilter.support()).isEqualTo(FilterSupport.PARTIAL) 36 | assertThat(andFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")})""") 37 | } 38 | 39 | @Test 40 | def andFilterSupportFullPartial(): Unit = { 41 | val andFilter = PushableFilter(And(f1, f3), schema) 42 | assertThat(andFilter.support()).isEqualTo(FilterSupport.PARTIAL) 43 | assertThat(andFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")} AND (${pushF1.aql("d")}))""") 44 | } 45 | 46 | @Test 47 | def andFilterSupportPartialPartial(): Unit = { 48 | val andFilter = PushableFilter(And(f3, f3), schema) 49 | assertThat(andFilter.support()).isEqualTo(FilterSupport.PARTIAL) 50 | assertThat(andFilter.aql("d")).isEqualTo(s"""((${pushF1.aql("d")}) AND (${pushF1.aql("d")}))""") 51 | } 52 | 53 | @Test 54 | def andFilterSupportPartialNone(): Unit = { 55 | val andFilter = PushableFilter(And(f3, f2), schema) 56 | assertThat(andFilter.support()).isEqualTo(FilterSupport.PARTIAL) 57 | assertThat(andFilter.aql("d")).isEqualTo(s"""((${pushF1.aql("d")}))""") 58 | } 59 | 60 | @Test 61 | def andFilterSupportNoneNone(): Unit = { 62 | val andFilter = PushableFilter(And(f2, f2), schema) 63 | assertThat(andFilter.support()).isEqualTo(FilterSupport.NONE) 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/filter/package.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.commons 2 | 3 | import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema 4 | import org.apache.spark.sql.types._ 5 | 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | package object filter { 9 | 10 | private[filter] def splitAttributeNameParts(attribute: String): Array[String] = { 11 | val parts = new ArrayBuffer[String]() 12 | var sb = new StringBuilder() 13 | var inEscapedBlock = false 14 | for (c <- attribute.toCharArray) { 15 | if (c == '`') inEscapedBlock = !inEscapedBlock 16 | if (c == '.' && !inEscapedBlock) { 17 | parts += sb.toString() 18 | sb = new StringBuilder() 19 | } else if (c != '`') { 20 | sb.append(c) 21 | } 22 | } 23 | parts += sb.toString() 24 | parts.toArray 25 | } 26 | 27 | private[filter] def isTypeAqlCompatible(t: AbstractDataType): Boolean = t match { 28 | // atomic types 29 | case _: 30 | StringType 31 | | BooleanType 32 | | FloatType 33 | | DoubleType 34 | | IntegerType 35 | | LongType 36 | | ShortType 37 | | ByteType 38 | => true 39 | case _: 40 | DateType 41 | | TimestampType 42 | => false 43 | // complex types 44 | case _: NullType => true 45 | case at: ArrayType => isTypeAqlCompatible(at.elementType) 46 | case st: StructType => st.forall(f => isTypeAqlCompatible(f.dataType)) 47 | case mt: MapType => mt.keyType == StringType && isTypeAqlCompatible(mt.valueType) 48 | case _ => false 49 | } 50 | 51 | private[filter] def getValue(t: AbstractDataType, v: Any): String = t match { 52 | case NullType => "null" 53 | case _: TimestampType | DateType | StringType => s""""$v"""" 54 | case _: BooleanType | FloatType | DoubleType | IntegerType | LongType | ShortType | ByteType => v.toString 55 | case at: ArrayType => s"""[${v.asInstanceOf[Traversable[Any]].map(getValue(at.elementType, _)).mkString(",")}]""" 56 | case _: StructType => 57 | val row = v.asInstanceOf[GenericRowWithSchema] 58 | val parts = row.values.zip(row.schema).map(sf => 59 | s""""${sf._2.name}":${getValue(sf._2.dataType, sf._1)}""" 60 | ) 61 | s"{${parts.mkString(",")}}" 62 | case mt: MapType => 63 | v.asInstanceOf[Map[String, Any]].map(it => { 64 | s"""${getValue(mt.keyType, it._1)}:${getValue(mt.valueType, it._2)}""" 65 | }).mkString("{", ",", "}") 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/JacksonTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb 2 | 3 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} 4 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGeneratorProvider, ArangoParserProvider, MappingUtils} 5 | import org.apache.spark.sql.types._ 6 | import org.assertj.core.api.Assertions.assertThat 7 | import org.junit.jupiter.api.Test 8 | 9 | import java.io.ByteArrayOutputStream 10 | import java.nio.charset.StandardCharsets 11 | 12 | /** 13 | * @author Michele Rastelli 14 | */ 15 | class JacksonTest { 16 | private val jsonString = 17 | """ 18 | |{ 19 | | "birthday": "1964-01-02", 20 | | "gender": "female", 21 | | "likes": [ 22 | | "swimming" 23 | | ], 24 | | "name": { 25 | | "first": "Roseline", 26 | | "last": "Jucean" 27 | | }, 28 | | "nullString": null, 29 | | "nullField": null, 30 | | "mapField": { 31 | | "foo": 1, 32 | | "bar": 2 33 | | } 34 | |} 35 | |""".stripMargin.replaceAll("\\s", "") 36 | 37 | private val jsonBytes = jsonString.getBytes(StandardCharsets.UTF_8) 38 | private val vpackBytes = MappingUtils.jsonToVPack(jsonString) 39 | 40 | private val schema: StructType = new StructType( 41 | Array( 42 | StructField("birthday", DateType), 43 | StructField("gender", StringType), 44 | StructField("likes", ArrayType(StringType)), 45 | StructField("name", StructType( 46 | Array( 47 | StructField("first", StringType), 48 | StructField("last", StringType) 49 | ) 50 | )), 51 | StructField("nullString", StringType, nullable = true), 52 | StructField("nullField", NullType), 53 | StructField("mapField", MapType(StringType, IntegerType)) 54 | ) 55 | ) 56 | 57 | @Test 58 | def jsonRoudTrip(): Unit = { 59 | roundTrip(ContentType.JSON, jsonBytes) 60 | } 61 | 62 | @Test 63 | def vpackRoudTrip(): Unit = { 64 | roundTrip(ContentType.VPACK, vpackBytes) 65 | } 66 | 67 | private def roundTrip(contentType: ContentType, data: Array[Byte]): Unit = { 68 | val parser = ArangoParserProvider().of(contentType, schema, ArangoDBConf()) 69 | val parsed = parser.parse(data) 70 | val output = new ByteArrayOutputStream() 71 | val generator = ArangoGeneratorProvider().of(contentType, schema, output, ArangoDBConf()) 72 | generator.write(parsed.head) 73 | generator.close() 74 | assertThat(output.toByteArray).isEqualTo(data) 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /demo/src/main/scala/Demo.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.sql.SparkSession 2 | 3 | object Demo { 4 | val importPath: String = System.getProperty("importPath", "/demo/docker/import") 5 | val password: String = System.getProperty("password", "test") 6 | val endpoints: String = System.getProperty("endpoints", "172.28.0.1:8529,172.28.0.1:8539,172.28.0.1:8549") 7 | val sslEnabled: String = System.getProperty("ssl.enabled", "true") 8 | val sslCertValue: String = System.getProperty("ssl.cert.value", "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlekNDQW1PZ0F3SUJBZ0lFZURDelh6QU5CZ2txaGtpRzl3MEJBUXNGQURCdU1SQXdEZ1lEVlFRR0V3ZFYKYm10dWIzZHVNUkF3RGdZRFZRUUlFd2RWYm10dWIzZHVNUkF3RGdZRFZRUUhFd2RWYm10dWIzZHVNUkF3RGdZRApWUVFLRXdkVmJtdHViM2R1TVJBd0RnWURWUVFMRXdkVmJtdHViM2R1TVJJd0VBWURWUVFERXdsc2IyTmhiR2h2CmMzUXdIaGNOTWpBeE1UQXhNVGcxTVRFNVdoY05NekF4TURNd01UZzFNVEU1V2pCdU1SQXdEZ1lEVlFRR0V3ZFYKYm10dWIzZHVNUkF3RGdZRFZRUUlFd2RWYm10dWIzZHVNUkF3RGdZRFZRUUhFd2RWYm10dWIzZHVNUkF3RGdZRApWUVFLRXdkVmJtdHViM2R1TVJBd0RnWURWUVFMRXdkVmJtdHViM2R1TVJJd0VBWURWUVFERXdsc2IyTmhiR2h2CmMzUXdnZ0VpTUEwR0NTcUdTSWIzRFFFQkFRVUFBNElCRHdBd2dnRUtBb0lCQVFDMVdpRG5kNCt1Q21NRzUzOVoKTlpCOE53STBSWkYzc1VTUUdQeDNsa3FhRlRaVkV6TVpMNzZIWXZkYzlRZzdkaWZ5S3lRMDlSTFNwTUFMWDlldQpTc2VEN2JaR25mUUg1MkJuS2NUMDllUTN3aDdhVlE1c04yb215Z2RITEM3WDl1c250eEFmdjdOem12ZG9nTlhvCkpReVkvaFNaZmY3UklxV0g4Tm5BVUtranFPZTZCZjVMRGJ4SEtFU21yRkJ4T0NPbmhjcHZaV2V0d3BpUmRKVlAKd1VuNVA4MkNBWnpmaUJmbUJabkI3RDBsKy82Q3Y0ak11SDI2dUFJY2l4blZla0JRemwxUmd3Y3p1aVpmMk1HTwo2NHZETU1KSldFOUNsWkYxdVF1UXJ3WEY2cXdodVAxSG5raWk2d05iVHRQV2xHU2txZXV0cjAwNCtIemJmOEtuClJZNFBBZ01CQUFHaklUQWZNQjBHQTFVZERnUVdCQlRCcnY5QXd5bnQzQzVJYmFDTnlPVzV2NEROa1RBTkJna3EKaGtpRzl3MEJBUXNGQUFPQ0FRRUFJbTlyUHZEa1lwbXpwU0loUjNWWEc5WTcxZ3hSRHJxa0VlTHNNb0V5cUdudwovengxYkRDTmVHZzJQbmNMbFc2elRJaXBFQm9vaXhJRTlVN0t4SGdaeEJ5MEV0NkVFV3ZJVW1ucjZGNEYrZGJUCkQwNTBHSGxjWjdlT2VxWVRQWWVRQzUwMkcxRm80dGROaTRsRFA5TDlYWnBmN1ExUWltUkgycWFMUzAzWkZaYTIKdFk3YWgvUlFxWkw4RGt4eDgvemMyNXNnVEhWcHhvSzg1M2dsQlZCcy9FTk1peUdKV21BWFFheWV3WTNFUHQvOQp3R3dWNEttVTNkUERsZVFlWFNVR1BVSVNlUXhGankrakN3MjFwWXZpV1ZKVE5CQTlsNW55M0doRW1jbk9UL2dRCkhDdlZSTHlHTE1iYU1aNEpyUHdiK2FBdEJncmdlaUs0eGVTTU12cmJodz09Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K") 9 | 10 | val spark: SparkSession = SparkSession.builder 11 | .appName("arangodb-demo") 12 | .master("local[*, 3]") 13 | .getOrCreate 14 | 15 | val options: Map[String, String] = Map( 16 | "password" -> password, 17 | "endpoints" -> endpoints, 18 | "ssl.enabled" -> sslEnabled, 19 | "ssl.cert.value" -> sslCertValue, 20 | "ssl.verifyHost" -> "false" 21 | ) 22 | 23 | def main(args: Array[String]): Unit = { 24 | WriteDemo.writeDemo() 25 | ReadDemo.readDemo() 26 | ReadWriteDemo.readWriteDemo() 27 | spark.stop 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /demo/python-demo/read_demo.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import pyspark.sql 4 | from pyspark.sql import SparkSession 5 | from pyspark.sql.types import StructType, StructField, StringType 6 | 7 | from schemas import movie_schema 8 | from utils import combine_dicts 9 | 10 | 11 | def read_demo(spark: SparkSession, base_opts: Dict[str, str]): 12 | movies_df = read_collection(spark, "movies", base_opts, movie_schema) 13 | 14 | print("Read table: history movies or documentaries about 'World War' released from 2000-01-01") 15 | # We can get to what we want in 2 different ways: 16 | # First, the PySpark dataframe way... 17 | movies_df \ 18 | .select("title", "releaseDate", "genre", "description") \ 19 | .filter("genre IN ('History', 'Documentary') AND description LIKE '%World War%' AND releaseDate > '2000'") \ 20 | .show() 21 | 22 | # Second, in the Pandas on Spark way... 23 | movies_pd_df = movies_df.to_pandas_on_spark() 24 | subset = movies_pd_df[["title", "releaseDate", "genre", "description"]] 25 | recent_ww_movies = subset[subset["genre"].isin(["History", "Documentary"])\ 26 | & (subset["releaseDate"] >= '2000')\ 27 | & subset["description"].str.contains("World War")] 28 | print(recent_ww_movies) 29 | 30 | print("Read query: actors of movies directed by Clint Eastwood with related movie title and interpreted role") 31 | read_aql_query( 32 | spark, 33 | """WITH movies, persons 34 | FOR v, e, p IN 2 ANY "persons/1062" OUTBOUND directed, INBOUND actedIn 35 | RETURN {movie: p.vertices[1].title, name: v.name, role: p.edges[1].name} 36 | """, 37 | base_opts, 38 | StructType([ 39 | StructField("movie", StringType()), 40 | StructField("name", StringType()), 41 | StructField("role", StringType()) 42 | ]) 43 | ).show(20, 200) 44 | 45 | 46 | def read_collection(spark: SparkSession, collection_name: str, base_opts: Dict[str, str], schema: StructType) -> pyspark.sql.DataFrame: 47 | arangodb_datasource_options = combine_dicts([base_opts, {"table": collection_name}]) 48 | 49 | return spark.read \ 50 | .format("com.arangodb.spark") \ 51 | .options(**arangodb_datasource_options) \ 52 | .schema(schema) \ 53 | .load() 54 | 55 | 56 | def read_aql_query(spark: SparkSession, query: str, base_opts: Dict[str, str], schema: StructType) -> pyspark.sql.DataFrame: 57 | arangodb_datasource_options = combine_dicts([base_opts, {"query": query}]) 58 | 59 | return spark.read \ 60 | .format("com.arangodb.spark") \ 61 | .options(**arangodb_datasource_options) \ 62 | .schema(schema) \ 63 | .load() 64 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/ReadSmartEdgeCollectionTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import com.arangodb.entity.EdgeDefinition 4 | import com.arangodb.model.GraphCreateOptions 5 | import org.apache.spark.sql.DataFrame 6 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 7 | import org.assertj.core.api.Assertions.assertThat 8 | import org.junit.jupiter.api.Assumptions.assumeTrue 9 | import org.junit.jupiter.api.BeforeAll 10 | import org.junit.jupiter.params.ParameterizedTest 11 | import org.junit.jupiter.params.provider.MethodSource 12 | 13 | import java.util 14 | import scala.collection.JavaConverters.asJavaIterableConverter 15 | import scala.collection.immutable 16 | 17 | class ReadSmartEdgeCollectionTest extends BaseSparkTest { 18 | 19 | @ParameterizedTest 20 | @MethodSource(Array("provideProtocolAndContentType")) 21 | def readSmartEdgeCollection(protocol: String, contentType: String): Unit = { 22 | val df: DataFrame = spark.read 23 | .format(BaseSparkTest.arangoDatasource) 24 | .options(options + ( 25 | ArangoDBConf.COLLECTION -> ReadSmartEdgeCollectionTest.name, 26 | ArangoDBConf.PROTOCOL -> protocol, 27 | ArangoDBConf.CONTENT_TYPE -> contentType 28 | )) 29 | .load() 30 | 31 | 32 | import spark.implicits._ 33 | val read = df 34 | .as[Edge] 35 | .collect() 36 | 37 | assertThat(read.map(_.name)).containsAll(ReadSmartEdgeCollectionTest.data.map(d => d("name")).asJava) 38 | } 39 | 40 | } 41 | 42 | object ReadSmartEdgeCollectionTest { 43 | val name = "smartEdgeCol" 44 | val from = s"from-$name" 45 | val to = s"from-$name" 46 | 47 | val data: immutable.Seq[Map[String, String]] = (1 to 10) 48 | .map(x => Map( 49 | "name" -> s"name-$x", 50 | "_from" -> s"$from/a:$x", 51 | "_to" -> s"$to/b:$x" 52 | )) 53 | 54 | @BeforeAll 55 | def init(): Unit = { 56 | assumeTrue(BaseSparkTest.isCluster && BaseSparkTest.isEnterprise) 57 | 58 | if (BaseSparkTest.db.graph(name).exists()) { 59 | BaseSparkTest.db.graph(name).drop(true) 60 | } 61 | 62 | val ed = new EdgeDefinition() 63 | .collection(name) 64 | .from(from) 65 | .to(to) 66 | val opts = new GraphCreateOptions() 67 | .numberOfShards(2) 68 | .isSmart(true) 69 | .smartGraphAttribute("name") 70 | BaseSparkTest.db.createGraph(name, List(ed).asJava.asInstanceOf[util.Collection[EdgeDefinition]], opts) 71 | BaseSparkTest.db.collection(name).insertDocuments(data.asJava.asInstanceOf[util.Collection[Any]]) 72 | } 73 | } 74 | 75 | case class Edge( 76 | name: String, 77 | _from: String, 78 | _to: String 79 | ) 80 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | // scalastyle:off 19 | 20 | package org.apache.spark.sql.arangodb.datasource.mapping.json 21 | 22 | import com.fasterxml.jackson.core.{JsonParser, JsonToken} 23 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult 24 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} 25 | import org.apache.spark.sql.errors.QueryErrorsBase 26 | import org.apache.spark.sql.types._ 27 | 28 | object JacksonUtils extends QueryErrorsBase { 29 | /** 30 | * Advance the parser until a null or a specific token is found 31 | */ 32 | def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = { 33 | parser.nextToken() match { 34 | case null => false 35 | case x => x != stopOn 36 | } 37 | } 38 | 39 | def verifyType(name: String, dataType: DataType): TypeCheckResult = { 40 | dataType match { 41 | case NullType | _: AtomicType | CalendarIntervalType => TypeCheckSuccess 42 | 43 | case st: StructType => 44 | st.foldLeft(TypeCheckSuccess: TypeCheckResult) { case (currResult, field) => 45 | if (currResult.isFailure) currResult else verifyType(field.name, field.dataType) 46 | } 47 | 48 | case at: ArrayType => verifyType(name, at.elementType) 49 | 50 | // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when 51 | // generating JSON, so we only care if the values are valid for JSON. 52 | case mt: MapType => verifyType(name, mt.valueType) 53 | 54 | case udt: UserDefinedType[_] => verifyType(name, udt.sqlType) 55 | 56 | case _ => 57 | DataTypeMismatch( 58 | errorSubClass = "CANNOT_CONVERT_TO_JSON", 59 | messageParameters = Map( 60 | "name" -> toSQLId(name), 61 | "type" -> toSQLType(dataType))) 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | // scalastyle:off 19 | 20 | package org.apache.spark.sql.arangodb.datasource.mapping.json 21 | 22 | import com.fasterxml.jackson.core.{JsonParser, JsonToken} 23 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult 24 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} 25 | import org.apache.spark.sql.errors.QueryErrorsBase 26 | import org.apache.spark.sql.types._ 27 | 28 | object JacksonUtils extends QueryErrorsBase { 29 | /** 30 | * Advance the parser until a null or a specific token is found 31 | */ 32 | def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = { 33 | parser.nextToken() match { 34 | case null => false 35 | case x => x != stopOn 36 | } 37 | } 38 | 39 | def verifyType(name: String, dataType: DataType): TypeCheckResult = { 40 | dataType match { 41 | case NullType | _: AtomicType | CalendarIntervalType => TypeCheckSuccess 42 | 43 | case st: StructType => 44 | st.foldLeft(TypeCheckSuccess: TypeCheckResult) { case (currResult, field) => 45 | if (currResult.isFailure) currResult else verifyType(field.name, field.dataType) 46 | } 47 | 48 | case at: ArrayType => verifyType(name, at.elementType) 49 | 50 | // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when 51 | // generating JSON, so we only care if the values are valid for JSON. 52 | case mt: MapType => verifyType(name, mt.valueType) 53 | 54 | case udt: UserDefinedType[_] => verifyType(name, udt.sqlType) 55 | 56 | case _ => 57 | DataTypeMismatch( 58 | errorSubClass = "CANNOT_CONVERT_TO_JSON", 59 | messageParameters = Map( 60 | "name" -> toSQLId(name), 61 | "type" -> toSQLType(dataType))) 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | arangodb-spark-datasource 7 | com.arangodb 8 | 1.9.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | arangodb-spark-datasource-3.4_${scala.compat.version} 13 | 14 | arangodb-spark-datasource-3.4 15 | ArangoDB Datasource for Apache Spark 3.4 16 | https://github.com/arangodb/arangodb-spark-datasource 17 | 18 | 19 | 20 | Michele Rastelli 21 | https://github.com/rashtao 22 | 23 | 24 | 25 | 26 | https://github.com/arangodb/arangodb-spark-datasource 27 | 28 | 29 | 30 | false 31 | ../integration-tests/target/site/jacoco-aggregate/jacoco.xml 32 | src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/* 33 | src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/* 34 | false 35 | 36 | 37 | 38 | 39 | com.arangodb 40 | arangodb-spark-commons-${spark.compat.version}_${scala.compat.version} 41 | ${project.version} 42 | compile 43 | 44 | 45 | 46 | 47 | 48 | 49 | maven-assembly-plugin 50 | 51 | 52 | jar-with-dependencies 53 | 54 | 55 | 56 | 57 | package 58 | 59 | single 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | arangodb-spark-datasource 7 | com.arangodb 8 | 1.9.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | arangodb-spark-datasource-3.5_${scala.compat.version} 13 | 14 | arangodb-spark-datasource-3.5 15 | ArangoDB Datasource for Apache Spark 3.5 16 | https://github.com/arangodb/arangodb-spark-datasource 17 | 18 | 19 | 20 | Michele Rastelli 21 | https://github.com/rashtao 22 | 23 | 24 | 25 | 26 | https://github.com/arangodb/arangodb-spark-datasource 27 | 28 | 29 | 30 | false 31 | ../integration-tests/target/site/jacoco-aggregate/jacoco.xml 32 | src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/* 33 | src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/* 34 | false 35 | 36 | 37 | 38 | 39 | com.arangodb 40 | arangodb-spark-commons-${spark.compat.version}_${scala.compat.version} 41 | ${project.version} 42 | compile 43 | 44 | 45 | 46 | 47 | 48 | 49 | maven-assembly-plugin 50 | 51 | 52 | jar-with-dependencies 53 | 54 | 55 | 56 | 57 | package 58 | 59 | single 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /demo/src/main/scala/ReadDemo.scala: -------------------------------------------------------------------------------- 1 | import Schemas.movieSchema 2 | import org.apache.spark.sql.DataFrame 3 | import org.apache.spark.sql.types._ 4 | 5 | object ReadDemo { 6 | 7 | def readDemo(): Unit = { 8 | println("-----------------") 9 | println("--- READ DEMO ---") 10 | println("-----------------") 11 | 12 | val moviesDF = readTable("movies", movieSchema) 13 | 14 | println("Read table: history movies or documentaries about 'World War' released from 2000-01-01") 15 | moviesDF 16 | .select("title", "releaseDate", "genre", "description") 17 | .filter("genre IN ('History', 'Documentary') AND description LIKE '%World War%' AND releaseDate > '2000'") 18 | .show(20, 200) 19 | /* 20 | Filters and projection pushdowns are applied in this case. 21 | 22 | In the console an info message log like the following will be printed: 23 | > INFO ArangoScanBuilder:57 - Filters fully applied in AQL: 24 | > IsNotNull(description) 25 | > IsNotNull(releaseDate) 26 | > In(genre, [History,Documentary]) 27 | > StringContains(description,World War) 28 | > GreaterThan(releaseDate,2000-01-01) 29 | 30 | Also the generated AQL query will be printed with log level debug: 31 | > DEBUG ArangoClient:61 - Executing AQL query: 32 | > FOR d IN @@col FILTER `d`.`description` != null AND `d`.`releaseDate` != null AND LENGTH(["History","Documentary"][* FILTER `d`.`genre` == CURRENT]) > 0 AND CONTAINS(`d`.`description`, "World War") AND DATE_TIMESTAMP(`d`.`releaseDate`) > DATE_TIMESTAMP("2000-01-01") RETURN {`description`:`d`.`description`,`genre`:`d`.`genre`,`releaseDate`:`d`.`releaseDate`,`title`:`d`.`title`} 33 | > with params: Map(@col -> movies) 34 | */ 35 | 36 | println("Read query: actors of movies directed by Clint Eastwood with related movie title and interpreted role") 37 | readQuery( 38 | """WITH movies, persons 39 | |FOR v, e, p IN 2 ANY "persons/1062" OUTBOUND directed, INBOUND actedIn 40 | | RETURN {movie: p.vertices[1].title, name: v.name, role: p.edges[1].name} 41 | |""".stripMargin, 42 | schema = StructType(Array( 43 | StructField("movie", StringType), 44 | StructField("name", StringType), 45 | StructField("role", StringType) 46 | )) 47 | ).show(20, 200) 48 | } 49 | 50 | def readTable(tableName: String, schema: StructType): DataFrame = { 51 | Demo.spark.read 52 | .format("com.arangodb.spark") 53 | .options(Demo.options + ("table" -> tableName)) 54 | .schema(schema) 55 | .load 56 | } 57 | 58 | def readQuery(query: String, schema: StructType): DataFrame = { 59 | Demo.spark.read 60 | .format("com.arangodb.spark") 61 | .options(Demo.options + ("query" -> query)) 62 | .schema(schema) 63 | .load 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.reader 2 | 3 | import com.arangodb.entity.CursorWarning 4 | import org.apache.spark.internal.Logging 5 | import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider 6 | import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx 7 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} 8 | import org.apache.spark.sql.catalyst.InternalRow 9 | import org.apache.spark.sql.catalyst.util.FailureSafeParser 10 | import org.apache.spark.sql.connector.read.PartitionReader 11 | import org.apache.spark.sql.types.StructType 12 | 13 | import scala.annotation.tailrec 14 | import scala.collection.JavaConverters.iterableAsScalaIterableConverter 15 | 16 | 17 | class ArangoCollectionPartitionReader(inputPartition: ArangoCollectionPartition, ctx: PushDownCtx, opts: ArangoDBConf) 18 | extends PartitionReader[InternalRow] with Logging { 19 | 20 | // override endpoints with partition endpoint 21 | private val options = opts.updated(ArangoDBConf.ENDPOINTS, inputPartition.endpoint) 22 | private val actualSchema = StructType(ctx.requiredSchema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord)) 23 | private val parser = ArangoParserProvider().of(options.driverOptions.contentType, actualSchema, options) 24 | private val safeParser = new FailureSafeParser[Array[Byte]]( 25 | parser.parse, 26 | options.readOptions.parseMode, 27 | ctx.requiredSchema, 28 | options.readOptions.columnNameOfCorruptRecord) 29 | private val client = ArangoClient(options) 30 | private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx.filters, actualSchema) 31 | 32 | var rowIterator: Iterator[InternalRow] = _ 33 | 34 | // warnings of non stream AQL cursors are all returned along with the first batch 35 | if (!options.readOptions.stream) logWarns() 36 | 37 | @tailrec 38 | final override def next: Boolean = 39 | if (iterator.hasNext) { 40 | val current = iterator.next() 41 | rowIterator = safeParser.parse(current.get) 42 | if (rowIterator.hasNext) { 43 | true 44 | } else { 45 | next 46 | } 47 | } else { 48 | // FIXME: https://arangodb.atlassian.net/browse/BTS-671 49 | // stream AQL cursors' warnings are only returned along with the final batch 50 | if (options.readOptions.stream) logWarns() 51 | false 52 | } 53 | 54 | override def get: InternalRow = rowIterator.next() 55 | 56 | override def close(): Unit = { 57 | iterator.close() 58 | client.shutdown() 59 | } 60 | 61 | private def logWarns(): Unit = Option(iterator.getWarnings).foreach(_.asScala.foreach((w: CursorWarning) => 62 | logWarning(s"Got AQL warning: [${w.getCode}] ${w.getMessage}") 63 | )) 64 | 65 | } 66 | -------------------------------------------------------------------------------- /ChangeLog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres 4 | to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). 5 | 6 | ## [Unreleased] 7 | 8 | ## [1.8.0] - 2024-09-23 9 | 10 | - updated Java Driver to version `7.9.0` 11 | - added support to load SSL Trust Store from file (DE-502, #61) 12 | - dropped support for Spark 3.2 13 | 14 | ## [1.7.0] - 2024-07-04 15 | 16 | - added support for Spark 3.5 (#58) 17 | - added support to protocol `http2` (DE-596, #55) 18 | - added configuration `ssl.verifyHost` to disable TLS hostname verification (DE-790, #54) 19 | - updated `arangodb-java-driver` to version `7.7.1` 20 | - dropped support for Spark 3.1 21 | 22 | ## [1.6.0] - 2024-03-20 23 | 24 | - support to query `ttl` (#50) 25 | - updated Java Driver to version `7.5.1` 26 | - dropped support for Spark 2.4 27 | 28 | ## [1.5.1] - 2023-09-25 29 | 30 | - fixed reading from smart edge collections (#47) 31 | 32 | ## [1.5.0] - 2023-05-31 33 | 34 | - support for Spark 3.4 (#45) 35 | - support for Spark 3.3 (#44) 36 | - updated Java Driver to version `7.0` (#35) 37 | 38 | ## [1.4.3] - 2023-03-28 39 | 40 | - added previous attempts exceptions in `ArangoDBDataWriterException` 41 | 42 | ## [1.4.2] - 2023-03-16 43 | 44 | - added debug header `x-arango-spark-request-id` 45 | 46 | ## [1.4.1] - 2022-12-15 47 | 48 | - fixed filters pushdown for read mode `query` (#37) 49 | 50 | ## [1.4.0] - 2022-05-24 51 | 52 | - support for Spark 3.2 (#31) 53 | - support for `null` at root level of a JSON array in Spark 3.1 mapping (SPARK-36379) 54 | - updated dependencies 55 | - flush write buffer on byte threshold `byteBatchSize` (#30) 56 | - remove `null` `_key` field during serialization (#29) 57 | 58 | ## [1.3.0] - 2022-04-27 59 | 60 | - added `ignoreNullFields` config param (#28) 61 | 62 | ## [1.2.0] - 2022-03-18 63 | 64 | - use `overwriteMode=ignore` if save mode is other than `Append` (#26) 65 | - require non-nullable string fields `_from` and `_to` to write to edge collections (#25) 66 | - configurable backoff retry delay for write requests (`retry.minDelay` and `retry.maxDelay`), disabled by default (#24) 67 | - retry only if schema has non-nullable field `_key` (#23) 68 | - retry on connection exceptions 69 | - added `retry.maxAttempts` config param (#20) 70 | - increased default timeout to 5 minutes 71 | - reject writing decimal types with json content type (#18) 72 | - report records causing write errors (#17) 73 | - improved logging about connections and write tasks 74 | 75 | ## [1.1.1] - 2022-02-28 76 | 77 | - retry timeout exception in truncate requests (#16) 78 | - fixed exception serialization bug (#15) 79 | 80 | ## [1.1.0] - 2022-02-23 81 | 82 | - added driver timeout configuration option (#12) 83 | - updated dependency `com.arangodb:arangodb-java-driver:6.16.1` 84 | 85 | ## [1.0.0] - 2021-12-11 86 | 87 | - Initial Release 88 | -------------------------------------------------------------------------------- /demo/src/main/scala/WriteDemo.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull 2 | import org.apache.spark.sql.functions._ 3 | import org.apache.spark.sql.types._ 4 | import org.apache.spark.sql.{Column, DataFrame} 5 | 6 | object WriteDemo { 7 | 8 | val saveOptions: Map[String, String] = Demo.options ++ Map( 9 | "table.shards" -> "9", 10 | "confirmTruncate" -> "true", 11 | "overwriteMode" -> "replace" 12 | ) 13 | 14 | def writeDemo(): Unit = { 15 | println("------------------") 16 | println("--- WRITE DEMO ---") 17 | println("------------------") 18 | 19 | println("Reading JSON files...") 20 | val nodesDF = Demo.spark.read.json(Demo.importPath + "/nodes.jsonl") 21 | .withColumn("_key", new Column(AssertNotNull(col("_key").expr))) 22 | .withColumn("releaseDate", unixTsToSparkDate(col("releaseDate"))) 23 | .withColumn("birthday", unixTsToSparkDate(col("birthday"))) 24 | .withColumn("lastModified", unixTsToSparkTs(col("lastModified"))) 25 | .persist() 26 | val edgesDF = Demo.spark.read.json(Demo.importPath + "/edges.jsonl") 27 | .withColumn("_key", new Column(AssertNotNull(col("_key").expr))) 28 | .withColumn("_from", new Column(AssertNotNull(concat(lit("persons/"), col("_from")).expr))) 29 | .withColumn("_to", new Column(AssertNotNull(concat(lit("movies/"), col("_to")).expr))) 30 | .persist() 31 | 32 | val personsDF = nodesDF 33 | .select(Schemas.personSchema.fieldNames.filter(_ != "_id").map(col): _*) 34 | .where("type = 'Person'") 35 | val moviesDF = nodesDF 36 | .select(Schemas.movieSchema.fieldNames.filter(_ != "_id").map(col): _*) 37 | .where("type = 'Movie'") 38 | val directedDF = edgesDF 39 | .select(Schemas.directedSchema.fieldNames.filter(_ != "_id").map(col): _*) 40 | .where("`$label` = 'DIRECTED'") 41 | val actedInDF = edgesDF 42 | .select(Schemas.actsInSchema.fieldNames.filter(_ != "_id").map(col): _*) 43 | .where("`$label` = 'ACTS_IN'") 44 | 45 | println("Writing 'persons' collection...") 46 | saveDF(personsDF, "persons") 47 | 48 | println("Writing 'movies' collection...") 49 | saveDF(moviesDF, "movies") 50 | 51 | println("Writing 'directed' edge collection...") 52 | saveDF(directedDF, "directed", "edge") 53 | 54 | println("Writing 'actedIn' edge collection...") 55 | saveDF(actedInDF, "actedIn", "edge") 56 | } 57 | 58 | def unixTsToSparkTs(c: Column): Column = (c.cast(LongType) / 1000).cast(TimestampType) 59 | 60 | def unixTsToSparkDate(c: Column): Column = unixTsToSparkTs(c).cast(DateType) 61 | 62 | def saveDF(df: DataFrame, tableName: String, tableType: String = "document"): Unit = 63 | df 64 | .write 65 | .mode("overwrite") 66 | .format("com.arangodb.spark") 67 | .options(saveOptions ++ Map( 68 | "table" -> tableName, 69 | "table.type" -> tableType 70 | )) 71 | .save() 72 | 73 | } 74 | -------------------------------------------------------------------------------- /python-integration-tests/integration/test_composite_filter.py: -------------------------------------------------------------------------------- 1 | import arango.database 2 | import pyspark.sql 3 | import pytest 4 | from pyspark.sql import SparkSession 5 | from pyspark.sql.functions import col 6 | from pyspark.sql.types import StructType, StructField, IntegerType, StringType, BooleanType 7 | 8 | from integration import test_basespark 9 | 10 | data = [ 11 | { 12 | "integer": 1, 13 | "string": "one", 14 | "bool": True 15 | }, 16 | { 17 | "integer": 2, 18 | "string": "two", 19 | "bool": True 20 | } 21 | ] 22 | 23 | schema = StructType([ 24 | StructField("integer", IntegerType(), nullable=False), 25 | StructField("string", StringType(), nullable=False), 26 | StructField("bool", BooleanType(), nullable=False), 27 | ]) 28 | 29 | table_name = "compositeFilter" 30 | 31 | 32 | @pytest.fixture(scope="module") 33 | def composite_df(database_conn: arango.database.StandardDatabase, spark: SparkSession) -> pyspark.sql.DataFrame: 34 | df = test_basespark.create_df(database_conn, spark, table_name, data, schema) 35 | yield df 36 | test_basespark.drop_table(database_conn, table_name) 37 | 38 | 39 | def test_or_filter(spark: SparkSession, composite_df: pyspark.sql.DataFrame): 40 | field_name = "integer" 41 | value = data[0][field_name] 42 | res = composite_df.filter((col(field_name) == 0) | (col(field_name) == 1)).collect() 43 | 44 | assert len(res) == 1 45 | assert res[0].asDict()[field_name] == value 46 | 47 | sql_res = spark.sql(f""" 48 | SELECT * FROM {table_name} 49 | WHERE {field_name} = 0 OR {field_name} = 1 50 | """).collect() 51 | 52 | assert len(sql_res) == 1 53 | assert sql_res[0].asDict()[field_name] == value 54 | 55 | 56 | def test_not_filter(spark: SparkSession, composite_df: pyspark.sql.DataFrame): 57 | field_name = "integer" 58 | value = data[0][field_name] 59 | res = composite_df.filter(~(col(field_name) == 2)).collect() 60 | 61 | assert len(res) == 1 62 | assert res[0].asDict()[field_name] == value 63 | 64 | sql_res = spark.sql(f""" 65 | SELECT * FROM {table_name} 66 | WHERE NOT {field_name} = 2 67 | """).collect() 68 | 69 | assert len(sql_res) == 1 70 | assert sql_res[0].asDict()[field_name] == value 71 | 72 | 73 | def test_or_and_filter(spark: SparkSession, composite_df: pyspark.sql.DataFrame): 74 | field_name_1 = "integer" 75 | value_1 = data[0][field_name_1] 76 | 77 | field_name_2 = "string" 78 | value_2 = data[0][field_name_2] 79 | 80 | res = composite_df.filter((col("bool") == False) | ((col(field_name_1) == value_1) & (col(field_name_2) == value_2))).collect() 81 | 82 | assert len(res) == 1 83 | assert res[0].asDict()[field_name_1] == value_1 84 | 85 | sql_res = spark.sql(f""" 86 | SELECT * FROM {table_name} 87 | WHERE bool = false OR ({field_name_1} = {value_1} AND {field_name_2} = "{value_2}") 88 | """).collect() 89 | 90 | assert len(sql_res) == 1 91 | assert sql_res[0].asDict()[field_name_1] == value_1 92 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.reader 2 | 3 | import org.apache.spark.internal.Logging 4 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ReadMode} 5 | import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter} 6 | import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx 7 | import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} 8 | import org.apache.spark.sql.sources.Filter 9 | import org.apache.spark.sql.types.StructType 10 | 11 | class ArangoScanBuilder(options: ArangoDBConf, tableSchema: StructType) extends ScanBuilder 12 | with SupportsPushDownFilters 13 | with SupportsPushDownRequiredColumns 14 | with Logging { 15 | 16 | private var readSchema: StructType = _ 17 | 18 | // fully or partially applied filters 19 | private var appliedPushableFilters: Array[PushableFilter] = Array() 20 | private var appliedSparkFilters: Array[Filter] = Array() 21 | 22 | override def build(): Scan = new ArangoScan(new PushDownCtx(readSchema, appliedPushableFilters), options) 23 | 24 | override def pushFilters(filters: Array[Filter]): Array[Filter] = { 25 | options.readOptions.readMode match { 26 | case ReadMode.Collection => pushFiltersReadModeCollection(filters) 27 | case ReadMode.Query => filters 28 | } 29 | } 30 | 31 | private def pushFiltersReadModeCollection(filters: Array[Filter]): Array[Filter] = { 32 | // filters related to columnNameOfCorruptRecord are not pushed down 33 | val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord) 34 | val ignoredFilters = filters.filter(isCorruptRecordFilter) 35 | val filtersBySupport = filters 36 | .filterNot(isCorruptRecordFilter) 37 | .map(f => (f, PushableFilter(f, tableSchema))) 38 | .groupBy(_._2.support()) 39 | 40 | val fullSupp = filtersBySupport.getOrElse(FilterSupport.FULL, Array()) 41 | val partialSupp = filtersBySupport.getOrElse(FilterSupport.PARTIAL, Array()) 42 | val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()).map(_._1) ++ ignoredFilters 43 | 44 | val appliedFilters = fullSupp ++ partialSupp 45 | appliedPushableFilters = appliedFilters.map(_._2) 46 | appliedSparkFilters = appliedFilters.map(_._1) 47 | 48 | if (fullSupp.nonEmpty) { 49 | logInfo(s"Filters fully applied in AQL:\n\t${fullSupp.map(_._1).mkString("\n\t")}") 50 | } 51 | if (partialSupp.nonEmpty) { 52 | logInfo(s"Filters partially applied in AQL:\n\t${partialSupp.map(_._1).mkString("\n\t")}") 53 | } 54 | if (noneSupp.nonEmpty) { 55 | logInfo(s"Filters not applied in AQL:\n\t${noneSupp.mkString("\n\t")}") 56 | } 57 | 58 | partialSupp.map(_._1) ++ noneSupp 59 | } 60 | 61 | override def pushedFilters(): Array[Filter] = appliedSparkFilters 62 | 63 | override def pruneColumns(requiredSchema: StructType): Unit = { 64 | this.readSchema = requiredSchema 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/CreateCollectionTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.write 2 | 3 | import com.arangodb.ArangoCollection 4 | import org.apache.spark.sql.{Row, SaveMode} 5 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, CollectionType} 6 | import org.apache.spark.sql.arangodb.datasource.BaseSparkTest 7 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 8 | import org.assertj.core.api.Assertions.assertThat 9 | import org.junit.jupiter.api.BeforeEach 10 | import org.junit.jupiter.params.ParameterizedTest 11 | import org.junit.jupiter.params.provider.MethodSource 12 | 13 | import scala.collection.JavaConverters._ 14 | 15 | 16 | class CreateCollectionTest extends BaseSparkTest { 17 | 18 | private val collectionName = "chessPlayersCreateCollection" 19 | private val collection: ArangoCollection = db.collection(collectionName) 20 | 21 | private val rows = Seq( 22 | Row("a/1", "b/1"), 23 | Row("a/2", "b/2"), 24 | Row("a/3", "b/3"), 25 | Row("a/4", "b/4"), 26 | Row("a/5", "b/5"), 27 | Row("a/6", "b/6") 28 | ) 29 | 30 | private val df = spark.createDataFrame(rows.asJava, StructType(Array( 31 | StructField("_from", StringType, nullable = false), 32 | StructField("_to", StringType, nullable = false) 33 | ))).repartition(3) 34 | 35 | @BeforeEach 36 | def beforeEach(): Unit = { 37 | if (collection.exists()) { 38 | collection.drop() 39 | } 40 | } 41 | 42 | @ParameterizedTest 43 | @MethodSource(Array("provideProtocolAndContentType")) 44 | def saveModeAppend(protocol: String, contentType: String): Unit = { 45 | df.write 46 | .format(BaseSparkTest.arangoDatasource) 47 | .mode(SaveMode.Append) 48 | .options(options + ( 49 | ArangoDBConf.COLLECTION -> collectionName, 50 | ArangoDBConf.PROTOCOL -> protocol, 51 | ArangoDBConf.CONTENT_TYPE -> contentType, 52 | ArangoDBConf.NUMBER_OF_SHARDS -> "5", 53 | ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name 54 | )) 55 | .save() 56 | 57 | if (isCluster) { 58 | assertThat(collection.getProperties.getNumberOfShards).isEqualTo(5) 59 | } 60 | assertThat(collection.getProperties.getType.getType).isEqualTo(com.arangodb.entity.CollectionType.EDGES.getType) 61 | } 62 | 63 | @ParameterizedTest 64 | @MethodSource(Array("provideProtocolAndContentType")) 65 | def saveModeOverwrite(protocol: String, contentType: String): Unit = { 66 | df.write 67 | .format(BaseSparkTest.arangoDatasource) 68 | .mode(SaveMode.Overwrite) 69 | .options(options + ( 70 | ArangoDBConf.COLLECTION -> collectionName, 71 | ArangoDBConf.PROTOCOL -> protocol, 72 | ArangoDBConf.CONTENT_TYPE -> contentType, 73 | ArangoDBConf.CONFIRM_TRUNCATE -> "true", 74 | ArangoDBConf.NUMBER_OF_SHARDS -> "5", 75 | ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name 76 | )) 77 | .save() 78 | 79 | if (isCluster) { 80 | assertThat(collection.getProperties.getNumberOfShards).isEqualTo(5) 81 | } 82 | assertThat(collection.getProperties.getType.getType).isEqualTo(com.arangodb.entity.CollectionType.EDGES.getType) 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * DISCLAIMER 3 | * 4 | * Copyright 2016 ArangoDB GmbH, Cologne, Germany 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | * 18 | * Copyright holder is ArangoDB GmbH, Cologne, Germany 19 | */ 20 | 21 | package org.apache.spark.sql.arangodb.commons 22 | 23 | import com.arangodb.entity 24 | 25 | sealed trait ReadMode 26 | 27 | object ReadMode { 28 | /** 29 | * Read from an Arango collection. The scan will be partitioned according to the collection shards. 30 | */ 31 | case object Collection extends ReadMode 32 | 33 | /** 34 | * Read executing a user query, without partitioning. 35 | */ 36 | case object Query extends ReadMode 37 | } 38 | 39 | sealed trait ContentType { 40 | val name: String 41 | } 42 | 43 | object ContentType { 44 | case object JSON extends ContentType { 45 | override val name: String = "json" 46 | } 47 | 48 | case object VPACK extends ContentType { 49 | override val name: String = "vpack" 50 | } 51 | 52 | def apply(value: String): ContentType = value match { 53 | case JSON.name => JSON 54 | case VPACK.name => VPACK 55 | case _ => throw new IllegalArgumentException(s"${ArangoDBConf.CONTENT_TYPE}: $value") 56 | } 57 | } 58 | 59 | sealed trait Protocol { 60 | val name: String 61 | } 62 | 63 | object Protocol { 64 | case object VST extends Protocol { 65 | override val name: String = "vst" 66 | } 67 | 68 | case object HTTP extends Protocol { 69 | override val name: String = "http" 70 | } 71 | 72 | case object HTTP2 extends Protocol { 73 | override val name: String = "http2" 74 | } 75 | 76 | def apply(value: String): Protocol = value match { 77 | case VST.name => VST 78 | case HTTP.name => HTTP 79 | case HTTP2.name => HTTP2 80 | case _ => throw new IllegalArgumentException(s"${ArangoDBConf.PROTOCOL}: $value") 81 | } 82 | } 83 | 84 | sealed trait CollectionType { 85 | val name: String 86 | 87 | def get(): entity.CollectionType 88 | } 89 | 90 | object CollectionType { 91 | case object DOCUMENT extends CollectionType { 92 | override val name: String = "document" 93 | 94 | override def get(): entity.CollectionType = entity.CollectionType.DOCUMENT 95 | } 96 | 97 | case object EDGE extends CollectionType { 98 | override val name: String = "edge" 99 | 100 | override def get(): entity.CollectionType = entity.CollectionType.EDGES 101 | } 102 | 103 | def apply(value: String): CollectionType = value match { 104 | case DOCUMENT.name => DOCUMENT 105 | case EDGE.name => EDGE 106 | case _ => throw new IllegalArgumentException(s"${ArangoDBConf.COLLECTION_TYPE}: $value") 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/StringFiltersTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import org.apache.spark.sql.DataFrame 4 | import org.apache.spark.sql.functions.col 5 | import org.apache.spark.sql.types._ 6 | import org.assertj.core.api.Assertions.assertThat 7 | import org.junit.jupiter.api.{AfterAll, BeforeAll, Test} 8 | 9 | import scala.collection.JavaConverters._ 10 | 11 | class StringFiltersTest extends BaseSparkTest { 12 | private val df = StringFiltersTest.df 13 | 14 | @Test 15 | def startsWith(): Unit = { 16 | val fieldName = "string" 17 | val value = StringFiltersTest.data.head(fieldName).asInstanceOf[String] 18 | val res = df.filter(col(fieldName).startsWith("Lorem")).collect() 19 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames)) 20 | assertThat(res).hasSize(1) 21 | assertThat(res.head(fieldName)).isEqualTo(value) 22 | val sqlRes = spark.sql( 23 | s""" 24 | |SELECT * FROM stringFilters 25 | |WHERE $fieldName LIKE "Lorem%" 26 | |""".stripMargin).collect() 27 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames)) 28 | assertThat(sqlRes).hasSize(1) 29 | assertThat(sqlRes.head(fieldName)).isEqualTo(value) 30 | } 31 | 32 | @Test 33 | def endsWith(): Unit = { 34 | val fieldName = "string" 35 | val value = StringFiltersTest.data.head(fieldName).asInstanceOf[String] 36 | val res = df.filter(col(fieldName).endsWith("amet")).collect() 37 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames)) 38 | assertThat(res).hasSize(1) 39 | assertThat(res.head(fieldName)).isEqualTo(value) 40 | val sqlRes = spark.sql( 41 | s""" 42 | |SELECT * FROM stringFilters 43 | |WHERE $fieldName LIKE "%amet" 44 | |""".stripMargin).collect() 45 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames)) 46 | assertThat(sqlRes).hasSize(1) 47 | assertThat(sqlRes.head(fieldName)).isEqualTo(value) 48 | } 49 | 50 | @Test 51 | def contains(): Unit = { 52 | val fieldName = "string" 53 | val value = StringFiltersTest.data.head(fieldName).asInstanceOf[String] 54 | val res = df.filter(col(fieldName).contains("dolor")).collect() 55 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames)) 56 | assertThat(res).hasSize(1) 57 | assertThat(res.head(fieldName)).isEqualTo(value) 58 | val sqlRes = spark.sql( 59 | s""" 60 | |SELECT * FROM stringFilters 61 | |WHERE $fieldName LIKE "%dolor%" 62 | |""".stripMargin).collect() 63 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames)) 64 | assertThat(sqlRes).hasSize(1) 65 | assertThat(sqlRes.head(fieldName)).isEqualTo(value) 66 | } 67 | 68 | } 69 | 70 | object StringFiltersTest { 71 | private var df: DataFrame = _ 72 | private val data: Seq[Map[String, Any]] = Seq( 73 | Map( 74 | "string" -> "Lorem ipsum dolor sit amet" 75 | ), 76 | Map( 77 | "string" -> "consectetur adipiscing elit" 78 | ) 79 | ) 80 | 81 | private val schema = StructType(Array( 82 | StructField("string", StringType, nullable = false) 83 | )) 84 | 85 | @BeforeAll 86 | def init(): Unit = { 87 | df = BaseSparkTest.createDF("stringFilters", data, schema) 88 | } 89 | 90 | @AfterAll 91 | def cleanup(): Unit = { 92 | BaseSparkTest.dropTable("stringFilters") 93 | } 94 | } -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # ArangoDB Spark Datasource Demo 2 | 3 | This demo is composed of 3 parts: 4 | 5 | - `WriteDemo`: reads the input json files as Spark Dataframes, applies conversions to map the data to Spark data types 6 | and writes the records into ArangoDB collections 7 | - `ReadDemo`: reads the ArangoDB collections created above as Spark Dataframes, specifying columns selection and records 8 | filters predicates or custom AQL queries 9 | - `ReadWriteDemo`: reads the ArangoDB collections created above as Spark Dataframes, applies projections and filtering, 10 | writes to a new ArangoDB collection 11 | 12 | There are demos available written in Scala & Python (using PySpark) as outlined below. 13 | 14 | ## Requirements 15 | 16 | This demo requires: 17 | 18 | - JDK 8, 11 or 17 19 | - `maven` 20 | - `docker` 21 | 22 | For the python demo, you will also need 23 | - `python` 24 | 25 | ## Prepare the environment 26 | 27 | Set environment variables: 28 | 29 | ```shell 30 | export ARANGO_SPARK_VERSION=1.9.0-SNAPSHOT 31 | ``` 32 | 33 | Start ArangoDB cluster with docker: 34 | 35 | ```shell 36 | SSL=true STARTER_MODE=cluster ./docker/start_db.sh 37 | ``` 38 | 39 | The deployed cluster will be accessible at [https://172.28.0.1:8529](http://172.28.0.1:8529) with username `root` and 40 | password `test`. 41 | 42 | Start Spark cluster: 43 | 44 | ```shell 45 | ./docker/start_spark.sh 46 | ``` 47 | 48 | ## Install locally 49 | 50 | NB: this is only needed for SNAPSHOT versions. 51 | 52 | ```shell 53 | mvn -f ../pom.xml install -Dmaven.test.skip=true -Dgpg.skip=true -Dmaven.javadoc.skip=true -Pscala-2.12 -Pspark-3.5 54 | ``` 55 | 56 | ## Run embedded 57 | 58 | Test the Spark application in embedded mode: 59 | 60 | ```shell 61 | mvn \ 62 | -Pscala-2.12 -Pspark-3.5 \ 63 | test 64 | ``` 65 | 66 | Test the Spark application against ArangoDB Oasis deployment: 67 | 68 | ```shell 69 | mvn \ 70 | -Pscala-2.12 -Pspark-3.5 \ 71 | -Dpassword= \ 72 | -Dendpoints= \ 73 | -Dssl.cert.value= \ 74 | test 75 | ``` 76 | 77 | ## Submit to Spark cluster 78 | 79 | Package the application: 80 | 81 | ```shell 82 | mvn package -Dmaven.test.skip=true -Pscala-2.12 -Pspark-3.5 83 | ``` 84 | 85 | Submit demo program: 86 | 87 | ```shell 88 | docker run -it --rm \ 89 | -v $(pwd):/demo \ 90 | -v $(pwd)/docker/.ivy2:/opt/bitnami/spark/.ivy2 \ 91 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \ 92 | --network arangodb \ 93 | docker.io/bitnamilegacy/spark:3.5.6 \ 94 | ./bin/spark-submit --master spark://spark-master:7077 \ 95 | --packages="com.arangodb:arangodb-spark-datasource-3.5_2.12:$ARANGO_SPARK_VERSION" \ 96 | --class Demo /demo/target/demo-$ARANGO_SPARK_VERSION.jar 97 | ``` 98 | 99 | ## Python(PySpark) Demo 100 | 101 | This demo requires the same environment setup as outlined above. 102 | Additionally, the python requirements will need to be installed as follows: 103 | ```shell 104 | pip install -r ./python-demo/requirements.txt 105 | ``` 106 | 107 | To run the PySpark demo, run 108 | ```shell 109 | python ./python-demo/demo.py \ 110 | --ssl-enabled=true \ 111 | --endpoints=172.28.0.1:8529,172.28.0.1:8539,172.28.0.1:8549 112 | ``` 113 | 114 | To run it against an Oasis deployment, run 115 | ```shell 116 | python ./python-demo/demo.py \ 117 | --password= \ 118 | --endpoints= \ 119 | --ssl-enabled=true \ 120 | --ssl-cert-value= 121 | ``` -------------------------------------------------------------------------------- /docker/server.pem: -------------------------------------------------------------------------------- 1 | Bag Attributes 2 | friendlyName: arangotest 3 | localKeyID: 54 69 6D 65 20 31 36 30 34 32 35 36 36 37 39 38 35 34 4 | Key Attributes: 5 | -----BEGIN PRIVATE KEY----- 6 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC1WiDnd4+uCmMG 7 | 539ZNZB8NwI0RZF3sUSQGPx3lkqaFTZVEzMZL76HYvdc9Qg7difyKyQ09RLSpMAL 8 | X9euSseD7bZGnfQH52BnKcT09eQ3wh7aVQ5sN2omygdHLC7X9usntxAfv7Nzmvdo 9 | gNXoJQyY/hSZff7RIqWH8NnAUKkjqOe6Bf5LDbxHKESmrFBxOCOnhcpvZWetwpiR 10 | dJVPwUn5P82CAZzfiBfmBZnB7D0l+/6Cv4jMuH26uAIcixnVekBQzl1RgwczuiZf 11 | 2MGO64vDMMJJWE9ClZF1uQuQrwXF6qwhuP1Hnkii6wNbTtPWlGSkqeutr004+Hzb 12 | f8KnRY4PAgMBAAECggEAKi1d/bdW2TldMpvgiFTm15zLjHCpllbKBWFqRj3T9+X7 13 | Duo6Nh9ehopD0YDDe2DNhYr3DsH4sLjUWVDfDpAhutMsU1wlBzmOuC+EuRv/CeDB 14 | 4DFr+0sgCwlti+YAtwWcR05SF7A0Ai0GYW2lUipbtbFSBSjCfM08BlPDsPCRhdM8 15 | DhBn3S45aP7oC8BdhG/etg+DfXW+/nyNwEcMCYG97bzXNjzYpCQjo/bTHdh2UPYM 16 | 4WEAqFzZ5jir8LVS3v7GqpqPmk6FnHJOJpfpOSZoPqnfpIw7SVlNsXHvDaHGcgYZ 17 | Xec7rLQlBuv4RZU7OlGJpK2Ng5kvS9q3nfqqn7YIMQKBgQDqSsYnE+k6CnrSpa2W 18 | B9W/+PChITgkA46XBUUjAueJ7yVZQQEOzl0VI6RoVBp3t66eO8uM9omO8/ogHXku 19 | Ei9UUIIfH4BsSP7G5A06UC/FgReDxwBfbRuS+lupnmc348vPDkFlJZ4hDgWflNev 20 | 7tpUbljSAqUea1VhdBy146V4qwKBgQDGJ6iL1+A9uUM+1UklOAPpPhTQ8ZQDRCj7 21 | 7IMVcbzWYvCMuVNXzOWuiz+VYr3IGCJZIbxbFDOHxGF4XKJnk0vm1qhQQME0PtAF 22 | i1jIfsxpj8KKJl9Uad+XLQCYRV8mIZlhsd/ErRJuz6FyqevKH3nFIb0ggF3x2d06 23 | odTHuj4ILQKBgCUsI/BDSne4/e+59aaeK52/w33tJVkhb1gqr+N0LIRH+ycEF0Tg 24 | HQijlQwwe9qOvBfC6PK+kuipcP/zbSyQGg5Ij7ycZOXJVxL7T9X2rv2pE7AGvNpn 25 | Fz7klfJ9fWbyr310h4+ivkoETYQaO3ZgcSeAMntvi/8djHhf0cZSDgjtAoGBAKvQ 26 | TUNcHjJGxfjgRLkB1dpSmwgEv7sJSaQOkiZw5TTauwq50nsJzYlHcg1cfYPW8Ulp 27 | iAFNBdVNwNn1MFgwjpqMO4rCawObBxIXnhbSYvmQzjStSvFNj7JsMdzWIcdVUMI1 28 | 0fmdu6LbY3ihvzIVkqcMNwnMZCjFKB6jnXTElu7NAoGAS0gNPD/bfzWAhZBBYp9/ 29 | SLGOvjHKrSVWGwDiqdAGuh6xg+1C3F+XpiITP6d3Wv3PCJ/Gia5isQPSMaXG+xTt 30 | 6huBgFlksHqr0tsQA9dcgGW7BDr5VhRq5/WinaLhGGy1R+i2zbDmQXgHbCO+RH/s 31 | bD9F4LZ3RoXmGHLW0IUggPw= 32 | -----END PRIVATE KEY----- 33 | Bag Attributes 34 | friendlyName: arangotest 35 | localKeyID: 54 69 6D 65 20 31 36 30 34 32 35 36 36 37 39 38 35 34 36 | subject=C = Unknown, ST = Unknown, L = Unknown, O = Unknown, OU = Unknown, CN = localhost 37 | 38 | issuer=C = Unknown, ST = Unknown, L = Unknown, O = Unknown, OU = Unknown, CN = localhost 39 | 40 | -----BEGIN CERTIFICATE----- 41 | MIIDezCCAmOgAwIBAgIEeDCzXzANBgkqhkiG9w0BAQsFADBuMRAwDgYDVQQGEwdV 42 | bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD 43 | VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRIwEAYDVQQDEwlsb2NhbGhv 44 | c3QwHhcNMjAxMTAxMTg1MTE5WhcNMzAxMDMwMTg1MTE5WjBuMRAwDgYDVQQGEwdV 45 | bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD 46 | VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRIwEAYDVQQDEwlsb2NhbGhv 47 | c3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC1WiDnd4+uCmMG539Z 48 | NZB8NwI0RZF3sUSQGPx3lkqaFTZVEzMZL76HYvdc9Qg7difyKyQ09RLSpMALX9eu 49 | SseD7bZGnfQH52BnKcT09eQ3wh7aVQ5sN2omygdHLC7X9usntxAfv7NzmvdogNXo 50 | JQyY/hSZff7RIqWH8NnAUKkjqOe6Bf5LDbxHKESmrFBxOCOnhcpvZWetwpiRdJVP 51 | wUn5P82CAZzfiBfmBZnB7D0l+/6Cv4jMuH26uAIcixnVekBQzl1RgwczuiZf2MGO 52 | 64vDMMJJWE9ClZF1uQuQrwXF6qwhuP1Hnkii6wNbTtPWlGSkqeutr004+Hzbf8Kn 53 | RY4PAgMBAAGjITAfMB0GA1UdDgQWBBTBrv9Awynt3C5IbaCNyOW5v4DNkTANBgkq 54 | hkiG9w0BAQsFAAOCAQEAIm9rPvDkYpmzpSIhR3VXG9Y71gxRDrqkEeLsMoEyqGnw 55 | /zx1bDCNeGg2PncLlW6zTIipEBooixIE9U7KxHgZxBy0Et6EEWvIUmnr6F4F+dbT 56 | D050GHlcZ7eOeqYTPYeQC502G1Fo4tdNi4lDP9L9XZpf7Q1QimRH2qaLS03ZFZa2 57 | tY7ah/RQqZL8Dkxx8/zc25sgTHVpxoK853glBVBs/ENMiyGJWmAXQayewY3EPt/9 58 | wGwV4KmU3dPDleQeXSUGPUISeQxFjy+jCw21pYviWVJTNBA9l5ny3GhEmcnOT/gQ 59 | HCvVRLyGLMbaMZ4JrPwb+aAtBgrgeiK4xeSMMvrbhw== 60 | -----END CERTIFICATE----- 61 | -------------------------------------------------------------------------------- /demo/docker/server.pem: -------------------------------------------------------------------------------- 1 | Bag Attributes 2 | friendlyName: arangotest 3 | localKeyID: 54 69 6D 65 20 31 36 30 34 32 35 36 36 37 39 38 35 34 4 | Key Attributes: 5 | -----BEGIN PRIVATE KEY----- 6 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC1WiDnd4+uCmMG 7 | 539ZNZB8NwI0RZF3sUSQGPx3lkqaFTZVEzMZL76HYvdc9Qg7difyKyQ09RLSpMAL 8 | X9euSseD7bZGnfQH52BnKcT09eQ3wh7aVQ5sN2omygdHLC7X9usntxAfv7Nzmvdo 9 | gNXoJQyY/hSZff7RIqWH8NnAUKkjqOe6Bf5LDbxHKESmrFBxOCOnhcpvZWetwpiR 10 | dJVPwUn5P82CAZzfiBfmBZnB7D0l+/6Cv4jMuH26uAIcixnVekBQzl1RgwczuiZf 11 | 2MGO64vDMMJJWE9ClZF1uQuQrwXF6qwhuP1Hnkii6wNbTtPWlGSkqeutr004+Hzb 12 | f8KnRY4PAgMBAAECggEAKi1d/bdW2TldMpvgiFTm15zLjHCpllbKBWFqRj3T9+X7 13 | Duo6Nh9ehopD0YDDe2DNhYr3DsH4sLjUWVDfDpAhutMsU1wlBzmOuC+EuRv/CeDB 14 | 4DFr+0sgCwlti+YAtwWcR05SF7A0Ai0GYW2lUipbtbFSBSjCfM08BlPDsPCRhdM8 15 | DhBn3S45aP7oC8BdhG/etg+DfXW+/nyNwEcMCYG97bzXNjzYpCQjo/bTHdh2UPYM 16 | 4WEAqFzZ5jir8LVS3v7GqpqPmk6FnHJOJpfpOSZoPqnfpIw7SVlNsXHvDaHGcgYZ 17 | Xec7rLQlBuv4RZU7OlGJpK2Ng5kvS9q3nfqqn7YIMQKBgQDqSsYnE+k6CnrSpa2W 18 | B9W/+PChITgkA46XBUUjAueJ7yVZQQEOzl0VI6RoVBp3t66eO8uM9omO8/ogHXku 19 | Ei9UUIIfH4BsSP7G5A06UC/FgReDxwBfbRuS+lupnmc348vPDkFlJZ4hDgWflNev 20 | 7tpUbljSAqUea1VhdBy146V4qwKBgQDGJ6iL1+A9uUM+1UklOAPpPhTQ8ZQDRCj7 21 | 7IMVcbzWYvCMuVNXzOWuiz+VYr3IGCJZIbxbFDOHxGF4XKJnk0vm1qhQQME0PtAF 22 | i1jIfsxpj8KKJl9Uad+XLQCYRV8mIZlhsd/ErRJuz6FyqevKH3nFIb0ggF3x2d06 23 | odTHuj4ILQKBgCUsI/BDSne4/e+59aaeK52/w33tJVkhb1gqr+N0LIRH+ycEF0Tg 24 | HQijlQwwe9qOvBfC6PK+kuipcP/zbSyQGg5Ij7ycZOXJVxL7T9X2rv2pE7AGvNpn 25 | Fz7klfJ9fWbyr310h4+ivkoETYQaO3ZgcSeAMntvi/8djHhf0cZSDgjtAoGBAKvQ 26 | TUNcHjJGxfjgRLkB1dpSmwgEv7sJSaQOkiZw5TTauwq50nsJzYlHcg1cfYPW8Ulp 27 | iAFNBdVNwNn1MFgwjpqMO4rCawObBxIXnhbSYvmQzjStSvFNj7JsMdzWIcdVUMI1 28 | 0fmdu6LbY3ihvzIVkqcMNwnMZCjFKB6jnXTElu7NAoGAS0gNPD/bfzWAhZBBYp9/ 29 | SLGOvjHKrSVWGwDiqdAGuh6xg+1C3F+XpiITP6d3Wv3PCJ/Gia5isQPSMaXG+xTt 30 | 6huBgFlksHqr0tsQA9dcgGW7BDr5VhRq5/WinaLhGGy1R+i2zbDmQXgHbCO+RH/s 31 | bD9F4LZ3RoXmGHLW0IUggPw= 32 | -----END PRIVATE KEY----- 33 | Bag Attributes 34 | friendlyName: arangotest 35 | localKeyID: 54 69 6D 65 20 31 36 30 34 32 35 36 36 37 39 38 35 34 36 | subject=C = Unknown, ST = Unknown, L = Unknown, O = Unknown, OU = Unknown, CN = localhost 37 | 38 | issuer=C = Unknown, ST = Unknown, L = Unknown, O = Unknown, OU = Unknown, CN = localhost 39 | 40 | -----BEGIN CERTIFICATE----- 41 | MIIDezCCAmOgAwIBAgIEeDCzXzANBgkqhkiG9w0BAQsFADBuMRAwDgYDVQQGEwdV 42 | bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD 43 | VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRIwEAYDVQQDEwlsb2NhbGhv 44 | c3QwHhcNMjAxMTAxMTg1MTE5WhcNMzAxMDMwMTg1MTE5WjBuMRAwDgYDVQQGEwdV 45 | bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD 46 | VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRIwEAYDVQQDEwlsb2NhbGhv 47 | c3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC1WiDnd4+uCmMG539Z 48 | NZB8NwI0RZF3sUSQGPx3lkqaFTZVEzMZL76HYvdc9Qg7difyKyQ09RLSpMALX9eu 49 | SseD7bZGnfQH52BnKcT09eQ3wh7aVQ5sN2omygdHLC7X9usntxAfv7NzmvdogNXo 50 | JQyY/hSZff7RIqWH8NnAUKkjqOe6Bf5LDbxHKESmrFBxOCOnhcpvZWetwpiRdJVP 51 | wUn5P82CAZzfiBfmBZnB7D0l+/6Cv4jMuH26uAIcixnVekBQzl1RgwczuiZf2MGO 52 | 64vDMMJJWE9ClZF1uQuQrwXF6qwhuP1Hnkii6wNbTtPWlGSkqeutr004+Hzbf8Kn 53 | RY4PAgMBAAGjITAfMB0GA1UdDgQWBBTBrv9Awynt3C5IbaCNyOW5v4DNkTANBgkq 54 | hkiG9w0BAQsFAAOCAQEAIm9rPvDkYpmzpSIhR3VXG9Y71gxRDrqkEeLsMoEyqGnw 55 | /zx1bDCNeGg2PncLlW6zTIipEBooixIE9U7KxHgZxBy0Et6EEWvIUmnr6F4F+dbT 56 | D050GHlcZ7eOeqYTPYeQC502G1Fo4tdNi4lDP9L9XZpf7Q1QimRH2qaLS03ZFZa2 57 | tY7ah/RQqZL8Dkxx8/zc25sgTHVpxoK853glBVBs/ENMiyGJWmAXQayewY3EPt/9 58 | wGwV4KmU3dPDleQeXSUGPUISeQxFjy+jCw21pYviWVJTNBA9l5ny3GhEmcnOT/gQ 59 | HCvVRLyGLMbaMZ4JrPwb+aAtBgrgeiK4xeSMMvrbhw== 60 | -----END CERTIFICATE----- 61 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/WriteResiliencyTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.write 2 | 3 | import com.arangodb.ArangoCollection 4 | import com.arangodb.model.OverwriteMode 5 | import org.apache.spark.sql.SaveMode 6 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 7 | import org.apache.spark.sql.arangodb.datasource.BaseSparkTest 8 | import org.assertj.core.api.Assertions.assertThat 9 | import org.junit.jupiter.api.{BeforeEach, Disabled} 10 | import org.junit.jupiter.params.ParameterizedTest 11 | import org.junit.jupiter.params.provider.MethodSource 12 | 13 | 14 | class WriteResiliencyTest extends BaseSparkTest { 15 | 16 | private val collectionName = "chessPlayersResiliency" 17 | private val collection: ArangoCollection = db.collection(collectionName) 18 | 19 | import spark.implicits._ 20 | 21 | private val df = Seq( 22 | ("Carlsen", "Magnus"), 23 | ("Caruana", "Fabiano"), 24 | ("Ding", "Liren"), 25 | ("Nepomniachtchi", "Ian"), 26 | ("Aronian", "Levon"), 27 | ("Grischuk", "Alexander"), 28 | ("Giri", "Anish"), 29 | ("Mamedyarov", "Shakhriyar"), 30 | ("So", "Wesley"), 31 | ("Radjabov", "Teimour") 32 | ).toDF("_key", "name") 33 | .repartition(6) 34 | 35 | @BeforeEach 36 | def beforeEach(): Unit = { 37 | if (collection.exists()) { 38 | collection.truncate() 39 | } else { 40 | collection.create() 41 | } 42 | } 43 | 44 | @Disabled("manual test only") 45 | @ParameterizedTest 46 | @MethodSource(Array("provideProtocolAndContentType")) 47 | def retryOnTimeout(protocol: String, contentType: String): Unit = { 48 | df.write 49 | .format(BaseSparkTest.arangoDatasource) 50 | .mode(SaveMode.Append) 51 | .options(options + ( 52 | ArangoDBConf.TIMEOUT -> "1", 53 | ArangoDBConf.ENDPOINTS -> BaseSparkTest.endpoints, 54 | ArangoDBConf.COLLECTION -> collectionName, 55 | ArangoDBConf.PROTOCOL -> protocol, 56 | ArangoDBConf.CONTENT_TYPE -> contentType, 57 | ArangoDBConf.CONFIRM_TRUNCATE -> "true", 58 | ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue 59 | )) 60 | .save() 61 | 62 | assertThat(collection.count().getCount).isEqualTo(10L) 63 | } 64 | 65 | @ParameterizedTest 66 | @MethodSource(Array("provideProtocolAndContentType")) 67 | def retryOnWrongHost(protocol: String, contentType: String): Unit = { 68 | retryOnBadHost(BaseSparkTest.endpoints + ",127.0.0.1:111", protocol, contentType) 69 | } 70 | 71 | @ParameterizedTest 72 | @MethodSource(Array("provideProtocolAndContentType")) 73 | def retryOnUnknownHost(protocol: String, contentType: String): Unit = { 74 | retryOnBadHost(BaseSparkTest.endpoints + ",wrongHost:8529", protocol, contentType) 75 | } 76 | 77 | private def retryOnBadHost(endpoints: String, protocol: String, contentType: String): Unit = { 78 | df.write 79 | .format(BaseSparkTest.arangoDatasource) 80 | .mode(SaveMode.Append) 81 | .options(options + ( 82 | ArangoDBConf.ENDPOINTS -> endpoints, 83 | ArangoDBConf.COLLECTION -> collectionName, 84 | ArangoDBConf.PROTOCOL -> protocol, 85 | ArangoDBConf.CONTENT_TYPE -> contentType, 86 | ArangoDBConf.CONFIRM_TRUNCATE -> "true", 87 | ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue, 88 | ArangoDBConf.MAX_ATTEMPTS -> "4", 89 | ArangoDBConf.MIN_RETRY_DELAY -> "20", 90 | ArangoDBConf.MAX_RETRY_DELAY -> "40" 91 | )) 92 | .save() 93 | 94 | assertThat(collection.count().getCount).isEqualTo(10L) 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /integration-tests/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | arangodb-spark-datasource 7 | com.arangodb 8 | 1.9.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | integration-tests 13 | 14 | 15 | target/site/jacoco-aggregate/jacoco.xml 16 | 17 | 18 | 19 | 20 | com.arangodb 21 | arangodb-spark-datasource-${spark.compat.version}_${scala.compat.version} 22 | ${project.version} 23 | test 24 | 25 | 26 | com.arangodb 27 | arangodb-spark-commons-${spark.compat.version}_${scala.compat.version} 28 | ${project.version} 29 | test 30 | 31 | 32 | com.arangodb 33 | jackson-serde-json 34 | 7.9.0 35 | test 36 | 37 | 38 | com.fasterxml.jackson.core 39 | jackson-core 40 | 41 | 42 | com.fasterxml.jackson.core 43 | jackson-databind 44 | 45 | 46 | com.fasterxml.jackson.core 47 | jackson-annotations 48 | 49 | 50 | 51 | 52 | com.arangodb 53 | jackson-serde-vpack 54 | 7.9.0 55 | test 56 | 57 | 58 | com.arangodb 59 | jackson-dataformat-velocypack 60 | 61 | 62 | 63 | 64 | io.qameta.allure 65 | allure-junit5 66 | 2.30.0 67 | test 68 | 69 | 70 | 71 | 72 | 73 | 74 | org.jacoco 75 | jacoco-maven-plugin 76 | 77 | 78 | report-aggregate 79 | verify 80 | 81 | report-aggregate 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /demo/python-demo/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from argparse import ArgumentParser 4 | from typing import Dict 5 | 6 | from pyspark.sql import SparkSession 7 | 8 | from read_write_demo import read_write_demo 9 | from read_demo import read_demo 10 | from write_demo import write_demo 11 | 12 | 13 | def create_spark_session() -> SparkSession: 14 | # Here we can initialize the spark session, and in doing so, 15 | # include the ArangoDB Spark DataSource package 16 | arango_spark_version = os.environ["ARANGO_SPARK_VERSION"] 17 | 18 | spark = SparkSession.builder \ 19 | .appName("ArangoDBPySparkDataTypesExample") \ 20 | .master("local[*]") \ 21 | .config("spark.jars.packages", f"com.arangodb:arangodb-spark-datasource-3.5_2.12:{arango_spark_version}") \ 22 | .getOrCreate() 23 | 24 | return spark 25 | 26 | 27 | def create_base_arangodb_datasource_opts(password: str, endpoints: str, ssl_enabled: str, ssl_cert_value: str) -> Dict[str, str]: 28 | return { 29 | "password": password, 30 | "endpoints": endpoints, 31 | "ssl.enabled": ssl_enabled, 32 | "ssl.cert.value": ssl_cert_value, 33 | "ssl.verifyHost": "false" 34 | } 35 | 36 | 37 | def main(): 38 | parser = ArgumentParser() 39 | parser.add_argument("--import-path", default=None) 40 | parser.add_argument("--password", default="test") 41 | parser.add_argument("--endpoints", default="localhost:8529") 42 | parser.add_argument("--ssl-enabled", default="false") 43 | parser.add_argument("--ssl-cert-value", default="LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlekNDQW1PZ0F3SUJBZ0lFZURDelh6QU5CZ2txaGtpRzl3MEJBUXNGQURCdU1SQXdEZ1lEVlFRR0V3ZFYKYm10dWIzZHVNUkF3RGdZRFZRUUlFd2RWYm10dWIzZHVNUkF3RGdZRFZRUUhFd2RWYm10dWIzZHVNUkF3RGdZRApWUVFLRXdkVmJtdHViM2R1TVJBd0RnWURWUVFMRXdkVmJtdHViM2R1TVJJd0VBWURWUVFERXdsc2IyTmhiR2h2CmMzUXdIaGNOTWpBeE1UQXhNVGcxTVRFNVdoY05NekF4TURNd01UZzFNVEU1V2pCdU1SQXdEZ1lEVlFRR0V3ZFYKYm10dWIzZHVNUkF3RGdZRFZRUUlFd2RWYm10dWIzZHVNUkF3RGdZRFZRUUhFd2RWYm10dWIzZHVNUkF3RGdZRApWUVFLRXdkVmJtdHViM2R1TVJBd0RnWURWUVFMRXdkVmJtdHViM2R1TVJJd0VBWURWUVFERXdsc2IyTmhiR2h2CmMzUXdnZ0VpTUEwR0NTcUdTSWIzRFFFQkFRVUFBNElCRHdBd2dnRUtBb0lCQVFDMVdpRG5kNCt1Q21NRzUzOVoKTlpCOE53STBSWkYzc1VTUUdQeDNsa3FhRlRaVkV6TVpMNzZIWXZkYzlRZzdkaWZ5S3lRMDlSTFNwTUFMWDlldQpTc2VEN2JaR25mUUg1MkJuS2NUMDllUTN3aDdhVlE1c04yb215Z2RITEM3WDl1c250eEFmdjdOem12ZG9nTlhvCkpReVkvaFNaZmY3UklxV0g4Tm5BVUtranFPZTZCZjVMRGJ4SEtFU21yRkJ4T0NPbmhjcHZaV2V0d3BpUmRKVlAKd1VuNVA4MkNBWnpmaUJmbUJabkI3RDBsKy82Q3Y0ak11SDI2dUFJY2l4blZla0JRemwxUmd3Y3p1aVpmMk1HTwo2NHZETU1KSldFOUNsWkYxdVF1UXJ3WEY2cXdodVAxSG5raWk2d05iVHRQV2xHU2txZXV0cjAwNCtIemJmOEtuClJZNFBBZ01CQUFHaklUQWZNQjBHQTFVZERnUVdCQlRCcnY5QXd5bnQzQzVJYmFDTnlPVzV2NEROa1RBTkJna3EKaGtpRzl3MEJBUXNGQUFPQ0FRRUFJbTlyUHZEa1lwbXpwU0loUjNWWEc5WTcxZ3hSRHJxa0VlTHNNb0V5cUdudwovengxYkRDTmVHZzJQbmNMbFc2elRJaXBFQm9vaXhJRTlVN0t4SGdaeEJ5MEV0NkVFV3ZJVW1ucjZGNEYrZGJUCkQwNTBHSGxjWjdlT2VxWVRQWWVRQzUwMkcxRm80dGROaTRsRFA5TDlYWnBmN1ExUWltUkgycWFMUzAzWkZaYTIKdFk3YWgvUlFxWkw4RGt4eDgvemMyNXNnVEhWcHhvSzg1M2dsQlZCcy9FTk1peUdKV21BWFFheWV3WTNFUHQvOQp3R3dWNEttVTNkUERsZVFlWFNVR1BVSVNlUXhGankrakN3MjFwWXZpV1ZKVE5CQTlsNW55M0doRW1jbk9UL2dRCkhDdlZSTHlHTE1iYU1aNEpyUHdiK2FBdEJncmdlaUs0eGVTTU12cmJodz09Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K") 44 | args = parser.parse_args() 45 | 46 | if args.import_path is None: 47 | args.import_path = pathlib.Path(__file__).resolve().parent.parent / "docker" / "import" 48 | 49 | spark = create_spark_session() 50 | base_opts = create_base_arangodb_datasource_opts(args.password, args.endpoints, args.ssl_enabled, args.ssl_cert_value) 51 | write_demo(spark, base_opts, args.import_path) 52 | read_demo(spark, base_opts) 53 | read_write_demo(spark, base_opts) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/EdgeSchemaValidationTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.write 2 | 3 | import com.arangodb.ArangoCollection 4 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, CollectionType} 5 | import org.apache.spark.sql.arangodb.datasource.BaseSparkTest 6 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 7 | import org.apache.spark.sql.{DataFrame, Row, SaveMode} 8 | import org.assertj.core.api.Assertions.{assertThat, catchThrowable} 9 | import org.assertj.core.api.ThrowableAssert.ThrowingCallable 10 | import org.junit.jupiter.api.BeforeEach 11 | import org.junit.jupiter.params.ParameterizedTest 12 | import org.junit.jupiter.params.provider.MethodSource 13 | 14 | import scala.collection.JavaConverters._ 15 | 16 | class EdgeSchemaValidationTest extends BaseSparkTest { 17 | 18 | private val collectionName = "edgeSchemaValidationTest" 19 | private val collection: ArangoCollection = db.collection(collectionName) 20 | 21 | private val rows = Seq( 22 | Row("k1", "from/from", "to/to"), 23 | Row("k2", "from/from", "to/to"), 24 | Row("k3", "from/from", "to/to") 25 | ) 26 | 27 | private val df = spark.createDataFrame(rows.asJava, StructType(Array( 28 | StructField("_key", StringType, nullable = false), 29 | StructField("_from", StringType, nullable = false), 30 | StructField("_to", StringType, nullable = false) 31 | ))) 32 | 33 | @BeforeEach 34 | def beforeEach(): Unit = { 35 | if (collection.exists()) { 36 | collection.drop() 37 | } 38 | } 39 | 40 | @ParameterizedTest 41 | @MethodSource(Array("provideProtocolAndContentType")) 42 | def write(protocol: String, contentType: String): Unit = { 43 | doWrite(df, protocol, contentType) 44 | assertThat(collection.count().getCount).isEqualTo(3L) 45 | } 46 | 47 | @ParameterizedTest 48 | @MethodSource(Array("provideProtocolAndContentType")) 49 | def dfWithNullableFromFieldShouldFail(protocol: String, contentType: String): Unit = { 50 | val nullableFromSchema = StructType(df.schema.map(p => 51 | if (p.name == "_from") StructField(p.name, p.dataType) 52 | else p 53 | )) 54 | val dfWithNullableFrom = spark.createDataFrame(df.rdd, nullableFromSchema) 55 | val thrown = catchThrowable(new ThrowingCallable() { 56 | override def call(): Unit = doWrite(dfWithNullableFrom, protocol, contentType) 57 | }) 58 | 59 | assertThat(thrown).isInstanceOf(classOf[IllegalArgumentException]) 60 | assertThat(thrown).hasMessageContaining("_from") 61 | } 62 | 63 | @ParameterizedTest 64 | @MethodSource(Array("provideProtocolAndContentType")) 65 | def dfWithNullableToFieldShouldFail(protocol: String, contentType: String): Unit = { 66 | val nullableFromSchema = StructType(df.schema.map(p => 67 | if (p.name == "_to") StructField(p.name, p.dataType) 68 | else p 69 | )) 70 | val dfWithNullableFrom = spark.createDataFrame(df.rdd, nullableFromSchema) 71 | val thrown = catchThrowable(new ThrowingCallable() { 72 | override def call(): Unit = doWrite(dfWithNullableFrom, protocol, contentType) 73 | }) 74 | 75 | assertThat(thrown).isInstanceOf(classOf[IllegalArgumentException]) 76 | assertThat(thrown).hasMessageContaining("_to") 77 | } 78 | 79 | private def doWrite(testDF: DataFrame, protocol: String, contentType: String): Unit = { 80 | testDF.write 81 | .format(BaseSparkTest.arangoDatasource) 82 | .mode(SaveMode.Append) 83 | .options(options + ( 84 | ArangoDBConf.COLLECTION -> collectionName, 85 | ArangoDBConf.PROTOCOL -> protocol, 86 | ArangoDBConf.CONTENT_TYPE -> contentType, 87 | ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name 88 | )) 89 | .save() 90 | } 91 | 92 | } 93 | -------------------------------------------------------------------------------- /docker/start_db.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Configuration environment variables: 4 | # STARTER_MODE: (single|cluster|activefailover), default single 5 | # DOCKER_IMAGE: ArangoDB docker image, default docker.io/arangodb/enterprise:latest 6 | # STARTER_DOCKER_IMAGE: ArangoDB Starter docker image, default docker.io/arangodb/arangodb-starter:latest 7 | # SSL: (true|false), default false 8 | # ARANGO_LICENSE_KEY: only required for ArangoDB Enterprise 9 | 10 | # EXAMPLE: 11 | # STARTER_MODE=cluster SSL=true ./start_db.sh 12 | 13 | STARTER_MODE=${STARTER_MODE:=single} 14 | DOCKER_IMAGE=${DOCKER_IMAGE:=docker.io/arangodb/enterprise:latest} 15 | STARTER_DOCKER_IMAGE=${STARTER_DOCKER_IMAGE:=docker.io/arangodb/arangodb-starter:latest} 16 | SSL=${SSL:=false} 17 | COMPRESSION=${COMPRESSION:=false} 18 | 19 | GW=172.28.0.1 20 | docker network create arangodb --subnet 172.28.0.0/16 21 | 22 | # exit when any command fails 23 | set -e 24 | 25 | docker pull $STARTER_DOCKER_IMAGE 26 | docker pull $DOCKER_IMAGE 27 | 28 | LOCATION=$(pwd)/$(dirname "$0") 29 | AUTHORIZATION_HEADER=$(cat "$LOCATION"/jwtHeader) 30 | 31 | STARTER_ARGS= 32 | SCHEME=http 33 | ARANGOSH_SCHEME=http+tcp 34 | COORDINATORS=("$GW:8529" "$GW:8539" "$GW:8549") 35 | 36 | if [ "$STARTER_MODE" == "single" ]; then 37 | COORDINATORS=("$GW:8529") 38 | fi 39 | 40 | if [ "$SSL" == "true" ]; then 41 | STARTER_ARGS="$STARTER_ARGS --ssl.keyfile=/data/server.pem" 42 | SCHEME=https 43 | ARANGOSH_SCHEME=http+ssl 44 | fi 45 | 46 | if [ "$COMPRESSION" == "true" ]; then 47 | STARTER_ARGS="${STARTER_ARGS} --all.http.compress-response-threshold=1" 48 | fi 49 | 50 | # data volume 51 | docker create -v /data --name arangodb-data alpine:3 /bin/true 52 | docker cp "$LOCATION"/jwtSecret arangodb-data:/data 53 | docker cp "$LOCATION"/server.pem arangodb-data:/data 54 | 55 | docker run -d \ 56 | --name=adb \ 57 | -p 8528:8528 \ 58 | --volumes-from arangodb-data \ 59 | -v /var/run/docker.sock:/var/run/docker.sock \ 60 | -e ARANGO_LICENSE_KEY="$ARANGO_LICENSE_KEY" \ 61 | $STARTER_DOCKER_IMAGE \ 62 | $STARTER_ARGS \ 63 | --docker.net-mode=default \ 64 | --docker.container=adb \ 65 | --auth.jwt-secret=/data/jwtSecret \ 66 | --starter.address="${GW}" \ 67 | --docker.image="${DOCKER_IMAGE}" \ 68 | --starter.local --starter.mode=${STARTER_MODE} --all.log.level=debug --all.log.output=+ --log.verbose \ 69 | --all.server.descriptors-minimum=1024 --all.javascript.allow-admin-execute=true 70 | 71 | 72 | wait_server() { 73 | # shellcheck disable=SC2091 74 | until $(curl --output /dev/null --insecure --fail --silent --head -i -H "$AUTHORIZATION_HEADER" "$SCHEME://$1/_api/version"); do 75 | printf '.' 76 | sleep 1 77 | done 78 | } 79 | 80 | echo "Waiting..." 81 | 82 | for a in ${COORDINATORS[*]} ; do 83 | wait_server "$a" 84 | done 85 | 86 | set +e 87 | for a in ${COORDINATORS[*]} ; do 88 | echo "" 89 | echo "Setting username and password..." 90 | docker run --rm ${DOCKER_IMAGE} arangosh --server.endpoint="$ARANGOSH_SCHEME://$a" --server.authentication=false --javascript.execute-string='require("org/arangodb/users").update("root", "test")' 91 | done 92 | set -e 93 | 94 | for a in ${COORDINATORS[*]} ; do 95 | echo "" 96 | echo "Requesting endpoint version..." 97 | curl -u root:test --insecure --fail "$SCHEME://$a/_api/version" 98 | done 99 | 100 | echo "" 101 | echo "" 102 | echo "Done, your deployment is reachable at: " 103 | for a in ${COORDINATORS[*]} ; do 104 | echo "$SCHEME://$a" 105 | echo "" 106 | done 107 | 108 | if [ "$STARTER_MODE" == "activefailover" ]; then 109 | LEADER=$("$LOCATION"/find_active_endpoint.sh) 110 | echo "Leader: $SCHEME://$LEADER" 111 | echo "" 112 | fi 113 | -------------------------------------------------------------------------------- /demo/docker/start_db.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Configuration environment variables: 4 | # STARTER_MODE: (single|cluster|activefailover), default single 5 | # DOCKER_IMAGE: ArangoDB docker image, default docker.io/arangodb/enterprise:latest 6 | # STARTER_DOCKER_IMAGE: ArangoDB Starter docker image, default docker.io/arangodb/arangodb-starter:latest 7 | # SSL: (true|false), default false 8 | # ARANGO_LICENSE_KEY: only required for ArangoDB Enterprise 9 | 10 | # EXAMPLE: 11 | # STARTER_MODE=cluster SSL=true ./start_db.sh 12 | 13 | STARTER_MODE=${STARTER_MODE:=single} 14 | DOCKER_IMAGE=${DOCKER_IMAGE:=docker.io/arangodb/enterprise:latest} 15 | STARTER_DOCKER_IMAGE=${STARTER_DOCKER_IMAGE:=docker.io/arangodb/arangodb-starter:latest} 16 | SSL=${SSL:=false} 17 | COMPRESSION=${COMPRESSION:=false} 18 | 19 | GW=172.28.0.1 20 | docker network create arangodb --subnet 172.28.0.0/16 21 | 22 | # exit when any command fails 23 | set -e 24 | 25 | docker pull $STARTER_DOCKER_IMAGE 26 | docker pull $DOCKER_IMAGE 27 | 28 | LOCATION=$(pwd)/$(dirname "$0") 29 | AUTHORIZATION_HEADER=$(cat "$LOCATION"/jwtHeader) 30 | 31 | STARTER_ARGS= 32 | SCHEME=http 33 | ARANGOSH_SCHEME=http+tcp 34 | COORDINATORS=("$GW:8529" "$GW:8539" "$GW:8549") 35 | 36 | if [ "$STARTER_MODE" == "single" ]; then 37 | COORDINATORS=("$GW:8529") 38 | fi 39 | 40 | if [ "$SSL" == "true" ]; then 41 | STARTER_ARGS="$STARTER_ARGS --ssl.keyfile=/data/server.pem" 42 | SCHEME=https 43 | ARANGOSH_SCHEME=http+ssl 44 | fi 45 | 46 | if [ "$COMPRESSION" == "true" ]; then 47 | STARTER_ARGS="${STARTER_ARGS} --all.http.compress-response-threshold=1" 48 | fi 49 | 50 | # data volume 51 | docker create -v /data --name arangodb-data alpine:3 /bin/true 52 | docker cp "$LOCATION"/jwtSecret arangodb-data:/data 53 | docker cp "$LOCATION"/server.pem arangodb-data:/data 54 | 55 | docker run -d \ 56 | --name=adb \ 57 | -p 8528:8528 \ 58 | --volumes-from arangodb-data \ 59 | -v /var/run/docker.sock:/var/run/docker.sock \ 60 | -e ARANGO_LICENSE_KEY="$ARANGO_LICENSE_KEY" \ 61 | $STARTER_DOCKER_IMAGE \ 62 | $STARTER_ARGS \ 63 | --docker.net-mode=default \ 64 | --docker.container=adb \ 65 | --auth.jwt-secret=/data/jwtSecret \ 66 | --starter.address="${GW}" \ 67 | --docker.image="${DOCKER_IMAGE}" \ 68 | --starter.local --starter.mode=${STARTER_MODE} --all.log.level=debug --all.log.output=+ --log.verbose \ 69 | --all.server.descriptors-minimum=1024 --all.javascript.allow-admin-execute=true 70 | 71 | 72 | wait_server() { 73 | # shellcheck disable=SC2091 74 | until $(curl --output /dev/null --insecure --fail --silent --head -i -H "$AUTHORIZATION_HEADER" "$SCHEME://$1/_api/version"); do 75 | printf '.' 76 | sleep 1 77 | done 78 | } 79 | 80 | echo "Waiting..." 81 | 82 | for a in ${COORDINATORS[*]} ; do 83 | wait_server "$a" 84 | done 85 | 86 | set +e 87 | for a in ${COORDINATORS[*]} ; do 88 | echo "" 89 | echo "Setting username and password..." 90 | docker run --rm ${DOCKER_IMAGE} arangosh --server.endpoint="$ARANGOSH_SCHEME://$a" --server.authentication=false --javascript.execute-string='require("org/arangodb/users").update("root", "test")' 91 | done 92 | set -e 93 | 94 | for a in ${COORDINATORS[*]} ; do 95 | echo "" 96 | echo "Requesting endpoint version..." 97 | curl -u root:test --insecure --fail "$SCHEME://$a/_api/version" 98 | done 99 | 100 | echo "" 101 | echo "" 102 | echo "Done, your deployment is reachable at: " 103 | for a in ${COORDINATORS[*]} ; do 104 | echo "$SCHEME://$a" 105 | echo "" 106 | done 107 | 108 | if [ "$STARTER_MODE" == "activefailover" ]; then 109 | LEADER=$("$LOCATION"/find_active_endpoint.sh) 110 | echo "Leader: $SCHEME://$LEADER" 111 | echo "" 112 | fi 113 | -------------------------------------------------------------------------------- /demo/python-demo/write_demo.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pathlib 3 | from typing import Dict 4 | 5 | from pyspark import pandas as ps 6 | from pyspark.sql import SparkSession, functions as f 7 | from pyspark.sql.types import StructType 8 | 9 | from utils import combine_dicts 10 | from schemas import person_schema, movie_schema, directed_schema, acts_in_schema 11 | 12 | 13 | def save_df(ps_df, table_name: str, options: Dict[str, str], table_type: str = None) -> None: 14 | if not table_type: 15 | table_type = "document" 16 | 17 | all_opts = combine_dicts([options, { 18 | "table.shards": "9", 19 | "confirmTruncate": "true", 20 | "overwriteMode": "replace", 21 | "table": table_name, 22 | "table.type": table_type 23 | }]) 24 | 25 | ps_df.to_spark()\ 26 | .write\ 27 | .mode("overwrite")\ 28 | .format("com.arangodb.spark")\ 29 | .options(**all_opts)\ 30 | .save() 31 | 32 | 33 | def write_demo(spark: SparkSession, save_opts: Dict[str, str], import_path_str: str): 34 | import_path = pathlib.Path(import_path_str) 35 | 36 | print("Read Nodes from JSONL using Pandas on Spark API") 37 | nodes_pd_df = ps.read_json(str(import_path / "nodes.jsonl")) 38 | nodes_pd_df = nodes_pd_df[nodes_pd_df["_key"].notnull()] 39 | nodes_pd_df["releaseDate"] = ps.to_datetime(nodes_pd_df["releaseDate"], unit="ms") 40 | nodes_pd_df["birthday"] = ps.to_datetime(nodes_pd_df["birthday"], unit="ms") 41 | 42 | def convert_to_timestamp(to_modify, column): 43 | tz_aware_datetime = datetime.datetime.utcfromtimestamp( 44 | int(to_modify[column])/1000 45 | ).replace(tzinfo=datetime.timezone.utc).astimezone(tz=None) 46 | tz_naive = tz_aware_datetime.replace(tzinfo=None) 47 | to_modify[column] = tz_naive 48 | return to_modify 49 | 50 | nodes_pd_df = nodes_pd_df.apply(convert_to_timestamp, axis=1, args=("lastModified",)) 51 | 52 | nodes_df = nodes_pd_df.to_spark() 53 | nodes_pd_df = nodes_df\ 54 | .withColumn("releaseDate", f.to_date(nodes_df["releaseDate"])) \ 55 | .withColumn("birthday", f.to_date(nodes_df["birthday"])) \ 56 | .to_pandas_on_spark() 57 | 58 | print("Read Edges from JSONL using PySpark API") 59 | edges_df = spark.read.json(str(import_path / "edges.jsonl")) 60 | # apply the schema to change nullability of _key, _from, and _to columns in schema 61 | edges_pd_df = edges_df.to_pandas_on_spark() 62 | edges_pd_df["_from"] = "persons/" + edges_pd_df["_from"] 63 | edges_pd_df["_to"] = "movies/" + edges_pd_df["_to"] 64 | 65 | print("Create the collection dfs") 66 | persons_df = nodes_pd_df[nodes_pd_df["type"] == "Person"][person_schema.fieldNames()[1:]] 67 | movies_df = nodes_pd_df[nodes_pd_df["type"] == "Movie"][movie_schema.fieldNames()[1:]] 68 | directed_df = edges_pd_df[edges_pd_df["$label"] == "DIRECTED"][directed_schema.fieldNames()[1:]] 69 | acted_in_df = edges_pd_df[edges_pd_df["$label"] == "ACTS_IN"][acts_in_schema.fieldNames()[1:]] 70 | 71 | # _from and _to need to be set with nullable=False in the schema in order for it to work 72 | directed_df = spark.createDataFrame(directed_df.to_spark().rdd, StructType( 73 | directed_schema.fields[1:])).to_pandas_on_spark() 74 | acted_in_df = spark.createDataFrame(acted_in_df.to_spark().rdd, StructType( 75 | acts_in_schema.fields[1:])).to_pandas_on_spark() 76 | 77 | print("writing the persons collection") 78 | save_df(persons_df, "persons", save_opts) 79 | print("writing the movies collection") 80 | save_df(movies_df, "movies", save_opts) 81 | print("writing the 'directed' edge collection") 82 | save_df(directed_df, "directed", save_opts, "edge") 83 | print("writing the 'actedIn' collection") 84 | save_df(acted_in_df, "actedIn", save_opts, "edge") 85 | -------------------------------------------------------------------------------- /arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource.writer 2 | 3 | import com.arangodb.entity.CollectionType 4 | import com.arangodb.model.OverwriteMode 5 | import org.apache.spark.internal.Logging 6 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf, ContentType} 7 | import org.apache.spark.sql.connector.write.{BatchWrite, SupportsTruncate, WriteBuilder} 8 | import org.apache.spark.sql.types.{DecimalType, StringType, StructType} 9 | import org.apache.spark.sql.{AnalysisException, SaveMode} 10 | 11 | class ArangoWriterBuilder(schema: StructType, options: ArangoDBConf) 12 | extends WriteBuilder with SupportsTruncate with Logging { 13 | 14 | private var mode: SaveMode = SaveMode.Append 15 | validateConfig() 16 | 17 | override def buildForBatch(): BatchWrite = { 18 | val client = ArangoClient(options) 19 | if (!client.collectionExists()) { 20 | client.createCollection() 21 | } 22 | client.shutdown() 23 | 24 | val updatedOptions = options.updated(ArangoDBConf.OVERWRITE_MODE, mode match { 25 | case SaveMode.Append => options.writeOptions.overwriteMode.getValue 26 | case _ => OverwriteMode.ignore.getValue 27 | }) 28 | 29 | logSummary(updatedOptions) 30 | new ArangoBatchWriter(schema, updatedOptions, mode) 31 | } 32 | 33 | override def truncate(): WriteBuilder = { 34 | mode = SaveMode.Overwrite 35 | if (options.writeOptions.confirmTruncate) { 36 | val client = ArangoClient(options) 37 | if (client.collectionExists()) { 38 | client.truncate() 39 | } else { 40 | client.createCollection() 41 | } 42 | client.shutdown() 43 | this 44 | } else { 45 | throw new AnalysisException( 46 | "You are attempting to use overwrite mode which will truncate this collection prior to inserting data. If " + 47 | "you just want to change data already in the collection set save mode 'append' and " + 48 | s"'overwrite.mode=(replace|update)'. To actually truncate set '${ArangoDBConf.CONFIRM_TRUNCATE}=true'.") 49 | } 50 | } 51 | 52 | private def validateConfig(): Unit = { 53 | if (options.driverOptions.contentType == ContentType.JSON && hasDecimalTypeFields) { 54 | throw new UnsupportedOperationException("Cannot write DecimalType when using contentType=json") 55 | } 56 | 57 | if (options.writeOptions.collectionType == CollectionType.EDGES && 58 | !schema.exists(p => p.name == "_from" && p.dataType == StringType && !p.nullable) 59 | ) { 60 | throw new IllegalArgumentException("Writing edge collection requires non nullable string field named _from.") 61 | } 62 | 63 | if (options.writeOptions.collectionType == CollectionType.EDGES && 64 | !schema.exists(p => p.name == "_to" && p.dataType == StringType && !p.nullable) 65 | ) { 66 | throw new IllegalArgumentException("Writing edge collection requires non nullable string field named _to.") 67 | } 68 | } 69 | 70 | private def hasDecimalTypeFields: Boolean = 71 | schema.existsRecursively { 72 | case _: DecimalType => true 73 | case _ => false 74 | } 75 | 76 | private def logSummary(updatedOptions: ArangoDBConf): Unit = { 77 | val canRetry = ArangoDataWriter.canRetry(schema, updatedOptions) 78 | 79 | logInfo(s"Using save mode: $mode") 80 | logInfo(s"Using write configuration: ${updatedOptions.writeOptions}") 81 | logInfo(s"Using mapping configuration: ${updatedOptions.mappingOptions}") 82 | logInfo(s"Can retry: $canRetry") 83 | 84 | if (!canRetry) { 85 | logWarning( 86 | """The provided configuration does not allow idempotent requests: write failures will not be retried and lead 87 | |to task failure. Speculative task executions could fail or write incorrect data.""" 88 | .stripMargin.replaceAll("\n", "") 89 | ) 90 | } 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/CompositeFilterTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import org.apache.spark.sql.DataFrame 4 | import org.apache.spark.sql.functions.{col, not} 5 | import org.apache.spark.sql.types._ 6 | import org.assertj.core.api.Assertions.assertThat 7 | import org.junit.jupiter.api.{AfterAll, BeforeAll, Test} 8 | 9 | import scala.collection.JavaConverters._ 10 | 11 | class CompositeFilterTest extends BaseSparkTest { 12 | private val df = CompositeFilterTest.df 13 | 14 | @Test 15 | def orFilter(): Unit = { 16 | val fieldName = "integer" 17 | val value = CompositeFilterTest.data.head(fieldName) 18 | val res = df.filter(col(fieldName).equalTo(0) or col(fieldName).equalTo(1)).collect() 19 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames)) 20 | assertThat(res).hasSize(1) 21 | assertThat(res.head(fieldName)).isEqualTo(value) 22 | val sqlRes = spark.sql( 23 | s""" 24 | |SELECT * FROM compositeFilter 25 | |WHERE $fieldName = 0 OR $fieldName = 1 26 | |""".stripMargin).collect() 27 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames)) 28 | assertThat(sqlRes).hasSize(1) 29 | assertThat(sqlRes.head(fieldName)).isEqualTo(value) 30 | } 31 | 32 | @Test 33 | def notFilter(): Unit = { 34 | val fieldName = "integer" 35 | val value = CompositeFilterTest.data.head(fieldName) 36 | val res = df.filter(not(col(fieldName).equalTo(2))).collect() 37 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames)) 38 | assertThat(res).hasSize(1) 39 | assertThat(res.head(fieldName)).isEqualTo(value) 40 | val sqlRes = spark.sql( 41 | s""" 42 | |SELECT * FROM compositeFilter 43 | |WHERE NOT ($fieldName = 2) 44 | |""".stripMargin).collect() 45 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames)) 46 | assertThat(sqlRes).hasSize(1) 47 | assertThat(sqlRes.head(fieldName)).isEqualTo(value) 48 | } 49 | 50 | @Test 51 | def orAndFilter(): Unit = { 52 | val fieldName1 = "integer" 53 | val value1 = CompositeFilterTest.data.head(fieldName1) 54 | 55 | val fieldName2 = "string" 56 | val value2 = CompositeFilterTest.data.head(fieldName2) 57 | 58 | val res = df.filter(col("bool").equalTo(false) or (col(fieldName1).equalTo(value1) and col(fieldName2).equalTo(value2))).collect() 59 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames)) 60 | assertThat(res).hasSize(1) 61 | assertThat(res.head(fieldName1)).isEqualTo(value1) 62 | val sqlRes = spark.sql( 63 | s""" 64 | |SELECT * FROM compositeFilter 65 | |WHERE bool = false OR ($fieldName1 = $value1 AND $fieldName2 = "$value2") 66 | |""".stripMargin).collect() 67 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames)) 68 | assertThat(sqlRes).hasSize(1) 69 | assertThat(sqlRes.head(fieldName1)).isEqualTo(value1) 70 | } 71 | 72 | } 73 | 74 | object CompositeFilterTest { 75 | private var df: DataFrame = _ 76 | private val data: Seq[Map[String, Any]] = Seq( 77 | Map( 78 | "integer" -> 1, 79 | "string" -> "one", 80 | "bool" -> true 81 | ), 82 | Map( 83 | "integer" -> 2, 84 | "string" -> "two", 85 | "bool" -> true 86 | ) 87 | ) 88 | 89 | private val schema = StructType(Array( 90 | // atomic types 91 | StructField("integer", IntegerType, nullable = false), 92 | StructField("string", StringType, nullable = false), 93 | StructField("bool", BooleanType, nullable = false) 94 | )) 95 | 96 | @BeforeAll 97 | def init(): Unit = { 98 | df = BaseSparkTest.createDF("compositeFilter", data, schema) 99 | } 100 | 101 | @AfterAll 102 | def cleanup(): Unit = { 103 | BaseSparkTest.dropTable("compositeFilter") 104 | } 105 | } -------------------------------------------------------------------------------- /docker/start_db_macos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Configuration environment variables: 4 | # STARTER_MODE: (single|cluster|activefailover), default single 5 | # DOCKER_IMAGE: ArangoDB docker image, default docker.io/arangodb/enterprise:latest 6 | # SSL: (true|false), default false 7 | # DATABASE_EXTENDED_NAMES: (true|false), default false 8 | # ARANGO_LICENSE_KEY: only required for ArangoDB Enterprise 9 | 10 | # EXAMPLE: 11 | # STARTER_MODE=cluster SSL=true ./start_db.sh 12 | 13 | STARTER_MODE=${STARTER_MODE:=single} 14 | DOCKER_IMAGE=${DOCKER_IMAGE:=docker.io/arangodb/enterprise:latest} 15 | SSL=${SSL:=false} 16 | DATABASE_EXTENDED_NAMES=${DATABASE_EXTENDED_NAMES:=false} 17 | 18 | STARTER_DOCKER_IMAGE=docker.io/arangodb/arangodb-starter:latest 19 | GW=172.28.0.1 20 | LOCALGW=localhost 21 | docker network create arangodb --subnet 172.28.0.0/16 22 | 23 | # exit when any command fails 24 | set -e 25 | 26 | docker pull $STARTER_DOCKER_IMAGE 27 | docker pull $DOCKER_IMAGE 28 | 29 | LOCATION=$(pwd)/$(dirname "$0") 30 | 31 | echo "Averysecretword" > "$LOCATION"/jwtSecret 32 | docker run --rm -v "$LOCATION"/jwtSecret:/jwtSecret "$STARTER_DOCKER_IMAGE" auth header --auth.jwt-secret /jwtSecret > "$LOCATION"/jwtHeader 33 | AUTHORIZATION_HEADER=$(cat "$LOCATION"/jwtHeader) 34 | 35 | STARTER_ARGS= 36 | SCHEME=http 37 | ARANGOSH_SCHEME=http+tcp 38 | COORDINATORS=("$LOCALGW:8529" "$LOCALGW:8539" "$LOCALGW:8549") 39 | COORDINATORSINTERNAL=("$GW:8529" "$GW:8539" "$GW:8549") 40 | 41 | if [ "$STARTER_MODE" == "single" ]; then 42 | COORDINATORS=("$LOCALGW:8529") 43 | COORDINATORSINTERNAL=("$GW:8529") 44 | fi 45 | 46 | if [ "$SSL" == "true" ]; then 47 | STARTER_ARGS="$STARTER_ARGS --ssl.keyfile=server.pem" 48 | SCHEME=https 49 | ARANGOSH_SCHEME=http+ssl 50 | fi 51 | 52 | if [ "$DATABASE_EXTENDED_NAMES" == "true" ]; then 53 | STARTER_ARGS="${STARTER_ARGS} --all.database.extended-names-databases=true" 54 | fi 55 | 56 | if [ "$USE_MOUNTED_DATA" == "true" ]; then 57 | STARTER_ARGS="${STARTER_ARGS} --starter.data-dir=/data" 58 | MOUNT_DATA="-v $LOCATION/data:/data" 59 | fi 60 | 61 | docker run -d \ 62 | --name=adb \ 63 | -p 8528:8528 \ 64 | -v "$LOCATION"/server.pem:/server.pem \ 65 | -v "$LOCATION"/jwtSecret:/jwtSecret \ 66 | $MOUNT_DATA \ 67 | -v /var/run/docker.sock:/var/run/docker.sock \ 68 | -e ARANGO_LICENSE_KEY="$ARANGO_LICENSE_KEY" \ 69 | $STARTER_DOCKER_IMAGE \ 70 | $STARTER_ARGS \ 71 | --docker.net-mode=default \ 72 | --docker.container=adb \ 73 | --auth.jwt-secret=/jwtSecret \ 74 | --starter.address="${GW}" \ 75 | --docker.image="${DOCKER_IMAGE}" \ 76 | --starter.local --starter.mode=${STARTER_MODE} --all.log.level=debug --all.log.output=+ --log.verbose 77 | 78 | 79 | wait_server() { 80 | # shellcheck disable=SC2091 81 | until $(curl --output /dev/null --insecure --fail --silent --head -i -H "$AUTHORIZATION_HEADER" "$SCHEME://$1/_api/version"); do 82 | printf '.' 83 | sleep 1 84 | done 85 | } 86 | 87 | echo "Waiting..." 88 | 89 | for a in ${COORDINATORS[*]} ; do 90 | wait_server "$a" 91 | done 92 | 93 | set +e 94 | ITER=0 95 | for a in ${COORDINATORS[*]} ; do 96 | echo "" 97 | echo "Setting username and password..." 98 | docker run --rm ${DOCKER_IMAGE} arangosh --server.endpoint="$ARANGOSH_SCHEME://${COORDINATORSINTERNAL[ITER]}" --server.authentication=false --javascript.execute-string='require("org/arangodb/users").update("root", "test")' 99 | ITER=$(expr $ITER + 1) 100 | done 101 | set -e 102 | 103 | for a in ${COORDINATORS[*]} ; do 104 | echo "" 105 | echo "Requesting endpoint version..." 106 | curl -u root:test --insecure --fail "$SCHEME://$a/_api/version" 107 | done 108 | 109 | echo "" 110 | echo "" 111 | echo "Done, your deployment is reachable at: " 112 | for a in ${COORDINATORS[*]} ; do 113 | echo "$SCHEME://$a" 114 | echo "" 115 | done 116 | 117 | if [ "$STARTER_MODE" == "activefailover" ]; then 118 | LEADER=$("$LOCATION"/find_active_endpoint.sh) 119 | echo "Leader: $SCHEME://$LEADER" 120 | echo "" 121 | fi 122 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import org.apache.spark.SPARK_VERSION 4 | import org.apache.spark.sql.DataFrame 5 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 6 | import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StringType, StructField, StructType} 7 | import org.assertj.core.api.Assertions.assertThat 8 | import org.junit.jupiter.api.Assumptions.assumeTrue 9 | import org.junit.jupiter.params.ParameterizedTest 10 | import org.junit.jupiter.params.provider.ValueSource 11 | 12 | class DeserializationCastTest extends BaseSparkTest { 13 | private val collectionName = "deserializationCast" 14 | 15 | @ParameterizedTest 16 | @ValueSource(strings = Array("vpack", "json")) 17 | def numberIntToStringCast(contentType: String): Unit = doTestImplicitCast( 18 | StructType(Array(StructField("a", StringType))), 19 | Seq(Map("a" -> 1)), 20 | Seq("""{"a":1}"""), 21 | contentType 22 | ) 23 | 24 | @ParameterizedTest 25 | @ValueSource(strings = Array("vpack", "json")) 26 | def numberDecToStringCast(contentType: String): Unit = doTestImplicitCast( 27 | StructType(Array(StructField("a", StringType))), 28 | Seq(Map("a" -> 1.1)), 29 | Seq("""{"a":1.1}"""), 30 | contentType 31 | ) 32 | 33 | @ParameterizedTest 34 | @ValueSource(strings = Array("vpack", "json")) 35 | def boolToStringCast(contentType: String): Unit = doTestImplicitCast( 36 | StructType(Array(StructField("a", StringType))), 37 | Seq(Map("a" -> true)), 38 | Seq("""{"a":true}"""), 39 | contentType 40 | ) 41 | 42 | @ParameterizedTest 43 | @ValueSource(strings = Array("vpack", "json")) 44 | def objectToStringCast(contentType: String): Unit = doTestImplicitCast( 45 | StructType(Array(StructField("a", StringType))), 46 | Seq(Map("a" -> Map("b" -> "c"))), 47 | Seq("""{"a":{"b":"c"}}"""), 48 | contentType 49 | ) 50 | 51 | @ParameterizedTest 52 | @ValueSource(strings = Array("vpack", "json")) 53 | def arrayToStringCast(contentType: String): Unit = doTestImplicitCast( 54 | StructType(Array(StructField("a", StringType))), 55 | Seq(Map("a" -> Array(1, 2))), 56 | Seq("""{"a":[1,2]}"""), 57 | contentType 58 | ) 59 | 60 | @ParameterizedTest 61 | @ValueSource(strings = Array("vpack", "json")) 62 | def nullToIntegerCast(contentType: String): Unit = { 63 | doTestImplicitCast( 64 | StructType(Array(StructField("a", IntegerType, nullable = false))), 65 | Seq(Map("a" -> null)), 66 | Seq("""{"a":0}"""), 67 | contentType 68 | ) 69 | } 70 | 71 | @ParameterizedTest 72 | @ValueSource(strings = Array("vpack", "json")) 73 | def nullToDoubleCast(contentType: String): Unit = { 74 | doTestImplicitCast( 75 | StructType(Array(StructField("a", DoubleType, nullable = false))), 76 | Seq(Map("a" -> null)), 77 | Seq("""{"a":0.0}"""), 78 | contentType 79 | ) 80 | } 81 | 82 | @ParameterizedTest 83 | @ValueSource(strings = Array("vpack", "json")) 84 | def nullAsBoolean(contentType: String): Unit = { 85 | doTestImplicitCast( 86 | StructType(Array(StructField("a", BooleanType, nullable = false))), 87 | Seq(Map("a" -> null)), 88 | Seq("""{"a":false}"""), 89 | contentType 90 | ) 91 | } 92 | 93 | private def doTestImplicitCast( 94 | schema: StructType, 95 | data: Iterable[Map[String, Any]], 96 | jsonData: Seq[String], 97 | contentType: String 98 | ) = { 99 | 100 | /** 101 | * FIXME: many vpack tests fail 102 | */ 103 | assumeTrue(contentType != "vpack") 104 | 105 | import spark.implicits._ 106 | val dfFromJson: DataFrame = spark.read.schema(schema).json(jsonData.toDS) 107 | dfFromJson.show() 108 | val df = BaseSparkTest.createDF(collectionName, data, schema, Map(ArangoDBConf.CONTENT_TYPE -> contentType)) 109 | assertThat(df.collect()).isEqualTo(dfFromJson.collect()) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | // scalastyle:off 19 | 20 | package org.apache.spark.sql.arangodb.datasource.mapping.json 21 | 22 | import com.fasterxml.jackson.core.{JsonFactory, JsonParser} 23 | import org.apache.hadoop.io.Text 24 | import org.apache.spark.sql.catalyst.InternalRow 25 | import org.apache.spark.unsafe.types.UTF8String 26 | import sun.nio.cs.StreamDecoder 27 | 28 | import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} 29 | import java.nio.channels.Channels 30 | import java.nio.charset.{Charset, StandardCharsets} 31 | 32 | private[sql] object CreateJacksonParser extends Serializable { 33 | def string(jsonFactory: JsonFactory, record: String): JsonParser = { 34 | jsonFactory.createParser(record) 35 | } 36 | 37 | def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = { 38 | val bb = record.getByteBuffer 39 | assert(bb.hasArray) 40 | 41 | val bain = new ByteArrayInputStream( 42 | bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) 43 | 44 | jsonFactory.createParser(new InputStreamReader(bain, StandardCharsets.UTF_8)) 45 | } 46 | 47 | def text(jsonFactory: JsonFactory, record: Text): JsonParser = { 48 | jsonFactory.createParser(record.getBytes, 0, record.getLength) 49 | } 50 | 51 | // Jackson parsers can be ranked according to their performance: 52 | // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser 53 | // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically 54 | // by checking leading bytes of the array. 55 | // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected 56 | // automatically by analyzing first bytes of the input stream. 57 | // 3. Reader based parser. This is the slowest parser used here but it allows to create 58 | // a reader with specific encoding. 59 | // The method creates a reader for an array with given encoding and sets size of internal 60 | // decoding buffer according to size of input array. 61 | private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = { 62 | val bais = new ByteArrayInputStream(in, 0, length) 63 | val byteChannel = Channels.newChannel(bais) 64 | val decodingBufferSize = Math.min(length, 8192) 65 | val decoder = Charset.forName(enc).newDecoder() 66 | 67 | StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize) 68 | } 69 | 70 | def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = { 71 | val sd = getStreamDecoder(enc, record.getBytes, record.getLength) 72 | jsonFactory.createParser(sd) 73 | } 74 | 75 | def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = { 76 | jsonFactory.createParser(is) 77 | } 78 | 79 | def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { 80 | jsonFactory.createParser(new InputStreamReader(is, enc)) 81 | } 82 | 83 | def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { 84 | val ba = row.getBinary(0) 85 | 86 | jsonFactory.createParser(ba, 0, ba.length) 87 | } 88 | 89 | def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { 90 | val binary = row.getBinary(0) 91 | val sd = getStreamDecoder(enc, binary, binary.length) 92 | 93 | jsonFactory.createParser(sd) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | // scalastyle:off 19 | 20 | package org.apache.spark.sql.arangodb.datasource.mapping.json 21 | 22 | import com.fasterxml.jackson.core.{JsonFactory, JsonParser} 23 | import org.apache.hadoop.io.Text 24 | import org.apache.spark.sql.catalyst.InternalRow 25 | import org.apache.spark.unsafe.types.UTF8String 26 | import sun.nio.cs.StreamDecoder 27 | 28 | import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} 29 | import java.nio.channels.Channels 30 | import java.nio.charset.{Charset, StandardCharsets} 31 | 32 | private[sql] object CreateJacksonParser extends Serializable { 33 | def string(jsonFactory: JsonFactory, record: String): JsonParser = { 34 | jsonFactory.createParser(record) 35 | } 36 | 37 | def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = { 38 | val bb = record.getByteBuffer 39 | assert(bb.hasArray) 40 | 41 | val bain = new ByteArrayInputStream( 42 | bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) 43 | 44 | jsonFactory.createParser(new InputStreamReader(bain, StandardCharsets.UTF_8)) 45 | } 46 | 47 | def text(jsonFactory: JsonFactory, record: Text): JsonParser = { 48 | jsonFactory.createParser(record.getBytes, 0, record.getLength) 49 | } 50 | 51 | // Jackson parsers can be ranked according to their performance: 52 | // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser 53 | // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically 54 | // by checking leading bytes of the array. 55 | // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected 56 | // automatically by analyzing first bytes of the input stream. 57 | // 3. Reader based parser. This is the slowest parser used here but it allows to create 58 | // a reader with specific encoding. 59 | // The method creates a reader for an array with given encoding and sets size of internal 60 | // decoding buffer according to size of input array. 61 | private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = { 62 | val bais = new ByteArrayInputStream(in, 0, length) 63 | val byteChannel = Channels.newChannel(bais) 64 | val decodingBufferSize = Math.min(length, 8192) 65 | val decoder = Charset.forName(enc).newDecoder() 66 | 67 | StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize) 68 | } 69 | 70 | def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = { 71 | val sd = getStreamDecoder(enc, record.getBytes, record.getLength) 72 | jsonFactory.createParser(sd) 73 | } 74 | 75 | def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = { 76 | jsonFactory.createParser(is) 77 | } 78 | 79 | def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { 80 | jsonFactory.createParser(new InputStreamReader(is, enc)) 81 | } 82 | 83 | def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { 84 | val ba = row.getBinary(0) 85 | 86 | jsonFactory.createParser(ba, 0, ba.length) 87 | } 88 | 89 | def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { 90 | val binary = row.getBinary(0) 91 | val sd = getStreamDecoder(enc, binary, binary.length) 92 | 93 | jsonFactory.createParser(sd) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /python-integration-tests/integration/test_deserialization_cast.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any 2 | 3 | import arango.database 4 | import pytest 5 | from pyspark.sql import SparkSession 6 | from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, BooleanType 7 | 8 | from integration import test_basespark 9 | 10 | COLLECTION_NAME = "deserializationCast" 11 | content_types = ["vpack", "json"] 12 | 13 | 14 | def check_implicit_cast(db: arango.database.StandardDatabase, spark: SparkSession, schema: StructType, data: List[Dict[str, Any]], json_data: List[str], content_type: str): 15 | # FIXME: many vpack tests are failing 16 | if content_type == "vpack": 17 | pytest.xfail("Too many vpack tests fail") 18 | 19 | df_from_json = spark.read.schema(schema).json(spark.sparkContext.parallelize(json_data)) 20 | df_from_json.show() 21 | 22 | df = test_basespark.create_df(db, spark, COLLECTION_NAME, data, schema, {"contentType": content_type}) 23 | assert df.collect() == df_from_json.collect() 24 | 25 | 26 | @pytest.mark.parametrize("content_type", content_types) 27 | def test_number_int_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 28 | check_implicit_cast( 29 | database_conn, 30 | spark, 31 | StructType([StructField("a", StringType())]), 32 | [{"a": 1}], 33 | ['{"a":1}'], 34 | content_type 35 | ) 36 | 37 | 38 | @pytest.mark.parametrize("content_type", content_types) 39 | def test_number_dec_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 40 | check_implicit_cast( 41 | database_conn, 42 | spark, 43 | StructType([StructField("a", StringType())]), 44 | [{"a": 1.1}], 45 | ['{"a":1.1}'], 46 | content_type 47 | ) 48 | 49 | 50 | @pytest.mark.parametrize("content_type", content_types) 51 | def test_bool_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 52 | check_implicit_cast( 53 | database_conn, 54 | spark, 55 | StructType([StructField("a", StringType())]), 56 | [{"a": True}], 57 | ['{"a":true}'], 58 | content_type 59 | ) 60 | 61 | 62 | @pytest.mark.parametrize("content_type", content_types) 63 | def test_object_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 64 | check_implicit_cast( 65 | database_conn, 66 | spark, 67 | StructType([StructField("a", StringType())]), 68 | [{"a": {"b": "c"}}], 69 | ['{"a":{"b":"c"}}'], 70 | content_type 71 | ) 72 | 73 | 74 | @pytest.mark.parametrize("content_type", content_types) 75 | def test_array_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 76 | check_implicit_cast( 77 | database_conn, 78 | spark, 79 | StructType([StructField("a", StringType())]), 80 | [{"a": [1, 2]}], 81 | ['{"a":[1,2]}'], 82 | content_type 83 | ) 84 | 85 | 86 | @pytest.mark.parametrize("content_type", content_types) 87 | def test_null_to_integer_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 88 | check_implicit_cast( 89 | database_conn, 90 | spark, 91 | StructType([StructField("a", IntegerType())]), 92 | [{"a": None}], 93 | ['{"a":null}'], 94 | content_type 95 | ) 96 | 97 | 98 | @pytest.mark.parametrize("content_type", content_types) 99 | def test_null_to_double_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 100 | check_implicit_cast( 101 | database_conn, 102 | spark, 103 | StructType([StructField("a", DoubleType())]), 104 | [{"a": None}], 105 | ['{"a":null}'], 106 | content_type 107 | ) 108 | 109 | 110 | @pytest.mark.parametrize("content_type", content_types) 111 | def test_null_to_boolean_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 112 | check_implicit_cast( 113 | database_conn, 114 | spark, 115 | StructType([StructField("a", BooleanType())]), 116 | [{"a": None}], 117 | ['{"a":null}'], 118 | content_type 119 | ) 120 | -------------------------------------------------------------------------------- /python-integration-tests/integration/test_readwrite_datatype.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import sys 3 | from datetime import datetime, date 4 | from decimal import Context 5 | 6 | import arango.database 7 | import pyspark.sql 8 | import pytest 9 | from py4j.protocol import Py4JJavaError 10 | from pyspark.sql import SparkSession 11 | from pyspark.sql.types import StructType, StructField, BooleanType, DoubleType, FloatType, IntegerType, LongType, \ 12 | DateType, TimestampType, ShortType, ByteType, StringType, ArrayType, Row, MapType, DecimalType 13 | 14 | from integration import test_basespark 15 | from integration.utils import combine_dicts 16 | 17 | COLLECTION_NAME = "datatypes" 18 | 19 | data = [ 20 | [ 21 | False, 22 | 1.1, 23 | 0.09375, 24 | 1, 25 | 1, 26 | date.fromisoformat("2021-01-01"), 27 | datetime.fromisoformat("2021-01-01 01:01:01.111").astimezone(), 28 | 1, 29 | 1, 30 | "one", 31 | [1, 1, 1], 32 | [["a", "b", "c"], ["d", "e", "f"]], 33 | {"a": 1, "b": 1}, 34 | Row("a1", 1) 35 | ], 36 | [ 37 | True, 38 | 2.2, 39 | 2.2, 40 | 2, 41 | 2, 42 | date.fromisoformat("2022-02-02"), 43 | datetime.fromisoformat("2022-02-02 02:02:02.222").astimezone(), 44 | 2, 45 | 2, 46 | "two", 47 | [2, 2, 2], 48 | [["a", "b", "c"], ["d", "e", "f"]], 49 | {"a": 2, "b": 2}, 50 | Row("a1", 2) 51 | ] 52 | ] 53 | 54 | struct_fields = [ 55 | StructField("bool", BooleanType(), nullable=False), 56 | StructField("double", DoubleType(), nullable=False), 57 | StructField("float", FloatType(), nullable=False), 58 | StructField("integer", IntegerType(), nullable=False), 59 | StructField("long", LongType(), nullable=False), 60 | StructField("date", DateType(), nullable=False), 61 | StructField("timestamp", TimestampType(), nullable=False), 62 | StructField("short", ShortType(), nullable=False), 63 | StructField("byte", ByteType(), nullable=False), 64 | StructField("string", StringType(), nullable=False), 65 | StructField("intArray", ArrayType(IntegerType()), nullable=False), 66 | StructField("stringArrayArray", ArrayType(ArrayType(StringType())), nullable=False), 67 | StructField("intMap", MapType(StringType(), IntegerType()), nullable=False), 68 | StructField("struct", StructType([ 69 | StructField("a", StringType()), 70 | StructField("b", IntegerType()) 71 | ])) 72 | ] 73 | 74 | schema = StructType(struct_fields) 75 | 76 | 77 | @pytest.mark.parametrize("protocol,content_type", test_basespark.protocol_and_content_type) 78 | def test_round_trip_read_write(spark: SparkSession, protocol: str, content_type: str): 79 | df = spark.createDataFrame(spark.sparkContext.parallelize([Row(*x) for x in data]), schema) 80 | round_trip_readwrite(spark, df, protocol, content_type) 81 | 82 | 83 | @pytest.mark.parametrize("protocol,content_type", test_basespark.protocol_and_content_type) 84 | def test_round_trip_read_write_decimal_type(spark: SparkSession, protocol: str, content_type: str): 85 | if content_type != "vpack": 86 | pytest.xfail("vpack does not support round trip decimal types") 87 | 88 | schema_with_decimal = StructType(copy.deepcopy(struct_fields)) 89 | schema_with_decimal.add(StructField("decimal", DecimalType(38, 18), nullable=False)) 90 | df = spark.createDataFrame(spark.sparkContext.parallelize([Row(*x, Context(prec=38).create_decimal("2.22222222")) for x in data]), schema_with_decimal) 91 | round_trip_readwrite(spark, df, protocol, content_type) 92 | 93 | 94 | def write_df(df: pyspark.sql.DataFrame, protocol: str, content_type: str): 95 | all_opts = combine_dicts([ 96 | test_basespark.options, 97 | { 98 | "table": COLLECTION_NAME, 99 | "protocol": protocol, 100 | "contentType": content_type, 101 | "overwriteMode": "replace", 102 | "confirmTruncate": "true" 103 | } 104 | ]) 105 | df.write\ 106 | .format(test_basespark.arango_datasource_name)\ 107 | .mode("overwrite")\ 108 | .options(**all_opts)\ 109 | .save() 110 | 111 | 112 | def round_trip_readwrite(spark: SparkSession, df: pyspark.sql.DataFrame, protocol: str, content_type: str): 113 | initial = df.collect() 114 | 115 | all_opts = combine_dicts([test_basespark.options, {"table": COLLECTION_NAME}]) 116 | 117 | write_df(df, protocol, content_type) 118 | read = spark.read.format(test_basespark.arango_datasource_name)\ 119 | .options(**all_opts)\ 120 | .schema(df.schema)\ 121 | .load()\ 122 | .collect() 123 | 124 | assert initial.sort() == read.sort() 125 | -------------------------------------------------------------------------------- /python-integration-tests/integration/write/test_savemode.py: -------------------------------------------------------------------------------- 1 | import arango 2 | import arango.database 3 | import arango.collection 4 | import pytest 5 | import pyspark.sql 6 | from py4j.protocol import Py4JJavaError 7 | from pyspark.sql import SparkSession 8 | from pyspark.sql.utils import AnalysisException 9 | 10 | from integration.test_basespark import protocol_and_content_type, options, arango_datasource_name 11 | from integration.utils import combine_dicts 12 | 13 | 14 | COLLECTION_NAME = "chessPlayersSaveMode" 15 | 16 | data = [ 17 | ("Carlsen", "Magnus"), 18 | ("Caruana", "Fabiano"), 19 | ("Ding", "Liren"), 20 | ("Nepomniachtchi", "Ian"), 21 | ("Aronian", "Levon"), 22 | ("Grischuk", "Alexander"), 23 | ("Giri", "Anish"), 24 | ("Mamedyarov", "Shakhriyar"), 25 | ("So", "Wesley"), 26 | ("Radjabov", "Teimour") 27 | ] 28 | 29 | 30 | @pytest.fixture(scope="function") 31 | def chess_collection(database_conn: arango.database.StandardDatabase) -> arango.collection.StandardCollection: 32 | if database_conn.has_collection(COLLECTION_NAME): 33 | database_conn.delete_collection(COLLECTION_NAME) 34 | yield database_conn.collection(COLLECTION_NAME) 35 | 36 | 37 | @pytest.fixture 38 | def chess_df(spark: SparkSession) -> pyspark.sql.DataFrame: 39 | df = spark.createDataFrame(data, schema=["surname", "name"]) 40 | return df 41 | 42 | 43 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type) 44 | def test_savemode_append(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, protocol: str, content_type: str): 45 | all_opts = combine_dicts([options, { 46 | "table": COLLECTION_NAME, 47 | "protocol": protocol, 48 | "contentType": content_type 49 | }]) 50 | 51 | chess_df.write\ 52 | .format(arango_datasource_name)\ 53 | .mode("Append")\ 54 | .options(**all_opts)\ 55 | .save() 56 | 57 | assert chess_collection.count() == 10 58 | 59 | 60 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type) 61 | def test_savemode_append_with_existing_collection(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, database_conn: arango.database.StandardDatabase, protocol: str, content_type: str): 62 | database_conn.create_collection(COLLECTION_NAME) 63 | chess_collection.insert({}) 64 | 65 | all_opts = combine_dicts([options, { 66 | "table": COLLECTION_NAME, 67 | "protocol": protocol, 68 | "contentType": content_type 69 | }]) 70 | 71 | chess_df.write \ 72 | .format(arango_datasource_name) \ 73 | .mode("Append") \ 74 | .options(**all_opts) \ 75 | .save() 76 | 77 | assert chess_collection.count() == 11 78 | 79 | 80 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type) 81 | def test_savemode_overwrite_should_throw_whenused_alone(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, protocol: str, content_type: str): 82 | all_opts = combine_dicts([options, { 83 | "table": COLLECTION_NAME, 84 | "protocol": protocol, 85 | "contentType": content_type 86 | }]) 87 | 88 | with pytest.raises(AnalysisException) as e: 89 | chess_df.write \ 90 | .format(arango_datasource_name) \ 91 | .mode("Overwrite") \ 92 | .options(**all_opts) \ 93 | .save() 94 | 95 | e.match("confirmTruncate") 96 | 97 | 98 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type) 99 | def test_savemode_overwrite(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, protocol: str, content_type: str): 100 | all_opts = combine_dicts([options, { 101 | "table": COLLECTION_NAME, 102 | "protocol": protocol, 103 | "contentType": content_type, 104 | "confirmTruncate": "true" 105 | }]) 106 | 107 | chess_df.write \ 108 | .format(arango_datasource_name) \ 109 | .mode("Overwrite") \ 110 | .options(**all_opts) \ 111 | .save() 112 | 113 | assert chess_collection.count() == 10 114 | 115 | 116 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type) 117 | def test_savemode_overwrite_with_existing_collection(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, database_conn: arango.database.StandardDatabase, protocol: str, content_type: str): 118 | database_conn.create_collection(COLLECTION_NAME) 119 | chess_collection.insert({}) 120 | 121 | all_opts = combine_dicts([options, { 122 | "table": COLLECTION_NAME, 123 | "protocol": protocol, 124 | "contentType": content_type, 125 | "confirmTruncate": "true" 126 | }]) 127 | 128 | chess_df.write \ 129 | .format(arango_datasource_name) \ 130 | .mode("Overwrite") \ 131 | .options(**all_opts) \ 132 | .save() 133 | 134 | assert chess_collection.count() == 10 135 | -------------------------------------------------------------------------------- /integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BadRecordsTest.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.arangodb.datasource 2 | 3 | import org.apache.spark.{SPARK_VERSION, SparkException} 4 | import org.apache.spark.sql.DataFrame 5 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf 6 | import org.apache.spark.sql.catalyst.util.{BadRecordException, DropMalformedMode, FailFastMode, ParseMode} 7 | import org.apache.spark.sql.types._ 8 | import org.assertj.core.api.Assertions.{assertThat, catchThrowable} 9 | import org.assertj.core.api.ThrowableAssert.ThrowingCallable 10 | import org.junit.jupiter.params.ParameterizedTest 11 | import org.junit.jupiter.params.provider.ValueSource 12 | 13 | class BadRecordsTest extends BaseSparkTest { 14 | private val collectionName = "deserializationCast" 15 | 16 | @ParameterizedTest 17 | @ValueSource(strings = Array("vpack", "json")) 18 | def stringAsInteger(contentType: String): Unit = testBadRecord( 19 | StructType(Array(StructField("a", IntegerType))), 20 | Seq(Map("a" -> "1")), 21 | Seq("""{"a":"1"}"""), 22 | contentType 23 | ) 24 | 25 | @ParameterizedTest 26 | @ValueSource(strings = Array("vpack", "json")) 27 | def booleanAsInteger(contentType: String): Unit = testBadRecord( 28 | StructType(Array(StructField("a", IntegerType))), 29 | Seq(Map("a" -> true)), 30 | Seq("""{"a":true}"""), 31 | contentType 32 | ) 33 | 34 | @ParameterizedTest 35 | @ValueSource(strings = Array("vpack", "json")) 36 | def stringAsDouble(contentType: String): Unit = testBadRecord( 37 | StructType(Array(StructField("a", DoubleType))), 38 | Seq(Map("a" -> "1")), 39 | Seq("""{"a":"1"}"""), 40 | contentType 41 | ) 42 | 43 | @ParameterizedTest 44 | @ValueSource(strings = Array("vpack", "json")) 45 | def booleanAsDouble(contentType: String): Unit = testBadRecord( 46 | StructType(Array(StructField("a", DoubleType))), 47 | Seq(Map("a" -> true)), 48 | Seq("""{"a":true}"""), 49 | contentType 50 | ) 51 | 52 | @ParameterizedTest 53 | @ValueSource(strings = Array("vpack", "json")) 54 | def stringAsBoolean(contentType: String): Unit = testBadRecord( 55 | StructType(Array(StructField("a", BooleanType))), 56 | Seq(Map("a" -> "true")), 57 | Seq("""{"a":"true"}"""), 58 | contentType 59 | ) 60 | 61 | @ParameterizedTest 62 | @ValueSource(strings = Array("vpack", "json")) 63 | def numberAsBoolean(contentType: String): Unit = testBadRecord( 64 | StructType(Array(StructField("a", BooleanType))), 65 | Seq(Map("a" -> 1)), 66 | Seq("""{"a":1}"""), 67 | contentType 68 | ) 69 | 70 | private def testBadRecord( 71 | schema: StructType, 72 | data: Iterable[Map[String, Any]], 73 | jsonData: Seq[String], 74 | contentType: String 75 | ) = { 76 | // PERMISSIVE 77 | doTestBadRecord(schema, data, jsonData, Map(ArangoDBConf.CONTENT_TYPE -> contentType)) 78 | 79 | // PERMISSIVE with columnNameOfCorruptRecord 80 | doTestBadRecord( 81 | schema.add(StructField("corruptRecord", StringType)), 82 | data, 83 | jsonData, 84 | Map( 85 | ArangoDBConf.CONTENT_TYPE -> contentType, 86 | ArangoDBConf.COLUMN_NAME_OF_CORRUPT_RECORD -> "corruptRecord" 87 | ) 88 | ) 89 | 90 | // DROPMALFORMED 91 | doTestBadRecord(schema, data, jsonData, 92 | Map( 93 | ArangoDBConf.CONTENT_TYPE -> contentType, 94 | ArangoDBConf.PARSE_MODE -> DropMalformedMode.name 95 | ) 96 | ) 97 | 98 | // FAILFAST 99 | val df = BaseSparkTest.createDF(collectionName, data, schema, Map( 100 | ArangoDBConf.CONTENT_TYPE -> contentType, 101 | ArangoDBConf.PARSE_MODE -> FailFastMode.name 102 | )) 103 | val thrown = catchThrowable(new ThrowingCallable() { 104 | override def call(): Unit = df.collect() 105 | }) 106 | 107 | assertThat(thrown.getCause).isInstanceOf(classOf[SparkException]) 108 | assertThat(thrown.getCause).hasMessageContaining("Malformed record") 109 | assertThat(thrown.getCause).hasCauseInstanceOf(classOf[BadRecordException]) 110 | } 111 | 112 | private def doTestBadRecord( 113 | schema: StructType, 114 | data: Iterable[Map[String, Any]], 115 | jsonData: Seq[String], 116 | opts: Map[String, String] = Map.empty 117 | ) = { 118 | import spark.implicits._ 119 | val dfFromJson: DataFrame = spark.read.schema(schema).options(opts).json(jsonData.toDS) 120 | dfFromJson.show() 121 | 122 | val tableDF = BaseSparkTest.createDF(collectionName, data, schema, opts) 123 | assertThat(tableDF.collect()).isEqualTo(dfFromJson.collect()) 124 | 125 | val queryDF = BaseSparkTest.createQueryDF(s"RETURN ${jsonData.head}", schema, opts) 126 | assertThat(queryDF.collect()).isEqualTo(dfFromJson.collect()) 127 | } 128 | 129 | } 130 | -------------------------------------------------------------------------------- /python-integration-tests/integration/test_bad_records.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import arango.database 4 | import pytest 5 | from py4j.protocol import Py4JJavaError 6 | from pyspark.sql import SparkSession 7 | from pyspark.sql.types import StructType, StringType, StructField, IntegerType, DoubleType, BooleanType 8 | 9 | from integration import test_basespark 10 | from integration.test_basespark import options, arango_datasource_name 11 | from integration.utils import combine_dicts 12 | 13 | 14 | content_types = ["vpack", "json"] 15 | COLLECTION_NAME = "deserializationCast" 16 | 17 | 18 | def do_test_bad_record(db: arango.database.StandardDatabase, spark: SparkSession, schema: StructType, data: List[Dict[str, Any]], json_data: List[str], opts: Dict[str, str]): 19 | df_from_json = spark.read.schema(schema).options(**opts).json(spark.sparkContext.parallelize(json_data)) 20 | df_from_json.show() 21 | 22 | table_df = test_basespark.create_df(db, spark, COLLECTION_NAME, data, schema, opts) 23 | assert table_df.collect() == df_from_json.collect() 24 | 25 | query_df = test_basespark.create_query_df(spark, f"RETURN {json_data[0]}", schema, opts) 26 | assert query_df.collect() == df_from_json.collect() 27 | 28 | 29 | def check_bad_record(db: arango.database.StandardDatabase, spark: SparkSession, schema: StructType, data: List[Dict[str, Any]], json_data: List[str], content_type: str): 30 | # Permissive 31 | do_test_bad_record(db, spark, schema, data, json_data, {"contentType": content_type}) 32 | 33 | # Permissive with column name of corrupt record 34 | do_test_bad_record( 35 | db, 36 | spark, 37 | schema.add(StructField("corruptRecord", StringType())), 38 | data, 39 | json_data, 40 | { 41 | "contentType": content_type, 42 | "columnNameOfCorruptRecord": "corruptRecord" 43 | } 44 | ) 45 | 46 | # Dropmalformed 47 | do_test_bad_record(db, spark, schema, data, json_data, 48 | { 49 | "contentType": content_type, 50 | "mode": "DROPMALFORMED" 51 | }) 52 | 53 | # Failfast 54 | df = test_basespark.create_df(db, spark, COLLECTION_NAME, data, schema, 55 | { 56 | "contentType": content_type, 57 | "mode": "FAILFAST" 58 | }) 59 | with pytest.raises(Py4JJavaError) as e: 60 | df.collect() 61 | 62 | e.match("SparkException") 63 | e.match("Malformed record") 64 | e.match("BadRecordException") 65 | 66 | 67 | @pytest.mark.parametrize("content_type", content_types) 68 | def test_string_as_integer(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 69 | check_bad_record( 70 | database_conn, 71 | spark, 72 | StructType([StructField("a", IntegerType())]), 73 | [{"a": "1"}], 74 | ['{"a":"1"}'], 75 | content_type 76 | ) 77 | 78 | 79 | @pytest.mark.parametrize("content_type", content_types) 80 | def test_boolean_as_integer(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 81 | check_bad_record( 82 | database_conn, 83 | spark, 84 | StructType([StructField("a", IntegerType())]), 85 | [{"a": True}], 86 | ['{"a":true}'], 87 | content_type 88 | ) 89 | 90 | 91 | @pytest.mark.parametrize("content_type", content_types) 92 | def test_string_as_double(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 93 | check_bad_record( 94 | database_conn, 95 | spark, 96 | StructType([StructField("a", DoubleType())]), 97 | [{"a": "1"}], 98 | ['{"a":"1"}'], 99 | content_type 100 | ) 101 | 102 | 103 | @pytest.mark.parametrize("content_type", content_types) 104 | def test_boolean_as_double(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 105 | check_bad_record( 106 | database_conn, 107 | spark, 108 | StructType([StructField("a", DoubleType())]), 109 | [{"a": True}], 110 | ['{"a":true}'], 111 | content_type 112 | ) 113 | 114 | 115 | @pytest.mark.parametrize("content_type", content_types) 116 | def test_string_as_boolean(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 117 | check_bad_record( 118 | database_conn, 119 | spark, 120 | StructType([StructField("a", BooleanType())]), 121 | [{"a": "true"}], 122 | ['{"a":"true"}'], 123 | content_type 124 | ) 125 | 126 | 127 | @pytest.mark.parametrize("content_type", content_types) 128 | def test_number_as_boolean(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str): 129 | check_bad_record( 130 | database_conn, 131 | spark, 132 | StructType([StructField("a", BooleanType())]), 133 | [{"a": 1}], 134 | ['{"a":1}'], 135 | content_type 136 | ) --------------------------------------------------------------------------------