├── project ├── build.properties └── plugins.sbt ├── src ├── main │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── org.apache.spark.sql.sources.DataSourceRegister │ ├── scala │ │ └── com │ │ │ └── audienceproject │ │ │ └── spark │ │ │ └── dynamodb │ │ │ ├── attribute.scala │ │ │ ├── datasource │ │ │ ├── OutputPartitioning.scala │ │ │ ├── ScanPartition.scala │ │ │ ├── DynamoDataDeleteWriter.scala │ │ │ ├── DynamoWriteBuilder.scala │ │ │ ├── DefaultSource.scala │ │ │ ├── DynamoDataUpdateWriter.scala │ │ │ ├── DynamoBatchReader.scala │ │ │ ├── DynamoDataWriter.scala │ │ │ ├── DynamoWriterFactory.scala │ │ │ ├── DynamoScanBuilder.scala │ │ │ ├── DynamoReaderFactory.scala │ │ │ ├── DynamoTable.scala │ │ │ └── TypeConversion.scala │ │ │ ├── connector │ │ │ ├── KeySchema.scala │ │ │ ├── DynamoWritable.scala │ │ │ ├── ColumnSchema.scala │ │ │ ├── TableIndexConnector.scala │ │ │ ├── FilterPushdown.scala │ │ │ ├── DynamoConnector.scala │ │ │ └── TableConnector.scala │ │ │ ├── reflect │ │ │ └── SchemaAnalysis.scala │ │ │ ├── implicits.scala │ │ │ └── catalyst │ │ │ └── JavaConverter.scala │ └── java │ │ └── com │ │ └── audienceproject │ │ └── shaded │ │ └── google │ │ └── common │ │ ├── base │ │ ├── Ticker.java │ │ └── Preconditions.java │ │ └── util │ │ └── concurrent │ │ ├── Uninterruptibles.java │ │ └── RateLimiter.java └── test │ ├── scala │ └── com │ │ └── audienceproject │ │ └── spark │ │ └── dynamodb │ │ ├── structs │ │ ├── TestFruit.scala │ │ └── TestFruitWithProperties.scala │ │ ├── NullValuesTest.scala │ │ ├── NullBooleanTest.scala │ │ ├── FilterPushdownTest.scala │ │ ├── DefaultSourceTest.scala │ │ ├── RegionTest.scala │ │ ├── AbstractInMemoryTest.scala │ │ ├── WriteRelationTest.scala │ │ └── NestedDataStructuresTest.scala │ └── resources │ └── log4j2.xml ├── .editorconfig ├── .gitignore ├── wercker.yml ├── README.md └── LICENSE /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.2.6 2 | -------------------------------------------------------------------------------- /src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | com.audienceproject.spark.dynamodb.datasource.DefaultSource 2 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | [*] 3 | end_of_line = lf 4 | insert_final_newline = true 5 | charset = utf-8 6 | indent_style = space 7 | indent_size = 4 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | /bin/ 3 | *.class 4 | *.log 5 | .classpath 6 | .idea 7 | .wercker 8 | project/target 9 | project/project 10 | lib_managed*/ 11 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn 2 | 3 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.0") 4 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.4") 5 | addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2") 6 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/structs/TestFruit.scala: -------------------------------------------------------------------------------- 1 | package com.audienceproject.spark.dynamodb.structs 2 | 3 | import com.audienceproject.spark.dynamodb.attribute 4 | 5 | case class TestFruit(@attribute("name") primaryKey: String, 6 | color: String, 7 | weightKg: Double) 8 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/structs/TestFruitWithProperties.scala: -------------------------------------------------------------------------------- 1 | package com.audienceproject.spark.dynamodb.structs 2 | 3 | case class TestFruitProperties(freshness: String, 4 | eco: Boolean, 5 | price: Double) 6 | 7 | case class TestFruitWithProperties(name: String, 8 | color: String, 9 | weight: Double, 10 | properties: TestFruitProperties) 11 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/attribute.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb 22 | 23 | import scala.annotation.StaticAnnotation 24 | 25 | final case class attribute(name: String) extends StaticAnnotation 26 | -------------------------------------------------------------------------------- /wercker.yml: -------------------------------------------------------------------------------- 1 | box: 2 | id: audienceproject/jvm 3 | username: $DOCKERHUB_ACCOUNT 4 | password: $DOCKERHUB_PASSWORD 5 | tag: latest 6 | 7 | build: 8 | steps: 9 | - script: 10 | name: Compile 11 | code: sbt clean compile 12 | - audienceproject/aws-cli-assume-role@1.0.2: 13 | aws-access-key-id: $AWS_ACCESS_KEY 14 | aws-secret-access-key: $AWS_SECRET_KEY 15 | role-arn: arn:aws:iam::$AWS_ACCOUNT_ID:role/build-$WERCKER_GIT_REPOSITORY 16 | - script: 17 | name: Test 18 | code: sbt clean compile test 19 | - script: 20 | name: Clean again 21 | code: sbt clean 22 | 23 | publish-snapshot: 24 | steps: 25 | - audienceproject/sbt-to-maven-central@2.0.0: 26 | user: $NEXUS_USER 27 | password: $NEXUS_PASSWORD 28 | private-key: $NEXUS_PK 29 | passphrase: $NEXUS_PASSPHRASE 30 | 31 | publish-release: 32 | steps: 33 | - audienceproject/sbt-to-maven-central@2.0.0: 34 | user: $NEXUS_USER 35 | password: $NEXUS_PASSWORD 36 | private-key: $NEXUS_PK 37 | passphrase: $NEXUS_PASSPHRASE 38 | destination: RELEASE 39 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/OutputPartitioning.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import org.apache.spark.sql.connector.read.partitioning.{Distribution, Partitioning} 24 | 25 | class OutputPartitioning(override val numPartitions: Int) extends Partitioning { 26 | 27 | override def satisfy(distribution: Distribution): Boolean = false 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/test/resources/log4j2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/ScanPartition.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import org.apache.spark.sql.connector.read.InputPartition 24 | import org.apache.spark.sql.sources.Filter 25 | 26 | class ScanPartition(val partitionIndex: Int, 27 | val requiredColumns: Seq[String], 28 | val filters: Array[Filter]) 29 | extends InputPartition 30 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/connector/KeySchema.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.connector 22 | 23 | import com.amazonaws.services.dynamodbv2.model.{KeySchemaElement, KeyType} 24 | 25 | private[dynamodb] case class KeySchema(hashKeyName: String, rangeKeyName: Option[String]) 26 | 27 | private[dynamodb] object KeySchema { 28 | 29 | def fromDescription(keySchemaElements: Seq[KeySchemaElement]): KeySchema = { 30 | val hashKeyName = keySchemaElements.find(_.getKeyType == KeyType.HASH.toString).get.getAttributeName 31 | val rangeKeyName = keySchemaElements.find(_.getKeyType == KeyType.RANGE.toString).map(_.getAttributeName) 32 | KeySchema(hashKeyName, rangeKeyName) 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataDeleteWriter.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2020 AudienceProject. All rights reserved. 20 | */ 21 | 22 | package com.audienceproject.spark.dynamodb.datasource 23 | 24 | import com.amazonaws.services.dynamodbv2.document.DynamoDB 25 | import com.audienceproject.spark.dynamodb.connector.{ColumnSchema, TableConnector} 26 | 27 | class DynamoDataDeleteWriter(batchSize: Int, 28 | columnSchema: ColumnSchema, 29 | connector: TableConnector, 30 | client: DynamoDB) 31 | extends DynamoDataWriter(batchSize, columnSchema, connector, client) { 32 | 33 | protected override def flush(): Unit = { 34 | if (buffer.nonEmpty) { 35 | connector.deleteItems(columnSchema, buffer)(client, rateLimiter) 36 | buffer.clear() 37 | } 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/NullValuesTest.scala: -------------------------------------------------------------------------------- 1 | package com.audienceproject.spark.dynamodb 2 | 3 | import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput} 4 | import com.audienceproject.spark.dynamodb.implicits._ 5 | import org.apache.spark.sql.Row 6 | import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} 7 | 8 | class NullValuesTest extends AbstractInMemoryTest { 9 | 10 | test("Insert nested StructType with null values") { 11 | dynamoDB.createTable(new CreateTableRequest() 12 | .withTableName("NullTest") 13 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 14 | .withKeySchema(new KeySchemaElement("name", "HASH")) 15 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 16 | 17 | val schema = StructType( 18 | Seq( 19 | StructField("name", StringType, nullable = false), 20 | StructField("info", StructType( 21 | Seq( 22 | StructField("age", IntegerType, nullable = true), 23 | StructField("address", StringType, nullable = true) 24 | ) 25 | ), nullable = true) 26 | ) 27 | ) 28 | 29 | val rows = spark.sparkContext.parallelize(Seq( 30 | Row("one", Row(30, "Somewhere")), 31 | Row("two", null), 32 | Row("three", Row(null, null)) 33 | )) 34 | 35 | val newItemsDs = spark.createDataFrame(rows, schema) 36 | 37 | newItemsDs.write.dynamodb("NullTest") 38 | 39 | val validationDs = spark.read.dynamodb("NullTest") 40 | 41 | validationDs.show(false) 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoWritable.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.connector 22 | 23 | import com.amazonaws.services.dynamodbv2.document.DynamoDB 24 | import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter 25 | import org.apache.spark.sql.catalyst.InternalRow 26 | 27 | private[dynamodb] trait DynamoWritable { 28 | 29 | val writeLimit: Double 30 | 31 | def putItems(columnSchema: ColumnSchema, items: Seq[InternalRow]) 32 | (client: DynamoDB, rateLimiter: RateLimiter): Unit 33 | 34 | def updateItem(columnSchema: ColumnSchema, item: InternalRow) 35 | (client: DynamoDB, rateLimiter: RateLimiter): Unit 36 | 37 | def deleteItems(columnSchema: ColumnSchema, itema: Seq[InternalRow]) 38 | (client: DynamoDB, rateLimiter: RateLimiter): Unit 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoWriteBuilder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import com.audienceproject.spark.dynamodb.connector.TableConnector 24 | import org.apache.spark.sql.connector.write._ 25 | import org.apache.spark.sql.types.StructType 26 | 27 | class DynamoWriteBuilder(connector: TableConnector, parameters: Map[String, String], schema: StructType) 28 | extends WriteBuilder { 29 | 30 | override def buildForBatch(): BatchWrite = new BatchWrite { 31 | override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = 32 | new DynamoWriterFactory(connector, parameters, schema) 33 | 34 | override def commit(messages: Array[WriterCommitMessage]): Unit = {} 35 | 36 | override def abort(messages: Array[WriterCommitMessage]): Unit = {} 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/NullBooleanTest.scala: -------------------------------------------------------------------------------- 1 | package com.audienceproject.spark.dynamodb 2 | 3 | import com.amazonaws.services.dynamodbv2.document.Item 4 | import com.amazonaws.services.dynamodbv2.model.{ 5 | AttributeDefinition, 6 | CreateTableRequest, 7 | KeySchemaElement, 8 | ProvisionedThroughput 9 | } 10 | import com.audienceproject.spark.dynamodb.implicits._ 11 | 12 | class NullBooleanTest extends AbstractInMemoryTest { 13 | test("Test Null") { 14 | dynamoDB.createTable( 15 | new CreateTableRequest() 16 | .withTableName("TestNullBoolean") 17 | .withAttributeDefinitions(new AttributeDefinition("Pk", "S")) 18 | .withKeySchema(new KeySchemaElement("Pk", "HASH")) 19 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L)) 20 | ) 21 | 22 | val table = dynamoDB.getTable("TestNullBoolean") 23 | 24 | for ((_pk, _type, _value) <- Seq( 25 | ("id1", "type1", true), 26 | ("id2", "type2", null) 27 | )) { 28 | if (_type != "type2") { 29 | table.putItem( 30 | new Item() 31 | .withString("Pk", _pk) 32 | .withString("Type", _type) 33 | .withBoolean("Value", _value.asInstanceOf[Boolean]) 34 | ) 35 | } else { 36 | table.putItem( 37 | new Item() 38 | .withString("Pk", _pk) 39 | .withString("Type", _type) 40 | .withNull("Value") 41 | ) 42 | } 43 | } 44 | 45 | val df = spark.read.dynamodbAs[BooleanClass]("TestNullBoolean") 46 | 47 | import spark.implicits._ 48 | df.where($"Type" === "type2").show() 49 | client.deleteTable("TestNullBoolean") 50 | } 51 | } 52 | 53 | case class BooleanClass(Pk: String, Type: String, Value: Boolean) 54 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import java.util 24 | 25 | import org.apache.spark.sql.connector.catalog.{Table, TableProvider} 26 | import org.apache.spark.sql.connector.expressions.Transform 27 | import org.apache.spark.sql.sources.DataSourceRegister 28 | import org.apache.spark.sql.types.StructType 29 | import org.apache.spark.sql.util.CaseInsensitiveStringMap 30 | 31 | class DefaultSource extends TableProvider with DataSourceRegister { 32 | 33 | override def getTable(schema: StructType, 34 | partitioning: Array[Transform], 35 | properties: util.Map[String, String]): Table = { 36 | new DynamoTable(new CaseInsensitiveStringMap(properties), Some(schema)) 37 | } 38 | 39 | override def inferSchema(options: CaseInsensitiveStringMap): StructType = { 40 | new DynamoTable(options).schema() 41 | } 42 | 43 | override def supportsExternalMetadata(): Boolean = true 44 | 45 | override def shortName(): String = "dynamodb" 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataUpdateWriter.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import com.amazonaws.services.dynamodbv2.document.DynamoDB 24 | import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter 25 | import com.audienceproject.spark.dynamodb.connector.{ColumnSchema, TableConnector} 26 | import org.apache.spark.sql.catalyst.InternalRow 27 | import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} 28 | 29 | class DynamoDataUpdateWriter(columnSchema: ColumnSchema, 30 | connector: TableConnector, 31 | client: DynamoDB) 32 | extends DataWriter[InternalRow] { 33 | 34 | private val rateLimiter = RateLimiter.create(connector.writeLimit) 35 | 36 | override def write(record: InternalRow): Unit = { 37 | connector.updateItem(columnSchema, record)(client, rateLimiter) 38 | } 39 | 40 | override def commit(): WriterCommitMessage = new WriterCommitMessage {} 41 | 42 | override def abort(): Unit = {} 43 | 44 | override def close(): Unit = client.shutdown() 45 | 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/com/audienceproject/shaded/google/common/base/Ticker.java: -------------------------------------------------------------------------------- 1 | package com.audienceproject.shaded.google.common.base; 2 | 3 | /* 4 | * Notice: 5 | * This file was modified at AudienceProject ApS by Cosmin Catalin Sanda (cosmin@audienceproject.com) 6 | */ 7 | 8 | /* 9 | * Copyright (C) 2011 The Guava Authors 10 | * 11 | * Licensed under the Apache License, Version 2.0 (the "License"); 12 | * you may not use this file except in compliance with the License. 13 | * You may obtain a copy of the License at 14 | * 15 | * http://www.apache.org/licenses/LICENSE-2.0 16 | * 17 | * Unless required by applicable law or agreed to in writing, software 18 | * distributed under the License is distributed on an "AS IS" BASIS, 19 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | * See the License for the specific language governing permissions and 21 | * limitations under the License. 22 | */ 23 | 24 | /** 25 | * A time source; returns a time value representing the number of nanoseconds elapsed since some 26 | * fixed but arbitrary point in time. 27 | * 28 | *

