├── project
├── build.properties
└── plugins.sbt
├── src
├── main
│ ├── resources
│ │ └── META-INF
│ │ │ └── services
│ │ │ └── org.apache.spark.sql.sources.DataSourceRegister
│ ├── scala
│ │ └── com
│ │ │ └── audienceproject
│ │ │ └── spark
│ │ │ └── dynamodb
│ │ │ ├── attribute.scala
│ │ │ ├── datasource
│ │ │ ├── OutputPartitioning.scala
│ │ │ ├── ScanPartition.scala
│ │ │ ├── DynamoDataDeleteWriter.scala
│ │ │ ├── DynamoWriteBuilder.scala
│ │ │ ├── DefaultSource.scala
│ │ │ ├── DynamoDataUpdateWriter.scala
│ │ │ ├── DynamoBatchReader.scala
│ │ │ ├── DynamoDataWriter.scala
│ │ │ ├── DynamoWriterFactory.scala
│ │ │ ├── DynamoScanBuilder.scala
│ │ │ ├── DynamoReaderFactory.scala
│ │ │ ├── DynamoTable.scala
│ │ │ └── TypeConversion.scala
│ │ │ ├── connector
│ │ │ ├── KeySchema.scala
│ │ │ ├── DynamoWritable.scala
│ │ │ ├── ColumnSchema.scala
│ │ │ ├── TableIndexConnector.scala
│ │ │ ├── FilterPushdown.scala
│ │ │ ├── DynamoConnector.scala
│ │ │ └── TableConnector.scala
│ │ │ ├── reflect
│ │ │ └── SchemaAnalysis.scala
│ │ │ ├── implicits.scala
│ │ │ └── catalyst
│ │ │ └── JavaConverter.scala
│ └── java
│ │ └── com
│ │ └── audienceproject
│ │ └── shaded
│ │ └── google
│ │ └── common
│ │ ├── base
│ │ ├── Ticker.java
│ │ └── Preconditions.java
│ │ └── util
│ │ └── concurrent
│ │ ├── Uninterruptibles.java
│ │ └── RateLimiter.java
└── test
│ ├── scala
│ └── com
│ │ └── audienceproject
│ │ └── spark
│ │ └── dynamodb
│ │ ├── structs
│ │ ├── TestFruit.scala
│ │ └── TestFruitWithProperties.scala
│ │ ├── NullValuesTest.scala
│ │ ├── NullBooleanTest.scala
│ │ ├── FilterPushdownTest.scala
│ │ ├── DefaultSourceTest.scala
│ │ ├── RegionTest.scala
│ │ ├── AbstractInMemoryTest.scala
│ │ ├── WriteRelationTest.scala
│ │ └── NestedDataStructuresTest.scala
│ └── resources
│ └── log4j2.xml
├── .editorconfig
├── .gitignore
├── wercker.yml
├── README.md
└── LICENSE
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version = 1.2.6
2 |
--------------------------------------------------------------------------------
/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister:
--------------------------------------------------------------------------------
1 | com.audienceproject.spark.dynamodb.datasource.DefaultSource
2 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 | [*]
3 | end_of_line = lf
4 | insert_final_newline = true
5 | charset = utf-8
6 | indent_style = space
7 | indent_size = 4
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /target/
2 | /bin/
3 | *.class
4 | *.log
5 | .classpath
6 | .idea
7 | .wercker
8 | project/target
9 | project/project
10 | lib_managed*/
11 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | logLevel := Level.Warn
2 |
3 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.0")
4 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.4")
5 | addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
6 |
--------------------------------------------------------------------------------
/src/test/scala/com/audienceproject/spark/dynamodb/structs/TestFruit.scala:
--------------------------------------------------------------------------------
1 | package com.audienceproject.spark.dynamodb.structs
2 |
3 | import com.audienceproject.spark.dynamodb.attribute
4 |
5 | case class TestFruit(@attribute("name") primaryKey: String,
6 | color: String,
7 | weightKg: Double)
8 |
--------------------------------------------------------------------------------
/src/test/scala/com/audienceproject/spark/dynamodb/structs/TestFruitWithProperties.scala:
--------------------------------------------------------------------------------
1 | package com.audienceproject.spark.dynamodb.structs
2 |
3 | case class TestFruitProperties(freshness: String,
4 | eco: Boolean,
5 | price: Double)
6 |
7 | case class TestFruitWithProperties(name: String,
8 | color: String,
9 | weight: Double,
10 | properties: TestFruitProperties)
11 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/attribute.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb
22 |
23 | import scala.annotation.StaticAnnotation
24 |
25 | final case class attribute(name: String) extends StaticAnnotation
26 |
--------------------------------------------------------------------------------
/wercker.yml:
--------------------------------------------------------------------------------
1 | box:
2 | id: audienceproject/jvm
3 | username: $DOCKERHUB_ACCOUNT
4 | password: $DOCKERHUB_PASSWORD
5 | tag: latest
6 |
7 | build:
8 | steps:
9 | - script:
10 | name: Compile
11 | code: sbt clean compile
12 | - audienceproject/aws-cli-assume-role@1.0.2:
13 | aws-access-key-id: $AWS_ACCESS_KEY
14 | aws-secret-access-key: $AWS_SECRET_KEY
15 | role-arn: arn:aws:iam::$AWS_ACCOUNT_ID:role/build-$WERCKER_GIT_REPOSITORY
16 | - script:
17 | name: Test
18 | code: sbt clean compile test
19 | - script:
20 | name: Clean again
21 | code: sbt clean
22 |
23 | publish-snapshot:
24 | steps:
25 | - audienceproject/sbt-to-maven-central@2.0.0:
26 | user: $NEXUS_USER
27 | password: $NEXUS_PASSWORD
28 | private-key: $NEXUS_PK
29 | passphrase: $NEXUS_PASSPHRASE
30 |
31 | publish-release:
32 | steps:
33 | - audienceproject/sbt-to-maven-central@2.0.0:
34 | user: $NEXUS_USER
35 | password: $NEXUS_PASSWORD
36 | private-key: $NEXUS_PK
37 | passphrase: $NEXUS_PASSPHRASE
38 | destination: RELEASE
39 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/datasource/OutputPartitioning.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.datasource
22 |
23 | import org.apache.spark.sql.connector.read.partitioning.{Distribution, Partitioning}
24 |
25 | class OutputPartitioning(override val numPartitions: Int) extends Partitioning {
26 |
27 | override def satisfy(distribution: Distribution): Boolean = false
28 |
29 | }
30 |
--------------------------------------------------------------------------------
/src/test/resources/log4j2.xml:
--------------------------------------------------------------------------------
1 |
2 |
Warning: this interface can only be used to measure elapsed time, not wall time.
29 | *
30 | * @author Kevin Bourrillion
31 | * @since 10.0
32 | * (mostly source-compatible since 9.0)
34 | */
35 | public abstract class Ticker {
36 | /**
37 | * Constructor for use by subclasses.
38 | */
39 | protected Ticker() {}
40 |
41 | /**
42 | * Returns the number of nanoseconds elapsed since this ticker's fixed
43 | * point of reference.
44 | */
45 | public abstract long read();
46 |
47 | /**
48 | * A ticker that reads the current time using {@link System#nanoTime}.
49 | *
50 | * @since 10.0
51 | */
52 | public static Ticker systemTicker() {
53 | return SYSTEM_TICKER;
54 | }
55 |
56 | private static final Ticker SYSTEM_TICKER = new Ticker() {
57 | @Override
58 | public long read() {
59 | return System.nanoTime();
60 | }
61 | };
62 | }
63 |
64 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoBatchReader.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.datasource
22 |
23 | import com.audienceproject.spark.dynamodb.connector.DynamoConnector
24 | import org.apache.spark.sql.connector.read._
25 | import org.apache.spark.sql.connector.read.partitioning.Partitioning
26 | import org.apache.spark.sql.sources.Filter
27 | import org.apache.spark.sql.types.StructType
28 |
29 | class DynamoBatchReader(connector: DynamoConnector,
30 | filters: Array[Filter],
31 | schema: StructType)
32 | extends Scan with Batch with SupportsReportPartitioning {
33 |
34 | override def readSchema(): StructType = schema
35 |
36 | override def toBatch: Batch = this
37 |
38 | override def planInputPartitions(): Array[InputPartition] = {
39 | val requiredColumns = schema.map(_.name)
40 | Array.tabulate(connector.totalSegments)(new ScanPartition(_, requiredColumns, filters))
41 | }
42 |
43 | override def createReaderFactory(): PartitionReaderFactory =
44 | new DynamoReaderFactory(connector, schema)
45 |
46 | override val outputPartitioning: Partitioning = new OutputPartitioning(connector.totalSegments)
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/src/test/scala/com/audienceproject/spark/dynamodb/FilterPushdownTest.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb
22 |
23 | import com.audienceproject.spark.dynamodb.implicits._
24 |
25 | class FilterPushdownTest extends AbstractInMemoryTest {
26 |
27 | test("Count of red fruit is 2 (`EqualTo` filter)") {
28 | import spark.implicits._
29 | val fruitCount = spark.read.dynamodb("TestFruit").where($"color" === "red").count()
30 | assert(fruitCount === 2)
31 | }
32 |
33 | test("Count of yellow and green fruit is 4 (`In` filter)") {
34 | import spark.implicits._
35 | val fruitCount = spark.read.dynamodb("TestFruit")
36 | .where($"color" isin("yellow", "green"))
37 | .count()
38 | assert(fruitCount === 4)
39 | }
40 |
41 | test("Count of 0.01 weight fruit is 4 (`In` filter)") {
42 | import spark.implicits._
43 | val fruitCount = spark.read.dynamodb("TestFruit")
44 | .where($"weightKg" isin 0.01)
45 | .count()
46 | assert(fruitCount === 3)
47 | }
48 |
49 | test("Only 'banana' starts with a 'b' and is >0.01 kg (`StringStartsWith`, `GreaterThan`, `And` filters)") {
50 | import spark.implicits._
51 | val fruit = spark.read.dynamodb("TestFruit")
52 | .where(($"name" startsWith "b") && ($"weightKg" > 0.01))
53 | .collectAsList().get(0)
54 | assert(fruit.getAs[String]("name") === "banana")
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataWriter.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.datasource
22 |
23 | import com.amazonaws.services.dynamodbv2.document.DynamoDB
24 | import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter
25 | import com.audienceproject.spark.dynamodb.connector.{ColumnSchema, TableConnector}
26 | import org.apache.spark.sql.catalyst.InternalRow
27 | import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
28 |
29 | import scala.collection.mutable.ArrayBuffer
30 |
31 | class DynamoDataWriter(batchSize: Int,
32 | columnSchema: ColumnSchema,
33 | connector: TableConnector,
34 | client: DynamoDB)
35 | extends DataWriter[InternalRow] {
36 |
37 | protected val buffer: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow](batchSize)
38 | protected val rateLimiter: RateLimiter = RateLimiter.create(connector.writeLimit)
39 |
40 | override def write(record: InternalRow): Unit = {
41 | buffer += record.copy()
42 | if (buffer.size == batchSize) {
43 | flush()
44 | }
45 | }
46 |
47 | override def commit(): WriterCommitMessage = {
48 | flush()
49 | new WriterCommitMessage {}
50 | }
51 |
52 | override def abort(): Unit = {}
53 |
54 | override def close(): Unit = client.shutdown()
55 |
56 | protected def flush(): Unit = {
57 | if (buffer.nonEmpty) {
58 | connector.putItems(columnSchema, buffer)(client, rateLimiter)
59 | buffer.clear()
60 | }
61 | }
62 |
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/connector/ColumnSchema.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.connector
22 |
23 | import org.apache.spark.sql.types.{DataType, StructType}
24 |
25 | private[dynamodb] class ColumnSchema(keySchema: KeySchema,
26 | sparkSchema: StructType) {
27 |
28 | type Attr = (String, Int, DataType)
29 |
30 | private val columnNames = sparkSchema.map(_.name)
31 |
32 | private val keyIndices = keySchema match {
33 | case KeySchema(hashKey, None) =>
34 | val hashKeyIndex = columnNames.indexOf(hashKey)
35 | val hashKeyType = sparkSchema(hashKey).dataType
36 | Left(hashKey, hashKeyIndex, hashKeyType)
37 | case KeySchema(hashKey, Some(rangeKey)) =>
38 | val hashKeyIndex = columnNames.indexOf(hashKey)
39 | val rangeKeyIndex = columnNames.indexOf(rangeKey)
40 | val hashKeyType = sparkSchema(hashKey).dataType
41 | val rangeKeyType = sparkSchema(rangeKey).dataType
42 | Right((hashKey, hashKeyIndex, hashKeyType), (rangeKey, rangeKeyIndex, rangeKeyType))
43 | }
44 |
45 | private val attributeIndices = columnNames.zipWithIndex.filterNot({
46 | case (name, _) => keySchema match {
47 | case KeySchema(hashKey, None) => name == hashKey
48 | case KeySchema(hashKey, Some(rangeKey)) => name == hashKey || name == rangeKey
49 | }
50 | }).map({
51 | case (name, index) => (name, index, sparkSchema(name).dataType)
52 | })
53 |
54 | def keys(): Either[Attr, (Attr, Attr)] = keyIndices
55 |
56 | def attributes(): Seq[Attr] = attributeIndices
57 |
58 | }
59 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoWriterFactory.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.datasource
22 |
23 | import com.audienceproject.spark.dynamodb.connector.{ColumnSchema, TableConnector}
24 | import org.apache.spark.sql.catalyst.InternalRow
25 | import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
26 | import org.apache.spark.sql.types.StructType
27 |
28 | class DynamoWriterFactory(connector: TableConnector,
29 | parameters: Map[String, String],
30 | schema: StructType)
31 | extends DataWriterFactory {
32 |
33 | private val batchSize = parameters.getOrElse("writebatchsize", "25").toInt
34 | private val update = parameters.getOrElse("update", "false").toBoolean
35 | private val delete = parameters.getOrElse("delete", "false").toBoolean
36 |
37 | private val region = parameters.get("region")
38 | private val roleArn = parameters.get("rolearn")
39 | private val providerClassName = parameters.get("providerclassname")
40 |
41 | override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
42 | val columnSchema = new ColumnSchema(connector.keySchema, schema)
43 | val client = connector.getDynamoDB(region, roleArn, providerClassName)
44 | if (update) {
45 | assert(!delete, "Please provide exactly one of 'update' or 'delete' options.")
46 | new DynamoDataUpdateWriter(columnSchema, connector, client)
47 | } else if (delete) {
48 | new DynamoDataDeleteWriter(batchSize, columnSchema, connector, client)
49 | } else {
50 | new DynamoDataWriter(batchSize, columnSchema, connector, client)
51 | }
52 | }
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoScanBuilder.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.datasource
22 |
23 | import com.audienceproject.spark.dynamodb.connector.{DynamoConnector, FilterPushdown}
24 | import org.apache.spark.sql.connector.read._
25 | import org.apache.spark.sql.sources.Filter
26 | import org.apache.spark.sql.types._
27 |
28 | class DynamoScanBuilder(connector: DynamoConnector, schema: StructType)
29 | extends ScanBuilder
30 | with SupportsPushDownRequiredColumns
31 | with SupportsPushDownFilters {
32 |
33 | private var acceptedFilters: Array[Filter] = Array.empty
34 | private var currentSchema: StructType = schema
35 |
36 | override def build(): Scan = new DynamoBatchReader(connector, pushedFilters(), currentSchema)
37 |
38 | override def pruneColumns(requiredSchema: StructType): Unit = {
39 | val keyFields = Seq(Some(connector.keySchema.hashKeyName), connector.keySchema.rangeKeyName).flatten
40 | .flatMap(keyName => currentSchema.fields.find(_.name == keyName))
41 | val requiredFields = keyFields ++ requiredSchema.fields
42 | val newFields = currentSchema.fields.filter(requiredFields.contains)
43 | currentSchema = StructType(newFields)
44 | }
45 |
46 | override def pushFilters(filters: Array[Filter]): Array[Filter] = {
47 | if (connector.filterPushdownEnabled) {
48 | val (acceptedFilters, postScanFilters) = FilterPushdown.acceptFilters(filters)
49 | this.acceptedFilters = acceptedFilters
50 | postScanFilters // Return filters that need to be evaluated after scanning.
51 | } else filters
52 | }
53 |
54 | override def pushedFilters(): Array[Filter] = acceptedFilters
55 |
56 | }
57 |
--------------------------------------------------------------------------------
/src/test/scala/com/audienceproject/spark/dynamodb/DefaultSourceTest.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb
22 |
23 | import com.audienceproject.spark.dynamodb.implicits._
24 | import com.audienceproject.spark.dynamodb.structs.TestFruit
25 | import org.apache.spark.sql.functions._
26 |
27 | import scala.collection.JavaConverters._
28 |
29 | class DefaultSourceTest extends AbstractInMemoryTest {
30 |
31 | test("Table count is 9") {
32 | val count = spark.read.dynamodb("TestFruit")
33 | count.show()
34 | assert(count.count() === 9)
35 | }
36 |
37 | test("Column sum is 27") {
38 | val result = spark.read.dynamodb("TestFruit").collectAsList().asScala
39 | val numCols = result.map(_.length).sum
40 | assert(numCols === 27)
41 | }
42 |
43 | test("Select only first two columns") {
44 | val result = spark.read.dynamodb("TestFruit").select("name", "color").collectAsList().asScala
45 | val numCols = result.map(_.length).sum
46 | assert(numCols === 18)
47 | }
48 |
49 | test("The least occurring color is yellow") {
50 | import spark.implicits._
51 | val itemWithLeastOccurringColor = spark.read.dynamodb("TestFruit")
52 | .groupBy($"color").agg(count($"color").as("countColor"))
53 | .orderBy($"countColor")
54 | .takeAsList(1).get(0)
55 | assert(itemWithLeastOccurringColor.getAs[String]("color") === "yellow")
56 | }
57 |
58 | test("Test of attribute name alias") {
59 | import spark.implicits._
60 | val itemApple = spark.read.dynamodbAs[TestFruit]("TestFruit")
61 | .filter($"primaryKey" === "apple")
62 | .takeAsList(1).get(0)
63 | assert(itemApple.primaryKey === "apple")
64 | }
65 |
66 | }
67 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/reflect/SchemaAnalysis.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.reflect
22 |
23 | import com.audienceproject.spark.dynamodb.attribute
24 | import org.apache.spark.sql.catalyst.ScalaReflection
25 | import org.apache.spark.sql.types.{StructField, StructType}
26 |
27 | import scala.reflect.ClassTag
28 | import scala.reflect.runtime.{universe => ru}
29 |
30 | /**
31 | * Uses reflection to perform a static analysis that can derive a Spark schema from a case class of type `T`.
32 | */
33 | private[dynamodb] object SchemaAnalysis {
34 |
35 | def apply[T <: Product : ClassTag : ru.TypeTag]: (StructType, Map[String, String]) = {
36 |
37 | val runtimeMirror = ru.runtimeMirror(getClass.getClassLoader)
38 |
39 | val classObj = scala.reflect.classTag[T].runtimeClass
40 | val classSymbol = runtimeMirror.classSymbol(classObj)
41 |
42 | val params = classSymbol.primaryConstructor.typeSignature.paramLists.head
43 | val (sparkFields, aliasMap) = params.foldLeft((List.empty[StructField], Map.empty[String, String]))({
44 | case ((list, map), field) =>
45 | val sparkType = ScalaReflection.schemaFor(field.typeSignature).dataType
46 |
47 | // Black magic from here:
48 | // https://stackoverflow.com/questions/23046958/accessing-an-annotation-value-in-scala
49 | val attrName = field.annotations.collectFirst({
50 | case ann: ru.AnnotationApi if ann.tree.tpe =:= ru.typeOf[attribute] =>
51 | ann.tree.children.tail.collectFirst({
52 | case ru.Literal(ru.Constant(name: String)) => name
53 | })
54 | }).flatten
55 |
56 | if (attrName.isDefined) {
57 | val sparkField = StructField(attrName.get, sparkType, nullable = true)
58 | (list :+ sparkField, map + (attrName.get -> field.name.toString))
59 | } else {
60 | val sparkField = StructField(field.name.toString, sparkType, nullable = true)
61 | (list :+ sparkField, map)
62 | }
63 | })
64 |
65 | (StructType(sparkFields), aliasMap)
66 | }
67 |
68 | }
69 |
--------------------------------------------------------------------------------
/src/test/scala/com/audienceproject/spark/dynamodb/RegionTest.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb
22 |
23 | import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
24 | import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
25 | import com.amazonaws.services.dynamodbv2.document.DynamoDB
26 | import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput}
27 | import com.audienceproject.spark.dynamodb.implicits._
28 |
29 | class RegionTest extends AbstractInMemoryTest {
30 |
31 | test("Inserting from a local Dataset") {
32 | val tableName = "RegionTest1"
33 | dynamoDB.createTable(new CreateTableRequest()
34 | .withTableName(tableName)
35 | .withAttributeDefinitions(new AttributeDefinition("name", "S"))
36 | .withKeySchema(new KeySchemaElement("name", "HASH"))
37 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L)))
38 | val client: AmazonDynamoDB = AmazonDynamoDBClientBuilder.standard()
39 | .withEndpointConfiguration(new EndpointConfiguration(System.getProperty("aws.dynamodb.endpoint"), "eu-central-1"))
40 | .build()
41 | val dynamoDBEU: DynamoDB = new DynamoDB(client)
42 | dynamoDBEU.createTable(new CreateTableRequest()
43 | .withTableName(tableName)
44 | .withAttributeDefinitions(new AttributeDefinition("name", "S"))
45 | .withKeySchema(new KeySchemaElement("name", "HASH"))
46 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L)))
47 |
48 | import spark.implicits._
49 |
50 | val newItemsDs = spark.createDataset(Seq(
51 | ("lemon", "yellow", 0.1),
52 | ("orange", "orange", 0.2),
53 | ("pomegranate", "red", 0.2)
54 | ))
55 | .withColumnRenamed("_1", "name")
56 | .withColumnRenamed("_2", "color")
57 | .withColumnRenamed("_3", "weight")
58 | newItemsDs.write.option("region","eu-central-1").dynamodb(tableName)
59 |
60 | val validationDs = spark.read.dynamodb(tableName)
61 | assert(validationDs.count() === 0)
62 | val validationDsEU = spark.read.option("region","eu-central-1").dynamodb(tableName)
63 | assert(validationDsEU.count() === 3)
64 | }
65 |
66 | }
67 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/implicits.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb
22 |
23 | import com.audienceproject.spark.dynamodb.reflect.SchemaAnalysis
24 | import org.apache.spark.sql._
25 | import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
26 | import org.apache.spark.sql.functions.col
27 |
28 | import scala.reflect.ClassTag
29 | import scala.reflect.runtime.universe.TypeTag
30 |
31 | object implicits {
32 |
33 | implicit class DynamoDBDataFrameReader(reader: DataFrameReader) {
34 |
35 | def dynamodb(tableName: String): DataFrame =
36 | getDynamoDBSource(tableName).load()
37 |
38 | def dynamodb(tableName: String, indexName: String): DataFrame =
39 | getDynamoDBSource(tableName).option("indexName", indexName).load()
40 |
41 | def dynamodbAs[T <: Product : ClassTag : TypeTag](tableName: String): Dataset[T] = {
42 | implicit val encoder: Encoder[T] = ExpressionEncoder()
43 | val (schema, aliasMap) = SchemaAnalysis[T]
44 | getColumnsAlias(getDynamoDBSource(tableName).schema(schema).load(), aliasMap).as
45 | }
46 |
47 | def dynamodbAs[T <: Product : ClassTag : TypeTag](tableName: String, indexName: String): Dataset[T] = {
48 | implicit val encoder: Encoder[T] = ExpressionEncoder()
49 | val (schema, aliasMap) = SchemaAnalysis[T]
50 | getColumnsAlias(
51 | getDynamoDBSource(tableName).option("indexName", indexName).schema(schema).load(), aliasMap).as
52 | }
53 |
54 | private def getDynamoDBSource(tableName: String): DataFrameReader =
55 | reader.format("com.audienceproject.spark.dynamodb.datasource").option("tableName", tableName)
56 |
57 | private def getColumnsAlias(dataFrame: DataFrame, aliasMap: Map[String, String]): DataFrame = {
58 | if (aliasMap.isEmpty) dataFrame
59 | else {
60 | val columnsAlias = dataFrame.columns.map({
61 | case name if aliasMap.isDefinedAt(name) => col(name) as aliasMap(name)
62 | case name => col(name)
63 | })
64 | dataFrame.select(columnsAlias: _*)
65 | }
66 | }
67 |
68 | }
69 |
70 | implicit class DynamoDBDataFrameWriter[T](writer: DataFrameWriter[T]) {
71 |
72 | def dynamodb(tableName: String): Unit =
73 | writer.format("com.audienceproject.spark.dynamodb.datasource")
74 | .mode(SaveMode.Append)
75 | .option("tableName", tableName)
76 | .save()
77 |
78 | }
79 |
80 | }
81 |
--------------------------------------------------------------------------------
/src/test/scala/com/audienceproject/spark/dynamodb/AbstractInMemoryTest.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb
22 |
23 | import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
24 | import com.amazonaws.services.dynamodbv2.document.{DynamoDB, Item}
25 | import com.amazonaws.services.dynamodbv2.local.main.ServerRunner
26 | import com.amazonaws.services.dynamodbv2.local.server.DynamoDBProxyServer
27 | import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput}
28 | import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
29 | import org.apache.spark.sql.SparkSession
30 | import org.scalatest.{BeforeAndAfterAll, FunSuite}
31 |
32 | class AbstractInMemoryTest extends FunSuite with BeforeAndAfterAll {
33 |
34 | val server: DynamoDBProxyServer = ServerRunner.createServerFromCommandLineArgs(Array("-inMemory"))
35 |
36 | val client: AmazonDynamoDB = AmazonDynamoDBClientBuilder.standard()
37 | .withEndpointConfiguration(new EndpointConfiguration(System.getProperty("aws.dynamodb.endpoint"), "us-east-1"))
38 | .build()
39 | val dynamoDB: DynamoDB = new DynamoDB(client)
40 |
41 | val spark: SparkSession = SparkSession.builder
42 | .master("local")
43 | .appName(this.getClass.getName)
44 | .getOrCreate()
45 |
46 | spark.sparkContext.setLogLevel("ERROR")
47 |
48 | override def beforeAll(): Unit = {
49 | server.start()
50 |
51 | // Create a test table.
52 | dynamoDB.createTable(new CreateTableRequest()
53 | .withTableName("TestFruit")
54 | .withAttributeDefinitions(new AttributeDefinition("name", "S"))
55 | .withKeySchema(new KeySchemaElement("name", "HASH"))
56 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L)))
57 |
58 | // Populate with test data.
59 | val table = dynamoDB.getTable("TestFruit")
60 | for ((name, color, weight) <- Seq(
61 | ("apple", "red", 0.2), ("banana", "yellow", 0.15), ("watermelon", "red", 0.5),
62 | ("grape", "green", 0.01), ("pear", "green", 0.2), ("kiwi", "green", 0.05),
63 | ("blackberry", "purple", 0.01), ("blueberry", "purple", 0.01), ("plum", "purple", 0.1)
64 | )) {
65 | table.putItem(new Item()
66 | .withString("name", name)
67 | .withString("color", color)
68 | .withDouble("weightKg", weight))
69 | }
70 | }
71 |
72 | override def afterAll(): Unit = {
73 | client.deleteTable("TestFruit")
74 | server.stop()
75 | }
76 |
77 | }
78 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/catalyst/JavaConverter.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.catalyst
22 |
23 | import java.util
24 |
25 | import org.apache.spark.sql.catalyst.InternalRow
26 | import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
27 | import org.apache.spark.sql.types._
28 | import org.apache.spark.unsafe.types.UTF8String
29 |
30 | import scala.collection.JavaConverters._
31 |
32 | object JavaConverter {
33 |
34 | def convertRowValue(row: InternalRow, index: Int, elementType: DataType): Any = {
35 | elementType match {
36 | case ArrayType(innerType, _) => convertArray(row.getArray(index), innerType)
37 | case MapType(keyType, valueType, _) => convertMap(row.getMap(index), keyType, valueType)
38 | case StructType(fields) => convertStruct(row.getStruct(index, fields.length), fields)
39 | case StringType => row.getString(index)
40 | case LongType => row.getLong(index)
41 | case t: DecimalType => row.getDecimal(index, t.precision, t.scale).toBigDecimal
42 | case _ => row.get(index, elementType)
43 | }
44 | }
45 |
46 | def convertArray(array: ArrayData, elementType: DataType): Any = {
47 | elementType match {
48 | case ArrayType(innerType, _) => array.toSeq[ArrayData](elementType).map(convertArray(_, innerType)).asJava
49 | case MapType(keyType, valueType, _) => array.toSeq[MapData](elementType).map(convertMap(_, keyType, valueType)).asJava
50 | case structType: StructType => array.toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)).asJava
51 | case StringType => convertStringArray(array).asJava
52 | case _ => array.toSeq[Any](elementType).asJava
53 | }
54 | }
55 |
56 | def convertMap(map: MapData, keyType: DataType, valueType: DataType): util.Map[String, Any] = {
57 | if (keyType != StringType) throw new IllegalArgumentException(
58 | s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
59 | val keys = convertStringArray(map.keyArray())
60 | val values = valueType match {
61 | case ArrayType(innerType, _) => map.valueArray().toSeq[ArrayData](valueType).map(convertArray(_, innerType))
62 | case MapType(innerKeyType, innerValueType, _) => map.valueArray().toSeq[MapData](valueType).map(convertMap(_, innerKeyType, innerValueType))
63 | case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(convertStruct(_, structType.fields))
64 | case StringType => convertStringArray(map.valueArray())
65 | case _ => map.valueArray().toSeq[Any](valueType)
66 | }
67 | val kvPairs = for (i <- 0 until map.numElements()) yield keys(i) -> values(i)
68 | Map(kvPairs: _*).asJava
69 | }
70 |
71 | def convertStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = {
72 | val kvPairs = for (i <- 0 until row.numFields) yield
73 | if (row.isNullAt(i)) fields(i).name -> null
74 | else fields(i).name -> convertRowValue(row, i, fields(i).dataType)
75 | Map(kvPairs: _*).asJava
76 | }
77 |
78 | def convertStringArray(array: ArrayData): Seq[String] =
79 | array.toSeq[UTF8String](StringType).map(_.toString)
80 |
81 | }
82 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoReaderFactory.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.datasource
22 |
23 | import com.amazonaws.services.dynamodbv2.document.Item
24 | import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter
25 | import com.audienceproject.spark.dynamodb.connector.DynamoConnector
26 | import org.apache.spark.sql.catalyst.InternalRow
27 | import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
28 | import org.apache.spark.sql.types.{StructField, StructType}
29 |
30 | import scala.collection.JavaConverters._
31 |
32 | class DynamoReaderFactory(connector: DynamoConnector,
33 | schema: StructType)
34 | extends PartitionReaderFactory {
35 |
36 | override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
37 | if (connector.isEmpty) new EmptyReader
38 | else new ScanPartitionReader(partition.asInstanceOf[ScanPartition])
39 | }
40 |
41 | private class EmptyReader extends PartitionReader[InternalRow] {
42 | override def next(): Boolean = false
43 |
44 | override def get(): InternalRow = throw new IllegalStateException("Unable to call get() on empty iterator")
45 |
46 | override def close(): Unit = {}
47 | }
48 |
49 | private class ScanPartitionReader(scanPartition: ScanPartition) extends PartitionReader[InternalRow] {
50 |
51 | import scanPartition._
52 |
53 | private val pageIterator = connector.scan(partitionIndex, requiredColumns, filters).pages().iterator().asScala
54 | private val rateLimiter = RateLimiter.create(connector.readLimit)
55 |
56 | private var innerIterator: Iterator[InternalRow] = Iterator.empty
57 |
58 | private var currentRow: InternalRow = _
59 | private var proceed = false
60 |
61 | private val typeConversions = schema.collect({
62 | case StructField(name, dataType, _, _) => name -> TypeConversion(name, dataType)
63 | }).toMap
64 |
65 | override def next(): Boolean = {
66 | proceed = true
67 | innerIterator.hasNext || {
68 | if (pageIterator.hasNext) {
69 | nextPage()
70 | next()
71 | }
72 | else false
73 | }
74 | }
75 |
76 | override def get(): InternalRow = {
77 | if (proceed) {
78 | currentRow = innerIterator.next()
79 | proceed = false
80 | }
81 | currentRow
82 | }
83 |
84 | override def close(): Unit = {}
85 |
86 | private def nextPage(): Unit = {
87 | val page = pageIterator.next()
88 | val result = page.getLowLevelResult
89 | Option(result.getScanResult.getConsumedCapacity).foreach(cap => rateLimiter.acquire(cap.getCapacityUnits.toInt max 1))
90 | innerIterator = result.getItems.iterator().asScala.map(itemToRow(requiredColumns))
91 | }
92 |
93 | private def itemToRow(requiredColumns: Seq[String])(item: Item): InternalRow =
94 | if (requiredColumns.nonEmpty) InternalRow.fromSeq(requiredColumns.map(columnName => typeConversions(columnName)(item)))
95 | else InternalRow.fromSeq(item.asMap().asScala.values.toSeq.map(_.toString))
96 |
97 | }
98 |
99 | }
100 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.connector
22 |
23 | import com.amazonaws.services.dynamodbv2.document.spec.ScanSpec
24 | import com.amazonaws.services.dynamodbv2.document.{ItemCollection, ScanOutcome}
25 | import com.amazonaws.services.dynamodbv2.model.ReturnConsumedCapacity
26 | import com.amazonaws.services.dynamodbv2.xspec.ExpressionSpecBuilder
27 | import org.apache.spark.sql.sources.Filter
28 |
29 | import scala.collection.JavaConverters._
30 |
31 | private[dynamodb] class TableIndexConnector(tableName: String, indexName: String, parallelism: Int, parameters: Map[String, String])
32 | extends DynamoConnector with Serializable {
33 |
34 | private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
35 | private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
36 | private val region = parameters.get("region")
37 | private val roleArn = parameters.get("roleArn")
38 | private val providerClassName = parameters.get("providerclassname")
39 |
40 | override val filterPushdownEnabled: Boolean = filterPushdown
41 |
42 | override val (keySchema, readLimit, itemLimit, totalSegments) = {
43 | val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName)
44 | val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get
45 |
46 | // Key schema.
47 | val keySchema = KeySchema.fromDescription(indexDesc.getKeySchema.asScala)
48 |
49 | // User parameters.
50 | val bytesPerRCU = parameters.getOrElse("bytesPerRCU", "4000").toInt
51 | val maxPartitionBytes = parameters.getOrElse("maxpartitionbytes", "128000000").toInt
52 | val targetCapacity = parameters.getOrElse("targetCapacity", "1").toDouble
53 | val readFactor = if (consistentRead) 1 else 2
54 |
55 | // Table parameters.
56 | val indexSize = indexDesc.getIndexSizeBytes
57 | val itemCount = indexDesc.getItemCount
58 |
59 | // Partitioning calculation.
60 | val numPartitions = parameters.get("readpartitions").map(_.toInt).getOrElse({
61 | val sizeBased = (indexSize / maxPartitionBytes).toInt max 1
62 | val remainder = sizeBased % parallelism
63 | if (remainder > 0) sizeBased + (parallelism - remainder)
64 | else sizeBased
65 | })
66 |
67 | // Provisioned or on-demand throughput.
68 | val readThroughput = parameters.getOrElse("throughput", Option(indexDesc.getProvisionedThroughput.getReadCapacityUnits)
69 | .filter(_ > 0).map(_.longValue().toString)
70 | .getOrElse("100")).toLong
71 |
72 | // Rate limit calculation.
73 | val avgItemSize = indexSize.toDouble / itemCount
74 | val readCapacity = readThroughput * targetCapacity
75 |
76 | val rateLimit = readCapacity / parallelism
77 | val itemLimit = ((bytesPerRCU / avgItemSize * rateLimit).toInt * readFactor) max 1
78 |
79 | (keySchema, rateLimit, itemLimit, numPartitions)
80 | }
81 |
82 | override def scan(segmentNum: Int, columns: Seq[String], filters: Seq[Filter]): ItemCollection[ScanOutcome] = {
83 | val scanSpec = new ScanSpec()
84 | .withSegment(segmentNum)
85 | .withTotalSegments(totalSegments)
86 | .withMaxPageSize(itemLimit)
87 | .withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL)
88 | .withConsistentRead(consistentRead)
89 |
90 | if (columns.nonEmpty) {
91 | val xspec = new ExpressionSpecBuilder().addProjections(columns: _*)
92 |
93 | if (filters.nonEmpty && filterPushdown) {
94 | xspec.withCondition(FilterPushdown(filters))
95 | }
96 |
97 | scanSpec.withExpressionSpec(xspec.buildForScan())
98 | }
99 |
100 | getDynamoDB(region, roleArn, providerClassName).getTable(tableName).getIndex(indexName).scan(scanSpec)
101 | }
102 |
103 | }
104 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/connector/FilterPushdown.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.connector
22 |
23 | import com.amazonaws.services.dynamodbv2.xspec.ExpressionSpecBuilder.{BOOL => newBOOL, N => newN, S => newS, _}
24 | import com.amazonaws.services.dynamodbv2.xspec._
25 | import org.apache.spark.sql.sources._
26 |
27 | private[dynamodb] object FilterPushdown {
28 |
29 | def apply(filters: Seq[Filter]): Condition =
30 | filters.map(buildCondition).map(parenthesize).reduce[Condition](_ and _)
31 |
32 | /**
33 | * Accepts only filters that would be considered valid input to FilterPushdown.apply()
34 | *
35 | * @param filters input list which may contain both valid and invalid filters
36 | * @return a (valid, invalid) partitioning of the input filters
37 | */
38 | def acceptFilters(filters: Array[Filter]): (Array[Filter], Array[Filter]) =
39 | filters.partition(checkFilter)
40 |
41 | private def checkFilter(filter: Filter): Boolean = filter match {
42 | case _: StringEndsWith => false
43 | case And(left, right) => checkFilter(left) && checkFilter(right)
44 | case Or(left, right) => checkFilter(left) && checkFilter(right)
45 | case Not(f) => checkFilter(f)
46 | case _ => true
47 | }
48 |
49 | private def buildCondition(filter: Filter): Condition = filter match {
50 | case EqualTo(path, value: Boolean) => newBOOL(path).eq(value)
51 | case EqualTo(path, value) => coerceAndApply(_ eq _, _ eq _)(path, value)
52 |
53 | case GreaterThan(path, value) => coerceAndApply(_ gt _, _ gt _)(path, value)
54 | case GreaterThanOrEqual(path, value) => coerceAndApply(_ ge _, _ ge _)(path, value)
55 |
56 | case LessThan(path, value) => coerceAndApply(_ lt _, _ lt _)(path, value)
57 | case LessThanOrEqual(path, value) => coerceAndApply(_ le _, _ le _)(path, value)
58 |
59 | case In(path, values) =>
60 | val valueList = values.toList
61 | valueList match {
62 | case (_: String) :: _ => newS(path).in(valueList.asInstanceOf[List[String]]: _*)
63 | case (_: Boolean) :: _ => newBOOL(path).in(valueList.asInstanceOf[List[Boolean]]: _*)
64 | case (_: Int) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*)
65 | case (_: Long) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*)
66 | case (_: Short) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*)
67 | case (_: Float) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*)
68 | case (_: Double) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*)
69 | case Nil => throw new IllegalArgumentException("Unable to apply `In` filter with empty value list")
70 | case _ => throw new IllegalArgumentException(s"Type of values supplied to `In` filter on attribute $path not supported by filter pushdown")
71 | }
72 |
73 | case IsNull(path) => attribute_not_exists(path)
74 | case IsNotNull(path) => attribute_exists(path)
75 |
76 | case StringStartsWith(path, value) => newS(path).beginsWith(value)
77 | case StringContains(path, value) => newS(path).contains(value)
78 | case StringEndsWith(_, _) => throw new UnsupportedOperationException("Filter `StringEndsWith` is not supported by DynamoDB")
79 |
80 | case And(left, right) => parenthesize(buildCondition(left)) and parenthesize(buildCondition(right))
81 | case Or(left, right) => parenthesize(buildCondition(left)) or parenthesize(buildCondition(right))
82 | case Not(f) => parenthesize(buildCondition(f)).negate()
83 | }
84 |
85 | private def coerceAndApply(stringOp: (S, String) => Condition, numOp: (N, Number) => Condition)
86 | (path: String, value: Any): Condition = value match {
87 | case string: String => stringOp(newS(path), string)
88 | case number: Int => numOp(newN(path), number)
89 | case number: Long => numOp(newN(path), number)
90 | case number: Short => numOp(newN(path), number)
91 | case number: Float => numOp(newN(path), number)
92 | case number: Double => numOp(newN(path), number)
93 | case _ => throw new IllegalArgumentException(s"Type of operand given to filter on attribute $path not supported by filter pushdown")
94 | }
95 |
96 | }
97 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoTable.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.datasource
22 |
23 | import java.util
24 |
25 | import com.audienceproject.spark.dynamodb.connector.{TableConnector, TableIndexConnector}
26 | import org.apache.spark.sql.SparkSession
27 | import org.apache.spark.sql.connector.catalog._
28 | import org.apache.spark.sql.connector.read.ScanBuilder
29 | import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
30 | import org.apache.spark.sql.types._
31 | import org.apache.spark.sql.util.CaseInsensitiveStringMap
32 | import org.slf4j.LoggerFactory
33 |
34 | import scala.collection.JavaConverters._
35 |
36 | class DynamoTable(options: CaseInsensitiveStringMap,
37 | userSchema: Option[StructType] = None)
38 | extends Table
39 | with SupportsRead
40 | with SupportsWrite {
41 |
42 | private val logger = LoggerFactory.getLogger(this.getClass)
43 |
44 | private val dynamoConnector = {
45 | val indexName = Option(options.get("indexname"))
46 | val defaultParallelism = Option(options.get("defaultparallelism")).map(_.toInt).getOrElse(getDefaultParallelism)
47 | val optionsMap = Map(options.asScala.toSeq: _*)
48 |
49 | if (indexName.isDefined) new TableIndexConnector(name(), indexName.get, defaultParallelism, optionsMap)
50 | else new TableConnector(name(), defaultParallelism, optionsMap)
51 | }
52 |
53 | override def name(): String = options.get("tablename")
54 |
55 | override def schema(): StructType = userSchema.getOrElse(inferSchema())
56 |
57 | override def capabilities(): util.Set[TableCapability] =
58 | Set(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava
59 |
60 | override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
61 | new DynamoScanBuilder(dynamoConnector, schema())
62 | }
63 |
64 | override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
65 | val parameters = Map(info.options().asScala.toSeq: _*)
66 | dynamoConnector match {
67 | case tableConnector: TableConnector => new DynamoWriteBuilder(tableConnector, parameters, info.schema())
68 | case _ => throw new RuntimeException("Unable to write to a GSI, please omit `indexName` option.")
69 | }
70 | }
71 |
72 | private def getDefaultParallelism: Int =
73 | SparkSession.getActiveSession match {
74 | case Some(spark) => spark.sparkContext.defaultParallelism
75 | case None =>
76 | logger.warn("Unable to read defaultParallelism from SparkSession." +
77 | " Parallelism will be 1 unless overwritten with option `defaultParallelism`")
78 | 1
79 | }
80 |
81 | private def inferSchema(): StructType = {
82 | val inferenceItems =
83 | if (dynamoConnector.nonEmpty && options.getBoolean("inferSchema",true)) dynamoConnector.scan(0, Seq.empty, Seq.empty).firstPage().getLowLevelResult.getItems.asScala
84 | else Seq.empty
85 |
86 | val typeMapping = inferenceItems.foldLeft(Map[String, DataType]())({
87 | case (map, item) => map ++ item.asMap().asScala.mapValues(inferType)
88 | })
89 | val typeSeq = typeMapping.map({ case (name, sparkType) => StructField(name, sparkType) }).toSeq
90 |
91 | if (typeSeq.size > 100) throw new RuntimeException("Schema inference not possible, too many attributes in table.")
92 |
93 | StructType(typeSeq)
94 | }
95 |
96 | private def inferType(value: Any): DataType = value match {
97 | case number: java.math.BigDecimal =>
98 | if (number.scale() == 0) {
99 | if (number.precision() < 10) IntegerType
100 | else if (number.precision() < 19) LongType
101 | else DataTypes.createDecimalType(number.precision(), number.scale())
102 | }
103 | else DoubleType
104 | case list: java.util.ArrayList[_] =>
105 | if (list.isEmpty) ArrayType(StringType)
106 | else ArrayType(inferType(list.get(0)))
107 | case set: java.util.Set[_] =>
108 | if (set.isEmpty) ArrayType(StringType)
109 | else ArrayType(inferType(set.iterator().next()))
110 | case map: java.util.Map[String, _] =>
111 | val mapFields = (for ((fieldName, fieldValue) <- map.asScala) yield {
112 | StructField(fieldName, inferType(fieldValue))
113 | }).toSeq
114 | StructType(mapFields)
115 | case _: java.lang.Boolean => BooleanType
116 | case _: Array[Byte] => BinaryType
117 | case _ => StringType
118 | }
119 |
120 | }
121 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/datasource/TypeConversion.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2019 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.datasource
22 |
23 | import com.amazonaws.services.dynamodbv2.document.{IncompatibleTypeException, Item}
24 | import org.apache.spark.sql.catalyst.InternalRow
25 | import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
26 | import org.apache.spark.sql.types._
27 | import org.apache.spark.unsafe.types.UTF8String
28 |
29 | import scala.collection.JavaConverters._
30 |
31 | private[dynamodb] object TypeConversion {
32 |
33 | def apply(attrName: String, sparkType: DataType): Item => Any =
34 |
35 | sparkType match {
36 | case BooleanType => nullableGet(_.getBOOL)(attrName)
37 | case StringType => nullableGet(item => attrName => UTF8String.fromString(item.getString(attrName)))(attrName)
38 | case IntegerType => nullableGet(_.getInt)(attrName)
39 | case LongType => nullableGet(_.getLong)(attrName)
40 | case DoubleType => nullableGet(_.getDouble)(attrName)
41 | case FloatType => nullableGet(_.getFloat)(attrName)
42 | case BinaryType => nullableGet(_.getBinary)(attrName)
43 | case DecimalType() => nullableGet(_.getNumber)(attrName)
44 | case ArrayType(innerType, _) =>
45 | nullableGet(_.getList)(attrName).andThen(extractArray(convertValue(innerType)))
46 | case MapType(keyType, valueType, _) =>
47 | if (keyType != StringType) throw new IllegalArgumentException(s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
48 | nullableGet(_.getRawMap)(attrName).andThen(extractMap(convertValue(valueType)))
49 | case StructType(fields) =>
50 | val nestedConversions = fields.collect({ case StructField(name, dataType, _, _) => name -> convertValue(dataType) })
51 | nullableGet(_.getRawMap)(attrName).andThen(extractStruct(nestedConversions))
52 | case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.")
53 | }
54 |
55 | private val stringConverter = (value: Any) => UTF8String.fromString(value.asInstanceOf[String])
56 |
57 | private def convertValue(sparkType: DataType): Any => Any =
58 |
59 | sparkType match {
60 | case IntegerType => nullableConvert(_.intValue())
61 | case LongType => nullableConvert(_.longValue())
62 | case DoubleType => nullableConvert(_.doubleValue())
63 | case FloatType => nullableConvert(_.floatValue())
64 | case DecimalType() => nullableConvert(identity)
65 | case ArrayType(innerType, _) => extractArray(convertValue(innerType))
66 | case MapType(keyType, valueType, _) =>
67 | if (keyType != StringType) throw new IllegalArgumentException(s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
68 | extractMap(convertValue(valueType))
69 | case StructType(fields) =>
70 | val nestedConversions = fields.collect({ case StructField(name, dataType, _, _) => name -> convertValue(dataType) })
71 | extractStruct(nestedConversions)
72 | case BooleanType => {
73 | case boolean: Boolean => boolean
74 | case _ => null
75 | }
76 | case StringType => {
77 | case string: String => UTF8String.fromString(string)
78 | case _ => null
79 | }
80 | case BinaryType => {
81 | case byteArray: Array[Byte] => byteArray
82 | case _ => null
83 | }
84 | case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.")
85 | }
86 |
87 | private def nullableGet(getter: Item => String => Any)(attrName: String): Item => Any = {
88 | case item if item.hasAttribute(attrName) => try getter(item)(attrName) catch {
89 | case _: NumberFormatException => null
90 | case _: IncompatibleTypeException => null
91 | }
92 | case _ => null
93 | }
94 |
95 | private def nullableConvert(converter: java.math.BigDecimal => Any): Any => Any = {
96 | case item: java.math.BigDecimal => converter(item)
97 | case _ => null
98 | }
99 |
100 | private def extractArray(converter: Any => Any): Any => Any = {
101 | case list: java.util.List[_] => new GenericArrayData(list.asScala.map(converter))
102 | case set: java.util.Set[_] => new GenericArrayData(set.asScala.map(converter).toSeq)
103 | case _ => null
104 | }
105 |
106 | private def extractMap(converter: Any => Any): Any => Any = {
107 | case map: java.util.Map[_, _] => ArrayBasedMapData(map, stringConverter, converter)
108 | case _ => null
109 | }
110 |
111 | private def extractStruct(conversions: Seq[(String, Any => Any)]): Any => Any = {
112 | case map: java.util.Map[_, _] => InternalRow.fromSeq(conversions.map({
113 | case (name, conv) => conv(map.get(name))
114 | }))
115 | case _ => null
116 | }
117 |
118 | }
119 |
--------------------------------------------------------------------------------
/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing,
13 | * software distributed under the License is distributed on an
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 | * KIND, either express or implied. See the License for the
16 | * specific language governing permissions and limitations
17 | * under the License.
18 | *
19 | * Copyright © 2018 AudienceProject. All rights reserved.
20 | */
21 | package com.audienceproject.spark.dynamodb.connector
22 |
23 | import com.amazonaws.auth.profile.ProfileCredentialsProvider
24 | import com.amazonaws.auth.{AWSCredentialsProvider, AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
25 | import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
26 | import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
27 | import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBAsync, AmazonDynamoDBAsyncClientBuilder, AmazonDynamoDBClientBuilder}
28 | import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder
29 | import com.amazonaws.services.securitytoken.model.AssumeRoleRequest
30 | import org.apache.spark.sql.sources.Filter
31 |
32 | private[dynamodb] trait DynamoConnector {
33 |
34 | @transient private lazy val properties = sys.props
35 |
36 | def getDynamoDB(region: Option[String] = None, roleArn: Option[String] = None, providerClassName: Option[String] = None): DynamoDB = {
37 | val client: AmazonDynamoDB = getDynamoDBClient(region, roleArn, providerClassName)
38 | new DynamoDB(client)
39 | }
40 |
41 | private def getDynamoDBClient(region: Option[String] = None,
42 | roleArn: Option[String] = None,
43 | providerClassName: Option[String]): AmazonDynamoDB = {
44 | val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1"))
45 | val credentials = getCredentials(chosenRegion, roleArn, providerClassName)
46 |
47 | properties.get("aws.dynamodb.endpoint").map(endpoint => {
48 | AmazonDynamoDBClientBuilder.standard()
49 | .withCredentials(credentials)
50 | .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
51 | .build()
52 | }).getOrElse(
53 | AmazonDynamoDBClientBuilder.standard()
54 | .withCredentials(credentials)
55 | .withRegion(chosenRegion)
56 | .build()
57 | )
58 | }
59 |
60 | def getDynamoDBAsyncClient(region: Option[String] = None,
61 | roleArn: Option[String] = None,
62 | providerClassName: Option[String] = None): AmazonDynamoDBAsync = {
63 | val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1"))
64 | val credentials = getCredentials(chosenRegion, roleArn, providerClassName)
65 |
66 | properties.get("aws.dynamodb.endpoint").map(endpoint => {
67 | AmazonDynamoDBAsyncClientBuilder.standard()
68 | .withCredentials(credentials)
69 | .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
70 | .build()
71 | }).getOrElse(
72 | AmazonDynamoDBAsyncClientBuilder.standard()
73 | .withCredentials(credentials)
74 | .withRegion(chosenRegion)
75 | .build()
76 | )
77 | }
78 |
79 | /**
80 | * Get credentials from an instantiated object of the class name given
81 | * or a passed in arn
82 | * or from profile
83 | * or return the default credential provider
84 | **/
85 | private def getCredentials(chosenRegion: String, roleArn: Option[String], providerClassName: Option[String]) = {
86 | providerClassName.map(providerClass => {
87 | Class.forName(providerClass).newInstance.asInstanceOf[AWSCredentialsProvider]
88 | }).orElse(roleArn.map(arn => {
89 | val stsClient = properties.get("aws.sts.endpoint").map(endpoint => {
90 | AWSSecurityTokenServiceClientBuilder
91 | .standard()
92 | .withCredentials(new DefaultAWSCredentialsProviderChain)
93 | .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
94 | .build()
95 | }).getOrElse(
96 | // STS without an endpoint will sign from the region, but use the global endpoint
97 | AWSSecurityTokenServiceClientBuilder
98 | .standard()
99 | .withCredentials(new DefaultAWSCredentialsProviderChain)
100 | .withRegion(chosenRegion)
101 | .build()
102 | )
103 | val assumeRoleResult = stsClient.assumeRole(
104 | new AssumeRoleRequest()
105 | .withRoleSessionName("DynamoDBAssumed")
106 | .withRoleArn(arn)
107 | )
108 | val stsCredentials = assumeRoleResult.getCredentials
109 | val assumeCreds = new BasicSessionCredentials(
110 | stsCredentials.getAccessKeyId,
111 | stsCredentials.getSecretAccessKey,
112 | stsCredentials.getSessionToken
113 | )
114 | new AWSStaticCredentialsProvider(assumeCreds)
115 | })).orElse(properties.get("aws.profile").map(new ProfileCredentialsProvider(_)))
116 | .getOrElse(new DefaultAWSCredentialsProviderChain)
117 | }
118 |
119 | val keySchema: KeySchema
120 |
121 | val readLimit: Double
122 |
123 | val itemLimit: Int
124 |
125 | val totalSegments: Int
126 |
127 | val filterPushdownEnabled: Boolean
128 |
129 | def scan(segmentNum: Int, columns: Seq[String], filters: Seq[Filter]): ItemCollection[ScanOutcome]
130 |
131 | def isEmpty: Boolean = itemLimit == 0
132 |
133 | def nonEmpty: Boolean = !isEmpty
134 |
135 | }
136 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spark+DynamoDB
2 | Plug-and-play implementation of an Apache Spark custom data source for AWS DynamoDB.
3 |
4 | We published a small article about the project, check it out here:
5 | https://www.audienceproject.com/blog/tech/sparkdynamodb-using-aws-dynamodb-data-source-apache-spark/
6 |
7 | ## News
8 |
9 | * 2021-01-28: Added option `inferSchema=false` which is useful when writing to a table with many columns
10 | * 2020-07-23: Releasing version 1.1.0 which supports Spark 3.0.0 and Scala 2.12. Future releases will no longer be compatible with Scala 2.11 and Spark 2.x.x.
11 | * 2020-04-28: Releasing version 1.0.4. Includes support for assuming AWS roles through custom STS endpoint (credits @jhulten).
12 | * 2020-04-09: We are releasing version 1.0.3 of the Spark+DynamoDB connector. Added option to `delete` records (thank you @rhelmstetter). Fixes (thank you @juanyunism for #46).
13 | * 2019-11-25: We are releasing version 1.0.0 of the Spark+DynamoDB connector, which is based on the Spark Data Source V2 API. Out-of-the-box throughput calculations, parallelism and partition planning should now be more reliable. We have also pulled out the external dependency on Guava, which was causing a lot of compatibility issues.
14 |
15 | ## Features
16 |
17 | - Distributed, parallel scan with lazy evaluation
18 | - Throughput control by rate limiting on target fraction of provisioned table/index capacity
19 | - Schema discovery to suit your needs
20 | - Dynamic inference
21 | - Static analysis of case class
22 | - Column and filter pushdown
23 | - Global secondary index support
24 | - Write support
25 |
26 | ## Getting The Dependency
27 |
28 | The library is available from [Maven Central](https://mvnrepository.com/artifact/com.audienceproject/spark-dynamodb). Add the dependency in SBT as ```"com.audienceproject" %% "spark-dynamodb" % "latest"```
29 |
30 | Spark is used in the library as a "provided" dependency, which means Spark has to be installed separately on the container where the application is running, such as is the case on AWS EMR.
31 |
32 | ## Quick Start Guide
33 |
34 | ### Scala
35 | ```scala
36 | import com.audienceproject.spark.dynamodb.implicits._
37 | import org.apache.spark.sql.SparkSession
38 |
39 | val spark = SparkSession.builder().getOrCreate()
40 |
41 | // Load a DataFrame from a Dynamo table. Only incurs the cost of a single scan for schema inference.
42 | val dynamoDf = spark.read.dynamodb("SomeTableName") // <-- DataFrame of Row objects with inferred schema.
43 |
44 | // Scan the table for the first 100 items (the order is arbitrary) and print them.
45 | dynamoDf.show(100)
46 |
47 | // write to some other table overwriting existing item with same keys
48 | dynamoDf.write.dynamodb("SomeOtherTable")
49 |
50 | // Case class representing the items in our table.
51 | import com.audienceproject.spark.dynamodb.attribute
52 | case class Vegetable (name: String, color: String, @attribute("weight_kg") weightKg: Double)
53 |
54 | // Load a Dataset[Vegetable]. Notice the @attribute annotation on the case class - we imagine the weight attribute is named with an underscore in DynamoDB.
55 | import org.apache.spark.sql.functions._
56 | import spark.implicits._
57 | val vegetableDs = spark.read.dynamodbAs[Vegetable]("VegeTable")
58 | val avgWeightByColor = vegetableDs.agg($"color", avg($"weightKg")) // The column is called 'weightKg' in the Dataset.
59 | ```
60 |
61 | ### Python
62 | ```python
63 | # Load a DataFrame from a Dynamo table. Only incurs the cost of a single scan for schema inference.
64 | dynamoDf = spark.read.option("tableName", "SomeTableName") \
65 | .format("dynamodb") \
66 | .load() # <-- DataFrame of Row objects with inferred schema.
67 |
68 | # Scan the table for the first 100 items (the order is arbitrary) and print them.
69 | dynamoDf.show(100)
70 |
71 | # write to some other table overwriting existing item with same keys
72 | dynamoDf.write.option("tableName", "SomeOtherTable") \
73 | .format("dynamodb") \
74 | .save()
75 | ```
76 |
77 | *Note:* When running from `pyspark` shell, you can add the library as:
78 | ```bash
79 | pyspark --packages com.audienceproject:spark-dynamodb_ If instead, you wish to treat {@link InterruptedException} uniformly
120 | * with other exceptions.
121 | *
122 | * @throws ExecutionException if the computation threw an exception
123 | * @throws CancellationException if the computation was cancelled
124 | */
125 | public static If instead, you wish to treat {@link InterruptedException} uniformly
149 | * with other exceptions.
150 | *
151 | * @throws ExecutionException if the computation threw an exception
152 | * @throws CancellationException if the computation was cancelled
153 | * @throws TimeoutException if the wait timed out
154 | */
155 | public static Warning: only the {@code "%s"} specifier is recognized as a
46 | * placeholder in these messages, not the full range of {@link
47 | * String#format(String, Object[])} specifiers.
48 | *
49 | * Take care not to confuse precondition checking with other similar types
50 | * of checks! Precondition exceptions -- including those provided here, but also
51 | * {@link IndexOutOfBoundsException}, {@link NoSuchElementException}, {@link
52 | * UnsupportedOperationException} and others -- are used to signal that the
53 | * calling method has made an error. This tells the caller that it should
54 | * not have invoked the method when it did, with the arguments it did, or
55 | * perhaps ever. Postcondition or other invariant failures should not throw
56 | * these types of exceptions.
57 | *
58 | * See the Guava User Guide on
60 | * using {@code Preconditions}.
61 | *
62 | * @author Kevin Bourrillion
63 | * @since 2.0 (imported from Google Collections Library)
64 | */
65 | public final class Preconditions {
66 | private Preconditions() {}
67 |
68 | /**
69 | * Ensures the truth of an expression involving one or more parameters to the
70 | * calling method.
71 | *
72 | * @param expression a boolean expression
73 | * @throws IllegalArgumentException if {@code expression} is false
74 | */
75 | public static void checkArgument(boolean expression) {
76 | if (!expression) {
77 | throw new IllegalArgumentException();
78 | }
79 | }
80 |
81 | /**
82 | * Ensures the truth of an expression involving one or more parameters to the
83 | * calling method.
84 | *
85 | * @param expression a boolean expression
86 | * @param errorMessage the exception message to use if the check fails; will
87 | * be converted to a string using {@link String#valueOf(Object)}
88 | * @throws IllegalArgumentException if {@code expression} is false
89 | */
90 | public static void checkArgument(
91 | boolean expression, @Nullable Object errorMessage) {
92 | if (!expression) {
93 | throw new IllegalArgumentException(String.valueOf(errorMessage));
94 | }
95 | }
96 |
97 | /**
98 | * Ensures the truth of an expression involving one or more parameters to the
99 | * calling method.
100 | *
101 | * @param expression a boolean expression
102 | * @param errorMessageTemplate a template for the exception message should the
103 | * check fail. The message is formed by replacing each {@code %s}
104 | * placeholder in the template with an argument. These are matched by
105 | * position - the first {@code %s} gets {@code errorMessageArgs[0]}, etc.
106 | * Unmatched arguments will be appended to the formatted message in square
107 | * braces. Unmatched placeholders will be left as-is.
108 | * @param errorMessageArgs the arguments to be substituted into the message
109 | * template. Arguments are converted to strings using
110 | * {@link String#valueOf(Object)}.
111 | * @throws IllegalArgumentException if {@code expression} is false
112 | * @throws NullPointerException if the check fails and either {@code
113 | * errorMessageTemplate} or {@code errorMessageArgs} is null (don't let
114 | * this happen)
115 | */
116 | public static void checkArgument(boolean expression,
117 | @Nullable String errorMessageTemplate,
118 | @Nullable Object... errorMessageArgs) {
119 | if (!expression) {
120 | throw new IllegalArgumentException(
121 | format(errorMessageTemplate, errorMessageArgs));
122 | }
123 | }
124 |
125 | /**
126 | * Ensures the truth of an expression involving the state of the calling
127 | * instance, but not involving any parameters to the calling method.
128 | *
129 | * @param expression a boolean expression
130 | * @throws IllegalStateException if {@code expression} is false
131 | */
132 | public static void checkState(boolean expression) {
133 | if (!expression) {
134 | throw new IllegalStateException();
135 | }
136 | }
137 |
138 | /**
139 | * Ensures the truth of an expression involving the state of the calling
140 | * instance, but not involving any parameters to the calling method.
141 | *
142 | * @param expression a boolean expression
143 | * @param errorMessage the exception message to use if the check fails; will
144 | * be converted to a string using {@link String#valueOf(Object)}
145 | * @throws IllegalStateException if {@code expression} is false
146 | */
147 | public static void checkState(
148 | boolean expression, @Nullable Object errorMessage) {
149 | if (!expression) {
150 | throw new IllegalStateException(String.valueOf(errorMessage));
151 | }
152 | }
153 |
154 | /**
155 | * Ensures the truth of an expression involving the state of the calling
156 | * instance, but not involving any parameters to the calling method.
157 | *
158 | * @param expression a boolean expression
159 | * @param errorMessageTemplate a template for the exception message should the
160 | * check fail. The message is formed by replacing each {@code %s}
161 | * placeholder in the template with an argument. These are matched by
162 | * position - the first {@code %s} gets {@code errorMessageArgs[0]}, etc.
163 | * Unmatched arguments will be appended to the formatted message in square
164 | * braces. Unmatched placeholders will be left as-is.
165 | * @param errorMessageArgs the arguments to be substituted into the message
166 | * template. Arguments are converted to strings using
167 | * {@link String#valueOf(Object)}.
168 | * @throws IllegalStateException if {@code expression} is false
169 | * @throws NullPointerException if the check fails and either {@code
170 | * errorMessageTemplate} or {@code errorMessageArgs} is null (don't let
171 | * this happen)
172 | */
173 | public static void checkState(boolean expression,
174 | @Nullable String errorMessageTemplate,
175 | @Nullable Object... errorMessageArgs) {
176 | if (!expression) {
177 | throw new IllegalStateException(
178 | format(errorMessageTemplate, errorMessageArgs));
179 | }
180 | }
181 |
182 | /**
183 | * Ensures that an object reference passed as a parameter to the calling
184 | * method is not null.
185 | *
186 | * @param reference an object reference
187 | * @return the non-null reference that was validated
188 | * @throws NullPointerException if {@code reference} is null
189 | */
190 | public static Rate limiters are often used to restrict the rate at which some
36 | * physical or logical resource is accessed. This is in contrast to {@link
37 | * java.util.concurrent.Semaphore} which restricts the number of concurrent
38 | * accesses instead of the rate (note though that concurrency and rate are closely related,
39 | * e.g. see Little's Law).
40 | *
41 | * A {@code RateLimiter} is defined primarily by the rate at which permits
42 | * are issued. Absent additional configuration, permits will be distributed at a
43 | * fixed rate, defined in terms of permits per second. Permits will be distributed
44 | * smoothly, with the delay between individual permits being adjusted to ensure
45 | * that the configured rate is maintained.
46 | *
47 | * It is possible to configure a {@code RateLimiter} to have a warmup
48 | * period during which time the permits issued each second steadily increases until
49 | * it hits the stable rate.
50 | *
51 | * As an example, imagine that we have a list of tasks to execute, but we don't want to
52 | * submit more than 2 per second:
53 | * As another example, imagine that we produce a stream of data, and we want to cap it
64 | * at 5kb per second. This could be accomplished by requiring a permit per byte, and specifying
65 | * a rate of 5000 permits per second:
66 | * It is important to note that the number of permits requested never
75 | * affect the throttling of the request itself (an invocation to {@code acquire(1)}
76 | * and an invocation to {@code acquire(1000)} will result in exactly the same throttling, if any),
77 | * but it affects the throttling of the next request. I.e., if an expensive task
78 | * arrives at an idle RateLimiter, it will be granted immediately, but it is the next
79 | * request that will experience extra throttling, thus paying for the cost of the expensive
80 | * task.
81 | *
82 | * Note: {@code RateLimiter} does not provide fairness guarantees.
83 | *
84 | * @author Dimitris Andreou
85 | * @since 13.0
86 | */
87 | // TODO(user): switch to nano precision. A natural unit of cost is "bytes", and a micro precision
88 | // would mean a maximum rate of "1MB/s", which might be small in some cases.
89 | @ThreadSafe
90 | public abstract class RateLimiter {
91 | /*
92 | * How is the RateLimiter designed, and why?
93 | *
94 | * The primary feature of a RateLimiter is its "stable rate", the maximum rate that
95 | * is should allow at normal conditions. This is enforced by "throttling" incoming
96 | * requests as needed, i.e. compute, for an incoming request, the appropriate throttle time,
97 | * and make the calling thread wait as much.
98 | *
99 | * The simplest way to maintain a rate of QPS is to keep the timestamp of the last
100 | * granted request, and ensure that (1/QPS) seconds have elapsed since then. For example,
101 | * for a rate of QPS=5 (5 tokens per second), if we ensure that a request isn't granted
102 | * earlier than 200ms after the the last one, then we achieve the intended rate.
103 | * If a request comes and the last request was granted only 100ms ago, then we wait for
104 | * another 100ms. At this rate, serving 15 fresh permits (i.e. for an acquire(15) request)
105 | * naturally takes 3 seconds.
106 | *
107 | * It is important to realize that such a RateLimiter has a very superficial memory
108 | * of the past: it only remembers the last request. What if the RateLimiter was unused for
109 | * a long period of time, then a request arrived and was immediately granted?
110 | * This RateLimiter would immediately forget about that past underutilization. This may
111 | * result in either underutilization or overflow, depending on the real world consequences
112 | * of not using the expected rate.
113 | *
114 | * Past underutilization could mean that excess resources are available. Then, the RateLimiter
115 | * should speed up for a while, to take advantage of these resources. This is important
116 | * when the rate is applied to networking (limiting bandwidth), where past underutilization
117 | * typically translates to "almost empty buffers", which can be filled immediately.
118 | *
119 | * On the other hand, past underutilization could mean that "the server responsible for
120 | * handling the request has become less ready for future requests", i.e. its caches become
121 | * stale, and requests become more likely to trigger expensive operations (a more extreme
122 | * case of this example is when a server has just booted, and it is mostly busy with getting
123 | * itself up to speed).
124 | *
125 | * To deal with such scenarios, we add an extra dimension, that of "past underutilization",
126 | * modeled by "storedPermits" variable. This variable is zero when there is no
127 | * underutilization, and it can grow up to maxStoredPermits, for sufficiently large
128 | * underutilization. So, the requested permits, by an invocation acquire(permits),
129 | * are served from:
130 | * - stored permits (if available)
131 | * - fresh permits (for any remaining permits)
132 | *
133 | * How this works is best explained with an example:
134 | *
135 | * For a RateLimiter that produces 1 token per second, every second
136 | * that goes by with the RateLimiter being unused, we increase storedPermits by 1.
137 | * Say we leave the RateLimiter unused for 10 seconds (i.e., we expected a request at time
138 | * X, but we are at time X + 10 seconds before a request actually arrives; this is
139 | * also related to the point made in the last paragraph), thus storedPermits
140 | * becomes 10.0 (assuming maxStoredPermits >= 10.0). At that point, a request of acquire(3)
141 | * arrives. We serve this request out of storedPermits, and reduce that to 7.0 (how this is
142 | * translated to throttling time is discussed later). Immediately after, assume that an
143 | * acquire(10) request arriving. We serve the request partly from storedPermits,
144 | * using all the remaining 7.0 permits, and the remaining 3.0, we serve them by fresh permits
145 | * produced by the rate limiter.
146 | *
147 | * We already know how much time it takes to serve 3 fresh permits: if the rate is
148 | * "1 token per second", then this will take 3 seconds. But what does it mean to serve 7
149 | * stored permits? As explained above, there is no unique answer. If we are primarily
150 | * interested to deal with underutilization, then we want stored permits to be given out
151 | * /faster/ than fresh ones, because underutilization = free resources for the taking.
152 | * If we are primarily interested to deal with overflow, then stored permits could
153 | * be given out /slower/ than fresh ones. Thus, we require a (different in each case)
154 | * function that translates storedPermits to throtting time.
155 | *
156 | * This role is played by storedPermitsToWaitTime(double storedPermits, double permitsToTake).
157 | * The underlying model is a continuous function mapping storedPermits
158 | * (from 0.0 to maxStoredPermits) onto the 1/rate (i.e. intervals) that is effective at the given
159 | * storedPermits. "storedPermits" essentially measure unused time; we spend unused time
160 | * buying/storing permits. Rate is "permits / time", thus "1 / rate = time / permits".
161 | * Thus, "1/rate" (time / permits) times "permits" gives time, i.e., integrals on this
162 | * function (which is what storedPermitsToWaitTime() computes) correspond to minimum intervals
163 | * between subsequent requests, for the specified number of requested permits.
164 | *
165 | * Here is an example of storedPermitsToWaitTime:
166 | * If storedPermits == 10.0, and we want 3 permits, we take them from storedPermits,
167 | * reducing them to 7.0, and compute the throttling for these as a call to
168 | * storedPermitsToWaitTime(storedPermits = 10.0, permitsToTake = 3.0), which will
169 | * evaluate the integral of the function from 7.0 to 10.0.
170 | *
171 | * Using integrals guarantees that the effect of a single acquire(3) is equivalent
172 | * to { acquire(1); acquire(1); acquire(1); }, or { acquire(2); acquire(1); }, etc,
173 | * since the integral of the function in [7.0, 10.0] is equivalent to the sum of the
174 | * integrals of [7.0, 8.0], [8.0, 9.0], [9.0, 10.0] (and so on), no matter
175 | * what the function is. This guarantees that we handle correctly requests of varying weight
176 | * (permits), /no matter/ what the actual function is - so we can tweak the latter freely.
177 | * (The only requirement, obviously, is that we can compute its integrals).
178 | *
179 | * Note well that if, for this function, we chose a horizontal line, at height of exactly
180 | * (1/QPS), then the effect of the function is non-existent: we serve storedPermits at
181 | * exactly the same cost as fresh ones (1/QPS is the cost for each). We use this trick later.
182 | *
183 | * If we pick a function that goes /below/ that horizontal line, it means that we reduce
184 | * the area of the function, thus time. Thus, the RateLimiter becomes /faster/ after a
185 | * period of underutilization. If, on the other hand, we pick a function that
186 | * goes /above/ that horizontal line, then it means that the area (time) is increased,
187 | * thus storedPermits are more costly than fresh permits, thus the RateLimiter becomes
188 | * /slower/ after a period of underutilization.
189 | *
190 | * Last, but not least: consider a RateLimiter with rate of 1 permit per second, currently
191 | * completely unused, and an expensive acquire(100) request comes. It would be nonsensical
192 | * to just wait for 100 seconds, and /then/ start the actual task. Why wait without doing
193 | * anything? A much better approach is to /allow/ the request right away (as if it was an
194 | * acquire(1) request instead), and postpone /subsequent/ requests as needed. In this version,
195 | * we allow starting the task immediately, and postpone by 100 seconds future requests,
196 | * thus we allow for work to get done in the meantime instead of waiting idly.
197 | *
198 | * This has important consequences: it means that the RateLimiter doesn't remember the time
199 | * of the _last_ request, but it remembers the (expected) time of the _next_ request. This
200 | * also enables us to tell immediately (see tryAcquire(timeout)) whether a particular
201 | * timeout is enough to get us to the point of the next scheduling time, since we always
202 | * maintain that. And what we mean by "an unused RateLimiter" is also defined by that
203 | * notion: when we observe that the "expected arrival time of the next request" is actually
204 | * in the past, then the difference (now - past) is the amount of time that the RateLimiter
205 | * was formally unused, and it is that amount of time which we translate to storedPermits.
206 | * (We increase storedPermits with the amount of permits that would have been produced
207 | * in that idle time). So, if rate == 1 permit per second, and arrivals come exactly
208 | * one second after the previous, then storedPermits is _never_ increased -- we would only
209 | * increase it for arrivals _later_ than the expected one second.
210 | */
211 |
212 | /**
213 | * Creates a {@code RateLimiter} with the specified stable throughput, given as
214 | * "permits per second" (commonly referred to as QPS, queries per second).
215 | *
216 | * The returned {@code RateLimiter} ensures that on average no more than {@code
217 | * permitsPerSecond} are issued during any given second, with sustained requests
218 | * being smoothly spread over each second. When the incoming request rate exceeds
219 | * {@code permitsPerSecond} the rate limiter will release one permit every {@code
220 | * (1.0 / permitsPerSecond)} seconds. When the rate limiter is unused,
221 | * bursts of up to {@code permitsPerSecond} permits will be allowed, with subsequent
222 | * requests being smoothly limited at the stable rate of {@code permitsPerSecond}.
223 | *
224 | * @param permitsPerSecond the rate of the returned {@code RateLimiter}, measured in
225 | * how many permits become available per second.
226 | */
227 | public static RateLimiter create(double permitsPerSecond) {
228 | return create(SleepingTicker.SYSTEM_TICKER, permitsPerSecond);
229 | }
230 |
231 | static RateLimiter create(SleepingTicker ticker, double permitsPerSecond) {
232 | RateLimiter rateLimiter = new Bursty(ticker);
233 | rateLimiter.setRate(permitsPerSecond);
234 | return rateLimiter;
235 | }
236 |
237 | /**
238 | * Creates a {@code RateLimiter} with the specified stable throughput, given as
239 | * "permits per second" (commonly referred to as QPS, queries per second), and a
240 | * warmup period, during which the {@code RateLimiter} smoothly ramps up its rate,
241 | * until it reaches its maximum rate at the end of the period (as long as there are enough
242 | * requests to saturate it). Similarly, if the {@code RateLimiter} is left unused for
243 | * a duration of {@code warmupPeriod}, it will gradually return to its "cold" state,
244 | * i.e. it will go through the same warming up process as when it was first created.
245 | *
246 | * The returned {@code RateLimiter} is intended for cases where the resource that actually
247 | * fulfils the requests (e.g., a remote server) needs "warmup" time, rather than
248 | * being immediately accessed at the stable (maximum) rate.
249 | *
250 | * The returned {@code RateLimiter} starts in a "cold" state (i.e. the warmup period
251 | * will follow), and if it is left unused for long enough, it will return to that state.
252 | *
253 | * @param permitsPerSecond the rate of the returned {@code RateLimiter}, measured in
254 | * how many permits become available per second
255 | * @param warmupPeriod the duration of the period where the {@code RateLimiter} ramps up its
256 | * rate, before reaching its stable (maximum) rate
257 | * @param unit the time unit of the warmupPeriod argument
258 | */
259 | // TODO(user): add a burst size of 1-second-worth of permits, as in the metronome?
260 | public static RateLimiter create(double permitsPerSecond, long warmupPeriod, TimeUnit unit) {
261 | return create(SleepingTicker.SYSTEM_TICKER, permitsPerSecond, warmupPeriod, unit);
262 | }
263 |
264 | static RateLimiter create(
265 | SleepingTicker ticker, double permitsPerSecond, long warmupPeriod, TimeUnit timeUnit) {
266 | RateLimiter rateLimiter = new WarmingUp(ticker, warmupPeriod, timeUnit);
267 | rateLimiter.setRate(permitsPerSecond);
268 | return rateLimiter;
269 | }
270 |
271 | static RateLimiter createBursty(
272 | SleepingTicker ticker, double permitsPerSecond, int maxBurstSize) {
273 | Bursty rateLimiter = new Bursty(ticker);
274 | rateLimiter.setRate(permitsPerSecond);
275 | rateLimiter.maxPermits = maxBurstSize;
276 | return rateLimiter;
277 | }
278 |
279 | /**
280 | * The underlying timer; used both to measure elapsed time and sleep as necessary. A separate
281 | * object to facilitate testing.
282 | */
283 | private final SleepingTicker ticker;
284 |
285 | /**
286 | * The timestamp when the RateLimiter was created; used to avoid possible overflow/time-wrapping
287 | * errors.
288 | */
289 | private final long offsetNanos;
290 |
291 | /**
292 | * The currently stored permits.
293 | */
294 | double storedPermits;
295 |
296 | /**
297 | * The maximum number of stored permits.
298 | */
299 | double maxPermits;
300 |
301 | /**
302 | * The interval between two unit requests, at our stable rate. E.g., a stable rate of 5 permits
303 | * per second has a stable interval of 200ms.
304 | */
305 | volatile double stableIntervalMicros;
306 |
307 | private final Object mutex = new Object();
308 |
309 | /**
310 | * The time when the next request (no matter its size) will be granted. After granting a request,
311 | * this is pushed further in the future. Large requests push this further than small requests.
312 | */
313 | private long nextFreeTicketMicros = 0L; // could be either in the past or future
314 |
315 | private RateLimiter(SleepingTicker ticker) {
316 | this.ticker = ticker;
317 | this.offsetNanos = ticker.read();
318 | }
319 |
320 | /**
321 | * Updates the stable rate of this {@code RateLimiter}, that is, the
322 | * {@code permitsPerSecond} argument provided in the factory method that
323 | * constructed the {@code RateLimiter}. Currently throttled threads will not
324 | * be awakened as a result of this invocation, thus they do not observe the new rate;
325 | * only subsequent requests will.
326 | *
327 | * Note though that, since each request repays (by waiting, if necessary) the cost
328 | * of the previous request, this means that the very next request
329 | * after an invocation to {@code setRate} will not be affected by the new rate;
330 | * it will pay the cost of the previous request, which is in terms of the previous rate.
331 | *
332 | * The behavior of the {@code RateLimiter} is not modified in any other way,
333 | * e.g. if the {@code RateLimiter} was configured with a warmup period of 20 seconds,
334 | * it still has a warmup period of 20 seconds after this method invocation.
335 | *
336 | * @param permitsPerSecond the new stable rate of this {@code RateLimiter}.
337 | */
338 | public final void setRate(double permitsPerSecond) {
339 | Preconditions.checkArgument(permitsPerSecond > 0.0
340 | && !Double.isNaN(permitsPerSecond), "rate must be positive");
341 | synchronized (mutex) {
342 | resync(readSafeMicros());
343 | double stableIntervalMicros = TimeUnit.SECONDS.toMicros(1L) / permitsPerSecond;
344 | this.stableIntervalMicros = stableIntervalMicros;
345 | doSetRate(permitsPerSecond, stableIntervalMicros);
346 | }
347 | }
348 |
349 | abstract void doSetRate(double permitsPerSecond, double stableIntervalMicros);
350 |
351 | /**
352 | * Returns the stable rate (as {@code permits per seconds}) with which this
353 | * {@code RateLimiter} is configured with. The initial value of this is the same as
354 | * the {@code permitsPerSecond} argument passed in the factory method that produced
355 | * this {@code RateLimiter}, and it is only updated after invocations
356 | * to {@linkplain #setRate}.
357 | */
358 | public final double getRate() {
359 | return TimeUnit.SECONDS.toMicros(1L) / stableIntervalMicros;
360 | }
361 |
362 | /**
363 | * Acquires a permit from this {@code RateLimiter}, blocking until the request can be granted.
364 | *
365 | * This method is equivalent to {@code acquire(1)}.
366 | */
367 | public void acquire() {
368 | acquire(1);
369 | }
370 |
371 | /**
372 | * Acquires the given number of permits from this {@code RateLimiter}, blocking until the
373 | * request be granted.
374 | *
375 | * @param permits the number of permits to acquire
376 | */
377 | public void acquire(int permits) {
378 | checkPermits(permits);
379 | long microsToWait;
380 | synchronized (mutex) {
381 | microsToWait = reserveNextTicket(permits, readSafeMicros());
382 | }
383 | ticker.sleepMicrosUninterruptibly(microsToWait);
384 | }
385 |
386 | /**
387 | * Acquires a permit from this {@code RateLimiter} if it can be obtained
388 | * without exceeding the specified {@code timeout}, or returns {@code false}
389 | * immediately (without waiting) if the permit would not have been granted
390 | * before the timeout expired.
391 | *
392 | * This method is equivalent to {@code tryAcquire(1, timeout, unit)}.
393 | *
394 | * @param timeout the maximum time to wait for the permit
395 | * @param unit the time unit of the timeout argument
396 | * @return {@code true} if the permit was acquired, {@code false} otherwise
397 | */
398 | public boolean tryAcquire(long timeout, TimeUnit unit) {
399 | return tryAcquire(1, timeout, unit);
400 | }
401 |
402 | /**
403 | * Acquires permits from this {@link RateLimiter} if it can be acquired immediately without delay.
404 | *
405 | *
406 | * This method is equivalent to {@code tryAcquire(permits, 0, anyUnit)}.
407 | *
408 | * @param permits the number of permits to acquire
409 | * @return {@code true} if the permits were acquired, {@code false} otherwise
410 | * @since 14.0
411 | */
412 | public boolean tryAcquire(int permits) {
413 | return tryAcquire(permits, 0, TimeUnit.MICROSECONDS);
414 | }
415 |
416 | /**
417 | * Acquires a permit from this {@link RateLimiter} if it can be acquired immediately without
418 | * delay.
419 | *
420 | *
421 | * This method is equivalent to {@code tryAcquire(1)}.
422 | *
423 | * @return {@code true} if the permit was acquired, {@code false} otherwise
424 | * @since 14.0
425 | */
426 | public boolean tryAcquire() {
427 | return tryAcquire(1, 0, TimeUnit.MICROSECONDS);
428 | }
429 |
430 | /**
431 | * Acquires the given number of permits from this {@code RateLimiter} if it can be obtained
432 | * without exceeding the specified {@code timeout}, or returns {@code false}
433 | * immediately (without waiting) if the permits would not have been granted
434 | * before the timeout expired.
435 | *
436 | * @param permits the number of permits to acquire
437 | * @param timeout the maximum time to wait for the permits
438 | * @param unit the time unit of the timeout argument
439 | * @return {@code true} if the permits were acquired, {@code false} otherwise
440 | */
441 | public boolean tryAcquire(int permits, long timeout, TimeUnit unit) {
442 | long timeoutMicros = unit.toMicros(timeout);
443 | checkPermits(permits);
444 | long microsToWait;
445 | synchronized (mutex) {
446 | long nowMicros = readSafeMicros();
447 | if (nextFreeTicketMicros > nowMicros + timeoutMicros) {
448 | return false;
449 | } else {
450 | microsToWait = reserveNextTicket(permits, nowMicros);
451 | }
452 | }
453 | ticker.sleepMicrosUninterruptibly(microsToWait);
454 | return true;
455 | }
456 |
457 | private static void checkPermits(int permits) {
458 | Preconditions.checkArgument(permits > 0, "Requested permits must be positive");
459 | }
460 |
461 | /**
462 | * Reserves next ticket and returns the wait time that the caller must wait for.
463 | */
464 | private long reserveNextTicket(double requiredPermits, long nowMicros) {
465 | resync(nowMicros);
466 | long microsToNextFreeTicket = nextFreeTicketMicros - nowMicros;
467 | double storedPermitsToSpend = Math.min(requiredPermits, this.storedPermits);
468 | double freshPermits = requiredPermits - storedPermitsToSpend;
469 |
470 | long waitMicros = storedPermitsToWaitTime(this.storedPermits, storedPermitsToSpend)
471 | + (long) (freshPermits * stableIntervalMicros);
472 |
473 | this.nextFreeTicketMicros = nextFreeTicketMicros + waitMicros;
474 | this.storedPermits -= storedPermitsToSpend;
475 | return microsToNextFreeTicket;
476 | }
477 |
478 | /**
479 | * Translates a specified portion of our currently stored permits which we want to
480 | * spend/acquire, into a throttling time. Conceptually, this evaluates the integral
481 | * of the underlying function we use, for the range of
482 | * [(storedPermits - permitsToTake), storedPermits].
483 | *
484 | * This always holds: {@code 0 <= permitsToTake <= storedPermits}
485 | */
486 | abstract long storedPermitsToWaitTime(double storedPermits, double permitsToTake);
487 |
488 | private void resync(long nowMicros) {
489 | // if nextFreeTicket is in the past, resync to now
490 | if (nowMicros > nextFreeTicketMicros) {
491 | storedPermits = Math.min(maxPermits,
492 | storedPermits + (nowMicros - nextFreeTicketMicros) / stableIntervalMicros);
493 | nextFreeTicketMicros = nowMicros;
494 | }
495 | }
496 |
497 | private long readSafeMicros() {
498 | return TimeUnit.NANOSECONDS.toMicros(ticker.read() - offsetNanos);
499 | }
500 |
501 | @Override
502 | public String toString() {
503 | return String.format("RateLimiter[stableRate=%3.1fqps]", 1000000.0 / stableIntervalMicros);
504 | }
505 |
506 | /**
507 | * This implements the following function:
508 | *
509 | * ^ throttling
510 | * |
511 | * 3*stable + /
512 | * interval | /.
513 | * (cold) | / .
514 | * | / . <-- "warmup period" is the area of the trapezoid between
515 | * 2*stable + / . halfPermits and maxPermits
516 | * interval | / .
517 | * | / .
518 | * | / .
519 | * stable +----------/ WARM . }
520 | * interval | . UP . } <-- this rectangle (from 0 to maxPermits, and
521 | * | . PERIOD. } height == stableInterval) defines the cooldown period,
522 | * | . . } and we want cooldownPeriod == warmupPeriod
523 | * |---------------------------------> storedPermits
524 | * (halfPermits) (maxPermits)
525 | *
526 | * Before going into the details of this particular function, let's keep in mind the basics:
527 | * 1) The state of the RateLimiter (storedPermits) is a vertical line in this figure.
528 | * 2) When the RateLimiter is not used, this goes right (up to maxPermits)
529 | * 3) When the RateLimiter is used, this goes left (down to zero), since if we have storedPermits,
530 | * we serve from those first
531 | * 4) When _unused_, we go right at the same speed (rate)! I.e., if our rate is
532 | * 2 permits per second, and 3 unused seconds pass, we will always save 6 permits
533 | * (no matter what our initial position was), up to maxPermits.
534 | * If we invert the rate, we get the "stableInterval" (interval between two requests
535 | * in a perfectly spaced out sequence of requests of the given rate). Thus, if you
536 | * want to see "how much time it will take to go from X storedPermits to X+K storedPermits?",
537 | * the answer is always stableInterval * K. In the same example, for 2 permits per second,
538 | * stableInterval is 500ms. Thus to go from X storedPermits to X+6 storedPermits, we
539 | * require 6 * 500ms = 3 seconds.
540 | *
541 | * In short, the time it takes to move to the right (save K permits) is equal to the
542 | * rectangle of width == K and height == stableInterval.
543 | * 4) When _used_, the time it takes, as explained in the introductory class note, is
544 | * equal to the integral of our function, between X permits and X-K permits, assuming
545 | * we want to spend K saved permits.
546 | *
547 | * In summary, the time it takes to move to the left (spend K permits), is equal to the
548 | * area of the function of width == K.
549 | *
550 | * Let's dive into this function now:
551 | *
552 | * When we have storedPermits <= halfPermits (the left portion of the function), then
553 | * we spend them at the exact same rate that
554 | * fresh permits would be generated anyway (that rate is 1/stableInterval). We size
555 | * this area to be equal to _half_ the specified warmup period. Why we need this?
556 | * And why half? We'll explain shortly below (after explaining the second part).
557 | *
558 | * Stored permits that are beyond halfPermits, are mapped to an ascending line, that goes
559 | * from stableInterval to 3 * stableInterval. The average height for that part is
560 | * 2 * stableInterval, and is sized appropriately to have an area _equal_ to the
561 | * specified warmup period. Thus, by point (4) above, it takes "warmupPeriod" amount of time
562 | * to go from maxPermits to halfPermits.
563 | *
564 | * BUT, by point (3) above, it only takes "warmupPeriod / 2" amount of time to return back
565 | * to maxPermits, from halfPermits! (Because the trapezoid has double the area of the rectangle
566 | * of height stableInterval and equivalent width). We decided that the "cooldown period"
567 | * time should be equivalent to "warmup period", thus a fully saturated RateLimiter
568 | * (with zero stored permits, serving only fresh ones) can go to a fully unsaturated
569 | * (with storedPermits == maxPermits) in the same amount of time it takes for a fully
570 | * unsaturated RateLimiter to return to the stableInterval -- which happens in halfPermits,
571 | * since beyond that point, we use a horizontal line of "stableInterval" height, simulating
572 | * the regular rate.
573 | *
574 | * Thus, we have figured all dimensions of this shape, to give all the desired
575 | * properties:
576 | * - the width is warmupPeriod / stableInterval, to make cooldownPeriod == warmupPeriod
577 | * - the slope starts at the middle, and goes from stableInterval to 3*stableInterval so
578 | * to have halfPermits being spend in double the usual time (half the rate), while their
579 | * respective rate is steadily ramping up
580 | */
581 | private static class WarmingUp extends RateLimiter {
582 |
583 | final long warmupPeriodMicros;
584 | /**
585 | * The slope of the line from the stable interval (when permits == 0), to the cold interval
586 | * (when permits == maxPermits)
587 | */
588 | private double slope;
589 | private double halfPermits;
590 |
591 | WarmingUp(SleepingTicker ticker, long warmupPeriod, TimeUnit timeUnit) {
592 | super(ticker);
593 | this.warmupPeriodMicros = timeUnit.toMicros(warmupPeriod);
594 | }
595 |
596 | @Override
597 | void doSetRate(double permitsPerSecond, double stableIntervalMicros) {
598 | double oldMaxPermits = maxPermits;
599 | maxPermits = warmupPeriodMicros / stableIntervalMicros;
600 | halfPermits = maxPermits / 2.0;
601 | // Stable interval is x, cold is 3x, so on average it's 2x. Double the time -> halve the rate
602 | double coldIntervalMicros = stableIntervalMicros * 3.0;
603 | slope = (coldIntervalMicros - stableIntervalMicros) / halfPermits;
604 | if (oldMaxPermits == Double.POSITIVE_INFINITY) {
605 | // if we don't special-case this, we would get storedPermits == NaN, below
606 | storedPermits = 0.0;
607 | } else {
608 | storedPermits = (oldMaxPermits == 0.0)
609 | ? maxPermits // initial state is cold
610 | : storedPermits * maxPermits / oldMaxPermits;
611 | }
612 | }
613 |
614 | @Override
615 | long storedPermitsToWaitTime(double storedPermits, double permitsToTake) {
616 | double availablePermitsAboveHalf = storedPermits - halfPermits;
617 | long micros = 0;
618 | // measuring the integral on the right part of the function (the climbing line)
619 | if (availablePermitsAboveHalf > 0.0) {
620 | double permitsAboveHalfToTake = Math.min(availablePermitsAboveHalf, permitsToTake);
621 | micros = (long) (permitsAboveHalfToTake * (permitsToTime(availablePermitsAboveHalf)
622 | + permitsToTime(availablePermitsAboveHalf - permitsAboveHalfToTake)) / 2.0);
623 | permitsToTake -= permitsAboveHalfToTake;
624 | }
625 | // measuring the integral on the left part of the function (the horizontal line)
626 | micros += (stableIntervalMicros * permitsToTake);
627 | return micros;
628 | }
629 |
630 | private double permitsToTime(double permits) {
631 | return stableIntervalMicros + permits * slope;
632 | }
633 | }
634 |
635 | /**
636 | * This implements a trivial function, where storedPermits are translated to
637 | * zero throttling - thus, a client gets an infinite speedup for permits acquired out
638 | * of the storedPermits pool. This is also used for the special case of the "metronome",
639 | * where the width of the function is also zero; maxStoredPermits is zero, thus
640 | * storedPermits and permitsToTake are always zero as well. Such a RateLimiter can
641 | * not save permits when unused, thus all permits it serves are fresh, using the
642 | * designated rate.
643 | */
644 | private static class Bursty extends RateLimiter {
645 | Bursty(SleepingTicker ticker) {
646 | super(ticker);
647 | }
648 |
649 | @Override
650 | void doSetRate(double permitsPerSecond, double stableIntervalMicros) {
651 | double oldMaxPermits = this.maxPermits;
652 | /*
653 | * We allow the equivalent work of up to one second to be granted with zero waiting, if the
654 | * rate limiter has been unused for as much. This is to avoid potentially producing tiny
655 | * wait interval between subsequent requests for sufficiently large rates, which would
656 | * unnecessarily overconstrain the thread scheduler.
657 | */
658 | maxPermits = permitsPerSecond; // one second worth of permits
659 | storedPermits = (oldMaxPermits == 0.0)
660 | ? 0.0 // initial state
661 | : storedPermits * maxPermits / oldMaxPermits;
662 | }
663 |
664 | @Override
665 | long storedPermitsToWaitTime(double storedPermits, double permitsToTake) {
666 | return 0L;
667 | }
668 | }
669 | }
670 |
671 | abstract class SleepingTicker extends Ticker {
672 | abstract void sleepMicrosUninterruptibly(long micros);
673 |
674 | static final SleepingTicker SYSTEM_TICKER = new SleepingTicker() {
675 | @Override
676 | public long read() {
677 | return systemTicker().read();
678 | }
679 |
680 | @Override
681 | public void sleepMicrosUninterruptibly(long micros) {
682 | if (micros > 0) {
683 | Uninterruptibles.sleepUninterruptibly(micros, TimeUnit.MICROSECONDS);
684 | }
685 | }
686 | };
687 | }
688 |
--------------------------------------------------------------------------------
32 | * if (count <= 0) {
33 | * throw new IllegalArgumentException("must be positive: " + count);
34 | * }
35 | *
36 | * to be replaced with the more compact
37 | *
38 | * checkArgument(count > 0, "must be positive: %s", count);
39 | *
40 | * Note that the sense of the expression is inverted; with {@code Preconditions}
41 | * you declare what you expect to be true, just as you do with an
42 | *
43 | * {@code assert} or a JUnit {@code assertTrue} call.
44 | *
45 | * {@code
54 | * final RateLimiter rateLimiter = RateLimiter.create(2.0); // rate is "2 permits per second"
55 | * void submitTasks(List
62 | *
63 | * {@code
67 | * final RateLimiter rateLimiter = RateLimiter.create(5000.0); // rate = 5000 permits per second
68 | * void submitPacket(byte[] packet) {
69 | * rateLimiter.acquire(packet.length);
70 | * networkService.send(packet);
71 | * }
72 | *}
73 | *
74 | *