├── docker
├── jwtSecret
├── jwtHeader
├── server.pem
├── start_db.sh
└── start_db_macos.sh
├── demo
├── docker
│ ├── jwtSecret
│ ├── .ivy2
│ │ ├── cache
│ │ │ └── .gitignore
│ │ └── jars
│ │ │ └── .gitignore
│ ├── jwtHeader
│ ├── stop.sh
│ ├── start_spark.sh
│ ├── server.pem
│ └── start_db.sh
├── python-demo
│ ├── requirements.txt
│ ├── utils.py
│ ├── read_write_demo.py
│ ├── schemas.py
│ ├── read_demo.py
│ ├── demo.py
│ └── write_demo.py
├── src
│ ├── test
│ │ ├── scala
│ │ │ └── DemoTest.scala
│ │ └── resources
│ │ │ └── log4j.xml
│ └── main
│ │ └── scala
│ │ ├── ReadWriteDemo.scala
│ │ ├── Schemas.scala
│ │ ├── Demo.scala
│ │ ├── ReadDemo.scala
│ │ └── WriteDemo.scala
└── README.md
├── python-integration-tests
├── integration
│ ├── __init__.py
│ ├── write
│ │ ├── __init__.py
│ │ └── test_savemode.py
│ ├── utils.py
│ ├── conftest.py
│ ├── test_composite_filter.py
│ ├── test_deserialization_cast.py
│ ├── test_readwrite_datatype.py
│ └── test_bad_records.py
└── test-requirements.txt
├── integration-tests
├── src
│ └── test
│ │ ├── resources
│ │ ├── allure.properties
│ │ ├── cert.p12
│ │ └── log4j2.properties
│ │ └── scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── sql
│ │ └── arangodb
│ │ ├── datasource
│ │ ├── User.scala
│ │ ├── AcquireHostListTest.scala
│ │ ├── TestUtils.scala
│ │ ├── write
│ │ │ ├── WriteWithNullKey.scala
│ │ │ ├── CreateCollectionTest.scala
│ │ │ ├── WriteResiliencyTest.scala
│ │ │ └── EdgeSchemaValidationTest.scala
│ │ ├── ReadSmartEdgeCollectionTest.scala
│ │ ├── StringFiltersTest.scala
│ │ ├── CompositeFilterTest.scala
│ │ ├── DeserializationCastTest.scala
│ │ └── BadRecordsTest.scala
│ │ ├── examples
│ │ ├── PushdownExample.scala
│ │ └── DataTypesExample.scala
│ │ └── JacksonTest.scala
└── pom.xml
├── .gitignore
├── arangodb-spark-datasource-3.4
├── src
│ └── main
│ │ ├── resources
│ │ └── META-INF
│ │ │ └── services
│ │ │ ├── org.apache.spark.sql.sources.DataSourceRegister
│ │ │ ├── org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider
│ │ │ └── org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
│ │ └── scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── sql
│ │ └── arangodb
│ │ └── datasource
│ │ └── mapping
│ │ ├── package.scala
│ │ ├── ArangoGeneratorImpl.scala
│ │ ├── ArangoParserImpl.scala
│ │ └── json
│ │ ├── JacksonUtils.scala
│ │ └── CreateJacksonParser.scala
└── pom.xml
├── arangodb-spark-datasource-3.5
├── src
│ └── main
│ │ ├── resources
│ │ └── META-INF
│ │ │ └── services
│ │ │ ├── org.apache.spark.sql.sources.DataSourceRegister
│ │ │ ├── org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider
│ │ │ └── org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
│ │ └── scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── sql
│ │ └── arangodb
│ │ └── datasource
│ │ └── mapping
│ │ ├── package.scala
│ │ ├── ArangoGeneratorImpl.scala
│ │ ├── ArangoParserImpl.scala
│ │ └── json
│ │ ├── JacksonUtils.scala
│ │ └── CreateJacksonParser.scala
└── pom.xml
├── bin
├── clean.sh
└── test.sh
├── arangodb-spark-commons
├── src
│ ├── main
│ │ └── scala
│ │ │ ├── org
│ │ │ └── apache
│ │ │ │ └── spark
│ │ │ │ └── sql
│ │ │ │ └── arangodb
│ │ │ │ ├── commons
│ │ │ │ ├── exceptions
│ │ │ │ │ ├── DataWriteAbortException.scala
│ │ │ │ │ ├── ArangoDBDataWriterException.scala
│ │ │ │ │ └── ArangoDBMultiException.scala
│ │ │ │ ├── utils
│ │ │ │ │ └── PushDownCtx.scala
│ │ │ │ ├── mapping
│ │ │ │ │ ├── ArangoParserProvider.scala
│ │ │ │ │ ├── MappingUtils.scala
│ │ │ │ │ ├── ArangoGeneratorProvider.scala
│ │ │ │ │ ├── ArangoParser.scala
│ │ │ │ │ └── ArangoGenerator.scala
│ │ │ │ ├── filter
│ │ │ │ │ ├── FilterSupport.scala
│ │ │ │ │ └── package.scala
│ │ │ │ ├── ArangoUtils.scala
│ │ │ │ ├── PushdownUtils.scala
│ │ │ │ └── package.scala
│ │ │ │ └── datasource
│ │ │ │ ├── reader
│ │ │ │ ├── ArangoCollectionPartition.scala
│ │ │ │ ├── ArangoPartitionReaderFactory.scala
│ │ │ │ ├── ArangoScan.scala
│ │ │ │ ├── ArangoQueryReader.scala
│ │ │ │ ├── ArangoCollectionPartitionReader.scala
│ │ │ │ └── ArangoScanBuilder.scala
│ │ │ │ ├── writer
│ │ │ │ ├── ArangoDataWriterFactory.scala
│ │ │ │ ├── ArangoBatchWriter.scala
│ │ │ │ └── ArangoWriterBuilder.scala
│ │ │ │ └── ArangoTable.scala
│ │ │ └── com
│ │ │ └── arangodb
│ │ │ └── spark
│ │ │ └── DefaultSource.scala
│ └── test
│ │ └── scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── sql
│ │ └── arangodb
│ │ └── commons
│ │ ├── filter
│ │ ├── IsNullTest.scala
│ │ ├── PackageTest.scala
│ │ ├── NotFilterTest.scala
│ │ ├── StringEndsWithFilterTest.scala
│ │ ├── StringContainsFilterTest.scala
│ │ ├── StringStartsWithFilterTest.scala
│ │ ├── OrFilterTest.scala
│ │ └── AndFilterTest.scala
│ │ ├── exceptions
│ │ └── ExceptionsSerializationTest.scala
│ │ └── ColumnsPruningTest.scala
└── pom.xml
├── .circleci
└── maven-release-settings.xml
├── README.md
├── dev-README.md
└── ChangeLog.md
/docker/jwtSecret:
--------------------------------------------------------------------------------
1 | Averysecretword
2 |
--------------------------------------------------------------------------------
/demo/docker/jwtSecret:
--------------------------------------------------------------------------------
1 | Averysecretword
2 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/write/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/demo/docker/.ivy2/cache/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/demo/docker/.ivy2/jars/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/demo/python-demo/requirements.txt:
--------------------------------------------------------------------------------
1 | pyspark[pandas_on_spark]==3.5.7
2 |
--------------------------------------------------------------------------------
/python-integration-tests/test-requirements.txt:
--------------------------------------------------------------------------------
1 | python-arango==7.3.4
2 | pytest==7.1.2
--------------------------------------------------------------------------------
/integration-tests/src/test/resources/allure.properties:
--------------------------------------------------------------------------------
1 | allure.results.directory=target/allure-results
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.class
2 | *.log
3 | /target/
4 | **/target/
5 | *.iml
6 | **/.idea/
7 | .directory
8 | **/.flattened-pom.xml
9 | __pycache__/
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister:
--------------------------------------------------------------------------------
1 | com.arangodb.spark.DefaultSource
2 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister:
--------------------------------------------------------------------------------
1 | com.arangodb.spark.DefaultSource
2 |
--------------------------------------------------------------------------------
/integration-tests/src/test/resources/cert.p12:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arangodb/arangodb-spark-datasource/HEAD/integration-tests/src/test/resources/cert.p12
--------------------------------------------------------------------------------
/docker/jwtHeader:
--------------------------------------------------------------------------------
1 | Authorization: bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmFuZ29kYiIsInNlcnZlcl9pZCI6ImZvbyJ9.QmuhPHkmRPJuHGxsEqggHGRyVXikV44tb5YU_yWEvEM
2 |
--------------------------------------------------------------------------------
/demo/docker/jwtHeader:
--------------------------------------------------------------------------------
1 | Authorization: bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmFuZ29kYiIsInNlcnZlcl9pZCI6ImZvbyJ9.QmuhPHkmRPJuHGxsEqggHGRyVXikV44tb5YU_yWEvEM
2 |
--------------------------------------------------------------------------------
/demo/python-demo/utils.py:
--------------------------------------------------------------------------------
1 | def combine_dicts(list_of_dicts):
2 | whole_dict = {}
3 | for d in list_of_dicts:
4 | whole_dict.update(d)
5 | return whole_dict
6 |
--------------------------------------------------------------------------------
/bin/clean.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mvn clean -Pspark-3.4 -Pscala-2.12
4 | mvn clean -Pspark-3.4 -Pscala-2.13
5 | mvn clean -Pspark-3.5 -Pscala-2.12
6 | mvn clean -Pspark-3.5 -Pscala-2.13
7 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/utils.py:
--------------------------------------------------------------------------------
1 | def combine_dicts(list_of_dicts):
2 | whole_dict = {}
3 | for d in list_of_dicts:
4 | whole_dict.update(d)
5 | return whole_dict
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider:
--------------------------------------------------------------------------------
1 | org.apache.spark.sql.arangodb.datasource.mapping.ArangoGeneratorProviderImpl
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider:
--------------------------------------------------------------------------------
1 | org.apache.spark.sql.arangodb.datasource.mapping.ArangoParserProviderImpl
2 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider:
--------------------------------------------------------------------------------
1 | org.apache.spark.sql.arangodb.datasource.mapping.ArangoGeneratorProviderImpl
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider:
--------------------------------------------------------------------------------
1 | org.apache.spark.sql.arangodb.datasource.mapping.ArangoParserProviderImpl
2 |
--------------------------------------------------------------------------------
/demo/docker/stop.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | docker exec adb /app/arangodb stop
4 | sleep 1
5 | docker rm -f \
6 | adb \
7 | spark-master \
8 | spark-worker-1 \
9 | spark-worker-2 \
10 | spark-worker-3
11 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/exceptions/DataWriteAbortException.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.exceptions
2 |
3 | class DataWriteAbortException(message: String) extends RuntimeException(message)
4 |
--------------------------------------------------------------------------------
/demo/src/test/scala/DemoTest.scala:
--------------------------------------------------------------------------------
1 | import org.junit.jupiter.api.Test
2 |
3 | class DemoTest {
4 |
5 | @Test
6 | def testDemo(): Unit = {
7 | System.setProperty("importPath", "docker/import")
8 | Demo.main(Array.empty)
9 | }
10 |
11 | }
12 |
--------------------------------------------------------------------------------
/bin/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # exit when any command fails
4 | set -e
5 |
6 | mvn clean -Pspark-3.4 -Pscala-2.12
7 | mvn test -Pspark-3.4 -Pscala-2.12
8 |
9 | mvn clean -Pspark-3.5 -Pscala-2.12
10 | mvn test -Pspark-3.5 -Pscala-2.12
11 |
12 |
13 | mvn clean -Pspark-3.4 -Pscala-2.13
14 | mvn test -Pspark-3.4 -Pscala-2.13
15 |
16 | mvn clean -Pspark-3.5 -Pscala-2.13
17 | mvn test -Pspark-3.5 -Pscala-2.13
18 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from integration.test_basespark import arangodb_client, database_conn, spark, endpoints, single_endpoint, adb_hostname
3 |
4 |
5 | def pytest_addoption(parser):
6 | parser.addoption("--adb-datasource-jar", action="store", dest="datasource_jar_loc", required=True)
7 | parser.addoption("--adb-hostname", action="store", dest="adb_hostname", default="172.28.0.1")
8 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/User.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import java.sql.Date
4 |
5 | case class User(
6 | name: Name,
7 | birthday: Date,
8 | gender: String,
9 | likes: List[String]
10 | )
11 |
12 | case class Name(
13 | first: String,
14 | last: String
15 | )
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/utils/PushDownCtx.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.utils
2 |
3 | import org.apache.spark.sql.arangodb.commons.filter.PushableFilter
4 | import org.apache.spark.sql.types.StructType
5 |
6 | class PushDownCtx(
7 | // columns projection to return
8 | val requiredSchema: StructType,
9 |
10 | // filters to push down
11 | val filters: Array[PushableFilter]
12 | )
13 | extends Serializable
14 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/ArangoParserProvider.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.mapping
2 |
3 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
4 | import org.apache.spark.sql.types.DataType
5 |
6 | import java.util.ServiceLoader
7 |
8 | trait ArangoParserProvider {
9 | def of(contentType: ContentType, schema: DataType, conf: ArangoDBConf): ArangoParser
10 | }
11 |
12 | object ArangoParserProvider {
13 | def apply(): ArangoParserProvider = ServiceLoader.load(classOf[ArangoParserProvider]).iterator().next()
14 | }
15 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/MappingUtils.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.mapping
2 |
3 | import com.arangodb.jackson.dataformat.velocypack.VPackMapper
4 | import com.fasterxml.jackson.databind.ObjectMapper
5 |
6 | object MappingUtils {
7 |
8 | private val jsonMapper = new ObjectMapper()
9 | private val vpackMapper = new VPackMapper()
10 |
11 | def vpackToJson(in: Array[Byte]): String = jsonMapper.writeValueAsString(vpackMapper.readTree(in))
12 |
13 | def jsonToVPack(in: String): Array[Byte] = vpackMapper.writeValueAsBytes(jsonMapper.readTree(in))
14 |
15 | }
16 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.reader
2 |
3 | import org.apache.spark.sql.connector.read.InputPartition
4 |
5 | /**
6 | * Partition corresponding to an Arango collection shard
7 | * @param shardId collection shard id
8 | * @param endpoint db endpoint to use to query the partition
9 | */
10 | class ArangoCollectionPartition(val shardId: String, val endpoint: String) extends InputPartition
11 |
12 | /**
13 | * Custom user queries will not be partitioned (eg. AQL traversals)
14 | */
15 | object SingletonPartition extends InputPartition
16 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.writer
2 |
3 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
4 | import org.apache.spark.sql.catalyst.InternalRow
5 | import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
6 | import org.apache.spark.sql.types.StructType
7 |
8 | class ArangoDataWriterFactory(schema: StructType, options: ArangoDBConf) extends DataWriterFactory {
9 | override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
10 | new ArangoDataWriter(schema, options, partitionId)
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/demo/src/test/resources/log4j.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/ArangoGeneratorProvider.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.mapping
2 |
3 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
4 | import org.apache.spark.sql.types.StructType
5 |
6 | import java.io.OutputStream
7 | import java.util.ServiceLoader
8 |
9 | trait ArangoGeneratorProvider {
10 | def of(contentType: ContentType, schema: StructType, outputStream: OutputStream, conf: ArangoDBConf): ArangoGenerator
11 | }
12 |
13 | object ArangoGeneratorProvider {
14 | def apply(): ArangoGeneratorProvider = ServiceLoader.load(classOf[ArangoGeneratorProvider]).iterator().next()
15 | }
16 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import com.fasterxml.jackson.core.JsonFactory
4 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
5 | import org.apache.spark.sql.arangodb.datasource.mapping.json.JSONOptions
6 |
7 | package object mapping {
8 | private[mapping] def createOptions(jsonFactory: JsonFactory, conf: ArangoDBConf) =
9 | new JSONOptions(Map.empty[String, String], "UTC") {
10 | override def buildJsonFactory(): JsonFactory = jsonFactory
11 |
12 | override val ignoreNullFields: Boolean = conf.mappingOptions.ignoreNullFields
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import com.fasterxml.jackson.core.JsonFactory
4 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
5 | import org.apache.spark.sql.arangodb.datasource.mapping.json.JSONOptions
6 |
7 | package object mapping {
8 | private[mapping] def createOptions(jsonFactory: JsonFactory, conf: ArangoDBConf) =
9 | new JSONOptions(Map.empty[String, String], "UTC") {
10 | override def buildJsonFactory(): JsonFactory = jsonFactory
11 |
12 | override val ignoreNullFields: Boolean = conf.mappingOptions.ignoreNullFields
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/filter/FilterSupport.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 |
4 | sealed trait FilterSupport
5 |
6 | object FilterSupport {
7 |
8 | /**
9 | * the filter can be applied and does not need to be evaluated again after scanning
10 | */
11 | case object FULL extends FilterSupport
12 |
13 | /**
14 | * the filter can be partially applied and it needs to be evaluated again after scanning
15 | */
16 | case object PARTIAL extends FilterSupport
17 |
18 | /**
19 | * the filter cannot be applied and it needs to be evaluated again after scanning
20 | */
21 | case object NONE extends FilterSupport
22 | }
23 |
--------------------------------------------------------------------------------
/demo/python-demo/read_write_demo.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from pyspark.sql import SparkSession
4 |
5 | import read_demo
6 | import write_demo
7 | from schemas import movie_schema
8 |
9 |
10 | def read_write_demo(spark: SparkSession, opts: Dict[str, str]):
11 | print("-----------------------")
12 | print("--- READ-WRITE DEMO ---")
13 | print("-----------------------")
14 |
15 | print("Reading 'movies' collection and writing 'actionMovies' collection...")
16 | action_movies_df = read_demo.read_collection(spark, "movies", opts, movie_schema)\
17 | .select("_key", "title", "releaseDate", "runtime", "description")\
18 | .filter("genre = 'Action'")
19 | write_demo.save_df(action_movies_df.to_pandas_on_spark(), "actionMovies", opts)
20 | print("You can now view the actionMovies collection in ArangoDB!")
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.reader
2 |
3 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
4 | import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
5 | import org.apache.spark.sql.catalyst.InternalRow
6 | import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
7 |
8 | class ArangoPartitionReaderFactory(ctx: PushDownCtx, options: ArangoDBConf) extends PartitionReaderFactory {
9 | override def createReader(partition: InputPartition): PartitionReader[InternalRow] = partition match {
10 | case p: ArangoCollectionPartition => new ArangoCollectionPartitionReader(p, ctx, options)
11 | case SingletonPartition => new ArangoQueryReader(ctx.requiredSchema, options)
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/ArangoParser.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.mapping
2 |
3 | import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
4 | import org.apache.spark.sql.catalyst.InternalRow
5 | import org.apache.spark.unsafe.types.UTF8String
6 |
7 | trait ArangoParser {
8 |
9 | /**
10 | * Parse the JSON or VPACK input to the set of [[InternalRow]]s.
11 | *
12 | * @param recordLiteral an optional function that will be used to generate
13 | * the corrupt record text instead of record.toString
14 | */
15 | def parse[T](
16 | record: T,
17 | createParser: (JsonFactory, T) => JsonParser,
18 | recordLiteral: T => UTF8String): Iterable[InternalRow]
19 |
20 | def parse(data: Array[Byte]): Iterable[InternalRow]
21 |
22 | }
23 |
--------------------------------------------------------------------------------
/.circleci/maven-release-settings.xml:
--------------------------------------------------------------------------------
1 |
3 |
4 |
5 |
6 | central
7 |
8 | true
9 |
10 |
11 | ${env.GPG_KEYNAME}
12 | ${env.GPG_PASSPHRASE}
13 |
14 |
15 |
16 |
17 |
18 |
19 | central
20 | ${env.CENTRAL_USERNAME}
21 | ${env.CENTRAL_PASSWORD}
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/exceptions/ArangoDBDataWriterException.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.exceptions
2 |
3 | /**
4 | * Exception thrown after all writes attempts have failed.
5 | * It contains the exceptions thrown at each attempt.
6 | *
7 | * @param exceptions array of exceptions thrown at each attempt
8 | */
9 | class ArangoDBDataWriterException(val exceptions: Array[Exception])
10 | extends RuntimeException(s"Failed ${exceptions.length} times: ${ArangoDBDataWriterException.toMessage(exceptions)}") {
11 | val attempts: Int = exceptions.length
12 |
13 | override def getCause: Throwable = exceptions(0)
14 | }
15 |
16 | private object ArangoDBDataWriterException {
17 |
18 | // creates exception message
19 | private def toMessage(exceptions: Array[Exception]): String = exceptions
20 | .zipWithIndex
21 | .map(it => s"""Attempt #${it._2 + 1}: ${it._1}""")
22 | .mkString("[\n\t", ",\n\t", "\n]")
23 |
24 | }
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # ArangoDB Datasource for Apache Spark
4 | [](https://maven-badges.herokuapp.com/maven-central/com.arangodb/arangodb-spark-datasource-3.5_2.12)
5 | [](https://github.com/arangodb/arangodb-spark-datasource/actions)
6 | [](https://sonarcloud.io/summary/new_code?id=arangodb_arangodb-spark-datasource)
7 |
8 | The official [ArangoDB](https://www.arangodb.com/) Datasource connector for Apache Spark.
9 |
10 | ## Learn more
11 | - [ChangeLog](ChangeLog.md)
12 | - [Demo](./demo)
13 | - [Documentation](https://www.arangodb.com/docs/stable/drivers/spark-connector-new.html)
14 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/AcquireHostListTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import com.arangodb.spark.DefaultSource
4 | import org.apache.spark.sql.SparkSession
5 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
6 | import org.junit.jupiter.api.{Disabled, Test}
7 |
8 | @Disabled("manual test only")
9 | class AcquireHostListTest {
10 |
11 | private val spark: SparkSession = SparkSession.builder()
12 | .appName("ArangoDBSparkTest")
13 | .master("local[*]")
14 | .config("spark.driver.host", "127.0.0.1")
15 | .getOrCreate()
16 |
17 | @Test
18 | def read(): Unit = {
19 | spark.read
20 | .format(classOf[DefaultSource].getName)
21 | .options(Map(
22 | ArangoDBConf.COLLECTION -> "_fishbowl",
23 | ArangoDBConf.ENDPOINTS -> "172.28.0.1:8529",
24 | ArangoDBConf.ACQUIRE_HOST_LIST -> "true",
25 | ArangoDBConf.PASSWORD -> "test"
26 | ))
27 | .load()
28 | .show()
29 | }
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/mapping/ArangoGenerator.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.mapping
2 |
3 | import org.apache.spark.sql.catalyst.InternalRow
4 | import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
5 |
6 | trait ArangoGenerator {
7 |
8 | def close(): Unit
9 |
10 | def flush(): Unit
11 |
12 | def writeStartArray(): Unit
13 |
14 | def writeEndArray(): Unit
15 |
16 | /**
17 | * Transforms a single `InternalRow` to JSON or VPACK object.
18 | *
19 | * @param row The row to convert
20 | */
21 | def write(row: InternalRow): Unit
22 |
23 | /**
24 | * Transforms multiple `InternalRow`s or `MapData`s to JSON or VPACK array
25 | *
26 | * @param array The array of rows or maps to convert
27 | */
28 | def write(array: ArrayData): Unit
29 |
30 | /**
31 | * Transforms a single `MapData` to JSON or VPACK object
32 | *
33 | * @param map a map to convert
34 | */
35 | def write(map: MapData): Unit
36 |
37 | def writeLineEnding(): Unit
38 | }
39 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/IsNullTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 | import org.apache.spark.sql.sources.{IsNull, IsNotNull}
4 | import org.apache.spark.sql.types._
5 | import org.assertj.core.api.Assertions.assertThat
6 | import org.junit.jupiter.api.Test
7 |
8 | class IsNullTest {
9 | private val schema = StructType(Array(
10 | StructField("a", StructType(Array(
11 | StructField("b", StringType)
12 | )))
13 | ))
14 |
15 | @Test
16 | def isNull(): Unit = {
17 | val isNullFilter = PushableFilter(IsNull("a.b"), schema)
18 | assertThat(isNullFilter.support()).isEqualTo(FilterSupport.FULL)
19 | assertThat(isNullFilter.aql("d")).isEqualTo("""`d`.`a`.`b` == null""")
20 | }
21 |
22 | @Test
23 | def isNotNull(): Unit = {
24 | val isNotNullFilter = PushableFilter(IsNotNull("a.b"), schema)
25 | assertThat(isNotNullFilter.support()).isEqualTo(FilterSupport.FULL)
26 | assertThat(isNotNullFilter.aql("d")).isEqualTo("""`d`.`a`.`b` != null""")
27 | }
28 |
29 | }
30 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/exceptions/ExceptionsSerializationTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.exceptions
2 |
3 | import com.arangodb.entity.ErrorEntity
4 | import com.fasterxml.jackson.databind.ObjectMapper
5 | import org.junit.jupiter.api.Test
6 |
7 | import java.io.{ByteArrayOutputStream, ObjectOutputStream}
8 |
9 | class ExceptionsSerializationTest {
10 |
11 | @Test
12 | def arangoDBMultiException(): Unit = {
13 | val mapper = new ObjectMapper()
14 | val errors = Stream.range(1, 10000).map(_ =>
15 | (
16 | mapper.readValue(
17 | """{
18 | |"errorMessage": "errorMessage",
19 | |"errorNum": 1234
20 | |}""".stripMargin, classOf[ErrorEntity]),
21 | "record"
22 | )
23 | )
24 | val e = new ArangoDBMultiException(errors.toArray)
25 | val objectOutputStream = new ObjectOutputStream(new ByteArrayOutputStream())
26 | objectOutputStream.writeObject(e)
27 | objectOutputStream.flush()
28 | objectOutputStream.close()
29 | }
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/TestUtils.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | object TestUtils {
4 | def isAtLeastVersion(version: String, otherMajor: Int, otherMinor: Int, otherPatch: Int): Boolean = compareVersion(version, otherMajor, otherMinor, otherPatch) >= 0
5 |
6 | def isLessThanVersion(version: String, otherMajor: Int, otherMinor: Int, otherPatch: Int): Boolean = compareVersion(version, otherMajor, otherMinor, otherPatch) < 0
7 |
8 | private def compareVersion(version: String, otherMajor: Int, otherMinor: Int, otherPatch: Int): Int = {
9 | val parts = version.split("-")(0).split("\\.")
10 | val major = parts(0).toInt
11 | val minor = parts(1).toInt
12 | val patch = parts(2).toInt
13 | val majorComparison = Integer.compare(major, otherMajor)
14 | if (majorComparison != 0) {
15 | return majorComparison
16 | }
17 | val minorComparison = Integer.compare(minor, otherMinor)
18 | if (minorComparison != 0) {
19 | return minorComparison
20 | }
21 | Integer.compare(patch, otherPatch)
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoUtils.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons
2 |
3 | import org.apache.spark.sql.types.{StringType, StructField, StructType}
4 | import org.apache.spark.sql.{Encoders, SparkSession}
5 |
6 | /**
7 | * @author Michele Rastelli
8 | */
9 | object ArangoUtils {
10 |
11 | def inferSchema(options: ArangoDBConf): StructType = {
12 | val client = ArangoClient(options)
13 | val sampleEntries = options.readOptions.readMode match {
14 | case ReadMode.Query => client.readQuerySample()
15 | case ReadMode.Collection => client.readCollectionSample()
16 | }
17 | client.shutdown()
18 |
19 | val spark = SparkSession.getActiveSession.get
20 | val schema = spark
21 | .read
22 | .json(spark.createDataset(sampleEntries)(Encoders.STRING))
23 | .schema
24 |
25 | if (options.readOptions.columnNameOfCorruptRecord.isEmpty) {
26 | schema
27 | } else {
28 | schema.add(StructField(options.readOptions.columnNameOfCorruptRecord, StringType, nullable = true))
29 | }
30 | }
31 |
32 | }
33 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/ColumnsPruningTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons
2 |
3 | import org.apache.spark.sql.types._
4 | import org.assertj.core.api.Assertions.assertThat
5 | import org.junit.jupiter.api.Test
6 |
7 | class ColumnsPruningTest {
8 |
9 | @Test
10 | def generateAqlReturnClause(): Unit = {
11 | val schema = StructType(Array(
12 | StructField(""""birthday"""", DateType),
13 | StructField("gender", StringType),
14 | StructField("likes", ArrayType(StringType)),
15 | StructField("name", StructType(Array(
16 | StructField("first", StringType),
17 | StructField("last", StringType)
18 | )))
19 | ))
20 |
21 | val res = PushdownUtils.generateColumnsFilter(schema, "d")
22 | assertThat(res).isEqualTo(
23 | """
24 | |{
25 | | `"birthday"`: `d`.`"birthday"`,
26 | | `gender`: `d`.`gender`,
27 | | `likes`: `d`.`likes`,
28 | | `name`: {
29 | | `first`: `d`.`name`.`first`,
30 | | `last`: `d`.`name`.`last`
31 | | }
32 | |}
33 | |""".stripMargin.replaceAll("\\s", ""))
34 | }
35 |
36 | }
37 |
--------------------------------------------------------------------------------
/integration-tests/src/test/resources/log4j2.properties:
--------------------------------------------------------------------------------
1 | status=warn
2 |
3 | appender.console.type=Console
4 | appender.console.name=console
5 | appender.console.layout.type=PatternLayout
6 | appender.console.layout.pattern=%d{HH:mm:ss.SSS} %-5p %c{10}:%L - %m%n
7 |
8 | rootLogger.level=info
9 | rootLogger.appenderRef.stdout.ref=console
10 |
11 | # Settings to quiet third party logs that are too verbose
12 | logger.jetty.name = org.sparkproject.jetty
13 | logger.jetty.level = warn
14 | logger.jetty2.name = org.sparkproject.jetty.util.component.AbstractLifeCycle
15 | logger.jetty2.level = error
16 | logger.repl1.name = org.apache.spark.repl.SparkIMain$exprTyper
17 | logger.repl1.level = info
18 | logger.repl2.name = org.apache.spark.repl.SparkILoop$SparkILoopInterpreter
19 | logger.repl2.level = info
20 |
21 | ## ---
22 |
23 | #logger.writer.name=org.apache.spark.sql.arangodb.datasource.writer.ArangoDataWriter
24 | #logger.writer.level=debug
25 |
26 | #logger.driver.name=com.arangodb
27 | #logger.driver.level=debug
28 |
29 | #logger.netty.name=com.arangodb.shaded.netty
30 | #logger.netty.level=debug
31 |
32 | #logger.communication.name=com.arangodb.internal.net.Communication
33 | #logger.communication.level=debug
34 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/PackageTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 | import org.assertj.core.api.Assertions.assertThat
4 | import org.junit.jupiter.params.ParameterizedTest
5 | import org.junit.jupiter.params.provider.{Arguments, MethodSource}
6 |
7 | import java.util.stream
8 |
9 | class PackageTest {
10 | @ParameterizedTest
11 | @MethodSource(Array("provideSplitAttributeName"))
12 | def splitAttributeName(attribute: String, expected: Array[String]): Unit = {
13 | assertThat(splitAttributeNameParts(attribute)).isEqualTo(expected)
14 | }
15 | }
16 |
17 | object PackageTest {
18 | def provideSplitAttributeName: stream.Stream[Arguments] =
19 | stream.Stream.of(
20 | Arguments.of("a.b.c.d.e.f", Array("a", "b", "c", "d", "e", "f")),
21 | Arguments.of("a.`b`.c.d.e.f", Array("a", "b", "c", "d", "e", "f")),
22 | Arguments.of("a.`b.c`.d.e.f", Array("a", "b.c", "d", "e", "f")),
23 | Arguments.of("a.b.`c.d`.e.f", Array("a", "b", "c.d", "e", "f")),
24 | Arguments.of("a.b.`.c.d.`.e.f", Array("a", "b", ".c.d.", "e", "f")),
25 | Arguments.of("a.b.`.`.e.f", Array("a", "b", ".", "e", "f"))
26 | )
27 | }
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/exceptions/ArangoDBMultiException.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.exceptions
2 |
3 | import com.arangodb.entity.ErrorEntity
4 |
5 | // Due to https://github.com/scala/bug/issues/10679 scala Stream serialization could generate StackOverflowError. To
6 | // avoid it we use:
7 | // val errors: Array[ErrorEntity]
8 | // instead of:
9 | // val errors: Iterable[ErrorEntity]
10 |
11 | /**
12 | * @param errors array of tuples with:
13 | * _1 : the error entity
14 | * _2 : the stringified record causing the error
15 | */
16 | class ArangoDBMultiException(val errors: Array[(ErrorEntity, String)])
17 | extends RuntimeException(ArangoDBMultiException.toMessage(errors))
18 |
19 | private object ArangoDBMultiException {
20 |
21 | // creates exception message keeping only 5 errors for each error type
22 | private def toMessage(errors: Array[(ErrorEntity, String)]): String = errors
23 | .groupBy(_._1.getErrorNum)
24 | .mapValues(_.take(5))
25 | .values
26 | .flatten
27 | .map(it => s"""Error: ${it._1.getErrorNum} - ${it._1.getErrorMessage} for record: ${it._2}""")
28 | .mkString("[\n\t", ",\n\t", "\n]")
29 | }
30 |
--------------------------------------------------------------------------------
/demo/src/main/scala/ReadWriteDemo.scala:
--------------------------------------------------------------------------------
1 | import Schemas.movieSchema
2 |
3 | object ReadWriteDemo {
4 |
5 | def readWriteDemo(): Unit = {
6 | println("-----------------------")
7 | println("--- READ-WRITE DEMO ---")
8 | println("-----------------------")
9 |
10 | println("Reading 'movies' collection and writing 'actionMovies' collection...")
11 | val actionMoviesDF = ReadDemo.readTable("movies", movieSchema)
12 | .select("_key", "title", "releaseDate", "runtime", "description")
13 | .filter("genre = 'Action'")
14 | WriteDemo.saveDF(actionMoviesDF, "actionMovies")
15 | /*
16 | Filters and projection pushdowns are applied in this case.
17 |
18 | In the console an info message log like the following will be printed:
19 | > INFO ArangoScanBuilder:57 - Filters fully applied in AQL:
20 | > IsNotNull(genre)
21 | > EqualTo(genre,Action)
22 |
23 | Also the generated AQL query will be printed with log level debug:
24 | > DEBUG ArangoClient:61 - Executing AQL query:
25 | > FOR d IN @@col FILTER `d`.`genre` != null AND `d`.`genre` == "Action" RETURN {`_key`:`d`.`_key`,`description`:`d`.`description`,`releaseDate`:`d`.`releaseDate`,`runtime`:`d`.`runtime`,`title`:`d`.`title`}
26 | > with params: Map(@col -> movies)
27 | */
28 |
29 | }
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.writer
2 |
3 | import org.apache.spark.sql.SaveMode
4 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
5 | import org.apache.spark.sql.arangodb.commons.exceptions.DataWriteAbortException
6 | import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
7 | import org.apache.spark.sql.types.StructType
8 |
9 | class ArangoBatchWriter(schema: StructType, options: ArangoDBConf, mode: SaveMode) extends BatchWrite {
10 |
11 | override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory =
12 | new ArangoDataWriterFactory(schema, options)
13 |
14 | override def commit(messages: Array[WriterCommitMessage]): Unit = {
15 | // nothing to do here
16 | }
17 |
18 | override def abort(messages: Array[WriterCommitMessage]): Unit = {
19 | val client = ArangoClient(options)
20 | mode match {
21 | case SaveMode.Append => throw new DataWriteAbortException(
22 | "Cannot abort with SaveMode.Append: the underlying data source may require manual cleanup.")
23 | case SaveMode.Overwrite => client.truncate()
24 | case SaveMode.ErrorIfExists => ???
25 | case SaveMode.Ignore => ???
26 | }
27 | client.shutdown()
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/dev-README.md:
--------------------------------------------------------------------------------
1 | # dev-README
2 |
3 | ## GH Actions
4 | Check results [here](https://github.com/arangodb/arangodb-spark-datasource/actions).
5 |
6 | ## SonarCloud
7 | Check results [here](https://sonarcloud.io/project/overview?id=arangodb_arangodb-spark-datasource).
8 |
9 | ## check dependencies updates
10 | ```shell
11 | mvn -Pspark-${sparkVersion} -Pscala-${scalaVersion} versions:display-dependency-updates
12 | ```
13 |
14 | ## analysis tools
15 |
16 | ### scalastyle
17 | ```shell
18 | mvn -Pspark-${sparkVersion} -Pscala-${scalaVersion} process-sources
19 | ```
20 | Reports:
21 | - [arangodb-spark-commons](arangodb-spark-commons/target/scalastyle-output.xml)
22 | - [arangodb-spark-datasource-3.4](arangodb-spark-datasource-3.4/target/scalastyle-output.xml)
23 | - [arangodb-spark-datasource-3.5](arangodb-spark-datasource-3.5/target/scalastyle-output.xml)
24 |
25 | ### scapegoat
26 | ```shell
27 | mvn -Pspark-${sparkVersion} -Pscala-${scalaVersion} test-compile
28 | ```
29 | Reports:
30 | - [arangodb-spark-commons](arangodb-spark-commons/target/scapegoat/scapegoat.html)
31 | - [arangodb-spark-datasource-3.4](arangodb-spark-datasource-3.4/target/scapegoat/scapegoat.html)
32 | - [arangodb-spark-datasource-3.5](arangodb-spark-datasource-3.5/target/scapegoat/scapegoat.html)
33 |
34 | ### JaCoCo
35 | ```shell
36 | mvn -Pspark-${sparkVersion} -Pscala-${scalaVersion} verify
37 | ```
38 | Report:
39 | - [integration-tests](integration-tests/target/site/jacoco-aggregate/index.html)
40 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.reader
2 |
3 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf, ReadMode}
4 | import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
5 | import org.apache.spark.sql.catalyst.expressions.ExprUtils
6 | import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan}
7 | import org.apache.spark.sql.types.StructType
8 |
9 | class ArangoScan(ctx: PushDownCtx, options: ArangoDBConf) extends Scan with Batch {
10 | ExprUtils.verifyColumnNameOfCorruptRecord(ctx.requiredSchema, options.readOptions.columnNameOfCorruptRecord)
11 |
12 | override def readSchema(): StructType = ctx.requiredSchema
13 |
14 | override def toBatch: Batch = this
15 |
16 | override def planInputPartitions(): Array[InputPartition] = options.readOptions.readMode match {
17 | case ReadMode.Query => Array(SingletonPartition)
18 | case ReadMode.Collection => planCollectionPartitions()
19 | }
20 |
21 | override def createReaderFactory(): PartitionReaderFactory = new ArangoPartitionReaderFactory(ctx, options)
22 |
23 | private def planCollectionPartitions(): Array[InputPartition] =
24 | ArangoClient.getCollectionShardIds(options)
25 | .zip(Stream.continually(options.driverOptions.endpoints).flatten)
26 | .map(it => new ArangoCollectionPartition(it._1, it._2))
27 |
28 | }
29 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/WriteWithNullKey.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.write
2 |
3 | import com.arangodb.ArangoCollection
4 | import org.apache.spark.sql.SaveMode
5 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
6 | import org.apache.spark.sql.arangodb.datasource.BaseSparkTest
7 | import org.assertj.core.api.Assertions.assertThat
8 | import org.junit.jupiter.params.ParameterizedTest
9 | import org.junit.jupiter.params.provider.MethodSource
10 |
11 | class WriteWithNullKey extends BaseSparkTest {
12 | private val collectionName = "writeWithNullKey"
13 | private val collection: ArangoCollection = db.collection(collectionName)
14 |
15 | import spark.implicits._
16 |
17 | @ParameterizedTest
18 | @MethodSource(Array("provideProtocolAndContentType"))
19 | def writeWithNullKey(protocol: String, contentType: String): Unit = {
20 | Seq(
21 | ("Carlsen", "Magnus"),
22 | (null, "Fabiano")
23 | )
24 | .toDF("_key", "name")
25 | .write
26 | .format(BaseSparkTest.arangoDatasource)
27 | .mode(SaveMode.Overwrite)
28 | .options(options + (
29 | ArangoDBConf.COLLECTION -> collectionName,
30 | ArangoDBConf.PROTOCOL -> protocol,
31 | ArangoDBConf.CONTENT_TYPE -> contentType,
32 | ArangoDBConf.CONFIRM_TRUNCATE -> "true"
33 | ))
34 | .save()
35 |
36 | assertThat(collection.count().getCount).isEqualTo(2L)
37 | }
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/NotFilterTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 | import org.apache.spark.sql.sources.{And, EqualTo, Not}
4 | import org.apache.spark.sql.types._
5 | import org.assertj.core.api.Assertions.assertThat
6 | import org.junit.jupiter.api.Test
7 |
8 | class NotFilterTest {
9 | private val schema = StructType(Array(
10 | StructField("integer", IntegerType),
11 | StructField("string", StringType),
12 | StructField("binary", BinaryType)
13 | ))
14 |
15 | // FilterSupport.FULL
16 | private val f1 = EqualTo("string", "str")
17 | private val pushF1 = PushableFilter(f1, schema: StructType)
18 |
19 | // FilterSupport.NONE
20 | private val f2 = EqualTo("binary", Array(Byte.MaxValue))
21 |
22 | // FilterSupport.PARTIAL
23 | private val f3 = And(f1, f2)
24 |
25 | @Test
26 | def notFilterSupportFull(): Unit = {
27 | val notFilter = PushableFilter(Not(f1), schema)
28 | assertThat(notFilter.support()).isEqualTo(FilterSupport.FULL)
29 | assertThat(notFilter.aql("d")).isEqualTo(s"""NOT (${pushF1.aql("d")})""")
30 | }
31 |
32 | @Test
33 | def notFilterSupportNone(): Unit = {
34 | val notFilter = PushableFilter(Not(f2), schema)
35 | assertThat(notFilter.support()).isEqualTo(FilterSupport.NONE)
36 | }
37 |
38 | @Test
39 | def notFilterSupportPartial(): Unit = {
40 | val notFilter = PushableFilter(Not(f3), schema)
41 | assertThat(notFilter.support()).isEqualTo(FilterSupport.NONE)
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/examples/PushdownExample.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.examples
2 |
3 | import com.arangodb.ArangoDB
4 | import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
5 |
6 | object PushdownExample {
7 |
8 | final case class User(name: String, age: Int)
9 |
10 | def main(args: Array[String]): Unit = {
11 | prepareDB()
12 |
13 | val spark: SparkSession = SparkSession.builder()
14 | .appName("ArangoDBSparkTest")
15 | .master("local[*, 3]")
16 | .config("spark.driver.host", "127.0.0.1")
17 | .getOrCreate()
18 |
19 | import spark.implicits._
20 | val ds: Dataset[User] = spark.read
21 | .format("com.arangodb.spark")
22 | .option("password", "test")
23 | .option("endpoints", "172.28.0.1:8529")
24 | .option("table", "users")
25 | .schema(Encoders.product[User].schema)
26 | .load()
27 | .as[User]
28 |
29 | ds
30 | .select("name")
31 | .filter("age >= 18 AND age < 22")
32 | .show
33 |
34 | /*
35 | Generated query:
36 | FOR d IN @@col FILTER `d`.`age` >= 18 AND `d`.`age` < 22 RETURN {`name`:`d`.`name`}
37 | with params: Map(@col -> users)
38 | */
39 | }
40 |
41 | private def prepareDB(): Unit = {
42 | val arangoDB = new ArangoDB.Builder()
43 | .host("172.28.0.1", 8529)
44 | .password("test")
45 | .build()
46 |
47 | val col = arangoDB.db().collection("users")
48 | if (!col.exists())
49 | col.create()
50 | col.truncate()
51 | col.insertDocument(User("Alice", 10))
52 | col.insertDocument(User("Bob", 20))
53 | col.insertDocument(User("Eve", 30))
54 |
55 | arangoDB.shutdown()
56 | }
57 |
58 | }
59 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/com/arangodb/spark/DefaultSource.scala:
--------------------------------------------------------------------------------
1 | package com.arangodb.spark
2 |
3 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
4 | import org.apache.spark.sql.arangodb.datasource.ArangoTable
5 | import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
6 | import org.apache.spark.sql.connector.expressions.Transform
7 | import org.apache.spark.sql.sources.DataSourceRegister
8 | import org.apache.spark.sql.types.StructType
9 | import org.apache.spark.sql.util.CaseInsensitiveStringMap
10 |
11 | import java.util
12 |
13 | class DefaultSource extends TableProvider with DataSourceRegister {
14 |
15 | private def extractOptions(options: util.Map[String, String]): ArangoDBConf = {
16 | val opts: ArangoDBConf = ArangoDBConf(options)
17 | if (opts.driverOptions.acquireHostList) {
18 | val hosts = ArangoClient.acquireHostList(opts)
19 | opts.updated(ArangoDBConf.ENDPOINTS, hosts.mkString(","))
20 | } else {
21 | opts
22 | }
23 | }
24 |
25 | override def inferSchema(options: CaseInsensitiveStringMap): StructType = getTable(options).schema()
26 |
27 | private def getTable(options: CaseInsensitiveStringMap): Table =
28 | getTable(None, options.asCaseSensitiveMap()) // scalastyle:ignore null
29 |
30 | override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table =
31 | getTable(Option(schema), properties)
32 |
33 | override def supportsExternalMetadata(): Boolean = true
34 |
35 | override def shortName(): String = "arangodb"
36 |
37 | private def getTable(schema: Option[StructType], properties: util.Map[String, String]) =
38 | new ArangoTable(schema, extractOptions(properties))
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ArangoUtils}
4 | import org.apache.spark.sql.arangodb.datasource.reader.ArangoScanBuilder
5 | import org.apache.spark.sql.arangodb.datasource.writer.ArangoWriterBuilder
6 | import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
7 | import org.apache.spark.sql.connector.read.ScanBuilder
8 | import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
9 | import org.apache.spark.sql.types.StructType
10 | import org.apache.spark.sql.util.CaseInsensitiveStringMap
11 |
12 | import java.util
13 | import scala.collection.JavaConverters.setAsJavaSetConverter
14 |
15 | class ArangoTable(private var schemaOpt: Option[StructType], options: ArangoDBConf) extends Table with SupportsRead with SupportsWrite {
16 | private lazy val tableSchema = schemaOpt.getOrElse(ArangoUtils.inferSchema(options))
17 |
18 | override def name(): String = this.getClass.toString
19 |
20 | override def schema(): StructType = tableSchema
21 |
22 | override def capabilities(): util.Set[TableCapability] = Set(
23 | TableCapability.BATCH_READ,
24 | TableCapability.BATCH_WRITE,
25 | // TableCapability.STREAMING_WRITE,
26 | TableCapability.ACCEPT_ANY_SCHEMA,
27 | TableCapability.TRUNCATE
28 | // TableCapability.OVERWRITE_BY_FILTER,
29 | // TableCapability.OVERWRITE_DYNAMIC,
30 | ).asJava
31 |
32 | override def newScanBuilder(scanOptions: CaseInsensitiveStringMap): ScanBuilder =
33 | new ArangoScanBuilder(options.updated(ArangoDBConf(scanOptions)), schema())
34 |
35 | override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
36 | new ArangoWriterBuilder(info.schema(), options.updated(ArangoDBConf(info.options())))
37 | }
38 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/PushdownUtils.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons
2 |
3 | import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter}
4 | import org.apache.spark.sql.sources.Filter
5 | import org.apache.spark.sql.types.{StructField, StructType}
6 |
7 | import scala.annotation.tailrec
8 |
9 | // FIXME: use documentVariable instead of "d"
10 | object PushdownUtils {
11 |
12 | private[commons] def generateColumnsFilter(schema: StructType, documentVariable: String): String =
13 | doGenerateColumnsFilter(schema, s"`$documentVariable`.")
14 |
15 | private def doGenerateColumnsFilter(schema: StructType, ctx: String): String =
16 | s"""{${schema.fields.map(generateFieldFilter(_, ctx)).mkString(",")}}"""
17 |
18 | private def generateFieldFilter(field: StructField, ctx: String): String = {
19 | val fieldName = s"`${field.name}`"
20 | val value = s"$ctx$fieldName"
21 | s"$fieldName:" + (field.dataType match {
22 | case s: StructType => doGenerateColumnsFilter(s, s"$value.")
23 | case _ => value
24 | })
25 | }
26 |
27 | def generateFilterClause(filters: Array[PushableFilter]): String = filters match {
28 | case Array() => ""
29 | case _ => "FILTER " + filters
30 | .filter(_.support != FilterSupport.NONE)
31 | .map(_.aql("d"))
32 | .mkString(" AND ")
33 | }
34 |
35 | def generateRowFilters(filters: Array[Filter], schema: StructType, documentVariable: String = "d"): Array[PushableFilter] =
36 | filters.map(PushableFilter(_, schema))
37 |
38 | @tailrec
39 | private[commons] def getStructField(fieldNameParts: Array[String], fieldSchema: StructField): StructField = fieldNameParts match {
40 | case Array() => fieldSchema
41 | case _ => getStructField(fieldNameParts.tail, fieldSchema.dataType.asInstanceOf[StructType](fieldNameParts.head))
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/demo/src/main/scala/Schemas.scala:
--------------------------------------------------------------------------------
1 | import org.apache.spark.sql.types.{DateType, IntegerType, StringType, StructField, StructType, TimestampType}
2 |
3 | object Schemas {
4 | val movieSchema: StructType = StructType(Array(
5 | StructField("_id", StringType, nullable = false),
6 | StructField("_key", StringType, nullable = false),
7 | StructField("description", StringType),
8 | StructField("genre", StringType),
9 | StructField("homepage", StringType),
10 | StructField("imageUrl", StringType),
11 | StructField("imdbId", StringType),
12 | StructField("language", StringType),
13 | StructField("lastModified", TimestampType),
14 | StructField("releaseDate", DateType),
15 | StructField("runtime", IntegerType),
16 | StructField("studio", StringType),
17 | StructField("tagline", StringType),
18 | StructField("title", StringType),
19 | StructField("trailer", StringType)
20 | ))
21 |
22 | val personSchema: StructType = StructType(Array(
23 | StructField("_id", StringType, nullable = false),
24 | StructField("_key", StringType, nullable = false),
25 | StructField("biography", StringType),
26 | StructField("birthday", DateType),
27 | StructField("birthplace", StringType),
28 | StructField("lastModified", TimestampType),
29 | StructField("name", StringType),
30 | StructField("profileImageUrl", StringType)
31 | ))
32 |
33 | val actsInSchema: StructType = StructType(Array(
34 | StructField("_id", StringType, nullable = false),
35 | StructField("_key", StringType, nullable = false),
36 | StructField("_from", StringType, nullable = false),
37 | StructField("_to", StringType, nullable = false),
38 | StructField("name", StringType)
39 | ))
40 |
41 | val directedSchema: StructType = StructType(Array(
42 | StructField("_id", StringType, nullable = false),
43 | StructField("_key", StringType, nullable = false),
44 | StructField("_from", StringType, nullable = false),
45 | StructField("_to", StringType, nullable = false)
46 | ))
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/StringEndsWithFilterTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 | import org.apache.spark.sql.sources.StringEndsWith
4 | import org.apache.spark.sql.types._
5 | import org.assertj.core.api.Assertions.assertThat
6 | import org.junit.jupiter.api.Test
7 |
8 | class StringEndsWithFilterTest {
9 | private val schema = StructType(Array(
10 | // atomic types
11 | StructField("bool", BooleanType),
12 | StructField("double", DoubleType),
13 | StructField("float", FloatType),
14 | StructField("integer", IntegerType),
15 | StructField("date", DateType),
16 | StructField("timestamp", TimestampType),
17 | StructField("short", ShortType),
18 | StructField("byte", ByteType),
19 | StructField("string", StringType),
20 |
21 | // complex types
22 | StructField("array", ArrayType(StringType)),
23 | StructField("null", NullType),
24 | StructField("struct", StructType(Array(
25 | StructField("a", StringType),
26 | StructField("b", IntegerType)
27 | )))
28 | ))
29 |
30 | @Test
31 | def stringEndsWithStringFilter(): Unit = {
32 | val field = "string"
33 | val value = "str"
34 | val filter = PushableFilter(StringEndsWith(field, value), schema: StructType)
35 | assertThat(filter.support()).isEqualTo(FilterSupport.FULL)
36 | assertThat(filter.aql("d")).isEqualTo(s"""STARTS_WITH(REVERSE(`d`.`$field`), REVERSE("$value"))""")
37 | }
38 |
39 | @Test
40 | def stringEndsWithFilterTimestamp(): Unit = {
41 | val field = "timestamp"
42 | val value = "2001-01-02T15:30:45.678111Z"
43 | val filter = PushableFilter(StringEndsWith(field, value), schema: StructType)
44 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE)
45 | }
46 |
47 | @Test
48 | def stringEndsWithFilterDate(): Unit = {
49 | val field = "date"
50 | val value = "2001-01-02"
51 | val filter = PushableFilter(StringEndsWith(field, value), schema: StructType)
52 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE)
53 | }
54 |
55 | }
56 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.mapping
2 |
3 | import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder
4 | import com.fasterxml.jackson.core.JsonFactoryBuilder
5 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
6 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGenerator, ArangoGeneratorProvider}
7 | import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonGenerator}
8 | import org.apache.spark.sql.types.{DataType, StructType}
9 |
10 | import java.io.OutputStream
11 |
12 | abstract sealed class ArangoGeneratorImpl(
13 | schema: DataType,
14 | writer: OutputStream,
15 | options: JSONOptions)
16 | extends JacksonGenerator(
17 | schema,
18 | options.buildJsonFactory().createGenerator(writer),
19 | options) with ArangoGenerator
20 |
21 | class ArangoGeneratorProviderImpl extends ArangoGeneratorProvider {
22 | override def of(
23 | contentType: ContentType,
24 | schema: StructType,
25 | outputStream: OutputStream,
26 | conf: ArangoDBConf
27 | ): ArangoGeneratorImpl = contentType match {
28 | case ContentType.JSON => new JsonArangoGenerator(schema, outputStream, conf)
29 | case ContentType.VPACK => new VPackArangoGenerator(schema, outputStream, conf)
30 | case _ => throw new IllegalArgumentException
31 | }
32 | }
33 |
34 | class JsonArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf)
35 | extends ArangoGeneratorImpl(
36 | schema,
37 | outputStream,
38 | createOptions(new JsonFactoryBuilder().build(), conf)
39 | )
40 |
41 | class VPackArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf)
42 | extends ArangoGeneratorImpl(
43 | schema,
44 | outputStream,
45 | createOptions(new VPackFactoryBuilder().build(), conf)
46 | )
47 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.mapping
2 |
3 | import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder
4 | import com.fasterxml.jackson.core.JsonFactoryBuilder
5 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
6 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGenerator, ArangoGeneratorProvider}
7 | import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonGenerator}
8 | import org.apache.spark.sql.types.{DataType, StructType}
9 |
10 | import java.io.OutputStream
11 |
12 | abstract sealed class ArangoGeneratorImpl(
13 | schema: DataType,
14 | writer: OutputStream,
15 | options: JSONOptions)
16 | extends JacksonGenerator(
17 | schema,
18 | options.buildJsonFactory().createGenerator(writer),
19 | options) with ArangoGenerator
20 |
21 | class ArangoGeneratorProviderImpl extends ArangoGeneratorProvider {
22 | override def of(
23 | contentType: ContentType,
24 | schema: StructType,
25 | outputStream: OutputStream,
26 | conf: ArangoDBConf
27 | ): ArangoGeneratorImpl = contentType match {
28 | case ContentType.JSON => new JsonArangoGenerator(schema, outputStream, conf)
29 | case ContentType.VPACK => new VPackArangoGenerator(schema, outputStream, conf)
30 | case _ => throw new IllegalArgumentException
31 | }
32 | }
33 |
34 | class JsonArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf)
35 | extends ArangoGeneratorImpl(
36 | schema,
37 | outputStream,
38 | createOptions(new JsonFactoryBuilder().build(), conf)
39 | )
40 |
41 | class VPackArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf)
42 | extends ArangoGeneratorImpl(
43 | schema,
44 | outputStream,
45 | createOptions(new VPackFactoryBuilder().build(), conf)
46 | )
47 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/StringContainsFilterTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 | import org.apache.spark.sql.sources.StringContains
4 | import org.apache.spark.sql.types._
5 | import org.assertj.core.api.Assertions.assertThat
6 | import org.junit.jupiter.api.Test
7 |
8 | class StringContainsFilterTest {
9 | private val schema = StructType(Array(
10 | // atomic types
11 | StructField("bool", BooleanType),
12 | StructField("double", DoubleType),
13 | StructField("float", FloatType),
14 | StructField("integer", IntegerType),
15 | StructField("date", DateType),
16 | StructField("timestamp", TimestampType),
17 | StructField("short", ShortType),
18 | StructField("byte", ByteType),
19 | StructField("string", StringType),
20 |
21 | // complex types
22 | StructField("array", ArrayType(StringType)),
23 | StructField("intMap", MapType(StringType, IntegerType)),
24 | StructField("null", NullType),
25 | StructField("struct", StructType(Array(
26 | StructField("a", StringType),
27 | StructField("b", IntegerType)
28 | )))
29 | ))
30 |
31 | @Test
32 | def stringContainsStringFilter(): Unit = {
33 | val field = "string"
34 | val value = "str"
35 | val filter = PushableFilter(StringContains(field, value), schema: StructType)
36 | assertThat(filter.support()).isEqualTo(FilterSupport.FULL)
37 | assertThat(filter.aql("d")).isEqualTo(s"""CONTAINS(`d`.`$field`, "$value")""")
38 | }
39 |
40 | @Test
41 | def stringContainsFilterTimestamp(): Unit = {
42 | val field = "timestamp"
43 | val value = "2001-01-02T15:30:45.678111Z"
44 | val filter = PushableFilter(StringContains(field, value), schema: StructType)
45 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE)
46 | }
47 |
48 | @Test
49 | def stringContainsFilterDate(): Unit = {
50 | val field = "date"
51 | val value = "2001-01-02"
52 | val filter = PushableFilter(StringContains(field, value), schema: StructType)
53 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE)
54 | }
55 |
56 | }
57 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/StringStartsWithFilterTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 | import org.apache.spark.sql.sources.StringStartsWith
4 | import org.apache.spark.sql.types._
5 | import org.assertj.core.api.Assertions.assertThat
6 | import org.junit.jupiter.api.Test
7 |
8 | class StringStartsWithFilterTest {
9 | private val schema = StructType(Array(
10 | // atomic types
11 | StructField("bool", BooleanType),
12 | StructField("double", DoubleType),
13 | StructField("float", FloatType),
14 | StructField("integer", IntegerType),
15 | StructField("date", DateType),
16 | StructField("timestamp", TimestampType),
17 | StructField("short", ShortType),
18 | StructField("byte", ByteType),
19 | StructField("string", StringType),
20 |
21 | // complex types
22 | StructField("array", ArrayType(StringType)),
23 | StructField("intMap", MapType(StringType, IntegerType)),
24 | StructField("null", NullType),
25 | StructField("struct", StructType(Array(
26 | StructField("a", StringType),
27 | StructField("b", IntegerType)
28 | )))
29 | ))
30 |
31 | @Test
32 | def stringStartsWithStringFilter(): Unit = {
33 | val field = "string"
34 | val value = "str"
35 | val filter = PushableFilter(StringStartsWith(field, value), schema: StructType)
36 | assertThat(filter.support()).isEqualTo(FilterSupport.FULL)
37 | assertThat(filter.aql("d")).isEqualTo(s"""STARTS_WITH(`d`.`$field`, "$value")""")
38 | }
39 |
40 | @Test
41 | def stringStartsWithFilterTimestamp(): Unit = {
42 | val field = "timestamp"
43 | val value = "2001-01-02T15:30:45.678111Z"
44 | val filter = PushableFilter(StringStartsWith(field, value), schema: StructType)
45 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE)
46 | }
47 |
48 | @Test
49 | def stringStartsWithFilterDate(): Unit = {
50 | val field = "date"
51 | val value = "2001-01-02"
52 | val filter = PushableFilter(StringStartsWith(field, value), schema: StructType)
53 | assertThat(filter.support()).isEqualTo(FilterSupport.NONE)
54 | }
55 |
56 | }
57 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/examples/DataTypesExample.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.examples
2 |
3 | import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
4 |
5 | import java.sql.{Date, Timestamp}
6 | import java.time.{LocalDate, LocalDateTime}
7 |
8 | object DataTypesExample {
9 |
10 | final case class Order(
11 | userId: String,
12 | price: Double,
13 | shipped: Boolean,
14 | totItems: Int,
15 | creationDate: Date,
16 | lastModifiedTs: Timestamp,
17 | itemIds: List[String],
18 | qty: Map[String, Int]
19 | )
20 |
21 | def main(args: Array[String]): Unit = {
22 | val spark: SparkSession = SparkSession.builder()
23 | .appName("ArangoDBSparkTest")
24 | .master("local[*, 3]")
25 | .config("spark.driver.host", "127.0.0.1")
26 | .getOrCreate()
27 |
28 | import spark.implicits._
29 |
30 | val o = Order(
31 | userId = "Mike",
32 | price = 9.99,
33 | shipped = true,
34 | totItems = 2,
35 | creationDate = Date.valueOf(LocalDate.now()),
36 | lastModifiedTs = Timestamp.valueOf(LocalDateTime.now()),
37 | itemIds = List("itm1", "itm2"),
38 | qty = Map("itm1" -> 1, "itm2" -> 3)
39 | )
40 |
41 | val ds = Seq(o).toDS()
42 |
43 | ds.show()
44 | ds.printSchema()
45 |
46 | ds
47 | .write
48 | .mode("overwrite")
49 | .format("com.arangodb.spark")
50 | .option("password", "test")
51 | .option("endpoints", "172.28.0.1:8529")
52 | .option("table", "orders")
53 | .option("confirmTruncate", "true")
54 | .save()
55 |
56 | val readDS: Dataset[Order] = spark.read
57 | .format("com.arangodb.spark")
58 | .option("password", "test")
59 | .option("endpoints", "172.28.0.1:8529")
60 | .option("table", "orders")
61 | .schema(Encoders.product[Order].schema)
62 | .load()
63 | .as[Order]
64 |
65 | readDS.show()
66 | readDS.printSchema()
67 |
68 | assert(readDS.collect().head == o)
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/demo/docker/start_spark.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | docker network create arangodb --subnet 172.28.0.0/16
4 |
5 | docker run -d --network arangodb --ip 172.28.10.1 --name spark-master -h spark-master \
6 | -e SPARK_MODE=master \
7 | -e SPARK_RPC_AUTHENTICATION_ENABLED=no \
8 | -e SPARK_RPC_ENCRYPTION_ENABLED=no \
9 | -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \
10 | -e SPARK_SSL_ENABLED=no \
11 | -v $(pwd)/docker/import:/import \
12 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \
13 | docker.io/bitnamilegacy/spark:3.5.6
14 |
15 | docker run -d --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spark-worker-1 \
16 | -e SPARK_MODE=worker \
17 | -e SPARK_MASTER_URL=spark://spark-master:7077 \
18 | -e SPARK_WORKER_MEMORY=1G \
19 | -e SPARK_WORKER_CORES=1 \
20 | -e SPARK_RPC_AUTHENTICATION_ENABLED=no \
21 | -e SPARK_RPC_ENCRYPTION_ENABLED=no \
22 | -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \
23 | -e SPARK_SSL_ENABLED=no \
24 | -v $(pwd)/docker/import:/import \
25 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \
26 | docker.io/bitnamilegacy/spark:3.5.6
27 |
28 | docker run -d --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spark-worker-2 \
29 | -e SPARK_MODE=worker \
30 | -e SPARK_MASTER_URL=spark://spark-master:7077 \
31 | -e SPARK_WORKER_MEMORY=1G \
32 | -e SPARK_WORKER_CORES=1 \
33 | -e SPARK_RPC_AUTHENTICATION_ENABLED=no \
34 | -e SPARK_RPC_ENCRYPTION_ENABLED=no \
35 | -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \
36 | -e SPARK_SSL_ENABLED=no \
37 | -v $(pwd)/docker/import:/import \
38 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \
39 | docker.io/bitnamilegacy/spark:3.5.6
40 |
41 | docker run -d --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spark-worker-3 \
42 | -e SPARK_MODE=worker \
43 | -e SPARK_MASTER_URL=spark://spark-master:7077 \
44 | -e SPARK_WORKER_MEMORY=1G \
45 | -e SPARK_WORKER_CORES=1 \
46 | -e SPARK_RPC_AUTHENTICATION_ENABLED=no \
47 | -e SPARK_RPC_ENCRYPTION_ENABLED=no \
48 | -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \
49 | -e SPARK_SSL_ENABLED=no \
50 | -v $(pwd)/docker/import:/import \
51 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \
52 | docker.io/bitnamilegacy/spark:3.5.6
53 |
--------------------------------------------------------------------------------
/demo/python-demo/schemas.py:
--------------------------------------------------------------------------------
1 | from pyspark.sql.types import StructType, StructField, StringType, TimestampType, DateType, IntegerType
2 |
3 | movie_schema: StructType = StructType([
4 | StructField("_id", StringType(), nullable=False),
5 | StructField("_key", StringType(), nullable=False),
6 | StructField("description", StringType()),
7 | StructField("genre", StringType()),
8 | StructField("homepage", StringType()),
9 | StructField("imageUrl", StringType()),
10 | StructField("imdbId", StringType()),
11 | StructField("language", StringType()),
12 | StructField("lastModified", TimestampType()),
13 | StructField("releaseDate", DateType()),
14 | StructField("runtime", IntegerType()),
15 | StructField("studio", StringType()),
16 | StructField("tagline", StringType()),
17 | StructField("title", StringType()),
18 | StructField("trailer", StringType())
19 | ])
20 | person_schema: StructType = StructType([
21 | StructField("_id", StringType(), nullable=False),
22 | StructField("_key", StringType(), nullable=False),
23 | StructField("biography", StringType()),
24 | StructField("birthday", DateType()),
25 | StructField("birthplace", StringType()),
26 | StructField("lastModified", TimestampType()),
27 | StructField("name", StringType()),
28 | StructField("profileImageUrl", StringType())
29 | ])
30 | edges_schema: StructType = StructType([
31 | StructField("_key", StringType(), nullable=False),
32 | StructField("_from", StringType(), nullable=False),
33 | StructField("_to", StringType(), nullable=False),
34 | StructField("$label", StringType()),
35 | StructField("name", StringType()),
36 | StructField("type", StringType()),
37 | ])
38 | acts_in_schema: StructType = StructType([
39 | StructField("_id", StringType(), nullable=False),
40 | StructField("_key", StringType(), nullable=False),
41 | StructField("_from", StringType(), nullable=False),
42 | StructField("_to", StringType(), nullable=False),
43 | StructField("name", StringType())
44 | ])
45 | directed_schema: StructType = StructType([
46 | StructField("_id", StringType(), nullable=False),
47 | StructField("_key", StringType(), nullable=False),
48 | StructField("_from", StringType(), nullable=False),
49 | StructField("_to", StringType(), nullable=False)
50 | ])
51 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.mapping
2 |
3 | import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder
4 | import com.fasterxml.jackson.core.json.JsonReadFeature
5 | import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder}
6 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
7 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoParser, ArangoParserProvider, MappingUtils}
8 | import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonParser}
9 | import org.apache.spark.sql.catalyst.InternalRow
10 | import org.apache.spark.sql.types.DataType
11 | import org.apache.spark.unsafe.types.UTF8String
12 |
13 | abstract sealed class ArangoParserImpl(
14 | schema: DataType,
15 | options: JSONOptions,
16 | recordLiteral: Array[Byte] => UTF8String)
17 | extends JacksonParser(schema, options) with ArangoParser {
18 | override def parse(data: Array[Byte]): Iterable[InternalRow] = super.parse(
19 | data,
20 | (jsonFactory: JsonFactory, record: Array[Byte]) => jsonFactory.createParser(record),
21 | recordLiteral
22 | )
23 | }
24 |
25 | class ArangoParserProviderImpl extends ArangoParserProvider {
26 | override def of(contentType: ContentType, schema: DataType, conf: ArangoDBConf): ArangoParserImpl = contentType match {
27 | case ContentType.JSON => new JsonArangoParser(schema, conf)
28 | case ContentType.VPACK => new VPackArangoParser(schema, conf)
29 | case _ => throw new IllegalArgumentException
30 | }
31 | }
32 |
33 | class JsonArangoParser(schema: DataType, conf: ArangoDBConf)
34 | extends ArangoParserImpl(
35 | schema,
36 | createOptions(new JsonFactoryBuilder()
37 | .configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS, true)
38 | .build(), conf),
39 | (bytes: Array[Byte]) => UTF8String.fromBytes(bytes)
40 | )
41 |
42 | class VPackArangoParser(schema: DataType, conf: ArangoDBConf)
43 | extends ArangoParserImpl(
44 | schema,
45 | createOptions(new VPackFactoryBuilder().build(), conf),
46 | (bytes: Array[Byte]) => UTF8String.fromString(MappingUtils.vpackToJson(bytes))
47 | )
48 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.mapping
2 |
3 | import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder
4 | import com.fasterxml.jackson.core.json.JsonReadFeature
5 | import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder}
6 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
7 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoParser, ArangoParserProvider, MappingUtils}
8 | import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonParser}
9 | import org.apache.spark.sql.catalyst.InternalRow
10 | import org.apache.spark.sql.types.DataType
11 | import org.apache.spark.unsafe.types.UTF8String
12 |
13 | abstract sealed class ArangoParserImpl(
14 | schema: DataType,
15 | options: JSONOptions,
16 | recordLiteral: Array[Byte] => UTF8String)
17 | extends JacksonParser(schema, options) with ArangoParser {
18 | override def parse(data: Array[Byte]): Iterable[InternalRow] = super.parse(
19 | data,
20 | (jsonFactory: JsonFactory, record: Array[Byte]) => jsonFactory.createParser(record),
21 | recordLiteral
22 | )
23 | }
24 |
25 | class ArangoParserProviderImpl extends ArangoParserProvider {
26 | override def of(contentType: ContentType, schema: DataType, conf: ArangoDBConf): ArangoParserImpl = contentType match {
27 | case ContentType.JSON => new JsonArangoParser(schema, conf)
28 | case ContentType.VPACK => new VPackArangoParser(schema, conf)
29 | case _ => throw new IllegalArgumentException
30 | }
31 | }
32 |
33 | class JsonArangoParser(schema: DataType, conf: ArangoDBConf)
34 | extends ArangoParserImpl(
35 | schema,
36 | createOptions(new JsonFactoryBuilder()
37 | .configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS, true)
38 | .build(), conf),
39 | (bytes: Array[Byte]) => UTF8String.fromBytes(bytes)
40 | )
41 |
42 | class VPackArangoParser(schema: DataType, conf: ArangoDBConf)
43 | extends ArangoParserImpl(
44 | schema,
45 | createOptions(new VPackFactoryBuilder().build(), conf),
46 | (bytes: Array[Byte]) => UTF8String.fromString(MappingUtils.vpackToJson(bytes))
47 | )
48 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/OrFilterTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 | import org.apache.spark.sql.sources.{And, Or, EqualTo}
4 | import org.apache.spark.sql.types._
5 | import org.assertj.core.api.Assertions.assertThat
6 | import org.junit.jupiter.api.Test
7 |
8 | class OrFilterTest {
9 | private val schema = StructType(Array(
10 | StructField("integer", IntegerType),
11 | StructField("string", StringType),
12 | StructField("binary", BinaryType)
13 | ))
14 |
15 | // FilterSupport.FULL
16 | private val f1 = EqualTo("string", "str")
17 | private val pushF1 = PushableFilter(f1, schema)
18 |
19 | // FilterSupport.NONE
20 | private val f2 = EqualTo("binary", Array(Byte.MaxValue))
21 |
22 | // FilterSupport.PARTIAL
23 | private val f3 = And(f1, f2)
24 |
25 | @Test
26 | def orFilterSupportFullFull(): Unit = {
27 | val orFilter = PushableFilter(Or(f1, f1), schema)
28 | assertThat(orFilter.support()).isEqualTo(FilterSupport.FULL)
29 | assertThat(orFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")} OR ${pushF1.aql("d")})""")
30 | }
31 |
32 | @Test
33 | def orFilterSupportFullNone(): Unit = {
34 | val orFilter = PushableFilter(Or(f1, f2), schema)
35 | assertThat(orFilter.support()).isEqualTo(FilterSupport.NONE)
36 | }
37 |
38 | @Test
39 | def orFilterSupportFullPartial(): Unit = {
40 | val orFilter = PushableFilter(Or(f1, f3), schema)
41 | assertThat(orFilter.support()).isEqualTo(FilterSupport.PARTIAL)
42 | assertThat(orFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")} OR (${pushF1.aql("d")}))""")
43 | }
44 |
45 | @Test
46 | def orFilterSupportPartialPartial(): Unit = {
47 | val orFilter = PushableFilter(Or(f3, f3), schema)
48 | assertThat(orFilter.support()).isEqualTo(FilterSupport.PARTIAL)
49 | assertThat(orFilter.aql("d")).isEqualTo(s"""((${pushF1.aql("d")}) OR (${pushF1.aql("d")}))""")
50 | }
51 |
52 | @Test
53 | def orFilterSupportPartialNone(): Unit = {
54 | val orFilter = PushableFilter(Or(f3, f2), schema)
55 | assertThat(orFilter.support()).isEqualTo(FilterSupport.NONE)
56 | }
57 |
58 | @Test
59 | def orFilterSupportNoneNone(): Unit = {
60 | val orFilter = PushableFilter(Or(f2, f2), schema)
61 | assertThat(orFilter.support()).isEqualTo(FilterSupport.NONE)
62 | }
63 |
64 | }
65 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | arangodb-spark-datasource
7 | com.arangodb
8 | 1.9.0-SNAPSHOT
9 |
10 | 4.0.0
11 |
12 | arangodb-spark-commons-${spark.compat.version}_${scala.compat.version}
13 |
14 | arangodb-spark-commons
15 | ArangoDB Spark Datasource Commons
16 | https://github.com/arangodb/arangodb-spark-datasource
17 |
18 |
19 |
20 | Michele Rastelli
21 | https://github.com/rashtao
22 |
23 |
24 |
25 |
26 | https://github.com/arangodb/arangodb-spark-datasource
27 |
28 |
29 |
30 | false
31 | ../integration-tests/target/site/jacoco-aggregate/jacoco.xml
32 | false
33 |
34 |
35 |
36 |
37 |
38 | org.jacoco
39 | jacoco-maven-plugin
40 |
41 |
42 | report-aggregate
43 | verify
44 |
45 | report-aggregate
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 | org.apache.maven.plugins
57 | maven-surefire-report-plugin
58 | 3.3.0
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.reader
2 |
3 | import com.arangodb.entity.CursorWarning
4 | import org.apache.spark.internal.Logging
5 | import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
6 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
7 | import org.apache.spark.sql.catalyst.InternalRow
8 | import org.apache.spark.sql.catalyst.util.FailureSafeParser
9 | import org.apache.spark.sql.connector.read.PartitionReader
10 | import org.apache.spark.sql.types._
11 |
12 | import scala.annotation.tailrec
13 | import scala.collection.JavaConverters.iterableAsScalaIterableConverter
14 |
15 |
16 | class ArangoQueryReader(schema: StructType, options: ArangoDBConf) extends PartitionReader[InternalRow] with Logging {
17 |
18 | private val actualSchema = StructType(schema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord))
19 | private val parser = ArangoParserProvider().of(options.driverOptions.contentType, actualSchema, options)
20 | private val safeParser = new FailureSafeParser[Array[Byte]](
21 | parser.parse,
22 | options.readOptions.parseMode,
23 | schema,
24 | options.readOptions.columnNameOfCorruptRecord)
25 | private val client = ArangoClient(options)
26 | private val iterator = client.readQuery()
27 |
28 | var rowIterator: Iterator[InternalRow] = _
29 |
30 | // warnings of non stream AQL cursors are all returned along with the first batch
31 | if (!options.readOptions.stream) logWarns()
32 |
33 | @tailrec
34 | final override def next: Boolean =
35 | if (iterator.hasNext) {
36 | val current = iterator.next()
37 | rowIterator = safeParser.parse(current.get)
38 | if (rowIterator.hasNext) {
39 | true
40 | } else {
41 | next
42 | }
43 | } else {
44 | // FIXME: https://arangodb.atlassian.net/browse/BTS-671
45 | // stream AQL cursors' warnings are only returned along with the final batch
46 | if (options.readOptions.stream) logWarns()
47 | false
48 | }
49 |
50 | override def get: InternalRow = rowIterator.next()
51 |
52 | override def close(): Unit = {
53 | iterator.close()
54 | client.shutdown()
55 | }
56 |
57 | private def logWarns(): Unit = Option(iterator.getWarnings).foreach(_.asScala.foreach((w: CursorWarning) =>
58 | logWarning(s"Got AQL warning: [${w.getCode}] ${w.getMessage}")
59 | ))
60 |
61 | }
62 |
63 |
64 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/test/scala/org/apache/spark/sql/arangodb/commons/filter/AndFilterTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons.filter
2 |
3 | import org.apache.spark.sql.sources.{EqualTo, And}
4 | import org.apache.spark.sql.types._
5 | import org.assertj.core.api.Assertions.assertThat
6 | import org.junit.jupiter.api.Test
7 |
8 | class AndFilterTest {
9 | private val schema = StructType(Array(
10 | StructField("integer", IntegerType),
11 | StructField("string", StringType),
12 | StructField("binary", BinaryType)
13 | ))
14 |
15 | // FilterSupport.FULL
16 | private val f1 = EqualTo("string", "str")
17 | private val pushF1 = PushableFilter(f1, schema)
18 |
19 | // FilterSupport.NONE
20 | private val f2 = EqualTo("binary", Array(Byte.MaxValue))
21 |
22 | // FilterSupport.PARTIAL
23 | private val f3 = And(f1, f2)
24 |
25 | @Test
26 | def andFilterSupportFullFull(): Unit = {
27 | val andFilter = PushableFilter(And(f1, f1), schema)
28 | assertThat(andFilter.support()).isEqualTo(FilterSupport.FULL)
29 | assertThat(andFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")} AND ${pushF1.aql("d")})""")
30 | }
31 |
32 | @Test
33 | def andFilterSupportFullNone(): Unit = {
34 | val andFilter = PushableFilter(And(f1, f2), schema)
35 | assertThat(andFilter.support()).isEqualTo(FilterSupport.PARTIAL)
36 | assertThat(andFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")})""")
37 | }
38 |
39 | @Test
40 | def andFilterSupportFullPartial(): Unit = {
41 | val andFilter = PushableFilter(And(f1, f3), schema)
42 | assertThat(andFilter.support()).isEqualTo(FilterSupport.PARTIAL)
43 | assertThat(andFilter.aql("d")).isEqualTo(s"""(${pushF1.aql("d")} AND (${pushF1.aql("d")}))""")
44 | }
45 |
46 | @Test
47 | def andFilterSupportPartialPartial(): Unit = {
48 | val andFilter = PushableFilter(And(f3, f3), schema)
49 | assertThat(andFilter.support()).isEqualTo(FilterSupport.PARTIAL)
50 | assertThat(andFilter.aql("d")).isEqualTo(s"""((${pushF1.aql("d")}) AND (${pushF1.aql("d")}))""")
51 | }
52 |
53 | @Test
54 | def andFilterSupportPartialNone(): Unit = {
55 | val andFilter = PushableFilter(And(f3, f2), schema)
56 | assertThat(andFilter.support()).isEqualTo(FilterSupport.PARTIAL)
57 | assertThat(andFilter.aql("d")).isEqualTo(s"""((${pushF1.aql("d")}))""")
58 | }
59 |
60 | @Test
61 | def andFilterSupportNoneNone(): Unit = {
62 | val andFilter = PushableFilter(And(f2, f2), schema)
63 | assertThat(andFilter.support()).isEqualTo(FilterSupport.NONE)
64 | }
65 |
66 | }
67 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/filter/package.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.commons
2 |
3 | import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
4 | import org.apache.spark.sql.types._
5 |
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | package object filter {
9 |
10 | private[filter] def splitAttributeNameParts(attribute: String): Array[String] = {
11 | val parts = new ArrayBuffer[String]()
12 | var sb = new StringBuilder()
13 | var inEscapedBlock = false
14 | for (c <- attribute.toCharArray) {
15 | if (c == '`') inEscapedBlock = !inEscapedBlock
16 | if (c == '.' && !inEscapedBlock) {
17 | parts += sb.toString()
18 | sb = new StringBuilder()
19 | } else if (c != '`') {
20 | sb.append(c)
21 | }
22 | }
23 | parts += sb.toString()
24 | parts.toArray
25 | }
26 |
27 | private[filter] def isTypeAqlCompatible(t: AbstractDataType): Boolean = t match {
28 | // atomic types
29 | case _:
30 | StringType
31 | | BooleanType
32 | | FloatType
33 | | DoubleType
34 | | IntegerType
35 | | LongType
36 | | ShortType
37 | | ByteType
38 | => true
39 | case _:
40 | DateType
41 | | TimestampType
42 | => false
43 | // complex types
44 | case _: NullType => true
45 | case at: ArrayType => isTypeAqlCompatible(at.elementType)
46 | case st: StructType => st.forall(f => isTypeAqlCompatible(f.dataType))
47 | case mt: MapType => mt.keyType == StringType && isTypeAqlCompatible(mt.valueType)
48 | case _ => false
49 | }
50 |
51 | private[filter] def getValue(t: AbstractDataType, v: Any): String = t match {
52 | case NullType => "null"
53 | case _: TimestampType | DateType | StringType => s""""$v""""
54 | case _: BooleanType | FloatType | DoubleType | IntegerType | LongType | ShortType | ByteType => v.toString
55 | case at: ArrayType => s"""[${v.asInstanceOf[Traversable[Any]].map(getValue(at.elementType, _)).mkString(",")}]"""
56 | case _: StructType =>
57 | val row = v.asInstanceOf[GenericRowWithSchema]
58 | val parts = row.values.zip(row.schema).map(sf =>
59 | s""""${sf._2.name}":${getValue(sf._2.dataType, sf._1)}"""
60 | )
61 | s"{${parts.mkString(",")}}"
62 | case mt: MapType =>
63 | v.asInstanceOf[Map[String, Any]].map(it => {
64 | s"""${getValue(mt.keyType, it._1)}:${getValue(mt.valueType, it._2)}"""
65 | }).mkString("{", ",", "}")
66 | }
67 |
68 | }
69 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/JacksonTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb
2 |
3 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType}
4 | import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGeneratorProvider, ArangoParserProvider, MappingUtils}
5 | import org.apache.spark.sql.types._
6 | import org.assertj.core.api.Assertions.assertThat
7 | import org.junit.jupiter.api.Test
8 |
9 | import java.io.ByteArrayOutputStream
10 | import java.nio.charset.StandardCharsets
11 |
12 | /**
13 | * @author Michele Rastelli
14 | */
15 | class JacksonTest {
16 | private val jsonString =
17 | """
18 | |{
19 | | "birthday": "1964-01-02",
20 | | "gender": "female",
21 | | "likes": [
22 | | "swimming"
23 | | ],
24 | | "name": {
25 | | "first": "Roseline",
26 | | "last": "Jucean"
27 | | },
28 | | "nullString": null,
29 | | "nullField": null,
30 | | "mapField": {
31 | | "foo": 1,
32 | | "bar": 2
33 | | }
34 | |}
35 | |""".stripMargin.replaceAll("\\s", "")
36 |
37 | private val jsonBytes = jsonString.getBytes(StandardCharsets.UTF_8)
38 | private val vpackBytes = MappingUtils.jsonToVPack(jsonString)
39 |
40 | private val schema: StructType = new StructType(
41 | Array(
42 | StructField("birthday", DateType),
43 | StructField("gender", StringType),
44 | StructField("likes", ArrayType(StringType)),
45 | StructField("name", StructType(
46 | Array(
47 | StructField("first", StringType),
48 | StructField("last", StringType)
49 | )
50 | )),
51 | StructField("nullString", StringType, nullable = true),
52 | StructField("nullField", NullType),
53 | StructField("mapField", MapType(StringType, IntegerType))
54 | )
55 | )
56 |
57 | @Test
58 | def jsonRoudTrip(): Unit = {
59 | roundTrip(ContentType.JSON, jsonBytes)
60 | }
61 |
62 | @Test
63 | def vpackRoudTrip(): Unit = {
64 | roundTrip(ContentType.VPACK, vpackBytes)
65 | }
66 |
67 | private def roundTrip(contentType: ContentType, data: Array[Byte]): Unit = {
68 | val parser = ArangoParserProvider().of(contentType, schema, ArangoDBConf())
69 | val parsed = parser.parse(data)
70 | val output = new ByteArrayOutputStream()
71 | val generator = ArangoGeneratorProvider().of(contentType, schema, output, ArangoDBConf())
72 | generator.write(parsed.head)
73 | generator.close()
74 | assertThat(output.toByteArray).isEqualTo(data)
75 | }
76 |
77 | }
78 |
--------------------------------------------------------------------------------
/demo/src/main/scala/Demo.scala:
--------------------------------------------------------------------------------
1 | import org.apache.spark.sql.SparkSession
2 |
3 | object Demo {
4 | val importPath: String = System.getProperty("importPath", "/demo/docker/import")
5 | val password: String = System.getProperty("password", "test")
6 | val endpoints: String = System.getProperty("endpoints", "172.28.0.1:8529,172.28.0.1:8539,172.28.0.1:8549")
7 | val sslEnabled: String = System.getProperty("ssl.enabled", "true")
8 | val sslCertValue: String = System.getProperty("ssl.cert.value", "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlekNDQW1PZ0F3SUJBZ0lFZURDelh6QU5CZ2txaGtpRzl3MEJBUXNGQURCdU1SQXdEZ1lEVlFRR0V3ZFYKYm10dWIzZHVNUkF3RGdZRFZRUUlFd2RWYm10dWIzZHVNUkF3RGdZRFZRUUhFd2RWYm10dWIzZHVNUkF3RGdZRApWUVFLRXdkVmJtdHViM2R1TVJBd0RnWURWUVFMRXdkVmJtdHViM2R1TVJJd0VBWURWUVFERXdsc2IyTmhiR2h2CmMzUXdIaGNOTWpBeE1UQXhNVGcxTVRFNVdoY05NekF4TURNd01UZzFNVEU1V2pCdU1SQXdEZ1lEVlFRR0V3ZFYKYm10dWIzZHVNUkF3RGdZRFZRUUlFd2RWYm10dWIzZHVNUkF3RGdZRFZRUUhFd2RWYm10dWIzZHVNUkF3RGdZRApWUVFLRXdkVmJtdHViM2R1TVJBd0RnWURWUVFMRXdkVmJtdHViM2R1TVJJd0VBWURWUVFERXdsc2IyTmhiR2h2CmMzUXdnZ0VpTUEwR0NTcUdTSWIzRFFFQkFRVUFBNElCRHdBd2dnRUtBb0lCQVFDMVdpRG5kNCt1Q21NRzUzOVoKTlpCOE53STBSWkYzc1VTUUdQeDNsa3FhRlRaVkV6TVpMNzZIWXZkYzlRZzdkaWZ5S3lRMDlSTFNwTUFMWDlldQpTc2VEN2JaR25mUUg1MkJuS2NUMDllUTN3aDdhVlE1c04yb215Z2RITEM3WDl1c250eEFmdjdOem12ZG9nTlhvCkpReVkvaFNaZmY3UklxV0g4Tm5BVUtranFPZTZCZjVMRGJ4SEtFU21yRkJ4T0NPbmhjcHZaV2V0d3BpUmRKVlAKd1VuNVA4MkNBWnpmaUJmbUJabkI3RDBsKy82Q3Y0ak11SDI2dUFJY2l4blZla0JRemwxUmd3Y3p1aVpmMk1HTwo2NHZETU1KSldFOUNsWkYxdVF1UXJ3WEY2cXdodVAxSG5raWk2d05iVHRQV2xHU2txZXV0cjAwNCtIemJmOEtuClJZNFBBZ01CQUFHaklUQWZNQjBHQTFVZERnUVdCQlRCcnY5QXd5bnQzQzVJYmFDTnlPVzV2NEROa1RBTkJna3EKaGtpRzl3MEJBUXNGQUFPQ0FRRUFJbTlyUHZEa1lwbXpwU0loUjNWWEc5WTcxZ3hSRHJxa0VlTHNNb0V5cUdudwovengxYkRDTmVHZzJQbmNMbFc2elRJaXBFQm9vaXhJRTlVN0t4SGdaeEJ5MEV0NkVFV3ZJVW1ucjZGNEYrZGJUCkQwNTBHSGxjWjdlT2VxWVRQWWVRQzUwMkcxRm80dGROaTRsRFA5TDlYWnBmN1ExUWltUkgycWFMUzAzWkZaYTIKdFk3YWgvUlFxWkw4RGt4eDgvemMyNXNnVEhWcHhvSzg1M2dsQlZCcy9FTk1peUdKV21BWFFheWV3WTNFUHQvOQp3R3dWNEttVTNkUERsZVFlWFNVR1BVSVNlUXhGankrakN3MjFwWXZpV1ZKVE5CQTlsNW55M0doRW1jbk9UL2dRCkhDdlZSTHlHTE1iYU1aNEpyUHdiK2FBdEJncmdlaUs0eGVTTU12cmJodz09Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K")
9 |
10 | val spark: SparkSession = SparkSession.builder
11 | .appName("arangodb-demo")
12 | .master("local[*, 3]")
13 | .getOrCreate
14 |
15 | val options: Map[String, String] = Map(
16 | "password" -> password,
17 | "endpoints" -> endpoints,
18 | "ssl.enabled" -> sslEnabled,
19 | "ssl.cert.value" -> sslCertValue,
20 | "ssl.verifyHost" -> "false"
21 | )
22 |
23 | def main(args: Array[String]): Unit = {
24 | WriteDemo.writeDemo()
25 | ReadDemo.readDemo()
26 | ReadWriteDemo.readWriteDemo()
27 | spark.stop
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/demo/python-demo/read_demo.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import pyspark.sql
4 | from pyspark.sql import SparkSession
5 | from pyspark.sql.types import StructType, StructField, StringType
6 |
7 | from schemas import movie_schema
8 | from utils import combine_dicts
9 |
10 |
11 | def read_demo(spark: SparkSession, base_opts: Dict[str, str]):
12 | movies_df = read_collection(spark, "movies", base_opts, movie_schema)
13 |
14 | print("Read table: history movies or documentaries about 'World War' released from 2000-01-01")
15 | # We can get to what we want in 2 different ways:
16 | # First, the PySpark dataframe way...
17 | movies_df \
18 | .select("title", "releaseDate", "genre", "description") \
19 | .filter("genre IN ('History', 'Documentary') AND description LIKE '%World War%' AND releaseDate > '2000'") \
20 | .show()
21 |
22 | # Second, in the Pandas on Spark way...
23 | movies_pd_df = movies_df.to_pandas_on_spark()
24 | subset = movies_pd_df[["title", "releaseDate", "genre", "description"]]
25 | recent_ww_movies = subset[subset["genre"].isin(["History", "Documentary"])\
26 | & (subset["releaseDate"] >= '2000')\
27 | & subset["description"].str.contains("World War")]
28 | print(recent_ww_movies)
29 |
30 | print("Read query: actors of movies directed by Clint Eastwood with related movie title and interpreted role")
31 | read_aql_query(
32 | spark,
33 | """WITH movies, persons
34 | FOR v, e, p IN 2 ANY "persons/1062" OUTBOUND directed, INBOUND actedIn
35 | RETURN {movie: p.vertices[1].title, name: v.name, role: p.edges[1].name}
36 | """,
37 | base_opts,
38 | StructType([
39 | StructField("movie", StringType()),
40 | StructField("name", StringType()),
41 | StructField("role", StringType())
42 | ])
43 | ).show(20, 200)
44 |
45 |
46 | def read_collection(spark: SparkSession, collection_name: str, base_opts: Dict[str, str], schema: StructType) -> pyspark.sql.DataFrame:
47 | arangodb_datasource_options = combine_dicts([base_opts, {"table": collection_name}])
48 |
49 | return spark.read \
50 | .format("com.arangodb.spark") \
51 | .options(**arangodb_datasource_options) \
52 | .schema(schema) \
53 | .load()
54 |
55 |
56 | def read_aql_query(spark: SparkSession, query: str, base_opts: Dict[str, str], schema: StructType) -> pyspark.sql.DataFrame:
57 | arangodb_datasource_options = combine_dicts([base_opts, {"query": query}])
58 |
59 | return spark.read \
60 | .format("com.arangodb.spark") \
61 | .options(**arangodb_datasource_options) \
62 | .schema(schema) \
63 | .load()
64 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/ReadSmartEdgeCollectionTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import com.arangodb.entity.EdgeDefinition
4 | import com.arangodb.model.GraphCreateOptions
5 | import org.apache.spark.sql.DataFrame
6 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
7 | import org.assertj.core.api.Assertions.assertThat
8 | import org.junit.jupiter.api.Assumptions.assumeTrue
9 | import org.junit.jupiter.api.BeforeAll
10 | import org.junit.jupiter.params.ParameterizedTest
11 | import org.junit.jupiter.params.provider.MethodSource
12 |
13 | import java.util
14 | import scala.collection.JavaConverters.asJavaIterableConverter
15 | import scala.collection.immutable
16 |
17 | class ReadSmartEdgeCollectionTest extends BaseSparkTest {
18 |
19 | @ParameterizedTest
20 | @MethodSource(Array("provideProtocolAndContentType"))
21 | def readSmartEdgeCollection(protocol: String, contentType: String): Unit = {
22 | val df: DataFrame = spark.read
23 | .format(BaseSparkTest.arangoDatasource)
24 | .options(options + (
25 | ArangoDBConf.COLLECTION -> ReadSmartEdgeCollectionTest.name,
26 | ArangoDBConf.PROTOCOL -> protocol,
27 | ArangoDBConf.CONTENT_TYPE -> contentType
28 | ))
29 | .load()
30 |
31 |
32 | import spark.implicits._
33 | val read = df
34 | .as[Edge]
35 | .collect()
36 |
37 | assertThat(read.map(_.name)).containsAll(ReadSmartEdgeCollectionTest.data.map(d => d("name")).asJava)
38 | }
39 |
40 | }
41 |
42 | object ReadSmartEdgeCollectionTest {
43 | val name = "smartEdgeCol"
44 | val from = s"from-$name"
45 | val to = s"from-$name"
46 |
47 | val data: immutable.Seq[Map[String, String]] = (1 to 10)
48 | .map(x => Map(
49 | "name" -> s"name-$x",
50 | "_from" -> s"$from/a:$x",
51 | "_to" -> s"$to/b:$x"
52 | ))
53 |
54 | @BeforeAll
55 | def init(): Unit = {
56 | assumeTrue(BaseSparkTest.isCluster && BaseSparkTest.isEnterprise)
57 |
58 | if (BaseSparkTest.db.graph(name).exists()) {
59 | BaseSparkTest.db.graph(name).drop(true)
60 | }
61 |
62 | val ed = new EdgeDefinition()
63 | .collection(name)
64 | .from(from)
65 | .to(to)
66 | val opts = new GraphCreateOptions()
67 | .numberOfShards(2)
68 | .isSmart(true)
69 | .smartGraphAttribute("name")
70 | BaseSparkTest.db.createGraph(name, List(ed).asJava.asInstanceOf[util.Collection[EdgeDefinition]], opts)
71 | BaseSparkTest.db.collection(name).insertDocuments(data.asJava.asInstanceOf[util.Collection[Any]])
72 | }
73 | }
74 |
75 | case class Edge(
76 | name: String,
77 | _from: String,
78 | _to: String
79 | )
80 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | // scalastyle:off
19 |
20 | package org.apache.spark.sql.arangodb.datasource.mapping.json
21 |
22 | import com.fasterxml.jackson.core.{JsonParser, JsonToken}
23 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
24 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
25 | import org.apache.spark.sql.errors.QueryErrorsBase
26 | import org.apache.spark.sql.types._
27 |
28 | object JacksonUtils extends QueryErrorsBase {
29 | /**
30 | * Advance the parser until a null or a specific token is found
31 | */
32 | def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = {
33 | parser.nextToken() match {
34 | case null => false
35 | case x => x != stopOn
36 | }
37 | }
38 |
39 | def verifyType(name: String, dataType: DataType): TypeCheckResult = {
40 | dataType match {
41 | case NullType | _: AtomicType | CalendarIntervalType => TypeCheckSuccess
42 |
43 | case st: StructType =>
44 | st.foldLeft(TypeCheckSuccess: TypeCheckResult) { case (currResult, field) =>
45 | if (currResult.isFailure) currResult else verifyType(field.name, field.dataType)
46 | }
47 |
48 | case at: ArrayType => verifyType(name, at.elementType)
49 |
50 | // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when
51 | // generating JSON, so we only care if the values are valid for JSON.
52 | case mt: MapType => verifyType(name, mt.valueType)
53 |
54 | case udt: UserDefinedType[_] => verifyType(name, udt.sqlType)
55 |
56 | case _ =>
57 | DataTypeMismatch(
58 | errorSubClass = "CANNOT_CONVERT_TO_JSON",
59 | messageParameters = Map(
60 | "name" -> toSQLId(name),
61 | "type" -> toSQLType(dataType)))
62 | }
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | // scalastyle:off
19 |
20 | package org.apache.spark.sql.arangodb.datasource.mapping.json
21 |
22 | import com.fasterxml.jackson.core.{JsonParser, JsonToken}
23 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
24 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
25 | import org.apache.spark.sql.errors.QueryErrorsBase
26 | import org.apache.spark.sql.types._
27 |
28 | object JacksonUtils extends QueryErrorsBase {
29 | /**
30 | * Advance the parser until a null or a specific token is found
31 | */
32 | def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = {
33 | parser.nextToken() match {
34 | case null => false
35 | case x => x != stopOn
36 | }
37 | }
38 |
39 | def verifyType(name: String, dataType: DataType): TypeCheckResult = {
40 | dataType match {
41 | case NullType | _: AtomicType | CalendarIntervalType => TypeCheckSuccess
42 |
43 | case st: StructType =>
44 | st.foldLeft(TypeCheckSuccess: TypeCheckResult) { case (currResult, field) =>
45 | if (currResult.isFailure) currResult else verifyType(field.name, field.dataType)
46 | }
47 |
48 | case at: ArrayType => verifyType(name, at.elementType)
49 |
50 | // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when
51 | // generating JSON, so we only care if the values are valid for JSON.
52 | case mt: MapType => verifyType(name, mt.valueType)
53 |
54 | case udt: UserDefinedType[_] => verifyType(name, udt.sqlType)
55 |
56 | case _ =>
57 | DataTypeMismatch(
58 | errorSubClass = "CANNOT_CONVERT_TO_JSON",
59 | messageParameters = Map(
60 | "name" -> toSQLId(name),
61 | "type" -> toSQLType(dataType)))
62 | }
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | arangodb-spark-datasource
7 | com.arangodb
8 | 1.9.0-SNAPSHOT
9 |
10 | 4.0.0
11 |
12 | arangodb-spark-datasource-3.4_${scala.compat.version}
13 |
14 | arangodb-spark-datasource-3.4
15 | ArangoDB Datasource for Apache Spark 3.4
16 | https://github.com/arangodb/arangodb-spark-datasource
17 |
18 |
19 |
20 | Michele Rastelli
21 | https://github.com/rashtao
22 |
23 |
24 |
25 |
26 | https://github.com/arangodb/arangodb-spark-datasource
27 |
28 |
29 |
30 | false
31 | ../integration-tests/target/site/jacoco-aggregate/jacoco.xml
32 | src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/*
33 | src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/*
34 | false
35 |
36 |
37 |
38 |
39 | com.arangodb
40 | arangodb-spark-commons-${spark.compat.version}_${scala.compat.version}
41 | ${project.version}
42 | compile
43 |
44 |
45 |
46 |
47 |
48 |
49 | maven-assembly-plugin
50 |
51 |
52 | jar-with-dependencies
53 |
54 |
55 |
56 |
57 | package
58 |
59 | single
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | arangodb-spark-datasource
7 | com.arangodb
8 | 1.9.0-SNAPSHOT
9 |
10 | 4.0.0
11 |
12 | arangodb-spark-datasource-3.5_${scala.compat.version}
13 |
14 | arangodb-spark-datasource-3.5
15 | ArangoDB Datasource for Apache Spark 3.5
16 | https://github.com/arangodb/arangodb-spark-datasource
17 |
18 |
19 |
20 | Michele Rastelli
21 | https://github.com/rashtao
22 |
23 |
24 |
25 |
26 | https://github.com/arangodb/arangodb-spark-datasource
27 |
28 |
29 |
30 | false
31 | ../integration-tests/target/site/jacoco-aggregate/jacoco.xml
32 | src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/*
33 | src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/*
34 | false
35 |
36 |
37 |
38 |
39 | com.arangodb
40 | arangodb-spark-commons-${spark.compat.version}_${scala.compat.version}
41 | ${project.version}
42 | compile
43 |
44 |
45 |
46 |
47 |
48 |
49 | maven-assembly-plugin
50 |
51 |
52 | jar-with-dependencies
53 |
54 |
55 |
56 |
57 | package
58 |
59 | single
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/demo/src/main/scala/ReadDemo.scala:
--------------------------------------------------------------------------------
1 | import Schemas.movieSchema
2 | import org.apache.spark.sql.DataFrame
3 | import org.apache.spark.sql.types._
4 |
5 | object ReadDemo {
6 |
7 | def readDemo(): Unit = {
8 | println("-----------------")
9 | println("--- READ DEMO ---")
10 | println("-----------------")
11 |
12 | val moviesDF = readTable("movies", movieSchema)
13 |
14 | println("Read table: history movies or documentaries about 'World War' released from 2000-01-01")
15 | moviesDF
16 | .select("title", "releaseDate", "genre", "description")
17 | .filter("genre IN ('History', 'Documentary') AND description LIKE '%World War%' AND releaseDate > '2000'")
18 | .show(20, 200)
19 | /*
20 | Filters and projection pushdowns are applied in this case.
21 |
22 | In the console an info message log like the following will be printed:
23 | > INFO ArangoScanBuilder:57 - Filters fully applied in AQL:
24 | > IsNotNull(description)
25 | > IsNotNull(releaseDate)
26 | > In(genre, [History,Documentary])
27 | > StringContains(description,World War)
28 | > GreaterThan(releaseDate,2000-01-01)
29 |
30 | Also the generated AQL query will be printed with log level debug:
31 | > DEBUG ArangoClient:61 - Executing AQL query:
32 | > FOR d IN @@col FILTER `d`.`description` != null AND `d`.`releaseDate` != null AND LENGTH(["History","Documentary"][* FILTER `d`.`genre` == CURRENT]) > 0 AND CONTAINS(`d`.`description`, "World War") AND DATE_TIMESTAMP(`d`.`releaseDate`) > DATE_TIMESTAMP("2000-01-01") RETURN {`description`:`d`.`description`,`genre`:`d`.`genre`,`releaseDate`:`d`.`releaseDate`,`title`:`d`.`title`}
33 | > with params: Map(@col -> movies)
34 | */
35 |
36 | println("Read query: actors of movies directed by Clint Eastwood with related movie title and interpreted role")
37 | readQuery(
38 | """WITH movies, persons
39 | |FOR v, e, p IN 2 ANY "persons/1062" OUTBOUND directed, INBOUND actedIn
40 | | RETURN {movie: p.vertices[1].title, name: v.name, role: p.edges[1].name}
41 | |""".stripMargin,
42 | schema = StructType(Array(
43 | StructField("movie", StringType),
44 | StructField("name", StringType),
45 | StructField("role", StringType)
46 | ))
47 | ).show(20, 200)
48 | }
49 |
50 | def readTable(tableName: String, schema: StructType): DataFrame = {
51 | Demo.spark.read
52 | .format("com.arangodb.spark")
53 | .options(Demo.options + ("table" -> tableName))
54 | .schema(schema)
55 | .load
56 | }
57 |
58 | def readQuery(query: String, schema: StructType): DataFrame = {
59 | Demo.spark.read
60 | .format("com.arangodb.spark")
61 | .options(Demo.options + ("query" -> query))
62 | .schema(schema)
63 | .load
64 | }
65 |
66 | }
67 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.reader
2 |
3 | import com.arangodb.entity.CursorWarning
4 | import org.apache.spark.internal.Logging
5 | import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider
6 | import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
7 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf}
8 | import org.apache.spark.sql.catalyst.InternalRow
9 | import org.apache.spark.sql.catalyst.util.FailureSafeParser
10 | import org.apache.spark.sql.connector.read.PartitionReader
11 | import org.apache.spark.sql.types.StructType
12 |
13 | import scala.annotation.tailrec
14 | import scala.collection.JavaConverters.iterableAsScalaIterableConverter
15 |
16 |
17 | class ArangoCollectionPartitionReader(inputPartition: ArangoCollectionPartition, ctx: PushDownCtx, opts: ArangoDBConf)
18 | extends PartitionReader[InternalRow] with Logging {
19 |
20 | // override endpoints with partition endpoint
21 | private val options = opts.updated(ArangoDBConf.ENDPOINTS, inputPartition.endpoint)
22 | private val actualSchema = StructType(ctx.requiredSchema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord))
23 | private val parser = ArangoParserProvider().of(options.driverOptions.contentType, actualSchema, options)
24 | private val safeParser = new FailureSafeParser[Array[Byte]](
25 | parser.parse,
26 | options.readOptions.parseMode,
27 | ctx.requiredSchema,
28 | options.readOptions.columnNameOfCorruptRecord)
29 | private val client = ArangoClient(options)
30 | private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx.filters, actualSchema)
31 |
32 | var rowIterator: Iterator[InternalRow] = _
33 |
34 | // warnings of non stream AQL cursors are all returned along with the first batch
35 | if (!options.readOptions.stream) logWarns()
36 |
37 | @tailrec
38 | final override def next: Boolean =
39 | if (iterator.hasNext) {
40 | val current = iterator.next()
41 | rowIterator = safeParser.parse(current.get)
42 | if (rowIterator.hasNext) {
43 | true
44 | } else {
45 | next
46 | }
47 | } else {
48 | // FIXME: https://arangodb.atlassian.net/browse/BTS-671
49 | // stream AQL cursors' warnings are only returned along with the final batch
50 | if (options.readOptions.stream) logWarns()
51 | false
52 | }
53 |
54 | override def get: InternalRow = rowIterator.next()
55 |
56 | override def close(): Unit = {
57 | iterator.close()
58 | client.shutdown()
59 | }
60 |
61 | private def logWarns(): Unit = Option(iterator.getWarnings).foreach(_.asScala.foreach((w: CursorWarning) =>
62 | logWarning(s"Got AQL warning: [${w.getCode}] ${w.getMessage}")
63 | ))
64 |
65 | }
66 |
--------------------------------------------------------------------------------
/ChangeLog.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | All notable changes to this project will be documented in this file.
3 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres
4 | to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
5 |
6 | ## [Unreleased]
7 |
8 | ## [1.8.0] - 2024-09-23
9 |
10 | - updated Java Driver to version `7.9.0`
11 | - added support to load SSL Trust Store from file (DE-502, #61)
12 | - dropped support for Spark 3.2
13 |
14 | ## [1.7.0] - 2024-07-04
15 |
16 | - added support for Spark 3.5 (#58)
17 | - added support to protocol `http2` (DE-596, #55)
18 | - added configuration `ssl.verifyHost` to disable TLS hostname verification (DE-790, #54)
19 | - updated `arangodb-java-driver` to version `7.7.1`
20 | - dropped support for Spark 3.1
21 |
22 | ## [1.6.0] - 2024-03-20
23 |
24 | - support to query `ttl` (#50)
25 | - updated Java Driver to version `7.5.1`
26 | - dropped support for Spark 2.4
27 |
28 | ## [1.5.1] - 2023-09-25
29 |
30 | - fixed reading from smart edge collections (#47)
31 |
32 | ## [1.5.0] - 2023-05-31
33 |
34 | - support for Spark 3.4 (#45)
35 | - support for Spark 3.3 (#44)
36 | - updated Java Driver to version `7.0` (#35)
37 |
38 | ## [1.4.3] - 2023-03-28
39 |
40 | - added previous attempts exceptions in `ArangoDBDataWriterException`
41 |
42 | ## [1.4.2] - 2023-03-16
43 |
44 | - added debug header `x-arango-spark-request-id`
45 |
46 | ## [1.4.1] - 2022-12-15
47 |
48 | - fixed filters pushdown for read mode `query` (#37)
49 |
50 | ## [1.4.0] - 2022-05-24
51 |
52 | - support for Spark 3.2 (#31)
53 | - support for `null` at root level of a JSON array in Spark 3.1 mapping (SPARK-36379)
54 | - updated dependencies
55 | - flush write buffer on byte threshold `byteBatchSize` (#30)
56 | - remove `null` `_key` field during serialization (#29)
57 |
58 | ## [1.3.0] - 2022-04-27
59 |
60 | - added `ignoreNullFields` config param (#28)
61 |
62 | ## [1.2.0] - 2022-03-18
63 |
64 | - use `overwriteMode=ignore` if save mode is other than `Append` (#26)
65 | - require non-nullable string fields `_from` and `_to` to write to edge collections (#25)
66 | - configurable backoff retry delay for write requests (`retry.minDelay` and `retry.maxDelay`), disabled by default (#24)
67 | - retry only if schema has non-nullable field `_key` (#23)
68 | - retry on connection exceptions
69 | - added `retry.maxAttempts` config param (#20)
70 | - increased default timeout to 5 minutes
71 | - reject writing decimal types with json content type (#18)
72 | - report records causing write errors (#17)
73 | - improved logging about connections and write tasks
74 |
75 | ## [1.1.1] - 2022-02-28
76 |
77 | - retry timeout exception in truncate requests (#16)
78 | - fixed exception serialization bug (#15)
79 |
80 | ## [1.1.0] - 2022-02-23
81 |
82 | - added driver timeout configuration option (#12)
83 | - updated dependency `com.arangodb:arangodb-java-driver:6.16.1`
84 |
85 | ## [1.0.0] - 2021-12-11
86 |
87 | - Initial Release
88 |
--------------------------------------------------------------------------------
/demo/src/main/scala/WriteDemo.scala:
--------------------------------------------------------------------------------
1 | import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
2 | import org.apache.spark.sql.functions._
3 | import org.apache.spark.sql.types._
4 | import org.apache.spark.sql.{Column, DataFrame}
5 |
6 | object WriteDemo {
7 |
8 | val saveOptions: Map[String, String] = Demo.options ++ Map(
9 | "table.shards" -> "9",
10 | "confirmTruncate" -> "true",
11 | "overwriteMode" -> "replace"
12 | )
13 |
14 | def writeDemo(): Unit = {
15 | println("------------------")
16 | println("--- WRITE DEMO ---")
17 | println("------------------")
18 |
19 | println("Reading JSON files...")
20 | val nodesDF = Demo.spark.read.json(Demo.importPath + "/nodes.jsonl")
21 | .withColumn("_key", new Column(AssertNotNull(col("_key").expr)))
22 | .withColumn("releaseDate", unixTsToSparkDate(col("releaseDate")))
23 | .withColumn("birthday", unixTsToSparkDate(col("birthday")))
24 | .withColumn("lastModified", unixTsToSparkTs(col("lastModified")))
25 | .persist()
26 | val edgesDF = Demo.spark.read.json(Demo.importPath + "/edges.jsonl")
27 | .withColumn("_key", new Column(AssertNotNull(col("_key").expr)))
28 | .withColumn("_from", new Column(AssertNotNull(concat(lit("persons/"), col("_from")).expr)))
29 | .withColumn("_to", new Column(AssertNotNull(concat(lit("movies/"), col("_to")).expr)))
30 | .persist()
31 |
32 | val personsDF = nodesDF
33 | .select(Schemas.personSchema.fieldNames.filter(_ != "_id").map(col): _*)
34 | .where("type = 'Person'")
35 | val moviesDF = nodesDF
36 | .select(Schemas.movieSchema.fieldNames.filter(_ != "_id").map(col): _*)
37 | .where("type = 'Movie'")
38 | val directedDF = edgesDF
39 | .select(Schemas.directedSchema.fieldNames.filter(_ != "_id").map(col): _*)
40 | .where("`$label` = 'DIRECTED'")
41 | val actedInDF = edgesDF
42 | .select(Schemas.actsInSchema.fieldNames.filter(_ != "_id").map(col): _*)
43 | .where("`$label` = 'ACTS_IN'")
44 |
45 | println("Writing 'persons' collection...")
46 | saveDF(personsDF, "persons")
47 |
48 | println("Writing 'movies' collection...")
49 | saveDF(moviesDF, "movies")
50 |
51 | println("Writing 'directed' edge collection...")
52 | saveDF(directedDF, "directed", "edge")
53 |
54 | println("Writing 'actedIn' edge collection...")
55 | saveDF(actedInDF, "actedIn", "edge")
56 | }
57 |
58 | def unixTsToSparkTs(c: Column): Column = (c.cast(LongType) / 1000).cast(TimestampType)
59 |
60 | def unixTsToSparkDate(c: Column): Column = unixTsToSparkTs(c).cast(DateType)
61 |
62 | def saveDF(df: DataFrame, tableName: String, tableType: String = "document"): Unit =
63 | df
64 | .write
65 | .mode("overwrite")
66 | .format("com.arangodb.spark")
67 | .options(saveOptions ++ Map(
68 | "table" -> tableName,
69 | "table.type" -> tableType
70 | ))
71 | .save()
72 |
73 | }
74 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/test_composite_filter.py:
--------------------------------------------------------------------------------
1 | import arango.database
2 | import pyspark.sql
3 | import pytest
4 | from pyspark.sql import SparkSession
5 | from pyspark.sql.functions import col
6 | from pyspark.sql.types import StructType, StructField, IntegerType, StringType, BooleanType
7 |
8 | from integration import test_basespark
9 |
10 | data = [
11 | {
12 | "integer": 1,
13 | "string": "one",
14 | "bool": True
15 | },
16 | {
17 | "integer": 2,
18 | "string": "two",
19 | "bool": True
20 | }
21 | ]
22 |
23 | schema = StructType([
24 | StructField("integer", IntegerType(), nullable=False),
25 | StructField("string", StringType(), nullable=False),
26 | StructField("bool", BooleanType(), nullable=False),
27 | ])
28 |
29 | table_name = "compositeFilter"
30 |
31 |
32 | @pytest.fixture(scope="module")
33 | def composite_df(database_conn: arango.database.StandardDatabase, spark: SparkSession) -> pyspark.sql.DataFrame:
34 | df = test_basespark.create_df(database_conn, spark, table_name, data, schema)
35 | yield df
36 | test_basespark.drop_table(database_conn, table_name)
37 |
38 |
39 | def test_or_filter(spark: SparkSession, composite_df: pyspark.sql.DataFrame):
40 | field_name = "integer"
41 | value = data[0][field_name]
42 | res = composite_df.filter((col(field_name) == 0) | (col(field_name) == 1)).collect()
43 |
44 | assert len(res) == 1
45 | assert res[0].asDict()[field_name] == value
46 |
47 | sql_res = spark.sql(f"""
48 | SELECT * FROM {table_name}
49 | WHERE {field_name} = 0 OR {field_name} = 1
50 | """).collect()
51 |
52 | assert len(sql_res) == 1
53 | assert sql_res[0].asDict()[field_name] == value
54 |
55 |
56 | def test_not_filter(spark: SparkSession, composite_df: pyspark.sql.DataFrame):
57 | field_name = "integer"
58 | value = data[0][field_name]
59 | res = composite_df.filter(~(col(field_name) == 2)).collect()
60 |
61 | assert len(res) == 1
62 | assert res[0].asDict()[field_name] == value
63 |
64 | sql_res = spark.sql(f"""
65 | SELECT * FROM {table_name}
66 | WHERE NOT {field_name} = 2
67 | """).collect()
68 |
69 | assert len(sql_res) == 1
70 | assert sql_res[0].asDict()[field_name] == value
71 |
72 |
73 | def test_or_and_filter(spark: SparkSession, composite_df: pyspark.sql.DataFrame):
74 | field_name_1 = "integer"
75 | value_1 = data[0][field_name_1]
76 |
77 | field_name_2 = "string"
78 | value_2 = data[0][field_name_2]
79 |
80 | res = composite_df.filter((col("bool") == False) | ((col(field_name_1) == value_1) & (col(field_name_2) == value_2))).collect()
81 |
82 | assert len(res) == 1
83 | assert res[0].asDict()[field_name_1] == value_1
84 |
85 | sql_res = spark.sql(f"""
86 | SELECT * FROM {table_name}
87 | WHERE bool = false OR ({field_name_1} = {value_1} AND {field_name_2} = "{value_2}")
88 | """).collect()
89 |
90 | assert len(sql_res) == 1
91 | assert sql_res[0].asDict()[field_name_1] == value_1
92 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.reader
2 |
3 | import org.apache.spark.internal.Logging
4 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ReadMode}
5 | import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter}
6 | import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
7 | import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
8 | import org.apache.spark.sql.sources.Filter
9 | import org.apache.spark.sql.types.StructType
10 |
11 | class ArangoScanBuilder(options: ArangoDBConf, tableSchema: StructType) extends ScanBuilder
12 | with SupportsPushDownFilters
13 | with SupportsPushDownRequiredColumns
14 | with Logging {
15 |
16 | private var readSchema: StructType = _
17 |
18 | // fully or partially applied filters
19 | private var appliedPushableFilters: Array[PushableFilter] = Array()
20 | private var appliedSparkFilters: Array[Filter] = Array()
21 |
22 | override def build(): Scan = new ArangoScan(new PushDownCtx(readSchema, appliedPushableFilters), options)
23 |
24 | override def pushFilters(filters: Array[Filter]): Array[Filter] = {
25 | options.readOptions.readMode match {
26 | case ReadMode.Collection => pushFiltersReadModeCollection(filters)
27 | case ReadMode.Query => filters
28 | }
29 | }
30 |
31 | private def pushFiltersReadModeCollection(filters: Array[Filter]): Array[Filter] = {
32 | // filters related to columnNameOfCorruptRecord are not pushed down
33 | val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord)
34 | val ignoredFilters = filters.filter(isCorruptRecordFilter)
35 | val filtersBySupport = filters
36 | .filterNot(isCorruptRecordFilter)
37 | .map(f => (f, PushableFilter(f, tableSchema)))
38 | .groupBy(_._2.support())
39 |
40 | val fullSupp = filtersBySupport.getOrElse(FilterSupport.FULL, Array())
41 | val partialSupp = filtersBySupport.getOrElse(FilterSupport.PARTIAL, Array())
42 | val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()).map(_._1) ++ ignoredFilters
43 |
44 | val appliedFilters = fullSupp ++ partialSupp
45 | appliedPushableFilters = appliedFilters.map(_._2)
46 | appliedSparkFilters = appliedFilters.map(_._1)
47 |
48 | if (fullSupp.nonEmpty) {
49 | logInfo(s"Filters fully applied in AQL:\n\t${fullSupp.map(_._1).mkString("\n\t")}")
50 | }
51 | if (partialSupp.nonEmpty) {
52 | logInfo(s"Filters partially applied in AQL:\n\t${partialSupp.map(_._1).mkString("\n\t")}")
53 | }
54 | if (noneSupp.nonEmpty) {
55 | logInfo(s"Filters not applied in AQL:\n\t${noneSupp.mkString("\n\t")}")
56 | }
57 |
58 | partialSupp.map(_._1) ++ noneSupp
59 | }
60 |
61 | override def pushedFilters(): Array[Filter] = appliedSparkFilters
62 |
63 | override def pruneColumns(requiredSchema: StructType): Unit = {
64 | this.readSchema = requiredSchema
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/CreateCollectionTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.write
2 |
3 | import com.arangodb.ArangoCollection
4 | import org.apache.spark.sql.{Row, SaveMode}
5 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, CollectionType}
6 | import org.apache.spark.sql.arangodb.datasource.BaseSparkTest
7 | import org.apache.spark.sql.types.{StringType, StructField, StructType}
8 | import org.assertj.core.api.Assertions.assertThat
9 | import org.junit.jupiter.api.BeforeEach
10 | import org.junit.jupiter.params.ParameterizedTest
11 | import org.junit.jupiter.params.provider.MethodSource
12 |
13 | import scala.collection.JavaConverters._
14 |
15 |
16 | class CreateCollectionTest extends BaseSparkTest {
17 |
18 | private val collectionName = "chessPlayersCreateCollection"
19 | private val collection: ArangoCollection = db.collection(collectionName)
20 |
21 | private val rows = Seq(
22 | Row("a/1", "b/1"),
23 | Row("a/2", "b/2"),
24 | Row("a/3", "b/3"),
25 | Row("a/4", "b/4"),
26 | Row("a/5", "b/5"),
27 | Row("a/6", "b/6")
28 | )
29 |
30 | private val df = spark.createDataFrame(rows.asJava, StructType(Array(
31 | StructField("_from", StringType, nullable = false),
32 | StructField("_to", StringType, nullable = false)
33 | ))).repartition(3)
34 |
35 | @BeforeEach
36 | def beforeEach(): Unit = {
37 | if (collection.exists()) {
38 | collection.drop()
39 | }
40 | }
41 |
42 | @ParameterizedTest
43 | @MethodSource(Array("provideProtocolAndContentType"))
44 | def saveModeAppend(protocol: String, contentType: String): Unit = {
45 | df.write
46 | .format(BaseSparkTest.arangoDatasource)
47 | .mode(SaveMode.Append)
48 | .options(options + (
49 | ArangoDBConf.COLLECTION -> collectionName,
50 | ArangoDBConf.PROTOCOL -> protocol,
51 | ArangoDBConf.CONTENT_TYPE -> contentType,
52 | ArangoDBConf.NUMBER_OF_SHARDS -> "5",
53 | ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name
54 | ))
55 | .save()
56 |
57 | if (isCluster) {
58 | assertThat(collection.getProperties.getNumberOfShards).isEqualTo(5)
59 | }
60 | assertThat(collection.getProperties.getType.getType).isEqualTo(com.arangodb.entity.CollectionType.EDGES.getType)
61 | }
62 |
63 | @ParameterizedTest
64 | @MethodSource(Array("provideProtocolAndContentType"))
65 | def saveModeOverwrite(protocol: String, contentType: String): Unit = {
66 | df.write
67 | .format(BaseSparkTest.arangoDatasource)
68 | .mode(SaveMode.Overwrite)
69 | .options(options + (
70 | ArangoDBConf.COLLECTION -> collectionName,
71 | ArangoDBConf.PROTOCOL -> protocol,
72 | ArangoDBConf.CONTENT_TYPE -> contentType,
73 | ArangoDBConf.CONFIRM_TRUNCATE -> "true",
74 | ArangoDBConf.NUMBER_OF_SHARDS -> "5",
75 | ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name
76 | ))
77 | .save()
78 |
79 | if (isCluster) {
80 | assertThat(collection.getProperties.getNumberOfShards).isEqualTo(5)
81 | }
82 | assertThat(collection.getProperties.getType.getType).isEqualTo(com.arangodb.entity.CollectionType.EDGES.getType)
83 | }
84 |
85 | }
86 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/package.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * DISCLAIMER
3 | *
4 | * Copyright 2016 ArangoDB GmbH, Cologne, Germany
5 | *
6 | * Licensed under the Apache License, Version 2.0 (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | *
18 | * Copyright holder is ArangoDB GmbH, Cologne, Germany
19 | */
20 |
21 | package org.apache.spark.sql.arangodb.commons
22 |
23 | import com.arangodb.entity
24 |
25 | sealed trait ReadMode
26 |
27 | object ReadMode {
28 | /**
29 | * Read from an Arango collection. The scan will be partitioned according to the collection shards.
30 | */
31 | case object Collection extends ReadMode
32 |
33 | /**
34 | * Read executing a user query, without partitioning.
35 | */
36 | case object Query extends ReadMode
37 | }
38 |
39 | sealed trait ContentType {
40 | val name: String
41 | }
42 |
43 | object ContentType {
44 | case object JSON extends ContentType {
45 | override val name: String = "json"
46 | }
47 |
48 | case object VPACK extends ContentType {
49 | override val name: String = "vpack"
50 | }
51 |
52 | def apply(value: String): ContentType = value match {
53 | case JSON.name => JSON
54 | case VPACK.name => VPACK
55 | case _ => throw new IllegalArgumentException(s"${ArangoDBConf.CONTENT_TYPE}: $value")
56 | }
57 | }
58 |
59 | sealed trait Protocol {
60 | val name: String
61 | }
62 |
63 | object Protocol {
64 | case object VST extends Protocol {
65 | override val name: String = "vst"
66 | }
67 |
68 | case object HTTP extends Protocol {
69 | override val name: String = "http"
70 | }
71 |
72 | case object HTTP2 extends Protocol {
73 | override val name: String = "http2"
74 | }
75 |
76 | def apply(value: String): Protocol = value match {
77 | case VST.name => VST
78 | case HTTP.name => HTTP
79 | case HTTP2.name => HTTP2
80 | case _ => throw new IllegalArgumentException(s"${ArangoDBConf.PROTOCOL}: $value")
81 | }
82 | }
83 |
84 | sealed trait CollectionType {
85 | val name: String
86 |
87 | def get(): entity.CollectionType
88 | }
89 |
90 | object CollectionType {
91 | case object DOCUMENT extends CollectionType {
92 | override val name: String = "document"
93 |
94 | override def get(): entity.CollectionType = entity.CollectionType.DOCUMENT
95 | }
96 |
97 | case object EDGE extends CollectionType {
98 | override val name: String = "edge"
99 |
100 | override def get(): entity.CollectionType = entity.CollectionType.EDGES
101 | }
102 |
103 | def apply(value: String): CollectionType = value match {
104 | case DOCUMENT.name => DOCUMENT
105 | case EDGE.name => EDGE
106 | case _ => throw new IllegalArgumentException(s"${ArangoDBConf.COLLECTION_TYPE}: $value")
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/StringFiltersTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import org.apache.spark.sql.DataFrame
4 | import org.apache.spark.sql.functions.col
5 | import org.apache.spark.sql.types._
6 | import org.assertj.core.api.Assertions.assertThat
7 | import org.junit.jupiter.api.{AfterAll, BeforeAll, Test}
8 |
9 | import scala.collection.JavaConverters._
10 |
11 | class StringFiltersTest extends BaseSparkTest {
12 | private val df = StringFiltersTest.df
13 |
14 | @Test
15 | def startsWith(): Unit = {
16 | val fieldName = "string"
17 | val value = StringFiltersTest.data.head(fieldName).asInstanceOf[String]
18 | val res = df.filter(col(fieldName).startsWith("Lorem")).collect()
19 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames))
20 | assertThat(res).hasSize(1)
21 | assertThat(res.head(fieldName)).isEqualTo(value)
22 | val sqlRes = spark.sql(
23 | s"""
24 | |SELECT * FROM stringFilters
25 | |WHERE $fieldName LIKE "Lorem%"
26 | |""".stripMargin).collect()
27 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames))
28 | assertThat(sqlRes).hasSize(1)
29 | assertThat(sqlRes.head(fieldName)).isEqualTo(value)
30 | }
31 |
32 | @Test
33 | def endsWith(): Unit = {
34 | val fieldName = "string"
35 | val value = StringFiltersTest.data.head(fieldName).asInstanceOf[String]
36 | val res = df.filter(col(fieldName).endsWith("amet")).collect()
37 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames))
38 | assertThat(res).hasSize(1)
39 | assertThat(res.head(fieldName)).isEqualTo(value)
40 | val sqlRes = spark.sql(
41 | s"""
42 | |SELECT * FROM stringFilters
43 | |WHERE $fieldName LIKE "%amet"
44 | |""".stripMargin).collect()
45 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames))
46 | assertThat(sqlRes).hasSize(1)
47 | assertThat(sqlRes.head(fieldName)).isEqualTo(value)
48 | }
49 |
50 | @Test
51 | def contains(): Unit = {
52 | val fieldName = "string"
53 | val value = StringFiltersTest.data.head(fieldName).asInstanceOf[String]
54 | val res = df.filter(col(fieldName).contains("dolor")).collect()
55 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames))
56 | assertThat(res).hasSize(1)
57 | assertThat(res.head(fieldName)).isEqualTo(value)
58 | val sqlRes = spark.sql(
59 | s"""
60 | |SELECT * FROM stringFilters
61 | |WHERE $fieldName LIKE "%dolor%"
62 | |""".stripMargin).collect()
63 | .map(_.getValuesMap[Any](StringFiltersTest.schema.fieldNames))
64 | assertThat(sqlRes).hasSize(1)
65 | assertThat(sqlRes.head(fieldName)).isEqualTo(value)
66 | }
67 |
68 | }
69 |
70 | object StringFiltersTest {
71 | private var df: DataFrame = _
72 | private val data: Seq[Map[String, Any]] = Seq(
73 | Map(
74 | "string" -> "Lorem ipsum dolor sit amet"
75 | ),
76 | Map(
77 | "string" -> "consectetur adipiscing elit"
78 | )
79 | )
80 |
81 | private val schema = StructType(Array(
82 | StructField("string", StringType, nullable = false)
83 | ))
84 |
85 | @BeforeAll
86 | def init(): Unit = {
87 | df = BaseSparkTest.createDF("stringFilters", data, schema)
88 | }
89 |
90 | @AfterAll
91 | def cleanup(): Unit = {
92 | BaseSparkTest.dropTable("stringFilters")
93 | }
94 | }
--------------------------------------------------------------------------------
/demo/README.md:
--------------------------------------------------------------------------------
1 | # ArangoDB Spark Datasource Demo
2 |
3 | This demo is composed of 3 parts:
4 |
5 | - `WriteDemo`: reads the input json files as Spark Dataframes, applies conversions to map the data to Spark data types
6 | and writes the records into ArangoDB collections
7 | - `ReadDemo`: reads the ArangoDB collections created above as Spark Dataframes, specifying columns selection and records
8 | filters predicates or custom AQL queries
9 | - `ReadWriteDemo`: reads the ArangoDB collections created above as Spark Dataframes, applies projections and filtering,
10 | writes to a new ArangoDB collection
11 |
12 | There are demos available written in Scala & Python (using PySpark) as outlined below.
13 |
14 | ## Requirements
15 |
16 | This demo requires:
17 |
18 | - JDK 8, 11 or 17
19 | - `maven`
20 | - `docker`
21 |
22 | For the python demo, you will also need
23 | - `python`
24 |
25 | ## Prepare the environment
26 |
27 | Set environment variables:
28 |
29 | ```shell
30 | export ARANGO_SPARK_VERSION=1.9.0-SNAPSHOT
31 | ```
32 |
33 | Start ArangoDB cluster with docker:
34 |
35 | ```shell
36 | SSL=true STARTER_MODE=cluster ./docker/start_db.sh
37 | ```
38 |
39 | The deployed cluster will be accessible at [https://172.28.0.1:8529](http://172.28.0.1:8529) with username `root` and
40 | password `test`.
41 |
42 | Start Spark cluster:
43 |
44 | ```shell
45 | ./docker/start_spark.sh
46 | ```
47 |
48 | ## Install locally
49 |
50 | NB: this is only needed for SNAPSHOT versions.
51 |
52 | ```shell
53 | mvn -f ../pom.xml install -Dmaven.test.skip=true -Dgpg.skip=true -Dmaven.javadoc.skip=true -Pscala-2.12 -Pspark-3.5
54 | ```
55 |
56 | ## Run embedded
57 |
58 | Test the Spark application in embedded mode:
59 |
60 | ```shell
61 | mvn \
62 | -Pscala-2.12 -Pspark-3.5 \
63 | test
64 | ```
65 |
66 | Test the Spark application against ArangoDB Oasis deployment:
67 |
68 | ```shell
69 | mvn \
70 | -Pscala-2.12 -Pspark-3.5 \
71 | -Dpassword= \
72 | -Dendpoints= \
73 | -Dssl.cert.value= \
74 | test
75 | ```
76 |
77 | ## Submit to Spark cluster
78 |
79 | Package the application:
80 |
81 | ```shell
82 | mvn package -Dmaven.test.skip=true -Pscala-2.12 -Pspark-3.5
83 | ```
84 |
85 | Submit demo program:
86 |
87 | ```shell
88 | docker run -it --rm \
89 | -v $(pwd):/demo \
90 | -v $(pwd)/docker/.ivy2:/opt/bitnami/spark/.ivy2 \
91 | -v $HOME/.m2/repository:/opt/bitnami/spark/.m2/repository \
92 | --network arangodb \
93 | docker.io/bitnamilegacy/spark:3.5.6 \
94 | ./bin/spark-submit --master spark://spark-master:7077 \
95 | --packages="com.arangodb:arangodb-spark-datasource-3.5_2.12:$ARANGO_SPARK_VERSION" \
96 | --class Demo /demo/target/demo-$ARANGO_SPARK_VERSION.jar
97 | ```
98 |
99 | ## Python(PySpark) Demo
100 |
101 | This demo requires the same environment setup as outlined above.
102 | Additionally, the python requirements will need to be installed as follows:
103 | ```shell
104 | pip install -r ./python-demo/requirements.txt
105 | ```
106 |
107 | To run the PySpark demo, run
108 | ```shell
109 | python ./python-demo/demo.py \
110 | --ssl-enabled=true \
111 | --endpoints=172.28.0.1:8529,172.28.0.1:8539,172.28.0.1:8549
112 | ```
113 |
114 | To run it against an Oasis deployment, run
115 | ```shell
116 | python ./python-demo/demo.py \
117 | --password= \
118 | --endpoints= \
119 | --ssl-enabled=true \
120 | --ssl-cert-value=
121 | ```
--------------------------------------------------------------------------------
/docker/server.pem:
--------------------------------------------------------------------------------
1 | Bag Attributes
2 | friendlyName: arangotest
3 | localKeyID: 54 69 6D 65 20 31 36 30 34 32 35 36 36 37 39 38 35 34
4 | Key Attributes:
5 | -----BEGIN PRIVATE KEY-----
6 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC1WiDnd4+uCmMG
7 | 539ZNZB8NwI0RZF3sUSQGPx3lkqaFTZVEzMZL76HYvdc9Qg7difyKyQ09RLSpMAL
8 | X9euSseD7bZGnfQH52BnKcT09eQ3wh7aVQ5sN2omygdHLC7X9usntxAfv7Nzmvdo
9 | gNXoJQyY/hSZff7RIqWH8NnAUKkjqOe6Bf5LDbxHKESmrFBxOCOnhcpvZWetwpiR
10 | dJVPwUn5P82CAZzfiBfmBZnB7D0l+/6Cv4jMuH26uAIcixnVekBQzl1RgwczuiZf
11 | 2MGO64vDMMJJWE9ClZF1uQuQrwXF6qwhuP1Hnkii6wNbTtPWlGSkqeutr004+Hzb
12 | f8KnRY4PAgMBAAECggEAKi1d/bdW2TldMpvgiFTm15zLjHCpllbKBWFqRj3T9+X7
13 | Duo6Nh9ehopD0YDDe2DNhYr3DsH4sLjUWVDfDpAhutMsU1wlBzmOuC+EuRv/CeDB
14 | 4DFr+0sgCwlti+YAtwWcR05SF7A0Ai0GYW2lUipbtbFSBSjCfM08BlPDsPCRhdM8
15 | DhBn3S45aP7oC8BdhG/etg+DfXW+/nyNwEcMCYG97bzXNjzYpCQjo/bTHdh2UPYM
16 | 4WEAqFzZ5jir8LVS3v7GqpqPmk6FnHJOJpfpOSZoPqnfpIw7SVlNsXHvDaHGcgYZ
17 | Xec7rLQlBuv4RZU7OlGJpK2Ng5kvS9q3nfqqn7YIMQKBgQDqSsYnE+k6CnrSpa2W
18 | B9W/+PChITgkA46XBUUjAueJ7yVZQQEOzl0VI6RoVBp3t66eO8uM9omO8/ogHXku
19 | Ei9UUIIfH4BsSP7G5A06UC/FgReDxwBfbRuS+lupnmc348vPDkFlJZ4hDgWflNev
20 | 7tpUbljSAqUea1VhdBy146V4qwKBgQDGJ6iL1+A9uUM+1UklOAPpPhTQ8ZQDRCj7
21 | 7IMVcbzWYvCMuVNXzOWuiz+VYr3IGCJZIbxbFDOHxGF4XKJnk0vm1qhQQME0PtAF
22 | i1jIfsxpj8KKJl9Uad+XLQCYRV8mIZlhsd/ErRJuz6FyqevKH3nFIb0ggF3x2d06
23 | odTHuj4ILQKBgCUsI/BDSne4/e+59aaeK52/w33tJVkhb1gqr+N0LIRH+ycEF0Tg
24 | HQijlQwwe9qOvBfC6PK+kuipcP/zbSyQGg5Ij7ycZOXJVxL7T9X2rv2pE7AGvNpn
25 | Fz7klfJ9fWbyr310h4+ivkoETYQaO3ZgcSeAMntvi/8djHhf0cZSDgjtAoGBAKvQ
26 | TUNcHjJGxfjgRLkB1dpSmwgEv7sJSaQOkiZw5TTauwq50nsJzYlHcg1cfYPW8Ulp
27 | iAFNBdVNwNn1MFgwjpqMO4rCawObBxIXnhbSYvmQzjStSvFNj7JsMdzWIcdVUMI1
28 | 0fmdu6LbY3ihvzIVkqcMNwnMZCjFKB6jnXTElu7NAoGAS0gNPD/bfzWAhZBBYp9/
29 | SLGOvjHKrSVWGwDiqdAGuh6xg+1C3F+XpiITP6d3Wv3PCJ/Gia5isQPSMaXG+xTt
30 | 6huBgFlksHqr0tsQA9dcgGW7BDr5VhRq5/WinaLhGGy1R+i2zbDmQXgHbCO+RH/s
31 | bD9F4LZ3RoXmGHLW0IUggPw=
32 | -----END PRIVATE KEY-----
33 | Bag Attributes
34 | friendlyName: arangotest
35 | localKeyID: 54 69 6D 65 20 31 36 30 34 32 35 36 36 37 39 38 35 34
36 | subject=C = Unknown, ST = Unknown, L = Unknown, O = Unknown, OU = Unknown, CN = localhost
37 |
38 | issuer=C = Unknown, ST = Unknown, L = Unknown, O = Unknown, OU = Unknown, CN = localhost
39 |
40 | -----BEGIN CERTIFICATE-----
41 | MIIDezCCAmOgAwIBAgIEeDCzXzANBgkqhkiG9w0BAQsFADBuMRAwDgYDVQQGEwdV
42 | bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD
43 | VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRIwEAYDVQQDEwlsb2NhbGhv
44 | c3QwHhcNMjAxMTAxMTg1MTE5WhcNMzAxMDMwMTg1MTE5WjBuMRAwDgYDVQQGEwdV
45 | bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD
46 | VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRIwEAYDVQQDEwlsb2NhbGhv
47 | c3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC1WiDnd4+uCmMG539Z
48 | NZB8NwI0RZF3sUSQGPx3lkqaFTZVEzMZL76HYvdc9Qg7difyKyQ09RLSpMALX9eu
49 | SseD7bZGnfQH52BnKcT09eQ3wh7aVQ5sN2omygdHLC7X9usntxAfv7NzmvdogNXo
50 | JQyY/hSZff7RIqWH8NnAUKkjqOe6Bf5LDbxHKESmrFBxOCOnhcpvZWetwpiRdJVP
51 | wUn5P82CAZzfiBfmBZnB7D0l+/6Cv4jMuH26uAIcixnVekBQzl1RgwczuiZf2MGO
52 | 64vDMMJJWE9ClZF1uQuQrwXF6qwhuP1Hnkii6wNbTtPWlGSkqeutr004+Hzbf8Kn
53 | RY4PAgMBAAGjITAfMB0GA1UdDgQWBBTBrv9Awynt3C5IbaCNyOW5v4DNkTANBgkq
54 | hkiG9w0BAQsFAAOCAQEAIm9rPvDkYpmzpSIhR3VXG9Y71gxRDrqkEeLsMoEyqGnw
55 | /zx1bDCNeGg2PncLlW6zTIipEBooixIE9U7KxHgZxBy0Et6EEWvIUmnr6F4F+dbT
56 | D050GHlcZ7eOeqYTPYeQC502G1Fo4tdNi4lDP9L9XZpf7Q1QimRH2qaLS03ZFZa2
57 | tY7ah/RQqZL8Dkxx8/zc25sgTHVpxoK853glBVBs/ENMiyGJWmAXQayewY3EPt/9
58 | wGwV4KmU3dPDleQeXSUGPUISeQxFjy+jCw21pYviWVJTNBA9l5ny3GhEmcnOT/gQ
59 | HCvVRLyGLMbaMZ4JrPwb+aAtBgrgeiK4xeSMMvrbhw==
60 | -----END CERTIFICATE-----
61 |
--------------------------------------------------------------------------------
/demo/docker/server.pem:
--------------------------------------------------------------------------------
1 | Bag Attributes
2 | friendlyName: arangotest
3 | localKeyID: 54 69 6D 65 20 31 36 30 34 32 35 36 36 37 39 38 35 34
4 | Key Attributes:
5 | -----BEGIN PRIVATE KEY-----
6 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC1WiDnd4+uCmMG
7 | 539ZNZB8NwI0RZF3sUSQGPx3lkqaFTZVEzMZL76HYvdc9Qg7difyKyQ09RLSpMAL
8 | X9euSseD7bZGnfQH52BnKcT09eQ3wh7aVQ5sN2omygdHLC7X9usntxAfv7Nzmvdo
9 | gNXoJQyY/hSZff7RIqWH8NnAUKkjqOe6Bf5LDbxHKESmrFBxOCOnhcpvZWetwpiR
10 | dJVPwUn5P82CAZzfiBfmBZnB7D0l+/6Cv4jMuH26uAIcixnVekBQzl1RgwczuiZf
11 | 2MGO64vDMMJJWE9ClZF1uQuQrwXF6qwhuP1Hnkii6wNbTtPWlGSkqeutr004+Hzb
12 | f8KnRY4PAgMBAAECggEAKi1d/bdW2TldMpvgiFTm15zLjHCpllbKBWFqRj3T9+X7
13 | Duo6Nh9ehopD0YDDe2DNhYr3DsH4sLjUWVDfDpAhutMsU1wlBzmOuC+EuRv/CeDB
14 | 4DFr+0sgCwlti+YAtwWcR05SF7A0Ai0GYW2lUipbtbFSBSjCfM08BlPDsPCRhdM8
15 | DhBn3S45aP7oC8BdhG/etg+DfXW+/nyNwEcMCYG97bzXNjzYpCQjo/bTHdh2UPYM
16 | 4WEAqFzZ5jir8LVS3v7GqpqPmk6FnHJOJpfpOSZoPqnfpIw7SVlNsXHvDaHGcgYZ
17 | Xec7rLQlBuv4RZU7OlGJpK2Ng5kvS9q3nfqqn7YIMQKBgQDqSsYnE+k6CnrSpa2W
18 | B9W/+PChITgkA46XBUUjAueJ7yVZQQEOzl0VI6RoVBp3t66eO8uM9omO8/ogHXku
19 | Ei9UUIIfH4BsSP7G5A06UC/FgReDxwBfbRuS+lupnmc348vPDkFlJZ4hDgWflNev
20 | 7tpUbljSAqUea1VhdBy146V4qwKBgQDGJ6iL1+A9uUM+1UklOAPpPhTQ8ZQDRCj7
21 | 7IMVcbzWYvCMuVNXzOWuiz+VYr3IGCJZIbxbFDOHxGF4XKJnk0vm1qhQQME0PtAF
22 | i1jIfsxpj8KKJl9Uad+XLQCYRV8mIZlhsd/ErRJuz6FyqevKH3nFIb0ggF3x2d06
23 | odTHuj4ILQKBgCUsI/BDSne4/e+59aaeK52/w33tJVkhb1gqr+N0LIRH+ycEF0Tg
24 | HQijlQwwe9qOvBfC6PK+kuipcP/zbSyQGg5Ij7ycZOXJVxL7T9X2rv2pE7AGvNpn
25 | Fz7klfJ9fWbyr310h4+ivkoETYQaO3ZgcSeAMntvi/8djHhf0cZSDgjtAoGBAKvQ
26 | TUNcHjJGxfjgRLkB1dpSmwgEv7sJSaQOkiZw5TTauwq50nsJzYlHcg1cfYPW8Ulp
27 | iAFNBdVNwNn1MFgwjpqMO4rCawObBxIXnhbSYvmQzjStSvFNj7JsMdzWIcdVUMI1
28 | 0fmdu6LbY3ihvzIVkqcMNwnMZCjFKB6jnXTElu7NAoGAS0gNPD/bfzWAhZBBYp9/
29 | SLGOvjHKrSVWGwDiqdAGuh6xg+1C3F+XpiITP6d3Wv3PCJ/Gia5isQPSMaXG+xTt
30 | 6huBgFlksHqr0tsQA9dcgGW7BDr5VhRq5/WinaLhGGy1R+i2zbDmQXgHbCO+RH/s
31 | bD9F4LZ3RoXmGHLW0IUggPw=
32 | -----END PRIVATE KEY-----
33 | Bag Attributes
34 | friendlyName: arangotest
35 | localKeyID: 54 69 6D 65 20 31 36 30 34 32 35 36 36 37 39 38 35 34
36 | subject=C = Unknown, ST = Unknown, L = Unknown, O = Unknown, OU = Unknown, CN = localhost
37 |
38 | issuer=C = Unknown, ST = Unknown, L = Unknown, O = Unknown, OU = Unknown, CN = localhost
39 |
40 | -----BEGIN CERTIFICATE-----
41 | MIIDezCCAmOgAwIBAgIEeDCzXzANBgkqhkiG9w0BAQsFADBuMRAwDgYDVQQGEwdV
42 | bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD
43 | VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRIwEAYDVQQDEwlsb2NhbGhv
44 | c3QwHhcNMjAxMTAxMTg1MTE5WhcNMzAxMDMwMTg1MTE5WjBuMRAwDgYDVQQGEwdV
45 | bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD
46 | VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRIwEAYDVQQDEwlsb2NhbGhv
47 | c3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC1WiDnd4+uCmMG539Z
48 | NZB8NwI0RZF3sUSQGPx3lkqaFTZVEzMZL76HYvdc9Qg7difyKyQ09RLSpMALX9eu
49 | SseD7bZGnfQH52BnKcT09eQ3wh7aVQ5sN2omygdHLC7X9usntxAfv7NzmvdogNXo
50 | JQyY/hSZff7RIqWH8NnAUKkjqOe6Bf5LDbxHKESmrFBxOCOnhcpvZWetwpiRdJVP
51 | wUn5P82CAZzfiBfmBZnB7D0l+/6Cv4jMuH26uAIcixnVekBQzl1RgwczuiZf2MGO
52 | 64vDMMJJWE9ClZF1uQuQrwXF6qwhuP1Hnkii6wNbTtPWlGSkqeutr004+Hzbf8Kn
53 | RY4PAgMBAAGjITAfMB0GA1UdDgQWBBTBrv9Awynt3C5IbaCNyOW5v4DNkTANBgkq
54 | hkiG9w0BAQsFAAOCAQEAIm9rPvDkYpmzpSIhR3VXG9Y71gxRDrqkEeLsMoEyqGnw
55 | /zx1bDCNeGg2PncLlW6zTIipEBooixIE9U7KxHgZxBy0Et6EEWvIUmnr6F4F+dbT
56 | D050GHlcZ7eOeqYTPYeQC502G1Fo4tdNi4lDP9L9XZpf7Q1QimRH2qaLS03ZFZa2
57 | tY7ah/RQqZL8Dkxx8/zc25sgTHVpxoK853glBVBs/ENMiyGJWmAXQayewY3EPt/9
58 | wGwV4KmU3dPDleQeXSUGPUISeQxFjy+jCw21pYviWVJTNBA9l5ny3GhEmcnOT/gQ
59 | HCvVRLyGLMbaMZ4JrPwb+aAtBgrgeiK4xeSMMvrbhw==
60 | -----END CERTIFICATE-----
61 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/WriteResiliencyTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.write
2 |
3 | import com.arangodb.ArangoCollection
4 | import com.arangodb.model.OverwriteMode
5 | import org.apache.spark.sql.SaveMode
6 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
7 | import org.apache.spark.sql.arangodb.datasource.BaseSparkTest
8 | import org.assertj.core.api.Assertions.assertThat
9 | import org.junit.jupiter.api.{BeforeEach, Disabled}
10 | import org.junit.jupiter.params.ParameterizedTest
11 | import org.junit.jupiter.params.provider.MethodSource
12 |
13 |
14 | class WriteResiliencyTest extends BaseSparkTest {
15 |
16 | private val collectionName = "chessPlayersResiliency"
17 | private val collection: ArangoCollection = db.collection(collectionName)
18 |
19 | import spark.implicits._
20 |
21 | private val df = Seq(
22 | ("Carlsen", "Magnus"),
23 | ("Caruana", "Fabiano"),
24 | ("Ding", "Liren"),
25 | ("Nepomniachtchi", "Ian"),
26 | ("Aronian", "Levon"),
27 | ("Grischuk", "Alexander"),
28 | ("Giri", "Anish"),
29 | ("Mamedyarov", "Shakhriyar"),
30 | ("So", "Wesley"),
31 | ("Radjabov", "Teimour")
32 | ).toDF("_key", "name")
33 | .repartition(6)
34 |
35 | @BeforeEach
36 | def beforeEach(): Unit = {
37 | if (collection.exists()) {
38 | collection.truncate()
39 | } else {
40 | collection.create()
41 | }
42 | }
43 |
44 | @Disabled("manual test only")
45 | @ParameterizedTest
46 | @MethodSource(Array("provideProtocolAndContentType"))
47 | def retryOnTimeout(protocol: String, contentType: String): Unit = {
48 | df.write
49 | .format(BaseSparkTest.arangoDatasource)
50 | .mode(SaveMode.Append)
51 | .options(options + (
52 | ArangoDBConf.TIMEOUT -> "1",
53 | ArangoDBConf.ENDPOINTS -> BaseSparkTest.endpoints,
54 | ArangoDBConf.COLLECTION -> collectionName,
55 | ArangoDBConf.PROTOCOL -> protocol,
56 | ArangoDBConf.CONTENT_TYPE -> contentType,
57 | ArangoDBConf.CONFIRM_TRUNCATE -> "true",
58 | ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue
59 | ))
60 | .save()
61 |
62 | assertThat(collection.count().getCount).isEqualTo(10L)
63 | }
64 |
65 | @ParameterizedTest
66 | @MethodSource(Array("provideProtocolAndContentType"))
67 | def retryOnWrongHost(protocol: String, contentType: String): Unit = {
68 | retryOnBadHost(BaseSparkTest.endpoints + ",127.0.0.1:111", protocol, contentType)
69 | }
70 |
71 | @ParameterizedTest
72 | @MethodSource(Array("provideProtocolAndContentType"))
73 | def retryOnUnknownHost(protocol: String, contentType: String): Unit = {
74 | retryOnBadHost(BaseSparkTest.endpoints + ",wrongHost:8529", protocol, contentType)
75 | }
76 |
77 | private def retryOnBadHost(endpoints: String, protocol: String, contentType: String): Unit = {
78 | df.write
79 | .format(BaseSparkTest.arangoDatasource)
80 | .mode(SaveMode.Append)
81 | .options(options + (
82 | ArangoDBConf.ENDPOINTS -> endpoints,
83 | ArangoDBConf.COLLECTION -> collectionName,
84 | ArangoDBConf.PROTOCOL -> protocol,
85 | ArangoDBConf.CONTENT_TYPE -> contentType,
86 | ArangoDBConf.CONFIRM_TRUNCATE -> "true",
87 | ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue,
88 | ArangoDBConf.MAX_ATTEMPTS -> "4",
89 | ArangoDBConf.MIN_RETRY_DELAY -> "20",
90 | ArangoDBConf.MAX_RETRY_DELAY -> "40"
91 | ))
92 | .save()
93 |
94 | assertThat(collection.count().getCount).isEqualTo(10L)
95 | }
96 |
97 | }
98 |
--------------------------------------------------------------------------------
/integration-tests/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | arangodb-spark-datasource
7 | com.arangodb
8 | 1.9.0-SNAPSHOT
9 |
10 | 4.0.0
11 |
12 | integration-tests
13 |
14 |
15 | target/site/jacoco-aggregate/jacoco.xml
16 |
17 |
18 |
19 |
20 | com.arangodb
21 | arangodb-spark-datasource-${spark.compat.version}_${scala.compat.version}
22 | ${project.version}
23 | test
24 |
25 |
26 | com.arangodb
27 | arangodb-spark-commons-${spark.compat.version}_${scala.compat.version}
28 | ${project.version}
29 | test
30 |
31 |
32 | com.arangodb
33 | jackson-serde-json
34 | 7.9.0
35 | test
36 |
37 |
38 | com.fasterxml.jackson.core
39 | jackson-core
40 |
41 |
42 | com.fasterxml.jackson.core
43 | jackson-databind
44 |
45 |
46 | com.fasterxml.jackson.core
47 | jackson-annotations
48 |
49 |
50 |
51 |
52 | com.arangodb
53 | jackson-serde-vpack
54 | 7.9.0
55 | test
56 |
57 |
58 | com.arangodb
59 | jackson-dataformat-velocypack
60 |
61 |
62 |
63 |
64 | io.qameta.allure
65 | allure-junit5
66 | 2.30.0
67 | test
68 |
69 |
70 |
71 |
72 |
73 |
74 | org.jacoco
75 | jacoco-maven-plugin
76 |
77 |
78 | report-aggregate
79 | verify
80 |
81 | report-aggregate
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/demo/python-demo/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from argparse import ArgumentParser
4 | from typing import Dict
5 |
6 | from pyspark.sql import SparkSession
7 |
8 | from read_write_demo import read_write_demo
9 | from read_demo import read_demo
10 | from write_demo import write_demo
11 |
12 |
13 | def create_spark_session() -> SparkSession:
14 | # Here we can initialize the spark session, and in doing so,
15 | # include the ArangoDB Spark DataSource package
16 | arango_spark_version = os.environ["ARANGO_SPARK_VERSION"]
17 |
18 | spark = SparkSession.builder \
19 | .appName("ArangoDBPySparkDataTypesExample") \
20 | .master("local[*]") \
21 | .config("spark.jars.packages", f"com.arangodb:arangodb-spark-datasource-3.5_2.12:{arango_spark_version}") \
22 | .getOrCreate()
23 |
24 | return spark
25 |
26 |
27 | def create_base_arangodb_datasource_opts(password: str, endpoints: str, ssl_enabled: str, ssl_cert_value: str) -> Dict[str, str]:
28 | return {
29 | "password": password,
30 | "endpoints": endpoints,
31 | "ssl.enabled": ssl_enabled,
32 | "ssl.cert.value": ssl_cert_value,
33 | "ssl.verifyHost": "false"
34 | }
35 |
36 |
37 | def main():
38 | parser = ArgumentParser()
39 | parser.add_argument("--import-path", default=None)
40 | parser.add_argument("--password", default="test")
41 | parser.add_argument("--endpoints", default="localhost:8529")
42 | parser.add_argument("--ssl-enabled", default="false")
43 | parser.add_argument("--ssl-cert-value", default="LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlekNDQW1PZ0F3SUJBZ0lFZURDelh6QU5CZ2txaGtpRzl3MEJBUXNGQURCdU1SQXdEZ1lEVlFRR0V3ZFYKYm10dWIzZHVNUkF3RGdZRFZRUUlFd2RWYm10dWIzZHVNUkF3RGdZRFZRUUhFd2RWYm10dWIzZHVNUkF3RGdZRApWUVFLRXdkVmJtdHViM2R1TVJBd0RnWURWUVFMRXdkVmJtdHViM2R1TVJJd0VBWURWUVFERXdsc2IyTmhiR2h2CmMzUXdIaGNOTWpBeE1UQXhNVGcxTVRFNVdoY05NekF4TURNd01UZzFNVEU1V2pCdU1SQXdEZ1lEVlFRR0V3ZFYKYm10dWIzZHVNUkF3RGdZRFZRUUlFd2RWYm10dWIzZHVNUkF3RGdZRFZRUUhFd2RWYm10dWIzZHVNUkF3RGdZRApWUVFLRXdkVmJtdHViM2R1TVJBd0RnWURWUVFMRXdkVmJtdHViM2R1TVJJd0VBWURWUVFERXdsc2IyTmhiR2h2CmMzUXdnZ0VpTUEwR0NTcUdTSWIzRFFFQkFRVUFBNElCRHdBd2dnRUtBb0lCQVFDMVdpRG5kNCt1Q21NRzUzOVoKTlpCOE53STBSWkYzc1VTUUdQeDNsa3FhRlRaVkV6TVpMNzZIWXZkYzlRZzdkaWZ5S3lRMDlSTFNwTUFMWDlldQpTc2VEN2JaR25mUUg1MkJuS2NUMDllUTN3aDdhVlE1c04yb215Z2RITEM3WDl1c250eEFmdjdOem12ZG9nTlhvCkpReVkvaFNaZmY3UklxV0g4Tm5BVUtranFPZTZCZjVMRGJ4SEtFU21yRkJ4T0NPbmhjcHZaV2V0d3BpUmRKVlAKd1VuNVA4MkNBWnpmaUJmbUJabkI3RDBsKy82Q3Y0ak11SDI2dUFJY2l4blZla0JRemwxUmd3Y3p1aVpmMk1HTwo2NHZETU1KSldFOUNsWkYxdVF1UXJ3WEY2cXdodVAxSG5raWk2d05iVHRQV2xHU2txZXV0cjAwNCtIemJmOEtuClJZNFBBZ01CQUFHaklUQWZNQjBHQTFVZERnUVdCQlRCcnY5QXd5bnQzQzVJYmFDTnlPVzV2NEROa1RBTkJna3EKaGtpRzl3MEJBUXNGQUFPQ0FRRUFJbTlyUHZEa1lwbXpwU0loUjNWWEc5WTcxZ3hSRHJxa0VlTHNNb0V5cUdudwovengxYkRDTmVHZzJQbmNMbFc2elRJaXBFQm9vaXhJRTlVN0t4SGdaeEJ5MEV0NkVFV3ZJVW1ucjZGNEYrZGJUCkQwNTBHSGxjWjdlT2VxWVRQWWVRQzUwMkcxRm80dGROaTRsRFA5TDlYWnBmN1ExUWltUkgycWFMUzAzWkZaYTIKdFk3YWgvUlFxWkw4RGt4eDgvemMyNXNnVEhWcHhvSzg1M2dsQlZCcy9FTk1peUdKV21BWFFheWV3WTNFUHQvOQp3R3dWNEttVTNkUERsZVFlWFNVR1BVSVNlUXhGankrakN3MjFwWXZpV1ZKVE5CQTlsNW55M0doRW1jbk9UL2dRCkhDdlZSTHlHTE1iYU1aNEpyUHdiK2FBdEJncmdlaUs0eGVTTU12cmJodz09Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K")
44 | args = parser.parse_args()
45 |
46 | if args.import_path is None:
47 | args.import_path = pathlib.Path(__file__).resolve().parent.parent / "docker" / "import"
48 |
49 | spark = create_spark_session()
50 | base_opts = create_base_arangodb_datasource_opts(args.password, args.endpoints, args.ssl_enabled, args.ssl_cert_value)
51 | write_demo(spark, base_opts, args.import_path)
52 | read_demo(spark, base_opts)
53 | read_write_demo(spark, base_opts)
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/EdgeSchemaValidationTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.write
2 |
3 | import com.arangodb.ArangoCollection
4 | import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, CollectionType}
5 | import org.apache.spark.sql.arangodb.datasource.BaseSparkTest
6 | import org.apache.spark.sql.types.{StringType, StructField, StructType}
7 | import org.apache.spark.sql.{DataFrame, Row, SaveMode}
8 | import org.assertj.core.api.Assertions.{assertThat, catchThrowable}
9 | import org.assertj.core.api.ThrowableAssert.ThrowingCallable
10 | import org.junit.jupiter.api.BeforeEach
11 | import org.junit.jupiter.params.ParameterizedTest
12 | import org.junit.jupiter.params.provider.MethodSource
13 |
14 | import scala.collection.JavaConverters._
15 |
16 | class EdgeSchemaValidationTest extends BaseSparkTest {
17 |
18 | private val collectionName = "edgeSchemaValidationTest"
19 | private val collection: ArangoCollection = db.collection(collectionName)
20 |
21 | private val rows = Seq(
22 | Row("k1", "from/from", "to/to"),
23 | Row("k2", "from/from", "to/to"),
24 | Row("k3", "from/from", "to/to")
25 | )
26 |
27 | private val df = spark.createDataFrame(rows.asJava, StructType(Array(
28 | StructField("_key", StringType, nullable = false),
29 | StructField("_from", StringType, nullable = false),
30 | StructField("_to", StringType, nullable = false)
31 | )))
32 |
33 | @BeforeEach
34 | def beforeEach(): Unit = {
35 | if (collection.exists()) {
36 | collection.drop()
37 | }
38 | }
39 |
40 | @ParameterizedTest
41 | @MethodSource(Array("provideProtocolAndContentType"))
42 | def write(protocol: String, contentType: String): Unit = {
43 | doWrite(df, protocol, contentType)
44 | assertThat(collection.count().getCount).isEqualTo(3L)
45 | }
46 |
47 | @ParameterizedTest
48 | @MethodSource(Array("provideProtocolAndContentType"))
49 | def dfWithNullableFromFieldShouldFail(protocol: String, contentType: String): Unit = {
50 | val nullableFromSchema = StructType(df.schema.map(p =>
51 | if (p.name == "_from") StructField(p.name, p.dataType)
52 | else p
53 | ))
54 | val dfWithNullableFrom = spark.createDataFrame(df.rdd, nullableFromSchema)
55 | val thrown = catchThrowable(new ThrowingCallable() {
56 | override def call(): Unit = doWrite(dfWithNullableFrom, protocol, contentType)
57 | })
58 |
59 | assertThat(thrown).isInstanceOf(classOf[IllegalArgumentException])
60 | assertThat(thrown).hasMessageContaining("_from")
61 | }
62 |
63 | @ParameterizedTest
64 | @MethodSource(Array("provideProtocolAndContentType"))
65 | def dfWithNullableToFieldShouldFail(protocol: String, contentType: String): Unit = {
66 | val nullableFromSchema = StructType(df.schema.map(p =>
67 | if (p.name == "_to") StructField(p.name, p.dataType)
68 | else p
69 | ))
70 | val dfWithNullableFrom = spark.createDataFrame(df.rdd, nullableFromSchema)
71 | val thrown = catchThrowable(new ThrowingCallable() {
72 | override def call(): Unit = doWrite(dfWithNullableFrom, protocol, contentType)
73 | })
74 |
75 | assertThat(thrown).isInstanceOf(classOf[IllegalArgumentException])
76 | assertThat(thrown).hasMessageContaining("_to")
77 | }
78 |
79 | private def doWrite(testDF: DataFrame, protocol: String, contentType: String): Unit = {
80 | testDF.write
81 | .format(BaseSparkTest.arangoDatasource)
82 | .mode(SaveMode.Append)
83 | .options(options + (
84 | ArangoDBConf.COLLECTION -> collectionName,
85 | ArangoDBConf.PROTOCOL -> protocol,
86 | ArangoDBConf.CONTENT_TYPE -> contentType,
87 | ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name
88 | ))
89 | .save()
90 | }
91 |
92 | }
93 |
--------------------------------------------------------------------------------
/docker/start_db.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Configuration environment variables:
4 | # STARTER_MODE: (single|cluster|activefailover), default single
5 | # DOCKER_IMAGE: ArangoDB docker image, default docker.io/arangodb/enterprise:latest
6 | # STARTER_DOCKER_IMAGE: ArangoDB Starter docker image, default docker.io/arangodb/arangodb-starter:latest
7 | # SSL: (true|false), default false
8 | # ARANGO_LICENSE_KEY: only required for ArangoDB Enterprise
9 |
10 | # EXAMPLE:
11 | # STARTER_MODE=cluster SSL=true ./start_db.sh
12 |
13 | STARTER_MODE=${STARTER_MODE:=single}
14 | DOCKER_IMAGE=${DOCKER_IMAGE:=docker.io/arangodb/enterprise:latest}
15 | STARTER_DOCKER_IMAGE=${STARTER_DOCKER_IMAGE:=docker.io/arangodb/arangodb-starter:latest}
16 | SSL=${SSL:=false}
17 | COMPRESSION=${COMPRESSION:=false}
18 |
19 | GW=172.28.0.1
20 | docker network create arangodb --subnet 172.28.0.0/16
21 |
22 | # exit when any command fails
23 | set -e
24 |
25 | docker pull $STARTER_DOCKER_IMAGE
26 | docker pull $DOCKER_IMAGE
27 |
28 | LOCATION=$(pwd)/$(dirname "$0")
29 | AUTHORIZATION_HEADER=$(cat "$LOCATION"/jwtHeader)
30 |
31 | STARTER_ARGS=
32 | SCHEME=http
33 | ARANGOSH_SCHEME=http+tcp
34 | COORDINATORS=("$GW:8529" "$GW:8539" "$GW:8549")
35 |
36 | if [ "$STARTER_MODE" == "single" ]; then
37 | COORDINATORS=("$GW:8529")
38 | fi
39 |
40 | if [ "$SSL" == "true" ]; then
41 | STARTER_ARGS="$STARTER_ARGS --ssl.keyfile=/data/server.pem"
42 | SCHEME=https
43 | ARANGOSH_SCHEME=http+ssl
44 | fi
45 |
46 | if [ "$COMPRESSION" == "true" ]; then
47 | STARTER_ARGS="${STARTER_ARGS} --all.http.compress-response-threshold=1"
48 | fi
49 |
50 | # data volume
51 | docker create -v /data --name arangodb-data alpine:3 /bin/true
52 | docker cp "$LOCATION"/jwtSecret arangodb-data:/data
53 | docker cp "$LOCATION"/server.pem arangodb-data:/data
54 |
55 | docker run -d \
56 | --name=adb \
57 | -p 8528:8528 \
58 | --volumes-from arangodb-data \
59 | -v /var/run/docker.sock:/var/run/docker.sock \
60 | -e ARANGO_LICENSE_KEY="$ARANGO_LICENSE_KEY" \
61 | $STARTER_DOCKER_IMAGE \
62 | $STARTER_ARGS \
63 | --docker.net-mode=default \
64 | --docker.container=adb \
65 | --auth.jwt-secret=/data/jwtSecret \
66 | --starter.address="${GW}" \
67 | --docker.image="${DOCKER_IMAGE}" \
68 | --starter.local --starter.mode=${STARTER_MODE} --all.log.level=debug --all.log.output=+ --log.verbose \
69 | --all.server.descriptors-minimum=1024 --all.javascript.allow-admin-execute=true
70 |
71 |
72 | wait_server() {
73 | # shellcheck disable=SC2091
74 | until $(curl --output /dev/null --insecure --fail --silent --head -i -H "$AUTHORIZATION_HEADER" "$SCHEME://$1/_api/version"); do
75 | printf '.'
76 | sleep 1
77 | done
78 | }
79 |
80 | echo "Waiting..."
81 |
82 | for a in ${COORDINATORS[*]} ; do
83 | wait_server "$a"
84 | done
85 |
86 | set +e
87 | for a in ${COORDINATORS[*]} ; do
88 | echo ""
89 | echo "Setting username and password..."
90 | docker run --rm ${DOCKER_IMAGE} arangosh --server.endpoint="$ARANGOSH_SCHEME://$a" --server.authentication=false --javascript.execute-string='require("org/arangodb/users").update("root", "test")'
91 | done
92 | set -e
93 |
94 | for a in ${COORDINATORS[*]} ; do
95 | echo ""
96 | echo "Requesting endpoint version..."
97 | curl -u root:test --insecure --fail "$SCHEME://$a/_api/version"
98 | done
99 |
100 | echo ""
101 | echo ""
102 | echo "Done, your deployment is reachable at: "
103 | for a in ${COORDINATORS[*]} ; do
104 | echo "$SCHEME://$a"
105 | echo ""
106 | done
107 |
108 | if [ "$STARTER_MODE" == "activefailover" ]; then
109 | LEADER=$("$LOCATION"/find_active_endpoint.sh)
110 | echo "Leader: $SCHEME://$LEADER"
111 | echo ""
112 | fi
113 |
--------------------------------------------------------------------------------
/demo/docker/start_db.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Configuration environment variables:
4 | # STARTER_MODE: (single|cluster|activefailover), default single
5 | # DOCKER_IMAGE: ArangoDB docker image, default docker.io/arangodb/enterprise:latest
6 | # STARTER_DOCKER_IMAGE: ArangoDB Starter docker image, default docker.io/arangodb/arangodb-starter:latest
7 | # SSL: (true|false), default false
8 | # ARANGO_LICENSE_KEY: only required for ArangoDB Enterprise
9 |
10 | # EXAMPLE:
11 | # STARTER_MODE=cluster SSL=true ./start_db.sh
12 |
13 | STARTER_MODE=${STARTER_MODE:=single}
14 | DOCKER_IMAGE=${DOCKER_IMAGE:=docker.io/arangodb/enterprise:latest}
15 | STARTER_DOCKER_IMAGE=${STARTER_DOCKER_IMAGE:=docker.io/arangodb/arangodb-starter:latest}
16 | SSL=${SSL:=false}
17 | COMPRESSION=${COMPRESSION:=false}
18 |
19 | GW=172.28.0.1
20 | docker network create arangodb --subnet 172.28.0.0/16
21 |
22 | # exit when any command fails
23 | set -e
24 |
25 | docker pull $STARTER_DOCKER_IMAGE
26 | docker pull $DOCKER_IMAGE
27 |
28 | LOCATION=$(pwd)/$(dirname "$0")
29 | AUTHORIZATION_HEADER=$(cat "$LOCATION"/jwtHeader)
30 |
31 | STARTER_ARGS=
32 | SCHEME=http
33 | ARANGOSH_SCHEME=http+tcp
34 | COORDINATORS=("$GW:8529" "$GW:8539" "$GW:8549")
35 |
36 | if [ "$STARTER_MODE" == "single" ]; then
37 | COORDINATORS=("$GW:8529")
38 | fi
39 |
40 | if [ "$SSL" == "true" ]; then
41 | STARTER_ARGS="$STARTER_ARGS --ssl.keyfile=/data/server.pem"
42 | SCHEME=https
43 | ARANGOSH_SCHEME=http+ssl
44 | fi
45 |
46 | if [ "$COMPRESSION" == "true" ]; then
47 | STARTER_ARGS="${STARTER_ARGS} --all.http.compress-response-threshold=1"
48 | fi
49 |
50 | # data volume
51 | docker create -v /data --name arangodb-data alpine:3 /bin/true
52 | docker cp "$LOCATION"/jwtSecret arangodb-data:/data
53 | docker cp "$LOCATION"/server.pem arangodb-data:/data
54 |
55 | docker run -d \
56 | --name=adb \
57 | -p 8528:8528 \
58 | --volumes-from arangodb-data \
59 | -v /var/run/docker.sock:/var/run/docker.sock \
60 | -e ARANGO_LICENSE_KEY="$ARANGO_LICENSE_KEY" \
61 | $STARTER_DOCKER_IMAGE \
62 | $STARTER_ARGS \
63 | --docker.net-mode=default \
64 | --docker.container=adb \
65 | --auth.jwt-secret=/data/jwtSecret \
66 | --starter.address="${GW}" \
67 | --docker.image="${DOCKER_IMAGE}" \
68 | --starter.local --starter.mode=${STARTER_MODE} --all.log.level=debug --all.log.output=+ --log.verbose \
69 | --all.server.descriptors-minimum=1024 --all.javascript.allow-admin-execute=true
70 |
71 |
72 | wait_server() {
73 | # shellcheck disable=SC2091
74 | until $(curl --output /dev/null --insecure --fail --silent --head -i -H "$AUTHORIZATION_HEADER" "$SCHEME://$1/_api/version"); do
75 | printf '.'
76 | sleep 1
77 | done
78 | }
79 |
80 | echo "Waiting..."
81 |
82 | for a in ${COORDINATORS[*]} ; do
83 | wait_server "$a"
84 | done
85 |
86 | set +e
87 | for a in ${COORDINATORS[*]} ; do
88 | echo ""
89 | echo "Setting username and password..."
90 | docker run --rm ${DOCKER_IMAGE} arangosh --server.endpoint="$ARANGOSH_SCHEME://$a" --server.authentication=false --javascript.execute-string='require("org/arangodb/users").update("root", "test")'
91 | done
92 | set -e
93 |
94 | for a in ${COORDINATORS[*]} ; do
95 | echo ""
96 | echo "Requesting endpoint version..."
97 | curl -u root:test --insecure --fail "$SCHEME://$a/_api/version"
98 | done
99 |
100 | echo ""
101 | echo ""
102 | echo "Done, your deployment is reachable at: "
103 | for a in ${COORDINATORS[*]} ; do
104 | echo "$SCHEME://$a"
105 | echo ""
106 | done
107 |
108 | if [ "$STARTER_MODE" == "activefailover" ]; then
109 | LEADER=$("$LOCATION"/find_active_endpoint.sh)
110 | echo "Leader: $SCHEME://$LEADER"
111 | echo ""
112 | fi
113 |
--------------------------------------------------------------------------------
/demo/python-demo/write_demo.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import pathlib
3 | from typing import Dict
4 |
5 | from pyspark import pandas as ps
6 | from pyspark.sql import SparkSession, functions as f
7 | from pyspark.sql.types import StructType
8 |
9 | from utils import combine_dicts
10 | from schemas import person_schema, movie_schema, directed_schema, acts_in_schema
11 |
12 |
13 | def save_df(ps_df, table_name: str, options: Dict[str, str], table_type: str = None) -> None:
14 | if not table_type:
15 | table_type = "document"
16 |
17 | all_opts = combine_dicts([options, {
18 | "table.shards": "9",
19 | "confirmTruncate": "true",
20 | "overwriteMode": "replace",
21 | "table": table_name,
22 | "table.type": table_type
23 | }])
24 |
25 | ps_df.to_spark()\
26 | .write\
27 | .mode("overwrite")\
28 | .format("com.arangodb.spark")\
29 | .options(**all_opts)\
30 | .save()
31 |
32 |
33 | def write_demo(spark: SparkSession, save_opts: Dict[str, str], import_path_str: str):
34 | import_path = pathlib.Path(import_path_str)
35 |
36 | print("Read Nodes from JSONL using Pandas on Spark API")
37 | nodes_pd_df = ps.read_json(str(import_path / "nodes.jsonl"))
38 | nodes_pd_df = nodes_pd_df[nodes_pd_df["_key"].notnull()]
39 | nodes_pd_df["releaseDate"] = ps.to_datetime(nodes_pd_df["releaseDate"], unit="ms")
40 | nodes_pd_df["birthday"] = ps.to_datetime(nodes_pd_df["birthday"], unit="ms")
41 |
42 | def convert_to_timestamp(to_modify, column):
43 | tz_aware_datetime = datetime.datetime.utcfromtimestamp(
44 | int(to_modify[column])/1000
45 | ).replace(tzinfo=datetime.timezone.utc).astimezone(tz=None)
46 | tz_naive = tz_aware_datetime.replace(tzinfo=None)
47 | to_modify[column] = tz_naive
48 | return to_modify
49 |
50 | nodes_pd_df = nodes_pd_df.apply(convert_to_timestamp, axis=1, args=("lastModified",))
51 |
52 | nodes_df = nodes_pd_df.to_spark()
53 | nodes_pd_df = nodes_df\
54 | .withColumn("releaseDate", f.to_date(nodes_df["releaseDate"])) \
55 | .withColumn("birthday", f.to_date(nodes_df["birthday"])) \
56 | .to_pandas_on_spark()
57 |
58 | print("Read Edges from JSONL using PySpark API")
59 | edges_df = spark.read.json(str(import_path / "edges.jsonl"))
60 | # apply the schema to change nullability of _key, _from, and _to columns in schema
61 | edges_pd_df = edges_df.to_pandas_on_spark()
62 | edges_pd_df["_from"] = "persons/" + edges_pd_df["_from"]
63 | edges_pd_df["_to"] = "movies/" + edges_pd_df["_to"]
64 |
65 | print("Create the collection dfs")
66 | persons_df = nodes_pd_df[nodes_pd_df["type"] == "Person"][person_schema.fieldNames()[1:]]
67 | movies_df = nodes_pd_df[nodes_pd_df["type"] == "Movie"][movie_schema.fieldNames()[1:]]
68 | directed_df = edges_pd_df[edges_pd_df["$label"] == "DIRECTED"][directed_schema.fieldNames()[1:]]
69 | acted_in_df = edges_pd_df[edges_pd_df["$label"] == "ACTS_IN"][acts_in_schema.fieldNames()[1:]]
70 |
71 | # _from and _to need to be set with nullable=False in the schema in order for it to work
72 | directed_df = spark.createDataFrame(directed_df.to_spark().rdd, StructType(
73 | directed_schema.fields[1:])).to_pandas_on_spark()
74 | acted_in_df = spark.createDataFrame(acted_in_df.to_spark().rdd, StructType(
75 | acts_in_schema.fields[1:])).to_pandas_on_spark()
76 |
77 | print("writing the persons collection")
78 | save_df(persons_df, "persons", save_opts)
79 | print("writing the movies collection")
80 | save_df(movies_df, "movies", save_opts)
81 | print("writing the 'directed' edge collection")
82 | save_df(directed_df, "directed", save_opts, "edge")
83 | print("writing the 'actedIn' collection")
84 | save_df(acted_in_df, "actedIn", save_opts, "edge")
85 |
--------------------------------------------------------------------------------
/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource.writer
2 |
3 | import com.arangodb.entity.CollectionType
4 | import com.arangodb.model.OverwriteMode
5 | import org.apache.spark.internal.Logging
6 | import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf, ContentType}
7 | import org.apache.spark.sql.connector.write.{BatchWrite, SupportsTruncate, WriteBuilder}
8 | import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
9 | import org.apache.spark.sql.{AnalysisException, SaveMode}
10 |
11 | class ArangoWriterBuilder(schema: StructType, options: ArangoDBConf)
12 | extends WriteBuilder with SupportsTruncate with Logging {
13 |
14 | private var mode: SaveMode = SaveMode.Append
15 | validateConfig()
16 |
17 | override def buildForBatch(): BatchWrite = {
18 | val client = ArangoClient(options)
19 | if (!client.collectionExists()) {
20 | client.createCollection()
21 | }
22 | client.shutdown()
23 |
24 | val updatedOptions = options.updated(ArangoDBConf.OVERWRITE_MODE, mode match {
25 | case SaveMode.Append => options.writeOptions.overwriteMode.getValue
26 | case _ => OverwriteMode.ignore.getValue
27 | })
28 |
29 | logSummary(updatedOptions)
30 | new ArangoBatchWriter(schema, updatedOptions, mode)
31 | }
32 |
33 | override def truncate(): WriteBuilder = {
34 | mode = SaveMode.Overwrite
35 | if (options.writeOptions.confirmTruncate) {
36 | val client = ArangoClient(options)
37 | if (client.collectionExists()) {
38 | client.truncate()
39 | } else {
40 | client.createCollection()
41 | }
42 | client.shutdown()
43 | this
44 | } else {
45 | throw new AnalysisException(
46 | "You are attempting to use overwrite mode which will truncate this collection prior to inserting data. If " +
47 | "you just want to change data already in the collection set save mode 'append' and " +
48 | s"'overwrite.mode=(replace|update)'. To actually truncate set '${ArangoDBConf.CONFIRM_TRUNCATE}=true'.")
49 | }
50 | }
51 |
52 | private def validateConfig(): Unit = {
53 | if (options.driverOptions.contentType == ContentType.JSON && hasDecimalTypeFields) {
54 | throw new UnsupportedOperationException("Cannot write DecimalType when using contentType=json")
55 | }
56 |
57 | if (options.writeOptions.collectionType == CollectionType.EDGES &&
58 | !schema.exists(p => p.name == "_from" && p.dataType == StringType && !p.nullable)
59 | ) {
60 | throw new IllegalArgumentException("Writing edge collection requires non nullable string field named _from.")
61 | }
62 |
63 | if (options.writeOptions.collectionType == CollectionType.EDGES &&
64 | !schema.exists(p => p.name == "_to" && p.dataType == StringType && !p.nullable)
65 | ) {
66 | throw new IllegalArgumentException("Writing edge collection requires non nullable string field named _to.")
67 | }
68 | }
69 |
70 | private def hasDecimalTypeFields: Boolean =
71 | schema.existsRecursively {
72 | case _: DecimalType => true
73 | case _ => false
74 | }
75 |
76 | private def logSummary(updatedOptions: ArangoDBConf): Unit = {
77 | val canRetry = ArangoDataWriter.canRetry(schema, updatedOptions)
78 |
79 | logInfo(s"Using save mode: $mode")
80 | logInfo(s"Using write configuration: ${updatedOptions.writeOptions}")
81 | logInfo(s"Using mapping configuration: ${updatedOptions.mappingOptions}")
82 | logInfo(s"Can retry: $canRetry")
83 |
84 | if (!canRetry) {
85 | logWarning(
86 | """The provided configuration does not allow idempotent requests: write failures will not be retried and lead
87 | |to task failure. Speculative task executions could fail or write incorrect data."""
88 | .stripMargin.replaceAll("\n", "")
89 | )
90 | }
91 | }
92 |
93 | }
94 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/CompositeFilterTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import org.apache.spark.sql.DataFrame
4 | import org.apache.spark.sql.functions.{col, not}
5 | import org.apache.spark.sql.types._
6 | import org.assertj.core.api.Assertions.assertThat
7 | import org.junit.jupiter.api.{AfterAll, BeforeAll, Test}
8 |
9 | import scala.collection.JavaConverters._
10 |
11 | class CompositeFilterTest extends BaseSparkTest {
12 | private val df = CompositeFilterTest.df
13 |
14 | @Test
15 | def orFilter(): Unit = {
16 | val fieldName = "integer"
17 | val value = CompositeFilterTest.data.head(fieldName)
18 | val res = df.filter(col(fieldName).equalTo(0) or col(fieldName).equalTo(1)).collect()
19 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames))
20 | assertThat(res).hasSize(1)
21 | assertThat(res.head(fieldName)).isEqualTo(value)
22 | val sqlRes = spark.sql(
23 | s"""
24 | |SELECT * FROM compositeFilter
25 | |WHERE $fieldName = 0 OR $fieldName = 1
26 | |""".stripMargin).collect()
27 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames))
28 | assertThat(sqlRes).hasSize(1)
29 | assertThat(sqlRes.head(fieldName)).isEqualTo(value)
30 | }
31 |
32 | @Test
33 | def notFilter(): Unit = {
34 | val fieldName = "integer"
35 | val value = CompositeFilterTest.data.head(fieldName)
36 | val res = df.filter(not(col(fieldName).equalTo(2))).collect()
37 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames))
38 | assertThat(res).hasSize(1)
39 | assertThat(res.head(fieldName)).isEqualTo(value)
40 | val sqlRes = spark.sql(
41 | s"""
42 | |SELECT * FROM compositeFilter
43 | |WHERE NOT ($fieldName = 2)
44 | |""".stripMargin).collect()
45 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames))
46 | assertThat(sqlRes).hasSize(1)
47 | assertThat(sqlRes.head(fieldName)).isEqualTo(value)
48 | }
49 |
50 | @Test
51 | def orAndFilter(): Unit = {
52 | val fieldName1 = "integer"
53 | val value1 = CompositeFilterTest.data.head(fieldName1)
54 |
55 | val fieldName2 = "string"
56 | val value2 = CompositeFilterTest.data.head(fieldName2)
57 |
58 | val res = df.filter(col("bool").equalTo(false) or (col(fieldName1).equalTo(value1) and col(fieldName2).equalTo(value2))).collect()
59 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames))
60 | assertThat(res).hasSize(1)
61 | assertThat(res.head(fieldName1)).isEqualTo(value1)
62 | val sqlRes = spark.sql(
63 | s"""
64 | |SELECT * FROM compositeFilter
65 | |WHERE bool = false OR ($fieldName1 = $value1 AND $fieldName2 = "$value2")
66 | |""".stripMargin).collect()
67 | .map(_.getValuesMap[Any](CompositeFilterTest.schema.fieldNames))
68 | assertThat(sqlRes).hasSize(1)
69 | assertThat(sqlRes.head(fieldName1)).isEqualTo(value1)
70 | }
71 |
72 | }
73 |
74 | object CompositeFilterTest {
75 | private var df: DataFrame = _
76 | private val data: Seq[Map[String, Any]] = Seq(
77 | Map(
78 | "integer" -> 1,
79 | "string" -> "one",
80 | "bool" -> true
81 | ),
82 | Map(
83 | "integer" -> 2,
84 | "string" -> "two",
85 | "bool" -> true
86 | )
87 | )
88 |
89 | private val schema = StructType(Array(
90 | // atomic types
91 | StructField("integer", IntegerType, nullable = false),
92 | StructField("string", StringType, nullable = false),
93 | StructField("bool", BooleanType, nullable = false)
94 | ))
95 |
96 | @BeforeAll
97 | def init(): Unit = {
98 | df = BaseSparkTest.createDF("compositeFilter", data, schema)
99 | }
100 |
101 | @AfterAll
102 | def cleanup(): Unit = {
103 | BaseSparkTest.dropTable("compositeFilter")
104 | }
105 | }
--------------------------------------------------------------------------------
/docker/start_db_macos.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Configuration environment variables:
4 | # STARTER_MODE: (single|cluster|activefailover), default single
5 | # DOCKER_IMAGE: ArangoDB docker image, default docker.io/arangodb/enterprise:latest
6 | # SSL: (true|false), default false
7 | # DATABASE_EXTENDED_NAMES: (true|false), default false
8 | # ARANGO_LICENSE_KEY: only required for ArangoDB Enterprise
9 |
10 | # EXAMPLE:
11 | # STARTER_MODE=cluster SSL=true ./start_db.sh
12 |
13 | STARTER_MODE=${STARTER_MODE:=single}
14 | DOCKER_IMAGE=${DOCKER_IMAGE:=docker.io/arangodb/enterprise:latest}
15 | SSL=${SSL:=false}
16 | DATABASE_EXTENDED_NAMES=${DATABASE_EXTENDED_NAMES:=false}
17 |
18 | STARTER_DOCKER_IMAGE=docker.io/arangodb/arangodb-starter:latest
19 | GW=172.28.0.1
20 | LOCALGW=localhost
21 | docker network create arangodb --subnet 172.28.0.0/16
22 |
23 | # exit when any command fails
24 | set -e
25 |
26 | docker pull $STARTER_DOCKER_IMAGE
27 | docker pull $DOCKER_IMAGE
28 |
29 | LOCATION=$(pwd)/$(dirname "$0")
30 |
31 | echo "Averysecretword" > "$LOCATION"/jwtSecret
32 | docker run --rm -v "$LOCATION"/jwtSecret:/jwtSecret "$STARTER_DOCKER_IMAGE" auth header --auth.jwt-secret /jwtSecret > "$LOCATION"/jwtHeader
33 | AUTHORIZATION_HEADER=$(cat "$LOCATION"/jwtHeader)
34 |
35 | STARTER_ARGS=
36 | SCHEME=http
37 | ARANGOSH_SCHEME=http+tcp
38 | COORDINATORS=("$LOCALGW:8529" "$LOCALGW:8539" "$LOCALGW:8549")
39 | COORDINATORSINTERNAL=("$GW:8529" "$GW:8539" "$GW:8549")
40 |
41 | if [ "$STARTER_MODE" == "single" ]; then
42 | COORDINATORS=("$LOCALGW:8529")
43 | COORDINATORSINTERNAL=("$GW:8529")
44 | fi
45 |
46 | if [ "$SSL" == "true" ]; then
47 | STARTER_ARGS="$STARTER_ARGS --ssl.keyfile=server.pem"
48 | SCHEME=https
49 | ARANGOSH_SCHEME=http+ssl
50 | fi
51 |
52 | if [ "$DATABASE_EXTENDED_NAMES" == "true" ]; then
53 | STARTER_ARGS="${STARTER_ARGS} --all.database.extended-names-databases=true"
54 | fi
55 |
56 | if [ "$USE_MOUNTED_DATA" == "true" ]; then
57 | STARTER_ARGS="${STARTER_ARGS} --starter.data-dir=/data"
58 | MOUNT_DATA="-v $LOCATION/data:/data"
59 | fi
60 |
61 | docker run -d \
62 | --name=adb \
63 | -p 8528:8528 \
64 | -v "$LOCATION"/server.pem:/server.pem \
65 | -v "$LOCATION"/jwtSecret:/jwtSecret \
66 | $MOUNT_DATA \
67 | -v /var/run/docker.sock:/var/run/docker.sock \
68 | -e ARANGO_LICENSE_KEY="$ARANGO_LICENSE_KEY" \
69 | $STARTER_DOCKER_IMAGE \
70 | $STARTER_ARGS \
71 | --docker.net-mode=default \
72 | --docker.container=adb \
73 | --auth.jwt-secret=/jwtSecret \
74 | --starter.address="${GW}" \
75 | --docker.image="${DOCKER_IMAGE}" \
76 | --starter.local --starter.mode=${STARTER_MODE} --all.log.level=debug --all.log.output=+ --log.verbose
77 |
78 |
79 | wait_server() {
80 | # shellcheck disable=SC2091
81 | until $(curl --output /dev/null --insecure --fail --silent --head -i -H "$AUTHORIZATION_HEADER" "$SCHEME://$1/_api/version"); do
82 | printf '.'
83 | sleep 1
84 | done
85 | }
86 |
87 | echo "Waiting..."
88 |
89 | for a in ${COORDINATORS[*]} ; do
90 | wait_server "$a"
91 | done
92 |
93 | set +e
94 | ITER=0
95 | for a in ${COORDINATORS[*]} ; do
96 | echo ""
97 | echo "Setting username and password..."
98 | docker run --rm ${DOCKER_IMAGE} arangosh --server.endpoint="$ARANGOSH_SCHEME://${COORDINATORSINTERNAL[ITER]}" --server.authentication=false --javascript.execute-string='require("org/arangodb/users").update("root", "test")'
99 | ITER=$(expr $ITER + 1)
100 | done
101 | set -e
102 |
103 | for a in ${COORDINATORS[*]} ; do
104 | echo ""
105 | echo "Requesting endpoint version..."
106 | curl -u root:test --insecure --fail "$SCHEME://$a/_api/version"
107 | done
108 |
109 | echo ""
110 | echo ""
111 | echo "Done, your deployment is reachable at: "
112 | for a in ${COORDINATORS[*]} ; do
113 | echo "$SCHEME://$a"
114 | echo ""
115 | done
116 |
117 | if [ "$STARTER_MODE" == "activefailover" ]; then
118 | LEADER=$("$LOCATION"/find_active_endpoint.sh)
119 | echo "Leader: $SCHEME://$LEADER"
120 | echo ""
121 | fi
122 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import org.apache.spark.SPARK_VERSION
4 | import org.apache.spark.sql.DataFrame
5 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
6 | import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StringType, StructField, StructType}
7 | import org.assertj.core.api.Assertions.assertThat
8 | import org.junit.jupiter.api.Assumptions.assumeTrue
9 | import org.junit.jupiter.params.ParameterizedTest
10 | import org.junit.jupiter.params.provider.ValueSource
11 |
12 | class DeserializationCastTest extends BaseSparkTest {
13 | private val collectionName = "deserializationCast"
14 |
15 | @ParameterizedTest
16 | @ValueSource(strings = Array("vpack", "json"))
17 | def numberIntToStringCast(contentType: String): Unit = doTestImplicitCast(
18 | StructType(Array(StructField("a", StringType))),
19 | Seq(Map("a" -> 1)),
20 | Seq("""{"a":1}"""),
21 | contentType
22 | )
23 |
24 | @ParameterizedTest
25 | @ValueSource(strings = Array("vpack", "json"))
26 | def numberDecToStringCast(contentType: String): Unit = doTestImplicitCast(
27 | StructType(Array(StructField("a", StringType))),
28 | Seq(Map("a" -> 1.1)),
29 | Seq("""{"a":1.1}"""),
30 | contentType
31 | )
32 |
33 | @ParameterizedTest
34 | @ValueSource(strings = Array("vpack", "json"))
35 | def boolToStringCast(contentType: String): Unit = doTestImplicitCast(
36 | StructType(Array(StructField("a", StringType))),
37 | Seq(Map("a" -> true)),
38 | Seq("""{"a":true}"""),
39 | contentType
40 | )
41 |
42 | @ParameterizedTest
43 | @ValueSource(strings = Array("vpack", "json"))
44 | def objectToStringCast(contentType: String): Unit = doTestImplicitCast(
45 | StructType(Array(StructField("a", StringType))),
46 | Seq(Map("a" -> Map("b" -> "c"))),
47 | Seq("""{"a":{"b":"c"}}"""),
48 | contentType
49 | )
50 |
51 | @ParameterizedTest
52 | @ValueSource(strings = Array("vpack", "json"))
53 | def arrayToStringCast(contentType: String): Unit = doTestImplicitCast(
54 | StructType(Array(StructField("a", StringType))),
55 | Seq(Map("a" -> Array(1, 2))),
56 | Seq("""{"a":[1,2]}"""),
57 | contentType
58 | )
59 |
60 | @ParameterizedTest
61 | @ValueSource(strings = Array("vpack", "json"))
62 | def nullToIntegerCast(contentType: String): Unit = {
63 | doTestImplicitCast(
64 | StructType(Array(StructField("a", IntegerType, nullable = false))),
65 | Seq(Map("a" -> null)),
66 | Seq("""{"a":0}"""),
67 | contentType
68 | )
69 | }
70 |
71 | @ParameterizedTest
72 | @ValueSource(strings = Array("vpack", "json"))
73 | def nullToDoubleCast(contentType: String): Unit = {
74 | doTestImplicitCast(
75 | StructType(Array(StructField("a", DoubleType, nullable = false))),
76 | Seq(Map("a" -> null)),
77 | Seq("""{"a":0.0}"""),
78 | contentType
79 | )
80 | }
81 |
82 | @ParameterizedTest
83 | @ValueSource(strings = Array("vpack", "json"))
84 | def nullAsBoolean(contentType: String): Unit = {
85 | doTestImplicitCast(
86 | StructType(Array(StructField("a", BooleanType, nullable = false))),
87 | Seq(Map("a" -> null)),
88 | Seq("""{"a":false}"""),
89 | contentType
90 | )
91 | }
92 |
93 | private def doTestImplicitCast(
94 | schema: StructType,
95 | data: Iterable[Map[String, Any]],
96 | jsonData: Seq[String],
97 | contentType: String
98 | ) = {
99 |
100 | /**
101 | * FIXME: many vpack tests fail
102 | */
103 | assumeTrue(contentType != "vpack")
104 |
105 | import spark.implicits._
106 | val dfFromJson: DataFrame = spark.read.schema(schema).json(jsonData.toDS)
107 | dfFromJson.show()
108 | val df = BaseSparkTest.createDF(collectionName, data, schema, Map(ArangoDBConf.CONTENT_TYPE -> contentType))
109 | assertThat(df.collect()).isEqualTo(dfFromJson.collect())
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | // scalastyle:off
19 |
20 | package org.apache.spark.sql.arangodb.datasource.mapping.json
21 |
22 | import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
23 | import org.apache.hadoop.io.Text
24 | import org.apache.spark.sql.catalyst.InternalRow
25 | import org.apache.spark.unsafe.types.UTF8String
26 | import sun.nio.cs.StreamDecoder
27 |
28 | import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}
29 | import java.nio.channels.Channels
30 | import java.nio.charset.{Charset, StandardCharsets}
31 |
32 | private[sql] object CreateJacksonParser extends Serializable {
33 | def string(jsonFactory: JsonFactory, record: String): JsonParser = {
34 | jsonFactory.createParser(record)
35 | }
36 |
37 | def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
38 | val bb = record.getByteBuffer
39 | assert(bb.hasArray)
40 |
41 | val bain = new ByteArrayInputStream(
42 | bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
43 |
44 | jsonFactory.createParser(new InputStreamReader(bain, StandardCharsets.UTF_8))
45 | }
46 |
47 | def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
48 | jsonFactory.createParser(record.getBytes, 0, record.getLength)
49 | }
50 |
51 | // Jackson parsers can be ranked according to their performance:
52 | // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser
53 | // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically
54 | // by checking leading bytes of the array.
55 | // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected
56 | // automatically by analyzing first bytes of the input stream.
57 | // 3. Reader based parser. This is the slowest parser used here but it allows to create
58 | // a reader with specific encoding.
59 | // The method creates a reader for an array with given encoding and sets size of internal
60 | // decoding buffer according to size of input array.
61 | private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = {
62 | val bais = new ByteArrayInputStream(in, 0, length)
63 | val byteChannel = Channels.newChannel(bais)
64 | val decodingBufferSize = Math.min(length, 8192)
65 | val decoder = Charset.forName(enc).newDecoder()
66 |
67 | StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize)
68 | }
69 |
70 | def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = {
71 | val sd = getStreamDecoder(enc, record.getBytes, record.getLength)
72 | jsonFactory.createParser(sd)
73 | }
74 |
75 | def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = {
76 | jsonFactory.createParser(is)
77 | }
78 |
79 | def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = {
80 | jsonFactory.createParser(new InputStreamReader(is, enc))
81 | }
82 |
83 | def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
84 | val ba = row.getBinary(0)
85 |
86 | jsonFactory.createParser(ba, 0, ba.length)
87 | }
88 |
89 | def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
90 | val binary = row.getBinary(0)
91 | val sd = getStreamDecoder(enc, binary, binary.length)
92 |
93 | jsonFactory.createParser(sd)
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/arangodb-spark-datasource-3.5/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | // scalastyle:off
19 |
20 | package org.apache.spark.sql.arangodb.datasource.mapping.json
21 |
22 | import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
23 | import org.apache.hadoop.io.Text
24 | import org.apache.spark.sql.catalyst.InternalRow
25 | import org.apache.spark.unsafe.types.UTF8String
26 | import sun.nio.cs.StreamDecoder
27 |
28 | import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}
29 | import java.nio.channels.Channels
30 | import java.nio.charset.{Charset, StandardCharsets}
31 |
32 | private[sql] object CreateJacksonParser extends Serializable {
33 | def string(jsonFactory: JsonFactory, record: String): JsonParser = {
34 | jsonFactory.createParser(record)
35 | }
36 |
37 | def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
38 | val bb = record.getByteBuffer
39 | assert(bb.hasArray)
40 |
41 | val bain = new ByteArrayInputStream(
42 | bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
43 |
44 | jsonFactory.createParser(new InputStreamReader(bain, StandardCharsets.UTF_8))
45 | }
46 |
47 | def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
48 | jsonFactory.createParser(record.getBytes, 0, record.getLength)
49 | }
50 |
51 | // Jackson parsers can be ranked according to their performance:
52 | // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser
53 | // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically
54 | // by checking leading bytes of the array.
55 | // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected
56 | // automatically by analyzing first bytes of the input stream.
57 | // 3. Reader based parser. This is the slowest parser used here but it allows to create
58 | // a reader with specific encoding.
59 | // The method creates a reader for an array with given encoding and sets size of internal
60 | // decoding buffer according to size of input array.
61 | private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = {
62 | val bais = new ByteArrayInputStream(in, 0, length)
63 | val byteChannel = Channels.newChannel(bais)
64 | val decodingBufferSize = Math.min(length, 8192)
65 | val decoder = Charset.forName(enc).newDecoder()
66 |
67 | StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize)
68 | }
69 |
70 | def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = {
71 | val sd = getStreamDecoder(enc, record.getBytes, record.getLength)
72 | jsonFactory.createParser(sd)
73 | }
74 |
75 | def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = {
76 | jsonFactory.createParser(is)
77 | }
78 |
79 | def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = {
80 | jsonFactory.createParser(new InputStreamReader(is, enc))
81 | }
82 |
83 | def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
84 | val ba = row.getBinary(0)
85 |
86 | jsonFactory.createParser(ba, 0, ba.length)
87 | }
88 |
89 | def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
90 | val binary = row.getBinary(0)
91 | val sd = getStreamDecoder(enc, binary, binary.length)
92 |
93 | jsonFactory.createParser(sd)
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/test_deserialization_cast.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Any
2 |
3 | import arango.database
4 | import pytest
5 | from pyspark.sql import SparkSession
6 | from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, BooleanType
7 |
8 | from integration import test_basespark
9 |
10 | COLLECTION_NAME = "deserializationCast"
11 | content_types = ["vpack", "json"]
12 |
13 |
14 | def check_implicit_cast(db: arango.database.StandardDatabase, spark: SparkSession, schema: StructType, data: List[Dict[str, Any]], json_data: List[str], content_type: str):
15 | # FIXME: many vpack tests are failing
16 | if content_type == "vpack":
17 | pytest.xfail("Too many vpack tests fail")
18 |
19 | df_from_json = spark.read.schema(schema).json(spark.sparkContext.parallelize(json_data))
20 | df_from_json.show()
21 |
22 | df = test_basespark.create_df(db, spark, COLLECTION_NAME, data, schema, {"contentType": content_type})
23 | assert df.collect() == df_from_json.collect()
24 |
25 |
26 | @pytest.mark.parametrize("content_type", content_types)
27 | def test_number_int_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
28 | check_implicit_cast(
29 | database_conn,
30 | spark,
31 | StructType([StructField("a", StringType())]),
32 | [{"a": 1}],
33 | ['{"a":1}'],
34 | content_type
35 | )
36 |
37 |
38 | @pytest.mark.parametrize("content_type", content_types)
39 | def test_number_dec_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
40 | check_implicit_cast(
41 | database_conn,
42 | spark,
43 | StructType([StructField("a", StringType())]),
44 | [{"a": 1.1}],
45 | ['{"a":1.1}'],
46 | content_type
47 | )
48 |
49 |
50 | @pytest.mark.parametrize("content_type", content_types)
51 | def test_bool_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
52 | check_implicit_cast(
53 | database_conn,
54 | spark,
55 | StructType([StructField("a", StringType())]),
56 | [{"a": True}],
57 | ['{"a":true}'],
58 | content_type
59 | )
60 |
61 |
62 | @pytest.mark.parametrize("content_type", content_types)
63 | def test_object_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
64 | check_implicit_cast(
65 | database_conn,
66 | spark,
67 | StructType([StructField("a", StringType())]),
68 | [{"a": {"b": "c"}}],
69 | ['{"a":{"b":"c"}}'],
70 | content_type
71 | )
72 |
73 |
74 | @pytest.mark.parametrize("content_type", content_types)
75 | def test_array_to_string_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
76 | check_implicit_cast(
77 | database_conn,
78 | spark,
79 | StructType([StructField("a", StringType())]),
80 | [{"a": [1, 2]}],
81 | ['{"a":[1,2]}'],
82 | content_type
83 | )
84 |
85 |
86 | @pytest.mark.parametrize("content_type", content_types)
87 | def test_null_to_integer_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
88 | check_implicit_cast(
89 | database_conn,
90 | spark,
91 | StructType([StructField("a", IntegerType())]),
92 | [{"a": None}],
93 | ['{"a":null}'],
94 | content_type
95 | )
96 |
97 |
98 | @pytest.mark.parametrize("content_type", content_types)
99 | def test_null_to_double_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
100 | check_implicit_cast(
101 | database_conn,
102 | spark,
103 | StructType([StructField("a", DoubleType())]),
104 | [{"a": None}],
105 | ['{"a":null}'],
106 | content_type
107 | )
108 |
109 |
110 | @pytest.mark.parametrize("content_type", content_types)
111 | def test_null_to_boolean_cast(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
112 | check_implicit_cast(
113 | database_conn,
114 | spark,
115 | StructType([StructField("a", BooleanType())]),
116 | [{"a": None}],
117 | ['{"a":null}'],
118 | content_type
119 | )
120 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/test_readwrite_datatype.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import sys
3 | from datetime import datetime, date
4 | from decimal import Context
5 |
6 | import arango.database
7 | import pyspark.sql
8 | import pytest
9 | from py4j.protocol import Py4JJavaError
10 | from pyspark.sql import SparkSession
11 | from pyspark.sql.types import StructType, StructField, BooleanType, DoubleType, FloatType, IntegerType, LongType, \
12 | DateType, TimestampType, ShortType, ByteType, StringType, ArrayType, Row, MapType, DecimalType
13 |
14 | from integration import test_basespark
15 | from integration.utils import combine_dicts
16 |
17 | COLLECTION_NAME = "datatypes"
18 |
19 | data = [
20 | [
21 | False,
22 | 1.1,
23 | 0.09375,
24 | 1,
25 | 1,
26 | date.fromisoformat("2021-01-01"),
27 | datetime.fromisoformat("2021-01-01 01:01:01.111").astimezone(),
28 | 1,
29 | 1,
30 | "one",
31 | [1, 1, 1],
32 | [["a", "b", "c"], ["d", "e", "f"]],
33 | {"a": 1, "b": 1},
34 | Row("a1", 1)
35 | ],
36 | [
37 | True,
38 | 2.2,
39 | 2.2,
40 | 2,
41 | 2,
42 | date.fromisoformat("2022-02-02"),
43 | datetime.fromisoformat("2022-02-02 02:02:02.222").astimezone(),
44 | 2,
45 | 2,
46 | "two",
47 | [2, 2, 2],
48 | [["a", "b", "c"], ["d", "e", "f"]],
49 | {"a": 2, "b": 2},
50 | Row("a1", 2)
51 | ]
52 | ]
53 |
54 | struct_fields = [
55 | StructField("bool", BooleanType(), nullable=False),
56 | StructField("double", DoubleType(), nullable=False),
57 | StructField("float", FloatType(), nullable=False),
58 | StructField("integer", IntegerType(), nullable=False),
59 | StructField("long", LongType(), nullable=False),
60 | StructField("date", DateType(), nullable=False),
61 | StructField("timestamp", TimestampType(), nullable=False),
62 | StructField("short", ShortType(), nullable=False),
63 | StructField("byte", ByteType(), nullable=False),
64 | StructField("string", StringType(), nullable=False),
65 | StructField("intArray", ArrayType(IntegerType()), nullable=False),
66 | StructField("stringArrayArray", ArrayType(ArrayType(StringType())), nullable=False),
67 | StructField("intMap", MapType(StringType(), IntegerType()), nullable=False),
68 | StructField("struct", StructType([
69 | StructField("a", StringType()),
70 | StructField("b", IntegerType())
71 | ]))
72 | ]
73 |
74 | schema = StructType(struct_fields)
75 |
76 |
77 | @pytest.mark.parametrize("protocol,content_type", test_basespark.protocol_and_content_type)
78 | def test_round_trip_read_write(spark: SparkSession, protocol: str, content_type: str):
79 | df = spark.createDataFrame(spark.sparkContext.parallelize([Row(*x) for x in data]), schema)
80 | round_trip_readwrite(spark, df, protocol, content_type)
81 |
82 |
83 | @pytest.mark.parametrize("protocol,content_type", test_basespark.protocol_and_content_type)
84 | def test_round_trip_read_write_decimal_type(spark: SparkSession, protocol: str, content_type: str):
85 | if content_type != "vpack":
86 | pytest.xfail("vpack does not support round trip decimal types")
87 |
88 | schema_with_decimal = StructType(copy.deepcopy(struct_fields))
89 | schema_with_decimal.add(StructField("decimal", DecimalType(38, 18), nullable=False))
90 | df = spark.createDataFrame(spark.sparkContext.parallelize([Row(*x, Context(prec=38).create_decimal("2.22222222")) for x in data]), schema_with_decimal)
91 | round_trip_readwrite(spark, df, protocol, content_type)
92 |
93 |
94 | def write_df(df: pyspark.sql.DataFrame, protocol: str, content_type: str):
95 | all_opts = combine_dicts([
96 | test_basespark.options,
97 | {
98 | "table": COLLECTION_NAME,
99 | "protocol": protocol,
100 | "contentType": content_type,
101 | "overwriteMode": "replace",
102 | "confirmTruncate": "true"
103 | }
104 | ])
105 | df.write\
106 | .format(test_basespark.arango_datasource_name)\
107 | .mode("overwrite")\
108 | .options(**all_opts)\
109 | .save()
110 |
111 |
112 | def round_trip_readwrite(spark: SparkSession, df: pyspark.sql.DataFrame, protocol: str, content_type: str):
113 | initial = df.collect()
114 |
115 | all_opts = combine_dicts([test_basespark.options, {"table": COLLECTION_NAME}])
116 |
117 | write_df(df, protocol, content_type)
118 | read = spark.read.format(test_basespark.arango_datasource_name)\
119 | .options(**all_opts)\
120 | .schema(df.schema)\
121 | .load()\
122 | .collect()
123 |
124 | assert initial.sort() == read.sort()
125 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/write/test_savemode.py:
--------------------------------------------------------------------------------
1 | import arango
2 | import arango.database
3 | import arango.collection
4 | import pytest
5 | import pyspark.sql
6 | from py4j.protocol import Py4JJavaError
7 | from pyspark.sql import SparkSession
8 | from pyspark.sql.utils import AnalysisException
9 |
10 | from integration.test_basespark import protocol_and_content_type, options, arango_datasource_name
11 | from integration.utils import combine_dicts
12 |
13 |
14 | COLLECTION_NAME = "chessPlayersSaveMode"
15 |
16 | data = [
17 | ("Carlsen", "Magnus"),
18 | ("Caruana", "Fabiano"),
19 | ("Ding", "Liren"),
20 | ("Nepomniachtchi", "Ian"),
21 | ("Aronian", "Levon"),
22 | ("Grischuk", "Alexander"),
23 | ("Giri", "Anish"),
24 | ("Mamedyarov", "Shakhriyar"),
25 | ("So", "Wesley"),
26 | ("Radjabov", "Teimour")
27 | ]
28 |
29 |
30 | @pytest.fixture(scope="function")
31 | def chess_collection(database_conn: arango.database.StandardDatabase) -> arango.collection.StandardCollection:
32 | if database_conn.has_collection(COLLECTION_NAME):
33 | database_conn.delete_collection(COLLECTION_NAME)
34 | yield database_conn.collection(COLLECTION_NAME)
35 |
36 |
37 | @pytest.fixture
38 | def chess_df(spark: SparkSession) -> pyspark.sql.DataFrame:
39 | df = spark.createDataFrame(data, schema=["surname", "name"])
40 | return df
41 |
42 |
43 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type)
44 | def test_savemode_append(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, protocol: str, content_type: str):
45 | all_opts = combine_dicts([options, {
46 | "table": COLLECTION_NAME,
47 | "protocol": protocol,
48 | "contentType": content_type
49 | }])
50 |
51 | chess_df.write\
52 | .format(arango_datasource_name)\
53 | .mode("Append")\
54 | .options(**all_opts)\
55 | .save()
56 |
57 | assert chess_collection.count() == 10
58 |
59 |
60 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type)
61 | def test_savemode_append_with_existing_collection(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, database_conn: arango.database.StandardDatabase, protocol: str, content_type: str):
62 | database_conn.create_collection(COLLECTION_NAME)
63 | chess_collection.insert({})
64 |
65 | all_opts = combine_dicts([options, {
66 | "table": COLLECTION_NAME,
67 | "protocol": protocol,
68 | "contentType": content_type
69 | }])
70 |
71 | chess_df.write \
72 | .format(arango_datasource_name) \
73 | .mode("Append") \
74 | .options(**all_opts) \
75 | .save()
76 |
77 | assert chess_collection.count() == 11
78 |
79 |
80 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type)
81 | def test_savemode_overwrite_should_throw_whenused_alone(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, protocol: str, content_type: str):
82 | all_opts = combine_dicts([options, {
83 | "table": COLLECTION_NAME,
84 | "protocol": protocol,
85 | "contentType": content_type
86 | }])
87 |
88 | with pytest.raises(AnalysisException) as e:
89 | chess_df.write \
90 | .format(arango_datasource_name) \
91 | .mode("Overwrite") \
92 | .options(**all_opts) \
93 | .save()
94 |
95 | e.match("confirmTruncate")
96 |
97 |
98 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type)
99 | def test_savemode_overwrite(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, protocol: str, content_type: str):
100 | all_opts = combine_dicts([options, {
101 | "table": COLLECTION_NAME,
102 | "protocol": protocol,
103 | "contentType": content_type,
104 | "confirmTruncate": "true"
105 | }])
106 |
107 | chess_df.write \
108 | .format(arango_datasource_name) \
109 | .mode("Overwrite") \
110 | .options(**all_opts) \
111 | .save()
112 |
113 | assert chess_collection.count() == 10
114 |
115 |
116 | @pytest.mark.parametrize("protocol,content_type", protocol_and_content_type)
117 | def test_savemode_overwrite_with_existing_collection(chess_df: pyspark.sql.DataFrame, chess_collection: arango.collection.StandardCollection, database_conn: arango.database.StandardDatabase, protocol: str, content_type: str):
118 | database_conn.create_collection(COLLECTION_NAME)
119 | chess_collection.insert({})
120 |
121 | all_opts = combine_dicts([options, {
122 | "table": COLLECTION_NAME,
123 | "protocol": protocol,
124 | "contentType": content_type,
125 | "confirmTruncate": "true"
126 | }])
127 |
128 | chess_df.write \
129 | .format(arango_datasource_name) \
130 | .mode("Overwrite") \
131 | .options(**all_opts) \
132 | .save()
133 |
134 | assert chess_collection.count() == 10
135 |
--------------------------------------------------------------------------------
/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/BadRecordsTest.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql.arangodb.datasource
2 |
3 | import org.apache.spark.{SPARK_VERSION, SparkException}
4 | import org.apache.spark.sql.DataFrame
5 | import org.apache.spark.sql.arangodb.commons.ArangoDBConf
6 | import org.apache.spark.sql.catalyst.util.{BadRecordException, DropMalformedMode, FailFastMode, ParseMode}
7 | import org.apache.spark.sql.types._
8 | import org.assertj.core.api.Assertions.{assertThat, catchThrowable}
9 | import org.assertj.core.api.ThrowableAssert.ThrowingCallable
10 | import org.junit.jupiter.params.ParameterizedTest
11 | import org.junit.jupiter.params.provider.ValueSource
12 |
13 | class BadRecordsTest extends BaseSparkTest {
14 | private val collectionName = "deserializationCast"
15 |
16 | @ParameterizedTest
17 | @ValueSource(strings = Array("vpack", "json"))
18 | def stringAsInteger(contentType: String): Unit = testBadRecord(
19 | StructType(Array(StructField("a", IntegerType))),
20 | Seq(Map("a" -> "1")),
21 | Seq("""{"a":"1"}"""),
22 | contentType
23 | )
24 |
25 | @ParameterizedTest
26 | @ValueSource(strings = Array("vpack", "json"))
27 | def booleanAsInteger(contentType: String): Unit = testBadRecord(
28 | StructType(Array(StructField("a", IntegerType))),
29 | Seq(Map("a" -> true)),
30 | Seq("""{"a":true}"""),
31 | contentType
32 | )
33 |
34 | @ParameterizedTest
35 | @ValueSource(strings = Array("vpack", "json"))
36 | def stringAsDouble(contentType: String): Unit = testBadRecord(
37 | StructType(Array(StructField("a", DoubleType))),
38 | Seq(Map("a" -> "1")),
39 | Seq("""{"a":"1"}"""),
40 | contentType
41 | )
42 |
43 | @ParameterizedTest
44 | @ValueSource(strings = Array("vpack", "json"))
45 | def booleanAsDouble(contentType: String): Unit = testBadRecord(
46 | StructType(Array(StructField("a", DoubleType))),
47 | Seq(Map("a" -> true)),
48 | Seq("""{"a":true}"""),
49 | contentType
50 | )
51 |
52 | @ParameterizedTest
53 | @ValueSource(strings = Array("vpack", "json"))
54 | def stringAsBoolean(contentType: String): Unit = testBadRecord(
55 | StructType(Array(StructField("a", BooleanType))),
56 | Seq(Map("a" -> "true")),
57 | Seq("""{"a":"true"}"""),
58 | contentType
59 | )
60 |
61 | @ParameterizedTest
62 | @ValueSource(strings = Array("vpack", "json"))
63 | def numberAsBoolean(contentType: String): Unit = testBadRecord(
64 | StructType(Array(StructField("a", BooleanType))),
65 | Seq(Map("a" -> 1)),
66 | Seq("""{"a":1}"""),
67 | contentType
68 | )
69 |
70 | private def testBadRecord(
71 | schema: StructType,
72 | data: Iterable[Map[String, Any]],
73 | jsonData: Seq[String],
74 | contentType: String
75 | ) = {
76 | // PERMISSIVE
77 | doTestBadRecord(schema, data, jsonData, Map(ArangoDBConf.CONTENT_TYPE -> contentType))
78 |
79 | // PERMISSIVE with columnNameOfCorruptRecord
80 | doTestBadRecord(
81 | schema.add(StructField("corruptRecord", StringType)),
82 | data,
83 | jsonData,
84 | Map(
85 | ArangoDBConf.CONTENT_TYPE -> contentType,
86 | ArangoDBConf.COLUMN_NAME_OF_CORRUPT_RECORD -> "corruptRecord"
87 | )
88 | )
89 |
90 | // DROPMALFORMED
91 | doTestBadRecord(schema, data, jsonData,
92 | Map(
93 | ArangoDBConf.CONTENT_TYPE -> contentType,
94 | ArangoDBConf.PARSE_MODE -> DropMalformedMode.name
95 | )
96 | )
97 |
98 | // FAILFAST
99 | val df = BaseSparkTest.createDF(collectionName, data, schema, Map(
100 | ArangoDBConf.CONTENT_TYPE -> contentType,
101 | ArangoDBConf.PARSE_MODE -> FailFastMode.name
102 | ))
103 | val thrown = catchThrowable(new ThrowingCallable() {
104 | override def call(): Unit = df.collect()
105 | })
106 |
107 | assertThat(thrown.getCause).isInstanceOf(classOf[SparkException])
108 | assertThat(thrown.getCause).hasMessageContaining("Malformed record")
109 | assertThat(thrown.getCause).hasCauseInstanceOf(classOf[BadRecordException])
110 | }
111 |
112 | private def doTestBadRecord(
113 | schema: StructType,
114 | data: Iterable[Map[String, Any]],
115 | jsonData: Seq[String],
116 | opts: Map[String, String] = Map.empty
117 | ) = {
118 | import spark.implicits._
119 | val dfFromJson: DataFrame = spark.read.schema(schema).options(opts).json(jsonData.toDS)
120 | dfFromJson.show()
121 |
122 | val tableDF = BaseSparkTest.createDF(collectionName, data, schema, opts)
123 | assertThat(tableDF.collect()).isEqualTo(dfFromJson.collect())
124 |
125 | val queryDF = BaseSparkTest.createQueryDF(s"RETURN ${jsonData.head}", schema, opts)
126 | assertThat(queryDF.collect()).isEqualTo(dfFromJson.collect())
127 | }
128 |
129 | }
130 |
--------------------------------------------------------------------------------
/python-integration-tests/integration/test_bad_records.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | import arango.database
4 | import pytest
5 | from py4j.protocol import Py4JJavaError
6 | from pyspark.sql import SparkSession
7 | from pyspark.sql.types import StructType, StringType, StructField, IntegerType, DoubleType, BooleanType
8 |
9 | from integration import test_basespark
10 | from integration.test_basespark import options, arango_datasource_name
11 | from integration.utils import combine_dicts
12 |
13 |
14 | content_types = ["vpack", "json"]
15 | COLLECTION_NAME = "deserializationCast"
16 |
17 |
18 | def do_test_bad_record(db: arango.database.StandardDatabase, spark: SparkSession, schema: StructType, data: List[Dict[str, Any]], json_data: List[str], opts: Dict[str, str]):
19 | df_from_json = spark.read.schema(schema).options(**opts).json(spark.sparkContext.parallelize(json_data))
20 | df_from_json.show()
21 |
22 | table_df = test_basespark.create_df(db, spark, COLLECTION_NAME, data, schema, opts)
23 | assert table_df.collect() == df_from_json.collect()
24 |
25 | query_df = test_basespark.create_query_df(spark, f"RETURN {json_data[0]}", schema, opts)
26 | assert query_df.collect() == df_from_json.collect()
27 |
28 |
29 | def check_bad_record(db: arango.database.StandardDatabase, spark: SparkSession, schema: StructType, data: List[Dict[str, Any]], json_data: List[str], content_type: str):
30 | # Permissive
31 | do_test_bad_record(db, spark, schema, data, json_data, {"contentType": content_type})
32 |
33 | # Permissive with column name of corrupt record
34 | do_test_bad_record(
35 | db,
36 | spark,
37 | schema.add(StructField("corruptRecord", StringType())),
38 | data,
39 | json_data,
40 | {
41 | "contentType": content_type,
42 | "columnNameOfCorruptRecord": "corruptRecord"
43 | }
44 | )
45 |
46 | # Dropmalformed
47 | do_test_bad_record(db, spark, schema, data, json_data,
48 | {
49 | "contentType": content_type,
50 | "mode": "DROPMALFORMED"
51 | })
52 |
53 | # Failfast
54 | df = test_basespark.create_df(db, spark, COLLECTION_NAME, data, schema,
55 | {
56 | "contentType": content_type,
57 | "mode": "FAILFAST"
58 | })
59 | with pytest.raises(Py4JJavaError) as e:
60 | df.collect()
61 |
62 | e.match("SparkException")
63 | e.match("Malformed record")
64 | e.match("BadRecordException")
65 |
66 |
67 | @pytest.mark.parametrize("content_type", content_types)
68 | def test_string_as_integer(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
69 | check_bad_record(
70 | database_conn,
71 | spark,
72 | StructType([StructField("a", IntegerType())]),
73 | [{"a": "1"}],
74 | ['{"a":"1"}'],
75 | content_type
76 | )
77 |
78 |
79 | @pytest.mark.parametrize("content_type", content_types)
80 | def test_boolean_as_integer(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
81 | check_bad_record(
82 | database_conn,
83 | spark,
84 | StructType([StructField("a", IntegerType())]),
85 | [{"a": True}],
86 | ['{"a":true}'],
87 | content_type
88 | )
89 |
90 |
91 | @pytest.mark.parametrize("content_type", content_types)
92 | def test_string_as_double(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
93 | check_bad_record(
94 | database_conn,
95 | spark,
96 | StructType([StructField("a", DoubleType())]),
97 | [{"a": "1"}],
98 | ['{"a":"1"}'],
99 | content_type
100 | )
101 |
102 |
103 | @pytest.mark.parametrize("content_type", content_types)
104 | def test_boolean_as_double(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
105 | check_bad_record(
106 | database_conn,
107 | spark,
108 | StructType([StructField("a", DoubleType())]),
109 | [{"a": True}],
110 | ['{"a":true}'],
111 | content_type
112 | )
113 |
114 |
115 | @pytest.mark.parametrize("content_type", content_types)
116 | def test_string_as_boolean(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
117 | check_bad_record(
118 | database_conn,
119 | spark,
120 | StructType([StructField("a", BooleanType())]),
121 | [{"a": "true"}],
122 | ['{"a":"true"}'],
123 | content_type
124 | )
125 |
126 |
127 | @pytest.mark.parametrize("content_type", content_types)
128 | def test_number_as_boolean(database_conn: arango.database.StandardDatabase, spark: SparkSession, content_type: str):
129 | check_bad_record(
130 | database_conn,
131 | spark,
132 | StructType([StructField("a", BooleanType())]),
133 | [{"a": 1}],
134 | ['{"a":1}'],
135 | content_type
136 | )
--------------------------------------------------------------------------------