Warning: this interface can only be used to measure elapsed time, not wall time. 29 | * 30 | * @author Kevin Bourrillion 31 | * @since 10.0 32 | * (mostly source-compatible since 9.0) 34 | */ 35 | public abstract class Ticker { 36 | /** 37 | * Constructor for use by subclasses. 38 | */ 39 | protected Ticker() {} 40 | 41 | /** 42 | * Returns the number of nanoseconds elapsed since this ticker's fixed 43 | * point of reference. 44 | */ 45 | public abstract long read(); 46 | 47 | /** 48 | * A ticker that reads the current time using {@link System#nanoTime}. 49 | * 50 | * @since 10.0 51 | */ 52 | public static Ticker systemTicker() { 53 | return SYSTEM_TICKER; 54 | } 55 | 56 | private static final Ticker SYSTEM_TICKER = new Ticker() { 57 | @Override 58 | public long read() { 59 | return System.nanoTime(); 60 | } 61 | }; 62 | } 63 | 64 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoBatchReader.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import com.audienceproject.spark.dynamodb.connector.DynamoConnector 24 | import org.apache.spark.sql.connector.read._ 25 | import org.apache.spark.sql.connector.read.partitioning.Partitioning 26 | import org.apache.spark.sql.sources.Filter 27 | import org.apache.spark.sql.types.StructType 28 | 29 | class DynamoBatchReader(connector: DynamoConnector, 30 | filters: Array[Filter], 31 | schema: StructType) 32 | extends Scan with Batch with SupportsReportPartitioning { 33 | 34 | override def readSchema(): StructType = schema 35 | 36 | override def toBatch: Batch = this 37 | 38 | override def planInputPartitions(): Array[InputPartition] = { 39 | val requiredColumns = schema.map(_.name) 40 | Array.tabulate(connector.totalSegments)(new ScanPartition(_, requiredColumns, filters)) 41 | } 42 | 43 | override def createReaderFactory(): PartitionReaderFactory = 44 | new DynamoReaderFactory(connector, schema) 45 | 46 | override val outputPartitioning: Partitioning = new OutputPartitioning(connector.totalSegments) 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/FilterPushdownTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb 22 | 23 | import com.audienceproject.spark.dynamodb.implicits._ 24 | 25 | class FilterPushdownTest extends AbstractInMemoryTest { 26 | 27 | test("Count of red fruit is 2 (`EqualTo` filter)") { 28 | import spark.implicits._ 29 | val fruitCount = spark.read.dynamodb("TestFruit").where($"color" === "red").count() 30 | assert(fruitCount === 2) 31 | } 32 | 33 | test("Count of yellow and green fruit is 4 (`In` filter)") { 34 | import spark.implicits._ 35 | val fruitCount = spark.read.dynamodb("TestFruit") 36 | .where($"color" isin("yellow", "green")) 37 | .count() 38 | assert(fruitCount === 4) 39 | } 40 | 41 | test("Count of 0.01 weight fruit is 4 (`In` filter)") { 42 | import spark.implicits._ 43 | val fruitCount = spark.read.dynamodb("TestFruit") 44 | .where($"weightKg" isin 0.01) 45 | .count() 46 | assert(fruitCount === 3) 47 | } 48 | 49 | test("Only 'banana' starts with a 'b' and is >0.01 kg (`StringStartsWith`, `GreaterThan`, `And` filters)") { 50 | import spark.implicits._ 51 | val fruit = spark.read.dynamodb("TestFruit") 52 | .where(($"name" startsWith "b") && ($"weightKg" > 0.01)) 53 | .collectAsList().get(0) 54 | assert(fruit.getAs[String]("name") === "banana") 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataWriter.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import com.amazonaws.services.dynamodbv2.document.DynamoDB 24 | import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter 25 | import com.audienceproject.spark.dynamodb.connector.{ColumnSchema, TableConnector} 26 | import org.apache.spark.sql.catalyst.InternalRow 27 | import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} 28 | 29 | import scala.collection.mutable.ArrayBuffer 30 | 31 | class DynamoDataWriter(batchSize: Int, 32 | columnSchema: ColumnSchema, 33 | connector: TableConnector, 34 | client: DynamoDB) 35 | extends DataWriter[InternalRow] { 36 | 37 | protected val buffer: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow](batchSize) 38 | protected val rateLimiter: RateLimiter = RateLimiter.create(connector.writeLimit) 39 | 40 | override def write(record: InternalRow): Unit = { 41 | buffer += record.copy() 42 | if (buffer.size == batchSize) { 43 | flush() 44 | } 45 | } 46 | 47 | override def commit(): WriterCommitMessage = { 48 | flush() 49 | new WriterCommitMessage {} 50 | } 51 | 52 | override def abort(): Unit = {} 53 | 54 | override def close(): Unit = client.shutdown() 55 | 56 | protected def flush(): Unit = { 57 | if (buffer.nonEmpty) { 58 | connector.putItems(columnSchema, buffer)(client, rateLimiter) 59 | buffer.clear() 60 | } 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/connector/ColumnSchema.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.connector 22 | 23 | import org.apache.spark.sql.types.{DataType, StructType} 24 | 25 | private[dynamodb] class ColumnSchema(keySchema: KeySchema, 26 | sparkSchema: StructType) { 27 | 28 | type Attr = (String, Int, DataType) 29 | 30 | private val columnNames = sparkSchema.map(_.name) 31 | 32 | private val keyIndices = keySchema match { 33 | case KeySchema(hashKey, None) => 34 | val hashKeyIndex = columnNames.indexOf(hashKey) 35 | val hashKeyType = sparkSchema(hashKey).dataType 36 | Left(hashKey, hashKeyIndex, hashKeyType) 37 | case KeySchema(hashKey, Some(rangeKey)) => 38 | val hashKeyIndex = columnNames.indexOf(hashKey) 39 | val rangeKeyIndex = columnNames.indexOf(rangeKey) 40 | val hashKeyType = sparkSchema(hashKey).dataType 41 | val rangeKeyType = sparkSchema(rangeKey).dataType 42 | Right((hashKey, hashKeyIndex, hashKeyType), (rangeKey, rangeKeyIndex, rangeKeyType)) 43 | } 44 | 45 | private val attributeIndices = columnNames.zipWithIndex.filterNot({ 46 | case (name, _) => keySchema match { 47 | case KeySchema(hashKey, None) => name == hashKey 48 | case KeySchema(hashKey, Some(rangeKey)) => name == hashKey || name == rangeKey 49 | } 50 | }).map({ 51 | case (name, index) => (name, index, sparkSchema(name).dataType) 52 | }) 53 | 54 | def keys(): Either[Attr, (Attr, Attr)] = keyIndices 55 | 56 | def attributes(): Seq[Attr] = attributeIndices 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoWriterFactory.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import com.audienceproject.spark.dynamodb.connector.{ColumnSchema, TableConnector} 24 | import org.apache.spark.sql.catalyst.InternalRow 25 | import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} 26 | import org.apache.spark.sql.types.StructType 27 | 28 | class DynamoWriterFactory(connector: TableConnector, 29 | parameters: Map[String, String], 30 | schema: StructType) 31 | extends DataWriterFactory { 32 | 33 | private val batchSize = parameters.getOrElse("writebatchsize", "25").toInt 34 | private val update = parameters.getOrElse("update", "false").toBoolean 35 | private val delete = parameters.getOrElse("delete", "false").toBoolean 36 | 37 | private val region = parameters.get("region") 38 | private val roleArn = parameters.get("rolearn") 39 | private val providerClassName = parameters.get("providerclassname") 40 | 41 | override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { 42 | val columnSchema = new ColumnSchema(connector.keySchema, schema) 43 | val client = connector.getDynamoDB(region, roleArn, providerClassName) 44 | if (update) { 45 | assert(!delete, "Please provide exactly one of 'update' or 'delete' options.") 46 | new DynamoDataUpdateWriter(columnSchema, connector, client) 47 | } else if (delete) { 48 | new DynamoDataDeleteWriter(batchSize, columnSchema, connector, client) 49 | } else { 50 | new DynamoDataWriter(batchSize, columnSchema, connector, client) 51 | } 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoScanBuilder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import com.audienceproject.spark.dynamodb.connector.{DynamoConnector, FilterPushdown} 24 | import org.apache.spark.sql.connector.read._ 25 | import org.apache.spark.sql.sources.Filter 26 | import org.apache.spark.sql.types._ 27 | 28 | class DynamoScanBuilder(connector: DynamoConnector, schema: StructType) 29 | extends ScanBuilder 30 | with SupportsPushDownRequiredColumns 31 | with SupportsPushDownFilters { 32 | 33 | private var acceptedFilters: Array[Filter] = Array.empty 34 | private var currentSchema: StructType = schema 35 | 36 | override def build(): Scan = new DynamoBatchReader(connector, pushedFilters(), currentSchema) 37 | 38 | override def pruneColumns(requiredSchema: StructType): Unit = { 39 | val keyFields = Seq(Some(connector.keySchema.hashKeyName), connector.keySchema.rangeKeyName).flatten 40 | .flatMap(keyName => currentSchema.fields.find(_.name == keyName)) 41 | val requiredFields = keyFields ++ requiredSchema.fields 42 | val newFields = currentSchema.fields.filter(requiredFields.contains) 43 | currentSchema = StructType(newFields) 44 | } 45 | 46 | override def pushFilters(filters: Array[Filter]): Array[Filter] = { 47 | if (connector.filterPushdownEnabled) { 48 | val (acceptedFilters, postScanFilters) = FilterPushdown.acceptFilters(filters) 49 | this.acceptedFilters = acceptedFilters 50 | postScanFilters // Return filters that need to be evaluated after scanning. 51 | } else filters 52 | } 53 | 54 | override def pushedFilters(): Array[Filter] = acceptedFilters 55 | 56 | } 57 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/DefaultSourceTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb 22 | 23 | import com.audienceproject.spark.dynamodb.implicits._ 24 | import com.audienceproject.spark.dynamodb.structs.TestFruit 25 | import org.apache.spark.sql.functions._ 26 | 27 | import scala.collection.JavaConverters._ 28 | 29 | class DefaultSourceTest extends AbstractInMemoryTest { 30 | 31 | test("Table count is 9") { 32 | val count = spark.read.dynamodb("TestFruit") 33 | count.show() 34 | assert(count.count() === 9) 35 | } 36 | 37 | test("Column sum is 27") { 38 | val result = spark.read.dynamodb("TestFruit").collectAsList().asScala 39 | val numCols = result.map(_.length).sum 40 | assert(numCols === 27) 41 | } 42 | 43 | test("Select only first two columns") { 44 | val result = spark.read.dynamodb("TestFruit").select("name", "color").collectAsList().asScala 45 | val numCols = result.map(_.length).sum 46 | assert(numCols === 18) 47 | } 48 | 49 | test("The least occurring color is yellow") { 50 | import spark.implicits._ 51 | val itemWithLeastOccurringColor = spark.read.dynamodb("TestFruit") 52 | .groupBy($"color").agg(count($"color").as("countColor")) 53 | .orderBy($"countColor") 54 | .takeAsList(1).get(0) 55 | assert(itemWithLeastOccurringColor.getAs[String]("color") === "yellow") 56 | } 57 | 58 | test("Test of attribute name alias") { 59 | import spark.implicits._ 60 | val itemApple = spark.read.dynamodbAs[TestFruit]("TestFruit") 61 | .filter($"primaryKey" === "apple") 62 | .takeAsList(1).get(0) 63 | assert(itemApple.primaryKey === "apple") 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/reflect/SchemaAnalysis.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.reflect 22 | 23 | import com.audienceproject.spark.dynamodb.attribute 24 | import org.apache.spark.sql.catalyst.ScalaReflection 25 | import org.apache.spark.sql.types.{StructField, StructType} 26 | 27 | import scala.reflect.ClassTag 28 | import scala.reflect.runtime.{universe => ru} 29 | 30 | /** 31 | * Uses reflection to perform a static analysis that can derive a Spark schema from a case class of type `T`. 32 | */ 33 | private[dynamodb] object SchemaAnalysis { 34 | 35 | def apply[T <: Product : ClassTag : ru.TypeTag]: (StructType, Map[String, String]) = { 36 | 37 | val runtimeMirror = ru.runtimeMirror(getClass.getClassLoader) 38 | 39 | val classObj = scala.reflect.classTag[T].runtimeClass 40 | val classSymbol = runtimeMirror.classSymbol(classObj) 41 | 42 | val params = classSymbol.primaryConstructor.typeSignature.paramLists.head 43 | val (sparkFields, aliasMap) = params.foldLeft((List.empty[StructField], Map.empty[String, String]))({ 44 | case ((list, map), field) => 45 | val sparkType = ScalaReflection.schemaFor(field.typeSignature).dataType 46 | 47 | // Black magic from here: 48 | // https://stackoverflow.com/questions/23046958/accessing-an-annotation-value-in-scala 49 | val attrName = field.annotations.collectFirst({ 50 | case ann: ru.AnnotationApi if ann.tree.tpe =:= ru.typeOf[attribute] => 51 | ann.tree.children.tail.collectFirst({ 52 | case ru.Literal(ru.Constant(name: String)) => name 53 | }) 54 | }).flatten 55 | 56 | if (attrName.isDefined) { 57 | val sparkField = StructField(attrName.get, sparkType, nullable = true) 58 | (list :+ sparkField, map + (attrName.get -> field.name.toString)) 59 | } else { 60 | val sparkField = StructField(field.name.toString, sparkType, nullable = true) 61 | (list :+ sparkField, map) 62 | } 63 | }) 64 | 65 | (StructType(sparkFields), aliasMap) 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/RegionTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb 22 | 23 | import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration 24 | import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder} 25 | import com.amazonaws.services.dynamodbv2.document.DynamoDB 26 | import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput} 27 | import com.audienceproject.spark.dynamodb.implicits._ 28 | 29 | class RegionTest extends AbstractInMemoryTest { 30 | 31 | test("Inserting from a local Dataset") { 32 | val tableName = "RegionTest1" 33 | dynamoDB.createTable(new CreateTableRequest() 34 | .withTableName(tableName) 35 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 36 | .withKeySchema(new KeySchemaElement("name", "HASH")) 37 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 38 | val client: AmazonDynamoDB = AmazonDynamoDBClientBuilder.standard() 39 | .withEndpointConfiguration(new EndpointConfiguration(System.getProperty("aws.dynamodb.endpoint"), "eu-central-1")) 40 | .build() 41 | val dynamoDBEU: DynamoDB = new DynamoDB(client) 42 | dynamoDBEU.createTable(new CreateTableRequest() 43 | .withTableName(tableName) 44 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 45 | .withKeySchema(new KeySchemaElement("name", "HASH")) 46 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 47 | 48 | import spark.implicits._ 49 | 50 | val newItemsDs = spark.createDataset(Seq( 51 | ("lemon", "yellow", 0.1), 52 | ("orange", "orange", 0.2), 53 | ("pomegranate", "red", 0.2) 54 | )) 55 | .withColumnRenamed("_1", "name") 56 | .withColumnRenamed("_2", "color") 57 | .withColumnRenamed("_3", "weight") 58 | newItemsDs.write.option("region","eu-central-1").dynamodb(tableName) 59 | 60 | val validationDs = spark.read.dynamodb(tableName) 61 | assert(validationDs.count() === 0) 62 | val validationDsEU = spark.read.option("region","eu-central-1").dynamodb(tableName) 63 | assert(validationDsEU.count() === 3) 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/implicits.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb 22 | 23 | import com.audienceproject.spark.dynamodb.reflect.SchemaAnalysis 24 | import org.apache.spark.sql._ 25 | import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder 26 | import org.apache.spark.sql.functions.col 27 | 28 | import scala.reflect.ClassTag 29 | import scala.reflect.runtime.universe.TypeTag 30 | 31 | object implicits { 32 | 33 | implicit class DynamoDBDataFrameReader(reader: DataFrameReader) { 34 | 35 | def dynamodb(tableName: String): DataFrame = 36 | getDynamoDBSource(tableName).load() 37 | 38 | def dynamodb(tableName: String, indexName: String): DataFrame = 39 | getDynamoDBSource(tableName).option("indexName", indexName).load() 40 | 41 | def dynamodbAs[T <: Product : ClassTag : TypeTag](tableName: String): Dataset[T] = { 42 | implicit val encoder: Encoder[T] = ExpressionEncoder() 43 | val (schema, aliasMap) = SchemaAnalysis[T] 44 | getColumnsAlias(getDynamoDBSource(tableName).schema(schema).load(), aliasMap).as 45 | } 46 | 47 | def dynamodbAs[T <: Product : ClassTag : TypeTag](tableName: String, indexName: String): Dataset[T] = { 48 | implicit val encoder: Encoder[T] = ExpressionEncoder() 49 | val (schema, aliasMap) = SchemaAnalysis[T] 50 | getColumnsAlias( 51 | getDynamoDBSource(tableName).option("indexName", indexName).schema(schema).load(), aliasMap).as 52 | } 53 | 54 | private def getDynamoDBSource(tableName: String): DataFrameReader = 55 | reader.format("com.audienceproject.spark.dynamodb.datasource").option("tableName", tableName) 56 | 57 | private def getColumnsAlias(dataFrame: DataFrame, aliasMap: Map[String, String]): DataFrame = { 58 | if (aliasMap.isEmpty) dataFrame 59 | else { 60 | val columnsAlias = dataFrame.columns.map({ 61 | case name if aliasMap.isDefinedAt(name) => col(name) as aliasMap(name) 62 | case name => col(name) 63 | }) 64 | dataFrame.select(columnsAlias: _*) 65 | } 66 | } 67 | 68 | } 69 | 70 | implicit class DynamoDBDataFrameWriter[T](writer: DataFrameWriter[T]) { 71 | 72 | def dynamodb(tableName: String): Unit = 73 | writer.format("com.audienceproject.spark.dynamodb.datasource") 74 | .mode(SaveMode.Append) 75 | .option("tableName", tableName) 76 | .save() 77 | 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/AbstractInMemoryTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb 22 | 23 | import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration 24 | import com.amazonaws.services.dynamodbv2.document.{DynamoDB, Item} 25 | import com.amazonaws.services.dynamodbv2.local.main.ServerRunner 26 | import com.amazonaws.services.dynamodbv2.local.server.DynamoDBProxyServer 27 | import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput} 28 | import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder} 29 | import org.apache.spark.sql.SparkSession 30 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 31 | 32 | class AbstractInMemoryTest extends FunSuite with BeforeAndAfterAll { 33 | 34 | val server: DynamoDBProxyServer = ServerRunner.createServerFromCommandLineArgs(Array("-inMemory")) 35 | 36 | val client: AmazonDynamoDB = AmazonDynamoDBClientBuilder.standard() 37 | .withEndpointConfiguration(new EndpointConfiguration(System.getProperty("aws.dynamodb.endpoint"), "us-east-1")) 38 | .build() 39 | val dynamoDB: DynamoDB = new DynamoDB(client) 40 | 41 | val spark: SparkSession = SparkSession.builder 42 | .master("local") 43 | .appName(this.getClass.getName) 44 | .getOrCreate() 45 | 46 | spark.sparkContext.setLogLevel("ERROR") 47 | 48 | override def beforeAll(): Unit = { 49 | server.start() 50 | 51 | // Create a test table. 52 | dynamoDB.createTable(new CreateTableRequest() 53 | .withTableName("TestFruit") 54 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 55 | .withKeySchema(new KeySchemaElement("name", "HASH")) 56 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 57 | 58 | // Populate with test data. 59 | val table = dynamoDB.getTable("TestFruit") 60 | for ((name, color, weight) <- Seq( 61 | ("apple", "red", 0.2), ("banana", "yellow", 0.15), ("watermelon", "red", 0.5), 62 | ("grape", "green", 0.01), ("pear", "green", 0.2), ("kiwi", "green", 0.05), 63 | ("blackberry", "purple", 0.01), ("blueberry", "purple", 0.01), ("plum", "purple", 0.1) 64 | )) { 65 | table.putItem(new Item() 66 | .withString("name", name) 67 | .withString("color", color) 68 | .withDouble("weightKg", weight)) 69 | } 70 | } 71 | 72 | override def afterAll(): Unit = { 73 | client.deleteTable("TestFruit") 74 | server.stop() 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/catalyst/JavaConverter.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.catalyst 22 | 23 | import java.util 24 | 25 | import org.apache.spark.sql.catalyst.InternalRow 26 | import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} 27 | import org.apache.spark.sql.types._ 28 | import org.apache.spark.unsafe.types.UTF8String 29 | 30 | import scala.collection.JavaConverters._ 31 | 32 | object JavaConverter { 33 | 34 | def convertRowValue(row: InternalRow, index: Int, elementType: DataType): Any = { 35 | elementType match { 36 | case ArrayType(innerType, _) => convertArray(row.getArray(index), innerType) 37 | case MapType(keyType, valueType, _) => convertMap(row.getMap(index), keyType, valueType) 38 | case StructType(fields) => convertStruct(row.getStruct(index, fields.length), fields) 39 | case StringType => row.getString(index) 40 | case LongType => row.getLong(index) 41 | case t: DecimalType => row.getDecimal(index, t.precision, t.scale).toBigDecimal 42 | case _ => row.get(index, elementType) 43 | } 44 | } 45 | 46 | def convertArray(array: ArrayData, elementType: DataType): Any = { 47 | elementType match { 48 | case ArrayType(innerType, _) => array.toSeq[ArrayData](elementType).map(convertArray(_, innerType)).asJava 49 | case MapType(keyType, valueType, _) => array.toSeq[MapData](elementType).map(convertMap(_, keyType, valueType)).asJava 50 | case structType: StructType => array.toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)).asJava 51 | case StringType => convertStringArray(array).asJava 52 | case _ => array.toSeq[Any](elementType).asJava 53 | } 54 | } 55 | 56 | def convertMap(map: MapData, keyType: DataType, valueType: DataType): util.Map[String, Any] = { 57 | if (keyType != StringType) throw new IllegalArgumentException( 58 | s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.") 59 | val keys = convertStringArray(map.keyArray()) 60 | val values = valueType match { 61 | case ArrayType(innerType, _) => map.valueArray().toSeq[ArrayData](valueType).map(convertArray(_, innerType)) 62 | case MapType(innerKeyType, innerValueType, _) => map.valueArray().toSeq[MapData](valueType).map(convertMap(_, innerKeyType, innerValueType)) 63 | case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(convertStruct(_, structType.fields)) 64 | case StringType => convertStringArray(map.valueArray()) 65 | case _ => map.valueArray().toSeq[Any](valueType) 66 | } 67 | val kvPairs = for (i <- 0 until map.numElements()) yield keys(i) -> values(i) 68 | Map(kvPairs: _*).asJava 69 | } 70 | 71 | def convertStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = { 72 | val kvPairs = for (i <- 0 until row.numFields) yield 73 | if (row.isNullAt(i)) fields(i).name -> null 74 | else fields(i).name -> convertRowValue(row, i, fields(i).dataType) 75 | Map(kvPairs: _*).asJava 76 | } 77 | 78 | def convertStringArray(array: ArrayData): Seq[String] = 79 | array.toSeq[UTF8String](StringType).map(_.toString) 80 | 81 | } 82 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoReaderFactory.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import com.amazonaws.services.dynamodbv2.document.Item 24 | import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter 25 | import com.audienceproject.spark.dynamodb.connector.DynamoConnector 26 | import org.apache.spark.sql.catalyst.InternalRow 27 | import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} 28 | import org.apache.spark.sql.types.{StructField, StructType} 29 | 30 | import scala.collection.JavaConverters._ 31 | 32 | class DynamoReaderFactory(connector: DynamoConnector, 33 | schema: StructType) 34 | extends PartitionReaderFactory { 35 | 36 | override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { 37 | if (connector.isEmpty) new EmptyReader 38 | else new ScanPartitionReader(partition.asInstanceOf[ScanPartition]) 39 | } 40 | 41 | private class EmptyReader extends PartitionReader[InternalRow] { 42 | override def next(): Boolean = false 43 | 44 | override def get(): InternalRow = throw new IllegalStateException("Unable to call get() on empty iterator") 45 | 46 | override def close(): Unit = {} 47 | } 48 | 49 | private class ScanPartitionReader(scanPartition: ScanPartition) extends PartitionReader[InternalRow] { 50 | 51 | import scanPartition._ 52 | 53 | private val pageIterator = connector.scan(partitionIndex, requiredColumns, filters).pages().iterator().asScala 54 | private val rateLimiter = RateLimiter.create(connector.readLimit) 55 | 56 | private var innerIterator: Iterator[InternalRow] = Iterator.empty 57 | 58 | private var currentRow: InternalRow = _ 59 | private var proceed = false 60 | 61 | private val typeConversions = schema.collect({ 62 | case StructField(name, dataType, _, _) => name -> TypeConversion(name, dataType) 63 | }).toMap 64 | 65 | override def next(): Boolean = { 66 | proceed = true 67 | innerIterator.hasNext || { 68 | if (pageIterator.hasNext) { 69 | nextPage() 70 | next() 71 | } 72 | else false 73 | } 74 | } 75 | 76 | override def get(): InternalRow = { 77 | if (proceed) { 78 | currentRow = innerIterator.next() 79 | proceed = false 80 | } 81 | currentRow 82 | } 83 | 84 | override def close(): Unit = {} 85 | 86 | private def nextPage(): Unit = { 87 | val page = pageIterator.next() 88 | val result = page.getLowLevelResult 89 | Option(result.getScanResult.getConsumedCapacity).foreach(cap => rateLimiter.acquire(cap.getCapacityUnits.toInt max 1)) 90 | innerIterator = result.getItems.iterator().asScala.map(itemToRow(requiredColumns)) 91 | } 92 | 93 | private def itemToRow(requiredColumns: Seq[String])(item: Item): InternalRow = 94 | if (requiredColumns.nonEmpty) InternalRow.fromSeq(requiredColumns.map(columnName => typeConversions(columnName)(item))) 95 | else InternalRow.fromSeq(item.asMap().asScala.values.toSeq.map(_.toString)) 96 | 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.connector 22 | 23 | import com.amazonaws.services.dynamodbv2.document.spec.ScanSpec 24 | import com.amazonaws.services.dynamodbv2.document.{ItemCollection, ScanOutcome} 25 | import com.amazonaws.services.dynamodbv2.model.ReturnConsumedCapacity 26 | import com.amazonaws.services.dynamodbv2.xspec.ExpressionSpecBuilder 27 | import org.apache.spark.sql.sources.Filter 28 | 29 | import scala.collection.JavaConverters._ 30 | 31 | private[dynamodb] class TableIndexConnector(tableName: String, indexName: String, parallelism: Int, parameters: Map[String, String]) 32 | extends DynamoConnector with Serializable { 33 | 34 | private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean 35 | private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean 36 | private val region = parameters.get("region") 37 | private val roleArn = parameters.get("roleArn") 38 | private val providerClassName = parameters.get("providerclassname") 39 | 40 | override val filterPushdownEnabled: Boolean = filterPushdown 41 | 42 | override val (keySchema, readLimit, itemLimit, totalSegments) = { 43 | val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName) 44 | val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get 45 | 46 | // Key schema. 47 | val keySchema = KeySchema.fromDescription(indexDesc.getKeySchema.asScala) 48 | 49 | // User parameters. 50 | val bytesPerRCU = parameters.getOrElse("bytesPerRCU", "4000").toInt 51 | val maxPartitionBytes = parameters.getOrElse("maxpartitionbytes", "128000000").toInt 52 | val targetCapacity = parameters.getOrElse("targetCapacity", "1").toDouble 53 | val readFactor = if (consistentRead) 1 else 2 54 | 55 | // Table parameters. 56 | val indexSize = indexDesc.getIndexSizeBytes 57 | val itemCount = indexDesc.getItemCount 58 | 59 | // Partitioning calculation. 60 | val numPartitions = parameters.get("readpartitions").map(_.toInt).getOrElse({ 61 | val sizeBased = (indexSize / maxPartitionBytes).toInt max 1 62 | val remainder = sizeBased % parallelism 63 | if (remainder > 0) sizeBased + (parallelism - remainder) 64 | else sizeBased 65 | }) 66 | 67 | // Provisioned or on-demand throughput. 68 | val readThroughput = parameters.getOrElse("throughput", Option(indexDesc.getProvisionedThroughput.getReadCapacityUnits) 69 | .filter(_ > 0).map(_.longValue().toString) 70 | .getOrElse("100")).toLong 71 | 72 | // Rate limit calculation. 73 | val avgItemSize = indexSize.toDouble / itemCount 74 | val readCapacity = readThroughput * targetCapacity 75 | 76 | val rateLimit = readCapacity / parallelism 77 | val itemLimit = ((bytesPerRCU / avgItemSize * rateLimit).toInt * readFactor) max 1 78 | 79 | (keySchema, rateLimit, itemLimit, numPartitions) 80 | } 81 | 82 | override def scan(segmentNum: Int, columns: Seq[String], filters: Seq[Filter]): ItemCollection[ScanOutcome] = { 83 | val scanSpec = new ScanSpec() 84 | .withSegment(segmentNum) 85 | .withTotalSegments(totalSegments) 86 | .withMaxPageSize(itemLimit) 87 | .withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL) 88 | .withConsistentRead(consistentRead) 89 | 90 | if (columns.nonEmpty) { 91 | val xspec = new ExpressionSpecBuilder().addProjections(columns: _*) 92 | 93 | if (filters.nonEmpty && filterPushdown) { 94 | xspec.withCondition(FilterPushdown(filters)) 95 | } 96 | 97 | scanSpec.withExpressionSpec(xspec.buildForScan()) 98 | } 99 | 100 | getDynamoDB(region, roleArn, providerClassName).getTable(tableName).getIndex(indexName).scan(scanSpec) 101 | } 102 | 103 | } 104 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/connector/FilterPushdown.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.connector 22 | 23 | import com.amazonaws.services.dynamodbv2.xspec.ExpressionSpecBuilder.{BOOL => newBOOL, N => newN, S => newS, _} 24 | import com.amazonaws.services.dynamodbv2.xspec._ 25 | import org.apache.spark.sql.sources._ 26 | 27 | private[dynamodb] object FilterPushdown { 28 | 29 | def apply(filters: Seq[Filter]): Condition = 30 | filters.map(buildCondition).map(parenthesize).reduce[Condition](_ and _) 31 | 32 | /** 33 | * Accepts only filters that would be considered valid input to FilterPushdown.apply() 34 | * 35 | * @param filters input list which may contain both valid and invalid filters 36 | * @return a (valid, invalid) partitioning of the input filters 37 | */ 38 | def acceptFilters(filters: Array[Filter]): (Array[Filter], Array[Filter]) = 39 | filters.partition(checkFilter) 40 | 41 | private def checkFilter(filter: Filter): Boolean = filter match { 42 | case _: StringEndsWith => false 43 | case And(left, right) => checkFilter(left) && checkFilter(right) 44 | case Or(left, right) => checkFilter(left) && checkFilter(right) 45 | case Not(f) => checkFilter(f) 46 | case _ => true 47 | } 48 | 49 | private def buildCondition(filter: Filter): Condition = filter match { 50 | case EqualTo(path, value: Boolean) => newBOOL(path).eq(value) 51 | case EqualTo(path, value) => coerceAndApply(_ eq _, _ eq _)(path, value) 52 | 53 | case GreaterThan(path, value) => coerceAndApply(_ gt _, _ gt _)(path, value) 54 | case GreaterThanOrEqual(path, value) => coerceAndApply(_ ge _, _ ge _)(path, value) 55 | 56 | case LessThan(path, value) => coerceAndApply(_ lt _, _ lt _)(path, value) 57 | case LessThanOrEqual(path, value) => coerceAndApply(_ le _, _ le _)(path, value) 58 | 59 | case In(path, values) => 60 | val valueList = values.toList 61 | valueList match { 62 | case (_: String) :: _ => newS(path).in(valueList.asInstanceOf[List[String]]: _*) 63 | case (_: Boolean) :: _ => newBOOL(path).in(valueList.asInstanceOf[List[Boolean]]: _*) 64 | case (_: Int) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*) 65 | case (_: Long) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*) 66 | case (_: Short) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*) 67 | case (_: Float) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*) 68 | case (_: Double) :: _ => newN(path).in(valueList.map(_.asInstanceOf[Number]): _*) 69 | case Nil => throw new IllegalArgumentException("Unable to apply `In` filter with empty value list") 70 | case _ => throw new IllegalArgumentException(s"Type of values supplied to `In` filter on attribute $path not supported by filter pushdown") 71 | } 72 | 73 | case IsNull(path) => attribute_not_exists(path) 74 | case IsNotNull(path) => attribute_exists(path) 75 | 76 | case StringStartsWith(path, value) => newS(path).beginsWith(value) 77 | case StringContains(path, value) => newS(path).contains(value) 78 | case StringEndsWith(_, _) => throw new UnsupportedOperationException("Filter `StringEndsWith` is not supported by DynamoDB") 79 | 80 | case And(left, right) => parenthesize(buildCondition(left)) and parenthesize(buildCondition(right)) 81 | case Or(left, right) => parenthesize(buildCondition(left)) or parenthesize(buildCondition(right)) 82 | case Not(f) => parenthesize(buildCondition(f)).negate() 83 | } 84 | 85 | private def coerceAndApply(stringOp: (S, String) => Condition, numOp: (N, Number) => Condition) 86 | (path: String, value: Any): Condition = value match { 87 | case string: String => stringOp(newS(path), string) 88 | case number: Int => numOp(newN(path), number) 89 | case number: Long => numOp(newN(path), number) 90 | case number: Short => numOp(newN(path), number) 91 | case number: Float => numOp(newN(path), number) 92 | case number: Double => numOp(newN(path), number) 93 | case _ => throw new IllegalArgumentException(s"Type of operand given to filter on attribute $path not supported by filter pushdown") 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoTable.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import java.util 24 | 25 | import com.audienceproject.spark.dynamodb.connector.{TableConnector, TableIndexConnector} 26 | import org.apache.spark.sql.SparkSession 27 | import org.apache.spark.sql.connector.catalog._ 28 | import org.apache.spark.sql.connector.read.ScanBuilder 29 | import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} 30 | import org.apache.spark.sql.types._ 31 | import org.apache.spark.sql.util.CaseInsensitiveStringMap 32 | import org.slf4j.LoggerFactory 33 | 34 | import scala.collection.JavaConverters._ 35 | 36 | class DynamoTable(options: CaseInsensitiveStringMap, 37 | userSchema: Option[StructType] = None) 38 | extends Table 39 | with SupportsRead 40 | with SupportsWrite { 41 | 42 | private val logger = LoggerFactory.getLogger(this.getClass) 43 | 44 | private val dynamoConnector = { 45 | val indexName = Option(options.get("indexname")) 46 | val defaultParallelism = Option(options.get("defaultparallelism")).map(_.toInt).getOrElse(getDefaultParallelism) 47 | val optionsMap = Map(options.asScala.toSeq: _*) 48 | 49 | if (indexName.isDefined) new TableIndexConnector(name(), indexName.get, defaultParallelism, optionsMap) 50 | else new TableConnector(name(), defaultParallelism, optionsMap) 51 | } 52 | 53 | override def name(): String = options.get("tablename") 54 | 55 | override def schema(): StructType = userSchema.getOrElse(inferSchema()) 56 | 57 | override def capabilities(): util.Set[TableCapability] = 58 | Set(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava 59 | 60 | override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { 61 | new DynamoScanBuilder(dynamoConnector, schema()) 62 | } 63 | 64 | override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { 65 | val parameters = Map(info.options().asScala.toSeq: _*) 66 | dynamoConnector match { 67 | case tableConnector: TableConnector => new DynamoWriteBuilder(tableConnector, parameters, info.schema()) 68 | case _ => throw new RuntimeException("Unable to write to a GSI, please omit `indexName` option.") 69 | } 70 | } 71 | 72 | private def getDefaultParallelism: Int = 73 | SparkSession.getActiveSession match { 74 | case Some(spark) => spark.sparkContext.defaultParallelism 75 | case None => 76 | logger.warn("Unable to read defaultParallelism from SparkSession." + 77 | " Parallelism will be 1 unless overwritten with option `defaultParallelism`") 78 | 1 79 | } 80 | 81 | private def inferSchema(): StructType = { 82 | val inferenceItems = 83 | if (dynamoConnector.nonEmpty && options.getBoolean("inferSchema",true)) dynamoConnector.scan(0, Seq.empty, Seq.empty).firstPage().getLowLevelResult.getItems.asScala 84 | else Seq.empty 85 | 86 | val typeMapping = inferenceItems.foldLeft(Map[String, DataType]())({ 87 | case (map, item) => map ++ item.asMap().asScala.mapValues(inferType) 88 | }) 89 | val typeSeq = typeMapping.map({ case (name, sparkType) => StructField(name, sparkType) }).toSeq 90 | 91 | if (typeSeq.size > 100) throw new RuntimeException("Schema inference not possible, too many attributes in table.") 92 | 93 | StructType(typeSeq) 94 | } 95 | 96 | private def inferType(value: Any): DataType = value match { 97 | case number: java.math.BigDecimal => 98 | if (number.scale() == 0) { 99 | if (number.precision() < 10) IntegerType 100 | else if (number.precision() < 19) LongType 101 | else DataTypes.createDecimalType(number.precision(), number.scale()) 102 | } 103 | else DoubleType 104 | case list: java.util.ArrayList[_] => 105 | if (list.isEmpty) ArrayType(StringType) 106 | else ArrayType(inferType(list.get(0))) 107 | case set: java.util.Set[_] => 108 | if (set.isEmpty) ArrayType(StringType) 109 | else ArrayType(inferType(set.iterator().next())) 110 | case map: java.util.Map[String, _] => 111 | val mapFields = (for ((fieldName, fieldValue) <- map.asScala) yield { 112 | StructField(fieldName, inferType(fieldValue)) 113 | }).toSeq 114 | StructType(mapFields) 115 | case _: java.lang.Boolean => BooleanType 116 | case _: Array[Byte] => BinaryType 117 | case _ => StringType 118 | } 119 | 120 | } 121 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/datasource/TypeConversion.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2019 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.datasource 22 | 23 | import com.amazonaws.services.dynamodbv2.document.{IncompatibleTypeException, Item} 24 | import org.apache.spark.sql.catalyst.InternalRow 25 | import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} 26 | import org.apache.spark.sql.types._ 27 | import org.apache.spark.unsafe.types.UTF8String 28 | 29 | import scala.collection.JavaConverters._ 30 | 31 | private[dynamodb] object TypeConversion { 32 | 33 | def apply(attrName: String, sparkType: DataType): Item => Any = 34 | 35 | sparkType match { 36 | case BooleanType => nullableGet(_.getBOOL)(attrName) 37 | case StringType => nullableGet(item => attrName => UTF8String.fromString(item.getString(attrName)))(attrName) 38 | case IntegerType => nullableGet(_.getInt)(attrName) 39 | case LongType => nullableGet(_.getLong)(attrName) 40 | case DoubleType => nullableGet(_.getDouble)(attrName) 41 | case FloatType => nullableGet(_.getFloat)(attrName) 42 | case BinaryType => nullableGet(_.getBinary)(attrName) 43 | case DecimalType() => nullableGet(_.getNumber)(attrName) 44 | case ArrayType(innerType, _) => 45 | nullableGet(_.getList)(attrName).andThen(extractArray(convertValue(innerType))) 46 | case MapType(keyType, valueType, _) => 47 | if (keyType != StringType) throw new IllegalArgumentException(s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.") 48 | nullableGet(_.getRawMap)(attrName).andThen(extractMap(convertValue(valueType))) 49 | case StructType(fields) => 50 | val nestedConversions = fields.collect({ case StructField(name, dataType, _, _) => name -> convertValue(dataType) }) 51 | nullableGet(_.getRawMap)(attrName).andThen(extractStruct(nestedConversions)) 52 | case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.") 53 | } 54 | 55 | private val stringConverter = (value: Any) => UTF8String.fromString(value.asInstanceOf[String]) 56 | 57 | private def convertValue(sparkType: DataType): Any => Any = 58 | 59 | sparkType match { 60 | case IntegerType => nullableConvert(_.intValue()) 61 | case LongType => nullableConvert(_.longValue()) 62 | case DoubleType => nullableConvert(_.doubleValue()) 63 | case FloatType => nullableConvert(_.floatValue()) 64 | case DecimalType() => nullableConvert(identity) 65 | case ArrayType(innerType, _) => extractArray(convertValue(innerType)) 66 | case MapType(keyType, valueType, _) => 67 | if (keyType != StringType) throw new IllegalArgumentException(s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.") 68 | extractMap(convertValue(valueType)) 69 | case StructType(fields) => 70 | val nestedConversions = fields.collect({ case StructField(name, dataType, _, _) => name -> convertValue(dataType) }) 71 | extractStruct(nestedConversions) 72 | case BooleanType => { 73 | case boolean: Boolean => boolean 74 | case _ => null 75 | } 76 | case StringType => { 77 | case string: String => UTF8String.fromString(string) 78 | case _ => null 79 | } 80 | case BinaryType => { 81 | case byteArray: Array[Byte] => byteArray 82 | case _ => null 83 | } 84 | case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.") 85 | } 86 | 87 | private def nullableGet(getter: Item => String => Any)(attrName: String): Item => Any = { 88 | case item if item.hasAttribute(attrName) => try getter(item)(attrName) catch { 89 | case _: NumberFormatException => null 90 | case _: IncompatibleTypeException => null 91 | } 92 | case _ => null 93 | } 94 | 95 | private def nullableConvert(converter: java.math.BigDecimal => Any): Any => Any = { 96 | case item: java.math.BigDecimal => converter(item) 97 | case _ => null 98 | } 99 | 100 | private def extractArray(converter: Any => Any): Any => Any = { 101 | case list: java.util.List[_] => new GenericArrayData(list.asScala.map(converter)) 102 | case set: java.util.Set[_] => new GenericArrayData(set.asScala.map(converter).toSeq) 103 | case _ => null 104 | } 105 | 106 | private def extractMap(converter: Any => Any): Any => Any = { 107 | case map: java.util.Map[_, _] => ArrayBasedMapData(map, stringConverter, converter) 108 | case _ => null 109 | } 110 | 111 | private def extractStruct(conversions: Seq[(String, Any => Any)]): Any => Any = { 112 | case map: java.util.Map[_, _] => InternalRow.fromSeq(conversions.map({ 113 | case (name, conv) => conv(map.get(name)) 114 | })) 115 | case _ => null 116 | } 117 | 118 | } 119 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.connector 22 | 23 | import com.amazonaws.auth.profile.ProfileCredentialsProvider 24 | import com.amazonaws.auth.{AWSCredentialsProvider, AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain} 25 | import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration 26 | import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome} 27 | import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBAsync, AmazonDynamoDBAsyncClientBuilder, AmazonDynamoDBClientBuilder} 28 | import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder 29 | import com.amazonaws.services.securitytoken.model.AssumeRoleRequest 30 | import org.apache.spark.sql.sources.Filter 31 | 32 | private[dynamodb] trait DynamoConnector { 33 | 34 | @transient private lazy val properties = sys.props 35 | 36 | def getDynamoDB(region: Option[String] = None, roleArn: Option[String] = None, providerClassName: Option[String] = None): DynamoDB = { 37 | val client: AmazonDynamoDB = getDynamoDBClient(region, roleArn, providerClassName) 38 | new DynamoDB(client) 39 | } 40 | 41 | private def getDynamoDBClient(region: Option[String] = None, 42 | roleArn: Option[String] = None, 43 | providerClassName: Option[String]): AmazonDynamoDB = { 44 | val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1")) 45 | val credentials = getCredentials(chosenRegion, roleArn, providerClassName) 46 | 47 | properties.get("aws.dynamodb.endpoint").map(endpoint => { 48 | AmazonDynamoDBClientBuilder.standard() 49 | .withCredentials(credentials) 50 | .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion)) 51 | .build() 52 | }).getOrElse( 53 | AmazonDynamoDBClientBuilder.standard() 54 | .withCredentials(credentials) 55 | .withRegion(chosenRegion) 56 | .build() 57 | ) 58 | } 59 | 60 | def getDynamoDBAsyncClient(region: Option[String] = None, 61 | roleArn: Option[String] = None, 62 | providerClassName: Option[String] = None): AmazonDynamoDBAsync = { 63 | val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1")) 64 | val credentials = getCredentials(chosenRegion, roleArn, providerClassName) 65 | 66 | properties.get("aws.dynamodb.endpoint").map(endpoint => { 67 | AmazonDynamoDBAsyncClientBuilder.standard() 68 | .withCredentials(credentials) 69 | .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion)) 70 | .build() 71 | }).getOrElse( 72 | AmazonDynamoDBAsyncClientBuilder.standard() 73 | .withCredentials(credentials) 74 | .withRegion(chosenRegion) 75 | .build() 76 | ) 77 | } 78 | 79 | /** 80 | * Get credentials from an instantiated object of the class name given 81 | * or a passed in arn 82 | * or from profile 83 | * or return the default credential provider 84 | **/ 85 | private def getCredentials(chosenRegion: String, roleArn: Option[String], providerClassName: Option[String]) = { 86 | providerClassName.map(providerClass => { 87 | Class.forName(providerClass).newInstance.asInstanceOf[AWSCredentialsProvider] 88 | }).orElse(roleArn.map(arn => { 89 | val stsClient = properties.get("aws.sts.endpoint").map(endpoint => { 90 | AWSSecurityTokenServiceClientBuilder 91 | .standard() 92 | .withCredentials(new DefaultAWSCredentialsProviderChain) 93 | .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion)) 94 | .build() 95 | }).getOrElse( 96 | // STS without an endpoint will sign from the region, but use the global endpoint 97 | AWSSecurityTokenServiceClientBuilder 98 | .standard() 99 | .withCredentials(new DefaultAWSCredentialsProviderChain) 100 | .withRegion(chosenRegion) 101 | .build() 102 | ) 103 | val assumeRoleResult = stsClient.assumeRole( 104 | new AssumeRoleRequest() 105 | .withRoleSessionName("DynamoDBAssumed") 106 | .withRoleArn(arn) 107 | ) 108 | val stsCredentials = assumeRoleResult.getCredentials 109 | val assumeCreds = new BasicSessionCredentials( 110 | stsCredentials.getAccessKeyId, 111 | stsCredentials.getSecretAccessKey, 112 | stsCredentials.getSessionToken 113 | ) 114 | new AWSStaticCredentialsProvider(assumeCreds) 115 | })).orElse(properties.get("aws.profile").map(new ProfileCredentialsProvider(_))) 116 | .getOrElse(new DefaultAWSCredentialsProviderChain) 117 | } 118 | 119 | val keySchema: KeySchema 120 | 121 | val readLimit: Double 122 | 123 | val itemLimit: Int 124 | 125 | val totalSegments: Int 126 | 127 | val filterPushdownEnabled: Boolean 128 | 129 | def scan(segmentNum: Int, columns: Seq[String], filters: Seq[Filter]): ItemCollection[ScanOutcome] 130 | 131 | def isEmpty: Boolean = itemLimit == 0 132 | 133 | def nonEmpty: Boolean = !isEmpty 134 | 135 | } 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spark+DynamoDB 2 | Plug-and-play implementation of an Apache Spark custom data source for AWS DynamoDB. 3 | 4 | We published a small article about the project, check it out here: 5 | https://www.audienceproject.com/blog/tech/sparkdynamodb-using-aws-dynamodb-data-source-apache-spark/ 6 | 7 | ## News 8 | 9 | * 2021-01-28: Added option `inferSchema=false` which is useful when writing to a table with many columns 10 | * 2020-07-23: Releasing version 1.1.0 which supports Spark 3.0.0 and Scala 2.12. Future releases will no longer be compatible with Scala 2.11 and Spark 2.x.x. 11 | * 2020-04-28: Releasing version 1.0.4. Includes support for assuming AWS roles through custom STS endpoint (credits @jhulten). 12 | * 2020-04-09: We are releasing version 1.0.3 of the Spark+DynamoDB connector. Added option to `delete` records (thank you @rhelmstetter). Fixes (thank you @juanyunism for #46). 13 | * 2019-11-25: We are releasing version 1.0.0 of the Spark+DynamoDB connector, which is based on the Spark Data Source V2 API. Out-of-the-box throughput calculations, parallelism and partition planning should now be more reliable. We have also pulled out the external dependency on Guava, which was causing a lot of compatibility issues. 14 | 15 | ## Features 16 | 17 | - Distributed, parallel scan with lazy evaluation 18 | - Throughput control by rate limiting on target fraction of provisioned table/index capacity 19 | - Schema discovery to suit your needs 20 | - Dynamic inference 21 | - Static analysis of case class 22 | - Column and filter pushdown 23 | - Global secondary index support 24 | - Write support 25 | 26 | ## Getting The Dependency 27 | 28 | The library is available from [Maven Central](https://mvnrepository.com/artifact/com.audienceproject/spark-dynamodb). Add the dependency in SBT as ```"com.audienceproject" %% "spark-dynamodb" % "latest"``` 29 | 30 | Spark is used in the library as a "provided" dependency, which means Spark has to be installed separately on the container where the application is running, such as is the case on AWS EMR. 31 | 32 | ## Quick Start Guide 33 | 34 | ### Scala 35 | ```scala 36 | import com.audienceproject.spark.dynamodb.implicits._ 37 | import org.apache.spark.sql.SparkSession 38 | 39 | val spark = SparkSession.builder().getOrCreate() 40 | 41 | // Load a DataFrame from a Dynamo table. Only incurs the cost of a single scan for schema inference. 42 | val dynamoDf = spark.read.dynamodb("SomeTableName") // <-- DataFrame of Row objects with inferred schema. 43 | 44 | // Scan the table for the first 100 items (the order is arbitrary) and print them. 45 | dynamoDf.show(100) 46 | 47 | // write to some other table overwriting existing item with same keys 48 | dynamoDf.write.dynamodb("SomeOtherTable") 49 | 50 | // Case class representing the items in our table. 51 | import com.audienceproject.spark.dynamodb.attribute 52 | case class Vegetable (name: String, color: String, @attribute("weight_kg") weightKg: Double) 53 | 54 | // Load a Dataset[Vegetable]. Notice the @attribute annotation on the case class - we imagine the weight attribute is named with an underscore in DynamoDB. 55 | import org.apache.spark.sql.functions._ 56 | import spark.implicits._ 57 | val vegetableDs = spark.read.dynamodbAs[Vegetable]("VegeTable") 58 | val avgWeightByColor = vegetableDs.agg($"color", avg($"weightKg")) // The column is called 'weightKg' in the Dataset. 59 | ``` 60 | 61 | ### Python 62 | ```python 63 | # Load a DataFrame from a Dynamo table. Only incurs the cost of a single scan for schema inference. 64 | dynamoDf = spark.read.option("tableName", "SomeTableName") \ 65 | .format("dynamodb") \ 66 | .load() # <-- DataFrame of Row objects with inferred schema. 67 | 68 | # Scan the table for the first 100 items (the order is arbitrary) and print them. 69 | dynamoDf.show(100) 70 | 71 | # write to some other table overwriting existing item with same keys 72 | dynamoDf.write.option("tableName", "SomeOtherTable") \ 73 | .format("dynamodb") \ 74 | .save() 75 | ``` 76 | 77 | *Note:* When running from `pyspark` shell, you can add the library as: 78 | ```bash 79 | pyspark --packages com.audienceproject:spark-dynamodb_: 80 | ``` 81 | 82 | ## Parameters 83 | The following parameters can be set as options on the Spark reader and writer object before loading/saving. 84 | - `region` sets the region where the dynamodb table. Default is environment specific. 85 | - `roleArn` sets an IAM role to assume. This allows for access to a DynamoDB in a different account than the Spark cluster. Defaults to the standard role configuration. 86 | 87 | The following parameters can be set as options on the Spark reader object before loading. 88 | 89 | - `readPartitions` number of partitions to split the initial RDD when loading the data into Spark. Defaults to the size of the DynamoDB table divided into chunks of `maxPartitionBytes` 90 | - `maxPartitionBytes` the maximum size of a single input partition. Default 128 MB 91 | - `defaultParallelism` the number of input partitions that can be read from DynamoDB simultaneously. Defaults to `sparkContext.defaultParallelism` 92 | - `targetCapacity` fraction of provisioned read capacity on the table (or index) to consume for reading. Default 1 (i.e. 100% capacity). 93 | - `stronglyConsistentReads` whether or not to use strongly consistent reads. Default false. 94 | - `bytesPerRCU` number of bytes that can be read per second with a single Read Capacity Unit. Default 4000 (4 KB). This value is multiplied by two when `stronglyConsistentReads=false` 95 | - `filterPushdown` whether or not to use filter pushdown to DynamoDB on scan requests. Default true. 96 | - `throughput` the desired read throughput to use. It overwrites any calculation used by the package. It is intended to be used with tables that are on-demand. Defaults to 100 for on-demand. 97 | 98 | The following parameters can be set as options on the Spark writer object before saving. 99 | 100 | - `writeBatchSize` number of items to send per call to DynamoDB BatchWriteItem. Default 25. 101 | - `targetCapacity` fraction of provisioned write capacity on the table to consume for writing or updating. Default 1 (i.e. 100% capacity). 102 | - `update` if true items will be written using UpdateItem on keys rather than BatchWriteItem. Default false. 103 | - `throughput` the desired write throughput to use. It overwrites any calculation used by the package. It is intended to be used with tables that are on-demand. Defaults to 100 for on-demand. 104 | - `inferSchema` if false will not automatically infer schema - this is useful when writing to a table with many columns 105 | 106 | ## System Properties 107 | The following Java system properties are available for configuration. 108 | 109 | - `aws.profile` IAM profile to use for default credentials provider. 110 | - `aws.dynamodb.region` region in which to access the AWS APIs. 111 | - `aws.dynamodb.endpoint` endpoint to use for accessing the DynamoDB API. 112 | - `aws.sts.endpoint` endpoint to use for accessing the STS API when assuming the role indicated by the `roleArn` parameter. 113 | 114 | ## Acknowledgements 115 | Usage of parallel scan and rate limiter inspired by work in https://github.com/traviscrawford/spark-dynamodb 116 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/WriteRelationTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb 22 | 23 | import java.util 24 | 25 | import collection.JavaConverters._ 26 | import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, KeyType, ProvisionedThroughput} 27 | import com.audienceproject.spark.dynamodb.implicits._ 28 | import org.apache.spark.sql.functions.{lit, when, length => sqlLength} 29 | import org.scalatest.Matchers 30 | 31 | class WriteRelationTest extends AbstractInMemoryTest with Matchers { 32 | 33 | test("Inserting from a local Dataset") { 34 | dynamoDB.createTable(new CreateTableRequest() 35 | .withTableName("InsertTest1") 36 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 37 | .withKeySchema(new KeySchemaElement("name", "HASH")) 38 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 39 | 40 | import spark.implicits._ 41 | 42 | val newItemsDs = spark.createDataset(Seq( 43 | ("lemon", "yellow", 0.1), 44 | ("orange", "orange", 0.2), 45 | ("pomegranate", "red", 0.2) 46 | )) 47 | .withColumnRenamed("_1", "name") 48 | .withColumnRenamed("_2", "color") 49 | .withColumnRenamed("_3", "weight") 50 | newItemsDs.write.dynamodb("InsertTest1") 51 | 52 | val validationDs = spark.read.dynamodb("InsertTest1") 53 | assert(validationDs.count() === 3) 54 | assert(validationDs.select("name").as[String].collect().forall(Seq("lemon", "orange", "pomegranate") contains _)) 55 | assert(validationDs.select("color").as[String].collect().forall(Seq("yellow", "orange", "red") contains _)) 56 | assert(validationDs.select("weight").as[Double].collect().forall(Seq(0.1, 0.2, 0.2) contains _)) 57 | } 58 | 59 | test("Deleting from a local Dataset with a HashKey only") { 60 | val tablename = "DeleteTest1" 61 | dynamoDB.createTable(new CreateTableRequest() 62 | .withTableName(tablename) 63 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 64 | .withKeySchema(new KeySchemaElement("name", "HASH")) 65 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 66 | 67 | import spark.implicits._ 68 | 69 | val newItemsDs = Seq( 70 | ("lemon", "yellow", 0.1), 71 | ("orange", "orange", 0.2), 72 | ("pomegranate", "red", 0.2) 73 | ).toDF("name", "color", "weight") 74 | newItemsDs.write.dynamodb(tablename) 75 | 76 | val toDelete = Seq( 77 | ("lemon", "yellow"), 78 | ("orange", "blue"), 79 | ("doesn't exist", "black") 80 | ).toDF("name", "color") 81 | toDelete.write.option("delete", "true").dynamodb(tablename) 82 | 83 | val validationDs = spark.read.dynamodb(tablename) 84 | validationDs.count() shouldEqual 1 85 | val rec = validationDs.first 86 | rec.getString(rec.fieldIndex("name")) shouldEqual "pomegranate" 87 | rec.getString(rec.fieldIndex("color")) shouldEqual "red" 88 | rec.getDouble(rec.fieldIndex("weight")) shouldEqual 0.2 89 | } 90 | 91 | test("Deleting from a local Dataset with a HashKey and RangeKey") { 92 | val tablename = "DeleteTest2" 93 | 94 | dynamoDB.createTable(new CreateTableRequest() 95 | .withTableName(tablename) 96 | .withAttributeDefinitions(Seq( 97 | new AttributeDefinition("name", "S"), 98 | new AttributeDefinition("weight", "N") 99 | ).asJavaCollection) 100 | .withKeySchema(Seq( 101 | new KeySchemaElement("name", KeyType.HASH), 102 | // also test that non-string key works 103 | new KeySchemaElement("weight", KeyType.RANGE) 104 | ).asJavaCollection) 105 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 106 | 107 | import spark.implicits._ 108 | 109 | val newItemsDs = Seq( 110 | ("lemon", "yellow", 0.1), 111 | ("lemon", "blue", 4.0), 112 | ("orange", "orange", 0.2), 113 | ("pomegranate", "red", 0.2) 114 | ).toDF("name", "color", "weight") 115 | newItemsDs.write.dynamodb(tablename) 116 | 117 | val toDelete = Seq( 118 | ("lemon", "yellow", 0.1), 119 | ("orange", "orange", 0.2), 120 | ("pomegranate", "shouldn'tdelete", 0.5) 121 | ).toDF("name", "color", "weight") 122 | toDelete.write.option("delete", "true").dynamodb(tablename) 123 | 124 | val validationDs = spark.read.dynamodb(tablename) 125 | validationDs.show 126 | validationDs.count() shouldEqual 2 127 | validationDs.select("name").as[String].collect should contain theSameElementsAs Seq("lemon", "pomegranate") 128 | validationDs.select("color").as[String].collect should contain theSameElementsAs Seq("blue", "red") 129 | } 130 | 131 | test("Updating from a local Dataset with new and only some previous columns") { 132 | val tablename = "UpdateTest1" 133 | dynamoDB.createTable(new CreateTableRequest() 134 | .withTableName(tablename) 135 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 136 | .withKeySchema(new KeySchemaElement("name", "HASH")) 137 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 138 | 139 | import spark.implicits._ 140 | 141 | val newItemsDs = Seq( 142 | ("lemon", "yellow", 0.1), 143 | ("orange", "orange", 0.2), 144 | ("pomegranate", "red", 0.2) 145 | ).toDF("name", "color", "weight") 146 | newItemsDs.write.dynamodb(tablename) 147 | 148 | newItemsDs 149 | .withColumn("size", sqlLength($"color")) 150 | .drop("color") 151 | .withColumn("weight", $"weight" * 2) 152 | .write.option("update", "true").dynamodb(tablename) 153 | 154 | val validationDs = spark.read.dynamodb(tablename) 155 | validationDs.show 156 | assert(validationDs.count() === 3) 157 | assert(validationDs.select("name").as[String].collect().forall(Seq("lemon", "orange", "pomegranate") contains _)) 158 | assert(validationDs.select("color").as[String].collect().forall(Seq("yellow", "orange", "red") contains _)) 159 | assert(validationDs.select("weight").as[Double].collect().forall(Seq(0.2, 0.4, 0.4) contains _)) 160 | assert(validationDs.select("size").as[Long].collect().forall(Seq(6, 3) contains _)) 161 | } 162 | 163 | test("Updating from a local Dataset with null values") { 164 | val tablename = "UpdateTest2" 165 | dynamoDB.createTable(new CreateTableRequest() 166 | .withTableName(tablename) 167 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 168 | .withKeySchema(new KeySchemaElement("name", "HASH")) 169 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 170 | 171 | import spark.implicits._ 172 | 173 | val newItemsDs = Seq( 174 | ("lemon", "yellow", 0.1), 175 | ("orange", "orange", 0.2), 176 | ("pomegranate", "red", 0.2) 177 | ).toDF("name", "color", "weight") 178 | newItemsDs.write.dynamodb(tablename) 179 | 180 | val alteredDs = newItemsDs 181 | .withColumn("weight", when($"weight" < 0.2, $"weight").otherwise(lit(null))) 182 | alteredDs.show 183 | alteredDs.write.option("update", "true").dynamodb(tablename) 184 | 185 | val validationDs = spark.read.dynamodb(tablename) 186 | validationDs.show 187 | assert(validationDs.count() === 3) 188 | assert(validationDs.select("name").as[String].collect().forall(Seq("lemon", "orange", "pomegranate") contains _)) 189 | assert(validationDs.select("color").as[String].collect().forall(Seq("yellow", "orange", "red") contains _)) 190 | assert(validationDs.select("weight").as[Double].collect().forall(Seq(0.2, 0.1) contains _)) 191 | } 192 | 193 | } 194 | -------------------------------------------------------------------------------- /src/test/scala/com/audienceproject/spark/dynamodb/NestedDataStructuresTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb 22 | 23 | import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput} 24 | import com.audienceproject.spark.dynamodb.implicits._ 25 | import com.audienceproject.spark.dynamodb.structs.{TestFruitProperties, TestFruitWithProperties} 26 | import org.apache.spark.sql.Row 27 | import org.apache.spark.sql.functions.struct 28 | import org.apache.spark.sql.types._ 29 | 30 | class NestedDataStructuresTest extends AbstractInMemoryTest { 31 | 32 | test("Insert ArrayType") { 33 | dynamoDB.createTable(new CreateTableRequest() 34 | .withTableName("InsertTestList") 35 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 36 | .withKeySchema(new KeySchemaElement("name", "HASH")) 37 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 38 | 39 | import spark.implicits._ 40 | 41 | val fruitSchema = StructType( 42 | Seq( 43 | StructField("name", StringType, nullable = false), 44 | StructField("color", StringType, nullable = false), 45 | StructField("weight", DoubleType, nullable = false), 46 | StructField("properties", ArrayType(StringType, containsNull = false), nullable = false) 47 | )) 48 | 49 | val rows = spark.sparkContext.parallelize(Seq( 50 | Row("lemon", "yellow", 0.1, Seq("fresh", "2 dkk")), 51 | Row("orange", "orange", 0.2, Seq("too ripe", "1 dkk")), 52 | Row("pomegranate", "red", 0.2, Seq("freshness", "4 dkk")) 53 | )) 54 | 55 | val newItemsDs = spark.createDataFrame(rows, fruitSchema) 56 | 57 | newItemsDs.printSchema() 58 | newItemsDs.show(false) 59 | 60 | newItemsDs.write.dynamodb("InsertTestList") 61 | 62 | println("Writing successful.") 63 | 64 | val validationDs = spark.read.dynamodb("InsertTestList") 65 | assert(validationDs.count() === 3) 66 | assert(validationDs.select($"properties".as[Seq[String]]).collect().forall(Seq( 67 | Seq("fresh", "2 dkk"), 68 | Seq("too ripe", "1 dkk"), 69 | Seq("freshness", "4 dkk") 70 | ) contains _)) 71 | } 72 | 73 | test("Insert MapType") { 74 | dynamoDB.createTable(new CreateTableRequest() 75 | .withTableName("InsertTestMap") 76 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 77 | .withKeySchema(new KeySchemaElement("name", "HASH")) 78 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 79 | 80 | import spark.implicits._ 81 | 82 | val fruitSchema = StructType( 83 | Seq( 84 | StructField("name", StringType, nullable = false), 85 | StructField("color", StringType, nullable = false), 86 | StructField("weight", DoubleType, nullable = false), 87 | StructField("properties", MapType(StringType, StringType, valueContainsNull = false)) 88 | )) 89 | 90 | val rows = spark.sparkContext.parallelize(Seq( 91 | Row("lemon", "yellow", 0.1, Map("freshness" -> "fresh", "eco" -> "yes", "price" -> "2 dkk")), 92 | Row("orange", "orange", 0.2, Map("freshness" -> "too ripe", "eco" -> "no", "price" -> "1 dkk")), 93 | Row("pomegranate", "red", 0.2, Map("freshness" -> "green", "eco" -> "yes", "price" -> "4 dkk")) 94 | )) 95 | 96 | val newItemsDs = spark.createDataFrame(rows, fruitSchema) 97 | 98 | newItemsDs.printSchema() 99 | newItemsDs.show(false) 100 | 101 | newItemsDs.write.dynamodb("InsertTestMap") 102 | 103 | println("Writing successful.") 104 | 105 | val validationDs = spark.read.schema(fruitSchema).dynamodb("InsertTestMap") 106 | validationDs.show(false) 107 | assert(validationDs.count() === 3) 108 | assert(validationDs.select($"properties".as[Map[String, String]]).collect().forall(Seq( 109 | Map("freshness" -> "fresh", "eco" -> "yes", "price" -> "2 dkk"), 110 | Map("freshness" -> "too ripe", "eco" -> "no", "price" -> "1 dkk"), 111 | Map("freshness" -> "green", "eco" -> "yes", "price" -> "4 dkk") 112 | ) contains _)) 113 | } 114 | 115 | test("Insert ArrayType with nested MapType") { 116 | dynamoDB.createTable(new CreateTableRequest() 117 | .withTableName("InsertTestListMap") 118 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 119 | .withKeySchema(new KeySchemaElement("name", "HASH")) 120 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 121 | 122 | import spark.implicits._ 123 | 124 | val fruitSchema = StructType( 125 | Seq( 126 | StructField("name", StringType, nullable = false), 127 | StructField("color", StringType, nullable = false), 128 | StructField("weight", DoubleType, nullable = false), 129 | StructField("properties", ArrayType(MapType(StringType, StringType, valueContainsNull = false), containsNull = false), nullable = false) 130 | )) 131 | 132 | val rows = spark.sparkContext.parallelize(Seq( 133 | Row("lemon", "yellow", 0.1, Seq(Map("freshness" -> "fresh", "eco" -> "yes", "price" -> "2 dkk"))), 134 | Row("orange", "orange", 0.2, Seq(Map("freshness" -> "too ripe", "eco" -> "no", "price" -> "1 dkk"))), 135 | Row("pomegranate", "red", 0.2, Seq(Map("freshness" -> "green", "eco" -> "yes", "price" -> "4 dkk"))) 136 | )) 137 | 138 | val newItemsDs = spark.createDataFrame(rows, fruitSchema) 139 | 140 | newItemsDs.printSchema() 141 | newItemsDs.show(false) 142 | 143 | newItemsDs.write.dynamodb("InsertTestListMap") 144 | 145 | println("Writing successful.") 146 | 147 | val validationDs = spark.read.schema(fruitSchema).dynamodb("InsertTestListMap") 148 | validationDs.show(false) 149 | assert(validationDs.count() === 3) 150 | assert(validationDs.select($"properties".as[Seq[Map[String, String]]]).collect().forall(Seq( 151 | Seq(Map("freshness" -> "fresh", "eco" -> "yes", "price" -> "2 dkk")), 152 | Seq(Map("freshness" -> "too ripe", "eco" -> "no", "price" -> "1 dkk")), 153 | Seq(Map("freshness" -> "green", "eco" -> "yes", "price" -> "4 dkk")) 154 | ) contains _)) 155 | } 156 | 157 | test("Insert StructType") { 158 | dynamoDB.createTable(new CreateTableRequest() 159 | .withTableName("InsertTestStruct") 160 | .withAttributeDefinitions(new AttributeDefinition("name", "S")) 161 | .withKeySchema(new KeySchemaElement("name", "HASH")) 162 | .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L))) 163 | 164 | import spark.implicits._ 165 | 166 | val fruitSchema = StructType( 167 | Seq( 168 | StructField("name", StringType, nullable = false), 169 | StructField("color", StringType, nullable = false), 170 | StructField("weight", DoubleType, nullable = false), 171 | StructField("freshness", StringType, nullable = false), 172 | StructField("eco", BooleanType, nullable = false), 173 | StructField("price", DoubleType, nullable = false) 174 | )) 175 | 176 | val rows = spark.sparkContext.parallelize(Seq( 177 | Row("lemon", "yellow", 0.1, "fresh", true, 2.0), 178 | Row("pomegranate", "red", 0.2, "green", true, 4.0) 179 | )) 180 | 181 | val newItemsDs = spark.createDataFrame(rows, fruitSchema).select( 182 | $"name", 183 | $"color", 184 | $"weight", 185 | struct($"freshness", $"eco", $"price") as "properties" 186 | ) 187 | 188 | newItemsDs.printSchema() 189 | newItemsDs.show(false) 190 | 191 | newItemsDs.write.dynamodb("InsertTestStruct") 192 | 193 | println("Writing successful.") 194 | 195 | val validationDs = spark.read.dynamodbAs[TestFruitWithProperties]("InsertTestStruct") 196 | assert(validationDs.count() === 2) 197 | assert(validationDs.select($"properties".as[TestFruitProperties]).collect().forall(Seq( 198 | TestFruitProperties("fresh", eco = true, 2.0), 199 | TestFruitProperties("green", eco = true, 4.0) 200 | ) contains _)) 201 | } 202 | 203 | } 204 | -------------------------------------------------------------------------------- /src/main/java/com/audienceproject/shaded/google/common/util/concurrent/Uninterruptibles.java: -------------------------------------------------------------------------------- 1 | package com.audienceproject.shaded.google.common.util.concurrent; 2 | 3 | /* 4 | * Notice: 5 | * This file was modified at AudienceProject ApS by Cosmin Catalin Sanda (cosmin@audienceproject.com) 6 | */ 7 | 8 | /* 9 | * Copyright (C) 2011 The Guava Authors 10 | * 11 | * Licensed under the Apache License, Version 2.0 (the "License"); 12 | * you may not use this file except in compliance with the License. 13 | * You may obtain a copy of the License at 14 | * 15 | * http://www.apache.org/licenses/LICENSE-2.0 16 | * 17 | * Unless required by applicable law or agreed to in writing, software 18 | * distributed under the License is distributed on an "AS IS" BASIS, 19 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | * See the License for the specific language governing permissions and 21 | * limitations under the License. 22 | */ 23 | 24 | import com.audienceproject.shaded.google.common.base.Preconditions; 25 | 26 | import java.util.concurrent.*; 27 | 28 | import static java.util.concurrent.TimeUnit.NANOSECONDS; 29 | 30 | /** 31 | * Utilities for treating interruptible operations as uninterruptible. 32 | * In all cases, if a thread is interrupted during such a call, the call 33 | * continues to block until the result is available or the timeout elapses, 34 | * and only then re-interrupts the thread. 35 | * 36 | * @author Anthony Zana 37 | * @since 10.0 38 | */ 39 | public final class Uninterruptibles { 40 | 41 | // Implementation Note: As of 3-7-11, the logic for each blocking/timeout 42 | // methods is identical, save for method being invoked. 43 | 44 | /** 45 | * Invokes {@code latch.}{@link CountDownLatch#await() await()} 46 | * uninterruptibly. 47 | */ 48 | public static void awaitUninterruptibly(CountDownLatch latch) { 49 | boolean interrupted = false; 50 | try { 51 | while (true) { 52 | try { 53 | latch.await(); 54 | return; 55 | } catch (InterruptedException e) { 56 | interrupted = true; 57 | } 58 | } 59 | } finally { 60 | if (interrupted) { 61 | Thread.currentThread().interrupt(); 62 | } 63 | } 64 | } 65 | 66 | /** 67 | * Invokes 68 | * {@code latch.}{@link CountDownLatch#await(long, TimeUnit) 69 | * await(timeout, unit)} uninterruptibly. 70 | */ 71 | public static boolean awaitUninterruptibly(CountDownLatch latch, 72 | long timeout, TimeUnit unit) { 73 | boolean interrupted = false; 74 | try { 75 | long remainingNanos = unit.toNanos(timeout); 76 | long end = System.nanoTime() + remainingNanos; 77 | 78 | while (true) { 79 | try { 80 | // CountDownLatch treats negative timeouts just like zero. 81 | return latch.await(remainingNanos, NANOSECONDS); 82 | } catch (InterruptedException e) { 83 | interrupted = true; 84 | remainingNanos = end - System.nanoTime(); 85 | } 86 | } 87 | } finally { 88 | if (interrupted) { 89 | Thread.currentThread().interrupt(); 90 | } 91 | } 92 | } 93 | 94 | /** 95 | * Invokes {@code toJoin.}{@link Thread#join() join()} uninterruptibly. 96 | */ 97 | public static void joinUninterruptibly(Thread toJoin) { 98 | boolean interrupted = false; 99 | try { 100 | while (true) { 101 | try { 102 | toJoin.join(); 103 | return; 104 | } catch (InterruptedException e) { 105 | interrupted = true; 106 | } 107 | } 108 | } finally { 109 | if (interrupted) { 110 | Thread.currentThread().interrupt(); 111 | } 112 | } 113 | } 114 | 115 | /** 116 | * Invokes {@code future.}{@link Future#get() get()} uninterruptibly. 117 | * To get uninterruptibility and remove checked exceptions. 118 | * 119 | *

If instead, you wish to treat {@link InterruptedException} uniformly 120 | * with other exceptions. 121 | * 122 | * @throws ExecutionException if the computation threw an exception 123 | * @throws CancellationException if the computation was cancelled 124 | */ 125 | public static V getUninterruptibly(Future future) 126 | throws ExecutionException { 127 | boolean interrupted = false; 128 | try { 129 | while (true) { 130 | try { 131 | return future.get(); 132 | } catch (InterruptedException e) { 133 | interrupted = true; 134 | } 135 | } 136 | } finally { 137 | if (interrupted) { 138 | Thread.currentThread().interrupt(); 139 | } 140 | } 141 | } 142 | 143 | /** 144 | * Invokes 145 | * {@code future.}{@link Future#get(long, TimeUnit) get(timeout, unit)} 146 | * uninterruptibly. 147 | * 148 | *

If instead, you wish to treat {@link InterruptedException} uniformly 149 | * with other exceptions. 150 | * 151 | * @throws ExecutionException if the computation threw an exception 152 | * @throws CancellationException if the computation was cancelled 153 | * @throws TimeoutException if the wait timed out 154 | */ 155 | public static V getUninterruptibly( 156 | Future future, long timeout, TimeUnit unit) 157 | throws ExecutionException, TimeoutException { 158 | boolean interrupted = false; 159 | try { 160 | long remainingNanos = unit.toNanos(timeout); 161 | long end = System.nanoTime() + remainingNanos; 162 | 163 | while (true) { 164 | try { 165 | // Future treats negative timeouts just like zero. 166 | return future.get(remainingNanos, NANOSECONDS); 167 | } catch (InterruptedException e) { 168 | interrupted = true; 169 | remainingNanos = end - System.nanoTime(); 170 | } 171 | } 172 | } finally { 173 | if (interrupted) { 174 | Thread.currentThread().interrupt(); 175 | } 176 | } 177 | } 178 | 179 | /** 180 | * Invokes 181 | * {@code unit.}{@link TimeUnit#timedJoin(Thread, long) 182 | * timedJoin(toJoin, timeout)} uninterruptibly. 183 | */ 184 | public static void joinUninterruptibly(Thread toJoin, 185 | long timeout, TimeUnit unit) { 186 | Preconditions.checkNotNull(toJoin); 187 | boolean interrupted = false; 188 | try { 189 | long remainingNanos = unit.toNanos(timeout); 190 | long end = System.nanoTime() + remainingNanos; 191 | while (true) { 192 | try { 193 | // TimeUnit.timedJoin() treats negative timeouts just like zero. 194 | NANOSECONDS.timedJoin(toJoin, remainingNanos); 195 | return; 196 | } catch (InterruptedException e) { 197 | interrupted = true; 198 | remainingNanos = end - System.nanoTime(); 199 | } 200 | } 201 | } finally { 202 | if (interrupted) { 203 | Thread.currentThread().interrupt(); 204 | } 205 | } 206 | } 207 | 208 | /** 209 | * Invokes {@code queue.}{@link BlockingQueue#take() take()} uninterruptibly. 210 | */ 211 | public static E takeUninterruptibly(BlockingQueue queue) { 212 | boolean interrupted = false; 213 | try { 214 | while (true) { 215 | try { 216 | return queue.take(); 217 | } catch (InterruptedException e) { 218 | interrupted = true; 219 | } 220 | } 221 | } finally { 222 | if (interrupted) { 223 | Thread.currentThread().interrupt(); 224 | } 225 | } 226 | } 227 | 228 | /** 229 | * Invokes {@code queue.}{@link BlockingQueue#put(Object) put(element)} 230 | * uninterruptibly. 231 | * 232 | * @throws ClassCastException if the class of the specified element prevents 233 | * it from being added to the given queue 234 | * @throws IllegalArgumentException if some property of the specified element 235 | * prevents it from being added to the given queue 236 | */ 237 | public static void putUninterruptibly(BlockingQueue queue, E element) { 238 | boolean interrupted = false; 239 | try { 240 | while (true) { 241 | try { 242 | queue.put(element); 243 | return; 244 | } catch (InterruptedException e) { 245 | interrupted = true; 246 | } 247 | } 248 | } finally { 249 | if (interrupted) { 250 | Thread.currentThread().interrupt(); 251 | } 252 | } 253 | } 254 | 255 | // TODO(user): Support Sleeper somehow (wrapper or interface method)? 256 | /** 257 | * Invokes {@code unit.}{@link TimeUnit#sleep(long) sleep(sleepFor)} 258 | * uninterruptibly. 259 | */ 260 | public static void sleepUninterruptibly(long sleepFor, TimeUnit unit) { 261 | boolean interrupted = false; 262 | try { 263 | long remainingNanos = unit.toNanos(sleepFor); 264 | long end = System.nanoTime() + remainingNanos; 265 | while (true) { 266 | try { 267 | // TimeUnit.sleep() treats negative timeouts just like zero. 268 | NANOSECONDS.sleep(remainingNanos); 269 | return; 270 | } catch (InterruptedException e) { 271 | interrupted = true; 272 | remainingNanos = end - System.nanoTime(); 273 | } 274 | } 275 | } finally { 276 | if (interrupted) { 277 | Thread.currentThread().interrupt(); 278 | } 279 | } 280 | } 281 | 282 | // TODO(user): Add support for waitUninterruptibly. 283 | 284 | private Uninterruptibles() {} 285 | } 286 | 287 | -------------------------------------------------------------------------------- /src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, 13 | * software distributed under the License is distributed on an 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | * KIND, either express or implied. See the License for the 16 | * specific language governing permissions and limitations 17 | * under the License. 18 | * 19 | * Copyright © 2018 AudienceProject. All rights reserved. 20 | */ 21 | package com.audienceproject.spark.dynamodb.connector 22 | 23 | import com.amazonaws.services.dynamodbv2.document._ 24 | import com.amazonaws.services.dynamodbv2.document.spec.{BatchWriteItemSpec, ScanSpec, UpdateItemSpec} 25 | import com.amazonaws.services.dynamodbv2.model.ReturnConsumedCapacity 26 | import com.amazonaws.services.dynamodbv2.xspec.ExpressionSpecBuilder 27 | import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter 28 | import com.audienceproject.spark.dynamodb.catalyst.JavaConverter 29 | import org.apache.spark.sql.catalyst.InternalRow 30 | import org.apache.spark.sql.sources.Filter 31 | 32 | import scala.annotation.tailrec 33 | import scala.collection.JavaConverters._ 34 | 35 | private[dynamodb] class TableConnector(tableName: String, parallelism: Int, parameters: Map[String, String]) 36 | extends DynamoConnector with DynamoWritable with Serializable { 37 | 38 | private val consistentRead = parameters.getOrElse("stronglyconsistentreads", "false").toBoolean 39 | private val filterPushdown = parameters.getOrElse("filterpushdown", "true").toBoolean 40 | private val region = parameters.get("region") 41 | private val roleArn = parameters.get("rolearn") 42 | private val providerClassName = parameters.get("providerclassname") 43 | 44 | override val filterPushdownEnabled: Boolean = filterPushdown 45 | 46 | override val (keySchema, readLimit, writeLimit, itemLimit, totalSegments) = { 47 | val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName) 48 | val desc = table.describe() 49 | 50 | // Key schema. 51 | val keySchema = KeySchema.fromDescription(desc.getKeySchema.asScala) 52 | 53 | // User parameters. 54 | val bytesPerRCU = parameters.getOrElse("bytesperrcu", "4000").toInt 55 | val maxPartitionBytes = parameters.getOrElse("maxpartitionbytes", "128000000").toInt 56 | val targetCapacity = parameters.getOrElse("targetcapacity", "1").toDouble 57 | val readFactor = if (consistentRead) 1 else 2 58 | 59 | // Table parameters. 60 | val tableSize = desc.getTableSizeBytes 61 | val itemCount = desc.getItemCount 62 | 63 | // Partitioning calculation. 64 | val numPartitions = parameters.get("readpartitions").map(_.toInt).getOrElse({ 65 | val sizeBased = (tableSize / maxPartitionBytes).toInt max 1 66 | val remainder = sizeBased % parallelism 67 | if (remainder > 0) sizeBased + (parallelism - remainder) 68 | else sizeBased 69 | }) 70 | 71 | // Provisioned or on-demand throughput. 72 | val readThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getReadCapacityUnits) 73 | .filter(_ > 0).map(_.longValue().toString) 74 | .getOrElse("100")).toLong 75 | val writeThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getWriteCapacityUnits) 76 | .filter(_ > 0).map(_.longValue().toString) 77 | .getOrElse("100")).toLong 78 | 79 | // Rate limit calculation. 80 | val avgItemSize = tableSize.toDouble / itemCount 81 | val readCapacity = readThroughput * targetCapacity 82 | val writeCapacity = writeThroughput * targetCapacity 83 | 84 | val readLimit = readCapacity / parallelism 85 | val itemLimit = ((bytesPerRCU / avgItemSize * readLimit).toInt * readFactor) max 1 86 | 87 | val writeLimit = writeCapacity / parallelism 88 | 89 | (keySchema, readLimit, writeLimit, itemLimit, numPartitions) 90 | } 91 | 92 | override def scan(segmentNum: Int, columns: Seq[String], filters: Seq[Filter]): ItemCollection[ScanOutcome] = { 93 | val scanSpec = new ScanSpec() 94 | .withSegment(segmentNum) 95 | .withTotalSegments(totalSegments) 96 | .withMaxPageSize(itemLimit) 97 | .withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL) 98 | .withConsistentRead(consistentRead) 99 | 100 | if (columns.nonEmpty) { 101 | val xspec = new ExpressionSpecBuilder().addProjections(columns: _*) 102 | 103 | if (filters.nonEmpty && filterPushdown) { 104 | xspec.withCondition(FilterPushdown(filters)) 105 | } 106 | 107 | scanSpec.withExpressionSpec(xspec.buildForScan()) 108 | } 109 | 110 | getDynamoDB(region, roleArn, providerClassName).getTable(tableName).scan(scanSpec) 111 | } 112 | 113 | override def putItems(columnSchema: ColumnSchema, items: Seq[InternalRow]) 114 | (client: DynamoDB, rateLimiter: RateLimiter): Unit = { 115 | // For each batch. 116 | val batchWriteItemSpec = new BatchWriteItemSpec().withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL) 117 | batchWriteItemSpec.withTableWriteItems(new TableWriteItems(tableName).withItemsToPut( 118 | // Map the items. 119 | items.map(row => { 120 | val item = new Item() 121 | 122 | // Map primary key. 123 | columnSchema.keys() match { 124 | case Left((hashKey, hashKeyIndex, hashKeyType)) => 125 | item.withPrimaryKey(hashKey, JavaConverter.convertRowValue(row, hashKeyIndex, hashKeyType)) 126 | case Right(((hashKey, hashKeyIndex, hashKeyType), (rangeKey, rangeKeyIndex, rangeKeyType))) => 127 | val hashKeyValue = JavaConverter.convertRowValue(row, hashKeyIndex, hashKeyType) 128 | val rangeKeyValue = JavaConverter.convertRowValue(row, rangeKeyIndex, rangeKeyType) 129 | item.withPrimaryKey(hashKey, hashKeyValue, rangeKey, rangeKeyValue) 130 | } 131 | 132 | // Map remaining columns. 133 | columnSchema.attributes().foreach({ 134 | case (name, index, dataType) if !row.isNullAt(index) => 135 | item.`with`(name, JavaConverter.convertRowValue(row, index, dataType)) 136 | case _ => 137 | }) 138 | 139 | item 140 | }): _* 141 | )) 142 | 143 | val response = client.batchWriteItem(batchWriteItemSpec) 144 | handleBatchWriteResponse(client, rateLimiter)(response) 145 | } 146 | 147 | override def updateItem(columnSchema: ColumnSchema, row: InternalRow) 148 | (client: DynamoDB, rateLimiter: RateLimiter): Unit = { 149 | val updateItemSpec = new UpdateItemSpec().withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL) 150 | 151 | // Map primary key. 152 | columnSchema.keys() match { 153 | case Left((hashKey, hashKeyIndex, hashKeyType)) => 154 | updateItemSpec.withPrimaryKey(hashKey, JavaConverter.convertRowValue(row, hashKeyIndex, hashKeyType)) 155 | case Right(((hashKey, hashKeyIndex, hashKeyType), (rangeKey, rangeKeyIndex, rangeKeyType))) => 156 | val hashKeyValue = JavaConverter.convertRowValue(row, hashKeyIndex, hashKeyType) 157 | val rangeKeyValue = JavaConverter.convertRowValue(row, rangeKeyIndex, rangeKeyType) 158 | updateItemSpec.withPrimaryKey(hashKey, hashKeyValue, rangeKey, rangeKeyValue) 159 | } 160 | 161 | // Map remaining columns. 162 | val attributeUpdates = columnSchema.attributes().collect({ 163 | case (name, index, dataType) if !row.isNullAt(index) => 164 | new AttributeUpdate(name).put(JavaConverter.convertRowValue(row, index, dataType)) 165 | }) 166 | 167 | updateItemSpec.withAttributeUpdate(attributeUpdates: _*) 168 | 169 | // Update item and rate limit on write capacity. 170 | val response = client.getTable(tableName).updateItem(updateItemSpec) 171 | Option(response.getUpdateItemResult.getConsumedCapacity) 172 | .foreach(cap => rateLimiter.acquire(cap.getCapacityUnits.toInt max 1)) 173 | } 174 | 175 | override def deleteItems(columnSchema: ColumnSchema, items: Seq[InternalRow]) 176 | (client: DynamoDB, rateLimiter: RateLimiter): Unit = { 177 | // For each batch. 178 | val batchWriteItemSpec = new BatchWriteItemSpec().withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL) 179 | 180 | val tableWriteItems = new TableWriteItems(tableName) 181 | val tableWriteItemsWithItems: TableWriteItems = 182 | // Check if hash key only or also range key. 183 | columnSchema.keys() match { 184 | case Left((hashKey, hashKeyIndex, hashKeyType)) => 185 | val hashKeys = items.map(row => 186 | JavaConverter.convertRowValue(row, hashKeyIndex, hashKeyType).asInstanceOf[AnyRef]) 187 | tableWriteItems.withHashOnlyKeysToDelete(hashKey, hashKeys: _*) 188 | case Right(((hashKey, hashKeyIndex, hashKeyType), (rangeKey, rangeKeyIndex, rangeKeyType))) => 189 | val alternatingHashAndRangeKeys = items.flatMap { row => 190 | val hashKeyValue = JavaConverter.convertRowValue(row, hashKeyIndex, hashKeyType) 191 | val rangeKeyValue = JavaConverter.convertRowValue(row, rangeKeyIndex, rangeKeyType) 192 | Seq(hashKeyValue.asInstanceOf[AnyRef], rangeKeyValue.asInstanceOf[AnyRef]) 193 | } 194 | tableWriteItems.withHashAndRangeKeysToDelete(hashKey, rangeKey, alternatingHashAndRangeKeys: _*) 195 | } 196 | 197 | batchWriteItemSpec.withTableWriteItems(tableWriteItemsWithItems) 198 | 199 | val response = client.batchWriteItem(batchWriteItemSpec) 200 | handleBatchWriteResponse(client, rateLimiter)(response) 201 | } 202 | 203 | @tailrec 204 | private def handleBatchWriteResponse(client: DynamoDB, rateLimiter: RateLimiter) 205 | (response: BatchWriteItemOutcome): Unit = { 206 | // Rate limit on write capacity. 207 | if (response.getBatchWriteItemResult.getConsumedCapacity != null) { 208 | response.getBatchWriteItemResult.getConsumedCapacity.asScala.map(cap => { 209 | cap.getTableName -> cap.getCapacityUnits.toInt 210 | }).toMap.get(tableName).foreach(units => rateLimiter.acquire(units max 1)) 211 | } 212 | // Retry unprocessed items. 213 | if (response.getUnprocessedItems != null && !response.getUnprocessedItems.isEmpty) { 214 | val newResponse = client.batchWriteItemUnprocessed(response.getUnprocessedItems) 215 | handleBatchWriteResponse(client, rateLimiter)(newResponse) 216 | } 217 | } 218 | 219 | } 220 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/main/java/com/audienceproject/shaded/google/common/base/Preconditions.java: -------------------------------------------------------------------------------- 1 | package com.audienceproject.shaded.google.common.base; 2 | 3 | /* 4 | * Notice: 5 | * This file was modified at AudienceProject ApS by Cosmin Catalin Sanda (cosmin@audienceproject.com) 6 | */ 7 | 8 | /* 9 | * Copyright (C) 2007 The Guava Authors 10 | * 11 | * Licensed under the Apache License, Version 2.0 (the "License"); 12 | * you may not use this file except in compliance with the License. 13 | * You may obtain a copy of the License at 14 | * 15 | * http://www.apache.org/licenses/LICENSE-2.0 16 | * 17 | * Unless required by applicable law or agreed to in writing, software 18 | * distributed under the License is distributed on an "AS IS" BASIS, 19 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | * See the License for the specific language governing permissions and 21 | * limitations under the License. 22 | */ 23 | 24 | import java.util.NoSuchElementException; 25 | 26 | import javax.annotation.Nullable; 27 | 28 | /** 29 | * Simple static methods to be called at the start of your own methods to verify 30 | * correct arguments and state. This allows constructs such as 31 | *

 32 |  *     if (count <= 0) {
 33 |  *       throw new IllegalArgumentException("must be positive: " + count);
 34 |  *     }
35 | * 36 | * to be replaced with the more compact 37 | *
 38 |  *     checkArgument(count > 0, "must be positive: %s", count);
39 | * 40 | * Note that the sense of the expression is inverted; with {@code Preconditions} 41 | * you declare what you expect to be true, just as you do with an 42 | * 43 | * {@code assert} or a JUnit {@code assertTrue} call. 44 | * 45 | *

Warning: only the {@code "%s"} specifier is recognized as a 46 | * placeholder in these messages, not the full range of {@link 47 | * String#format(String, Object[])} specifiers. 48 | * 49 | *

Take care not to confuse precondition checking with other similar types 50 | * of checks! Precondition exceptions -- including those provided here, but also 51 | * {@link IndexOutOfBoundsException}, {@link NoSuchElementException}, {@link 52 | * UnsupportedOperationException} and others -- are used to signal that the 53 | * calling method has made an error. This tells the caller that it should 54 | * not have invoked the method when it did, with the arguments it did, or 55 | * perhaps ever. Postcondition or other invariant failures should not throw 56 | * these types of exceptions. 57 | * 58 | *

See the Guava User Guide on 60 | * using {@code Preconditions}. 61 | * 62 | * @author Kevin Bourrillion 63 | * @since 2.0 (imported from Google Collections Library) 64 | */ 65 | public final class Preconditions { 66 | private Preconditions() {} 67 | 68 | /** 69 | * Ensures the truth of an expression involving one or more parameters to the 70 | * calling method. 71 | * 72 | * @param expression a boolean expression 73 | * @throws IllegalArgumentException if {@code expression} is false 74 | */ 75 | public static void checkArgument(boolean expression) { 76 | if (!expression) { 77 | throw new IllegalArgumentException(); 78 | } 79 | } 80 | 81 | /** 82 | * Ensures the truth of an expression involving one or more parameters to the 83 | * calling method. 84 | * 85 | * @param expression a boolean expression 86 | * @param errorMessage the exception message to use if the check fails; will 87 | * be converted to a string using {@link String#valueOf(Object)} 88 | * @throws IllegalArgumentException if {@code expression} is false 89 | */ 90 | public static void checkArgument( 91 | boolean expression, @Nullable Object errorMessage) { 92 | if (!expression) { 93 | throw new IllegalArgumentException(String.valueOf(errorMessage)); 94 | } 95 | } 96 | 97 | /** 98 | * Ensures the truth of an expression involving one or more parameters to the 99 | * calling method. 100 | * 101 | * @param expression a boolean expression 102 | * @param errorMessageTemplate a template for the exception message should the 103 | * check fail. The message is formed by replacing each {@code %s} 104 | * placeholder in the template with an argument. These are matched by 105 | * position - the first {@code %s} gets {@code errorMessageArgs[0]}, etc. 106 | * Unmatched arguments will be appended to the formatted message in square 107 | * braces. Unmatched placeholders will be left as-is. 108 | * @param errorMessageArgs the arguments to be substituted into the message 109 | * template. Arguments are converted to strings using 110 | * {@link String#valueOf(Object)}. 111 | * @throws IllegalArgumentException if {@code expression} is false 112 | * @throws NullPointerException if the check fails and either {@code 113 | * errorMessageTemplate} or {@code errorMessageArgs} is null (don't let 114 | * this happen) 115 | */ 116 | public static void checkArgument(boolean expression, 117 | @Nullable String errorMessageTemplate, 118 | @Nullable Object... errorMessageArgs) { 119 | if (!expression) { 120 | throw new IllegalArgumentException( 121 | format(errorMessageTemplate, errorMessageArgs)); 122 | } 123 | } 124 | 125 | /** 126 | * Ensures the truth of an expression involving the state of the calling 127 | * instance, but not involving any parameters to the calling method. 128 | * 129 | * @param expression a boolean expression 130 | * @throws IllegalStateException if {@code expression} is false 131 | */ 132 | public static void checkState(boolean expression) { 133 | if (!expression) { 134 | throw new IllegalStateException(); 135 | } 136 | } 137 | 138 | /** 139 | * Ensures the truth of an expression involving the state of the calling 140 | * instance, but not involving any parameters to the calling method. 141 | * 142 | * @param expression a boolean expression 143 | * @param errorMessage the exception message to use if the check fails; will 144 | * be converted to a string using {@link String#valueOf(Object)} 145 | * @throws IllegalStateException if {@code expression} is false 146 | */ 147 | public static void checkState( 148 | boolean expression, @Nullable Object errorMessage) { 149 | if (!expression) { 150 | throw new IllegalStateException(String.valueOf(errorMessage)); 151 | } 152 | } 153 | 154 | /** 155 | * Ensures the truth of an expression involving the state of the calling 156 | * instance, but not involving any parameters to the calling method. 157 | * 158 | * @param expression a boolean expression 159 | * @param errorMessageTemplate a template for the exception message should the 160 | * check fail. The message is formed by replacing each {@code %s} 161 | * placeholder in the template with an argument. These are matched by 162 | * position - the first {@code %s} gets {@code errorMessageArgs[0]}, etc. 163 | * Unmatched arguments will be appended to the formatted message in square 164 | * braces. Unmatched placeholders will be left as-is. 165 | * @param errorMessageArgs the arguments to be substituted into the message 166 | * template. Arguments are converted to strings using 167 | * {@link String#valueOf(Object)}. 168 | * @throws IllegalStateException if {@code expression} is false 169 | * @throws NullPointerException if the check fails and either {@code 170 | * errorMessageTemplate} or {@code errorMessageArgs} is null (don't let 171 | * this happen) 172 | */ 173 | public static void checkState(boolean expression, 174 | @Nullable String errorMessageTemplate, 175 | @Nullable Object... errorMessageArgs) { 176 | if (!expression) { 177 | throw new IllegalStateException( 178 | format(errorMessageTemplate, errorMessageArgs)); 179 | } 180 | } 181 | 182 | /** 183 | * Ensures that an object reference passed as a parameter to the calling 184 | * method is not null. 185 | * 186 | * @param reference an object reference 187 | * @return the non-null reference that was validated 188 | * @throws NullPointerException if {@code reference} is null 189 | */ 190 | public static T checkNotNull(T reference) { 191 | if (reference == null) { 192 | throw new NullPointerException(); 193 | } 194 | return reference; 195 | } 196 | 197 | /** 198 | * Ensures that an object reference passed as a parameter to the calling 199 | * method is not null. 200 | * 201 | * @param reference an object reference 202 | * @param errorMessage the exception message to use if the check fails; will 203 | * be converted to a string using {@link String#valueOf(Object)} 204 | * @return the non-null reference that was validated 205 | * @throws NullPointerException if {@code reference} is null 206 | */ 207 | public static T checkNotNull(T reference, @Nullable Object errorMessage) { 208 | if (reference == null) { 209 | throw new NullPointerException(String.valueOf(errorMessage)); 210 | } 211 | return reference; 212 | } 213 | 214 | /** 215 | * Ensures that an object reference passed as a parameter to the calling 216 | * method is not null. 217 | * 218 | * @param reference an object reference 219 | * @param errorMessageTemplate a template for the exception message should the 220 | * check fail. The message is formed by replacing each {@code %s} 221 | * placeholder in the template with an argument. These are matched by 222 | * position - the first {@code %s} gets {@code errorMessageArgs[0]}, etc. 223 | * Unmatched arguments will be appended to the formatted message in square 224 | * braces. Unmatched placeholders will be left as-is. 225 | * @param errorMessageArgs the arguments to be substituted into the message 226 | * template. Arguments are converted to strings using 227 | * {@link String#valueOf(Object)}. 228 | * @return the non-null reference that was validated 229 | * @throws NullPointerException if {@code reference} is null 230 | */ 231 | public static T checkNotNull(T reference, 232 | @Nullable String errorMessageTemplate, 233 | @Nullable Object... errorMessageArgs) { 234 | if (reference == null) { 235 | // If either of these parameters is null, the right thing happens anyway 236 | throw new NullPointerException( 237 | format(errorMessageTemplate, errorMessageArgs)); 238 | } 239 | return reference; 240 | } 241 | 242 | /* 243 | * All recent hotspots (as of 2009) *really* like to have the natural code 244 | * 245 | * if (guardExpression) { 246 | * throw new BadException(messageExpression); 247 | * } 248 | * 249 | * refactored so that messageExpression is moved to a separate 250 | * String-returning method. 251 | * 252 | * if (guardExpression) { 253 | * throw new BadException(badMsg(...)); 254 | * } 255 | * 256 | * The alternative natural refactorings into void or Exception-returning 257 | * methods are much slower. This is a big deal - we're talking factors of 258 | * 2-8 in microbenchmarks, not just 10-20%. (This is a hotspot optimizer 259 | * bug, which should be fixed, but that's a separate, big project). 260 | * 261 | * The coding pattern above is heavily used in java.util, e.g. in ArrayList. 262 | * There is a RangeCheckMicroBenchmark in the JDK that was used to test this. 263 | * 264 | * But the methods in this class want to throw different exceptions, 265 | * depending on the args, so it appears that this pattern is not directly 266 | * applicable. But we can use the ridiculous, devious trick of throwing an 267 | * exception in the middle of the construction of another exception. 268 | * Hotspot is fine with that. 269 | */ 270 | 271 | /** 272 | * Ensures that {@code index} specifies a valid element in an array, 273 | * list or string of size {@code size}. An element index may range from zero, 274 | * inclusive, to {@code size}, exclusive. 275 | * 276 | * @param index a user-supplied index identifying an element of an array, list 277 | * or string 278 | * @param size the size of that array, list or string 279 | * @return the value of {@code index} 280 | * @throws IndexOutOfBoundsException if {@code index} is negative or is not 281 | * less than {@code size} 282 | * @throws IllegalArgumentException if {@code size} is negative 283 | */ 284 | public static int checkElementIndex(int index, int size) { 285 | return checkElementIndex(index, size, "index"); 286 | } 287 | 288 | /** 289 | * Ensures that {@code index} specifies a valid element in an array, 290 | * list or string of size {@code size}. An element index may range from zero, 291 | * inclusive, to {@code size}, exclusive. 292 | * 293 | * @param index a user-supplied index identifying an element of an array, list 294 | * or string 295 | * @param size the size of that array, list or string 296 | * @param desc the text to use to describe this index in an error message 297 | * @return the value of {@code index} 298 | * @throws IndexOutOfBoundsException if {@code index} is negative or is not 299 | * less than {@code size} 300 | * @throws IllegalArgumentException if {@code size} is negative 301 | */ 302 | public static int checkElementIndex( 303 | int index, int size, @Nullable String desc) { 304 | // Carefully optimized for execution by hotspot (explanatory comment above) 305 | if (index < 0 || index >= size) { 306 | throw new IndexOutOfBoundsException(badElementIndex(index, size, desc)); 307 | } 308 | return index; 309 | } 310 | 311 | private static String badElementIndex(int index, int size, String desc) { 312 | if (index < 0) { 313 | return format("%s (%s) must not be negative", desc, index); 314 | } else if (size < 0) { 315 | throw new IllegalArgumentException("negative size: " + size); 316 | } else { // index >= size 317 | return format("%s (%s) must be less than size (%s)", desc, index, size); 318 | } 319 | } 320 | 321 | /** 322 | * Ensures that {@code index} specifies a valid position in an array, 323 | * list or string of size {@code size}. A position index may range from zero 324 | * to {@code size}, inclusive. 325 | * 326 | * @param index a user-supplied index identifying a position in an array, list 327 | * or string 328 | * @param size the size of that array, list or string 329 | * @return the value of {@code index} 330 | * @throws IndexOutOfBoundsException if {@code index} is negative or is 331 | * greater than {@code size} 332 | * @throws IllegalArgumentException if {@code size} is negative 333 | */ 334 | public static int checkPositionIndex(int index, int size) { 335 | return checkPositionIndex(index, size, "index"); 336 | } 337 | 338 | /** 339 | * Ensures that {@code index} specifies a valid position in an array, 340 | * list or string of size {@code size}. A position index may range from zero 341 | * to {@code size}, inclusive. 342 | * 343 | * @param index a user-supplied index identifying a position in an array, list 344 | * or string 345 | * @param size the size of that array, list or string 346 | * @param desc the text to use to describe this index in an error message 347 | * @return the value of {@code index} 348 | * @throws IndexOutOfBoundsException if {@code index} is negative or is 349 | * greater than {@code size} 350 | * @throws IllegalArgumentException if {@code size} is negative 351 | */ 352 | public static int checkPositionIndex( 353 | int index, int size, @Nullable String desc) { 354 | // Carefully optimized for execution by hotspot (explanatory comment above) 355 | if (index < 0 || index > size) { 356 | throw new IndexOutOfBoundsException(badPositionIndex(index, size, desc)); 357 | } 358 | return index; 359 | } 360 | 361 | private static String badPositionIndex(int index, int size, String desc) { 362 | if (index < 0) { 363 | return format("%s (%s) must not be negative", desc, index); 364 | } else if (size < 0) { 365 | throw new IllegalArgumentException("negative size: " + size); 366 | } else { // index > size 367 | return format("%s (%s) must not be greater than size (%s)", 368 | desc, index, size); 369 | } 370 | } 371 | 372 | /** 373 | * Ensures that {@code start} and {@code end} specify a valid positions 374 | * in an array, list or string of size {@code size}, and are in order. A 375 | * position index may range from zero to {@code size}, inclusive. 376 | * 377 | * @param start a user-supplied index identifying a starting position in an 378 | * array, list or string 379 | * @param end a user-supplied index identifying a ending position in an array, 380 | * list or string 381 | * @param size the size of that array, list or string 382 | * @throws IndexOutOfBoundsException if either index is negative or is 383 | * greater than {@code size}, or if {@code end} is less than {@code start} 384 | * @throws IllegalArgumentException if {@code size} is negative 385 | */ 386 | public static void checkPositionIndexes(int start, int end, int size) { 387 | // Carefully optimized for execution by hotspot (explanatory comment above) 388 | if (start < 0 || end < start || end > size) { 389 | throw new IndexOutOfBoundsException(badPositionIndexes(start, end, size)); 390 | } 391 | } 392 | 393 | private static String badPositionIndexes(int start, int end, int size) { 394 | if (start < 0 || start > size) { 395 | return badPositionIndex(start, size, "start index"); 396 | } 397 | if (end < 0 || end > size) { 398 | return badPositionIndex(end, size, "end index"); 399 | } 400 | // end < start 401 | return format("end index (%s) must not be less than start index (%s)", 402 | end, start); 403 | } 404 | 405 | /** 406 | * Substitutes each {@code %s} in {@code template} with an argument. These 407 | * are matched by position - the first {@code %s} gets {@code args[0]}, etc. 408 | * If there are more arguments than placeholders, the unmatched arguments will 409 | * be appended to the end of the formatted message in square braces. 410 | * 411 | * @param template a non-null string containing 0 or more {@code %s} 412 | * placeholders. 413 | * @param args the arguments to be substituted into the message 414 | * template. Arguments are converted to strings using 415 | * {@link String#valueOf(Object)}. Arguments can be null. 416 | */ 417 | static String format(String template, 418 | @Nullable Object... args) { 419 | template = String.valueOf(template); // null -> "null" 420 | 421 | // start substituting the arguments into the '%s' placeholders 422 | StringBuilder builder = new StringBuilder( 423 | template.length() + 16 * args.length); 424 | int templateStart = 0; 425 | int i = 0; 426 | while (i < args.length) { 427 | int placeholderStart = template.indexOf("%s", templateStart); 428 | if (placeholderStart == -1) { 429 | break; 430 | } 431 | builder.append(template.substring(templateStart, placeholderStart)); 432 | builder.append(args[i++]); 433 | templateStart = placeholderStart + 2; 434 | } 435 | builder.append(template.substring(templateStart)); 436 | 437 | // if we run out of placeholders, append the extra args in square braces 438 | if (i < args.length) { 439 | builder.append(" ["); 440 | builder.append(args[i++]); 441 | while (i < args.length) { 442 | builder.append(", "); 443 | builder.append(args[i++]); 444 | } 445 | builder.append(']'); 446 | } 447 | 448 | return builder.toString(); 449 | } 450 | } 451 | 452 | -------------------------------------------------------------------------------- /src/main/java/com/audienceproject/shaded/google/common/util/concurrent/RateLimiter.java: -------------------------------------------------------------------------------- 1 | package com.audienceproject.shaded.google.common.util.concurrent; 2 | 3 | /* 4 | * Notice: 5 | * This file was modified at AudienceProject ApS by Cosmin Catalin Sanda (cosmin@audienceproject.com) 6 | */ 7 | 8 | /* 9 | * Copyright (C) 2012 The Guava Authors 10 | * 11 | * Licensed under the Apache License, Version 2.0 (the "License"); 12 | * you may not use this file except in compliance with the License. 13 | * You may obtain a copy of the License at 14 | * 15 | * http://www.apache.org/licenses/LICENSE-2.0 16 | * 17 | * Unless required by applicable law or agreed to in writing, software 18 | * distributed under the License is distributed on an "AS IS" BASIS, 19 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | * See the License for the specific language governing permissions and 21 | * limitations under the License. 22 | */ 23 | 24 | import com.audienceproject.shaded.google.common.base.Preconditions; 25 | import com.audienceproject.shaded.google.common.base.Ticker; 26 | 27 | import javax.annotation.concurrent.ThreadSafe; 28 | import java.util.concurrent.TimeUnit; 29 | 30 | /** 31 | * A rate limiter. Conceptually, a rate limiter distributes permits at a 32 | * configurable rate. Each {@link #acquire()} blocks if necessary until a permit is 33 | * available, and then takes it. Once acquired, permits need not be released. 34 | * 35 | *

Rate limiters are often used to restrict the rate at which some 36 | * physical or logical resource is accessed. This is in contrast to {@link 37 | * java.util.concurrent.Semaphore} which restricts the number of concurrent 38 | * accesses instead of the rate (note though that concurrency and rate are closely related, 39 | * e.g. see Little's Law). 40 | * 41 | *

A {@code RateLimiter} is defined primarily by the rate at which permits 42 | * are issued. Absent additional configuration, permits will be distributed at a 43 | * fixed rate, defined in terms of permits per second. Permits will be distributed 44 | * smoothly, with the delay between individual permits being adjusted to ensure 45 | * that the configured rate is maintained. 46 | * 47 | *

It is possible to configure a {@code RateLimiter} to have a warmup 48 | * period during which time the permits issued each second steadily increases until 49 | * it hits the stable rate. 50 | * 51 | *

As an example, imagine that we have a list of tasks to execute, but we don't want to 52 | * submit more than 2 per second: 53 | *

  {@code
 54 |   *  final RateLimiter rateLimiter = RateLimiter.create(2.0); // rate is "2 permits per second"
 55 |  *  void submitTasks(List tasks, Executor executor) {
 56 |  *    for (Runnable task : tasks) {
 57 |  *      rateLimiter.acquire(); // may wait
 58 |  *      executor.execute(task);
 59 |  *    }
 60 |   *  }
 61 |   *}
62 | * 63 | *

As another example, imagine that we produce a stream of data, and we want to cap it 64 | * at 5kb per second. This could be accomplished by requiring a permit per byte, and specifying 65 | * a rate of 5000 permits per second: 66 | *

  {@code
 67 |   *  final RateLimiter rateLimiter = RateLimiter.create(5000.0); // rate = 5000 permits per second
 68 |  *  void submitPacket(byte[] packet) {
 69 |  *    rateLimiter.acquire(packet.length);
 70 |  *    networkService.send(packet);
 71 |  *  }
 72 |   *}
73 | * 74 | *

It is important to note that the number of permits requested never 75 | * affect the throttling of the request itself (an invocation to {@code acquire(1)} 76 | * and an invocation to {@code acquire(1000)} will result in exactly the same throttling, if any), 77 | * but it affects the throttling of the next request. I.e., if an expensive task 78 | * arrives at an idle RateLimiter, it will be granted immediately, but it is the next 79 | * request that will experience extra throttling, thus paying for the cost of the expensive 80 | * task. 81 | * 82 | *

Note: {@code RateLimiter} does not provide fairness guarantees. 83 | * 84 | * @author Dimitris Andreou 85 | * @since 13.0 86 | */ 87 | // TODO(user): switch to nano precision. A natural unit of cost is "bytes", and a micro precision 88 | // would mean a maximum rate of "1MB/s", which might be small in some cases. 89 | @ThreadSafe 90 | public abstract class RateLimiter { 91 | /* 92 | * How is the RateLimiter designed, and why? 93 | * 94 | * The primary feature of a RateLimiter is its "stable rate", the maximum rate that 95 | * is should allow at normal conditions. This is enforced by "throttling" incoming 96 | * requests as needed, i.e. compute, for an incoming request, the appropriate throttle time, 97 | * and make the calling thread wait as much. 98 | * 99 | * The simplest way to maintain a rate of QPS is to keep the timestamp of the last 100 | * granted request, and ensure that (1/QPS) seconds have elapsed since then. For example, 101 | * for a rate of QPS=5 (5 tokens per second), if we ensure that a request isn't granted 102 | * earlier than 200ms after the the last one, then we achieve the intended rate. 103 | * If a request comes and the last request was granted only 100ms ago, then we wait for 104 | * another 100ms. At this rate, serving 15 fresh permits (i.e. for an acquire(15) request) 105 | * naturally takes 3 seconds. 106 | * 107 | * It is important to realize that such a RateLimiter has a very superficial memory 108 | * of the past: it only remembers the last request. What if the RateLimiter was unused for 109 | * a long period of time, then a request arrived and was immediately granted? 110 | * This RateLimiter would immediately forget about that past underutilization. This may 111 | * result in either underutilization or overflow, depending on the real world consequences 112 | * of not using the expected rate. 113 | * 114 | * Past underutilization could mean that excess resources are available. Then, the RateLimiter 115 | * should speed up for a while, to take advantage of these resources. This is important 116 | * when the rate is applied to networking (limiting bandwidth), where past underutilization 117 | * typically translates to "almost empty buffers", which can be filled immediately. 118 | * 119 | * On the other hand, past underutilization could mean that "the server responsible for 120 | * handling the request has become less ready for future requests", i.e. its caches become 121 | * stale, and requests become more likely to trigger expensive operations (a more extreme 122 | * case of this example is when a server has just booted, and it is mostly busy with getting 123 | * itself up to speed). 124 | * 125 | * To deal with such scenarios, we add an extra dimension, that of "past underutilization", 126 | * modeled by "storedPermits" variable. This variable is zero when there is no 127 | * underutilization, and it can grow up to maxStoredPermits, for sufficiently large 128 | * underutilization. So, the requested permits, by an invocation acquire(permits), 129 | * are served from: 130 | * - stored permits (if available) 131 | * - fresh permits (for any remaining permits) 132 | * 133 | * How this works is best explained with an example: 134 | * 135 | * For a RateLimiter that produces 1 token per second, every second 136 | * that goes by with the RateLimiter being unused, we increase storedPermits by 1. 137 | * Say we leave the RateLimiter unused for 10 seconds (i.e., we expected a request at time 138 | * X, but we are at time X + 10 seconds before a request actually arrives; this is 139 | * also related to the point made in the last paragraph), thus storedPermits 140 | * becomes 10.0 (assuming maxStoredPermits >= 10.0). At that point, a request of acquire(3) 141 | * arrives. We serve this request out of storedPermits, and reduce that to 7.0 (how this is 142 | * translated to throttling time is discussed later). Immediately after, assume that an 143 | * acquire(10) request arriving. We serve the request partly from storedPermits, 144 | * using all the remaining 7.0 permits, and the remaining 3.0, we serve them by fresh permits 145 | * produced by the rate limiter. 146 | * 147 | * We already know how much time it takes to serve 3 fresh permits: if the rate is 148 | * "1 token per second", then this will take 3 seconds. But what does it mean to serve 7 149 | * stored permits? As explained above, there is no unique answer. If we are primarily 150 | * interested to deal with underutilization, then we want stored permits to be given out 151 | * /faster/ than fresh ones, because underutilization = free resources for the taking. 152 | * If we are primarily interested to deal with overflow, then stored permits could 153 | * be given out /slower/ than fresh ones. Thus, we require a (different in each case) 154 | * function that translates storedPermits to throtting time. 155 | * 156 | * This role is played by storedPermitsToWaitTime(double storedPermits, double permitsToTake). 157 | * The underlying model is a continuous function mapping storedPermits 158 | * (from 0.0 to maxStoredPermits) onto the 1/rate (i.e. intervals) that is effective at the given 159 | * storedPermits. "storedPermits" essentially measure unused time; we spend unused time 160 | * buying/storing permits. Rate is "permits / time", thus "1 / rate = time / permits". 161 | * Thus, "1/rate" (time / permits) times "permits" gives time, i.e., integrals on this 162 | * function (which is what storedPermitsToWaitTime() computes) correspond to minimum intervals 163 | * between subsequent requests, for the specified number of requested permits. 164 | * 165 | * Here is an example of storedPermitsToWaitTime: 166 | * If storedPermits == 10.0, and we want 3 permits, we take them from storedPermits, 167 | * reducing them to 7.0, and compute the throttling for these as a call to 168 | * storedPermitsToWaitTime(storedPermits = 10.0, permitsToTake = 3.0), which will 169 | * evaluate the integral of the function from 7.0 to 10.0. 170 | * 171 | * Using integrals guarantees that the effect of a single acquire(3) is equivalent 172 | * to { acquire(1); acquire(1); acquire(1); }, or { acquire(2); acquire(1); }, etc, 173 | * since the integral of the function in [7.0, 10.0] is equivalent to the sum of the 174 | * integrals of [7.0, 8.0], [8.0, 9.0], [9.0, 10.0] (and so on), no matter 175 | * what the function is. This guarantees that we handle correctly requests of varying weight 176 | * (permits), /no matter/ what the actual function is - so we can tweak the latter freely. 177 | * (The only requirement, obviously, is that we can compute its integrals). 178 | * 179 | * Note well that if, for this function, we chose a horizontal line, at height of exactly 180 | * (1/QPS), then the effect of the function is non-existent: we serve storedPermits at 181 | * exactly the same cost as fresh ones (1/QPS is the cost for each). We use this trick later. 182 | * 183 | * If we pick a function that goes /below/ that horizontal line, it means that we reduce 184 | * the area of the function, thus time. Thus, the RateLimiter becomes /faster/ after a 185 | * period of underutilization. If, on the other hand, we pick a function that 186 | * goes /above/ that horizontal line, then it means that the area (time) is increased, 187 | * thus storedPermits are more costly than fresh permits, thus the RateLimiter becomes 188 | * /slower/ after a period of underutilization. 189 | * 190 | * Last, but not least: consider a RateLimiter with rate of 1 permit per second, currently 191 | * completely unused, and an expensive acquire(100) request comes. It would be nonsensical 192 | * to just wait for 100 seconds, and /then/ start the actual task. Why wait without doing 193 | * anything? A much better approach is to /allow/ the request right away (as if it was an 194 | * acquire(1) request instead), and postpone /subsequent/ requests as needed. In this version, 195 | * we allow starting the task immediately, and postpone by 100 seconds future requests, 196 | * thus we allow for work to get done in the meantime instead of waiting idly. 197 | * 198 | * This has important consequences: it means that the RateLimiter doesn't remember the time 199 | * of the _last_ request, but it remembers the (expected) time of the _next_ request. This 200 | * also enables us to tell immediately (see tryAcquire(timeout)) whether a particular 201 | * timeout is enough to get us to the point of the next scheduling time, since we always 202 | * maintain that. And what we mean by "an unused RateLimiter" is also defined by that 203 | * notion: when we observe that the "expected arrival time of the next request" is actually 204 | * in the past, then the difference (now - past) is the amount of time that the RateLimiter 205 | * was formally unused, and it is that amount of time which we translate to storedPermits. 206 | * (We increase storedPermits with the amount of permits that would have been produced 207 | * in that idle time). So, if rate == 1 permit per second, and arrivals come exactly 208 | * one second after the previous, then storedPermits is _never_ increased -- we would only 209 | * increase it for arrivals _later_ than the expected one second. 210 | */ 211 | 212 | /** 213 | * Creates a {@code RateLimiter} with the specified stable throughput, given as 214 | * "permits per second" (commonly referred to as QPS, queries per second). 215 | * 216 | *

The returned {@code RateLimiter} ensures that on average no more than {@code 217 | * permitsPerSecond} are issued during any given second, with sustained requests 218 | * being smoothly spread over each second. When the incoming request rate exceeds 219 | * {@code permitsPerSecond} the rate limiter will release one permit every {@code 220 | * (1.0 / permitsPerSecond)} seconds. When the rate limiter is unused, 221 | * bursts of up to {@code permitsPerSecond} permits will be allowed, with subsequent 222 | * requests being smoothly limited at the stable rate of {@code permitsPerSecond}. 223 | * 224 | * @param permitsPerSecond the rate of the returned {@code RateLimiter}, measured in 225 | * how many permits become available per second. 226 | */ 227 | public static RateLimiter create(double permitsPerSecond) { 228 | return create(SleepingTicker.SYSTEM_TICKER, permitsPerSecond); 229 | } 230 | 231 | static RateLimiter create(SleepingTicker ticker, double permitsPerSecond) { 232 | RateLimiter rateLimiter = new Bursty(ticker); 233 | rateLimiter.setRate(permitsPerSecond); 234 | return rateLimiter; 235 | } 236 | 237 | /** 238 | * Creates a {@code RateLimiter} with the specified stable throughput, given as 239 | * "permits per second" (commonly referred to as QPS, queries per second), and a 240 | * warmup period, during which the {@code RateLimiter} smoothly ramps up its rate, 241 | * until it reaches its maximum rate at the end of the period (as long as there are enough 242 | * requests to saturate it). Similarly, if the {@code RateLimiter} is left unused for 243 | * a duration of {@code warmupPeriod}, it will gradually return to its "cold" state, 244 | * i.e. it will go through the same warming up process as when it was first created. 245 | * 246 | *

The returned {@code RateLimiter} is intended for cases where the resource that actually 247 | * fulfils the requests (e.g., a remote server) needs "warmup" time, rather than 248 | * being immediately accessed at the stable (maximum) rate. 249 | * 250 | *

The returned {@code RateLimiter} starts in a "cold" state (i.e. the warmup period 251 | * will follow), and if it is left unused for long enough, it will return to that state. 252 | * 253 | * @param permitsPerSecond the rate of the returned {@code RateLimiter}, measured in 254 | * how many permits become available per second 255 | * @param warmupPeriod the duration of the period where the {@code RateLimiter} ramps up its 256 | * rate, before reaching its stable (maximum) rate 257 | * @param unit the time unit of the warmupPeriod argument 258 | */ 259 | // TODO(user): add a burst size of 1-second-worth of permits, as in the metronome? 260 | public static RateLimiter create(double permitsPerSecond, long warmupPeriod, TimeUnit unit) { 261 | return create(SleepingTicker.SYSTEM_TICKER, permitsPerSecond, warmupPeriod, unit); 262 | } 263 | 264 | static RateLimiter create( 265 | SleepingTicker ticker, double permitsPerSecond, long warmupPeriod, TimeUnit timeUnit) { 266 | RateLimiter rateLimiter = new WarmingUp(ticker, warmupPeriod, timeUnit); 267 | rateLimiter.setRate(permitsPerSecond); 268 | return rateLimiter; 269 | } 270 | 271 | static RateLimiter createBursty( 272 | SleepingTicker ticker, double permitsPerSecond, int maxBurstSize) { 273 | Bursty rateLimiter = new Bursty(ticker); 274 | rateLimiter.setRate(permitsPerSecond); 275 | rateLimiter.maxPermits = maxBurstSize; 276 | return rateLimiter; 277 | } 278 | 279 | /** 280 | * The underlying timer; used both to measure elapsed time and sleep as necessary. A separate 281 | * object to facilitate testing. 282 | */ 283 | private final SleepingTicker ticker; 284 | 285 | /** 286 | * The timestamp when the RateLimiter was created; used to avoid possible overflow/time-wrapping 287 | * errors. 288 | */ 289 | private final long offsetNanos; 290 | 291 | /** 292 | * The currently stored permits. 293 | */ 294 | double storedPermits; 295 | 296 | /** 297 | * The maximum number of stored permits. 298 | */ 299 | double maxPermits; 300 | 301 | /** 302 | * The interval between two unit requests, at our stable rate. E.g., a stable rate of 5 permits 303 | * per second has a stable interval of 200ms. 304 | */ 305 | volatile double stableIntervalMicros; 306 | 307 | private final Object mutex = new Object(); 308 | 309 | /** 310 | * The time when the next request (no matter its size) will be granted. After granting a request, 311 | * this is pushed further in the future. Large requests push this further than small requests. 312 | */ 313 | private long nextFreeTicketMicros = 0L; // could be either in the past or future 314 | 315 | private RateLimiter(SleepingTicker ticker) { 316 | this.ticker = ticker; 317 | this.offsetNanos = ticker.read(); 318 | } 319 | 320 | /** 321 | * Updates the stable rate of this {@code RateLimiter}, that is, the 322 | * {@code permitsPerSecond} argument provided in the factory method that 323 | * constructed the {@code RateLimiter}. Currently throttled threads will not 324 | * be awakened as a result of this invocation, thus they do not observe the new rate; 325 | * only subsequent requests will. 326 | * 327 | *

Note though that, since each request repays (by waiting, if necessary) the cost 328 | * of the previous request, this means that the very next request 329 | * after an invocation to {@code setRate} will not be affected by the new rate; 330 | * it will pay the cost of the previous request, which is in terms of the previous rate. 331 | * 332 | *

The behavior of the {@code RateLimiter} is not modified in any other way, 333 | * e.g. if the {@code RateLimiter} was configured with a warmup period of 20 seconds, 334 | * it still has a warmup period of 20 seconds after this method invocation. 335 | * 336 | * @param permitsPerSecond the new stable rate of this {@code RateLimiter}. 337 | */ 338 | public final void setRate(double permitsPerSecond) { 339 | Preconditions.checkArgument(permitsPerSecond > 0.0 340 | && !Double.isNaN(permitsPerSecond), "rate must be positive"); 341 | synchronized (mutex) { 342 | resync(readSafeMicros()); 343 | double stableIntervalMicros = TimeUnit.SECONDS.toMicros(1L) / permitsPerSecond; 344 | this.stableIntervalMicros = stableIntervalMicros; 345 | doSetRate(permitsPerSecond, stableIntervalMicros); 346 | } 347 | } 348 | 349 | abstract void doSetRate(double permitsPerSecond, double stableIntervalMicros); 350 | 351 | /** 352 | * Returns the stable rate (as {@code permits per seconds}) with which this 353 | * {@code RateLimiter} is configured with. The initial value of this is the same as 354 | * the {@code permitsPerSecond} argument passed in the factory method that produced 355 | * this {@code RateLimiter}, and it is only updated after invocations 356 | * to {@linkplain #setRate}. 357 | */ 358 | public final double getRate() { 359 | return TimeUnit.SECONDS.toMicros(1L) / stableIntervalMicros; 360 | } 361 | 362 | /** 363 | * Acquires a permit from this {@code RateLimiter}, blocking until the request can be granted. 364 | * 365 | *

This method is equivalent to {@code acquire(1)}. 366 | */ 367 | public void acquire() { 368 | acquire(1); 369 | } 370 | 371 | /** 372 | * Acquires the given number of permits from this {@code RateLimiter}, blocking until the 373 | * request be granted. 374 | * 375 | * @param permits the number of permits to acquire 376 | */ 377 | public void acquire(int permits) { 378 | checkPermits(permits); 379 | long microsToWait; 380 | synchronized (mutex) { 381 | microsToWait = reserveNextTicket(permits, readSafeMicros()); 382 | } 383 | ticker.sleepMicrosUninterruptibly(microsToWait); 384 | } 385 | 386 | /** 387 | * Acquires a permit from this {@code RateLimiter} if it can be obtained 388 | * without exceeding the specified {@code timeout}, or returns {@code false} 389 | * immediately (without waiting) if the permit would not have been granted 390 | * before the timeout expired. 391 | * 392 | *

This method is equivalent to {@code tryAcquire(1, timeout, unit)}. 393 | * 394 | * @param timeout the maximum time to wait for the permit 395 | * @param unit the time unit of the timeout argument 396 | * @return {@code true} if the permit was acquired, {@code false} otherwise 397 | */ 398 | public boolean tryAcquire(long timeout, TimeUnit unit) { 399 | return tryAcquire(1, timeout, unit); 400 | } 401 | 402 | /** 403 | * Acquires permits from this {@link RateLimiter} if it can be acquired immediately without delay. 404 | * 405 | *

406 | * This method is equivalent to {@code tryAcquire(permits, 0, anyUnit)}. 407 | * 408 | * @param permits the number of permits to acquire 409 | * @return {@code true} if the permits were acquired, {@code false} otherwise 410 | * @since 14.0 411 | */ 412 | public boolean tryAcquire(int permits) { 413 | return tryAcquire(permits, 0, TimeUnit.MICROSECONDS); 414 | } 415 | 416 | /** 417 | * Acquires a permit from this {@link RateLimiter} if it can be acquired immediately without 418 | * delay. 419 | * 420 | *

421 | * This method is equivalent to {@code tryAcquire(1)}. 422 | * 423 | * @return {@code true} if the permit was acquired, {@code false} otherwise 424 | * @since 14.0 425 | */ 426 | public boolean tryAcquire() { 427 | return tryAcquire(1, 0, TimeUnit.MICROSECONDS); 428 | } 429 | 430 | /** 431 | * Acquires the given number of permits from this {@code RateLimiter} if it can be obtained 432 | * without exceeding the specified {@code timeout}, or returns {@code false} 433 | * immediately (without waiting) if the permits would not have been granted 434 | * before the timeout expired. 435 | * 436 | * @param permits the number of permits to acquire 437 | * @param timeout the maximum time to wait for the permits 438 | * @param unit the time unit of the timeout argument 439 | * @return {@code true} if the permits were acquired, {@code false} otherwise 440 | */ 441 | public boolean tryAcquire(int permits, long timeout, TimeUnit unit) { 442 | long timeoutMicros = unit.toMicros(timeout); 443 | checkPermits(permits); 444 | long microsToWait; 445 | synchronized (mutex) { 446 | long nowMicros = readSafeMicros(); 447 | if (nextFreeTicketMicros > nowMicros + timeoutMicros) { 448 | return false; 449 | } else { 450 | microsToWait = reserveNextTicket(permits, nowMicros); 451 | } 452 | } 453 | ticker.sleepMicrosUninterruptibly(microsToWait); 454 | return true; 455 | } 456 | 457 | private static void checkPermits(int permits) { 458 | Preconditions.checkArgument(permits > 0, "Requested permits must be positive"); 459 | } 460 | 461 | /** 462 | * Reserves next ticket and returns the wait time that the caller must wait for. 463 | */ 464 | private long reserveNextTicket(double requiredPermits, long nowMicros) { 465 | resync(nowMicros); 466 | long microsToNextFreeTicket = nextFreeTicketMicros - nowMicros; 467 | double storedPermitsToSpend = Math.min(requiredPermits, this.storedPermits); 468 | double freshPermits = requiredPermits - storedPermitsToSpend; 469 | 470 | long waitMicros = storedPermitsToWaitTime(this.storedPermits, storedPermitsToSpend) 471 | + (long) (freshPermits * stableIntervalMicros); 472 | 473 | this.nextFreeTicketMicros = nextFreeTicketMicros + waitMicros; 474 | this.storedPermits -= storedPermitsToSpend; 475 | return microsToNextFreeTicket; 476 | } 477 | 478 | /** 479 | * Translates a specified portion of our currently stored permits which we want to 480 | * spend/acquire, into a throttling time. Conceptually, this evaluates the integral 481 | * of the underlying function we use, for the range of 482 | * [(storedPermits - permitsToTake), storedPermits]. 483 | * 484 | * This always holds: {@code 0 <= permitsToTake <= storedPermits} 485 | */ 486 | abstract long storedPermitsToWaitTime(double storedPermits, double permitsToTake); 487 | 488 | private void resync(long nowMicros) { 489 | // if nextFreeTicket is in the past, resync to now 490 | if (nowMicros > nextFreeTicketMicros) { 491 | storedPermits = Math.min(maxPermits, 492 | storedPermits + (nowMicros - nextFreeTicketMicros) / stableIntervalMicros); 493 | nextFreeTicketMicros = nowMicros; 494 | } 495 | } 496 | 497 | private long readSafeMicros() { 498 | return TimeUnit.NANOSECONDS.toMicros(ticker.read() - offsetNanos); 499 | } 500 | 501 | @Override 502 | public String toString() { 503 | return String.format("RateLimiter[stableRate=%3.1fqps]", 1000000.0 / stableIntervalMicros); 504 | } 505 | 506 | /** 507 | * This implements the following function: 508 | * 509 | * ^ throttling 510 | * | 511 | * 3*stable + / 512 | * interval | /. 513 | * (cold) | / . 514 | * | / . <-- "warmup period" is the area of the trapezoid between 515 | * 2*stable + / . halfPermits and maxPermits 516 | * interval | / . 517 | * | / . 518 | * | / . 519 | * stable +----------/ WARM . } 520 | * interval | . UP . } <-- this rectangle (from 0 to maxPermits, and 521 | * | . PERIOD. } height == stableInterval) defines the cooldown period, 522 | * | . . } and we want cooldownPeriod == warmupPeriod 523 | * |---------------------------------> storedPermits 524 | * (halfPermits) (maxPermits) 525 | * 526 | * Before going into the details of this particular function, let's keep in mind the basics: 527 | * 1) The state of the RateLimiter (storedPermits) is a vertical line in this figure. 528 | * 2) When the RateLimiter is not used, this goes right (up to maxPermits) 529 | * 3) When the RateLimiter is used, this goes left (down to zero), since if we have storedPermits, 530 | * we serve from those first 531 | * 4) When _unused_, we go right at the same speed (rate)! I.e., if our rate is 532 | * 2 permits per second, and 3 unused seconds pass, we will always save 6 permits 533 | * (no matter what our initial position was), up to maxPermits. 534 | * If we invert the rate, we get the "stableInterval" (interval between two requests 535 | * in a perfectly spaced out sequence of requests of the given rate). Thus, if you 536 | * want to see "how much time it will take to go from X storedPermits to X+K storedPermits?", 537 | * the answer is always stableInterval * K. In the same example, for 2 permits per second, 538 | * stableInterval is 500ms. Thus to go from X storedPermits to X+6 storedPermits, we 539 | * require 6 * 500ms = 3 seconds. 540 | * 541 | * In short, the time it takes to move to the right (save K permits) is equal to the 542 | * rectangle of width == K and height == stableInterval. 543 | * 4) When _used_, the time it takes, as explained in the introductory class note, is 544 | * equal to the integral of our function, between X permits and X-K permits, assuming 545 | * we want to spend K saved permits. 546 | * 547 | * In summary, the time it takes to move to the left (spend K permits), is equal to the 548 | * area of the function of width == K. 549 | * 550 | * Let's dive into this function now: 551 | * 552 | * When we have storedPermits <= halfPermits (the left portion of the function), then 553 | * we spend them at the exact same rate that 554 | * fresh permits would be generated anyway (that rate is 1/stableInterval). We size 555 | * this area to be equal to _half_ the specified warmup period. Why we need this? 556 | * And why half? We'll explain shortly below (after explaining the second part). 557 | * 558 | * Stored permits that are beyond halfPermits, are mapped to an ascending line, that goes 559 | * from stableInterval to 3 * stableInterval. The average height for that part is 560 | * 2 * stableInterval, and is sized appropriately to have an area _equal_ to the 561 | * specified warmup period. Thus, by point (4) above, it takes "warmupPeriod" amount of time 562 | * to go from maxPermits to halfPermits. 563 | * 564 | * BUT, by point (3) above, it only takes "warmupPeriod / 2" amount of time to return back 565 | * to maxPermits, from halfPermits! (Because the trapezoid has double the area of the rectangle 566 | * of height stableInterval and equivalent width). We decided that the "cooldown period" 567 | * time should be equivalent to "warmup period", thus a fully saturated RateLimiter 568 | * (with zero stored permits, serving only fresh ones) can go to a fully unsaturated 569 | * (with storedPermits == maxPermits) in the same amount of time it takes for a fully 570 | * unsaturated RateLimiter to return to the stableInterval -- which happens in halfPermits, 571 | * since beyond that point, we use a horizontal line of "stableInterval" height, simulating 572 | * the regular rate. 573 | * 574 | * Thus, we have figured all dimensions of this shape, to give all the desired 575 | * properties: 576 | * - the width is warmupPeriod / stableInterval, to make cooldownPeriod == warmupPeriod 577 | * - the slope starts at the middle, and goes from stableInterval to 3*stableInterval so 578 | * to have halfPermits being spend in double the usual time (half the rate), while their 579 | * respective rate is steadily ramping up 580 | */ 581 | private static class WarmingUp extends RateLimiter { 582 | 583 | final long warmupPeriodMicros; 584 | /** 585 | * The slope of the line from the stable interval (when permits == 0), to the cold interval 586 | * (when permits == maxPermits) 587 | */ 588 | private double slope; 589 | private double halfPermits; 590 | 591 | WarmingUp(SleepingTicker ticker, long warmupPeriod, TimeUnit timeUnit) { 592 | super(ticker); 593 | this.warmupPeriodMicros = timeUnit.toMicros(warmupPeriod); 594 | } 595 | 596 | @Override 597 | void doSetRate(double permitsPerSecond, double stableIntervalMicros) { 598 | double oldMaxPermits = maxPermits; 599 | maxPermits = warmupPeriodMicros / stableIntervalMicros; 600 | halfPermits = maxPermits / 2.0; 601 | // Stable interval is x, cold is 3x, so on average it's 2x. Double the time -> halve the rate 602 | double coldIntervalMicros = stableIntervalMicros * 3.0; 603 | slope = (coldIntervalMicros - stableIntervalMicros) / halfPermits; 604 | if (oldMaxPermits == Double.POSITIVE_INFINITY) { 605 | // if we don't special-case this, we would get storedPermits == NaN, below 606 | storedPermits = 0.0; 607 | } else { 608 | storedPermits = (oldMaxPermits == 0.0) 609 | ? maxPermits // initial state is cold 610 | : storedPermits * maxPermits / oldMaxPermits; 611 | } 612 | } 613 | 614 | @Override 615 | long storedPermitsToWaitTime(double storedPermits, double permitsToTake) { 616 | double availablePermitsAboveHalf = storedPermits - halfPermits; 617 | long micros = 0; 618 | // measuring the integral on the right part of the function (the climbing line) 619 | if (availablePermitsAboveHalf > 0.0) { 620 | double permitsAboveHalfToTake = Math.min(availablePermitsAboveHalf, permitsToTake); 621 | micros = (long) (permitsAboveHalfToTake * (permitsToTime(availablePermitsAboveHalf) 622 | + permitsToTime(availablePermitsAboveHalf - permitsAboveHalfToTake)) / 2.0); 623 | permitsToTake -= permitsAboveHalfToTake; 624 | } 625 | // measuring the integral on the left part of the function (the horizontal line) 626 | micros += (stableIntervalMicros * permitsToTake); 627 | return micros; 628 | } 629 | 630 | private double permitsToTime(double permits) { 631 | return stableIntervalMicros + permits * slope; 632 | } 633 | } 634 | 635 | /** 636 | * This implements a trivial function, where storedPermits are translated to 637 | * zero throttling - thus, a client gets an infinite speedup for permits acquired out 638 | * of the storedPermits pool. This is also used for the special case of the "metronome", 639 | * where the width of the function is also zero; maxStoredPermits is zero, thus 640 | * storedPermits and permitsToTake are always zero as well. Such a RateLimiter can 641 | * not save permits when unused, thus all permits it serves are fresh, using the 642 | * designated rate. 643 | */ 644 | private static class Bursty extends RateLimiter { 645 | Bursty(SleepingTicker ticker) { 646 | super(ticker); 647 | } 648 | 649 | @Override 650 | void doSetRate(double permitsPerSecond, double stableIntervalMicros) { 651 | double oldMaxPermits = this.maxPermits; 652 | /* 653 | * We allow the equivalent work of up to one second to be granted with zero waiting, if the 654 | * rate limiter has been unused for as much. This is to avoid potentially producing tiny 655 | * wait interval between subsequent requests for sufficiently large rates, which would 656 | * unnecessarily overconstrain the thread scheduler. 657 | */ 658 | maxPermits = permitsPerSecond; // one second worth of permits 659 | storedPermits = (oldMaxPermits == 0.0) 660 | ? 0.0 // initial state 661 | : storedPermits * maxPermits / oldMaxPermits; 662 | } 663 | 664 | @Override 665 | long storedPermitsToWaitTime(double storedPermits, double permitsToTake) { 666 | return 0L; 667 | } 668 | } 669 | } 670 | 671 | abstract class SleepingTicker extends Ticker { 672 | abstract void sleepMicrosUninterruptibly(long micros); 673 | 674 | static final SleepingTicker SYSTEM_TICKER = new SleepingTicker() { 675 | @Override 676 | public long read() { 677 | return systemTicker().read(); 678 | } 679 | 680 | @Override 681 | public void sleepMicrosUninterruptibly(long micros) { 682 | if (micros > 0) { 683 | Uninterruptibles.sleepUninterruptibly(micros, TimeUnit.MICROSECONDS); 684 | } 685 | } 686 | }; 687 | } 688 | --------------------------------------------------------------------------------