├── version.sbt
├── src
├── it
│ ├── resources
│ │ └── log4j.properties
│ └── scala
│ │ └── io
│ │ └── github
│ │ └── spark_redshift_community
│ │ └── spark
│ │ └── redshift
│ │ ├── PostgresDriverIntegrationSuite.scala
│ │ ├── RedshiftCredentialsInConfIntegrationSuite.scala
│ │ ├── CrossRegionIntegrationSuite.scala
│ │ ├── IAMIntegrationSuite.scala
│ │ ├── DecimalIntegrationSuite.scala
│ │ ├── ColumnMetadataSuite.scala
│ │ ├── SaveModeIntegrationSuite.scala
│ │ ├── RedshiftWriteSuite.scala
│ │ ├── IntegrationSuiteBase.scala
│ │ └── RedshiftReadSuite.scala
├── test
│ ├── resources
│ │ ├── redshift_unload_data.txt
│ │ └── hive-site.xml
│ ├── scala
│ │ └── io
│ │ │ └── github
│ │ │ └── spark_redshift_community
│ │ │ └── spark
│ │ │ └── redshift
│ │ │ ├── TableNameSuite.scala
│ │ │ ├── SerializableConfigurationSuite.scala
│ │ │ ├── DirectMapredOutputCommitter.scala
│ │ │ ├── DirectMapreduceOutputCommitter.scala
│ │ │ ├── SeekableByteArrayInputStream.java
│ │ │ ├── UtilsSuite.scala
│ │ │ ├── QueryTest.scala
│ │ │ ├── FilterPushdownSuite.scala
│ │ │ ├── InMemoryS3AFileSystemSuite.scala
│ │ │ ├── MockRedshift.scala
│ │ │ ├── TestUtils.scala
│ │ │ ├── RedshiftInputFormatSuite.scala
│ │ │ ├── AWSCredentialsUtilsSuite.scala
│ │ │ ├── ConversionsSuite.scala
│ │ │ └── ParametersSuite.scala
│ └── java
│ │ └── io
│ │ └── github
│ │ └── spark_redshift_community
│ │ └── spark
│ │ └── redshift
│ │ └── InMemoryS3AFileSystem.java
└── main
│ └── scala
│ └── io
│ └── github
│ └── spark_redshift_community
│ └── spark
│ └── redshift
│ ├── SerializableConfiguration.scala
│ ├── package.scala
│ ├── RecordReaderIterator.scala
│ ├── TableName.scala
│ ├── FilterPushdown.scala
│ ├── RedshiftFileFormat.scala
│ ├── DefaultSource.scala
│ ├── AWSCredentialsUtils.scala
│ ├── Conversions.scala
│ ├── Utils.scala
│ ├── RedshiftInputFormat.scala
│ └── RedshiftRelation.scala
├── tutorial
├── images
│ ├── loadreadstep.png
│ ├── loadunloadstep.png
│ └── savetoredshift.png
├── how_to_build.md
└── SparkRedshiftTutorial.scala
├── .gitignore
├── .jvmopts
├── codecov.yml
├── NOTICE
├── .travis.yml
├── dev
└── run-tests-travis.sh
├── project
├── plugins.sbt
└── build.properties
├── CHANGELOG
├── LICENSE
└── scalastyle-config.xml
/version.sbt:
--------------------------------------------------------------------------------
1 | version in ThisBuild := "4.1.0"
2 |
--------------------------------------------------------------------------------
/src/it/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | log4j.rootLogger=OFF
2 |
--------------------------------------------------------------------------------
/tutorial/images/loadreadstep.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tes/spark-redshift/master/tutorial/images/loadreadstep.png
--------------------------------------------------------------------------------
/tutorial/images/loadunloadstep.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tes/spark-redshift/master/tutorial/images/loadunloadstep.png
--------------------------------------------------------------------------------
/tutorial/images/savetoredshift.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tes/spark-redshift/master/tutorial/images/savetoredshift.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | project/target
3 | .idea/
4 | .idea_modules/
5 | *.DS_Store
6 | build/*.jar
7 | aws_variables.env
8 | derby.log
9 |
--------------------------------------------------------------------------------
/.jvmopts:
--------------------------------------------------------------------------------
1 | -Dfile.encoding=UTF8
2 | -Xms1024M
3 | -Xmx1024M
4 | -Xss6M
5 | -XX:MaxPermSize=512m
6 | -XX:+CMSClassUnloadingEnabled
7 | -XX:+UseConcMarkSweepGC
8 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | comment:
2 | layout: header, changes, diff
3 | coverage:
4 | status:
5 | patch: false
6 | project:
7 | default:
8 | target: 85
9 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Apache Accumulo
2 | Copyright 2011-2019 The Apache Software Foundation
3 |
4 | This product includes software developed at
5 | The Apache Software Foundation (http://www.apache.org/).
6 |
7 |
--------------------------------------------------------------------------------
/src/test/resources/redshift_unload_data.txt:
--------------------------------------------------------------------------------
1 | 1|t|2015-07-01|1234152.12312498|1.0|42|1239012341823719|23|Unicode's樂趣|2015-07-01 00:00:00.001
2 | 1|f|2015-07-02|0|0.0|42|1239012341823719|-13|asdf|2015-07-02 00:00:00.0
3 | 0||2015-07-03|0.0|-1.0|4141214|1239012341823719||f|2015-07-03 00:00:00
4 | 0|f||-1234152.12312498|100000.0||1239012341823719|24|___\|_123|
5 | |||||||||
6 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: scala
2 | sudo: false
3 | # Cache settings here are based on latest SBT documentation.
4 | cache:
5 | directories:
6 | - $HOME/.ivy2/cache
7 | - $HOME/.sbt/boot/
8 | before_cache:
9 | # Tricks to avoid unnecessary cache updates
10 | - find $HOME/.ivy2 -name "ivydata-*.properties" -delete
11 | - find $HOME/.sbt -name "*.lock" -delete
12 | matrix:
13 | include:
14 | - jdk: openjdk8
15 | scala: 2.11.7
16 | env: HADOOP_VERSION="2.7.7" SPARK_VERSION="2.4.3" AWS_JAVA_SDK_VERSION="1.7.4"
17 |
18 | script:
19 | - ./dev/run-tests-travis.sh
20 |
21 | after_success:
22 | - bash <(curl -s https://codecov.io/bash)
23 |
--------------------------------------------------------------------------------
/dev/run-tests-travis.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -e
4 |
5 | sbt ++$TRAVIS_SCALA_VERSION scalastyle
6 | sbt ++$TRAVIS_SCALA_VERSION "test:scalastyle"
7 | sbt ++$TRAVIS_SCALA_VERSION "it:scalastyle"
8 |
9 | sbt \
10 | -Daws.testVersion=$AWS_JAVA_SDK_VERSION \
11 | -Dhadoop.testVersion=$HADOOP_VERSION \
12 | -Dspark.testVersion=$SPARK_VERSION \
13 | ++$TRAVIS_SCALA_VERSION \
14 | coverage test coverageReport
15 |
16 | if [ "$TRAVIS_SECURE_ENV_VARS" == "true" ]; then
17 | sbt \
18 | -Daws.testVersion=$AWS_JAVA_SDK_VERSION \
19 | -Dhadoop.testVersion=$HADOOP_VERSION \
20 | -Dspark.testVersion=$SPARK_VERSION \
21 | ++$TRAVIS_SCALA_VERSION \
22 | coverage it:test coverageReport 2> /dev/null;
23 | fi
24 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")
2 |
3 | addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.5")
4 |
5 | resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven"
6 |
7 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.2")
8 |
9 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0")
10 |
11 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0")
12 |
13 | addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0")
14 |
15 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.0")
16 |
17 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
18 |
19 | libraryDependencies += "org.apache.maven" % "maven-artifact" % "3.3.9"
20 |
--------------------------------------------------------------------------------
/tutorial/how_to_build.md:
--------------------------------------------------------------------------------
1 | If you are building this project from source, you can try the following
2 |
3 | ```
4 | git clone https://github.com/spark-redshift-community/spark-redshift.git
5 | ```
6 |
7 | ```
8 | cd spark-redshift
9 | ```
10 |
11 | ```
12 | ./build/sbt -v compile
13 | ```
14 |
15 | ```
16 | ./build/sbt -v package
17 | ```
18 |
19 | To run the test
20 |
21 | ```
22 | ./build/sbt -v test
23 | ```
24 |
25 | To run the integration test
26 |
27 | For the first time, you need to set up all the evnironment variables to connect to Redshift (see https://github.com/spark-redshift-community/spark-redshift/blob/master/src/it/scala/io/github/spark_redshift_community/spark/redshift/IntegrationSuiteBase.scala#L54).
28 |
29 | ```
30 | ./build/sbt -v it:test
31 | ```
32 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | sbt.version=0.13.13
18 |
--------------------------------------------------------------------------------
/src/test/resources/hive-site.xml:
--------------------------------------------------------------------------------
1 |
17 |
18 |
19 |
20 |
21 |
22 | fs.permissions.umask-mode
23 | 022
24 | Setting a value for fs.permissions.umask-mode to work around issue in HIVE-6962.
25 | It has no impact in Hadoop 1.x line on HDFS operations.
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/TableNameSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.scalatest.FunSuite
20 |
21 | class TableNameSuite extends FunSuite {
22 | test("TableName.parseFromEscaped") {
23 | assert(TableName.parseFromEscaped("foo.bar") === TableName("foo", "bar"))
24 | assert(TableName.parseFromEscaped("foo") === TableName("PUBLIC", "foo"))
25 | assert(TableName.parseFromEscaped("\"foo\"") === TableName("PUBLIC", "foo"))
26 | assert(TableName.parseFromEscaped("\"\"\"foo\"\"\".bar") === TableName("\"foo\"", "bar"))
27 | // Dots (.) can also appear inside of valid identifiers.
28 | assert(TableName.parseFromEscaped("\"foo.bar\".baz") === TableName("foo.bar", "baz"))
29 | assert(TableName.parseFromEscaped("\"foo\"\".bar\".baz") === TableName("foo\".bar", "baz"))
30 | }
31 |
32 | test("TableName.toString") {
33 | assert(TableName("foo", "bar").toString === """"foo"."bar"""")
34 | assert(TableName("PUBLIC", "bar").toString === """"PUBLIC"."bar"""")
35 | assert(TableName("\"foo\"", "bar").toString === "\"\"\"foo\"\"\".\"bar\"")
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/PostgresDriverIntegrationSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.apache.spark.sql.Row
20 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
21 |
22 | /**
23 | * Basic integration tests with the Postgres JDBC driver.
24 | */
25 | class PostgresDriverIntegrationSuite extends IntegrationSuiteBase {
26 |
27 | override def jdbcUrl: String = {
28 | super.jdbcUrl.replace("jdbc:redshift", "jdbc:postgresql")
29 | }
30 |
31 | // TODO (luca|issue #9) Fix tests when using postgresql driver
32 | ignore("postgresql driver takes precedence for jdbc:postgresql:// URIs") {
33 | val conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None)
34 | try {
35 | assert(conn.getClass.getName === "org.postgresql.jdbc4.Jdbc4Connection")
36 | } finally {
37 | conn.close()
38 | }
39 | }
40 |
41 | ignore("roundtrip save and load") {
42 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1),
43 | StructType(StructField("foo", IntegerType) :: Nil))
44 | testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/SerializableConfigurationSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.apache.hadoop.conf.Configuration
20 | import org.apache.spark.SparkConf
21 | import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance}
22 | import org.scalatest.FunSuite
23 |
24 | class SerializableConfigurationSuite extends FunSuite {
25 |
26 | private def testSerialization(serializer: SerializerInstance): Unit = {
27 | val conf = new SerializableConfiguration(new Configuration())
28 |
29 | val serialized = serializer.serialize(conf)
30 |
31 | serializer.deserialize[Any](serialized) match {
32 | case c: SerializableConfiguration =>
33 | assert(c.log != null, "log was null")
34 | assert(c.value != null, "value was null")
35 | case other => fail(
36 | s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.")
37 | }
38 | }
39 |
40 | test("serialization with JavaSerializer") {
41 | testSerialization(new JavaSerializer(new SparkConf()).newInstance())
42 | }
43 |
44 | test("serialization with KryoSerializer") {
45 | testSerialization(new KryoSerializer(new SparkConf()).newInstance())
46 | }
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.apache.spark.sql.Row
20 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
21 |
22 | /**
23 | * This suite performs basic integration tests where the Redshift credentials have been
24 | * specified via `spark-redshift`'s configuration rather than as part of the JDBC URL.
25 | */
26 | class RedshiftCredentialsInConfIntegrationSuite extends IntegrationSuiteBase {
27 |
28 | test("roundtrip save and load") {
29 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1),
30 | StructType(StructField("foo", IntegerType) :: Nil))
31 | val tableName = s"roundtrip_save_and_load_$randomSuffix"
32 | try {
33 | write(df)
34 | .option("url", jdbcUrlNoUserPassword)
35 | .option("user", AWS_REDSHIFT_USER)
36 | .option("password", AWS_REDSHIFT_PASSWORD)
37 | .option("dbtable", tableName)
38 | .save()
39 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
40 | val loadedDf = read
41 | .option("url", jdbcUrlNoUserPassword)
42 | .option("user", AWS_REDSHIFT_USER)
43 | .option("password", AWS_REDSHIFT_PASSWORD)
44 | .option("dbtable", tableName)
45 | .load()
46 | assert(loadedDf.schema === df.schema)
47 | checkAnswer(loadedDf, df.collect())
48 | } finally {
49 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
50 | }
51 | }
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/SerializableConfiguration.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.io._
20 |
21 | import com.esotericsoftware.kryo.io.{Input, Output}
22 | import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
23 | import org.apache.hadoop.conf.Configuration
24 | import org.slf4j.LoggerFactory
25 |
26 | import scala.util.control.NonFatal
27 |
28 | class SerializableConfiguration(@transient var value: Configuration)
29 | extends Serializable with KryoSerializable {
30 | @transient private[redshift] lazy val log = LoggerFactory.getLogger(getClass)
31 |
32 | private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
33 | out.defaultWriteObject()
34 | value.write(out)
35 | }
36 |
37 | private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
38 | value = new Configuration(false)
39 | value.readFields(in)
40 | }
41 |
42 | private def tryOrIOException[T](block: => T): T = {
43 | try {
44 | block
45 | } catch {
46 | case e: IOException =>
47 | log.error("Exception encountered", e)
48 | throw e
49 | case NonFatal(e) =>
50 | log.error("Exception encountered", e)
51 | throw new IOException(e)
52 | }
53 | }
54 |
55 | def write(kryo: Kryo, out: Output): Unit = {
56 | val dos = new DataOutputStream(out)
57 | value.write(dos)
58 | dos.flush()
59 | }
60 |
61 | def read(kryo: Kryo, in: Input): Unit = {
62 | value = new Configuration(false)
63 | value.readFields(new DataInputStream(in))
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/package.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | * Copyright 2015 TouchType Ltd. (Added JDBC-based Data Source API implementation)
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package io.github.spark_redshift_community.spark
19 |
20 | import org.apache.spark.sql.functions.col
21 | import org.apache.spark.sql.types.{StringType, StructField, StructType}
22 | import org.apache.spark.sql.{DataFrame, Row, SQLContext}
23 |
24 | package object redshift {
25 |
26 | /**
27 | * Wrapper of SQLContext that provide `redshiftFile` method.
28 | */
29 | implicit class RedshiftContext(sqlContext: SQLContext) {
30 |
31 | /**
32 | * Read a file unloaded from Redshift into a DataFrame.
33 | * @param path input path
34 | * @return a DataFrame with all string columns
35 | */
36 | def redshiftFile(path: String, columns: Seq[String]): DataFrame = {
37 | val sc = sqlContext.sparkContext
38 | val rdd = sc.newAPIHadoopFile(path, classOf[RedshiftInputFormat],
39 | classOf[java.lang.Long], classOf[Array[String]], sc.hadoopConfiguration)
40 | // TODO: allow setting NULL string.
41 | val nullable = rdd.values.map(_.map(f => if (f.isEmpty) null else f)).map(x => Row(x: _*))
42 | val schema = StructType(columns.map(c => StructField(c, StringType, nullable = true)))
43 | sqlContext.createDataFrame(nullable, schema)
44 | }
45 |
46 | /**
47 | * Reads a table unload from Redshift with its schema.
48 | */
49 | def redshiftFile(path: String, schema: StructType): DataFrame = {
50 | val casts = schema.fields.map { field =>
51 | col(field.name).cast(field.dataType).as(field.name)
52 | }
53 | redshiftFile(path, schema.fieldNames).select(casts: _*)
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/DirectMapredOutputCommitter.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License"); you may
5 | * not use this file except in compliance with the License. You may obtain
6 | * a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.apache.hadoop.fs.Path
20 | import org.apache.hadoop.mapred._
21 |
22 | class DirectMapredOutputCommitter extends OutputCommitter {
23 | override def setupJob(jobContext: JobContext): Unit = { }
24 |
25 | override def setupTask(taskContext: TaskAttemptContext): Unit = { }
26 |
27 | override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = {
28 | // We return true here to guard against implementations that do not handle false correctly.
29 | // The meaning of returning false is not entirely clear, so it's possible to be interpreted
30 | // as an error. Returning true just means that commitTask() will be called, which is a no-op.
31 | true
32 | }
33 |
34 | override def commitTask(taskContext: TaskAttemptContext): Unit = { }
35 |
36 | override def abortTask(taskContext: TaskAttemptContext): Unit = { }
37 |
38 | /**
39 | * Creates a _SUCCESS file to indicate the entire job was successful.
40 | * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option.
41 | */
42 | override def commitJob(context: JobContext): Unit = {
43 | val conf = context.getJobConf
44 | if (shouldCreateSuccessFile(conf)) {
45 | val outputPath = FileOutputFormat.getOutputPath(conf)
46 | if (outputPath != null) {
47 | val fileSys = outputPath.getFileSystem(conf)
48 | val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME)
49 | fileSys.create(filePath).close()
50 | }
51 | }
52 | }
53 |
54 | /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */
55 | private def shouldCreateSuccessFile(conf: JobConf): Boolean = {
56 | conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/DirectMapreduceOutputCommitter.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks, Inc.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License"); you may
5 | * not use this file except in compliance with the License. You may obtain
6 | * a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.apache.hadoop.conf.Configuration
20 | import org.apache.hadoop.fs.Path
21 | import org.apache.hadoop.mapreduce._
22 | import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat}
23 |
24 | class DirectMapreduceOutputCommitter extends OutputCommitter {
25 | override def setupJob(jobContext: JobContext): Unit = { }
26 |
27 | override def setupTask(taskContext: TaskAttemptContext): Unit = { }
28 |
29 | override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = {
30 | // We return true here to guard against implementations that do not handle false correctly.
31 | // The meaning of returning false is not entirely clear, so it's possible to be interpreted
32 | // as an error. Returning true just means that commitTask() will be called, which is a no-op.
33 | true
34 | }
35 |
36 | override def commitTask(taskContext: TaskAttemptContext): Unit = { }
37 |
38 | override def abortTask(taskContext: TaskAttemptContext): Unit = { }
39 |
40 | /**
41 | * Creates a _SUCCESS file to indicate the entire job was successful.
42 | * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option.
43 | */
44 | override def commitJob(context: JobContext): Unit = {
45 | val conf = context.getConfiguration
46 | if (shouldCreateSuccessFile(conf)) {
47 | val outputPath = FileOutputFormat.getOutputPath(context)
48 | if (outputPath != null) {
49 | val fileSys = outputPath.getFileSystem(conf)
50 | val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME)
51 | fileSys.create(filePath).close()
52 | }
53 | }
54 | }
55 |
56 | /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */
57 | private def shouldCreateSuccessFile(conf: Configuration): Boolean = {
58 | conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/CrossRegionIntegrationSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import com.amazonaws.auth.BasicAWSCredentials
20 | import com.amazonaws.services.s3.AmazonS3Client
21 | import org.apache.spark.sql.Row
22 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
23 |
24 | /**
25 | * Integration tests where the Redshift cluster and the S3 bucket are in different AWS regions.
26 | */
27 | class CrossRegionIntegrationSuite extends IntegrationSuiteBase {
28 |
29 | protected val AWS_S3_CROSS_REGION_SCRATCH_SPACE: String =
30 | loadConfigFromEnv("AWS_S3_CROSS_REGION_SCRATCH_SPACE")
31 | require(AWS_S3_CROSS_REGION_SCRATCH_SPACE.contains("s3a"), "must use s3a:// URL")
32 |
33 | override protected val tempDir: String = AWS_S3_CROSS_REGION_SCRATCH_SPACE + randomSuffix + "/"
34 |
35 | test("write") {
36 | val bucketRegion = Utils.getRegionForS3Bucket(
37 | tempDir,
38 | new AmazonS3Client(new BasicAWSCredentials(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY))).get
39 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1),
40 | StructType(StructField("foo", IntegerType) :: Nil))
41 | val tableName = s"roundtrip_save_and_load_$randomSuffix"
42 | try {
43 | write(df)
44 | .option("dbtable", tableName)
45 | .option("extracopyoptions", s"region '$bucketRegion'")
46 | .save()
47 | // Check that the table exists. It appears that creating a table in one connection then
48 | // immediately querying for existence from another connection may result in spurious "table
49 | // doesn't exist" errors; this caused the "save with all empty partitions" test to become
50 | // flaky (see #146). To work around this, add a small sleep and check again:
51 | if (!DefaultJDBCWrapper.tableExists(conn, tableName)) {
52 | Thread.sleep(1000)
53 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
54 | }
55 | } finally {
56 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
57 | }
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/RecordReaderIterator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package io.github.spark_redshift_community.spark.redshift
19 |
20 | import java.io.Closeable
21 |
22 | import org.apache.hadoop.mapreduce.RecordReader
23 |
24 | /**
25 | * An adaptor from a Hadoop [[RecordReader]] to an [[Iterator]] over the values returned.
26 | *
27 | * This is copied from Apache Spark and is inlined here to avoid depending on Spark internals
28 | * in this external library.
29 | */
30 | private[redshift] class RecordReaderIterator[T](
31 | private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable {
32 | private[this] var havePair = false
33 | private[this] var finished = false
34 |
35 | override def hasNext: Boolean = {
36 | if (!finished && !havePair) {
37 | finished = !rowReader.nextKeyValue
38 | if (finished) {
39 | // Close and release the reader here; close() will also be called when the task
40 | // completes, but for tasks that read from many files, it helps to release the
41 | // resources early.
42 | close()
43 | }
44 | havePair = !finished
45 | }
46 | !finished
47 | }
48 |
49 | override def next(): T = {
50 | if (!hasNext) {
51 | throw new java.util.NoSuchElementException("End of stream")
52 | }
53 | havePair = false
54 | rowReader.getCurrentValue
55 | }
56 |
57 | override def close(): Unit = {
58 | if (rowReader != null) {
59 | try {
60 | // Close the reader and release it. Note: it's very important that we don't close the
61 | // reader more than once, since that exposes us to MAPREDUCE-5918 when running against
62 | // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues
63 | // when reading compressed input.
64 | rowReader.close()
65 | } finally {
66 | rowReader = null
67 | }
68 | }
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/SeekableByteArrayInputStream.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /*
19 | SeekableByteArrayInputStream copied from
20 | https://github.com/apache/accumulo/blob/master/core/src/test/java/org/apache/accumulo/core/file/rfile/RFileTest.java
21 | */
22 |
23 | package io.github.spark_redshift_community.spark.redshift;
24 |
25 | import org.apache.hadoop.fs.PositionedReadable;
26 | import org.apache.hadoop.fs.Seekable;
27 |
28 | import java.io.ByteArrayInputStream;
29 | import java.io.IOException;
30 |
31 |
32 | class SeekableByteArrayInputStream extends ByteArrayInputStream
33 | implements Seekable, PositionedReadable {
34 |
35 | public SeekableByteArrayInputStream(byte[] buf) {
36 | super(buf);
37 | }
38 |
39 | @Override
40 | public long getPos() {
41 | return pos;
42 | }
43 |
44 | @Override
45 | public void seek(long pos) throws IOException {
46 | if (mark != 0)
47 | throw new IllegalStateException();
48 |
49 | reset();
50 | long skipped = skip(pos);
51 |
52 | if (skipped != pos)
53 | throw new IOException();
54 | }
55 |
56 | @Override
57 | public boolean seekToNewSource(long targetPos) {
58 | return false;
59 | }
60 |
61 | @Override
62 | public int read(long position, byte[] buffer, int offset, int length) {
63 |
64 | if (position >= buf.length)
65 | throw new IllegalArgumentException();
66 | if (position + length > buf.length)
67 | throw new IllegalArgumentException();
68 | if (length > buffer.length)
69 | throw new IllegalArgumentException();
70 |
71 | System.arraycopy(buf, (int) position, buffer, offset, length);
72 | return length;
73 | }
74 |
75 | @Override
76 | public void readFully(long position, byte[] buffer) {
77 | read(position, buffer, 0, buffer.length);
78 |
79 | }
80 |
81 | @Override
82 | public void readFully(long position, byte[] buffer, int offset, int length) {
83 | read(position, buffer, offset, length);
84 | }
85 |
86 | }
87 |
88 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/IAMIntegrationSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.SQLException
20 |
21 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
22 | import org.apache.spark.sql.{Row, SaveMode}
23 |
24 | /**
25 | * Integration tests for configuring Redshift to access S3 using Amazon IAM roles.
26 | */
27 | class IAMIntegrationSuite extends IntegrationSuiteBase {
28 |
29 | private val IAM_ROLE_ARN: String = loadConfigFromEnv("STS_ROLE_ARN")
30 |
31 | // TODO (luca|issue #8) Fix IAM Authentication tests
32 | ignore("roundtrip save and load") {
33 | val tableName = s"iam_roundtrip_save_and_load$randomSuffix"
34 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
35 | StructType(StructField("a", IntegerType) :: Nil))
36 | try {
37 | write(df)
38 | .option("dbtable", tableName)
39 | .option("forward_spark_s3_credentials", "false")
40 | .option("aws_iam_role", IAM_ROLE_ARN)
41 | .mode(SaveMode.ErrorIfExists)
42 | .save()
43 |
44 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
45 | val loadedDf = read
46 | .option("dbtable", tableName)
47 | .option("forward_spark_s3_credentials", "false")
48 | .option("aws_iam_role", IAM_ROLE_ARN)
49 | .load()
50 | assert(loadedDf.schema.length === 1)
51 | assert(loadedDf.columns === Seq("a"))
52 | checkAnswer(loadedDf, Seq(Row(1)))
53 | } finally {
54 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
55 | }
56 | }
57 |
58 | ignore("load fails if IAM role cannot be assumed") {
59 | val tableName = s"iam_load_fails_if_role_cannot_be_assumed$randomSuffix"
60 | try {
61 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
62 | StructType(StructField("a", IntegerType) :: Nil))
63 | val err = intercept[SQLException] {
64 | write(df)
65 | .option("dbtable", tableName)
66 | .option("forward_spark_s3_credentials", "false")
67 | .option("aws_iam_role", IAM_ROLE_ARN + "-some-bogus-suffix")
68 | .mode(SaveMode.ErrorIfExists)
69 | .save()
70 | }
71 | assert(err.getCause.getMessage.contains("is not authorized to assume IAM Role"))
72 | } finally {
73 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
74 | }
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/TableName.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import scala.collection.mutable.ArrayBuffer
20 |
21 | /**
22 | * Wrapper class for representing the name of a Redshift table.
23 | */
24 | private[redshift] case class TableName(unescapedSchemaName: String, unescapedTableName: String) {
25 | private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"'
26 | def escapedSchemaName: String = quote(unescapedSchemaName)
27 | def escapedTableName: String = quote(unescapedTableName)
28 | override def toString: String = s"$escapedSchemaName.$escapedTableName"
29 | }
30 |
31 | private[redshift] object TableName {
32 | /**
33 | * Parses a table name which is assumed to have been escaped according to Redshift's rules for
34 | * delimited identifiers.
35 | */
36 | def parseFromEscaped(str: String): TableName = {
37 | def dropOuterQuotes(s: String) =
38 | if (s.startsWith("\"") && s.endsWith("\"")) s.drop(1).dropRight(1) else s
39 | def unescapeQuotes(s: String) = s.replace("\"\"", "\"")
40 | def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s))
41 | splitByDots(str) match {
42 | case Seq(tableName) => TableName("PUBLIC", unescape(tableName))
43 | case Seq(schemaName, tableName) => TableName(unescape(schemaName), unescape(tableName))
44 | case other => throw new IllegalArgumentException(s"Could not parse table name from '$str'")
45 | }
46 | }
47 |
48 | /**
49 | * Split by dots (.) while obeying our identifier quoting rules in order to allow dots to appear
50 | * inside of quoted identifiers.
51 | */
52 | private def splitByDots(str: String): Seq[String] = {
53 | val parts: ArrayBuffer[String] = ArrayBuffer.empty
54 | val sb = new StringBuilder
55 | var inQuotes: Boolean = false
56 | for (c <- str) c match {
57 | case '"' =>
58 | // Note that double quotes are escaped by pairs of double quotes (""), so we don't need
59 | // any extra code to handle them; we'll be back in inQuotes=true after seeing the pair.
60 | sb.append('"')
61 | inQuotes = !inQuotes
62 | case '.' =>
63 | if (!inQuotes) {
64 | parts.append(sb.toString())
65 | sb.clear()
66 | } else {
67 | sb.append('.')
68 | }
69 | case other =>
70 | sb.append(other)
71 | }
72 | if (sb.nonEmpty) {
73 | parts.append(sb.toString())
74 | }
75 | parts
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/UtilsSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 TouchType Ltd
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.net.URI
20 |
21 | import org.scalatest.{FunSuite, Matchers}
22 |
23 | /**
24 | * Unit tests for helper functions
25 | */
26 | class UtilsSuite extends FunSuite with Matchers {
27 |
28 | test("joinUrls preserves protocol information") {
29 | Utils.joinUrls("s3n://foo/bar/", "/baz") shouldBe "s3n://foo/bar/baz/"
30 | Utils.joinUrls("s3n://foo/bar/", "/baz/") shouldBe "s3n://foo/bar/baz/"
31 | Utils.joinUrls("s3n://foo/bar/", "baz/") shouldBe "s3n://foo/bar/baz/"
32 | Utils.joinUrls("s3n://foo/bar/", "baz") shouldBe "s3n://foo/bar/baz/"
33 | Utils.joinUrls("s3n://foo/bar", "baz") shouldBe "s3n://foo/bar/baz/"
34 | }
35 |
36 | test("joinUrls preserves credentials") {
37 | assert(
38 | Utils.joinUrls("s3n://ACCESSKEY:SECRETKEY@bucket/tempdir", "subdir") ===
39 | "s3n://ACCESSKEY:SECRETKEY@bucket/tempdir/subdir/")
40 | }
41 |
42 | test("fixUrl produces Redshift-compatible equivalents") {
43 | Utils.fixS3Url("s3a://foo/bar/12345") shouldBe "s3://foo/bar/12345"
44 | Utils.fixS3Url("s3n://foo/bar/baz") shouldBe "s3://foo/bar/baz"
45 | }
46 |
47 | test("addEndpointToUrl produces urls with endpoints added to host") {
48 | Utils.addEndpointToUrl("s3a://foo/bar/12345") shouldBe "s3a://foo.s3.amazonaws.com/bar/12345"
49 | Utils.addEndpointToUrl("s3n://foo/bar/baz") shouldBe "s3n://foo.s3.amazonaws.com/bar/baz"
50 | }
51 |
52 | test("temp paths are random subdirectories of root") {
53 | val root = "s3n://temp/"
54 | val firstTempPath = Utils.makeTempPath(root)
55 |
56 | Utils.makeTempPath(root) should (startWith (root) and endWith ("/")
57 | and not equal root and not equal firstTempPath)
58 | }
59 |
60 | test("removeCredentialsFromURI removes AWS access keys") {
61 | def removeCreds(uri: String): String = {
62 | Utils.removeCredentialsFromURI(URI.create(uri)).toString
63 | }
64 | assert(removeCreds("s3n://bucket/path/to/temp/dir") === "s3n://bucket/path/to/temp/dir")
65 | assert(
66 | removeCreds("s3n://ACCESSKEY:SECRETKEY@bucket/path/to/temp/dir") ===
67 | "s3n://bucket/path/to/temp/dir")
68 | }
69 |
70 | test("getRegionForRedshiftCluster") {
71 | val redshiftUrl =
72 | "jdbc:redshift://example.secret.us-west-2.redshift.amazonaws.com:5439/database"
73 | assert(Utils.getRegionForRedshiftCluster("mycluster.example.com") === None)
74 | assert(Utils.getRegionForRedshiftCluster(redshiftUrl) === Some("us-west-2"))
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/FilterPushdown.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.{Date, Timestamp}
20 |
21 | import org.apache.spark.sql.sources._
22 | import org.apache.spark.sql.types._
23 |
24 | /**
25 | * Helper methods for pushing filters into Redshift queries.
26 | */
27 | private[redshift] object FilterPushdown {
28 | /**
29 | * Build a SQL WHERE clause for the given filters. If a filter cannot be pushed down then no
30 | * condition will be added to the WHERE clause. If none of the filters can be pushed down then
31 | * an empty string will be returned.
32 | *
33 | * @param schema the schema of the table being queried
34 | * @param filters an array of filters, the conjunction of which is the filter condition for the
35 | * scan.
36 | */
37 | def buildWhereClause(schema: StructType, filters: Seq[Filter]): String = {
38 | val filterExpressions = filters.flatMap(f => buildFilterExpression(schema, f)).mkString(" AND ")
39 | if (filterExpressions.isEmpty) "" else "WHERE " + filterExpressions
40 | }
41 |
42 | /**
43 | * Attempt to convert the given filter into a SQL expression. Returns None if the expression
44 | * could not be converted.
45 | */
46 | def buildFilterExpression(schema: StructType, filter: Filter): Option[String] = {
47 | def buildComparison(attr: String, value: Any, comparisonOp: String): Option[String] = {
48 | getTypeForAttribute(schema, attr).map { dataType =>
49 | val sqlEscapedValue: String = dataType match {
50 | case StringType => s"\\'${value.toString.replace("'", "\\'\\'")}\\'"
51 | case DateType => s"\\'${value.asInstanceOf[Date]}\\'"
52 | case TimestampType => s"\\'${value.asInstanceOf[Timestamp]}\\'"
53 | case _ => value.toString
54 | }
55 | s""""$attr" $comparisonOp $sqlEscapedValue"""
56 | }
57 | }
58 |
59 | filter match {
60 | case EqualTo(attr, value) => buildComparison(attr, value, "=")
61 | case LessThan(attr, value) => buildComparison(attr, value, "<")
62 | case GreaterThan(attr, value) => buildComparison(attr, value, ">")
63 | case LessThanOrEqual(attr, value) => buildComparison(attr, value, "<=")
64 | case GreaterThanOrEqual(attr, value) => buildComparison(attr, value, ">=")
65 | case IsNotNull(attr) =>
66 | getTypeForAttribute(schema, attr).map(dataType => s""""$attr" IS NOT NULL""")
67 | case IsNull(attr) =>
68 | getTypeForAttribute(schema, attr).map(dataType => s""""$attr" IS NULL""")
69 | case _ => None
70 | }
71 | }
72 |
73 | /**
74 | * Use the given schema to look up the attribute's data type. Returns None if the attribute could
75 | * not be resolved.
76 | */
77 | private def getTypeForAttribute(schema: StructType, attribute: String): Option[DataType] = {
78 | if (schema.fieldNames.contains(attribute)) {
79 | Some(schema(attribute).dataType)
80 | } else {
81 | None
82 | }
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/DecimalIntegrationSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.apache.spark.sql.Row
20 | import org.apache.spark.sql.types.DecimalType
21 |
22 | /**
23 | * Integration tests for decimal support. For a reference on Redshift's DECIMAL type, see
24 | * http://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html
25 | */
26 | class DecimalIntegrationSuite extends IntegrationSuiteBase {
27 |
28 | private def testReadingDecimals(precision: Int, scale: Int, decimalStrings: Seq[String]): Unit = {
29 | test(s"reading DECIMAL($precision, $scale)") {
30 | val tableName = s"reading_decimal_${precision}_${scale}_$randomSuffix"
31 | val expectedRows = decimalStrings.map { d =>
32 | if (d == null) {
33 | Row(null)
34 | } else {
35 | Row(Conversions.createRedshiftDecimalFormat().parse(d).asInstanceOf[java.math.BigDecimal])
36 | }
37 | }
38 | try {
39 | conn.createStatement().executeUpdate(
40 | s"CREATE TABLE $tableName (x DECIMAL($precision, $scale))")
41 | for (x <- decimalStrings) {
42 | conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES ($x)")
43 | }
44 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
45 | val loadedDf = read.option("dbtable", tableName).load()
46 | checkAnswer(loadedDf, expectedRows)
47 | checkAnswer(loadedDf.selectExpr("x + 0"), expectedRows)
48 | } finally {
49 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
50 | }
51 | }
52 | }
53 |
54 | testReadingDecimals(19, 0, Seq(
55 | // Max and min values of DECIMAL(19, 0) column according to Redshift docs:
56 | "9223372036854775807", // 2^63 - 1
57 | "-9223372036854775807",
58 | "0",
59 | "12345678910",
60 | null
61 | ))
62 |
63 | testReadingDecimals(19, 4, Seq(
64 | "922337203685477.5807",
65 | "-922337203685477.5807",
66 | "0",
67 | "1234567.8910",
68 | null
69 | ))
70 |
71 | testReadingDecimals(38, 4, Seq(
72 | "922337203685477.5808",
73 | "9999999999999999999999999999999999.0000",
74 | "-9999999999999999999999999999999999.0000",
75 | "0",
76 | "1234567.8910",
77 | null
78 | ))
79 |
80 | test("Decimal precision is preserved when reading from query (regression test for issue #203)") {
81 | withTempRedshiftTable("issue203") { tableName =>
82 | conn.createStatement().executeUpdate(s"CREATE TABLE $tableName (foo BIGINT)")
83 | conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES (91593373)")
84 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
85 | val df = read
86 | .option("query", s"select foo / 1000000.0 from $tableName limit 1")
87 | .load()
88 | val res: Double = df.collect().toSeq.head.getDecimal(0).doubleValue()
89 | assert(res === (91593373L / 1000000.0) +- 0.01)
90 | assert(df.schema.fields.head.dataType === DecimalType(28, 8))
91 | }
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/QueryTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package io.github.spark_redshift_community.spark.redshift
19 |
20 | import org.apache.spark.sql.catalyst.plans.logical
21 | import org.apache.spark.sql.{DataFrame, Row}
22 | import org.scalatest.FunSuite
23 |
24 | /**
25 | * Copy of Spark SQL's `QueryTest` trait.
26 | */
27 | trait QueryTest extends FunSuite {
28 | /**
29 | * Runs the plan and makes sure the answer matches the expected result.
30 | * @param df the [[DataFrame]] to be executed
31 | * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
32 | */
33 | def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = {
34 | val isSorted = df.queryExecution.logical.collect { case s: logical.Sort => s }.nonEmpty
35 | def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
36 | // Converts data to types that we can do equality comparison using Scala collections.
37 | // For BigDecimal type, the Scala type has a better definition of equality test (similar to
38 | // Java's java.math.BigDecimal.compareTo).
39 | // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
40 | // equality test.
41 | val converted: Seq[Row] = answer.map { s =>
42 | Row.fromSeq(s.toSeq.map {
43 | case d: java.math.BigDecimal => BigDecimal(d)
44 | case b: Array[Byte] => b.toSeq
45 | case o => o
46 | })
47 | }
48 | if (!isSorted) converted.sortBy(_.toString()) else converted
49 | }
50 | val sparkAnswer = try df.collect().toSeq catch {
51 | case e: Exception =>
52 | val errorMessage =
53 | s"""
54 | |Exception thrown while executing query:
55 | |${df.queryExecution}
56 | |== Exception ==
57 | |$e
58 | |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
59 | """.stripMargin
60 | fail(errorMessage)
61 | }
62 |
63 | if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
64 | val errorMessage =
65 | s"""
66 | |Results do not match for query:
67 | |${df.queryExecution}
68 | |== Results ==
69 | |${sideBySide(
70 | s"== Correct Answer - ${expectedAnswer.size} ==" +:
71 | prepareAnswer(expectedAnswer).map(_.toString()),
72 | s"== Spark Answer - ${sparkAnswer.size} ==" +:
73 | prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
74 | """.stripMargin
75 | fail(errorMessage)
76 | }
77 | }
78 |
79 | private def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = {
80 | val maxLeftSize = left.map(_.length).max
81 | val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("")
82 | val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("")
83 |
84 | leftPadded.zip(rightPadded).map {
85 | case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r
86 | }
87 | }
88 | }
89 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/FilterPushdownSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import io.github.spark_redshift_community.spark.redshift.FilterPushdown._
20 | import org.apache.spark.sql.sources._
21 | import org.apache.spark.sql.types._
22 | import org.scalatest.FunSuite
23 |
24 |
25 | class FilterPushdownSuite extends FunSuite {
26 | test("buildWhereClause with empty list of filters") {
27 | assert(buildWhereClause(StructType(Nil), Seq.empty) === "")
28 | }
29 |
30 | test("buildWhereClause with no filters that can be pushed down") {
31 | assert(buildWhereClause(StructType(Nil), Seq(NewFilter, NewFilter)) === "")
32 | }
33 |
34 | test("buildWhereClause with with some filters that cannot be pushed down") {
35 | val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_int", 1), NewFilter))
36 | assert(whereClause === """WHERE "test_int" = 1""")
37 | }
38 |
39 | test("buildWhereClause with string literals that contain Unicode characters") {
40 | // scalastyle:off
41 | val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_string", "Unicode's樂趣")))
42 | // Here, the apostrophe in the string needs to be replaced with two single quotes, '', but we
43 | // also need to escape those quotes with backslashes because this WHERE clause is going to
44 | // eventually be embedded inside of a single-quoted string that's embedded inside of a larger
45 | // Redshift query.
46 | assert(whereClause === """WHERE "test_string" = \'Unicode\'\'s樂趣\'""")
47 | // scalastyle:on
48 | }
49 |
50 | test("buildWhereClause with multiple filters") {
51 | val filters = Seq(
52 | EqualTo("test_bool", true),
53 | // scalastyle:off
54 | EqualTo("test_string", "Unicode是樂趣"),
55 | // scalastyle:on
56 | GreaterThan("test_double", 1000.0),
57 | LessThan("test_double", Double.MaxValue),
58 | GreaterThanOrEqual("test_float", 1.0f),
59 | LessThanOrEqual("test_int", 43),
60 | IsNotNull("test_int"),
61 | IsNull("test_int"))
62 | val whereClause = buildWhereClause(testSchema, filters)
63 | // scalastyle:off
64 | val expectedWhereClause =
65 | """
66 | |WHERE "test_bool" = true
67 | |AND "test_string" = \'Unicode是樂趣\'
68 | |AND "test_double" > 1000.0
69 | |AND "test_double" < 1.7976931348623157E308
70 | |AND "test_float" >= 1.0
71 | |AND "test_int" <= 43
72 | |AND "test_int" IS NOT NULL
73 | |AND "test_int" IS NULL
74 | """.stripMargin.lines.mkString(" ").trim
75 | // scalastyle:on
76 | assert(whereClause === expectedWhereClause)
77 | }
78 |
79 | private val testSchema: StructType = StructType(Seq(
80 | StructField("test_byte", ByteType),
81 | StructField("test_bool", BooleanType),
82 | StructField("test_date", DateType),
83 | StructField("test_double", DoubleType),
84 | StructField("test_float", FloatType),
85 | StructField("test_int", IntegerType),
86 | StructField("test_long", LongType),
87 | StructField("test_short", ShortType),
88 | StructField("test_string", StringType),
89 | StructField("test_timestamp", TimestampType)))
90 |
91 | /** A new filter subclasss which our pushdown logic does not know how to handle */
92 | private case object NewFilter extends Filter {
93 | override def references: Array[String] = Array.empty
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftFileFormat.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.net.URI
20 |
21 | import org.apache.hadoop.conf.Configuration
22 | import org.apache.hadoop.fs.{FileStatus, Path}
23 | import org.apache.hadoop.mapreduce._
24 | import org.apache.hadoop.mapreduce.lib.input.FileSplit
25 | import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
26 | import org.apache.spark.TaskContext
27 | import org.apache.spark.sql.SparkSession
28 | import org.apache.spark.sql.catalyst.InternalRow
29 | import org.apache.spark.sql.execution.datasources._
30 | import org.apache.spark.sql.sources.Filter
31 | import org.apache.spark.sql.types.{DataType, StructType}
32 |
33 | /**
34 | * Internal data source used for reading Redshift UNLOAD files.
35 | *
36 | * This is not intended for public consumption / use outside of this package and therefore
37 | * no API stability is guaranteed.
38 | */
39 | private[redshift] class RedshiftFileFormat extends FileFormat {
40 | override def inferSchema(
41 | sparkSession: SparkSession,
42 | options: Map[String, String],
43 | files: Seq[FileStatus]): Option[StructType] = {
44 | // Schema is provided by caller.
45 | None
46 | }
47 |
48 | override def prepareWrite(
49 | sparkSession: SparkSession,
50 | job: Job,
51 | options: Map[String, String],
52 | dataSchema: StructType): OutputWriterFactory = {
53 | throw new UnsupportedOperationException(s"prepareWrite is not supported for $this")
54 | }
55 |
56 | override def isSplitable(
57 | sparkSession: SparkSession,
58 | options: Map[String, String],
59 | path: Path): Boolean = {
60 | // Our custom InputFormat handles split records properly
61 | true
62 | }
63 |
64 | override def buildReader(
65 | sparkSession: SparkSession,
66 | dataSchema: StructType,
67 | partitionSchema: StructType,
68 | requiredSchema: StructType,
69 | filters: Seq[Filter],
70 | options: Map[String, String],
71 | hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
72 |
73 | require(partitionSchema.isEmpty)
74 | require(filters.isEmpty)
75 | require(dataSchema == requiredSchema)
76 |
77 | val broadcastedConf =
78 | sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
79 |
80 | (file: PartitionedFile) => {
81 | val conf = broadcastedConf.value.value
82 |
83 | val fileSplit = new FileSplit(
84 | new Path(new URI(file.filePath)),
85 | file.start,
86 | file.length,
87 | // TODO: Implement Locality
88 | Array.empty)
89 | val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
90 | val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
91 | val reader = new RedshiftRecordReader
92 | reader.initialize(fileSplit, hadoopAttemptContext)
93 | val iter = new RecordReaderIterator[Array[String]](reader)
94 | // Ensure that the record reader is closed upon task completion. It will ordinarily
95 | // be closed once it is completely iterated, but this is necessary to guard against
96 | // resource leaks in case the task fails or is interrupted.
97 | Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
98 | val converter = Conversions.createRowConverter(requiredSchema,
99 | options.getOrElse("nullString", Parameters.DEFAULT_PARAMETERS("csvnullstring")))
100 | iter.map(converter)
101 | }
102 | }
103 |
104 | override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true
105 | }
106 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/DefaultSource.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 TouchType Ltd
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import com.amazonaws.auth.AWSCredentialsProvider
20 | import com.amazonaws.services.s3.AmazonS3Client
21 | import io.github.spark_redshift_community.spark.redshift
22 | import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
23 | import org.apache.spark.sql.types.StructType
24 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
25 | import org.slf4j.LoggerFactory
26 |
27 | /**
28 | * Redshift Source implementation for Spark SQL
29 | */
30 | class DefaultSource(
31 | jdbcWrapper: JDBCWrapper,
32 | s3ClientFactory: AWSCredentialsProvider => AmazonS3Client)
33 | extends RelationProvider
34 | with SchemaRelationProvider
35 | with CreatableRelationProvider {
36 |
37 | private val log = LoggerFactory.getLogger(getClass)
38 |
39 | /**
40 | * Default constructor required by Data Source API
41 | */
42 | def this() = this(DefaultJDBCWrapper, awsCredentials => new AmazonS3Client(awsCredentials))
43 |
44 | /**
45 | * Create a new RedshiftRelation instance using parameters from Spark SQL DDL. Resolves the schema
46 | * using JDBC connection over provided URL, which must contain credentials.
47 | */
48 | override def createRelation(
49 | sqlContext: SQLContext,
50 | parameters: Map[String, String]): BaseRelation = {
51 | val params = Parameters.mergeParameters(parameters)
52 | redshift.RedshiftRelation(jdbcWrapper, s3ClientFactory, params, None)(sqlContext)
53 | }
54 |
55 | /**
56 | * Load a RedshiftRelation using user-provided schema, so no inference over JDBC will be used.
57 | */
58 | override def createRelation(
59 | sqlContext: SQLContext,
60 | parameters: Map[String, String],
61 | schema: StructType): BaseRelation = {
62 | val params = Parameters.mergeParameters(parameters)
63 | redshift.RedshiftRelation(jdbcWrapper, s3ClientFactory, params, Some(schema))(sqlContext)
64 | }
65 |
66 | /**
67 | * Creates a Relation instance by first writing the contents of the given DataFrame to Redshift
68 | */
69 | override def createRelation(
70 | sqlContext: SQLContext,
71 | saveMode: SaveMode,
72 | parameters: Map[String, String],
73 | data: DataFrame): BaseRelation = {
74 | val params = Parameters.mergeParameters(parameters)
75 | val table = params.table.getOrElse {
76 | throw new IllegalArgumentException(
77 | "For save operations you must specify a Redshift table name with the 'dbtable' parameter")
78 | }
79 |
80 | def tableExists: Boolean = {
81 | val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
82 | try {
83 | jdbcWrapper.tableExists(conn, table.toString)
84 | } finally {
85 | conn.close()
86 | }
87 | }
88 |
89 | val (doSave, dropExisting) = saveMode match {
90 | case SaveMode.Append => (true, false)
91 | case SaveMode.Overwrite => (true, true)
92 | case SaveMode.ErrorIfExists =>
93 | if (tableExists) {
94 | sys.error(s"Table $table already exists! (SaveMode is set to ErrorIfExists)")
95 | } else {
96 | (true, false)
97 | }
98 | case SaveMode.Ignore =>
99 | if (tableExists) {
100 | log.info(s"Table $table already exists -- ignoring save request.")
101 | (false, false)
102 | } else {
103 | (true, false)
104 | }
105 | }
106 |
107 | if (doSave) {
108 | val updatedParams = parameters.updated("overwrite", dropExisting.toString)
109 | new RedshiftWriter(jdbcWrapper, s3ClientFactory).saveToRedshift(
110 | sqlContext, data, saveMode, Parameters.mergeParameters(updatedParams))
111 | }
112 |
113 | createRelation(sqlContext, parameters)
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/InMemoryS3AFileSystemSuite.scala:
--------------------------------------------------------------------------------
1 | package io.github.spark_redshift_community.spark.redshift
2 |
3 | import java.io.FileNotFoundException
4 |
5 | import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path}
6 | import org.scalatest.{FunSuite, Matchers}
7 |
8 | class InMemoryS3AFileSystemSuite extends FunSuite with Matchers {
9 |
10 | test("Create a file creates all prefixes in the hierarchy") {
11 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem()
12 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/_SUCCESS")
13 |
14 | inMemoryS3AFileSystem.create(path)
15 |
16 | assert(
17 | inMemoryS3AFileSystem.exists(
18 | new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/_SUCCESS")))
19 |
20 | assert(
21 | inMemoryS3AFileSystem.exists(
22 | new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/")))
23 |
24 | assert(inMemoryS3AFileSystem.exists(new Path("s3a://test-bucket/temp-dir/")))
25 |
26 | }
27 |
28 | test("List all statuses for a dir") {
29 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem()
30 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/_SUCCESS")
31 | val path2 = new Path(
32 | "s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/manifest.json")
33 |
34 | inMemoryS3AFileSystem.create(path)
35 | inMemoryS3AFileSystem.create(path2)
36 |
37 | assert(
38 | inMemoryS3AFileSystem.listStatus(
39 | new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328")
40 | ).length == 2)
41 |
42 | assert(
43 | inMemoryS3AFileSystem.listStatus(
44 | new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328")
45 | ) === Array[FileStatus] (
46 | inMemoryS3AFileSystem.getFileStatus(path2),
47 | inMemoryS3AFileSystem.getFileStatus(path))
48 | )
49 |
50 | assert(
51 | inMemoryS3AFileSystem.listStatus(
52 | new Path("s3a://test-bucket/temp-dir/")).length == 1)
53 | }
54 |
55 | test("getFileStatus for file and dir") {
56 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem()
57 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/_SUCCESS")
58 |
59 | inMemoryS3AFileSystem.create(path)
60 |
61 | assert(inMemoryS3AFileSystem.getFileStatus(path).isDirectory === false)
62 |
63 | val dirPath = new Path(
64 | "s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328")
65 | val dirPathFileStatus = inMemoryS3AFileSystem.getFileStatus(dirPath)
66 | assert(dirPathFileStatus.isDirectory === true)
67 | assert(dirPathFileStatus.isEmptyDirectory === false)
68 |
69 | }
70 |
71 | test("Open a file from InMemoryS3AFileSystem") {
72 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem()
73 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000")
74 |
75 | inMemoryS3AFileSystem.create(path).write("some data".getBytes())
76 |
77 | var result = new Array[Byte](9)
78 | inMemoryS3AFileSystem.open(path).read(result)
79 |
80 | assert(result === "some data".getBytes())
81 |
82 | }
83 |
84 | test ("delete file from FileSystem") {
85 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem()
86 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000")
87 |
88 | inMemoryS3AFileSystem.create(path)
89 |
90 | assert(inMemoryS3AFileSystem.exists(path))
91 |
92 | inMemoryS3AFileSystem.delete(path, false)
93 | assert(inMemoryS3AFileSystem.exists(path) === false)
94 |
95 | }
96 |
97 | test("create already existing file throws FileAlreadyExistsException"){
98 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem()
99 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000")
100 | inMemoryS3AFileSystem.create(path)
101 | assertThrows[FileAlreadyExistsException](inMemoryS3AFileSystem.create(path))
102 | }
103 |
104 | test("getFileStatus can't find file"){
105 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem()
106 |
107 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000")
108 | assertThrows[FileNotFoundException](inMemoryS3AFileSystem.getFileStatus(path))
109 | }
110 |
111 | test("listStatus can't find path"){
112 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem()
113 |
114 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000")
115 | assertThrows[FileNotFoundException](inMemoryS3AFileSystem.listStatus(path))
116 | }
117 |
118 | }
119 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/AWSCredentialsUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.net.URI
20 |
21 | import com.amazonaws.auth._
22 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters
23 | import org.apache.hadoop.conf.Configuration
24 |
25 | private[redshift] object AWSCredentialsUtils {
26 |
27 | /**
28 | * Generates a credentials string for use in Redshift COPY and UNLOAD statements.
29 | * Favors a configured `aws_iam_role` if available in the parameters.
30 | */
31 | def getRedshiftCredentialsString(
32 | params: MergedParameters,
33 | sparkAwsCredentials: AWSCredentials): String = {
34 |
35 | def awsCredsToString(credentials: AWSCredentials): String = {
36 | credentials match {
37 | case creds: AWSSessionCredentials =>
38 | s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
39 | s"aws_secret_access_key=${creds.getAWSSecretKey};token=${creds.getSessionToken}"
40 | case creds =>
41 | s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
42 | s"aws_secret_access_key=${creds.getAWSSecretKey}"
43 | }
44 | }
45 | if (params.iamRole.isDefined) {
46 | s"aws_iam_role=${params.iamRole.get}"
47 | } else if (params.temporaryAWSCredentials.isDefined) {
48 | awsCredsToString(params.temporaryAWSCredentials.get.getCredentials)
49 | } else if (params.forwardSparkS3Credentials) {
50 | awsCredsToString(sparkAwsCredentials)
51 | } else {
52 | throw new IllegalStateException("No Redshift S3 authentication mechanism was specified")
53 | }
54 | }
55 |
56 | def staticCredentialsProvider(credentials: AWSCredentials): AWSCredentialsProvider = {
57 | new AWSCredentialsProvider {
58 | override def getCredentials: AWSCredentials = credentials
59 | override def refresh(): Unit = {}
60 | }
61 | }
62 |
63 | def load(params: MergedParameters, hadoopConfiguration: Configuration): AWSCredentialsProvider = {
64 | params.temporaryAWSCredentials.getOrElse(loadFromURI(params.rootTempDir, hadoopConfiguration))
65 | }
66 |
67 | private def loadFromURI(
68 | tempPath: String,
69 | hadoopConfiguration: Configuration): AWSCredentialsProvider = {
70 | // scalastyle:off
71 | // A good reference on Hadoop's configuration loading / precedence is
72 | // https://github.com/apache/hadoop/blob/trunk/hadoop-tools/hadoop-aws/src/site/markdown/tools/hadoop-aws/index.md
73 | // scalastyle:on
74 | val uri = new URI(tempPath)
75 | val uriScheme = uri.getScheme
76 |
77 | uriScheme match {
78 | case "s3" | "s3n" | "s3a" =>
79 | // WARNING: credentials in the URI is a potentially unsafe practice. I'm removing the test
80 | // AWSCredentialsInUriIntegrationSuite, so the following might or might not work.
81 |
82 | // This matches what S3A does, with one exception: we don't support anonymous credentials.
83 | // First, try to parse from URI:
84 | Option(uri.getUserInfo).flatMap { userInfo =>
85 | if (userInfo.contains(":")) {
86 | val Array(accessKey, secretKey) = userInfo.split(":")
87 | Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)))
88 | } else {
89 | None
90 | }
91 | }.orElse {
92 | // Next, try to read from configuration
93 | val accessKeyConfig = if (uriScheme == "s3a") "access.key" else "awsAccessKeyId"
94 | val secretKeyConfig = if (uriScheme == "s3a") "secret.key" else "awsSecretAccessKey"
95 |
96 | val accessKey = hadoopConfiguration.get(s"fs.$uriScheme.$accessKeyConfig", null)
97 | val secretKey = hadoopConfiguration.get(s"fs.$uriScheme.$secretKeyConfig", null)
98 | if (accessKey != null && secretKey != null) {
99 | Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)))
100 | } else {
101 | None
102 | }
103 | }.getOrElse {
104 | // Finally, fall back on the instance profile provider
105 | new DefaultAWSCredentialsProviderChain()
106 | }
107 | case other =>
108 | throw new IllegalArgumentException(s"Unrecognized scheme $other; expected s3, s3n, or s3a")
109 | }
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/MockRedshift.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.{Connection, PreparedStatement, ResultSet, SQLException}
20 |
21 | import org.apache.spark.sql.types.StructType
22 | import org.mockito.Matchers._
23 | import org.mockito.Mockito._
24 | import org.mockito.invocation.InvocationOnMock
25 | import org.mockito.stubbing.Answer
26 | import org.scalatest.Assertions._
27 |
28 | import scala.collection.mutable
29 | import scala.util.matching.Regex
30 |
31 |
32 | /**
33 | * Helper class for mocking Redshift / JDBC in unit tests.
34 | */
35 | class MockRedshift(
36 | jdbcUrl: String,
37 | existingTablesAndSchemas: Map[String, StructType],
38 | jdbcQueriesThatShouldFail: Seq[Regex] = Seq.empty) {
39 |
40 | private[this] val queriesIssued: mutable.Buffer[String] = mutable.Buffer.empty
41 | def getQueriesIssuedAgainstRedshift: Seq[String] = queriesIssued.toSeq
42 |
43 | private[this] val jdbcConnections: mutable.Buffer[Connection] = mutable.Buffer.empty
44 |
45 | val jdbcWrapper: JDBCWrapper = spy(new JDBCWrapper)
46 |
47 | private def createMockConnection(): Connection = {
48 | val conn = mock(classOf[Connection], RETURNS_SMART_NULLS)
49 | jdbcConnections.append(conn)
50 | when(conn.prepareStatement(anyString())).thenAnswer(new Answer[PreparedStatement] {
51 | override def answer(invocation: InvocationOnMock): PreparedStatement = {
52 | val query = invocation.getArguments()(0).asInstanceOf[String]
53 | queriesIssued.append(query)
54 | val mockStatement = mock(classOf[PreparedStatement], RETURNS_SMART_NULLS)
55 | if (jdbcQueriesThatShouldFail.forall(_.findFirstMatchIn(query).isEmpty)) {
56 | when(mockStatement.execute()).thenReturn(true)
57 | when(mockStatement.executeQuery()).thenReturn(
58 | mock(classOf[ResultSet], RETURNS_SMART_NULLS))
59 | } else {
60 | when(mockStatement.execute()).thenThrow(new SQLException(s"Error executing $query"))
61 | when(mockStatement.executeQuery()).thenThrow(new SQLException(s"Error executing $query"))
62 | }
63 | mockStatement
64 | }
65 | })
66 | conn
67 | }
68 |
69 | doAnswer(new Answer[Connection] {
70 | override def answer(invocation: InvocationOnMock): Connection = createMockConnection()
71 | }).when(jdbcWrapper)
72 | .getConnector(any[Option[String]](), same(jdbcUrl), any[Option[(String, String)]]())
73 |
74 | doAnswer(new Answer[Boolean] {
75 | override def answer(invocation: InvocationOnMock): Boolean = {
76 | existingTablesAndSchemas.contains(invocation.getArguments()(1).asInstanceOf[String])
77 | }
78 | }).when(jdbcWrapper).tableExists(any[Connection], anyString())
79 |
80 | doAnswer(new Answer[StructType] {
81 | override def answer(invocation: InvocationOnMock): StructType = {
82 | existingTablesAndSchemas(invocation.getArguments()(1).asInstanceOf[String])
83 | }
84 | }).when(jdbcWrapper).resolveTable(any[Connection], anyString())
85 |
86 | def verifyThatConnectionsWereClosed(): Unit = {
87 | jdbcConnections.foreach { conn =>
88 | verify(conn).close()
89 | }
90 | }
91 |
92 | def verifyThatRollbackWasCalled(): Unit = {
93 | jdbcConnections.foreach { conn =>
94 | verify(conn, atLeastOnce()).rollback()
95 | }
96 | }
97 |
98 | def verifyThatCommitWasNotCalled(): Unit = {
99 | jdbcConnections.foreach { conn =>
100 | verify(conn, never()).commit()
101 | }
102 | }
103 |
104 | def verifyThatExpectedQueriesWereIssued(expectedQueries: Seq[Regex]): Unit = {
105 | expectedQueries.zip(queriesIssued).foreach { case (expected, actual) =>
106 | if (expected.findFirstMatchIn(actual).isEmpty) {
107 | fail(
108 | s"""
109 | |Actual and expected JDBC queries did not match:
110 | |Expected: $expected
111 | |Actual: $actual
112 | """.stripMargin)
113 | }
114 | }
115 | if (expectedQueries.length > queriesIssued.length) {
116 | val missingQueries = expectedQueries.drop(queriesIssued.length)
117 | fail(s"Missing ${missingQueries.length} expected JDBC queries:" +
118 | s"\n${missingQueries.mkString("\n")}")
119 | } else if (queriesIssued.length > expectedQueries.length) {
120 | val extraQueries = queriesIssued.drop(expectedQueries.length)
121 | fail(s"Got ${extraQueries.length} unexpected JDBC queries:\n${extraQueries.mkString("\n")}")
122 | }
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/TestUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 TouchType Ltd
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.{Date, Timestamp}
20 | import java.time.ZoneId
21 | import java.util.{Calendar, Locale, TimeZone}
22 |
23 | import org.apache.spark.sql.Row
24 | import org.apache.spark.sql.types._
25 |
26 | /**
27 | * Helpers for Redshift tests that require common mocking
28 | */
29 | object TestUtils {
30 |
31 | /**
32 | * Simple schema that includes all data types we support
33 | */
34 | val testSchema: StructType = {
35 | // These column names need to be lowercase; see #51
36 | StructType(Seq(
37 | StructField("testbyte", ByteType),
38 | StructField("testbool", BooleanType),
39 | StructField("testdate", DateType),
40 | StructField("testdouble", DoubleType),
41 | StructField("testfloat", FloatType),
42 | StructField("testint", IntegerType),
43 | StructField("testlong", LongType),
44 | StructField("testshort", ShortType),
45 | StructField("teststring", StringType),
46 | StructField("testtimestamp", TimestampType)))
47 | }
48 |
49 | // scalastyle:off
50 | /**
51 | * Expected parsed output corresponding to the output of testData.
52 | */
53 | val expectedData: Seq[Row] = Seq(
54 | Row(1.toByte, true, TestUtils.toDate(2015, 6, 1), 1234152.12312498,
55 | 1.0f, 42, 1239012341823719L, 23.toShort, "Unicode's樂趣",
56 | TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1)),
57 | Row(1.toByte, false, TestUtils.toDate(2015, 6, 2), 0.0, 0.0f, 42,
58 | 1239012341823719L, -13.toShort, "asdf", TestUtils.toTimestamp(2015, 6, 2, 0, 0, 0, 0)),
59 | Row(0.toByte, null, TestUtils.toDate(2015, 6, 3), 0.0, -1.0f, 4141214,
60 | 1239012341823719L, null, "f", TestUtils.toTimestamp(2015, 6, 3, 0, 0, 0)),
61 | Row(0.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L, 24.toShort,
62 | "___|_123", null),
63 | Row(List.fill(10)(null): _*))
64 | // scalastyle:on
65 |
66 | /**
67 | * The same as `expectedData`, but with dates and timestamps converted into string format.
68 | * See #39 for context.
69 | */
70 | val expectedDataWithConvertedTimesAndDates: Seq[Row] = expectedData.map { row =>
71 | Row.fromSeq(row.toSeq.map {
72 | case t: Timestamp => Conversions.createRedshiftTimestampFormat().format(t)
73 | case d: Date => Conversions.createRedshiftDateFormat().format(d)
74 | case other => other
75 | })
76 | }
77 |
78 | /**
79 | * Convert date components to a millisecond timestamp
80 | */
81 | def toMillis(
82 | year: Int,
83 | zeroBasedMonth: Int,
84 | date: Int,
85 | hour: Int,
86 | minutes: Int,
87 | seconds: Int,
88 | millis: Int = 0,
89 | timeZone: String = null): Long = {
90 | val calendar = Calendar.getInstance()
91 | calendar.set(year, zeroBasedMonth, date, hour, minutes, seconds)
92 | calendar.set(Calendar.MILLISECOND, millis)
93 | if (timeZone != null) calendar.setTimeZone(TimeZone.getTimeZone(ZoneId.of(timeZone)))
94 | calendar.getTime.getTime
95 | }
96 |
97 | def toNanosTimestamp(
98 | year: Int,
99 | zeroBasedMonth: Int,
100 | date: Int,
101 | hour: Int,
102 | minutes: Int,
103 | seconds: Int,
104 | nanos: Int
105 | ): Timestamp = {
106 | val ts = new Timestamp(
107 | toMillis(
108 | year,
109 | zeroBasedMonth,
110 | date,
111 | hour,
112 | minutes,
113 | seconds
114 | )
115 | )
116 | ts.setNanos(nanos)
117 | ts
118 | }
119 |
120 | /**
121 | * Convert date components to a SQL Timestamp
122 | */
123 | def toTimestamp(
124 | year: Int,
125 | zeroBasedMonth: Int,
126 | date: Int,
127 | hour: Int,
128 | minutes: Int,
129 | seconds: Int,
130 | millis: Int = 0): Timestamp = {
131 | new Timestamp(toMillis(year, zeroBasedMonth, date, hour, minutes, seconds, millis))
132 | }
133 |
134 | /**
135 | * Convert date components to a SQL [[Date]].
136 | */
137 | def toDate(year: Int, zeroBasedMonth: Int, date: Int): Date = {
138 | new Date(toTimestamp(year, zeroBasedMonth, date, 0, 0, 0).getTime)
139 | }
140 |
141 | def withDefaultLocale[T](newDefaultLocale: Locale)(block: => T): T = {
142 | val originalDefaultLocale = Locale.getDefault
143 | try {
144 | Locale.setDefault(newDefaultLocale)
145 | block
146 | } finally {
147 | Locale.setDefault(originalDefaultLocale)
148 | }
149 | }
150 | }
151 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/ColumnMetadataSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.SQLException
20 |
21 | import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField, StructType}
22 | import org.apache.spark.sql.{Row, SaveMode}
23 |
24 | /**
25 | * End-to-end tests of features which depend on per-column metadata (such as comments, maxlength).
26 | */
27 | class ColumnMetadataSuite extends IntegrationSuiteBase {
28 |
29 | test("configuring maxlength on string columns") {
30 | val tableName = s"configuring_maxlength_on_string_column_$randomSuffix"
31 | try {
32 | val metadata = new MetadataBuilder().putLong("maxlength", 512).build()
33 | val schema = StructType(
34 | StructField("x", StringType, metadata = metadata) :: Nil)
35 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 512))), schema))
36 | .option("dbtable", tableName)
37 | .mode(SaveMode.ErrorIfExists)
38 | .save()
39 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
40 | checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 512)))
41 | // This append should fail due to the string being longer than the maxlength
42 | intercept[SQLException] {
43 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 513))), schema))
44 | .option("dbtable", tableName)
45 | .mode(SaveMode.Append)
46 | .save()
47 | }
48 | } finally {
49 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
50 | }
51 | }
52 |
53 | test("configuring compression on columns") {
54 | val tableName = s"configuring_compression_on_columns_$randomSuffix"
55 | try {
56 | val metadata = new MetadataBuilder().putString("encoding", "LZO").build()
57 | val schema = StructType(
58 | StructField("x", StringType, metadata = metadata) :: Nil)
59 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 128))), schema))
60 | .option("dbtable", tableName)
61 | .mode(SaveMode.ErrorIfExists)
62 | .save()
63 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
64 | checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 128)))
65 | val encodingDF = sqlContext.read
66 | .format("jdbc")
67 | .option("url", jdbcUrl)
68 | .option("dbtable",
69 | s"""(SELECT "column", lower(encoding) FROM pg_table_def WHERE tablename='$tableName')""")
70 | .load()
71 | checkAnswer(encodingDF, Seq(Row("x", "lzo")))
72 | } finally {
73 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
74 | }
75 | }
76 |
77 | test("configuring comments on columns") {
78 | val tableName = s"configuring_comments_on_columns_$randomSuffix"
79 | try {
80 | val metadata = new MetadataBuilder().putString("description", "Hello Column").build()
81 | val schema = StructType(
82 | StructField("x", StringType, metadata = metadata) :: Nil)
83 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 128))), schema))
84 | .option("dbtable", tableName)
85 | .option("description", "Hello Table")
86 | .mode(SaveMode.ErrorIfExists)
87 | .save()
88 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
89 | checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 128)))
90 | val tableDF = sqlContext.read
91 | .format("jdbc")
92 | .option("url", jdbcUrl)
93 | .option("dbtable", s"(SELECT pg_catalog.obj_description('$tableName'::regclass))")
94 | .load()
95 | checkAnswer(tableDF, Seq(Row("Hello Table")))
96 | val commentQuery =
97 | s"""
98 | |(SELECT c.column_name, pgd.description
99 | |FROM pg_catalog.pg_statio_all_tables st
100 | |INNER JOIN pg_catalog.pg_description pgd
101 | | ON (pgd.objoid=st.relid)
102 | |INNER JOIN information_schema.columns c
103 | | ON (pgd.objsubid=c.ordinal_position AND c.table_name=st.relname)
104 | |WHERE c.table_name='$tableName')
105 | """.stripMargin
106 | val columnDF = sqlContext.read
107 | .format("jdbc")
108 | .option("url", jdbcUrl)
109 | .option("dbtable", commentQuery)
110 | .load()
111 | checkAnswer(columnDF, Seq(Row("x", "Hello Column")))
112 | } finally {
113 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
114 | }
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/SaveModeIntegrationSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
20 | import org.apache.spark.sql.{Row, SaveMode}
21 |
22 | /**
23 | * End-to-end tests of [[SaveMode]] behavior.
24 | */
25 | class SaveModeIntegrationSuite extends IntegrationSuiteBase {
26 | test("SaveMode.Overwrite with schema-qualified table name (#97)") {
27 | withTempRedshiftTable("overwrite_schema_qualified_table_name") { tableName =>
28 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
29 | StructType(StructField("a", IntegerType) :: Nil))
30 | // Ensure that the table exists:
31 | write(df)
32 | .option("dbtable", tableName)
33 | .mode(SaveMode.ErrorIfExists)
34 | .save()
35 | assert(DefaultJDBCWrapper.tableExists(conn, s"PUBLIC.$tableName"))
36 | // Try overwriting that table while using the schema-qualified table name:
37 | write(df)
38 | .option("dbtable", s"PUBLIC.$tableName")
39 | .mode(SaveMode.Overwrite)
40 | .save()
41 | }
42 | }
43 |
44 | test("SaveMode.Overwrite with non-existent table") {
45 | testRoundtripSaveAndLoad(
46 | s"overwrite_non_existent_table$randomSuffix",
47 | sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
48 | StructType(StructField("a", IntegerType) :: Nil)),
49 | saveMode = SaveMode.Overwrite)
50 | }
51 |
52 | test("SaveMode.Overwrite with existing table") {
53 | withTempRedshiftTable("overwrite_existing_table") { tableName =>
54 | // Create a table to overwrite
55 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
56 | StructType(StructField("a", IntegerType) :: Nil)))
57 | .option("dbtable", tableName)
58 | .mode(SaveMode.ErrorIfExists)
59 | .save()
60 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
61 |
62 | val overwritingDf =
63 | sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema)
64 | write(overwritingDf)
65 | .option("dbtable", tableName)
66 | .mode(SaveMode.Overwrite)
67 | .save()
68 |
69 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
70 | checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData)
71 | }
72 | }
73 |
74 | // TODO:test overwrite that fails.
75 |
76 | // TODO (luca|issue #7) make SaveMode work
77 | ignore("Append SaveMode doesn't destroy existing data") {
78 | withTempRedshiftTable("append_doesnt_destroy_existing_data") { tableName =>
79 | createTestDataInRedshift(tableName)
80 | val extraData = Seq(
81 | Row(2.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L,
82 | 24.toShort, "___|_123", null))
83 |
84 | write(sqlContext.createDataFrame(sc.parallelize(extraData), TestUtils.testSchema))
85 | .option("dbtable", tableName)
86 | .mode(SaveMode.Append)
87 | .saveAsTable(tableName)
88 |
89 | checkAnswer(
90 | sqlContext.sql(s"select * from $tableName"),
91 | TestUtils.expectedData ++ extraData)
92 | }
93 | }
94 |
95 | ignore("Respect SaveMode.ErrorIfExists when table exists") {
96 | withTempRedshiftTable("respect_savemode_error_if_exists") { tableName =>
97 | val rdd = sc.parallelize(TestUtils.expectedData)
98 | val df = sqlContext.createDataFrame(rdd, TestUtils.testSchema)
99 | createTestDataInRedshift(tableName) // to ensure that the table already exists
100 |
101 | // Check that SaveMode.ErrorIfExists throws an exception
102 | val e = intercept[Exception] {
103 | write(df)
104 | .option("dbtable", tableName)
105 | .mode(SaveMode.ErrorIfExists)
106 | .saveAsTable(tableName)
107 | }
108 | assert(e.getMessage.contains("exists"))
109 | }
110 | }
111 |
112 | ignore("Do nothing when table exists if SaveMode = Ignore") {
113 | withTempRedshiftTable("do_nothing_when_savemode_ignore") { tableName =>
114 | val rdd = sc.parallelize(TestUtils.expectedData.drop(1))
115 | val df = sqlContext.createDataFrame(rdd, TestUtils.testSchema)
116 | createTestDataInRedshift(tableName) // to ensure that the table already exists
117 | write(df)
118 | .option("dbtable", tableName)
119 | .mode(SaveMode.Ignore)
120 | .saveAsTable(tableName)
121 |
122 | // Check that SaveMode.Ignore does nothing
123 | checkAnswer(
124 | sqlContext.sql(s"select * from $tableName"),
125 | TestUtils.expectedData)
126 | }
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftInputFormatSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2014 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package io.github.spark_redshift_community.spark.redshift
17 |
18 | import java.io.{DataOutputStream, File, FileOutputStream}
19 |
20 | import com.google.common.io.Files
21 | import io.github.spark_redshift_community.spark.redshift.RedshiftInputFormat._
22 | import org.apache.hadoop.conf.Configuration
23 | import org.apache.spark.SparkContext
24 | import org.apache.spark.sql.types._
25 | import org.apache.spark.sql.{Row, SQLContext}
26 | import org.scalatest.{BeforeAndAfterAll, FunSuite}
27 |
28 | import scala.language.implicitConversions
29 |
30 | class RedshiftInputFormatSuite extends FunSuite with BeforeAndAfterAll {
31 |
32 | import RedshiftInputFormatSuite._
33 |
34 | private var sc: SparkContext = _
35 |
36 | override def beforeAll(): Unit = {
37 | super.beforeAll()
38 | sc = new SparkContext("local", this.getClass.getName)
39 | }
40 |
41 | override def afterAll(): Unit = {
42 | sc.stop()
43 | super.afterAll()
44 | }
45 |
46 | private def writeToFile(contents: String, file: File): Unit = {
47 | val bytes = contents.getBytes
48 | val out = new DataOutputStream(new FileOutputStream(file))
49 | out.write(bytes, 0, bytes.length)
50 | out.close()
51 | }
52 |
53 | private def escape(records: Set[Seq[String]], delimiter: Char): String = {
54 | require(delimiter != '\\' && delimiter != '\n')
55 | records.map { r =>
56 | r.map { f =>
57 | f.replace("\\", "\\\\")
58 | .replace("\n", "\\\n")
59 | .replace(delimiter, "\\" + delimiter)
60 | }.mkString(delimiter)
61 | }.mkString("", "\n", "\n")
62 | }
63 |
64 | private final val KEY_BLOCK_SIZE = "fs.local.block.size"
65 |
66 | private final val TAB = '\t'
67 |
68 | private val records = Set(
69 | Seq("a\n", DEFAULT_DELIMITER + "b\\"),
70 | Seq("c", TAB + "d"),
71 | Seq("\ne", "\\\\f"))
72 |
73 | private def withTempDir(func: File => Unit): Unit = {
74 | val dir = Files.createTempDir()
75 | dir.deleteOnExit()
76 | func(dir)
77 | }
78 |
79 | test("default delimiter") {
80 | withTempDir { dir =>
81 | val escaped = escape(records, DEFAULT_DELIMITER)
82 | writeToFile(escaped, new File(dir, "part-00000"))
83 |
84 | val conf = new Configuration
85 | conf.setLong(KEY_BLOCK_SIZE, 4)
86 |
87 | val rdd = sc.newAPIHadoopFile(dir.toString, classOf[RedshiftInputFormat],
88 | classOf[java.lang.Long], classOf[Array[String]], conf)
89 |
90 | // TODO: Check this assertion - fails on Travis only, no idea what, or what it's for
91 | // assert(rdd.partitions.size > records.size) // so there exist at least one empty partition
92 |
93 | val actual = rdd.values.map(_.toSeq).collect()
94 | assert(actual.size === records.size)
95 | assert(actual.toSet === records)
96 | }
97 | }
98 |
99 | test("customized delimiter") {
100 | withTempDir { dir =>
101 | val escaped = escape(records, TAB)
102 | writeToFile(escaped, new File(dir, "part-00000"))
103 |
104 | val conf = new Configuration
105 | conf.setLong(KEY_BLOCK_SIZE, 4)
106 | conf.set(KEY_DELIMITER, TAB)
107 |
108 | val rdd = sc.newAPIHadoopFile(dir.toString, classOf[RedshiftInputFormat],
109 | classOf[java.lang.Long], classOf[Array[String]], conf)
110 |
111 | // TODO: Check this assertion - fails on Travis only, no idea what, or what it's for
112 | // assert(rdd.partitions.size > records.size) // so there exist at least one empty partitions
113 |
114 | val actual = rdd.values.map(_.toSeq).collect()
115 | assert(actual.size === records.size)
116 | assert(actual.toSet === records)
117 | }
118 | }
119 |
120 | test("schema parser") {
121 | withTempDir { dir =>
122 | val testRecords = Set(
123 | Seq("a\n", "TX", 1, 1.0, 1000L, 200000000000L),
124 | Seq("b", "CA", 2, 2.0, 2000L, 1231412314L))
125 | val escaped = escape(testRecords.map(_.map(_.toString)), DEFAULT_DELIMITER)
126 | writeToFile(escaped, new File(dir, "part-00000"))
127 |
128 | val sqlContext = new SQLContext(sc)
129 | val expectedSchema = StructType(Seq(
130 | StructField("name", StringType, nullable = true),
131 | StructField("state", StringType, nullable = true),
132 | StructField("id", IntegerType, nullable = true),
133 | StructField("score", DoubleType, nullable = true),
134 | StructField("big_score", LongType, nullable = true),
135 | StructField("some_long", LongType, nullable = true)))
136 |
137 | val df = sqlContext.redshiftFile(dir.toString, expectedSchema)
138 | assert(df.schema === expectedSchema)
139 |
140 | val parsed = df.rdd.map {
141 | case Row(
142 | name: String, state: String, id: Int, score: Double, bigScore: Long, someLong: Long
143 | ) => Seq(name, state, id, score, bigScore, someLong)
144 | }.collect().toSet
145 |
146 | assert(parsed === testRecords)
147 | }
148 | }
149 | }
150 |
151 | object RedshiftInputFormatSuite {
152 | implicit def charToString(c: Char): String = c.toString
153 | }
154 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/AWSCredentialsUtilsSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import com.amazonaws.auth.{AWSSessionCredentials, BasicAWSCredentials, BasicSessionCredentials}
20 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters
21 | import org.apache.hadoop.conf.Configuration
22 | import org.scalatest.FunSuite
23 |
24 | import scala.language.implicitConversions
25 |
26 | class AWSCredentialsUtilsSuite extends FunSuite {
27 |
28 | val baseParams = Map(
29 | "tempdir" -> "s3://foo/bar",
30 | "dbtable" -> "test_schema.test_table",
31 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password")
32 |
33 | private implicit def string2Params(tempdir: String): MergedParameters = {
34 | Parameters.mergeParameters(baseParams ++ Map(
35 | "tempdir" -> tempdir,
36 | "forward_spark_s3_credentials" -> "true"))
37 | }
38 |
39 | test("credentialsString with regular keys") {
40 | val creds = new BasicAWSCredentials("ACCESSKEYID", "SECRET/KEY/WITH/SLASHES")
41 | val params =
42 | Parameters.mergeParameters(baseParams ++ Map("forward_spark_s3_credentials" -> "true"))
43 | assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) ===
44 | "aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY/WITH/SLASHES")
45 | }
46 |
47 | test("credentialsString with STS temporary keys") {
48 | val params = Parameters.mergeParameters(baseParams ++ Map(
49 | "temporary_aws_access_key_id" -> "ACCESSKEYID",
50 | "temporary_aws_secret_access_key" -> "SECRET/KEY",
51 | "temporary_aws_session_token" -> "SESSION/Token"))
52 | assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) ===
53 | "aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY;token=SESSION/Token")
54 | }
55 |
56 | test("Configured IAM roles should take precedence") {
57 | val creds = new BasicSessionCredentials("ACCESSKEYID", "SECRET/KEY", "SESSION/Token")
58 | val iamRole = "arn:aws:iam::123456789000:role/redshift_iam_role"
59 | val params = Parameters.mergeParameters(baseParams ++ Map("aws_iam_role" -> iamRole))
60 | assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) ===
61 | s"aws_iam_role=$iamRole")
62 | }
63 |
64 | test("AWSCredentials.load() STS temporary keys should take precedence") {
65 | val conf = new Configuration(false)
66 | conf.set("fs.s3.awsAccessKeyId", "CONFID")
67 | conf.set("fs.s3.awsSecretAccessKey", "CONFKEY")
68 |
69 | val params = Parameters.mergeParameters(baseParams ++ Map(
70 | "tempdir" -> "s3://URIID:URIKEY@bucket/path",
71 | "temporary_aws_access_key_id" -> "key_id",
72 | "temporary_aws_secret_access_key" -> "secret",
73 | "temporary_aws_session_token" -> "token"
74 | ))
75 |
76 | val creds = AWSCredentialsUtils.load(params, conf).getCredentials
77 | assert(creds.isInstanceOf[AWSSessionCredentials])
78 | assert(creds.getAWSAccessKeyId === "key_id")
79 | assert(creds.getAWSSecretKey === "secret")
80 | assert(creds.asInstanceOf[AWSSessionCredentials].getSessionToken === "token")
81 | }
82 |
83 | test("AWSCredentials.load() credentials precedence for s3:// URIs") {
84 | val conf = new Configuration(false)
85 | conf.set("fs.s3.awsAccessKeyId", "CONFID")
86 | conf.set("fs.s3.awsSecretAccessKey", "CONFKEY")
87 |
88 | {
89 | val creds = AWSCredentialsUtils.load("s3://URIID:URIKEY@bucket/path", conf).getCredentials
90 | assert(creds.getAWSAccessKeyId === "URIID")
91 | assert(creds.getAWSSecretKey === "URIKEY")
92 | }
93 |
94 | {
95 | val creds = AWSCredentialsUtils.load("s3://bucket/path", conf).getCredentials
96 | assert(creds.getAWSAccessKeyId === "CONFID")
97 | assert(creds.getAWSSecretKey === "CONFKEY")
98 | }
99 |
100 | }
101 |
102 | test("AWSCredentials.load() credentials precedence for s3n:// URIs") {
103 | val conf = new Configuration(false)
104 | conf.set("fs.s3n.awsAccessKeyId", "CONFID")
105 | conf.set("fs.s3n.awsSecretAccessKey", "CONFKEY")
106 |
107 | {
108 | val creds = AWSCredentialsUtils.load("s3n://URIID:URIKEY@bucket/path", conf).getCredentials
109 | assert(creds.getAWSAccessKeyId === "URIID")
110 | assert(creds.getAWSSecretKey === "URIKEY")
111 | }
112 |
113 | {
114 | val creds = AWSCredentialsUtils.load("s3n://bucket/path", conf).getCredentials
115 | assert(creds.getAWSAccessKeyId === "CONFID")
116 | assert(creds.getAWSSecretKey === "CONFKEY")
117 | }
118 |
119 | }
120 |
121 | test("AWSCredentials.load() credentials precedence for s3a:// URIs") {
122 | val conf = new Configuration(false)
123 | conf.set("fs.s3a.access.key", "CONFID")
124 | conf.set("fs.s3a.secret.key", "CONFKEY")
125 |
126 | {
127 | val creds = AWSCredentialsUtils.load("s3a://URIID:URIKEY@bucket/path", conf).getCredentials
128 | assert(creds.getAWSAccessKeyId === "URIID")
129 | assert(creds.getAWSSecretKey === "URIKEY")
130 | }
131 |
132 | {
133 | val creds = AWSCredentialsUtils.load("s3a://bucket/path", conf).getCredentials
134 | assert(creds.getAWSAccessKeyId === "CONFID")
135 | assert(creds.getAWSSecretKey === "CONFKEY")
136 | }
137 |
138 | }
139 | }
140 |
--------------------------------------------------------------------------------
/tutorial/SparkRedshiftTutorial.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift.tutorial
18 | import org.apache.spark.sql.{SQLContext, SaveMode}
19 | import org.apache.spark.{SparkConf, SparkContext}
20 |
21 |
22 | /**
23 | * Source code accompanying the spark-redshift tutorial.
24 | * The following parameters need to be passed
25 | * 1. AWS Access Key
26 | * 2. AWS Secret Access Key
27 | * 3. Redshift Database Name
28 | * 4. Redshift UserId
29 | * 5. Redshift Password
30 | * 6. Redshift URL (Ex. swredshift.czac2vcs84ci.us-east-1.redshift.amazonaws.com:5439)
31 | */
32 | object SparkRedshiftTutorial {
33 | /*
34 | * For Windows Users only
35 | * 1. Download contents from link
36 | * https://github.com/srccodes/hadoop-common-2.2.0-bin/archive/master.zip
37 | * 2. Unzip the file in step 1 into your %HADOOP_HOME%/bin.
38 | * 3. pass System parameter -Dhadoop.home.dir=%HADOOP_HOME/bin where %HADOOP_HOME
39 | * must be an absolute not relative path
40 | */
41 |
42 | def main(args: Array[String]): Unit = {
43 |
44 | if (args.length < 6) {
45 | println("Needs 6 parameters only passed " + args.length)
46 | println("parameters needed - $awsAccessKey $awsSecretKey $rsDbName $rsUser $rsPassword $rsURL")
47 | }
48 | val awsAccessKey = args(0)
49 | val awsSecretKey = args(1)
50 | val rsDbName = args(2)
51 | val rsUser = args(3)
52 | val rsPassword = args(4)
53 | //Sample Redshift URL is swredshift.czac2vcs84ci.us-east-1.redshift.amazonaws.com:5439
54 | val rsURL = args(5)
55 | val jdbcURL = s"""jdbc:redshift://$rsURL/$rsDbName?user=$rsUser&password=$rsPassword"""
56 | println(jdbcURL)
57 | val sc = new SparkContext(new SparkConf().setAppName("SparkSQL").setMaster("local"))
58 |
59 | val tempS3Dir = "s3n://redshift-spark/temp/"
60 | sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", awsAccessKey)
61 | sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", awsSecretKey)
62 |
63 | val sqlContext = new SQLContext(sc)
64 |
65 | //Load from a table
66 | val eventsDF = sqlContext.read
67 | .format("io.github.spark_redshift_community.spark.redshift")
68 | .option("url", jdbcURL)
69 | .option("tempdir", tempS3Dir)
70 | .option("dbtable", "event")
71 | .load()
72 | eventsDF.show()
73 | eventsDF.printSchema()
74 |
75 | //Load from a query
76 | val salesQuery = """SELECT salesid, listid, sellerid, buyerid,
77 | eventid, dateid, qtysold, pricepaid, commission
78 | FROM sales
79 | ORDER BY saletime DESC LIMIT 10000"""
80 | val salesDF = sqlContext.read
81 | .format("io.github.spark_redshift_community.spark.redshift")
82 | .option("url", jdbcURL)
83 | .option("tempdir", tempS3Dir)
84 | .option("query", salesQuery)
85 | .load()
86 | salesDF.show()
87 |
88 | val eventQuery = "SELECT * FROM event"
89 | val eventDF = sqlContext.read
90 | .format("io.github.spark_redshift_community.spark.redshift")
91 | .option("url", jdbcURL)
92 | .option("tempdir", tempS3Dir)
93 | .option("query", eventQuery)
94 | .load()
95 |
96 | /*
97 | * Register 'event' table as temporary table 'myevent'
98 | * so that it can be queried via sqlContext.sql
99 | */
100 | eventsDF.registerTempTable("myevent")
101 |
102 | //Save to a Redshift table from a table registered in Spark
103 |
104 | /*
105 | * Create a new table redshiftevent after dropping any existing redshiftevent table
106 | * and write event records with event id less than 1000
107 | */
108 | sqlContext.sql("SELECT * FROM myevent WHERE eventid<=1000").withColumnRenamed("eventid", "id")
109 | .write.format("io.github.spark_redshift_community.spark.redshift")
110 | .option("url", jdbcURL)
111 | .option("tempdir", tempS3Dir)
112 | .option("dbtable", "redshiftevent")
113 | .mode(SaveMode.Overwrite)
114 | .save()
115 |
116 | /*
117 | * Append to an existing table redshiftevent if it exists or create a new one if it does not
118 | * exist and write event records with event id greater than 1000
119 | */
120 | sqlContext.sql("SELECT * FROM myevent WHERE eventid>1000").withColumnRenamed("eventid", "id")
121 | .write.format("io.github.spark_redshift_community.spark.redshift")
122 | .option("url", jdbcURL)
123 | .option("tempdir", tempS3Dir)
124 | .option("dbtable", "redshiftevent")
125 | .mode(SaveMode.Append)
126 | .save()
127 |
128 | /** Demonstration of interoperability */
129 | val salesAGGQuery = """SELECT sales.eventid AS id, SUM(qtysold) AS totalqty, SUM(pricepaid) AS salesamt
130 | FROM sales
131 | GROUP BY (sales.eventid)
132 | """
133 | val salesAGGDF = sqlContext.read
134 | .format("io.github.spark_redshift_community.spark.redshift")
135 | .option("url", jdbcURL)
136 | .option("tempdir", tempS3Dir)
137 | .option("query", salesAGGQuery)
138 | .load()
139 | salesAGGDF.registerTempTable("salesagg")
140 |
141 | /*
142 | * Join two DataFrame instances. Each could be sourced from any
143 | * compatible Data Source
144 | */
145 | val salesAGGDF2 = salesAGGDF.join(eventsDF, salesAGGDF("id") === eventsDF("eventid"))
146 | .select("id", "eventname", "totalqty", "salesamt")
147 |
148 | salesAGGDF2.registerTempTable("redshift_sales_agg")
149 |
150 | sqlContext.sql("SELECT * FROM redshift_sales_agg")
151 | .write.format("io.github.spark_redshift_community.spark.redshift")
152 | .option("url", jdbcURL)
153 | .option("tempdir", tempS3Dir)
154 | .option("dbtable", "redshift_sales_agg")
155 | .mode(SaveMode.Overwrite)
156 | .save()
157 | }
158 | }
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/Conversions.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 TouchType Ltd
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.Timestamp
20 | import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat}
21 | import java.time.{DateTimeException, LocalDateTime, ZonedDateTime}
22 | import java.time.format.DateTimeFormatter
23 | import java.util.Locale
24 |
25 | import org.apache.spark.sql.catalyst.InternalRow
26 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
27 | import org.apache.spark.sql.catalyst.expressions.GenericRow
28 | import org.apache.spark.sql.types._
29 |
30 | /**
31 | * Data type conversions for Redshift unloaded data
32 | */
33 | private[redshift] object Conversions {
34 |
35 | /**
36 | * From the DateTimeFormatter docs (Java 8):
37 | * "A formatter created from a pattern can be used as many times as necessary,
38 | * it is immutable and is thread-safe."
39 | */
40 | private val formatter = DateTimeFormatter.ofPattern(
41 | "yyyy-MM-dd HH:mm:ss[.SSSSSS][.SSSSS][.SSSS][.SSS][.SS][.S][X]")
42 |
43 | /**
44 | * Parse a boolean using Redshift's UNLOAD bool syntax
45 | */
46 | private def parseBoolean(s: String): Boolean = {
47 | if (s == "t") true
48 | else if (s == "f") false
49 | else throw new IllegalArgumentException(s"Expected 't' or 'f' but got '$s'")
50 | }
51 |
52 | /**
53 | * Formatter for writing decimals unloaded from Redshift.
54 | *
55 | * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this
56 | * DecimalFormat across threads.
57 | */
58 | def createRedshiftDecimalFormat(): DecimalFormat = {
59 | val format = new DecimalFormat()
60 | format.setParseBigDecimal(true)
61 | format.setDecimalFormatSymbols(new DecimalFormatSymbols(Locale.US))
62 | format
63 | }
64 |
65 | /**
66 | * Formatter for parsing strings exported from Redshift DATE columns.
67 | *
68 | * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this
69 | * SimpleDateFormat across threads.
70 | */
71 | def createRedshiftDateFormat(): SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd")
72 |
73 | /**
74 | * Formatter for formatting timestamps for insertion into Redshift TIMESTAMP columns.
75 | *
76 | * This formatter should not be used to parse timestamps returned from Redshift UNLOAD commands;
77 | * instead, use [[Timestamp.valueOf()]].
78 | *
79 | * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this
80 | * SimpleDateFormat across threads.
81 | */
82 | def createRedshiftTimestampFormat(): SimpleDateFormat = {
83 | new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS")
84 | }
85 |
86 | def parseRedshiftTimestamp(s: String): Timestamp = {
87 | val temporalAccessor = formatter.parse(s)
88 |
89 | try {
90 | // timestamptz
91 | Timestamp.from(ZonedDateTime.from(temporalAccessor).toInstant)
92 | }
93 | catch {
94 | // Case timestamp without timezone
95 | case e: DateTimeException =>
96 | Timestamp.valueOf(LocalDateTime.from(temporalAccessor))
97 | }
98 | }
99 |
100 | /**
101 | * Return a function that will convert arrays of strings conforming to the given schema to Rows.
102 | *
103 | * Note that instances of this function are NOT thread-safe.
104 | */
105 | def createRowConverter(schema: StructType, nullString: String): Array[String] => InternalRow = {
106 | val dateFormat = createRedshiftDateFormat()
107 | val decimalFormat = createRedshiftDecimalFormat()
108 | val conversionFunctions: Array[String => Any] = schema.fields.map { field =>
109 | field.dataType match {
110 | case ByteType => (data: String) => java.lang.Byte.parseByte(data)
111 | case BooleanType => (data: String) => parseBoolean(data)
112 | case DateType => (data: String) => new java.sql.Date(dateFormat.parse(data).getTime)
113 | case DoubleType => (data: String) => data match {
114 | case "nan" => Double.NaN
115 | case "inf" => Double.PositiveInfinity
116 | case "-inf" => Double.NegativeInfinity
117 | case _ => java.lang.Double.parseDouble(data)
118 | }
119 | case FloatType => (data: String) => data match {
120 | case "nan" => Float.NaN
121 | case "inf" => Float.PositiveInfinity
122 | case "-inf" => Float.NegativeInfinity
123 | case _ => java.lang.Float.parseFloat(data)
124 | }
125 | case dt: DecimalType =>
126 | (data: String) => decimalFormat.parse(data).asInstanceOf[java.math.BigDecimal]
127 | case IntegerType => (data: String) => java.lang.Integer.parseInt(data)
128 | case LongType => (data: String) => java.lang.Long.parseLong(data)
129 | case ShortType => (data: String) => java.lang.Short.parseShort(data)
130 | case StringType => (data: String) => data
131 | case TimestampType => (data: String) => parseRedshiftTimestamp(data)
132 | case _ => (data: String) => data
133 | }
134 | }
135 | // As a performance optimization, re-use the same mutable row / array:
136 | val converted: Array[Any] = Array.fill(schema.length)(null)
137 | val externalRow = new GenericRow(converted)
138 | val encoder = RowEncoder(schema)
139 | (inputRow: Array[String]) => {
140 | var i = 0
141 | while (i < schema.length) {
142 | val data = inputRow(i)
143 | converted(i) = if ((data == null || data == nullString) ||
144 | (data.isEmpty && schema.fields(i).dataType != StringType)) {
145 | null
146 | }
147 | else if (data.isEmpty) {
148 | ""
149 | }
150 | else {
151 | conversionFunctions(i)(data)
152 | }
153 | i += 1
154 | }
155 | encoder.toRow(externalRow)
156 | }
157 | }
158 | }
159 |
--------------------------------------------------------------------------------
/CHANGELOG:
--------------------------------------------------------------------------------
1 | # spark-redshift Changelog
2 |
3 | ## 4.1.0
4 |
5 | - Add `include_column_list` parameter
6 |
7 | ## 4.0.2
8 |
9 | - Trim SQL text for preactions and postactions, to fix empty SQL queries bug.
10 |
11 | ## 4.0.1
12 |
13 | - Fix bug when parsing microseconds from Redshift
14 |
15 | ## 4.0.0
16 |
17 | This major release makes spark-redshift compatible with spark 2.4. This was tested in production.
18 |
19 | While upgrading the package we droped some features due to time constraints.
20 |
21 | - Support for hadoop 1.x has been dropped.
22 | - STS and IAM authentication support has been dropped.
23 | - postgresql driver tests are inactive.
24 | - SaveMode tests (or functionality?) are broken. This is a bit scary but I'm not sure we use the functionality
25 | and fixing them didn't make it in this version (spark-snowflake removed them too).
26 | - S3Native has been deprecated. We created an InMemoryS3AFileSystem to test S3A.
27 |
28 | ## 4.0.0-SNAPSHOT
29 | - SNAPSHOT version to test publishing to Maven Central.
30 |
31 | ## 4.0.0-preview20190730 (2019-07-30)
32 |
33 | - The library is tested in production using spark2.4
34 | - RedshiftSourceSuite is again among the scala test suites.
35 |
36 | ## 4.0.0-preview20190715 (2019-07-15)
37 |
38 | Move to pre-4.0.0 'preview' releases rather than SNAPSHOT
39 |
40 | ## 4.0.0-SNAPSHOT-20190710 (2019-07-10)
41 |
42 | Remove AWSCredentialsInUriIntegrationSuite test and require s3a path in CrossRegionIntegrationSuite.scala
43 |
44 | ## 4.0.0-SNAPSHOT-20190627 (2019-06-27)
45 |
46 | Baseline SNAPSHOT version working with 2.4
47 |
48 | #### Deprecation
49 | In order to get this baseline snapshot out, we dropped some features and package versions,
50 | and disabled some tests.
51 | Some of these changes are temporary, others - such as dropping hadoop 1.x - are meant to stay.
52 |
53 | Our intent is to do the best job possible supporting the minimal set of features
54 | that the community needs. Other non-essential features may be dropped before the
55 | first non-snapshot release.
56 | The community's feedback and contributions are vitally important.
57 |
58 |
59 | * Support for hadoop 1.x has been dropped.
60 | * STS and IAM authentication support has been dropped (so are tests).
61 | * postgresql driver tests are inactive.
62 | * SaveMode tests (or functionality?) are broken. This is a bit scarier but I'm not sure we use the functionality and fixing them didn't make it in this version (spark-snowflake removed them too).
63 | * S3Native has been deprecated. It's our intention to phase it out from this repo. The test util ‘inMemoryFilesystem’ is not present anymore so an entire test suite RedshiftSourceSuite lost its major mock object and I had to remove it. We plan to re-write it using s3a.
64 |
65 | #### Commits changelog
66 | - 5b0f949 (HEAD -> master, origin_community/master) Merge pull request #6 from spark-redshift-community/luca-spark-2.4
67 | - 25acded (origin_community/luca-spark-2.4, origin/luca-spark-2.4, luca-spark-2.4) Revert sbt scripts to an older version
68 | - 866d4fd Moving to external github issues - rename spName to spark-redshift-community
69 | - 094cc15 remove in Memory FileSystem class and clean up comments in the sbt build file
70 | - 0666bc6 aws_variables.env gitignored
71 | - f3bbdb7 sbt assembly the package into a fat jar - found the perfect coordination between different libraries versions! Tests pass and can compile spark-on-paasta and spark successfullygit add src/ project/
72 | - b1fa3f6 Ignoring a bunch of tests as did snowflake - close to have a green build to try out
73 | - 95cdf94 Removing conn.commit() everywhere - got 88% of integration tests to run - fix for STS token aws access in progress
74 | - da10897 Compiling - managed to run tests but they mostly fail
75 | - 0fe37d2 Compiles with spark 2.4.0 - amazon unmarshal error
76 | - ea5da29 force spark.avro - hadoop 2.7.7 and awsjavasdk downgraded
77 | - 834f0d6 Upgraded jackson by excluding it in aws
78 | - 90581a8 Fixed NewFilter - including hadoop-aws - s3n test is failing
79 | - 50dfd98 (tag: v3.0.0, tag: gtig, origin/master, origin/HEAD) Merge pull request #5 from Yelp/fdc_first-version
80 | - fbb58b3 (origin/fdc_first-version) First Yelp release
81 | - 0d2a130 Merge pull request #4 from Yelp/fdc_DATALAKE-4899_empty-string-to-null
82 | - 689635c (origin/fdc_DATALAKE-4899_empty-string-to-null) Fix File line length exceeds 100 characters
83 | - d06fe3b Fix scalastyle
84 | - e15ccb5 Fix parenthesis
85 | - d16317e Fix indentation
86 | - 475e7a1 Fix convertion bit and test
87 | - 3ae6a9b Fix Empty string is converted to null
88 | - 967dddb Merge pull request #3 from Yelp/fdc_DATALAKE-486_avoid-log-creds
89 | - 040b4a9 Merge pull request #2 from Yelp/fdc_DATALAKE-488_cleanup-fix-double-to-float
90 | - 58fb829 (origin/fdc_DATALAKE-488_cleanup-fix-double-to-float) Fix test
91 | - 3384333 Add bit and default types
92 | - 3230aaa (origin/fdc_DATALAKE-486_avoid-log-creds) Avoid logging creds. log sql query statement only
93 | - ab8124a Fix double type to float and cleanup
94 | - cafa05f Merge pull request #1 from Yelp/fdc_DATALAKE-563_remove-itests-from-public
95 | - a3a39a2 (origin/fdc_DATALAKE-563_remove-itests-from-public) Remove itests. Fix jdbc url. Update Redshift jdbc driver
96 | - 184b442 Make the note more obvious.
97 | - 717a4ad Notes about inlining this in Databricks Runtime.
98 | - 8adfe95 (origin/fdc_first-test-branch-2) Fix decimal precision loss when reading the results of a Redshift query
99 | - 8da2d92 Test infra housekeeping: reduce SBT memory, update plugin versions, update SBT
100 | - 79bac6d Add instructions on using JitPack master SNAPSHOT builds
101 | - 7a4a08e Use PreparedStatement.getMetaData() to retrieve Redshift query schemas
102 | - b4c6053 Wrap and re-throw Await.result exceptions in order to capture full stacktrace
103 | - 1092c7c Update version in README to 3.0.0-preview1
104 | - 320748a Setting version to 3.0.0-SNAPSHOT
105 | - a28832b (tag: v3.0.0-preview1, origin/fdc_30-review) Setting version to 3.0.0-preview1
106 | - 8afde06 Make Redshift to S3 authentication mechanisms mutually exclusive
107 | - 9ed18a0 Use FileFormat-based data source instead of HadoopRDD for reads
108 | - 6cc49da Add option to use CSV as an intermediate data format during writes
109 | - d508d3e Add documentation and warnings related to using different regions for Redshift and S3
110 | - cdf192a Break RedshiftIntegrationSuite into smaller suites; refactor to remove some redundancy
111 | - bdf4462 Pass around AWSCredentialProviders instead of AWSCredentials
112 | - 51c29e6 Add codecov.yml file.
113 | - a9963da Update AWSCredentialUtils to be uniform between URI schemes.
114 |
115 | ## 3.0.0-SNAPSHOT (2017-11-08)
116 |
117 | Databricks spark-redshift pre-fork, changes not tracked.
118 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/ConversionsSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 TouchType Ltd
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.Timestamp
20 | import java.util.Locale
21 |
22 | import org.apache.spark.sql.Row
23 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
24 | import org.apache.spark.sql.types._
25 | import org.scalatest.FunSuite
26 |
27 | /**
28 | * Unit test for data type conversions
29 | */
30 | class ConversionsSuite extends FunSuite {
31 |
32 | private def createRowConverter(schema: StructType) = {
33 | Conversions.createRowConverter(schema, Parameters.DEFAULT_PARAMETERS("csvnullstring"))
34 | .andThen(RowEncoder(schema).resolveAndBind().fromRow)
35 | }
36 |
37 | test("Data should be correctly converted") {
38 | val convertRow = createRowConverter(TestUtils.testSchema)
39 | val doubleMin = Double.MinValue.toString
40 | val longMax = Long.MaxValue.toString
41 | // scalastyle:off
42 | val unicodeString = "Unicode是樂趣"
43 | // scalastyle:on
44 |
45 | val timestampWithMillis = "2014-03-01 00:00:01.123"
46 | val expectedTimestampMillis = TestUtils.toMillis(2014, 2, 1, 0, 0, 1, 123)
47 |
48 | val expectedDateMillis = TestUtils.toMillis(2015, 6, 1, 0, 0, 0)
49 |
50 | val convertedRow = convertRow(
51 | Array("1", "t", "2015-07-01", doubleMin, "1.0", "42",
52 | longMax, "23", unicodeString, timestampWithMillis))
53 |
54 | val expectedRow = Row(1.asInstanceOf[Byte], true, new Timestamp(expectedDateMillis),
55 | Double.MinValue, 1.0f, 42, Long.MaxValue, 23.toShort, unicodeString,
56 | new Timestamp(expectedTimestampMillis))
57 |
58 | assert(convertedRow == expectedRow)
59 | }
60 |
61 | test("Regression test for parsing timestamptz (bug #25 in spark_redshift_community)") {
62 | val rowConverter = createRowConverter(
63 | StructType(Seq(StructField("timestampWithTimezone", TimestampType))))
64 |
65 | // when converting to timestamp, we discard the TZ info.
66 | val timestampWithTimezone = "2014-03-01 00:00:01.123-03"
67 |
68 | val expectedTimestampWithTimezoneMillis = TestUtils.toMillis(
69 | 2014, 2, 1, 0, 0, 1, 123, "-03")
70 |
71 | val convertedRow = rowConverter(Array(timestampWithTimezone))
72 | val expectedRow = Row(new Timestamp(expectedTimestampWithTimezoneMillis))
73 |
74 | assert(convertedRow == expectedRow)
75 | }
76 |
77 | test("Row conversion handles null values") {
78 | val convertRow = createRowConverter(TestUtils.testSchema)
79 | val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String]
80 | assert(convertRow(emptyRow) === Row(emptyRow: _*))
81 | }
82 |
83 | test("Booleans are correctly converted") {
84 | val convertRow = createRowConverter(StructType(Seq(StructField("a", BooleanType))))
85 | assert(convertRow(Array("t")) === Row(true))
86 | assert(convertRow(Array("f")) === Row(false))
87 | assert(convertRow(Array(null)) === Row(null))
88 | intercept[IllegalArgumentException] {
89 | convertRow(Array("not-a-boolean"))
90 | }
91 | }
92 |
93 | test("timestamp conversion handles millisecond-level precision (regression test for #214)") {
94 | val schema = StructType(Seq(StructField("a", TimestampType)))
95 | val convertRow = createRowConverter(schema)
96 | Seq(
97 | "2014-03-01 00:00:01.123456" ->
98 | TestUtils.toNanosTimestamp(2014, 2, 1, 0, 0, 1, nanos = 123456000),
99 | "2014-03-01 00:00:01.12345" ->
100 | TestUtils.toNanosTimestamp(2014, 2, 1, 0, 0, 1, nanos = 123450000),
101 | "2014-03-01 00:00:01.1234" ->
102 | TestUtils.toNanosTimestamp(2014, 2, 1, 0, 0, 1, nanos = 123400000),
103 | "2014-03-01 00:00:01" ->
104 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 1000),
105 | "2014-03-01 00:00:01.000" ->
106 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 1000),
107 | "2014-03-01 00:00:00.1" ->
108 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 100),
109 | "2014-03-01 00:00:00.10" ->
110 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 100),
111 | "2014-03-01 00:00:00.100" ->
112 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 100),
113 | "2014-03-01 00:00:00.01" ->
114 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 10),
115 | "2014-03-01 00:00:00.010" ->
116 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 10),
117 | "2014-03-01 00:00:00.001" ->
118 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 1)
119 | ).foreach { case (timestampString, expectedTime) =>
120 | withClue(s"timestamp string is '$timestampString'") {
121 | val convertedRow = convertRow(Array(timestampString))
122 | val convertedTimestamp = convertedRow.get(0).asInstanceOf[Timestamp]
123 | assert(convertedTimestamp === expectedTime)
124 | }
125 | }
126 | }
127 |
128 | test("RedshiftDecimalFormat is locale-insensitive (regression test for #243)") {
129 | for (locale <- Seq(Locale.US, Locale.GERMAN, Locale.UK)) {
130 | withClue(s"locale = $locale") {
131 | TestUtils.withDefaultLocale(locale) {
132 | val decimalFormat = Conversions.createRedshiftDecimalFormat()
133 | val parsed = decimalFormat.parse("151.20").asInstanceOf[java.math.BigDecimal]
134 | assert(parsed.doubleValue() === 151.20)
135 | }
136 | }
137 | }
138 | }
139 |
140 | test("Row conversion properly handles NaN and Inf float values (regression test for #261)") {
141 | val convertRow = createRowConverter(StructType(Seq(StructField("a", FloatType))))
142 | assert(java.lang.Float.isNaN(convertRow(Array("nan")).getFloat(0)))
143 | assert(convertRow(Array("inf")) === Row(Float.PositiveInfinity))
144 | assert(convertRow(Array("-inf")) === Row(Float.NegativeInfinity))
145 | }
146 |
147 | test("Row conversion properly handles NaN and Inf double values (regression test for #261)") {
148 | val convertRow = createRowConverter(StructType(Seq(StructField("a", DoubleType))))
149 | assert(java.lang.Double.isNaN(convertRow(Array("nan")).getDouble(0)))
150 | assert(convertRow(Array("inf")) === Row(Double.PositiveInfinity))
151 | assert(convertRow(Array("-inf")) === Row(Double.NegativeInfinity))
152 | }
153 | }
154 |
--------------------------------------------------------------------------------
/src/test/java/io/github/spark_redshift_community/spark/redshift/InMemoryS3AFileSystem.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one
3 | * or more contributor license agreements. See the NOTICE file
4 | * distributed with this work for additional information
5 | * regarding copyright ownership. The ASF licenses this file
6 | * to you under the Apache License, Version 2.0 (the
7 | * "License"); you may not use this file except in compliance
8 | * with the License. You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 |
19 | package io.github.spark_redshift_community.spark.redshift;
20 |
21 | import java.io.*;
22 | import java.net.URI;
23 | import java.util.*;
24 |
25 | import org.apache.hadoop.fs.*;
26 | import org.apache.hadoop.fs.permission.FsPermission;
27 | import org.apache.hadoop.fs.s3a.S3AFileStatus;
28 |
29 | import org.apache.hadoop.conf.Configuration;
30 | import org.apache.hadoop.util.Progressable;
31 |
32 |
33 | /**
34 | * A stub implementation of NativeFileSystemStore for testing
35 | * S3AFileSystem without actually connecting to S3.
36 | */
37 | public class InMemoryS3AFileSystem extends FileSystem {
38 | public static final String BUCKET = "test-bucket";
39 | public static final URI FS_URI = URI.create("s3a://" + BUCKET + "/");
40 |
41 | private static final long DEFAULT_BLOCK_SIZE_TEST = 33554432;
42 |
43 | private final Path root = new Path(FS_URI.toString());
44 |
45 | private SortedMap dataMap = new TreeMap();
46 |
47 | private Configuration conf;
48 |
49 | @Override
50 | public URI getUri() {
51 | return FS_URI;
52 | }
53 |
54 | @Override
55 | public Path getWorkingDirectory() {
56 | return new Path(root, "work");
57 | }
58 |
59 | @Override
60 | public boolean mkdirs(Path f, FsPermission permission) throws IOException {
61 | // Not implemented
62 | return false;
63 | }
64 |
65 | @Override
66 | public void initialize(URI name, Configuration originalConf)
67 | throws IOException {
68 | conf = originalConf;
69 | }
70 |
71 | @Override
72 | public Configuration getConf() {
73 | return conf;
74 | }
75 |
76 | @Override
77 | public boolean exists(Path f) throws IOException {
78 |
79 | SortedMap subMap = dataMap.tailMap(toS3Key(f));
80 | for (String filePath: subMap.keySet()) {
81 | if (filePath.contains(toS3Key(f))) {
82 | return true;
83 | }
84 | }
85 | return false;
86 | }
87 |
88 | private String toS3Key(Path f) {
89 | return f.toString();
90 | }
91 |
92 | @Override
93 | public FSDataInputStream open(Path f) throws IOException {
94 | if (getFileStatus(f).isDirectory())
95 | throw new IOException("TESTING: path can't be opened - it's a directory");
96 |
97 | return new FSDataInputStream(
98 | new SeekableByteArrayInputStream(
99 | dataMap.get(toS3Key(f)).toByteArray()
100 | )
101 | );
102 | }
103 |
104 | @Override
105 | public FSDataInputStream open(Path f, int bufferSize) throws IOException {
106 | return open(f);
107 | }
108 |
109 | @Override
110 | public FSDataOutputStream create(Path f) throws IOException {
111 |
112 | if (exists(f)) {
113 | throw new FileAlreadyExistsException();
114 | }
115 |
116 | String key = toS3Key(f);
117 | ByteArrayOutputStream inMemoryS3File = new ByteArrayOutputStream();
118 |
119 | dataMap.put(key, inMemoryS3File);
120 |
121 | return new FSDataOutputStream(inMemoryS3File);
122 |
123 | }
124 |
125 | @Override
126 | public FSDataOutputStream create(Path f, FsPermission permission, boolean overwrite, int bufferSize, short replication, long blockSize, Progressable progress) throws IOException {
127 | // Not Implemented
128 | return null;
129 | }
130 |
131 | @Override
132 | public FSDataOutputStream append(Path f, int bufferSize, Progressable progress) throws IOException {
133 | // Not Implemented
134 | return null;
135 | }
136 |
137 | @Override
138 | public boolean rename(Path src, Path dst) throws IOException {
139 | dataMap.put(toS3Key(dst), dataMap.get(toS3Key(src)));
140 | return true;
141 | }
142 |
143 | @Override
144 | public boolean delete(Path f, boolean recursive) throws IOException {
145 | dataMap.remove(toS3Key(f));
146 | return true;
147 | }
148 |
149 | private Set childPaths(Path f) {
150 | Set children = new HashSet<>();
151 |
152 | String fDir = f + "/";
153 | for (String subKey: dataMap.tailMap(toS3Key(f)).keySet()){
154 | children.add(
155 | fDir + subKey.replace(fDir, "").split("/")[0]
156 | );
157 | }
158 | return children;
159 | }
160 |
161 | @Override
162 | public FileStatus[] listStatus(Path f) throws IOException {
163 |
164 | if (!exists(f)) throw new FileNotFoundException();
165 |
166 | if (getFileStatus(f).isDirectory()){
167 | ArrayList statuses = new ArrayList<>();
168 |
169 | for (String child: childPaths(f)) {
170 | statuses.add(getFileStatus(new Path(child)));
171 | }
172 |
173 | FileStatus[] arrayStatuses = new FileStatus[statuses.size()];
174 | return statuses.toArray(arrayStatuses);
175 | }
176 |
177 | else {
178 | FileStatus[] statuses = new FileStatus[1];
179 | statuses[0] = this.getFileStatus(f);
180 | return statuses;
181 | }
182 | }
183 |
184 | @Override
185 | public void setWorkingDirectory(Path new_dir) {
186 | // Not implemented
187 | }
188 |
189 | private boolean isDir(Path f) throws IOException{
190 | return exists(f) && dataMap.get(toS3Key(f)) == null;
191 | }
192 |
193 |
194 | @Override
195 | public S3AFileStatus getFileStatus(Path f) throws IOException {
196 |
197 | if (!exists(f)) throw new FileNotFoundException();
198 |
199 | if (isDir(f)) {
200 | return new S3AFileStatus(
201 | true,
202 | dataMap.tailMap(toS3Key(f)).size() == 1 && dataMap.containsKey(toS3Key(f)),
203 | f
204 | );
205 | }
206 | else {
207 | return new S3AFileStatus(
208 | dataMap.get(toS3Key(f)).toByteArray().length,
209 | System.currentTimeMillis(),
210 | f,
211 | this.getDefaultBlockSize()
212 | );
213 | }
214 | }
215 |
216 | @Override
217 | @SuppressWarnings("deprecation")
218 | public long getDefaultBlockSize() {
219 | return DEFAULT_BLOCK_SIZE_TEST;
220 | }
221 | }
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriteSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.SQLException
20 |
21 | import org.apache.spark.sql._
22 | import org.apache.spark.sql.types._
23 |
24 | /**
25 | * End-to-end tests of functionality which involves writing to Redshift via the connector.
26 | */
27 | abstract class BaseRedshiftWriteSuite extends IntegrationSuiteBase {
28 |
29 | protected val tempformat: String
30 |
31 | override protected def write(df: DataFrame): DataFrameWriter[Row] =
32 | super.write(df).option("tempformat", tempformat)
33 |
34 | test("roundtrip save and load") {
35 | // This test can be simplified once #98 is fixed.
36 | val tableName = s"roundtrip_save_and_load_$randomSuffix"
37 | try {
38 | write(
39 | sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema))
40 | .option("dbtable", tableName)
41 | .mode(SaveMode.ErrorIfExists)
42 | .save()
43 |
44 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
45 | checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData)
46 | } finally {
47 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
48 | }
49 | }
50 |
51 | test("roundtrip save and load with uppercase column names") {
52 | testRoundtripSaveAndLoad(
53 | s"roundtrip_write_and_read_with_uppercase_column_names_$randomSuffix",
54 | sqlContext.createDataFrame(
55 | sc.parallelize(Seq(Row(1))), StructType(StructField("SomeColumn", IntegerType) :: Nil)
56 | ),
57 | expectedSchemaAfterLoad = Some(StructType(StructField("somecolumn", IntegerType) :: Nil))
58 | )
59 | }
60 |
61 | test("save with column names that are reserved words") {
62 | testRoundtripSaveAndLoad(
63 | s"save_with_column_names_that_are_reserved_words_$randomSuffix",
64 | sqlContext.createDataFrame(
65 | sc.parallelize(Seq(Row(1))),
66 | StructType(StructField("table", IntegerType) :: Nil)
67 | )
68 | )
69 | }
70 |
71 | test("save with one empty partition (regression test for #96)") {
72 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 2),
73 | StructType(StructField("foo", IntegerType) :: Nil))
74 | assert(df.rdd.glom.collect() === Array(Array.empty[Row], Array(Row(1))))
75 | testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df)
76 | }
77 |
78 | test("save with all empty partitions (regression test for #96)") {
79 | val df = sqlContext.createDataFrame(sc.parallelize(Seq.empty[Row], 2),
80 | StructType(StructField("foo", IntegerType) :: Nil))
81 | assert(df.rdd.glom.collect() === Array(Array.empty[Row], Array.empty[Row]))
82 | testRoundtripSaveAndLoad(s"save_with_all_empty_partitions_$randomSuffix", df)
83 | // Now try overwriting that table. Although the new table is empty, it should still overwrite
84 | // the existing table.
85 | val df2 = df.withColumnRenamed("foo", "bar")
86 | testRoundtripSaveAndLoad(
87 | s"save_with_all_empty_partitions_$randomSuffix", df2, saveMode = SaveMode.Overwrite)
88 | }
89 |
90 | test("informative error message when saving a table with string that is longer than max length") {
91 | val tableName = s"error_message_when_string_too_long_$randomSuffix"
92 | try {
93 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 512))),
94 | StructType(StructField("A", StringType) :: Nil))
95 | val e = intercept[SQLException] {
96 | write(df)
97 | .option("dbtable", tableName)
98 | .mode(SaveMode.ErrorIfExists)
99 | .save()
100 | }
101 | assert(e.getMessage.contains("while loading data into Redshift"))
102 | } finally {
103 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
104 | }
105 | }
106 |
107 | test("full timestamp precision is preserved in loads (regression test for #214)") {
108 | val timestamps = Seq(
109 | TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 1),
110 | TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 10),
111 | TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 100),
112 | TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 1000))
113 | testRoundtripSaveAndLoad(
114 | s"full_timestamp_precision_is_preserved$randomSuffix",
115 | sqlContext.createDataFrame(sc.parallelize(timestamps.map(Row(_))),
116 | StructType(StructField("ts", TimestampType) :: Nil))
117 | )
118 | }
119 | }
120 |
121 | class AvroRedshiftWriteSuite extends BaseRedshiftWriteSuite {
122 | override protected val tempformat: String = "AVRO"
123 |
124 | test("informative error message when saving with column names that contain spaces (#84)") {
125 | intercept[IllegalArgumentException] {
126 | testRoundtripSaveAndLoad(
127 | s"error_when_saving_column_name_with_spaces_$randomSuffix",
128 | sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
129 | StructType(StructField("column name with spaces", IntegerType) :: Nil)))
130 | }
131 | }
132 | }
133 |
134 | class CSVRedshiftWriteSuite extends BaseRedshiftWriteSuite {
135 | override protected val tempformat: String = "CSV"
136 |
137 | test("save with column names that contain spaces (#84)") {
138 | testRoundtripSaveAndLoad(
139 | s"save_with_column_names_that_contain_spaces_$randomSuffix",
140 | sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
141 | StructType(StructField("column name with spaces", IntegerType) :: Nil)))
142 | }
143 | }
144 |
145 | class CSVGZIPRedshiftWriteSuite extends IntegrationSuiteBase {
146 | // Note: we purposely don't inherit from BaseRedshiftWriteSuite because we're only interested in
147 | // testing basic functionality of the GZIP code; the rest of the write path should be unaffected
148 | // by compression here.
149 |
150 | override protected def write(df: DataFrame): DataFrameWriter[Row] =
151 | super.write(df).option("tempformat", "CSV GZIP")
152 |
153 | test("roundtrip save and load") {
154 | // This test can be simplified once #98 is fixed.
155 | val tableName = s"roundtrip_save_and_load_$randomSuffix"
156 | try {
157 | write(
158 | sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema))
159 | .option("dbtable", tableName)
160 | .mode(SaveMode.ErrorIfExists)
161 | .save()
162 |
163 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
164 | checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData)
165 | } finally {
166 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
167 | }
168 | }
169 | }
170 |
--------------------------------------------------------------------------------
/src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 TouchType Ltd
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import org.scalatest.{FunSuite, Matchers}
20 |
21 | /**
22 | * Check validation of parameter config
23 | */
24 | class ParametersSuite extends FunSuite with Matchers {
25 |
26 | test("Minimal valid parameter map is accepted") {
27 | val params = Map(
28 | "tempdir" -> "s3://foo/bar",
29 | "dbtable" -> "test_schema.test_table",
30 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password",
31 | "forward_spark_s3_credentials" -> "true",
32 | "include_column_list" -> "true")
33 |
34 | val mergedParams = Parameters.mergeParameters(params)
35 |
36 | mergedParams.rootTempDir should startWith(params("tempdir"))
37 | mergedParams.createPerQueryTempDir() should startWith(params("tempdir"))
38 | mergedParams.jdbcUrl shouldBe params("url")
39 | mergedParams.table shouldBe Some(TableName("test_schema", "test_table"))
40 | assert(mergedParams.forwardSparkS3Credentials)
41 | assert(mergedParams.includeColumnList)
42 |
43 | // Check that the defaults have been added
44 | (
45 | Parameters.DEFAULT_PARAMETERS
46 | - "forward_spark_s3_credentials"
47 | - "include_column_list"
48 | ).foreach {
49 | case (key, value) => mergedParams.parameters(key) shouldBe value
50 | }
51 | }
52 |
53 | test("createPerQueryTempDir() returns distinct temp paths") {
54 | val params = Map(
55 | "forward_spark_s3_credentials" -> "true",
56 | "tempdir" -> "s3://foo/bar",
57 | "dbtable" -> "test_table",
58 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password")
59 |
60 | val mergedParams = Parameters.mergeParameters(params)
61 |
62 | mergedParams.createPerQueryTempDir() should not equal mergedParams.createPerQueryTempDir()
63 | }
64 |
65 | test("Errors are thrown when mandatory parameters are not provided") {
66 | def checkMerge(params: Map[String, String], err: String): Unit = {
67 | val e = intercept[IllegalArgumentException] {
68 | Parameters.mergeParameters(params)
69 | }
70 | assert(e.getMessage.contains(err))
71 | }
72 |
73 | val testURL = "jdbc:redshift://foo/bar?user=user&password=password"
74 | checkMerge(Map("dbtable" -> "test_table", "url" -> testURL), "tempdir")
75 | checkMerge(Map("tempdir" -> "s3://foo/bar", "url" -> testURL), "Redshift table name")
76 | checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar"), "JDBC URL")
77 | checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar", "url" -> testURL),
78 | "method for authenticating")
79 | }
80 |
81 | test("Must specify either 'dbtable' or 'query' parameter, but not both") {
82 | intercept[IllegalArgumentException] {
83 | Parameters.mergeParameters(Map(
84 | "forward_spark_s3_credentials" -> "true",
85 | "tempdir" -> "s3://foo/bar",
86 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
87 | }.getMessage should (include("dbtable") and include("query"))
88 |
89 | intercept[IllegalArgumentException] {
90 | Parameters.mergeParameters(Map(
91 | "forward_spark_s3_credentials" -> "true",
92 | "tempdir" -> "s3://foo/bar",
93 | "dbtable" -> "test_table",
94 | "query" -> "select * from test_table",
95 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
96 | }.getMessage should (include("dbtable") and include("query") and include("both"))
97 |
98 | Parameters.mergeParameters(Map(
99 | "forward_spark_s3_credentials" -> "true",
100 | "tempdir" -> "s3://foo/bar",
101 | "query" -> "select * from test_table",
102 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
103 | }
104 |
105 | test("Must specify credentials in either URL or 'user' and 'password' parameters, but not both") {
106 | intercept[IllegalArgumentException] {
107 | Parameters.mergeParameters(Map(
108 | "forward_spark_s3_credentials" -> "true",
109 | "tempdir" -> "s3://foo/bar",
110 | "query" -> "select * from test_table",
111 | "url" -> "jdbc:redshift://foo/bar"))
112 | }.getMessage should (include("credentials"))
113 |
114 | intercept[IllegalArgumentException] {
115 | Parameters.mergeParameters(Map(
116 | "forward_spark_s3_credentials" -> "true",
117 | "tempdir" -> "s3://foo/bar",
118 | "query" -> "select * from test_table",
119 | "user" -> "user",
120 | "password" -> "password",
121 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
122 | }.getMessage should (include("credentials") and include("both"))
123 |
124 | Parameters.mergeParameters(Map(
125 | "forward_spark_s3_credentials" -> "true",
126 | "tempdir" -> "s3://foo/bar",
127 | "query" -> "select * from test_table",
128 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password"))
129 | }
130 |
131 | test("tempformat option is case-insensitive") {
132 | val params = Map(
133 | "forward_spark_s3_credentials" -> "true",
134 | "tempdir" -> "s3://foo/bar",
135 | "dbtable" -> "test_schema.test_table",
136 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password")
137 |
138 | Parameters.mergeParameters(params + ("tempformat" -> "csv"))
139 | Parameters.mergeParameters(params + ("tempformat" -> "CSV"))
140 |
141 | intercept[IllegalArgumentException] {
142 | Parameters.mergeParameters(params + ("tempformat" -> "invalid-temp-format"))
143 | }
144 | }
145 |
146 | test("can only specify one Redshift to S3 authentication mechanism") {
147 | val e = intercept[IllegalArgumentException] {
148 | Parameters.mergeParameters(Map(
149 | "tempdir" -> "s3://foo/bar",
150 | "dbtable" -> "test_schema.test_table",
151 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password",
152 | "forward_spark_s3_credentials" -> "true",
153 | "aws_iam_role" -> "role"))
154 | }
155 | assert(e.getMessage.contains("mutually-exclusive"))
156 | }
157 |
158 | test("preaction and postactions should be trimmed before splitting by semicolon") {
159 | val params = Parameters.mergeParameters(Map(
160 | "forward_spark_s3_credentials" -> "true",
161 | "tempdir" -> "s3://foo/bar",
162 | "dbtable" -> "test_schema.test_table",
163 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password",
164 | "preactions" -> "update table1 set col1 = val1;update table1 set col2 = val2; ",
165 | "postactions" -> "update table2 set col1 = val1;update table2 set col2 = val2; "
166 | ))
167 |
168 | assert(params.preActions.length == 2)
169 | assert(params.preActions.head == "update table1 set col1 = val1")
170 | assert(params.preActions.last == "update table1 set col2 = val2")
171 | assert(params.postActions.length == 2)
172 | assert(params.postActions.head == "update table2 set col1 = val1")
173 | assert(params.postActions.last == "update table2 set col2 = val2")
174 | }
175 |
176 | }
177 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/Utils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 TouchType Ltd
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.net.URI
20 | import java.util.UUID
21 |
22 | import com.amazonaws.services.s3.model.BucketLifecycleConfiguration
23 | import com.amazonaws.services.s3.{AmazonS3Client, AmazonS3URI}
24 | import org.apache.hadoop.conf.Configuration
25 | import org.apache.hadoop.fs.FileSystem
26 | import org.slf4j.LoggerFactory
27 |
28 | import scala.collection.JavaConverters._
29 | import scala.util.control.NonFatal
30 |
31 | /**
32 | * Various arbitrary helper functions
33 | */
34 | private[redshift] object Utils {
35 |
36 | private val log = LoggerFactory.getLogger(getClass)
37 |
38 | def classForName(className: String): Class[_] = {
39 | val classLoader =
40 | Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader)
41 | // scalastyle:off
42 | Class.forName(className, true, classLoader)
43 | // scalastyle:on
44 | }
45 |
46 | /**
47 | * Joins prefix URL a to path suffix b, and appends a trailing /, in order to create
48 | * a temp directory path for S3.
49 | */
50 | def joinUrls(a: String, b: String): String = {
51 | a.stripSuffix("/") + "/" + b.stripPrefix("/").stripSuffix("/") + "/"
52 | }
53 |
54 | /**
55 | * Redshift COPY and UNLOAD commands don't support s3n or s3a, but users may wish to use them
56 | * for data loads. This function converts the URL back to the s3:// format.
57 | */
58 | def fixS3Url(url: String): String = {
59 | url.replaceAll("s3[an]://", "s3://")
60 | }
61 |
62 | /**
63 | * Factory method to create new S3URI in order to handle various library incompatibilities with
64 | * older AWS Java Libraries
65 | */
66 | def createS3URI(url: String): AmazonS3URI = {
67 | try {
68 | // try to instantiate AmazonS3URI with url
69 | new AmazonS3URI(url)
70 | } catch {
71 | case e: IllegalArgumentException if e.getMessage.
72 | startsWith("Invalid S3 URI: hostname does not appear to be a valid S3 endpoint") => {
73 | new AmazonS3URI(addEndpointToUrl(url))
74 | }
75 | }
76 | }
77 |
78 | /**
79 | * Since older AWS Java Libraries do not handle S3 urls that have just the bucket name
80 | * as the host, add the endpoint to the host
81 | */
82 | def addEndpointToUrl(url: String, domain: String = "s3.amazonaws.com"): String = {
83 | val uri = new URI(url)
84 | val hostWithEndpoint = uri.getHost + "." + domain
85 | new URI(uri.getScheme,
86 | uri.getUserInfo,
87 | hostWithEndpoint,
88 | uri.getPort,
89 | uri.getPath,
90 | uri.getQuery,
91 | uri.getFragment).toString
92 | }
93 |
94 | /**
95 | * Returns a copy of the given URI with the user credentials removed.
96 | */
97 | def removeCredentialsFromURI(uri: URI): URI = {
98 | new URI(
99 | uri.getScheme,
100 | null, // no user info
101 | uri.getHost,
102 | uri.getPort,
103 | uri.getPath,
104 | uri.getQuery,
105 | uri.getFragment)
106 | }
107 |
108 | // Visible for testing
109 | private[redshift] var lastTempPathGenerated: String = null
110 |
111 | /**
112 | * Creates a randomly named temp directory path for intermediate data
113 | */
114 | def makeTempPath(tempRoot: String): String = {
115 | lastTempPathGenerated = Utils.joinUrls(tempRoot, UUID.randomUUID().toString)
116 | lastTempPathGenerated
117 | }
118 |
119 | /**
120 | * Checks whether the S3 bucket for the given UI has an object lifecycle configuration to
121 | * ensure cleanup of temporary files. If no applicable configuration is found, this method logs
122 | * a helpful warning for the user.
123 | */
124 | def checkThatBucketHasObjectLifecycleConfiguration(
125 | tempDir: String,
126 | s3Client: AmazonS3Client): Unit = {
127 | try {
128 | val s3URI = createS3URI(Utils.fixS3Url(tempDir))
129 | val bucket = s3URI.getBucket
130 | assert(bucket != null, "Could not get bucket from S3 URI")
131 | val key = Option(s3URI.getKey).getOrElse("")
132 | val hasMatchingBucketLifecycleRule: Boolean = {
133 | val rules = Option(s3Client.getBucketLifecycleConfiguration(bucket))
134 | .map(_.getRules.asScala)
135 | .getOrElse(Seq.empty)
136 | rules.exists { rule =>
137 | // Note: this only checks that there is an active rule which matches the temp directory;
138 | // it does not actually check that the rule will delete the files. This check is still
139 | // better than nothing, though, and we can always improve it later.
140 | rule.getStatus == BucketLifecycleConfiguration.ENABLED && key.startsWith(rule.getPrefix)
141 | }
142 | }
143 | if (!hasMatchingBucketLifecycleRule) {
144 | log.warn(s"The S3 bucket $bucket does not have an object lifecycle configuration to " +
145 | "ensure cleanup of temporary files. Consider configuring `tempdir` to point to a " +
146 | "bucket with an object lifecycle policy that automatically deletes files after an " +
147 | "expiration period. For more information, see " +
148 | "https://docs.aws.amazon.com/AmazonS3/latest/dev/object-lifecycle-mgmt.html")
149 | }
150 | } catch {
151 | case NonFatal(e) =>
152 | log.warn("An error occurred while trying to read the S3 bucket lifecycle configuration", e)
153 | }
154 | }
155 |
156 | /**
157 | * Given a URI, verify that the Hadoop FileSystem for that URI is not the S3 block FileSystem.
158 | * `spark-redshift` cannot use this FileSystem because the files written to it will not be
159 | * readable by Redshift (and vice versa).
160 | */
161 | def assertThatFileSystemIsNotS3BlockFileSystem(uri: URI, hadoopConfig: Configuration): Unit = {
162 | val fs = FileSystem.get(uri, hadoopConfig)
163 | // Note that we do not want to use isInstanceOf here, since we're only interested in detecting
164 | // exact matches. We compare the class names as strings in order to avoid introducing a binary
165 | // dependency on classes which belong to the `hadoop-aws` JAR, as that artifact is not present
166 | // in some environments (such as EMR). See #92 for details.
167 | if (fs.getClass.getCanonicalName == "org.apache.hadoop.fs.s3.S3FileSystem") {
168 | throw new IllegalArgumentException(
169 | "spark-redshift does not support the S3 Block FileSystem. Please reconfigure `tempdir` to" +
170 | "use a s3n:// or s3a:// scheme.")
171 | }
172 | }
173 |
174 | /**
175 | * Attempts to retrieve the region of the S3 bucket.
176 | */
177 | def getRegionForS3Bucket(tempDir: String, s3Client: AmazonS3Client): Option[String] = {
178 | try {
179 | val s3URI = createS3URI(Utils.fixS3Url(tempDir))
180 | val bucket = s3URI.getBucket
181 | assert(bucket != null, "Could not get bucket from S3 URI")
182 | val region = s3Client.getBucketLocation(bucket) match {
183 | // Map "US Standard" to us-east-1
184 | case null | "US" => "us-east-1"
185 | case other => other
186 | }
187 | Some(region)
188 | } catch {
189 | case NonFatal(e) =>
190 | log.warn("An error occurred while trying to determine the S3 bucket's region", e)
191 | None
192 | }
193 | }
194 |
195 | /**
196 | * Attempts to determine the region of a Redshift cluster based on its URL. It may not be possible
197 | * to determine the region in some cases, such as when the Redshift cluster is placed behind a
198 | * proxy.
199 | */
200 | def getRegionForRedshiftCluster(url: String): Option[String] = {
201 | val regionRegex = """.*\.([^.]+)\.redshift\.amazonaws\.com.*""".r
202 | url match {
203 | case regionRegex(region) => Some(region)
204 | case _ => None
205 | }
206 | }
207 | }
208 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftInputFormat.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2014 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.io.{BufferedInputStream, IOException}
20 | import java.lang.{Long => JavaLong}
21 | import java.nio.charset.Charset
22 |
23 | import org.apache.hadoop.conf.Configuration
24 | import org.apache.hadoop.fs.{FileSystem, Path}
25 | import org.apache.hadoop.io.compress.CompressionCodecFactory
26 | import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
27 | import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
28 |
29 | import scala.collection.mutable.ArrayBuffer
30 |
31 | /**
32 | * Input format for text records saved with in-record delimiter and newline characters escaped.
33 | *
34 | * For example, a record containing two fields: `"a\n"` and `"|b\\"` saved with delimiter `|`
35 | * should be the following:
36 | * {{{
37 | * a\\\n|\\|b\\\\\n
38 | * }}},
39 | * where the in-record `|`, `\r`, `\n`, and `\\` characters are escaped by `\\`.
40 | * Users can configure the delimiter via [[RedshiftInputFormat$#KEY_DELIMITER]].
41 | * Its default value [[RedshiftInputFormat$#DEFAULT_DELIMITER]] is set to match Redshift's UNLOAD
42 | * with the ESCAPE option:
43 | * {{{
44 | * UNLOAD ('select_statement')
45 | * TO 's3://object_path_prefix'
46 | * ESCAPE
47 | * }}}
48 | *
49 | * @see org.apache.spark.SparkContext#newAPIHadoopFile
50 | */
51 | class RedshiftInputFormat extends FileInputFormat[JavaLong, Array[String]] {
52 |
53 | override def createRecordReader(
54 | split: InputSplit,
55 | context: TaskAttemptContext): RecordReader[JavaLong, Array[String]] = {
56 | new RedshiftRecordReader
57 | }
58 | }
59 |
60 | object RedshiftInputFormat {
61 |
62 | /** configuration key for delimiter */
63 | val KEY_DELIMITER = "redshift.delimiter"
64 | /** default delimiter */
65 | val DEFAULT_DELIMITER = '|'
66 |
67 | /** Gets the delimiter char from conf or the default. */
68 | private[redshift] def getDelimiterOrDefault(conf: Configuration): Char = {
69 | val c = conf.get(KEY_DELIMITER, DEFAULT_DELIMITER.toString)
70 | if (c.length != 1) {
71 | throw new IllegalArgumentException(s"Expect delimiter be a single character but got '$c'.")
72 | } else {
73 | c.charAt(0)
74 | }
75 | }
76 | }
77 |
78 | private[redshift] class RedshiftRecordReader extends RecordReader[JavaLong, Array[String]] {
79 |
80 | private var reader: BufferedInputStream = _
81 |
82 | private var key: JavaLong = _
83 | private var value: Array[String] = _
84 |
85 | private var start: Long = _
86 | private var end: Long = _
87 | private var cur: Long = _
88 |
89 | private var eof: Boolean = false
90 |
91 | private var delimiter: Byte = _
92 | @inline private[this] final val escapeChar: Byte = '\\'
93 | @inline private[this] final val lineFeed: Byte = '\n'
94 | @inline private[this] final val carriageReturn: Byte = '\r'
95 |
96 | @inline private[this] final val defaultBufferSize = 1024 * 1024
97 |
98 | private[this] val chars = ArrayBuffer.empty[Byte]
99 |
100 | override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = {
101 | val split = inputSplit.asInstanceOf[FileSplit]
102 | val file = split.getPath
103 | val conf: Configuration = context.getConfiguration
104 | delimiter = RedshiftInputFormat.getDelimiterOrDefault(conf).asInstanceOf[Byte]
105 | require(delimiter != escapeChar,
106 | s"The delimiter and the escape char cannot be the same but found $delimiter.")
107 | require(delimiter != lineFeed, "The delimiter cannot be the lineFeed character.")
108 | require(delimiter != carriageReturn, "The delimiter cannot be the carriage return.")
109 | val compressionCodecs = new CompressionCodecFactory(conf)
110 | val codec = compressionCodecs.getCodec(file)
111 | if (codec != null) {
112 | throw new IOException(s"Do not support compressed files but found $file.")
113 | }
114 | val fs = file.getFileSystem(conf)
115 | val size = fs.getFileStatus(file).getLen
116 | start = findNext(fs, file, size, split.getStart)
117 | end = findNext(fs, file, size, split.getStart + split.getLength)
118 | cur = start
119 | val in = fs.open(file)
120 | if (cur > 0L) {
121 | in.seek(cur - 1L)
122 | in.read()
123 | }
124 | reader = new BufferedInputStream(in, defaultBufferSize)
125 | }
126 |
127 | override def getProgress: Float = {
128 | if (start >= end) {
129 | 1.0f
130 | } else {
131 | math.min((cur - start).toFloat / (end - start), 1.0f)
132 | }
133 | }
134 |
135 | override def nextKeyValue(): Boolean = {
136 | if (cur < end && !eof) {
137 | key = cur
138 | value = nextValue()
139 | true
140 | } else {
141 | key = null
142 | value = null
143 | false
144 | }
145 | }
146 |
147 | override def getCurrentValue: Array[String] = value
148 |
149 | override def getCurrentKey: JavaLong = key
150 |
151 | override def close(): Unit = {
152 | if (reader != null) {
153 | reader.close()
154 | }
155 | }
156 |
157 | /**
158 | * Finds the start of the next record.
159 | * Because we don't know whether the first char is escaped or not, we need to first find a
160 | * position that is not escaped.
161 | *
162 | * @param fs file system
163 | * @param file file path
164 | * @param size file size
165 | * @param offset start offset
166 | * @return the start position of the next record
167 | */
168 | private def findNext(fs: FileSystem, file: Path, size: Long, offset: Long): Long = {
169 | if (offset == 0L) {
170 | return 0L
171 | } else if (offset >= size) {
172 | return size
173 | }
174 | val in = fs.open(file)
175 | var pos = offset
176 | in.seek(pos)
177 | val bis = new BufferedInputStream(in, defaultBufferSize)
178 | // Find the first unescaped char.
179 | var escaped = true
180 | var thisEof = false
181 | while (escaped && !thisEof) {
182 | val v = bis.read()
183 | if (v < 0) {
184 | thisEof = true
185 | } else {
186 | pos += 1
187 | if (v != escapeChar) {
188 | escaped = false
189 | }
190 | }
191 | }
192 | // Find the next unescaped line feed.
193 | var endOfRecord = false
194 | while ((escaped || !endOfRecord) && !thisEof) {
195 | val v = bis.read()
196 | if (v < 0) {
197 | thisEof = true
198 | } else {
199 | pos += 1
200 | if (v == escapeChar) {
201 | escaped = true
202 | } else {
203 | if (!escaped) {
204 | endOfRecord = v == lineFeed
205 | } else {
206 | escaped = false
207 | }
208 | }
209 | }
210 | }
211 | in.close()
212 | pos
213 | }
214 |
215 | private def nextValue(): Array[String] = {
216 | val fields = ArrayBuffer.empty[String]
217 | var escaped = false
218 | var endOfRecord = false
219 | while (!endOfRecord && !eof) {
220 | var endOfField = false
221 | chars.clear()
222 | while (!endOfField && !endOfRecord && !eof) {
223 | val v = reader.read()
224 | if (v < 0) {
225 | eof = true
226 | } else {
227 | cur += 1L
228 | val c = v.asInstanceOf[Byte]
229 | if (escaped) {
230 | if (c != escapeChar && c != delimiter && c != lineFeed && c != carriageReturn) {
231 | throw new IllegalStateException(
232 | s"Found `$c` (ASCII $v) after $escapeChar.")
233 | }
234 | chars.append(c)
235 | escaped = false
236 | } else {
237 | if (c == escapeChar) {
238 | escaped = true
239 | } else if (c == delimiter) {
240 | endOfField = true
241 | } else if (c == lineFeed) {
242 | endOfRecord = true
243 | } else {
244 | // also copy carriage return
245 | chars.append(c)
246 | }
247 | }
248 | }
249 | }
250 | // TODO: charset?
251 | fields.append(new String(chars.toArray, Charset.forName("UTF-8")))
252 | }
253 | if (escaped) {
254 | throw new IllegalStateException(s"Found hanging escape char.")
255 | }
256 | fields.toArray
257 | }
258 | }
259 |
260 |
--------------------------------------------------------------------------------
/src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftRelation.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 TouchType Ltd
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.io.InputStreamReader
20 | import java.net.URI
21 |
22 | import com.amazonaws.auth.AWSCredentialsProvider
23 | import com.amazonaws.services.s3.AmazonS3Client
24 | import com.eclipsesource.json.Json
25 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters
26 | import org.apache.spark.rdd.RDD
27 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
28 | import org.apache.spark.sql.sources._
29 | import org.apache.spark.sql.types._
30 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
31 | import org.slf4j.LoggerFactory
32 |
33 | import scala.collection.JavaConverters._
34 |
35 | /**
36 | * Data Source API implementation for Amazon Redshift database tables
37 | */
38 | private[redshift] case class RedshiftRelation(
39 | jdbcWrapper: JDBCWrapper,
40 | s3ClientFactory: AWSCredentialsProvider => AmazonS3Client,
41 | params: MergedParameters,
42 | userSchema: Option[StructType])
43 | (@transient val sqlContext: SQLContext)
44 | extends BaseRelation
45 | with PrunedFilteredScan
46 | with InsertableRelation {
47 |
48 | private val log = LoggerFactory.getLogger(getClass)
49 |
50 | if (sqlContext != null) {
51 | Utils.assertThatFileSystemIsNotS3BlockFileSystem(
52 | new URI(params.rootTempDir), sqlContext.sparkContext.hadoopConfiguration)
53 | }
54 |
55 | private val tableNameOrSubquery =
56 | params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get
57 |
58 | override lazy val schema: StructType = {
59 | userSchema.getOrElse {
60 | val tableNameOrSubquery =
61 | params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get
62 | val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
63 | try {
64 | jdbcWrapper.resolveTable(conn, tableNameOrSubquery)
65 | } finally {
66 | conn.close()
67 | }
68 | }
69 | }
70 |
71 | override def toString: String = s"RedshiftRelation($tableNameOrSubquery)"
72 |
73 | override def insert(data: DataFrame, overwrite: Boolean): Unit = {
74 | val saveMode = if (overwrite) {
75 | SaveMode.Overwrite
76 | } else {
77 | SaveMode.Append
78 | }
79 | val writer = new RedshiftWriter(jdbcWrapper, s3ClientFactory)
80 | writer.saveToRedshift(sqlContext, data, saveMode, params)
81 | }
82 |
83 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
84 | filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined)
85 | }
86 |
87 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
88 | val creds = AWSCredentialsUtils.load(params, sqlContext.sparkContext.hadoopConfiguration)
89 | for (
90 | redshiftRegion <- Utils.getRegionForRedshiftCluster(params.jdbcUrl);
91 | s3Region <- Utils.getRegionForS3Bucket(params.rootTempDir, s3ClientFactory(creds))
92 | ) {
93 | if (redshiftRegion != s3Region) {
94 | // We don't currently support `extraunloadoptions`, so even if Amazon _did_ add a `region`
95 | // option for this we wouldn't be able to pass in the new option. However, we choose to
96 | // err on the side of caution and don't throw an exception because we don't want to break
97 | // existing workloads in case the region detection logic is wrong.
98 | log.error("The Redshift cluster and S3 bucket are in different regions " +
99 | s"($redshiftRegion and $s3Region, respectively). Redshift's UNLOAD command requires " +
100 | s"that the Redshift cluster and Amazon S3 bucket be located in the same region, so " +
101 | s"this read will fail.")
102 | }
103 | }
104 | Utils.checkThatBucketHasObjectLifecycleConfiguration(params.rootTempDir, s3ClientFactory(creds))
105 | if (requiredColumns.isEmpty) {
106 | // In the special case where no columns were requested, issue a `count(*)` against Redshift
107 | // rather than unloading data.
108 | val whereClause = FilterPushdown.buildWhereClause(schema, filters)
109 | val countQuery = s"SELECT count(*) FROM $tableNameOrSubquery $whereClause"
110 | log.info(countQuery)
111 | val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
112 | try {
113 | val results = jdbcWrapper.executeQueryInterruptibly(conn.prepareStatement(countQuery))
114 | if (results.next()) {
115 | val numRows = results.getLong(1)
116 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
117 | val emptyRow = RowEncoder(StructType(Seq.empty)).toRow(Row(Seq.empty))
118 | sqlContext.sparkContext
119 | .parallelize(1L to numRows, parallelism)
120 | .map(_ => emptyRow)
121 | .asInstanceOf[RDD[Row]]
122 | } else {
123 | throw new IllegalStateException("Could not read count from Redshift")
124 | }
125 | } finally {
126 | conn.close()
127 | }
128 | } else {
129 | // Unload data from Redshift into a temporary directory in S3:
130 | val tempDir = params.createPerQueryTempDir()
131 | val unloadSql = buildUnloadStmt(requiredColumns, filters, tempDir, creds)
132 | val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
133 | try {
134 | jdbcWrapper.executeInterruptibly(conn.prepareStatement(unloadSql))
135 | } finally {
136 | conn.close()
137 | }
138 | // Read the MANIFEST file to get the list of S3 part files that were written by Redshift.
139 | // We need to use a manifest in order to guard against S3's eventually-consistent listings.
140 | val filesToRead: Seq[String] = {
141 | val cleanedTempDirUri =
142 | Utils.fixS3Url(Utils.removeCredentialsFromURI(URI.create(tempDir)).toString)
143 | val s3URI = Utils.createS3URI(cleanedTempDirUri)
144 | val s3Client = s3ClientFactory(creds)
145 | val is = s3Client.getObject(s3URI.getBucket, s3URI.getKey + "manifest").getObjectContent
146 | val s3Files = try {
147 | val entries = Json.parse(new InputStreamReader(is)).asObject().get("entries").asArray()
148 | entries.iterator().asScala.map(_.asObject().get("url").asString()).toSeq
149 | } finally {
150 | is.close()
151 | }
152 | // The filenames in the manifest are of the form s3://bucket/key, without credentials.
153 | // If the S3 credentials were originally specified in the tempdir's URI, then we need to
154 | // reintroduce them here
155 | s3Files.map { file =>
156 | tempDir.stripSuffix("/") + '/' + file.stripPrefix(cleanedTempDirUri).stripPrefix("/")
157 | }
158 | }
159 |
160 | val prunedSchema = pruneSchema(schema, requiredColumns)
161 |
162 | sqlContext.read
163 | .format(classOf[RedshiftFileFormat].getName)
164 | .schema(prunedSchema)
165 | .option("nullString", params.nullString)
166 | .load(filesToRead: _*)
167 | .queryExecution.executedPlan.execute().asInstanceOf[RDD[Row]]
168 | }
169 | }
170 |
171 | override def needConversion: Boolean = false
172 |
173 | private def buildUnloadStmt(
174 | requiredColumns: Array[String],
175 | filters: Array[Filter],
176 | tempDir: String,
177 | creds: AWSCredentialsProvider): String = {
178 | assert(!requiredColumns.isEmpty)
179 | // Always quote column names:
180 | val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ")
181 | val whereClause = FilterPushdown.buildWhereClause(schema, filters)
182 | val credsString: String =
183 | AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials)
184 | val query = {
185 | // Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape
186 | // any backslashes and single quotes that appear in the query itself
187 | val escapedTableNameOrSubqury = tableNameOrSubquery.replace("\\", "\\\\").replace("'", "\\'")
188 | s"SELECT $columnList FROM $escapedTableNameOrSubqury $whereClause"
189 | }
190 | log.info(query)
191 | // We need to remove S3 credentials from the unload path URI because they will conflict with
192 | // the credentials passed via `credsString`.
193 | val fixedUrl = Utils.fixS3Url(Utils.removeCredentialsFromURI(new URI(tempDir)).toString)
194 |
195 | s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString'" +
196 | s" ESCAPE MANIFEST NULL AS '${params.nullString}'"
197 | }
198 |
199 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
200 | val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
201 | new StructType(columns.map(name => fieldMap(name)))
202 | }
203 | }
204 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/IntegrationSuiteBase.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2015 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.net.URI
20 | import java.sql.Connection
21 |
22 | import org.apache.hadoop.conf.Configuration
23 | import org.apache.hadoop.fs.s3native.NativeS3FileSystem
24 | import org.apache.hadoop.fs.{FileSystem, Path}
25 | import org.apache.spark.SparkContext
26 | import org.apache.spark.sql._
27 | import org.apache.spark.sql.hive.test.TestHiveContext
28 | import org.apache.spark.sql.types.StructType
29 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers}
30 |
31 | import scala.util.Random
32 |
33 |
34 | /**
35 | * Base class for writing integration tests which run against a real Redshift cluster.
36 | */
37 | trait IntegrationSuiteBase
38 | extends QueryTest
39 | with Matchers
40 | with BeforeAndAfterAll
41 | with BeforeAndAfterEach {
42 |
43 | protected def loadConfigFromEnv(envVarName: String): String = {
44 | Option(System.getenv(envVarName)).getOrElse {
45 | fail(s"Must set $envVarName environment variable")
46 | }
47 | }
48 |
49 | // The following configurations must be set in order to run these tests. In Travis, these
50 | // environment variables are set using Travis's encrypted environment variables feature:
51 | // http://docs.travis-ci.com/user/environment-variables/#Encrypted-Variables
52 |
53 | // JDBC URL listed in the AWS console (should not contain username and password).
54 | protected val AWS_REDSHIFT_JDBC_URL: String = loadConfigFromEnv("AWS_REDSHIFT_JDBC_URL")
55 | protected val AWS_REDSHIFT_USER: String = loadConfigFromEnv("AWS_REDSHIFT_USER")
56 | protected val AWS_REDSHIFT_PASSWORD: String = loadConfigFromEnv("AWS_REDSHIFT_PASSWORD")
57 | protected val AWS_ACCESS_KEY_ID: String = loadConfigFromEnv("AWS_ACCESS_KEY_ID")
58 | protected val AWS_SECRET_ACCESS_KEY: String = loadConfigFromEnv("AWS_SECRET_ACCESS_KEY")
59 | // Path to a directory in S3 (e.g. 's3n://bucket-name/path/to/scratch/space').
60 | protected val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE")
61 | require(AWS_S3_SCRATCH_SPACE.contains("s3a"), "must use s3a:// URL")
62 |
63 | protected def jdbcUrl: String = {
64 | s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD&ssl=true"
65 | }
66 |
67 | protected def jdbcUrlNoUserPassword: String = {
68 | s"$AWS_REDSHIFT_JDBC_URL?ssl=true"
69 | }
70 | /**
71 | * Random suffix appended appended to table and directory names in order to avoid collisions
72 | * between separate Travis builds.
73 | */
74 | protected val randomSuffix: String = Math.abs(Random.nextLong()).toString
75 |
76 | protected val tempDir: String = AWS_S3_SCRATCH_SPACE + randomSuffix + "/"
77 |
78 | /**
79 | * Spark Context with Hadoop file overridden to point at our local test data file for this suite,
80 | * no-matter what temp directory was generated and requested.
81 | */
82 | protected var sc: SparkContext = _
83 | protected var sqlContext: SQLContext = _
84 | protected var conn: Connection = _
85 |
86 | override def beforeAll(): Unit = {
87 | super.beforeAll()
88 | sc = new SparkContext("local", "RedshiftSourceSuite")
89 | // Bypass Hadoop's FileSystem caching mechanism so that we don't cache the credentials:
90 | sc.hadoopConfiguration.setBoolean("fs.s3.impl.disable.cache", true)
91 | sc.hadoopConfiguration.setBoolean("fs.s3n.impl.disable.cache", true)
92 | sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
93 | sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
94 | sc.hadoopConfiguration.set("fs.s3a.access.key", AWS_ACCESS_KEY_ID)
95 | sc.hadoopConfiguration.set("fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY)
96 | conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None)
97 | }
98 |
99 | override def afterAll(): Unit = {
100 | try {
101 | val conf = new Configuration(false)
102 | conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
103 | conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
104 | conf.set("fs.s3a.access.key", AWS_ACCESS_KEY_ID)
105 | conf.set("fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY)
106 | // Bypass Hadoop's FileSystem caching mechanism so that we don't cache the credentials:
107 | conf.setBoolean("fs.s3.impl.disable.cache", true)
108 | conf.setBoolean("fs.s3n.impl.disable.cache", true)
109 | conf.setBoolean("fs.s3a.impl.disable.cache", true)
110 | conf.set("fs.s3.impl", classOf[InMemoryS3AFileSystem].getCanonicalName)
111 | conf.set("fs.s3a.impl", classOf[InMemoryS3AFileSystem].getCanonicalName)
112 | val fs = FileSystem.get(URI.create(tempDir), conf)
113 | fs.delete(new Path(tempDir), true)
114 | fs.close()
115 | } finally {
116 | try {
117 | conn.close()
118 | } finally {
119 | try {
120 | sc.stop()
121 | } finally {
122 | super.afterAll()
123 | }
124 | }
125 | }
126 | }
127 |
128 | override protected def beforeEach(): Unit = {
129 | super.beforeEach()
130 | sqlContext = new TestHiveContext(sc, loadTestTables = false)
131 | }
132 |
133 | /**
134 | * Create a new DataFrameReader using common options for reading from Redshift.
135 | */
136 | protected def read: DataFrameReader = {
137 | sqlContext.read
138 | .format("io.github.spark_redshift_community.spark.redshift")
139 | .option("url", jdbcUrl)
140 | .option("tempdir", tempDir)
141 | .option("forward_spark_s3_credentials", "true")
142 | }
143 | /**
144 | * Create a new DataFrameWriter using common options for writing to Redshift.
145 | */
146 | protected def write(df: DataFrame): DataFrameWriter[Row] = {
147 | df.write
148 | .format("io.github.spark_redshift_community.spark.redshift")
149 | .option("url", jdbcUrl)
150 | .option("tempdir", tempDir)
151 | .option("forward_spark_s3_credentials", "true")
152 | }
153 |
154 | protected def createTestDataInRedshift(tableName: String): Unit = {
155 | conn.createStatement().executeUpdate(
156 | s"""
157 | |create table $tableName (
158 | |testbyte int2,
159 | |testbool boolean,
160 | |testdate date,
161 | |testdouble float8,
162 | |testfloat float4,
163 | |testint int4,
164 | |testlong int8,
165 | |testshort int2,
166 | |teststring varchar(256),
167 | |testtimestamp timestamp
168 | |)
169 | """.stripMargin
170 | )
171 | // scalastyle:off
172 | conn.createStatement().executeUpdate(
173 | s"""
174 | |insert into $tableName values
175 | |(null, null, null, null, null, null, null, null, null, null),
176 | |(0, null, '2015-07-03', 0.0, -1.0, 4141214, 1239012341823719, null, 'f', '2015-07-03 00:00:00.000'),
177 | |(0, false, null, -1234152.12312498, 100000.0, null, 1239012341823719, 24, '___|_123', null),
178 | |(1, false, '2015-07-02', 0.0, 0.0, 42, 1239012341823719, -13, 'asdf', '2015-07-02 00:00:00.000'),
179 | |(1, true, '2015-07-01', 1234152.12312498, 1.0, 42, 1239012341823719, 23, 'Unicode''s樂趣', '2015-07-01 00:00:00.001')
180 | """.stripMargin
181 | )
182 | // scalastyle:on
183 | }
184 |
185 | protected def withTempRedshiftTable[T](namePrefix: String)(body: String => T): T = {
186 | val tableName = s"$namePrefix$randomSuffix"
187 | try {
188 | body(tableName)
189 | } finally {
190 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
191 | }
192 | }
193 |
194 | /**
195 | * Save the given DataFrame to Redshift, then load the results back into a DataFrame and check
196 | * that the returned DataFrame matches the one that we saved.
197 | *
198 | * @param tableName the table name to use
199 | * @param df the DataFrame to save
200 | * @param expectedSchemaAfterLoad if specified, the expected schema after loading the data back
201 | * from Redshift. This should be used in cases where you expect
202 | * the schema to differ due to reasons like case-sensitivity.
203 | * @param saveMode the [[SaveMode]] to use when writing data back to Redshift
204 | */
205 | def testRoundtripSaveAndLoad(
206 | tableName: String,
207 | df: DataFrame,
208 | expectedSchemaAfterLoad: Option[StructType] = None,
209 | saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = {
210 | try {
211 | write(df)
212 | .option("dbtable", tableName)
213 | .mode(saveMode)
214 | .save()
215 | // Check that the table exists. It appears that creating a table in one connection then
216 | // immediately querying for existence from another connection may result in spurious "table
217 | // doesn't exist" errors; this caused the "save with all empty partitions" test to become
218 | // flaky (see #146). To work around this, add a small sleep and check again:
219 | if (!DefaultJDBCWrapper.tableExists(conn, tableName)) {
220 | Thread.sleep(1000)
221 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
222 | }
223 | val loadedDf = read.option("dbtable", tableName).load()
224 | assert(loadedDf.schema === expectedSchemaAfterLoad.getOrElse(df.schema))
225 | checkAnswer(loadedDf, df.collect())
226 | } finally {
227 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
228 | }
229 | }
230 | }
231 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
--------------------------------------------------------------------------------
/src/it/scala/io/github/spark_redshift_community/spark/redshift/RedshiftReadSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 Databricks
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.github.spark_redshift_community.spark.redshift
18 |
19 | import java.sql.Timestamp
20 |
21 | import org.apache.spark.sql.types.LongType
22 | import org.apache.spark.sql.{Row, execution}
23 |
24 | /**
25 | * End-to-end tests of functionality which only impacts the read path (e.g. filter pushdown).
26 | */
27 | class RedshiftReadSuite extends IntegrationSuiteBase {
28 |
29 | private val test_table: String = s"read_suite_test_table_$randomSuffix"
30 |
31 | override def beforeAll(): Unit = {
32 | super.beforeAll()
33 | conn.prepareStatement(s"drop table if exists $test_table").executeUpdate()
34 | createTestDataInRedshift(test_table)
35 | }
36 |
37 | override def afterAll(): Unit = {
38 | try {
39 | conn.prepareStatement(s"drop table if exists $test_table").executeUpdate()
40 | } finally {
41 | super.afterAll()
42 | }
43 | }
44 |
45 | override def beforeEach(): Unit = {
46 | super.beforeEach()
47 | read.option("dbtable", test_table).load().createOrReplaceTempView("test_table")
48 | }
49 |
50 | test("DefaultSource can load Redshift UNLOAD output to a DataFrame") {
51 | checkAnswer(
52 | sqlContext.sql("select * from test_table"),
53 | TestUtils.expectedData)
54 | }
55 |
56 | test("count() on DataFrame created from a Redshift table") {
57 | checkAnswer(
58 | sqlContext.sql("select count(*) from test_table"),
59 | Seq(Row(TestUtils.expectedData.length))
60 | )
61 | }
62 |
63 | test("count() on DataFrame created from a Redshift query") {
64 | val loadedDf =
65 | // scalastyle:off
66 | read.option("query", s"select * from $test_table where teststring = 'Unicode''s樂趣'").load()
67 | // scalastyle:on
68 | checkAnswer(
69 | loadedDf.selectExpr("count(*)"),
70 | Seq(Row(1))
71 | )
72 | }
73 |
74 | test("backslashes in queries/subqueries are escaped (regression test for #215)") {
75 | val loadedDf =
76 | read.option("query", s"select replace(teststring, '\\\\', '') as col from $test_table").load()
77 | checkAnswer(
78 | loadedDf.filter("col = 'asdf'"),
79 | Seq(Row("asdf"))
80 | )
81 | }
82 |
83 | test("Can load output when 'dbtable' is a subquery wrapped in parentheses") {
84 | // scalastyle:off
85 | val query =
86 | s"""
87 | |(select testbyte, testbool
88 | |from $test_table
89 | |where testbool = true
90 | | and teststring = 'Unicode''s樂趣'
91 | | and testdouble = 1234152.12312498
92 | | and testfloat = 1.0
93 | | and testint = 42)
94 | """.stripMargin
95 | // scalastyle:on
96 | checkAnswer(read.option("dbtable", query).load(), Seq(Row(1, true)))
97 | }
98 |
99 | test("Can load output when 'query' is specified instead of 'dbtable'") {
100 | // scalastyle:off
101 | val query =
102 | s"""
103 | |select testbyte, testbool
104 | |from $test_table
105 | |where testbool = true
106 | | and teststring = 'Unicode''s樂趣'
107 | | and testdouble = 1234152.12312498
108 | | and testfloat = 1.0
109 | | and testint = 42
110 | """.stripMargin
111 | // scalastyle:on
112 | checkAnswer(read.option("query", query).load(), Seq(Row(1, true)))
113 | }
114 |
115 | test("Can load output of Redshift aggregation queries") {
116 | checkAnswer(
117 | read.option("query", s"select testbool, count(*) from $test_table group by testbool").load(),
118 | Seq(Row(true, 1), Row(false, 2), Row(null, 2)))
119 | }
120 |
121 | test("multiple scans on same table") {
122 | // .rdd() forces the first query to be unloaded from Redshift
123 | val rdd1 = sqlContext.sql("select testint from test_table").rdd
124 | // Similarly, this also forces an unload:
125 | sqlContext.sql("select testdouble from test_table").rdd
126 | // If the unloads were performed into the same directory then this call would fail: the
127 | // second unload from rdd2 would have overwritten the integers with doubles, so we'd get
128 | // a NumberFormatException.
129 | rdd1.count()
130 | }
131 |
132 | test("DefaultSource supports simple column filtering") {
133 | checkAnswer(
134 | sqlContext.sql("select testbyte, testbool from test_table"),
135 | Seq(
136 | Row(null, null),
137 | Row(0.toByte, null),
138 | Row(0.toByte, false),
139 | Row(1.toByte, false),
140 | Row(1.toByte, true)))
141 | }
142 |
143 | test("query with pruned and filtered scans") {
144 | // scalastyle:off
145 | checkAnswer(
146 | sqlContext.sql(
147 | """
148 | |select testbyte, testbool
149 | |from test_table
150 | |where testbool = true
151 | | and teststring = "Unicode's樂趣"
152 | | and testdouble = 1234152.12312498
153 | | and testfloat = 1.0
154 | | and testint = 42
155 | """.stripMargin),
156 | Seq(Row(1, true)))
157 | // scalastyle:on
158 | }
159 |
160 | test("RedshiftRelation implements Spark 1.6+'s unhandledFilters API") {
161 | assume(org.apache.spark.SPARK_VERSION.take(3) >= "1.6")
162 | val df = sqlContext.sql("select testbool from test_table where testbool = true")
163 | val physicalPlan = df.queryExecution.sparkPlan
164 | physicalPlan.collectFirst { case f: execution.FilterExec => f }.foreach { filter =>
165 | fail(s"Filter should have been eliminated:\n${df.queryExecution}")
166 | }
167 | }
168 |
169 | test("filtering based on date constants (regression test for #152)") {
170 | val date = TestUtils.toDate(year = 2015, zeroBasedMonth = 6, date = 3)
171 | val df = sqlContext.sql("select testdate from test_table")
172 |
173 | checkAnswer(df.filter(df("testdate") === date), Seq(Row(date)))
174 | // This query failed in Spark 1.6.0 but not in earlier versions. It looks like 1.6.0 performs
175 | // constant-folding, whereas earlier Spark versions would preserve the cast which prevented
176 | // filter pushdown.
177 | checkAnswer(df.filter("testdate = to_date('2015-07-03')"), Seq(Row(date)))
178 | }
179 |
180 | test("filtering based on timestamp constants (regression test for #152)") {
181 | val timestamp = TestUtils.toTimestamp(2015, zeroBasedMonth = 6, 1, 0, 0, 0, 1)
182 | val df = sqlContext.sql("select testtimestamp from test_table")
183 |
184 | checkAnswer(df.filter(df("testtimestamp") === timestamp), Seq(Row(timestamp)))
185 | // This query failed in Spark 1.6.0 but not in earlier versions. It looks like 1.6.0 performs
186 | // constant-folding, whereas earlier Spark versions would preserve the cast which prevented
187 | // filter pushdown.
188 | checkAnswer(df.filter("testtimestamp = '2015-07-01 00:00:00.001'"), Seq(Row(timestamp)))
189 | }
190 |
191 | test("read special float values (regression test for #261)") {
192 | val tableName = s"roundtrip_special_float_values_$randomSuffix"
193 | try {
194 | conn.createStatement().executeUpdate(
195 | s"CREATE TABLE $tableName (x real)")
196 | conn.createStatement().executeUpdate(
197 | s"INSERT INTO $tableName VALUES ('NaN'), ('Infinity'), ('-Infinity')")
198 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
199 | checkAnswer(
200 | read.option("dbtable", tableName).load(),
201 | Seq(Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity).map(x => Row.apply(x)))
202 | } finally {
203 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
204 | }
205 | }
206 |
207 | test("test empty string and null") {
208 | withTempRedshiftTable("records_with_empty_and_null_characters") { tableName =>
209 | conn.createStatement().executeUpdate(
210 | s"CREATE TABLE $tableName (x varchar(256))")
211 | conn.createStatement().executeUpdate(
212 | s"INSERT INTO $tableName VALUES ('null'), (''), (null)")
213 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
214 | checkAnswer(
215 | read.option("dbtable", tableName).load(),
216 | Seq("null", "", null).map(x => Row.apply(x)))
217 | }
218 | }
219 |
220 | test("test timestamptz parsing") {
221 | withTempRedshiftTable("luca_test_timestamptz_spark_redshift") { tableName =>
222 | conn.createStatement().executeUpdate(
223 | s"CREATE TABLE $tableName (x timestamptz)"
224 | )
225 | conn.createStatement().executeUpdate(
226 | s"INSERT INTO $tableName VALUES ('2015-07-03 00:00:00.000 -0300')"
227 | )
228 |
229 | checkAnswer(
230 | read.option("dbtable", tableName).load(),
231 | Seq(Row.apply(
232 | new Timestamp(TestUtils.toMillis(
233 | 2015, 6, 3, 0, 0, 0, 0, "-03"))))
234 | )
235 | }
236 | }
237 |
238 | test("read special double values (regression test for #261)") {
239 | val tableName = s"roundtrip_special_double_values_$randomSuffix"
240 | try {
241 | conn.createStatement().executeUpdate(
242 | s"CREATE TABLE $tableName (x double precision)")
243 | conn.createStatement().executeUpdate(
244 | s"INSERT INTO $tableName VALUES ('NaN'), ('Infinity'), ('-Infinity')")
245 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
246 | checkAnswer(
247 | read.option("dbtable", tableName).load(),
248 | Seq(Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity).map(x => Row.apply(x)))
249 | } finally {
250 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
251 | }
252 | }
253 |
254 | test("read records containing escaped characters") {
255 | withTempRedshiftTable("records_with_escaped_characters") { tableName =>
256 | conn.createStatement().executeUpdate(
257 | s"CREATE TABLE $tableName (x text)")
258 | conn.createStatement().executeUpdate(
259 | s"""INSERT INTO $tableName VALUES ('a\\nb'), ('\\\\'), ('"')""")
260 | assert(DefaultJDBCWrapper.tableExists(conn, tableName))
261 | checkAnswer(
262 | read.option("dbtable", tableName).load(),
263 | Seq("a\nb", "\\", "\"").map(x => Row.apply(x)))
264 | }
265 | }
266 |
267 | test("read result of approximate count(distinct) query (#300)") {
268 | val df = read
269 | .option("query", s"select approximate count(distinct testbool) as c from $test_table")
270 | .load()
271 | assert(df.schema.fields(0).dataType === LongType)
272 | }
273 | }
274 |
--------------------------------------------------------------------------------
/scalastyle-config.xml:
--------------------------------------------------------------------------------
1 |
17 |
39 |
40 |
41 | Scalastyle standard configuration
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 | true
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 | ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW
105 |
106 |
107 |
108 |
109 |
110 | ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 | ^println$
120 |
124 |
125 |
126 |
127 | org.apache.spark.Logging
128 |
129 |
130 |
131 |
132 | Class\.forName
133 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 | 800>
188 |
189 |
190 |
191 |
192 | 30
193 |
194 |
195 |
196 |
197 | 10
198 |
199 |
200 |
201 |
202 | 50
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 | -1,0,1,2,3
214 |
215 |
216 |
217 |
--------------------------------------------------------------------------------