├── version.sbt ├── src ├── it │ ├── resources │ │ └── log4j.properties │ └── scala │ │ └── io │ │ └── github │ │ └── spark_redshift_community │ │ └── spark │ │ └── redshift │ │ ├── PostgresDriverIntegrationSuite.scala │ │ ├── RedshiftCredentialsInConfIntegrationSuite.scala │ │ ├── CrossRegionIntegrationSuite.scala │ │ ├── IAMIntegrationSuite.scala │ │ ├── DecimalIntegrationSuite.scala │ │ ├── ColumnMetadataSuite.scala │ │ ├── SaveModeIntegrationSuite.scala │ │ ├── RedshiftWriteSuite.scala │ │ ├── IntegrationSuiteBase.scala │ │ └── RedshiftReadSuite.scala ├── test │ ├── resources │ │ ├── redshift_unload_data.txt │ │ └── hive-site.xml │ ├── scala │ │ └── io │ │ │ └── github │ │ │ └── spark_redshift_community │ │ │ └── spark │ │ │ └── redshift │ │ │ ├── TableNameSuite.scala │ │ │ ├── SerializableConfigurationSuite.scala │ │ │ ├── DirectMapredOutputCommitter.scala │ │ │ ├── DirectMapreduceOutputCommitter.scala │ │ │ ├── SeekableByteArrayInputStream.java │ │ │ ├── UtilsSuite.scala │ │ │ ├── QueryTest.scala │ │ │ ├── FilterPushdownSuite.scala │ │ │ ├── InMemoryS3AFileSystemSuite.scala │ │ │ ├── MockRedshift.scala │ │ │ ├── TestUtils.scala │ │ │ ├── RedshiftInputFormatSuite.scala │ │ │ ├── AWSCredentialsUtilsSuite.scala │ │ │ ├── ConversionsSuite.scala │ │ │ └── ParametersSuite.scala │ └── java │ │ └── io │ │ └── github │ │ └── spark_redshift_community │ │ └── spark │ │ └── redshift │ │ └── InMemoryS3AFileSystem.java └── main │ └── scala │ └── io │ └── github │ └── spark_redshift_community │ └── spark │ └── redshift │ ├── SerializableConfiguration.scala │ ├── package.scala │ ├── RecordReaderIterator.scala │ ├── TableName.scala │ ├── FilterPushdown.scala │ ├── RedshiftFileFormat.scala │ ├── DefaultSource.scala │ ├── AWSCredentialsUtils.scala │ ├── Conversions.scala │ ├── Utils.scala │ ├── RedshiftInputFormat.scala │ └── RedshiftRelation.scala ├── tutorial ├── images │ ├── loadreadstep.png │ ├── loadunloadstep.png │ └── savetoredshift.png ├── how_to_build.md └── SparkRedshiftTutorial.scala ├── .gitignore ├── .jvmopts ├── codecov.yml ├── NOTICE ├── .travis.yml ├── dev └── run-tests-travis.sh ├── project ├── plugins.sbt └── build.properties ├── CHANGELOG ├── LICENSE └── scalastyle-config.xml /version.sbt: -------------------------------------------------------------------------------- 1 | version in ThisBuild := "4.1.0" 2 | -------------------------------------------------------------------------------- /src/it/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootLogger=OFF 2 | -------------------------------------------------------------------------------- /tutorial/images/loadreadstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tes/spark-redshift/master/tutorial/images/loadreadstep.png -------------------------------------------------------------------------------- /tutorial/images/loadunloadstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tes/spark-redshift/master/tutorial/images/loadunloadstep.png -------------------------------------------------------------------------------- /tutorial/images/savetoredshift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tes/spark-redshift/master/tutorial/images/savetoredshift.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | project/target 3 | .idea/ 4 | .idea_modules/ 5 | *.DS_Store 6 | build/*.jar 7 | aws_variables.env 8 | derby.log 9 | -------------------------------------------------------------------------------- /.jvmopts: -------------------------------------------------------------------------------- 1 | -Dfile.encoding=UTF8 2 | -Xms1024M 3 | -Xmx1024M 4 | -Xss6M 5 | -XX:MaxPermSize=512m 6 | -XX:+CMSClassUnloadingEnabled 7 | -XX:+UseConcMarkSweepGC 8 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: 2 | layout: header, changes, diff 3 | coverage: 4 | status: 5 | patch: false 6 | project: 7 | default: 8 | target: 85 9 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Apache Accumulo 2 | Copyright 2011-2019 The Apache Software Foundation 3 | 4 | This product includes software developed at 5 | The Apache Software Foundation (http://www.apache.org/). 6 | 7 | -------------------------------------------------------------------------------- /src/test/resources/redshift_unload_data.txt: -------------------------------------------------------------------------------- 1 | 1|t|2015-07-01|1234152.12312498|1.0|42|1239012341823719|23|Unicode's樂趣|2015-07-01 00:00:00.001 2 | 1|f|2015-07-02|0|0.0|42|1239012341823719|-13|asdf|2015-07-02 00:00:00.0 3 | 0||2015-07-03|0.0|-1.0|4141214|1239012341823719||f|2015-07-03 00:00:00 4 | 0|f||-1234152.12312498|100000.0||1239012341823719|24|___\|_123| 5 | ||||||||| 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | sudo: false 3 | # Cache settings here are based on latest SBT documentation. 4 | cache: 5 | directories: 6 | - $HOME/.ivy2/cache 7 | - $HOME/.sbt/boot/ 8 | before_cache: 9 | # Tricks to avoid unnecessary cache updates 10 | - find $HOME/.ivy2 -name "ivydata-*.properties" -delete 11 | - find $HOME/.sbt -name "*.lock" -delete 12 | matrix: 13 | include: 14 | - jdk: openjdk8 15 | scala: 2.11.7 16 | env: HADOOP_VERSION="2.7.7" SPARK_VERSION="2.4.3" AWS_JAVA_SDK_VERSION="1.7.4" 17 | 18 | script: 19 | - ./dev/run-tests-travis.sh 20 | 21 | after_success: 22 | - bash <(curl -s https://codecov.io/bash) 23 | -------------------------------------------------------------------------------- /dev/run-tests-travis.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | sbt ++$TRAVIS_SCALA_VERSION scalastyle 6 | sbt ++$TRAVIS_SCALA_VERSION "test:scalastyle" 7 | sbt ++$TRAVIS_SCALA_VERSION "it:scalastyle" 8 | 9 | sbt \ 10 | -Daws.testVersion=$AWS_JAVA_SDK_VERSION \ 11 | -Dhadoop.testVersion=$HADOOP_VERSION \ 12 | -Dspark.testVersion=$SPARK_VERSION \ 13 | ++$TRAVIS_SCALA_VERSION \ 14 | coverage test coverageReport 15 | 16 | if [ "$TRAVIS_SECURE_ENV_VARS" == "true" ]; then 17 | sbt \ 18 | -Daws.testVersion=$AWS_JAVA_SDK_VERSION \ 19 | -Dhadoop.testVersion=$HADOOP_VERSION \ 20 | -Dspark.testVersion=$SPARK_VERSION \ 21 | ++$TRAVIS_SCALA_VERSION \ 22 | coverage it:test coverageReport 2> /dev/null; 23 | fi 24 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") 2 | 3 | addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.5") 4 | 5 | resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven" 6 | 7 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.2") 8 | 9 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") 10 | 11 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") 12 | 13 | addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0") 14 | 15 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.0") 16 | 17 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") 18 | 19 | libraryDependencies += "org.apache.maven" % "maven-artifact" % "3.3.9" 20 | -------------------------------------------------------------------------------- /tutorial/how_to_build.md: -------------------------------------------------------------------------------- 1 | If you are building this project from source, you can try the following 2 | 3 | ``` 4 | git clone https://github.com/spark-redshift-community/spark-redshift.git 5 | ``` 6 | 7 | ``` 8 | cd spark-redshift 9 | ``` 10 | 11 | ``` 12 | ./build/sbt -v compile 13 | ``` 14 | 15 | ``` 16 | ./build/sbt -v package 17 | ``` 18 | 19 | To run the test 20 | 21 | ``` 22 | ./build/sbt -v test 23 | ``` 24 | 25 | To run the integration test 26 | 27 | For the first time, you need to set up all the evnironment variables to connect to Redshift (see https://github.com/spark-redshift-community/spark-redshift/blob/master/src/it/scala/io/github/spark_redshift_community/spark/redshift/IntegrationSuiteBase.scala#L54). 28 | 29 | ``` 30 | ./build/sbt -v it:test 31 | ``` 32 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | sbt.version=0.13.13 18 | -------------------------------------------------------------------------------- /src/test/resources/hive-site.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 | fs.permissions.umask-mode 23 | 022 24 | Setting a value for fs.permissions.umask-mode to work around issue in HIVE-6962. 25 | It has no impact in Hadoop 1.x line on HDFS operations. 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/TableNameSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.scalatest.FunSuite 20 | 21 | class TableNameSuite extends FunSuite { 22 | test("TableName.parseFromEscaped") { 23 | assert(TableName.parseFromEscaped("foo.bar") === TableName("foo", "bar")) 24 | assert(TableName.parseFromEscaped("foo") === TableName("PUBLIC", "foo")) 25 | assert(TableName.parseFromEscaped("\"foo\"") === TableName("PUBLIC", "foo")) 26 | assert(TableName.parseFromEscaped("\"\"\"foo\"\"\".bar") === TableName("\"foo\"", "bar")) 27 | // Dots (.) can also appear inside of valid identifiers. 28 | assert(TableName.parseFromEscaped("\"foo.bar\".baz") === TableName("foo.bar", "baz")) 29 | assert(TableName.parseFromEscaped("\"foo\"\".bar\".baz") === TableName("foo\".bar", "baz")) 30 | } 31 | 32 | test("TableName.toString") { 33 | assert(TableName("foo", "bar").toString === """"foo"."bar"""") 34 | assert(TableName("PUBLIC", "bar").toString === """"PUBLIC"."bar"""") 35 | assert(TableName("\"foo\"", "bar").toString === "\"\"\"foo\"\"\".\"bar\"") 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/PostgresDriverIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.spark.sql.Row 20 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 21 | 22 | /** 23 | * Basic integration tests with the Postgres JDBC driver. 24 | */ 25 | class PostgresDriverIntegrationSuite extends IntegrationSuiteBase { 26 | 27 | override def jdbcUrl: String = { 28 | super.jdbcUrl.replace("jdbc:redshift", "jdbc:postgresql") 29 | } 30 | 31 | // TODO (luca|issue #9) Fix tests when using postgresql driver 32 | ignore("postgresql driver takes precedence for jdbc:postgresql:// URIs") { 33 | val conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None) 34 | try { 35 | assert(conn.getClass.getName === "org.postgresql.jdbc4.Jdbc4Connection") 36 | } finally { 37 | conn.close() 38 | } 39 | } 40 | 41 | ignore("roundtrip save and load") { 42 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), 43 | StructType(StructField("foo", IntegerType) :: Nil)) 44 | testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/SerializableConfigurationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.hadoop.conf.Configuration 20 | import org.apache.spark.SparkConf 21 | import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} 22 | import org.scalatest.FunSuite 23 | 24 | class SerializableConfigurationSuite extends FunSuite { 25 | 26 | private def testSerialization(serializer: SerializerInstance): Unit = { 27 | val conf = new SerializableConfiguration(new Configuration()) 28 | 29 | val serialized = serializer.serialize(conf) 30 | 31 | serializer.deserialize[Any](serialized) match { 32 | case c: SerializableConfiguration => 33 | assert(c.log != null, "log was null") 34 | assert(c.value != null, "value was null") 35 | case other => fail( 36 | s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.") 37 | } 38 | } 39 | 40 | test("serialization with JavaSerializer") { 41 | testSerialization(new JavaSerializer(new SparkConf()).newInstance()) 42 | } 43 | 44 | test("serialization with KryoSerializer") { 45 | testSerialization(new KryoSerializer(new SparkConf()).newInstance()) 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.spark.sql.Row 20 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 21 | 22 | /** 23 | * This suite performs basic integration tests where the Redshift credentials have been 24 | * specified via `spark-redshift`'s configuration rather than as part of the JDBC URL. 25 | */ 26 | class RedshiftCredentialsInConfIntegrationSuite extends IntegrationSuiteBase { 27 | 28 | test("roundtrip save and load") { 29 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), 30 | StructType(StructField("foo", IntegerType) :: Nil)) 31 | val tableName = s"roundtrip_save_and_load_$randomSuffix" 32 | try { 33 | write(df) 34 | .option("url", jdbcUrlNoUserPassword) 35 | .option("user", AWS_REDSHIFT_USER) 36 | .option("password", AWS_REDSHIFT_PASSWORD) 37 | .option("dbtable", tableName) 38 | .save() 39 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 40 | val loadedDf = read 41 | .option("url", jdbcUrlNoUserPassword) 42 | .option("user", AWS_REDSHIFT_USER) 43 | .option("password", AWS_REDSHIFT_PASSWORD) 44 | .option("dbtable", tableName) 45 | .load() 46 | assert(loadedDf.schema === df.schema) 47 | checkAnswer(loadedDf, df.collect()) 48 | } finally { 49 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 50 | } 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/SerializableConfiguration.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.io._ 20 | 21 | import com.esotericsoftware.kryo.io.{Input, Output} 22 | import com.esotericsoftware.kryo.{Kryo, KryoSerializable} 23 | import org.apache.hadoop.conf.Configuration 24 | import org.slf4j.LoggerFactory 25 | 26 | import scala.util.control.NonFatal 27 | 28 | class SerializableConfiguration(@transient var value: Configuration) 29 | extends Serializable with KryoSerializable { 30 | @transient private[redshift] lazy val log = LoggerFactory.getLogger(getClass) 31 | 32 | private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { 33 | out.defaultWriteObject() 34 | value.write(out) 35 | } 36 | 37 | private def readObject(in: ObjectInputStream): Unit = tryOrIOException { 38 | value = new Configuration(false) 39 | value.readFields(in) 40 | } 41 | 42 | private def tryOrIOException[T](block: => T): T = { 43 | try { 44 | block 45 | } catch { 46 | case e: IOException => 47 | log.error("Exception encountered", e) 48 | throw e 49 | case NonFatal(e) => 50 | log.error("Exception encountered", e) 51 | throw new IOException(e) 52 | } 53 | } 54 | 55 | def write(kryo: Kryo, out: Output): Unit = { 56 | val dos = new DataOutputStream(out) 57 | value.write(dos) 58 | dos.flush() 59 | } 60 | 61 | def read(kryo: Kryo, in: Input): Unit = { 62 | value = new Configuration(false) 63 | value.readFields(new DataInputStream(in)) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * Copyright 2015 TouchType Ltd. (Added JDBC-based Data Source API implementation) 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark 19 | 20 | import org.apache.spark.sql.functions.col 21 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 22 | import org.apache.spark.sql.{DataFrame, Row, SQLContext} 23 | 24 | package object redshift { 25 | 26 | /** 27 | * Wrapper of SQLContext that provide `redshiftFile` method. 28 | */ 29 | implicit class RedshiftContext(sqlContext: SQLContext) { 30 | 31 | /** 32 | * Read a file unloaded from Redshift into a DataFrame. 33 | * @param path input path 34 | * @return a DataFrame with all string columns 35 | */ 36 | def redshiftFile(path: String, columns: Seq[String]): DataFrame = { 37 | val sc = sqlContext.sparkContext 38 | val rdd = sc.newAPIHadoopFile(path, classOf[RedshiftInputFormat], 39 | classOf[java.lang.Long], classOf[Array[String]], sc.hadoopConfiguration) 40 | // TODO: allow setting NULL string. 41 | val nullable = rdd.values.map(_.map(f => if (f.isEmpty) null else f)).map(x => Row(x: _*)) 42 | val schema = StructType(columns.map(c => StructField(c, StringType, nullable = true))) 43 | sqlContext.createDataFrame(nullable, schema) 44 | } 45 | 46 | /** 47 | * Reads a table unload from Redshift with its schema. 48 | */ 49 | def redshiftFile(path: String, schema: StructType): DataFrame = { 50 | val casts = schema.fields.map { field => 51 | col(field.name).cast(field.dataType).as(field.name) 52 | } 53 | redshiftFile(path, schema.fieldNames).select(casts: _*) 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/DirectMapredOutputCommitter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | * not use this file except in compliance with the License. You may obtain 6 | * a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.hadoop.fs.Path 20 | import org.apache.hadoop.mapred._ 21 | 22 | class DirectMapredOutputCommitter extends OutputCommitter { 23 | override def setupJob(jobContext: JobContext): Unit = { } 24 | 25 | override def setupTask(taskContext: TaskAttemptContext): Unit = { } 26 | 27 | override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = { 28 | // We return true here to guard against implementations that do not handle false correctly. 29 | // The meaning of returning false is not entirely clear, so it's possible to be interpreted 30 | // as an error. Returning true just means that commitTask() will be called, which is a no-op. 31 | true 32 | } 33 | 34 | override def commitTask(taskContext: TaskAttemptContext): Unit = { } 35 | 36 | override def abortTask(taskContext: TaskAttemptContext): Unit = { } 37 | 38 | /** 39 | * Creates a _SUCCESS file to indicate the entire job was successful. 40 | * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option. 41 | */ 42 | override def commitJob(context: JobContext): Unit = { 43 | val conf = context.getJobConf 44 | if (shouldCreateSuccessFile(conf)) { 45 | val outputPath = FileOutputFormat.getOutputPath(conf) 46 | if (outputPath != null) { 47 | val fileSys = outputPath.getFileSystem(conf) 48 | val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) 49 | fileSys.create(filePath).close() 50 | } 51 | } 52 | } 53 | 54 | /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */ 55 | private def shouldCreateSuccessFile(conf: JobConf): Boolean = { 56 | conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/DirectMapreduceOutputCommitter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | * not use this file except in compliance with the License. You may obtain 6 | * a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.hadoop.conf.Configuration 20 | import org.apache.hadoop.fs.Path 21 | import org.apache.hadoop.mapreduce._ 22 | import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} 23 | 24 | class DirectMapreduceOutputCommitter extends OutputCommitter { 25 | override def setupJob(jobContext: JobContext): Unit = { } 26 | 27 | override def setupTask(taskContext: TaskAttemptContext): Unit = { } 28 | 29 | override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = { 30 | // We return true here to guard against implementations that do not handle false correctly. 31 | // The meaning of returning false is not entirely clear, so it's possible to be interpreted 32 | // as an error. Returning true just means that commitTask() will be called, which is a no-op. 33 | true 34 | } 35 | 36 | override def commitTask(taskContext: TaskAttemptContext): Unit = { } 37 | 38 | override def abortTask(taskContext: TaskAttemptContext): Unit = { } 39 | 40 | /** 41 | * Creates a _SUCCESS file to indicate the entire job was successful. 42 | * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option. 43 | */ 44 | override def commitJob(context: JobContext): Unit = { 45 | val conf = context.getConfiguration 46 | if (shouldCreateSuccessFile(conf)) { 47 | val outputPath = FileOutputFormat.getOutputPath(context) 48 | if (outputPath != null) { 49 | val fileSys = outputPath.getFileSystem(conf) 50 | val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) 51 | fileSys.create(filePath).close() 52 | } 53 | } 54 | } 55 | 56 | /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */ 57 | private def shouldCreateSuccessFile(conf: Configuration): Boolean = { 58 | conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/CrossRegionIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import com.amazonaws.auth.BasicAWSCredentials 20 | import com.amazonaws.services.s3.AmazonS3Client 21 | import org.apache.spark.sql.Row 22 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 23 | 24 | /** 25 | * Integration tests where the Redshift cluster and the S3 bucket are in different AWS regions. 26 | */ 27 | class CrossRegionIntegrationSuite extends IntegrationSuiteBase { 28 | 29 | protected val AWS_S3_CROSS_REGION_SCRATCH_SPACE: String = 30 | loadConfigFromEnv("AWS_S3_CROSS_REGION_SCRATCH_SPACE") 31 | require(AWS_S3_CROSS_REGION_SCRATCH_SPACE.contains("s3a"), "must use s3a:// URL") 32 | 33 | override protected val tempDir: String = AWS_S3_CROSS_REGION_SCRATCH_SPACE + randomSuffix + "/" 34 | 35 | test("write") { 36 | val bucketRegion = Utils.getRegionForS3Bucket( 37 | tempDir, 38 | new AmazonS3Client(new BasicAWSCredentials(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY))).get 39 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), 40 | StructType(StructField("foo", IntegerType) :: Nil)) 41 | val tableName = s"roundtrip_save_and_load_$randomSuffix" 42 | try { 43 | write(df) 44 | .option("dbtable", tableName) 45 | .option("extracopyoptions", s"region '$bucketRegion'") 46 | .save() 47 | // Check that the table exists. It appears that creating a table in one connection then 48 | // immediately querying for existence from another connection may result in spurious "table 49 | // doesn't exist" errors; this caused the "save with all empty partitions" test to become 50 | // flaky (see #146). To work around this, add a small sleep and check again: 51 | if (!DefaultJDBCWrapper.tableExists(conn, tableName)) { 52 | Thread.sleep(1000) 53 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 54 | } 55 | } finally { 56 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/RecordReaderIterator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift 19 | 20 | import java.io.Closeable 21 | 22 | import org.apache.hadoop.mapreduce.RecordReader 23 | 24 | /** 25 | * An adaptor from a Hadoop [[RecordReader]] to an [[Iterator]] over the values returned. 26 | * 27 | * This is copied from Apache Spark and is inlined here to avoid depending on Spark internals 28 | * in this external library. 29 | */ 30 | private[redshift] class RecordReaderIterator[T]( 31 | private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable { 32 | private[this] var havePair = false 33 | private[this] var finished = false 34 | 35 | override def hasNext: Boolean = { 36 | if (!finished && !havePair) { 37 | finished = !rowReader.nextKeyValue 38 | if (finished) { 39 | // Close and release the reader here; close() will also be called when the task 40 | // completes, but for tasks that read from many files, it helps to release the 41 | // resources early. 42 | close() 43 | } 44 | havePair = !finished 45 | } 46 | !finished 47 | } 48 | 49 | override def next(): T = { 50 | if (!hasNext) { 51 | throw new java.util.NoSuchElementException("End of stream") 52 | } 53 | havePair = false 54 | rowReader.getCurrentValue 55 | } 56 | 57 | override def close(): Unit = { 58 | if (rowReader != null) { 59 | try { 60 | // Close the reader and release it. Note: it's very important that we don't close the 61 | // reader more than once, since that exposes us to MAPREDUCE-5918 when running against 62 | // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues 63 | // when reading compressed input. 64 | rowReader.close() 65 | } finally { 66 | rowReader = null 67 | } 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/SeekableByteArrayInputStream.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* 19 | SeekableByteArrayInputStream copied from 20 | https://github.com/apache/accumulo/blob/master/core/src/test/java/org/apache/accumulo/core/file/rfile/RFileTest.java 21 | */ 22 | 23 | package io.github.spark_redshift_community.spark.redshift; 24 | 25 | import org.apache.hadoop.fs.PositionedReadable; 26 | import org.apache.hadoop.fs.Seekable; 27 | 28 | import java.io.ByteArrayInputStream; 29 | import java.io.IOException; 30 | 31 | 32 | class SeekableByteArrayInputStream extends ByteArrayInputStream 33 | implements Seekable, PositionedReadable { 34 | 35 | public SeekableByteArrayInputStream(byte[] buf) { 36 | super(buf); 37 | } 38 | 39 | @Override 40 | public long getPos() { 41 | return pos; 42 | } 43 | 44 | @Override 45 | public void seek(long pos) throws IOException { 46 | if (mark != 0) 47 | throw new IllegalStateException(); 48 | 49 | reset(); 50 | long skipped = skip(pos); 51 | 52 | if (skipped != pos) 53 | throw new IOException(); 54 | } 55 | 56 | @Override 57 | public boolean seekToNewSource(long targetPos) { 58 | return false; 59 | } 60 | 61 | @Override 62 | public int read(long position, byte[] buffer, int offset, int length) { 63 | 64 | if (position >= buf.length) 65 | throw new IllegalArgumentException(); 66 | if (position + length > buf.length) 67 | throw new IllegalArgumentException(); 68 | if (length > buffer.length) 69 | throw new IllegalArgumentException(); 70 | 71 | System.arraycopy(buf, (int) position, buffer, offset, length); 72 | return length; 73 | } 74 | 75 | @Override 76 | public void readFully(long position, byte[] buffer) { 77 | read(position, buffer, 0, buffer.length); 78 | 79 | } 80 | 81 | @Override 82 | public void readFully(long position, byte[] buffer, int offset, int length) { 83 | read(position, buffer, offset, length); 84 | } 85 | 86 | } 87 | 88 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/IAMIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.SQLException 20 | 21 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 22 | import org.apache.spark.sql.{Row, SaveMode} 23 | 24 | /** 25 | * Integration tests for configuring Redshift to access S3 using Amazon IAM roles. 26 | */ 27 | class IAMIntegrationSuite extends IntegrationSuiteBase { 28 | 29 | private val IAM_ROLE_ARN: String = loadConfigFromEnv("STS_ROLE_ARN") 30 | 31 | // TODO (luca|issue #8) Fix IAM Authentication tests 32 | ignore("roundtrip save and load") { 33 | val tableName = s"iam_roundtrip_save_and_load$randomSuffix" 34 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 35 | StructType(StructField("a", IntegerType) :: Nil)) 36 | try { 37 | write(df) 38 | .option("dbtable", tableName) 39 | .option("forward_spark_s3_credentials", "false") 40 | .option("aws_iam_role", IAM_ROLE_ARN) 41 | .mode(SaveMode.ErrorIfExists) 42 | .save() 43 | 44 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 45 | val loadedDf = read 46 | .option("dbtable", tableName) 47 | .option("forward_spark_s3_credentials", "false") 48 | .option("aws_iam_role", IAM_ROLE_ARN) 49 | .load() 50 | assert(loadedDf.schema.length === 1) 51 | assert(loadedDf.columns === Seq("a")) 52 | checkAnswer(loadedDf, Seq(Row(1))) 53 | } finally { 54 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 55 | } 56 | } 57 | 58 | ignore("load fails if IAM role cannot be assumed") { 59 | val tableName = s"iam_load_fails_if_role_cannot_be_assumed$randomSuffix" 60 | try { 61 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 62 | StructType(StructField("a", IntegerType) :: Nil)) 63 | val err = intercept[SQLException] { 64 | write(df) 65 | .option("dbtable", tableName) 66 | .option("forward_spark_s3_credentials", "false") 67 | .option("aws_iam_role", IAM_ROLE_ARN + "-some-bogus-suffix") 68 | .mode(SaveMode.ErrorIfExists) 69 | .save() 70 | } 71 | assert(err.getCause.getMessage.contains("is not authorized to assume IAM Role")) 72 | } finally { 73 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/TableName.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import scala.collection.mutable.ArrayBuffer 20 | 21 | /** 22 | * Wrapper class for representing the name of a Redshift table. 23 | */ 24 | private[redshift] case class TableName(unescapedSchemaName: String, unescapedTableName: String) { 25 | private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"' 26 | def escapedSchemaName: String = quote(unescapedSchemaName) 27 | def escapedTableName: String = quote(unescapedTableName) 28 | override def toString: String = s"$escapedSchemaName.$escapedTableName" 29 | } 30 | 31 | private[redshift] object TableName { 32 | /** 33 | * Parses a table name which is assumed to have been escaped according to Redshift's rules for 34 | * delimited identifiers. 35 | */ 36 | def parseFromEscaped(str: String): TableName = { 37 | def dropOuterQuotes(s: String) = 38 | if (s.startsWith("\"") && s.endsWith("\"")) s.drop(1).dropRight(1) else s 39 | def unescapeQuotes(s: String) = s.replace("\"\"", "\"") 40 | def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s)) 41 | splitByDots(str) match { 42 | case Seq(tableName) => TableName("PUBLIC", unescape(tableName)) 43 | case Seq(schemaName, tableName) => TableName(unescape(schemaName), unescape(tableName)) 44 | case other => throw new IllegalArgumentException(s"Could not parse table name from '$str'") 45 | } 46 | } 47 | 48 | /** 49 | * Split by dots (.) while obeying our identifier quoting rules in order to allow dots to appear 50 | * inside of quoted identifiers. 51 | */ 52 | private def splitByDots(str: String): Seq[String] = { 53 | val parts: ArrayBuffer[String] = ArrayBuffer.empty 54 | val sb = new StringBuilder 55 | var inQuotes: Boolean = false 56 | for (c <- str) c match { 57 | case '"' => 58 | // Note that double quotes are escaped by pairs of double quotes (""), so we don't need 59 | // any extra code to handle them; we'll be back in inQuotes=true after seeing the pair. 60 | sb.append('"') 61 | inQuotes = !inQuotes 62 | case '.' => 63 | if (!inQuotes) { 64 | parts.append(sb.toString()) 65 | sb.clear() 66 | } else { 67 | sb.append('.') 68 | } 69 | case other => 70 | sb.append(other) 71 | } 72 | if (sb.nonEmpty) { 73 | parts.append(sb.toString()) 74 | } 75 | parts 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/UtilsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 TouchType Ltd 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.net.URI 20 | 21 | import org.scalatest.{FunSuite, Matchers} 22 | 23 | /** 24 | * Unit tests for helper functions 25 | */ 26 | class UtilsSuite extends FunSuite with Matchers { 27 | 28 | test("joinUrls preserves protocol information") { 29 | Utils.joinUrls("s3n://foo/bar/", "/baz") shouldBe "s3n://foo/bar/baz/" 30 | Utils.joinUrls("s3n://foo/bar/", "/baz/") shouldBe "s3n://foo/bar/baz/" 31 | Utils.joinUrls("s3n://foo/bar/", "baz/") shouldBe "s3n://foo/bar/baz/" 32 | Utils.joinUrls("s3n://foo/bar/", "baz") shouldBe "s3n://foo/bar/baz/" 33 | Utils.joinUrls("s3n://foo/bar", "baz") shouldBe "s3n://foo/bar/baz/" 34 | } 35 | 36 | test("joinUrls preserves credentials") { 37 | assert( 38 | Utils.joinUrls("s3n://ACCESSKEY:SECRETKEY@bucket/tempdir", "subdir") === 39 | "s3n://ACCESSKEY:SECRETKEY@bucket/tempdir/subdir/") 40 | } 41 | 42 | test("fixUrl produces Redshift-compatible equivalents") { 43 | Utils.fixS3Url("s3a://foo/bar/12345") shouldBe "s3://foo/bar/12345" 44 | Utils.fixS3Url("s3n://foo/bar/baz") shouldBe "s3://foo/bar/baz" 45 | } 46 | 47 | test("addEndpointToUrl produces urls with endpoints added to host") { 48 | Utils.addEndpointToUrl("s3a://foo/bar/12345") shouldBe "s3a://foo.s3.amazonaws.com/bar/12345" 49 | Utils.addEndpointToUrl("s3n://foo/bar/baz") shouldBe "s3n://foo.s3.amazonaws.com/bar/baz" 50 | } 51 | 52 | test("temp paths are random subdirectories of root") { 53 | val root = "s3n://temp/" 54 | val firstTempPath = Utils.makeTempPath(root) 55 | 56 | Utils.makeTempPath(root) should (startWith (root) and endWith ("/") 57 | and not equal root and not equal firstTempPath) 58 | } 59 | 60 | test("removeCredentialsFromURI removes AWS access keys") { 61 | def removeCreds(uri: String): String = { 62 | Utils.removeCredentialsFromURI(URI.create(uri)).toString 63 | } 64 | assert(removeCreds("s3n://bucket/path/to/temp/dir") === "s3n://bucket/path/to/temp/dir") 65 | assert( 66 | removeCreds("s3n://ACCESSKEY:SECRETKEY@bucket/path/to/temp/dir") === 67 | "s3n://bucket/path/to/temp/dir") 68 | } 69 | 70 | test("getRegionForRedshiftCluster") { 71 | val redshiftUrl = 72 | "jdbc:redshift://example.secret.us-west-2.redshift.amazonaws.com:5439/database" 73 | assert(Utils.getRegionForRedshiftCluster("mycluster.example.com") === None) 74 | assert(Utils.getRegionForRedshiftCluster(redshiftUrl) === Some("us-west-2")) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/FilterPushdown.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.{Date, Timestamp} 20 | 21 | import org.apache.spark.sql.sources._ 22 | import org.apache.spark.sql.types._ 23 | 24 | /** 25 | * Helper methods for pushing filters into Redshift queries. 26 | */ 27 | private[redshift] object FilterPushdown { 28 | /** 29 | * Build a SQL WHERE clause for the given filters. If a filter cannot be pushed down then no 30 | * condition will be added to the WHERE clause. If none of the filters can be pushed down then 31 | * an empty string will be returned. 32 | * 33 | * @param schema the schema of the table being queried 34 | * @param filters an array of filters, the conjunction of which is the filter condition for the 35 | * scan. 36 | */ 37 | def buildWhereClause(schema: StructType, filters: Seq[Filter]): String = { 38 | val filterExpressions = filters.flatMap(f => buildFilterExpression(schema, f)).mkString(" AND ") 39 | if (filterExpressions.isEmpty) "" else "WHERE " + filterExpressions 40 | } 41 | 42 | /** 43 | * Attempt to convert the given filter into a SQL expression. Returns None if the expression 44 | * could not be converted. 45 | */ 46 | def buildFilterExpression(schema: StructType, filter: Filter): Option[String] = { 47 | def buildComparison(attr: String, value: Any, comparisonOp: String): Option[String] = { 48 | getTypeForAttribute(schema, attr).map { dataType => 49 | val sqlEscapedValue: String = dataType match { 50 | case StringType => s"\\'${value.toString.replace("'", "\\'\\'")}\\'" 51 | case DateType => s"\\'${value.asInstanceOf[Date]}\\'" 52 | case TimestampType => s"\\'${value.asInstanceOf[Timestamp]}\\'" 53 | case _ => value.toString 54 | } 55 | s""""$attr" $comparisonOp $sqlEscapedValue""" 56 | } 57 | } 58 | 59 | filter match { 60 | case EqualTo(attr, value) => buildComparison(attr, value, "=") 61 | case LessThan(attr, value) => buildComparison(attr, value, "<") 62 | case GreaterThan(attr, value) => buildComparison(attr, value, ">") 63 | case LessThanOrEqual(attr, value) => buildComparison(attr, value, "<=") 64 | case GreaterThanOrEqual(attr, value) => buildComparison(attr, value, ">=") 65 | case IsNotNull(attr) => 66 | getTypeForAttribute(schema, attr).map(dataType => s""""$attr" IS NOT NULL""") 67 | case IsNull(attr) => 68 | getTypeForAttribute(schema, attr).map(dataType => s""""$attr" IS NULL""") 69 | case _ => None 70 | } 71 | } 72 | 73 | /** 74 | * Use the given schema to look up the attribute's data type. Returns None if the attribute could 75 | * not be resolved. 76 | */ 77 | private def getTypeForAttribute(schema: StructType, attribute: String): Option[DataType] = { 78 | if (schema.fieldNames.contains(attribute)) { 79 | Some(schema(attribute).dataType) 80 | } else { 81 | None 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/DecimalIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.spark.sql.Row 20 | import org.apache.spark.sql.types.DecimalType 21 | 22 | /** 23 | * Integration tests for decimal support. For a reference on Redshift's DECIMAL type, see 24 | * http://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html 25 | */ 26 | class DecimalIntegrationSuite extends IntegrationSuiteBase { 27 | 28 | private def testReadingDecimals(precision: Int, scale: Int, decimalStrings: Seq[String]): Unit = { 29 | test(s"reading DECIMAL($precision, $scale)") { 30 | val tableName = s"reading_decimal_${precision}_${scale}_$randomSuffix" 31 | val expectedRows = decimalStrings.map { d => 32 | if (d == null) { 33 | Row(null) 34 | } else { 35 | Row(Conversions.createRedshiftDecimalFormat().parse(d).asInstanceOf[java.math.BigDecimal]) 36 | } 37 | } 38 | try { 39 | conn.createStatement().executeUpdate( 40 | s"CREATE TABLE $tableName (x DECIMAL($precision, $scale))") 41 | for (x <- decimalStrings) { 42 | conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES ($x)") 43 | } 44 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 45 | val loadedDf = read.option("dbtable", tableName).load() 46 | checkAnswer(loadedDf, expectedRows) 47 | checkAnswer(loadedDf.selectExpr("x + 0"), expectedRows) 48 | } finally { 49 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 50 | } 51 | } 52 | } 53 | 54 | testReadingDecimals(19, 0, Seq( 55 | // Max and min values of DECIMAL(19, 0) column according to Redshift docs: 56 | "9223372036854775807", // 2^63 - 1 57 | "-9223372036854775807", 58 | "0", 59 | "12345678910", 60 | null 61 | )) 62 | 63 | testReadingDecimals(19, 4, Seq( 64 | "922337203685477.5807", 65 | "-922337203685477.5807", 66 | "0", 67 | "1234567.8910", 68 | null 69 | )) 70 | 71 | testReadingDecimals(38, 4, Seq( 72 | "922337203685477.5808", 73 | "9999999999999999999999999999999999.0000", 74 | "-9999999999999999999999999999999999.0000", 75 | "0", 76 | "1234567.8910", 77 | null 78 | )) 79 | 80 | test("Decimal precision is preserved when reading from query (regression test for issue #203)") { 81 | withTempRedshiftTable("issue203") { tableName => 82 | conn.createStatement().executeUpdate(s"CREATE TABLE $tableName (foo BIGINT)") 83 | conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES (91593373)") 84 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 85 | val df = read 86 | .option("query", s"select foo / 1000000.0 from $tableName limit 1") 87 | .load() 88 | val res: Double = df.collect().toSeq.head.getDecimal(0).doubleValue() 89 | assert(res === (91593373L / 1000000.0) +- 0.01) 90 | assert(df.schema.fields.head.dataType === DecimalType(28, 8)) 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/QueryTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift 19 | 20 | import org.apache.spark.sql.catalyst.plans.logical 21 | import org.apache.spark.sql.{DataFrame, Row} 22 | import org.scalatest.FunSuite 23 | 24 | /** 25 | * Copy of Spark SQL's `QueryTest` trait. 26 | */ 27 | trait QueryTest extends FunSuite { 28 | /** 29 | * Runs the plan and makes sure the answer matches the expected result. 30 | * @param df the [[DataFrame]] to be executed 31 | * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. 32 | */ 33 | def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { 34 | val isSorted = df.queryExecution.logical.collect { case s: logical.Sort => s }.nonEmpty 35 | def prepareAnswer(answer: Seq[Row]): Seq[Row] = { 36 | // Converts data to types that we can do equality comparison using Scala collections. 37 | // For BigDecimal type, the Scala type has a better definition of equality test (similar to 38 | // Java's java.math.BigDecimal.compareTo). 39 | // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for 40 | // equality test. 41 | val converted: Seq[Row] = answer.map { s => 42 | Row.fromSeq(s.toSeq.map { 43 | case d: java.math.BigDecimal => BigDecimal(d) 44 | case b: Array[Byte] => b.toSeq 45 | case o => o 46 | }) 47 | } 48 | if (!isSorted) converted.sortBy(_.toString()) else converted 49 | } 50 | val sparkAnswer = try df.collect().toSeq catch { 51 | case e: Exception => 52 | val errorMessage = 53 | s""" 54 | |Exception thrown while executing query: 55 | |${df.queryExecution} 56 | |== Exception == 57 | |$e 58 | |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} 59 | """.stripMargin 60 | fail(errorMessage) 61 | } 62 | 63 | if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { 64 | val errorMessage = 65 | s""" 66 | |Results do not match for query: 67 | |${df.queryExecution} 68 | |== Results == 69 | |${sideBySide( 70 | s"== Correct Answer - ${expectedAnswer.size} ==" +: 71 | prepareAnswer(expectedAnswer).map(_.toString()), 72 | s"== Spark Answer - ${sparkAnswer.size} ==" +: 73 | prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} 74 | """.stripMargin 75 | fail(errorMessage) 76 | } 77 | } 78 | 79 | private def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { 80 | val maxLeftSize = left.map(_.length).max 81 | val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") 82 | val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") 83 | 84 | leftPadded.zip(rightPadded).map { 85 | case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/FilterPushdownSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import io.github.spark_redshift_community.spark.redshift.FilterPushdown._ 20 | import org.apache.spark.sql.sources._ 21 | import org.apache.spark.sql.types._ 22 | import org.scalatest.FunSuite 23 | 24 | 25 | class FilterPushdownSuite extends FunSuite { 26 | test("buildWhereClause with empty list of filters") { 27 | assert(buildWhereClause(StructType(Nil), Seq.empty) === "") 28 | } 29 | 30 | test("buildWhereClause with no filters that can be pushed down") { 31 | assert(buildWhereClause(StructType(Nil), Seq(NewFilter, NewFilter)) === "") 32 | } 33 | 34 | test("buildWhereClause with with some filters that cannot be pushed down") { 35 | val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_int", 1), NewFilter)) 36 | assert(whereClause === """WHERE "test_int" = 1""") 37 | } 38 | 39 | test("buildWhereClause with string literals that contain Unicode characters") { 40 | // scalastyle:off 41 | val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_string", "Unicode's樂趣"))) 42 | // Here, the apostrophe in the string needs to be replaced with two single quotes, '', but we 43 | // also need to escape those quotes with backslashes because this WHERE clause is going to 44 | // eventually be embedded inside of a single-quoted string that's embedded inside of a larger 45 | // Redshift query. 46 | assert(whereClause === """WHERE "test_string" = \'Unicode\'\'s樂趣\'""") 47 | // scalastyle:on 48 | } 49 | 50 | test("buildWhereClause with multiple filters") { 51 | val filters = Seq( 52 | EqualTo("test_bool", true), 53 | // scalastyle:off 54 | EqualTo("test_string", "Unicode是樂趣"), 55 | // scalastyle:on 56 | GreaterThan("test_double", 1000.0), 57 | LessThan("test_double", Double.MaxValue), 58 | GreaterThanOrEqual("test_float", 1.0f), 59 | LessThanOrEqual("test_int", 43), 60 | IsNotNull("test_int"), 61 | IsNull("test_int")) 62 | val whereClause = buildWhereClause(testSchema, filters) 63 | // scalastyle:off 64 | val expectedWhereClause = 65 | """ 66 | |WHERE "test_bool" = true 67 | |AND "test_string" = \'Unicode是樂趣\' 68 | |AND "test_double" > 1000.0 69 | |AND "test_double" < 1.7976931348623157E308 70 | |AND "test_float" >= 1.0 71 | |AND "test_int" <= 43 72 | |AND "test_int" IS NOT NULL 73 | |AND "test_int" IS NULL 74 | """.stripMargin.lines.mkString(" ").trim 75 | // scalastyle:on 76 | assert(whereClause === expectedWhereClause) 77 | } 78 | 79 | private val testSchema: StructType = StructType(Seq( 80 | StructField("test_byte", ByteType), 81 | StructField("test_bool", BooleanType), 82 | StructField("test_date", DateType), 83 | StructField("test_double", DoubleType), 84 | StructField("test_float", FloatType), 85 | StructField("test_int", IntegerType), 86 | StructField("test_long", LongType), 87 | StructField("test_short", ShortType), 88 | StructField("test_string", StringType), 89 | StructField("test_timestamp", TimestampType))) 90 | 91 | /** A new filter subclasss which our pushdown logic does not know how to handle */ 92 | private case object NewFilter extends Filter { 93 | override def references: Array[String] = Array.empty 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftFileFormat.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.net.URI 20 | 21 | import org.apache.hadoop.conf.Configuration 22 | import org.apache.hadoop.fs.{FileStatus, Path} 23 | import org.apache.hadoop.mapreduce._ 24 | import org.apache.hadoop.mapreduce.lib.input.FileSplit 25 | import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl 26 | import org.apache.spark.TaskContext 27 | import org.apache.spark.sql.SparkSession 28 | import org.apache.spark.sql.catalyst.InternalRow 29 | import org.apache.spark.sql.execution.datasources._ 30 | import org.apache.spark.sql.sources.Filter 31 | import org.apache.spark.sql.types.{DataType, StructType} 32 | 33 | /** 34 | * Internal data source used for reading Redshift UNLOAD files. 35 | * 36 | * This is not intended for public consumption / use outside of this package and therefore 37 | * no API stability is guaranteed. 38 | */ 39 | private[redshift] class RedshiftFileFormat extends FileFormat { 40 | override def inferSchema( 41 | sparkSession: SparkSession, 42 | options: Map[String, String], 43 | files: Seq[FileStatus]): Option[StructType] = { 44 | // Schema is provided by caller. 45 | None 46 | } 47 | 48 | override def prepareWrite( 49 | sparkSession: SparkSession, 50 | job: Job, 51 | options: Map[String, String], 52 | dataSchema: StructType): OutputWriterFactory = { 53 | throw new UnsupportedOperationException(s"prepareWrite is not supported for $this") 54 | } 55 | 56 | override def isSplitable( 57 | sparkSession: SparkSession, 58 | options: Map[String, String], 59 | path: Path): Boolean = { 60 | // Our custom InputFormat handles split records properly 61 | true 62 | } 63 | 64 | override def buildReader( 65 | sparkSession: SparkSession, 66 | dataSchema: StructType, 67 | partitionSchema: StructType, 68 | requiredSchema: StructType, 69 | filters: Seq[Filter], 70 | options: Map[String, String], 71 | hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { 72 | 73 | require(partitionSchema.isEmpty) 74 | require(filters.isEmpty) 75 | require(dataSchema == requiredSchema) 76 | 77 | val broadcastedConf = 78 | sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) 79 | 80 | (file: PartitionedFile) => { 81 | val conf = broadcastedConf.value.value 82 | 83 | val fileSplit = new FileSplit( 84 | new Path(new URI(file.filePath)), 85 | file.start, 86 | file.length, 87 | // TODO: Implement Locality 88 | Array.empty) 89 | val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) 90 | val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) 91 | val reader = new RedshiftRecordReader 92 | reader.initialize(fileSplit, hadoopAttemptContext) 93 | val iter = new RecordReaderIterator[Array[String]](reader) 94 | // Ensure that the record reader is closed upon task completion. It will ordinarily 95 | // be closed once it is completely iterated, but this is necessary to guard against 96 | // resource leaks in case the task fails or is interrupted. 97 | Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) 98 | val converter = Conversions.createRowConverter(requiredSchema, 99 | options.getOrElse("nullString", Parameters.DEFAULT_PARAMETERS("csvnullstring"))) 100 | iter.map(converter) 101 | } 102 | } 103 | 104 | override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true 105 | } 106 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 TouchType Ltd 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import com.amazonaws.auth.AWSCredentialsProvider 20 | import com.amazonaws.services.s3.AmazonS3Client 21 | import io.github.spark_redshift_community.spark.redshift 22 | import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider} 23 | import org.apache.spark.sql.types.StructType 24 | import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} 25 | import org.slf4j.LoggerFactory 26 | 27 | /** 28 | * Redshift Source implementation for Spark SQL 29 | */ 30 | class DefaultSource( 31 | jdbcWrapper: JDBCWrapper, 32 | s3ClientFactory: AWSCredentialsProvider => AmazonS3Client) 33 | extends RelationProvider 34 | with SchemaRelationProvider 35 | with CreatableRelationProvider { 36 | 37 | private val log = LoggerFactory.getLogger(getClass) 38 | 39 | /** 40 | * Default constructor required by Data Source API 41 | */ 42 | def this() = this(DefaultJDBCWrapper, awsCredentials => new AmazonS3Client(awsCredentials)) 43 | 44 | /** 45 | * Create a new RedshiftRelation instance using parameters from Spark SQL DDL. Resolves the schema 46 | * using JDBC connection over provided URL, which must contain credentials. 47 | */ 48 | override def createRelation( 49 | sqlContext: SQLContext, 50 | parameters: Map[String, String]): BaseRelation = { 51 | val params = Parameters.mergeParameters(parameters) 52 | redshift.RedshiftRelation(jdbcWrapper, s3ClientFactory, params, None)(sqlContext) 53 | } 54 | 55 | /** 56 | * Load a RedshiftRelation using user-provided schema, so no inference over JDBC will be used. 57 | */ 58 | override def createRelation( 59 | sqlContext: SQLContext, 60 | parameters: Map[String, String], 61 | schema: StructType): BaseRelation = { 62 | val params = Parameters.mergeParameters(parameters) 63 | redshift.RedshiftRelation(jdbcWrapper, s3ClientFactory, params, Some(schema))(sqlContext) 64 | } 65 | 66 | /** 67 | * Creates a Relation instance by first writing the contents of the given DataFrame to Redshift 68 | */ 69 | override def createRelation( 70 | sqlContext: SQLContext, 71 | saveMode: SaveMode, 72 | parameters: Map[String, String], 73 | data: DataFrame): BaseRelation = { 74 | val params = Parameters.mergeParameters(parameters) 75 | val table = params.table.getOrElse { 76 | throw new IllegalArgumentException( 77 | "For save operations you must specify a Redshift table name with the 'dbtable' parameter") 78 | } 79 | 80 | def tableExists: Boolean = { 81 | val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) 82 | try { 83 | jdbcWrapper.tableExists(conn, table.toString) 84 | } finally { 85 | conn.close() 86 | } 87 | } 88 | 89 | val (doSave, dropExisting) = saveMode match { 90 | case SaveMode.Append => (true, false) 91 | case SaveMode.Overwrite => (true, true) 92 | case SaveMode.ErrorIfExists => 93 | if (tableExists) { 94 | sys.error(s"Table $table already exists! (SaveMode is set to ErrorIfExists)") 95 | } else { 96 | (true, false) 97 | } 98 | case SaveMode.Ignore => 99 | if (tableExists) { 100 | log.info(s"Table $table already exists -- ignoring save request.") 101 | (false, false) 102 | } else { 103 | (true, false) 104 | } 105 | } 106 | 107 | if (doSave) { 108 | val updatedParams = parameters.updated("overwrite", dropExisting.toString) 109 | new RedshiftWriter(jdbcWrapper, s3ClientFactory).saveToRedshift( 110 | sqlContext, data, saveMode, Parameters.mergeParameters(updatedParams)) 111 | } 112 | 113 | createRelation(sqlContext, parameters) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/InMemoryS3AFileSystemSuite.scala: -------------------------------------------------------------------------------- 1 | package io.github.spark_redshift_community.spark.redshift 2 | 3 | import java.io.FileNotFoundException 4 | 5 | import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path} 6 | import org.scalatest.{FunSuite, Matchers} 7 | 8 | class InMemoryS3AFileSystemSuite extends FunSuite with Matchers { 9 | 10 | test("Create a file creates all prefixes in the hierarchy") { 11 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem() 12 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/_SUCCESS") 13 | 14 | inMemoryS3AFileSystem.create(path) 15 | 16 | assert( 17 | inMemoryS3AFileSystem.exists( 18 | new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/_SUCCESS"))) 19 | 20 | assert( 21 | inMemoryS3AFileSystem.exists( 22 | new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/"))) 23 | 24 | assert(inMemoryS3AFileSystem.exists(new Path("s3a://test-bucket/temp-dir/"))) 25 | 26 | } 27 | 28 | test("List all statuses for a dir") { 29 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem() 30 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/_SUCCESS") 31 | val path2 = new Path( 32 | "s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/manifest.json") 33 | 34 | inMemoryS3AFileSystem.create(path) 35 | inMemoryS3AFileSystem.create(path2) 36 | 37 | assert( 38 | inMemoryS3AFileSystem.listStatus( 39 | new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328") 40 | ).length == 2) 41 | 42 | assert( 43 | inMemoryS3AFileSystem.listStatus( 44 | new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328") 45 | ) === Array[FileStatus] ( 46 | inMemoryS3AFileSystem.getFileStatus(path2), 47 | inMemoryS3AFileSystem.getFileStatus(path)) 48 | ) 49 | 50 | assert( 51 | inMemoryS3AFileSystem.listStatus( 52 | new Path("s3a://test-bucket/temp-dir/")).length == 1) 53 | } 54 | 55 | test("getFileStatus for file and dir") { 56 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem() 57 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/_SUCCESS") 58 | 59 | inMemoryS3AFileSystem.create(path) 60 | 61 | assert(inMemoryS3AFileSystem.getFileStatus(path).isDirectory === false) 62 | 63 | val dirPath = new Path( 64 | "s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328") 65 | val dirPathFileStatus = inMemoryS3AFileSystem.getFileStatus(dirPath) 66 | assert(dirPathFileStatus.isDirectory === true) 67 | assert(dirPathFileStatus.isEmptyDirectory === false) 68 | 69 | } 70 | 71 | test("Open a file from InMemoryS3AFileSystem") { 72 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem() 73 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000") 74 | 75 | inMemoryS3AFileSystem.create(path).write("some data".getBytes()) 76 | 77 | var result = new Array[Byte](9) 78 | inMemoryS3AFileSystem.open(path).read(result) 79 | 80 | assert(result === "some data".getBytes()) 81 | 82 | } 83 | 84 | test ("delete file from FileSystem") { 85 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem() 86 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000") 87 | 88 | inMemoryS3AFileSystem.create(path) 89 | 90 | assert(inMemoryS3AFileSystem.exists(path)) 91 | 92 | inMemoryS3AFileSystem.delete(path, false) 93 | assert(inMemoryS3AFileSystem.exists(path) === false) 94 | 95 | } 96 | 97 | test("create already existing file throws FileAlreadyExistsException"){ 98 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem() 99 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000") 100 | inMemoryS3AFileSystem.create(path) 101 | assertThrows[FileAlreadyExistsException](inMemoryS3AFileSystem.create(path)) 102 | } 103 | 104 | test("getFileStatus can't find file"){ 105 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem() 106 | 107 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000") 108 | assertThrows[FileNotFoundException](inMemoryS3AFileSystem.getFileStatus(path)) 109 | } 110 | 111 | test("listStatus can't find path"){ 112 | val inMemoryS3AFileSystem = new InMemoryS3AFileSystem() 113 | 114 | val path = new Path("s3a://test-bucket/temp-dir/ba7e0bf3-25a0-4435-b7a5-fdb6b3d2d328/part0000") 115 | assertThrows[FileNotFoundException](inMemoryS3AFileSystem.listStatus(path)) 116 | } 117 | 118 | } 119 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/AWSCredentialsUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.net.URI 20 | 21 | import com.amazonaws.auth._ 22 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters 23 | import org.apache.hadoop.conf.Configuration 24 | 25 | private[redshift] object AWSCredentialsUtils { 26 | 27 | /** 28 | * Generates a credentials string for use in Redshift COPY and UNLOAD statements. 29 | * Favors a configured `aws_iam_role` if available in the parameters. 30 | */ 31 | def getRedshiftCredentialsString( 32 | params: MergedParameters, 33 | sparkAwsCredentials: AWSCredentials): String = { 34 | 35 | def awsCredsToString(credentials: AWSCredentials): String = { 36 | credentials match { 37 | case creds: AWSSessionCredentials => 38 | s"aws_access_key_id=${creds.getAWSAccessKeyId};" + 39 | s"aws_secret_access_key=${creds.getAWSSecretKey};token=${creds.getSessionToken}" 40 | case creds => 41 | s"aws_access_key_id=${creds.getAWSAccessKeyId};" + 42 | s"aws_secret_access_key=${creds.getAWSSecretKey}" 43 | } 44 | } 45 | if (params.iamRole.isDefined) { 46 | s"aws_iam_role=${params.iamRole.get}" 47 | } else if (params.temporaryAWSCredentials.isDefined) { 48 | awsCredsToString(params.temporaryAWSCredentials.get.getCredentials) 49 | } else if (params.forwardSparkS3Credentials) { 50 | awsCredsToString(sparkAwsCredentials) 51 | } else { 52 | throw new IllegalStateException("No Redshift S3 authentication mechanism was specified") 53 | } 54 | } 55 | 56 | def staticCredentialsProvider(credentials: AWSCredentials): AWSCredentialsProvider = { 57 | new AWSCredentialsProvider { 58 | override def getCredentials: AWSCredentials = credentials 59 | override def refresh(): Unit = {} 60 | } 61 | } 62 | 63 | def load(params: MergedParameters, hadoopConfiguration: Configuration): AWSCredentialsProvider = { 64 | params.temporaryAWSCredentials.getOrElse(loadFromURI(params.rootTempDir, hadoopConfiguration)) 65 | } 66 | 67 | private def loadFromURI( 68 | tempPath: String, 69 | hadoopConfiguration: Configuration): AWSCredentialsProvider = { 70 | // scalastyle:off 71 | // A good reference on Hadoop's configuration loading / precedence is 72 | // https://github.com/apache/hadoop/blob/trunk/hadoop-tools/hadoop-aws/src/site/markdown/tools/hadoop-aws/index.md 73 | // scalastyle:on 74 | val uri = new URI(tempPath) 75 | val uriScheme = uri.getScheme 76 | 77 | uriScheme match { 78 | case "s3" | "s3n" | "s3a" => 79 | // WARNING: credentials in the URI is a potentially unsafe practice. I'm removing the test 80 | // AWSCredentialsInUriIntegrationSuite, so the following might or might not work. 81 | 82 | // This matches what S3A does, with one exception: we don't support anonymous credentials. 83 | // First, try to parse from URI: 84 | Option(uri.getUserInfo).flatMap { userInfo => 85 | if (userInfo.contains(":")) { 86 | val Array(accessKey, secretKey) = userInfo.split(":") 87 | Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey))) 88 | } else { 89 | None 90 | } 91 | }.orElse { 92 | // Next, try to read from configuration 93 | val accessKeyConfig = if (uriScheme == "s3a") "access.key" else "awsAccessKeyId" 94 | val secretKeyConfig = if (uriScheme == "s3a") "secret.key" else "awsSecretAccessKey" 95 | 96 | val accessKey = hadoopConfiguration.get(s"fs.$uriScheme.$accessKeyConfig", null) 97 | val secretKey = hadoopConfiguration.get(s"fs.$uriScheme.$secretKeyConfig", null) 98 | if (accessKey != null && secretKey != null) { 99 | Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey))) 100 | } else { 101 | None 102 | } 103 | }.getOrElse { 104 | // Finally, fall back on the instance profile provider 105 | new DefaultAWSCredentialsProviderChain() 106 | } 107 | case other => 108 | throw new IllegalArgumentException(s"Unrecognized scheme $other; expected s3, s3n, or s3a") 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/MockRedshift.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.{Connection, PreparedStatement, ResultSet, SQLException} 20 | 21 | import org.apache.spark.sql.types.StructType 22 | import org.mockito.Matchers._ 23 | import org.mockito.Mockito._ 24 | import org.mockito.invocation.InvocationOnMock 25 | import org.mockito.stubbing.Answer 26 | import org.scalatest.Assertions._ 27 | 28 | import scala.collection.mutable 29 | import scala.util.matching.Regex 30 | 31 | 32 | /** 33 | * Helper class for mocking Redshift / JDBC in unit tests. 34 | */ 35 | class MockRedshift( 36 | jdbcUrl: String, 37 | existingTablesAndSchemas: Map[String, StructType], 38 | jdbcQueriesThatShouldFail: Seq[Regex] = Seq.empty) { 39 | 40 | private[this] val queriesIssued: mutable.Buffer[String] = mutable.Buffer.empty 41 | def getQueriesIssuedAgainstRedshift: Seq[String] = queriesIssued.toSeq 42 | 43 | private[this] val jdbcConnections: mutable.Buffer[Connection] = mutable.Buffer.empty 44 | 45 | val jdbcWrapper: JDBCWrapper = spy(new JDBCWrapper) 46 | 47 | private def createMockConnection(): Connection = { 48 | val conn = mock(classOf[Connection], RETURNS_SMART_NULLS) 49 | jdbcConnections.append(conn) 50 | when(conn.prepareStatement(anyString())).thenAnswer(new Answer[PreparedStatement] { 51 | override def answer(invocation: InvocationOnMock): PreparedStatement = { 52 | val query = invocation.getArguments()(0).asInstanceOf[String] 53 | queriesIssued.append(query) 54 | val mockStatement = mock(classOf[PreparedStatement], RETURNS_SMART_NULLS) 55 | if (jdbcQueriesThatShouldFail.forall(_.findFirstMatchIn(query).isEmpty)) { 56 | when(mockStatement.execute()).thenReturn(true) 57 | when(mockStatement.executeQuery()).thenReturn( 58 | mock(classOf[ResultSet], RETURNS_SMART_NULLS)) 59 | } else { 60 | when(mockStatement.execute()).thenThrow(new SQLException(s"Error executing $query")) 61 | when(mockStatement.executeQuery()).thenThrow(new SQLException(s"Error executing $query")) 62 | } 63 | mockStatement 64 | } 65 | }) 66 | conn 67 | } 68 | 69 | doAnswer(new Answer[Connection] { 70 | override def answer(invocation: InvocationOnMock): Connection = createMockConnection() 71 | }).when(jdbcWrapper) 72 | .getConnector(any[Option[String]](), same(jdbcUrl), any[Option[(String, String)]]()) 73 | 74 | doAnswer(new Answer[Boolean] { 75 | override def answer(invocation: InvocationOnMock): Boolean = { 76 | existingTablesAndSchemas.contains(invocation.getArguments()(1).asInstanceOf[String]) 77 | } 78 | }).when(jdbcWrapper).tableExists(any[Connection], anyString()) 79 | 80 | doAnswer(new Answer[StructType] { 81 | override def answer(invocation: InvocationOnMock): StructType = { 82 | existingTablesAndSchemas(invocation.getArguments()(1).asInstanceOf[String]) 83 | } 84 | }).when(jdbcWrapper).resolveTable(any[Connection], anyString()) 85 | 86 | def verifyThatConnectionsWereClosed(): Unit = { 87 | jdbcConnections.foreach { conn => 88 | verify(conn).close() 89 | } 90 | } 91 | 92 | def verifyThatRollbackWasCalled(): Unit = { 93 | jdbcConnections.foreach { conn => 94 | verify(conn, atLeastOnce()).rollback() 95 | } 96 | } 97 | 98 | def verifyThatCommitWasNotCalled(): Unit = { 99 | jdbcConnections.foreach { conn => 100 | verify(conn, never()).commit() 101 | } 102 | } 103 | 104 | def verifyThatExpectedQueriesWereIssued(expectedQueries: Seq[Regex]): Unit = { 105 | expectedQueries.zip(queriesIssued).foreach { case (expected, actual) => 106 | if (expected.findFirstMatchIn(actual).isEmpty) { 107 | fail( 108 | s""" 109 | |Actual and expected JDBC queries did not match: 110 | |Expected: $expected 111 | |Actual: $actual 112 | """.stripMargin) 113 | } 114 | } 115 | if (expectedQueries.length > queriesIssued.length) { 116 | val missingQueries = expectedQueries.drop(queriesIssued.length) 117 | fail(s"Missing ${missingQueries.length} expected JDBC queries:" + 118 | s"\n${missingQueries.mkString("\n")}") 119 | } else if (queriesIssued.length > expectedQueries.length) { 120 | val extraQueries = queriesIssued.drop(expectedQueries.length) 121 | fail(s"Got ${extraQueries.length} unexpected JDBC queries:\n${extraQueries.mkString("\n")}") 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/TestUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 TouchType Ltd 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.{Date, Timestamp} 20 | import java.time.ZoneId 21 | import java.util.{Calendar, Locale, TimeZone} 22 | 23 | import org.apache.spark.sql.Row 24 | import org.apache.spark.sql.types._ 25 | 26 | /** 27 | * Helpers for Redshift tests that require common mocking 28 | */ 29 | object TestUtils { 30 | 31 | /** 32 | * Simple schema that includes all data types we support 33 | */ 34 | val testSchema: StructType = { 35 | // These column names need to be lowercase; see #51 36 | StructType(Seq( 37 | StructField("testbyte", ByteType), 38 | StructField("testbool", BooleanType), 39 | StructField("testdate", DateType), 40 | StructField("testdouble", DoubleType), 41 | StructField("testfloat", FloatType), 42 | StructField("testint", IntegerType), 43 | StructField("testlong", LongType), 44 | StructField("testshort", ShortType), 45 | StructField("teststring", StringType), 46 | StructField("testtimestamp", TimestampType))) 47 | } 48 | 49 | // scalastyle:off 50 | /** 51 | * Expected parsed output corresponding to the output of testData. 52 | */ 53 | val expectedData: Seq[Row] = Seq( 54 | Row(1.toByte, true, TestUtils.toDate(2015, 6, 1), 1234152.12312498, 55 | 1.0f, 42, 1239012341823719L, 23.toShort, "Unicode's樂趣", 56 | TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1)), 57 | Row(1.toByte, false, TestUtils.toDate(2015, 6, 2), 0.0, 0.0f, 42, 58 | 1239012341823719L, -13.toShort, "asdf", TestUtils.toTimestamp(2015, 6, 2, 0, 0, 0, 0)), 59 | Row(0.toByte, null, TestUtils.toDate(2015, 6, 3), 0.0, -1.0f, 4141214, 60 | 1239012341823719L, null, "f", TestUtils.toTimestamp(2015, 6, 3, 0, 0, 0)), 61 | Row(0.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L, 24.toShort, 62 | "___|_123", null), 63 | Row(List.fill(10)(null): _*)) 64 | // scalastyle:on 65 | 66 | /** 67 | * The same as `expectedData`, but with dates and timestamps converted into string format. 68 | * See #39 for context. 69 | */ 70 | val expectedDataWithConvertedTimesAndDates: Seq[Row] = expectedData.map { row => 71 | Row.fromSeq(row.toSeq.map { 72 | case t: Timestamp => Conversions.createRedshiftTimestampFormat().format(t) 73 | case d: Date => Conversions.createRedshiftDateFormat().format(d) 74 | case other => other 75 | }) 76 | } 77 | 78 | /** 79 | * Convert date components to a millisecond timestamp 80 | */ 81 | def toMillis( 82 | year: Int, 83 | zeroBasedMonth: Int, 84 | date: Int, 85 | hour: Int, 86 | minutes: Int, 87 | seconds: Int, 88 | millis: Int = 0, 89 | timeZone: String = null): Long = { 90 | val calendar = Calendar.getInstance() 91 | calendar.set(year, zeroBasedMonth, date, hour, minutes, seconds) 92 | calendar.set(Calendar.MILLISECOND, millis) 93 | if (timeZone != null) calendar.setTimeZone(TimeZone.getTimeZone(ZoneId.of(timeZone))) 94 | calendar.getTime.getTime 95 | } 96 | 97 | def toNanosTimestamp( 98 | year: Int, 99 | zeroBasedMonth: Int, 100 | date: Int, 101 | hour: Int, 102 | minutes: Int, 103 | seconds: Int, 104 | nanos: Int 105 | ): Timestamp = { 106 | val ts = new Timestamp( 107 | toMillis( 108 | year, 109 | zeroBasedMonth, 110 | date, 111 | hour, 112 | minutes, 113 | seconds 114 | ) 115 | ) 116 | ts.setNanos(nanos) 117 | ts 118 | } 119 | 120 | /** 121 | * Convert date components to a SQL Timestamp 122 | */ 123 | def toTimestamp( 124 | year: Int, 125 | zeroBasedMonth: Int, 126 | date: Int, 127 | hour: Int, 128 | minutes: Int, 129 | seconds: Int, 130 | millis: Int = 0): Timestamp = { 131 | new Timestamp(toMillis(year, zeroBasedMonth, date, hour, minutes, seconds, millis)) 132 | } 133 | 134 | /** 135 | * Convert date components to a SQL [[Date]]. 136 | */ 137 | def toDate(year: Int, zeroBasedMonth: Int, date: Int): Date = { 138 | new Date(toTimestamp(year, zeroBasedMonth, date, 0, 0, 0).getTime) 139 | } 140 | 141 | def withDefaultLocale[T](newDefaultLocale: Locale)(block: => T): T = { 142 | val originalDefaultLocale = Locale.getDefault 143 | try { 144 | Locale.setDefault(newDefaultLocale) 145 | block 146 | } finally { 147 | Locale.setDefault(originalDefaultLocale) 148 | } 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/ColumnMetadataSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.SQLException 20 | 21 | import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField, StructType} 22 | import org.apache.spark.sql.{Row, SaveMode} 23 | 24 | /** 25 | * End-to-end tests of features which depend on per-column metadata (such as comments, maxlength). 26 | */ 27 | class ColumnMetadataSuite extends IntegrationSuiteBase { 28 | 29 | test("configuring maxlength on string columns") { 30 | val tableName = s"configuring_maxlength_on_string_column_$randomSuffix" 31 | try { 32 | val metadata = new MetadataBuilder().putLong("maxlength", 512).build() 33 | val schema = StructType( 34 | StructField("x", StringType, metadata = metadata) :: Nil) 35 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 512))), schema)) 36 | .option("dbtable", tableName) 37 | .mode(SaveMode.ErrorIfExists) 38 | .save() 39 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 40 | checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 512))) 41 | // This append should fail due to the string being longer than the maxlength 42 | intercept[SQLException] { 43 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 513))), schema)) 44 | .option("dbtable", tableName) 45 | .mode(SaveMode.Append) 46 | .save() 47 | } 48 | } finally { 49 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 50 | } 51 | } 52 | 53 | test("configuring compression on columns") { 54 | val tableName = s"configuring_compression_on_columns_$randomSuffix" 55 | try { 56 | val metadata = new MetadataBuilder().putString("encoding", "LZO").build() 57 | val schema = StructType( 58 | StructField("x", StringType, metadata = metadata) :: Nil) 59 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 128))), schema)) 60 | .option("dbtable", tableName) 61 | .mode(SaveMode.ErrorIfExists) 62 | .save() 63 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 64 | checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 128))) 65 | val encodingDF = sqlContext.read 66 | .format("jdbc") 67 | .option("url", jdbcUrl) 68 | .option("dbtable", 69 | s"""(SELECT "column", lower(encoding) FROM pg_table_def WHERE tablename='$tableName')""") 70 | .load() 71 | checkAnswer(encodingDF, Seq(Row("x", "lzo"))) 72 | } finally { 73 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 74 | } 75 | } 76 | 77 | test("configuring comments on columns") { 78 | val tableName = s"configuring_comments_on_columns_$randomSuffix" 79 | try { 80 | val metadata = new MetadataBuilder().putString("description", "Hello Column").build() 81 | val schema = StructType( 82 | StructField("x", StringType, metadata = metadata) :: Nil) 83 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 128))), schema)) 84 | .option("dbtable", tableName) 85 | .option("description", "Hello Table") 86 | .mode(SaveMode.ErrorIfExists) 87 | .save() 88 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 89 | checkAnswer(read.option("dbtable", tableName).load(), Seq(Row("a" * 128))) 90 | val tableDF = sqlContext.read 91 | .format("jdbc") 92 | .option("url", jdbcUrl) 93 | .option("dbtable", s"(SELECT pg_catalog.obj_description('$tableName'::regclass))") 94 | .load() 95 | checkAnswer(tableDF, Seq(Row("Hello Table"))) 96 | val commentQuery = 97 | s""" 98 | |(SELECT c.column_name, pgd.description 99 | |FROM pg_catalog.pg_statio_all_tables st 100 | |INNER JOIN pg_catalog.pg_description pgd 101 | | ON (pgd.objoid=st.relid) 102 | |INNER JOIN information_schema.columns c 103 | | ON (pgd.objsubid=c.ordinal_position AND c.table_name=st.relname) 104 | |WHERE c.table_name='$tableName') 105 | """.stripMargin 106 | val columnDF = sqlContext.read 107 | .format("jdbc") 108 | .option("url", jdbcUrl) 109 | .option("dbtable", commentQuery) 110 | .load() 111 | checkAnswer(columnDF, Seq(Row("x", "Hello Column"))) 112 | } finally { 113 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/SaveModeIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 20 | import org.apache.spark.sql.{Row, SaveMode} 21 | 22 | /** 23 | * End-to-end tests of [[SaveMode]] behavior. 24 | */ 25 | class SaveModeIntegrationSuite extends IntegrationSuiteBase { 26 | test("SaveMode.Overwrite with schema-qualified table name (#97)") { 27 | withTempRedshiftTable("overwrite_schema_qualified_table_name") { tableName => 28 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 29 | StructType(StructField("a", IntegerType) :: Nil)) 30 | // Ensure that the table exists: 31 | write(df) 32 | .option("dbtable", tableName) 33 | .mode(SaveMode.ErrorIfExists) 34 | .save() 35 | assert(DefaultJDBCWrapper.tableExists(conn, s"PUBLIC.$tableName")) 36 | // Try overwriting that table while using the schema-qualified table name: 37 | write(df) 38 | .option("dbtable", s"PUBLIC.$tableName") 39 | .mode(SaveMode.Overwrite) 40 | .save() 41 | } 42 | } 43 | 44 | test("SaveMode.Overwrite with non-existent table") { 45 | testRoundtripSaveAndLoad( 46 | s"overwrite_non_existent_table$randomSuffix", 47 | sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 48 | StructType(StructField("a", IntegerType) :: Nil)), 49 | saveMode = SaveMode.Overwrite) 50 | } 51 | 52 | test("SaveMode.Overwrite with existing table") { 53 | withTempRedshiftTable("overwrite_existing_table") { tableName => 54 | // Create a table to overwrite 55 | write(sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 56 | StructType(StructField("a", IntegerType) :: Nil))) 57 | .option("dbtable", tableName) 58 | .mode(SaveMode.ErrorIfExists) 59 | .save() 60 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 61 | 62 | val overwritingDf = 63 | sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema) 64 | write(overwritingDf) 65 | .option("dbtable", tableName) 66 | .mode(SaveMode.Overwrite) 67 | .save() 68 | 69 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 70 | checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData) 71 | } 72 | } 73 | 74 | // TODO:test overwrite that fails. 75 | 76 | // TODO (luca|issue #7) make SaveMode work 77 | ignore("Append SaveMode doesn't destroy existing data") { 78 | withTempRedshiftTable("append_doesnt_destroy_existing_data") { tableName => 79 | createTestDataInRedshift(tableName) 80 | val extraData = Seq( 81 | Row(2.toByte, false, null, -1234152.12312498, 100000.0f, null, 1239012341823719L, 82 | 24.toShort, "___|_123", null)) 83 | 84 | write(sqlContext.createDataFrame(sc.parallelize(extraData), TestUtils.testSchema)) 85 | .option("dbtable", tableName) 86 | .mode(SaveMode.Append) 87 | .saveAsTable(tableName) 88 | 89 | checkAnswer( 90 | sqlContext.sql(s"select * from $tableName"), 91 | TestUtils.expectedData ++ extraData) 92 | } 93 | } 94 | 95 | ignore("Respect SaveMode.ErrorIfExists when table exists") { 96 | withTempRedshiftTable("respect_savemode_error_if_exists") { tableName => 97 | val rdd = sc.parallelize(TestUtils.expectedData) 98 | val df = sqlContext.createDataFrame(rdd, TestUtils.testSchema) 99 | createTestDataInRedshift(tableName) // to ensure that the table already exists 100 | 101 | // Check that SaveMode.ErrorIfExists throws an exception 102 | val e = intercept[Exception] { 103 | write(df) 104 | .option("dbtable", tableName) 105 | .mode(SaveMode.ErrorIfExists) 106 | .saveAsTable(tableName) 107 | } 108 | assert(e.getMessage.contains("exists")) 109 | } 110 | } 111 | 112 | ignore("Do nothing when table exists if SaveMode = Ignore") { 113 | withTempRedshiftTable("do_nothing_when_savemode_ignore") { tableName => 114 | val rdd = sc.parallelize(TestUtils.expectedData.drop(1)) 115 | val df = sqlContext.createDataFrame(rdd, TestUtils.testSchema) 116 | createTestDataInRedshift(tableName) // to ensure that the table already exists 117 | write(df) 118 | .option("dbtable", tableName) 119 | .mode(SaveMode.Ignore) 120 | .saveAsTable(tableName) 121 | 122 | // Check that SaveMode.Ignore does nothing 123 | checkAnswer( 124 | sqlContext.sql(s"select * from $tableName"), 125 | TestUtils.expectedData) 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftInputFormatSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift 17 | 18 | import java.io.{DataOutputStream, File, FileOutputStream} 19 | 20 | import com.google.common.io.Files 21 | import io.github.spark_redshift_community.spark.redshift.RedshiftInputFormat._ 22 | import org.apache.hadoop.conf.Configuration 23 | import org.apache.spark.SparkContext 24 | import org.apache.spark.sql.types._ 25 | import org.apache.spark.sql.{Row, SQLContext} 26 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 27 | 28 | import scala.language.implicitConversions 29 | 30 | class RedshiftInputFormatSuite extends FunSuite with BeforeAndAfterAll { 31 | 32 | import RedshiftInputFormatSuite._ 33 | 34 | private var sc: SparkContext = _ 35 | 36 | override def beforeAll(): Unit = { 37 | super.beforeAll() 38 | sc = new SparkContext("local", this.getClass.getName) 39 | } 40 | 41 | override def afterAll(): Unit = { 42 | sc.stop() 43 | super.afterAll() 44 | } 45 | 46 | private def writeToFile(contents: String, file: File): Unit = { 47 | val bytes = contents.getBytes 48 | val out = new DataOutputStream(new FileOutputStream(file)) 49 | out.write(bytes, 0, bytes.length) 50 | out.close() 51 | } 52 | 53 | private def escape(records: Set[Seq[String]], delimiter: Char): String = { 54 | require(delimiter != '\\' && delimiter != '\n') 55 | records.map { r => 56 | r.map { f => 57 | f.replace("\\", "\\\\") 58 | .replace("\n", "\\\n") 59 | .replace(delimiter, "\\" + delimiter) 60 | }.mkString(delimiter) 61 | }.mkString("", "\n", "\n") 62 | } 63 | 64 | private final val KEY_BLOCK_SIZE = "fs.local.block.size" 65 | 66 | private final val TAB = '\t' 67 | 68 | private val records = Set( 69 | Seq("a\n", DEFAULT_DELIMITER + "b\\"), 70 | Seq("c", TAB + "d"), 71 | Seq("\ne", "\\\\f")) 72 | 73 | private def withTempDir(func: File => Unit): Unit = { 74 | val dir = Files.createTempDir() 75 | dir.deleteOnExit() 76 | func(dir) 77 | } 78 | 79 | test("default delimiter") { 80 | withTempDir { dir => 81 | val escaped = escape(records, DEFAULT_DELIMITER) 82 | writeToFile(escaped, new File(dir, "part-00000")) 83 | 84 | val conf = new Configuration 85 | conf.setLong(KEY_BLOCK_SIZE, 4) 86 | 87 | val rdd = sc.newAPIHadoopFile(dir.toString, classOf[RedshiftInputFormat], 88 | classOf[java.lang.Long], classOf[Array[String]], conf) 89 | 90 | // TODO: Check this assertion - fails on Travis only, no idea what, or what it's for 91 | // assert(rdd.partitions.size > records.size) // so there exist at least one empty partition 92 | 93 | val actual = rdd.values.map(_.toSeq).collect() 94 | assert(actual.size === records.size) 95 | assert(actual.toSet === records) 96 | } 97 | } 98 | 99 | test("customized delimiter") { 100 | withTempDir { dir => 101 | val escaped = escape(records, TAB) 102 | writeToFile(escaped, new File(dir, "part-00000")) 103 | 104 | val conf = new Configuration 105 | conf.setLong(KEY_BLOCK_SIZE, 4) 106 | conf.set(KEY_DELIMITER, TAB) 107 | 108 | val rdd = sc.newAPIHadoopFile(dir.toString, classOf[RedshiftInputFormat], 109 | classOf[java.lang.Long], classOf[Array[String]], conf) 110 | 111 | // TODO: Check this assertion - fails on Travis only, no idea what, or what it's for 112 | // assert(rdd.partitions.size > records.size) // so there exist at least one empty partitions 113 | 114 | val actual = rdd.values.map(_.toSeq).collect() 115 | assert(actual.size === records.size) 116 | assert(actual.toSet === records) 117 | } 118 | } 119 | 120 | test("schema parser") { 121 | withTempDir { dir => 122 | val testRecords = Set( 123 | Seq("a\n", "TX", 1, 1.0, 1000L, 200000000000L), 124 | Seq("b", "CA", 2, 2.0, 2000L, 1231412314L)) 125 | val escaped = escape(testRecords.map(_.map(_.toString)), DEFAULT_DELIMITER) 126 | writeToFile(escaped, new File(dir, "part-00000")) 127 | 128 | val sqlContext = new SQLContext(sc) 129 | val expectedSchema = StructType(Seq( 130 | StructField("name", StringType, nullable = true), 131 | StructField("state", StringType, nullable = true), 132 | StructField("id", IntegerType, nullable = true), 133 | StructField("score", DoubleType, nullable = true), 134 | StructField("big_score", LongType, nullable = true), 135 | StructField("some_long", LongType, nullable = true))) 136 | 137 | val df = sqlContext.redshiftFile(dir.toString, expectedSchema) 138 | assert(df.schema === expectedSchema) 139 | 140 | val parsed = df.rdd.map { 141 | case Row( 142 | name: String, state: String, id: Int, score: Double, bigScore: Long, someLong: Long 143 | ) => Seq(name, state, id, score, bigScore, someLong) 144 | }.collect().toSet 145 | 146 | assert(parsed === testRecords) 147 | } 148 | } 149 | } 150 | 151 | object RedshiftInputFormatSuite { 152 | implicit def charToString(c: Char): String = c.toString 153 | } 154 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/AWSCredentialsUtilsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import com.amazonaws.auth.{AWSSessionCredentials, BasicAWSCredentials, BasicSessionCredentials} 20 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters 21 | import org.apache.hadoop.conf.Configuration 22 | import org.scalatest.FunSuite 23 | 24 | import scala.language.implicitConversions 25 | 26 | class AWSCredentialsUtilsSuite extends FunSuite { 27 | 28 | val baseParams = Map( 29 | "tempdir" -> "s3://foo/bar", 30 | "dbtable" -> "test_schema.test_table", 31 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password") 32 | 33 | private implicit def string2Params(tempdir: String): MergedParameters = { 34 | Parameters.mergeParameters(baseParams ++ Map( 35 | "tempdir" -> tempdir, 36 | "forward_spark_s3_credentials" -> "true")) 37 | } 38 | 39 | test("credentialsString with regular keys") { 40 | val creds = new BasicAWSCredentials("ACCESSKEYID", "SECRET/KEY/WITH/SLASHES") 41 | val params = 42 | Parameters.mergeParameters(baseParams ++ Map("forward_spark_s3_credentials" -> "true")) 43 | assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) === 44 | "aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY/WITH/SLASHES") 45 | } 46 | 47 | test("credentialsString with STS temporary keys") { 48 | val params = Parameters.mergeParameters(baseParams ++ Map( 49 | "temporary_aws_access_key_id" -> "ACCESSKEYID", 50 | "temporary_aws_secret_access_key" -> "SECRET/KEY", 51 | "temporary_aws_session_token" -> "SESSION/Token")) 52 | assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) === 53 | "aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY;token=SESSION/Token") 54 | } 55 | 56 | test("Configured IAM roles should take precedence") { 57 | val creds = new BasicSessionCredentials("ACCESSKEYID", "SECRET/KEY", "SESSION/Token") 58 | val iamRole = "arn:aws:iam::123456789000:role/redshift_iam_role" 59 | val params = Parameters.mergeParameters(baseParams ++ Map("aws_iam_role" -> iamRole)) 60 | assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) === 61 | s"aws_iam_role=$iamRole") 62 | } 63 | 64 | test("AWSCredentials.load() STS temporary keys should take precedence") { 65 | val conf = new Configuration(false) 66 | conf.set("fs.s3.awsAccessKeyId", "CONFID") 67 | conf.set("fs.s3.awsSecretAccessKey", "CONFKEY") 68 | 69 | val params = Parameters.mergeParameters(baseParams ++ Map( 70 | "tempdir" -> "s3://URIID:URIKEY@bucket/path", 71 | "temporary_aws_access_key_id" -> "key_id", 72 | "temporary_aws_secret_access_key" -> "secret", 73 | "temporary_aws_session_token" -> "token" 74 | )) 75 | 76 | val creds = AWSCredentialsUtils.load(params, conf).getCredentials 77 | assert(creds.isInstanceOf[AWSSessionCredentials]) 78 | assert(creds.getAWSAccessKeyId === "key_id") 79 | assert(creds.getAWSSecretKey === "secret") 80 | assert(creds.asInstanceOf[AWSSessionCredentials].getSessionToken === "token") 81 | } 82 | 83 | test("AWSCredentials.load() credentials precedence for s3:// URIs") { 84 | val conf = new Configuration(false) 85 | conf.set("fs.s3.awsAccessKeyId", "CONFID") 86 | conf.set("fs.s3.awsSecretAccessKey", "CONFKEY") 87 | 88 | { 89 | val creds = AWSCredentialsUtils.load("s3://URIID:URIKEY@bucket/path", conf).getCredentials 90 | assert(creds.getAWSAccessKeyId === "URIID") 91 | assert(creds.getAWSSecretKey === "URIKEY") 92 | } 93 | 94 | { 95 | val creds = AWSCredentialsUtils.load("s3://bucket/path", conf).getCredentials 96 | assert(creds.getAWSAccessKeyId === "CONFID") 97 | assert(creds.getAWSSecretKey === "CONFKEY") 98 | } 99 | 100 | } 101 | 102 | test("AWSCredentials.load() credentials precedence for s3n:// URIs") { 103 | val conf = new Configuration(false) 104 | conf.set("fs.s3n.awsAccessKeyId", "CONFID") 105 | conf.set("fs.s3n.awsSecretAccessKey", "CONFKEY") 106 | 107 | { 108 | val creds = AWSCredentialsUtils.load("s3n://URIID:URIKEY@bucket/path", conf).getCredentials 109 | assert(creds.getAWSAccessKeyId === "URIID") 110 | assert(creds.getAWSSecretKey === "URIKEY") 111 | } 112 | 113 | { 114 | val creds = AWSCredentialsUtils.load("s3n://bucket/path", conf).getCredentials 115 | assert(creds.getAWSAccessKeyId === "CONFID") 116 | assert(creds.getAWSSecretKey === "CONFKEY") 117 | } 118 | 119 | } 120 | 121 | test("AWSCredentials.load() credentials precedence for s3a:// URIs") { 122 | val conf = new Configuration(false) 123 | conf.set("fs.s3a.access.key", "CONFID") 124 | conf.set("fs.s3a.secret.key", "CONFKEY") 125 | 126 | { 127 | val creds = AWSCredentialsUtils.load("s3a://URIID:URIKEY@bucket/path", conf).getCredentials 128 | assert(creds.getAWSAccessKeyId === "URIID") 129 | assert(creds.getAWSSecretKey === "URIKEY") 130 | } 131 | 132 | { 133 | val creds = AWSCredentialsUtils.load("s3a://bucket/path", conf).getCredentials 134 | assert(creds.getAWSAccessKeyId === "CONFID") 135 | assert(creds.getAWSSecretKey === "CONFKEY") 136 | } 137 | 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /tutorial/SparkRedshiftTutorial.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.tutorial 18 | import org.apache.spark.sql.{SQLContext, SaveMode} 19 | import org.apache.spark.{SparkConf, SparkContext} 20 | 21 | 22 | /** 23 | * Source code accompanying the spark-redshift tutorial. 24 | * The following parameters need to be passed 25 | * 1. AWS Access Key 26 | * 2. AWS Secret Access Key 27 | * 3. Redshift Database Name 28 | * 4. Redshift UserId 29 | * 5. Redshift Password 30 | * 6. Redshift URL (Ex. swredshift.czac2vcs84ci.us-east-1.redshift.amazonaws.com:5439) 31 | */ 32 | object SparkRedshiftTutorial { 33 | /* 34 | * For Windows Users only 35 | * 1. Download contents from link 36 | * https://github.com/srccodes/hadoop-common-2.2.0-bin/archive/master.zip 37 | * 2. Unzip the file in step 1 into your %HADOOP_HOME%/bin. 38 | * 3. pass System parameter -Dhadoop.home.dir=%HADOOP_HOME/bin where %HADOOP_HOME 39 | * must be an absolute not relative path 40 | */ 41 | 42 | def main(args: Array[String]): Unit = { 43 | 44 | if (args.length < 6) { 45 | println("Needs 6 parameters only passed " + args.length) 46 | println("parameters needed - $awsAccessKey $awsSecretKey $rsDbName $rsUser $rsPassword $rsURL") 47 | } 48 | val awsAccessKey = args(0) 49 | val awsSecretKey = args(1) 50 | val rsDbName = args(2) 51 | val rsUser = args(3) 52 | val rsPassword = args(4) 53 | //Sample Redshift URL is swredshift.czac2vcs84ci.us-east-1.redshift.amazonaws.com:5439 54 | val rsURL = args(5) 55 | val jdbcURL = s"""jdbc:redshift://$rsURL/$rsDbName?user=$rsUser&password=$rsPassword""" 56 | println(jdbcURL) 57 | val sc = new SparkContext(new SparkConf().setAppName("SparkSQL").setMaster("local")) 58 | 59 | val tempS3Dir = "s3n://redshift-spark/temp/" 60 | sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", awsAccessKey) 61 | sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", awsSecretKey) 62 | 63 | val sqlContext = new SQLContext(sc) 64 | 65 | //Load from a table 66 | val eventsDF = sqlContext.read 67 | .format("io.github.spark_redshift_community.spark.redshift") 68 | .option("url", jdbcURL) 69 | .option("tempdir", tempS3Dir) 70 | .option("dbtable", "event") 71 | .load() 72 | eventsDF.show() 73 | eventsDF.printSchema() 74 | 75 | //Load from a query 76 | val salesQuery = """SELECT salesid, listid, sellerid, buyerid, 77 | eventid, dateid, qtysold, pricepaid, commission 78 | FROM sales 79 | ORDER BY saletime DESC LIMIT 10000""" 80 | val salesDF = sqlContext.read 81 | .format("io.github.spark_redshift_community.spark.redshift") 82 | .option("url", jdbcURL) 83 | .option("tempdir", tempS3Dir) 84 | .option("query", salesQuery) 85 | .load() 86 | salesDF.show() 87 | 88 | val eventQuery = "SELECT * FROM event" 89 | val eventDF = sqlContext.read 90 | .format("io.github.spark_redshift_community.spark.redshift") 91 | .option("url", jdbcURL) 92 | .option("tempdir", tempS3Dir) 93 | .option("query", eventQuery) 94 | .load() 95 | 96 | /* 97 | * Register 'event' table as temporary table 'myevent' 98 | * so that it can be queried via sqlContext.sql 99 | */ 100 | eventsDF.registerTempTable("myevent") 101 | 102 | //Save to a Redshift table from a table registered in Spark 103 | 104 | /* 105 | * Create a new table redshiftevent after dropping any existing redshiftevent table 106 | * and write event records with event id less than 1000 107 | */ 108 | sqlContext.sql("SELECT * FROM myevent WHERE eventid<=1000").withColumnRenamed("eventid", "id") 109 | .write.format("io.github.spark_redshift_community.spark.redshift") 110 | .option("url", jdbcURL) 111 | .option("tempdir", tempS3Dir) 112 | .option("dbtable", "redshiftevent") 113 | .mode(SaveMode.Overwrite) 114 | .save() 115 | 116 | /* 117 | * Append to an existing table redshiftevent if it exists or create a new one if it does not 118 | * exist and write event records with event id greater than 1000 119 | */ 120 | sqlContext.sql("SELECT * FROM myevent WHERE eventid>1000").withColumnRenamed("eventid", "id") 121 | .write.format("io.github.spark_redshift_community.spark.redshift") 122 | .option("url", jdbcURL) 123 | .option("tempdir", tempS3Dir) 124 | .option("dbtable", "redshiftevent") 125 | .mode(SaveMode.Append) 126 | .save() 127 | 128 | /** Demonstration of interoperability */ 129 | val salesAGGQuery = """SELECT sales.eventid AS id, SUM(qtysold) AS totalqty, SUM(pricepaid) AS salesamt 130 | FROM sales 131 | GROUP BY (sales.eventid) 132 | """ 133 | val salesAGGDF = sqlContext.read 134 | .format("io.github.spark_redshift_community.spark.redshift") 135 | .option("url", jdbcURL) 136 | .option("tempdir", tempS3Dir) 137 | .option("query", salesAGGQuery) 138 | .load() 139 | salesAGGDF.registerTempTable("salesagg") 140 | 141 | /* 142 | * Join two DataFrame instances. Each could be sourced from any 143 | * compatible Data Source 144 | */ 145 | val salesAGGDF2 = salesAGGDF.join(eventsDF, salesAGGDF("id") === eventsDF("eventid")) 146 | .select("id", "eventname", "totalqty", "salesamt") 147 | 148 | salesAGGDF2.registerTempTable("redshift_sales_agg") 149 | 150 | sqlContext.sql("SELECT * FROM redshift_sales_agg") 151 | .write.format("io.github.spark_redshift_community.spark.redshift") 152 | .option("url", jdbcURL) 153 | .option("tempdir", tempS3Dir) 154 | .option("dbtable", "redshift_sales_agg") 155 | .mode(SaveMode.Overwrite) 156 | .save() 157 | } 158 | } -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/Conversions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 TouchType Ltd 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.Timestamp 20 | import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat} 21 | import java.time.{DateTimeException, LocalDateTime, ZonedDateTime} 22 | import java.time.format.DateTimeFormatter 23 | import java.util.Locale 24 | 25 | import org.apache.spark.sql.catalyst.InternalRow 26 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 27 | import org.apache.spark.sql.catalyst.expressions.GenericRow 28 | import org.apache.spark.sql.types._ 29 | 30 | /** 31 | * Data type conversions for Redshift unloaded data 32 | */ 33 | private[redshift] object Conversions { 34 | 35 | /** 36 | * From the DateTimeFormatter docs (Java 8): 37 | * "A formatter created from a pattern can be used as many times as necessary, 38 | * it is immutable and is thread-safe." 39 | */ 40 | private val formatter = DateTimeFormatter.ofPattern( 41 | "yyyy-MM-dd HH:mm:ss[.SSSSSS][.SSSSS][.SSSS][.SSS][.SS][.S][X]") 42 | 43 | /** 44 | * Parse a boolean using Redshift's UNLOAD bool syntax 45 | */ 46 | private def parseBoolean(s: String): Boolean = { 47 | if (s == "t") true 48 | else if (s == "f") false 49 | else throw new IllegalArgumentException(s"Expected 't' or 'f' but got '$s'") 50 | } 51 | 52 | /** 53 | * Formatter for writing decimals unloaded from Redshift. 54 | * 55 | * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this 56 | * DecimalFormat across threads. 57 | */ 58 | def createRedshiftDecimalFormat(): DecimalFormat = { 59 | val format = new DecimalFormat() 60 | format.setParseBigDecimal(true) 61 | format.setDecimalFormatSymbols(new DecimalFormatSymbols(Locale.US)) 62 | format 63 | } 64 | 65 | /** 66 | * Formatter for parsing strings exported from Redshift DATE columns. 67 | * 68 | * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this 69 | * SimpleDateFormat across threads. 70 | */ 71 | def createRedshiftDateFormat(): SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd") 72 | 73 | /** 74 | * Formatter for formatting timestamps for insertion into Redshift TIMESTAMP columns. 75 | * 76 | * This formatter should not be used to parse timestamps returned from Redshift UNLOAD commands; 77 | * instead, use [[Timestamp.valueOf()]]. 78 | * 79 | * Note that Java Formatters are NOT thread-safe, so you should not re-use instances of this 80 | * SimpleDateFormat across threads. 81 | */ 82 | def createRedshiftTimestampFormat(): SimpleDateFormat = { 83 | new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS") 84 | } 85 | 86 | def parseRedshiftTimestamp(s: String): Timestamp = { 87 | val temporalAccessor = formatter.parse(s) 88 | 89 | try { 90 | // timestamptz 91 | Timestamp.from(ZonedDateTime.from(temporalAccessor).toInstant) 92 | } 93 | catch { 94 | // Case timestamp without timezone 95 | case e: DateTimeException => 96 | Timestamp.valueOf(LocalDateTime.from(temporalAccessor)) 97 | } 98 | } 99 | 100 | /** 101 | * Return a function that will convert arrays of strings conforming to the given schema to Rows. 102 | * 103 | * Note that instances of this function are NOT thread-safe. 104 | */ 105 | def createRowConverter(schema: StructType, nullString: String): Array[String] => InternalRow = { 106 | val dateFormat = createRedshiftDateFormat() 107 | val decimalFormat = createRedshiftDecimalFormat() 108 | val conversionFunctions: Array[String => Any] = schema.fields.map { field => 109 | field.dataType match { 110 | case ByteType => (data: String) => java.lang.Byte.parseByte(data) 111 | case BooleanType => (data: String) => parseBoolean(data) 112 | case DateType => (data: String) => new java.sql.Date(dateFormat.parse(data).getTime) 113 | case DoubleType => (data: String) => data match { 114 | case "nan" => Double.NaN 115 | case "inf" => Double.PositiveInfinity 116 | case "-inf" => Double.NegativeInfinity 117 | case _ => java.lang.Double.parseDouble(data) 118 | } 119 | case FloatType => (data: String) => data match { 120 | case "nan" => Float.NaN 121 | case "inf" => Float.PositiveInfinity 122 | case "-inf" => Float.NegativeInfinity 123 | case _ => java.lang.Float.parseFloat(data) 124 | } 125 | case dt: DecimalType => 126 | (data: String) => decimalFormat.parse(data).asInstanceOf[java.math.BigDecimal] 127 | case IntegerType => (data: String) => java.lang.Integer.parseInt(data) 128 | case LongType => (data: String) => java.lang.Long.parseLong(data) 129 | case ShortType => (data: String) => java.lang.Short.parseShort(data) 130 | case StringType => (data: String) => data 131 | case TimestampType => (data: String) => parseRedshiftTimestamp(data) 132 | case _ => (data: String) => data 133 | } 134 | } 135 | // As a performance optimization, re-use the same mutable row / array: 136 | val converted: Array[Any] = Array.fill(schema.length)(null) 137 | val externalRow = new GenericRow(converted) 138 | val encoder = RowEncoder(schema) 139 | (inputRow: Array[String]) => { 140 | var i = 0 141 | while (i < schema.length) { 142 | val data = inputRow(i) 143 | converted(i) = if ((data == null || data == nullString) || 144 | (data.isEmpty && schema.fields(i).dataType != StringType)) { 145 | null 146 | } 147 | else if (data.isEmpty) { 148 | "" 149 | } 150 | else { 151 | conversionFunctions(i)(data) 152 | } 153 | i += 1 154 | } 155 | encoder.toRow(externalRow) 156 | } 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | # spark-redshift Changelog 2 | 3 | ## 4.1.0 4 | 5 | - Add `include_column_list` parameter 6 | 7 | ## 4.0.2 8 | 9 | - Trim SQL text for preactions and postactions, to fix empty SQL queries bug. 10 | 11 | ## 4.0.1 12 | 13 | - Fix bug when parsing microseconds from Redshift 14 | 15 | ## 4.0.0 16 | 17 | This major release makes spark-redshift compatible with spark 2.4. This was tested in production. 18 | 19 | While upgrading the package we droped some features due to time constraints. 20 | 21 | - Support for hadoop 1.x has been dropped. 22 | - STS and IAM authentication support has been dropped. 23 | - postgresql driver tests are inactive. 24 | - SaveMode tests (or functionality?) are broken. This is a bit scary but I'm not sure we use the functionality 25 | and fixing them didn't make it in this version (spark-snowflake removed them too). 26 | - S3Native has been deprecated. We created an InMemoryS3AFileSystem to test S3A. 27 | 28 | ## 4.0.0-SNAPSHOT 29 | - SNAPSHOT version to test publishing to Maven Central. 30 | 31 | ## 4.0.0-preview20190730 (2019-07-30) 32 | 33 | - The library is tested in production using spark2.4 34 | - RedshiftSourceSuite is again among the scala test suites. 35 | 36 | ## 4.0.0-preview20190715 (2019-07-15) 37 | 38 | Move to pre-4.0.0 'preview' releases rather than SNAPSHOT 39 | 40 | ## 4.0.0-SNAPSHOT-20190710 (2019-07-10) 41 | 42 | Remove AWSCredentialsInUriIntegrationSuite test and require s3a path in CrossRegionIntegrationSuite.scala 43 | 44 | ## 4.0.0-SNAPSHOT-20190627 (2019-06-27) 45 | 46 | Baseline SNAPSHOT version working with 2.4 47 | 48 | #### Deprecation 49 | In order to get this baseline snapshot out, we dropped some features and package versions, 50 | and disabled some tests. 51 | Some of these changes are temporary, others - such as dropping hadoop 1.x - are meant to stay. 52 | 53 | Our intent is to do the best job possible supporting the minimal set of features 54 | that the community needs. Other non-essential features may be dropped before the 55 | first non-snapshot release. 56 | The community's feedback and contributions are vitally important. 57 | 58 | 59 | * Support for hadoop 1.x has been dropped. 60 | * STS and IAM authentication support has been dropped (so are tests). 61 | * postgresql driver tests are inactive. 62 | * SaveMode tests (or functionality?) are broken. This is a bit scarier but I'm not sure we use the functionality and fixing them didn't make it in this version (spark-snowflake removed them too). 63 | * S3Native has been deprecated. It's our intention to phase it out from this repo. The test util ‘inMemoryFilesystem’ is not present anymore so an entire test suite RedshiftSourceSuite lost its major mock object and I had to remove it. We plan to re-write it using s3a. 64 | 65 | #### Commits changelog 66 | - 5b0f949 (HEAD -> master, origin_community/master) Merge pull request #6 from spark-redshift-community/luca-spark-2.4 67 | - 25acded (origin_community/luca-spark-2.4, origin/luca-spark-2.4, luca-spark-2.4) Revert sbt scripts to an older version 68 | - 866d4fd Moving to external github issues - rename spName to spark-redshift-community 69 | - 094cc15 remove in Memory FileSystem class and clean up comments in the sbt build file 70 | - 0666bc6 aws_variables.env gitignored 71 | - f3bbdb7 sbt assembly the package into a fat jar - found the perfect coordination between different libraries versions! Tests pass and can compile spark-on-paasta and spark successfullygit add src/ project/ 72 | - b1fa3f6 Ignoring a bunch of tests as did snowflake - close to have a green build to try out 73 | - 95cdf94 Removing conn.commit() everywhere - got 88% of integration tests to run - fix for STS token aws access in progress 74 | - da10897 Compiling - managed to run tests but they mostly fail 75 | - 0fe37d2 Compiles with spark 2.4.0 - amazon unmarshal error 76 | - ea5da29 force spark.avro - hadoop 2.7.7 and awsjavasdk downgraded 77 | - 834f0d6 Upgraded jackson by excluding it in aws 78 | - 90581a8 Fixed NewFilter - including hadoop-aws - s3n test is failing 79 | - 50dfd98 (tag: v3.0.0, tag: gtig, origin/master, origin/HEAD) Merge pull request #5 from Yelp/fdc_first-version 80 | - fbb58b3 (origin/fdc_first-version) First Yelp release 81 | - 0d2a130 Merge pull request #4 from Yelp/fdc_DATALAKE-4899_empty-string-to-null 82 | - 689635c (origin/fdc_DATALAKE-4899_empty-string-to-null) Fix File line length exceeds 100 characters 83 | - d06fe3b Fix scalastyle 84 | - e15ccb5 Fix parenthesis 85 | - d16317e Fix indentation 86 | - 475e7a1 Fix convertion bit and test 87 | - 3ae6a9b Fix Empty string is converted to null 88 | - 967dddb Merge pull request #3 from Yelp/fdc_DATALAKE-486_avoid-log-creds 89 | - 040b4a9 Merge pull request #2 from Yelp/fdc_DATALAKE-488_cleanup-fix-double-to-float 90 | - 58fb829 (origin/fdc_DATALAKE-488_cleanup-fix-double-to-float) Fix test 91 | - 3384333 Add bit and default types 92 | - 3230aaa (origin/fdc_DATALAKE-486_avoid-log-creds) Avoid logging creds. log sql query statement only 93 | - ab8124a Fix double type to float and cleanup 94 | - cafa05f Merge pull request #1 from Yelp/fdc_DATALAKE-563_remove-itests-from-public 95 | - a3a39a2 (origin/fdc_DATALAKE-563_remove-itests-from-public) Remove itests. Fix jdbc url. Update Redshift jdbc driver 96 | - 184b442 Make the note more obvious. 97 | - 717a4ad Notes about inlining this in Databricks Runtime. 98 | - 8adfe95 (origin/fdc_first-test-branch-2) Fix decimal precision loss when reading the results of a Redshift query 99 | - 8da2d92 Test infra housekeeping: reduce SBT memory, update plugin versions, update SBT 100 | - 79bac6d Add instructions on using JitPack master SNAPSHOT builds 101 | - 7a4a08e Use PreparedStatement.getMetaData() to retrieve Redshift query schemas 102 | - b4c6053 Wrap and re-throw Await.result exceptions in order to capture full stacktrace 103 | - 1092c7c Update version in README to 3.0.0-preview1 104 | - 320748a Setting version to 3.0.0-SNAPSHOT 105 | - a28832b (tag: v3.0.0-preview1, origin/fdc_30-review) Setting version to 3.0.0-preview1 106 | - 8afde06 Make Redshift to S3 authentication mechanisms mutually exclusive 107 | - 9ed18a0 Use FileFormat-based data source instead of HadoopRDD for reads 108 | - 6cc49da Add option to use CSV as an intermediate data format during writes 109 | - d508d3e Add documentation and warnings related to using different regions for Redshift and S3 110 | - cdf192a Break RedshiftIntegrationSuite into smaller suites; refactor to remove some redundancy 111 | - bdf4462 Pass around AWSCredentialProviders instead of AWSCredentials 112 | - 51c29e6 Add codecov.yml file. 113 | - a9963da Update AWSCredentialUtils to be uniform between URI schemes. 114 | 115 | ## 3.0.0-SNAPSHOT (2017-11-08) 116 | 117 | Databricks spark-redshift pre-fork, changes not tracked. 118 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/ConversionsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 TouchType Ltd 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.Timestamp 20 | import java.util.Locale 21 | 22 | import org.apache.spark.sql.Row 23 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 24 | import org.apache.spark.sql.types._ 25 | import org.scalatest.FunSuite 26 | 27 | /** 28 | * Unit test for data type conversions 29 | */ 30 | class ConversionsSuite extends FunSuite { 31 | 32 | private def createRowConverter(schema: StructType) = { 33 | Conversions.createRowConverter(schema, Parameters.DEFAULT_PARAMETERS("csvnullstring")) 34 | .andThen(RowEncoder(schema).resolveAndBind().fromRow) 35 | } 36 | 37 | test("Data should be correctly converted") { 38 | val convertRow = createRowConverter(TestUtils.testSchema) 39 | val doubleMin = Double.MinValue.toString 40 | val longMax = Long.MaxValue.toString 41 | // scalastyle:off 42 | val unicodeString = "Unicode是樂趣" 43 | // scalastyle:on 44 | 45 | val timestampWithMillis = "2014-03-01 00:00:01.123" 46 | val expectedTimestampMillis = TestUtils.toMillis(2014, 2, 1, 0, 0, 1, 123) 47 | 48 | val expectedDateMillis = TestUtils.toMillis(2015, 6, 1, 0, 0, 0) 49 | 50 | val convertedRow = convertRow( 51 | Array("1", "t", "2015-07-01", doubleMin, "1.0", "42", 52 | longMax, "23", unicodeString, timestampWithMillis)) 53 | 54 | val expectedRow = Row(1.asInstanceOf[Byte], true, new Timestamp(expectedDateMillis), 55 | Double.MinValue, 1.0f, 42, Long.MaxValue, 23.toShort, unicodeString, 56 | new Timestamp(expectedTimestampMillis)) 57 | 58 | assert(convertedRow == expectedRow) 59 | } 60 | 61 | test("Regression test for parsing timestamptz (bug #25 in spark_redshift_community)") { 62 | val rowConverter = createRowConverter( 63 | StructType(Seq(StructField("timestampWithTimezone", TimestampType)))) 64 | 65 | // when converting to timestamp, we discard the TZ info. 66 | val timestampWithTimezone = "2014-03-01 00:00:01.123-03" 67 | 68 | val expectedTimestampWithTimezoneMillis = TestUtils.toMillis( 69 | 2014, 2, 1, 0, 0, 1, 123, "-03") 70 | 71 | val convertedRow = rowConverter(Array(timestampWithTimezone)) 72 | val expectedRow = Row(new Timestamp(expectedTimestampWithTimezoneMillis)) 73 | 74 | assert(convertedRow == expectedRow) 75 | } 76 | 77 | test("Row conversion handles null values") { 78 | val convertRow = createRowConverter(TestUtils.testSchema) 79 | val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] 80 | assert(convertRow(emptyRow) === Row(emptyRow: _*)) 81 | } 82 | 83 | test("Booleans are correctly converted") { 84 | val convertRow = createRowConverter(StructType(Seq(StructField("a", BooleanType)))) 85 | assert(convertRow(Array("t")) === Row(true)) 86 | assert(convertRow(Array("f")) === Row(false)) 87 | assert(convertRow(Array(null)) === Row(null)) 88 | intercept[IllegalArgumentException] { 89 | convertRow(Array("not-a-boolean")) 90 | } 91 | } 92 | 93 | test("timestamp conversion handles millisecond-level precision (regression test for #214)") { 94 | val schema = StructType(Seq(StructField("a", TimestampType))) 95 | val convertRow = createRowConverter(schema) 96 | Seq( 97 | "2014-03-01 00:00:01.123456" -> 98 | TestUtils.toNanosTimestamp(2014, 2, 1, 0, 0, 1, nanos = 123456000), 99 | "2014-03-01 00:00:01.12345" -> 100 | TestUtils.toNanosTimestamp(2014, 2, 1, 0, 0, 1, nanos = 123450000), 101 | "2014-03-01 00:00:01.1234" -> 102 | TestUtils.toNanosTimestamp(2014, 2, 1, 0, 0, 1, nanos = 123400000), 103 | "2014-03-01 00:00:01" -> 104 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 1000), 105 | "2014-03-01 00:00:01.000" -> 106 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 1000), 107 | "2014-03-01 00:00:00.1" -> 108 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 100), 109 | "2014-03-01 00:00:00.10" -> 110 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 100), 111 | "2014-03-01 00:00:00.100" -> 112 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 100), 113 | "2014-03-01 00:00:00.01" -> 114 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 10), 115 | "2014-03-01 00:00:00.010" -> 116 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 10), 117 | "2014-03-01 00:00:00.001" -> 118 | TestUtils.toTimestamp(2014, 2, 1, 0, 0, 0, millis = 1) 119 | ).foreach { case (timestampString, expectedTime) => 120 | withClue(s"timestamp string is '$timestampString'") { 121 | val convertedRow = convertRow(Array(timestampString)) 122 | val convertedTimestamp = convertedRow.get(0).asInstanceOf[Timestamp] 123 | assert(convertedTimestamp === expectedTime) 124 | } 125 | } 126 | } 127 | 128 | test("RedshiftDecimalFormat is locale-insensitive (regression test for #243)") { 129 | for (locale <- Seq(Locale.US, Locale.GERMAN, Locale.UK)) { 130 | withClue(s"locale = $locale") { 131 | TestUtils.withDefaultLocale(locale) { 132 | val decimalFormat = Conversions.createRedshiftDecimalFormat() 133 | val parsed = decimalFormat.parse("151.20").asInstanceOf[java.math.BigDecimal] 134 | assert(parsed.doubleValue() === 151.20) 135 | } 136 | } 137 | } 138 | } 139 | 140 | test("Row conversion properly handles NaN and Inf float values (regression test for #261)") { 141 | val convertRow = createRowConverter(StructType(Seq(StructField("a", FloatType)))) 142 | assert(java.lang.Float.isNaN(convertRow(Array("nan")).getFloat(0))) 143 | assert(convertRow(Array("inf")) === Row(Float.PositiveInfinity)) 144 | assert(convertRow(Array("-inf")) === Row(Float.NegativeInfinity)) 145 | } 146 | 147 | test("Row conversion properly handles NaN and Inf double values (regression test for #261)") { 148 | val convertRow = createRowConverter(StructType(Seq(StructField("a", DoubleType)))) 149 | assert(java.lang.Double.isNaN(convertRow(Array("nan")).getDouble(0))) 150 | assert(convertRow(Array("inf")) === Row(Double.PositiveInfinity)) 151 | assert(convertRow(Array("-inf")) === Row(Double.NegativeInfinity)) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /src/test/java/io/github/spark_redshift_community/spark/redshift/InMemoryS3AFileSystem.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one 3 | * or more contributor license agreements. See the NOTICE file 4 | * distributed with this work for additional information 5 | * regarding copyright ownership. The ASF licenses this file 6 | * to you under the Apache License, Version 2.0 (the 7 | * "License"); you may not use this file except in compliance 8 | * with the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | package io.github.spark_redshift_community.spark.redshift; 20 | 21 | import java.io.*; 22 | import java.net.URI; 23 | import java.util.*; 24 | 25 | import org.apache.hadoop.fs.*; 26 | import org.apache.hadoop.fs.permission.FsPermission; 27 | import org.apache.hadoop.fs.s3a.S3AFileStatus; 28 | 29 | import org.apache.hadoop.conf.Configuration; 30 | import org.apache.hadoop.util.Progressable; 31 | 32 | 33 | /** 34 | * A stub implementation of NativeFileSystemStore for testing 35 | * S3AFileSystem without actually connecting to S3. 36 | */ 37 | public class InMemoryS3AFileSystem extends FileSystem { 38 | public static final String BUCKET = "test-bucket"; 39 | public static final URI FS_URI = URI.create("s3a://" + BUCKET + "/"); 40 | 41 | private static final long DEFAULT_BLOCK_SIZE_TEST = 33554432; 42 | 43 | private final Path root = new Path(FS_URI.toString()); 44 | 45 | private SortedMap dataMap = new TreeMap(); 46 | 47 | private Configuration conf; 48 | 49 | @Override 50 | public URI getUri() { 51 | return FS_URI; 52 | } 53 | 54 | @Override 55 | public Path getWorkingDirectory() { 56 | return new Path(root, "work"); 57 | } 58 | 59 | @Override 60 | public boolean mkdirs(Path f, FsPermission permission) throws IOException { 61 | // Not implemented 62 | return false; 63 | } 64 | 65 | @Override 66 | public void initialize(URI name, Configuration originalConf) 67 | throws IOException { 68 | conf = originalConf; 69 | } 70 | 71 | @Override 72 | public Configuration getConf() { 73 | return conf; 74 | } 75 | 76 | @Override 77 | public boolean exists(Path f) throws IOException { 78 | 79 | SortedMap subMap = dataMap.tailMap(toS3Key(f)); 80 | for (String filePath: subMap.keySet()) { 81 | if (filePath.contains(toS3Key(f))) { 82 | return true; 83 | } 84 | } 85 | return false; 86 | } 87 | 88 | private String toS3Key(Path f) { 89 | return f.toString(); 90 | } 91 | 92 | @Override 93 | public FSDataInputStream open(Path f) throws IOException { 94 | if (getFileStatus(f).isDirectory()) 95 | throw new IOException("TESTING: path can't be opened - it's a directory"); 96 | 97 | return new FSDataInputStream( 98 | new SeekableByteArrayInputStream( 99 | dataMap.get(toS3Key(f)).toByteArray() 100 | ) 101 | ); 102 | } 103 | 104 | @Override 105 | public FSDataInputStream open(Path f, int bufferSize) throws IOException { 106 | return open(f); 107 | } 108 | 109 | @Override 110 | public FSDataOutputStream create(Path f) throws IOException { 111 | 112 | if (exists(f)) { 113 | throw new FileAlreadyExistsException(); 114 | } 115 | 116 | String key = toS3Key(f); 117 | ByteArrayOutputStream inMemoryS3File = new ByteArrayOutputStream(); 118 | 119 | dataMap.put(key, inMemoryS3File); 120 | 121 | return new FSDataOutputStream(inMemoryS3File); 122 | 123 | } 124 | 125 | @Override 126 | public FSDataOutputStream create(Path f, FsPermission permission, boolean overwrite, int bufferSize, short replication, long blockSize, Progressable progress) throws IOException { 127 | // Not Implemented 128 | return null; 129 | } 130 | 131 | @Override 132 | public FSDataOutputStream append(Path f, int bufferSize, Progressable progress) throws IOException { 133 | // Not Implemented 134 | return null; 135 | } 136 | 137 | @Override 138 | public boolean rename(Path src, Path dst) throws IOException { 139 | dataMap.put(toS3Key(dst), dataMap.get(toS3Key(src))); 140 | return true; 141 | } 142 | 143 | @Override 144 | public boolean delete(Path f, boolean recursive) throws IOException { 145 | dataMap.remove(toS3Key(f)); 146 | return true; 147 | } 148 | 149 | private Set childPaths(Path f) { 150 | Set children = new HashSet<>(); 151 | 152 | String fDir = f + "/"; 153 | for (String subKey: dataMap.tailMap(toS3Key(f)).keySet()){ 154 | children.add( 155 | fDir + subKey.replace(fDir, "").split("/")[0] 156 | ); 157 | } 158 | return children; 159 | } 160 | 161 | @Override 162 | public FileStatus[] listStatus(Path f) throws IOException { 163 | 164 | if (!exists(f)) throw new FileNotFoundException(); 165 | 166 | if (getFileStatus(f).isDirectory()){ 167 | ArrayList statuses = new ArrayList<>(); 168 | 169 | for (String child: childPaths(f)) { 170 | statuses.add(getFileStatus(new Path(child))); 171 | } 172 | 173 | FileStatus[] arrayStatuses = new FileStatus[statuses.size()]; 174 | return statuses.toArray(arrayStatuses); 175 | } 176 | 177 | else { 178 | FileStatus[] statuses = new FileStatus[1]; 179 | statuses[0] = this.getFileStatus(f); 180 | return statuses; 181 | } 182 | } 183 | 184 | @Override 185 | public void setWorkingDirectory(Path new_dir) { 186 | // Not implemented 187 | } 188 | 189 | private boolean isDir(Path f) throws IOException{ 190 | return exists(f) && dataMap.get(toS3Key(f)) == null; 191 | } 192 | 193 | 194 | @Override 195 | public S3AFileStatus getFileStatus(Path f) throws IOException { 196 | 197 | if (!exists(f)) throw new FileNotFoundException(); 198 | 199 | if (isDir(f)) { 200 | return new S3AFileStatus( 201 | true, 202 | dataMap.tailMap(toS3Key(f)).size() == 1 && dataMap.containsKey(toS3Key(f)), 203 | f 204 | ); 205 | } 206 | else { 207 | return new S3AFileStatus( 208 | dataMap.get(toS3Key(f)).toByteArray().length, 209 | System.currentTimeMillis(), 210 | f, 211 | this.getDefaultBlockSize() 212 | ); 213 | } 214 | } 215 | 216 | @Override 217 | @SuppressWarnings("deprecation") 218 | public long getDefaultBlockSize() { 219 | return DEFAULT_BLOCK_SIZE_TEST; 220 | } 221 | } -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/RedshiftWriteSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.SQLException 20 | 21 | import org.apache.spark.sql._ 22 | import org.apache.spark.sql.types._ 23 | 24 | /** 25 | * End-to-end tests of functionality which involves writing to Redshift via the connector. 26 | */ 27 | abstract class BaseRedshiftWriteSuite extends IntegrationSuiteBase { 28 | 29 | protected val tempformat: String 30 | 31 | override protected def write(df: DataFrame): DataFrameWriter[Row] = 32 | super.write(df).option("tempformat", tempformat) 33 | 34 | test("roundtrip save and load") { 35 | // This test can be simplified once #98 is fixed. 36 | val tableName = s"roundtrip_save_and_load_$randomSuffix" 37 | try { 38 | write( 39 | sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema)) 40 | .option("dbtable", tableName) 41 | .mode(SaveMode.ErrorIfExists) 42 | .save() 43 | 44 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 45 | checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData) 46 | } finally { 47 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 48 | } 49 | } 50 | 51 | test("roundtrip save and load with uppercase column names") { 52 | testRoundtripSaveAndLoad( 53 | s"roundtrip_write_and_read_with_uppercase_column_names_$randomSuffix", 54 | sqlContext.createDataFrame( 55 | sc.parallelize(Seq(Row(1))), StructType(StructField("SomeColumn", IntegerType) :: Nil) 56 | ), 57 | expectedSchemaAfterLoad = Some(StructType(StructField("somecolumn", IntegerType) :: Nil)) 58 | ) 59 | } 60 | 61 | test("save with column names that are reserved words") { 62 | testRoundtripSaveAndLoad( 63 | s"save_with_column_names_that_are_reserved_words_$randomSuffix", 64 | sqlContext.createDataFrame( 65 | sc.parallelize(Seq(Row(1))), 66 | StructType(StructField("table", IntegerType) :: Nil) 67 | ) 68 | ) 69 | } 70 | 71 | test("save with one empty partition (regression test for #96)") { 72 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 2), 73 | StructType(StructField("foo", IntegerType) :: Nil)) 74 | assert(df.rdd.glom.collect() === Array(Array.empty[Row], Array(Row(1)))) 75 | testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df) 76 | } 77 | 78 | test("save with all empty partitions (regression test for #96)") { 79 | val df = sqlContext.createDataFrame(sc.parallelize(Seq.empty[Row], 2), 80 | StructType(StructField("foo", IntegerType) :: Nil)) 81 | assert(df.rdd.glom.collect() === Array(Array.empty[Row], Array.empty[Row])) 82 | testRoundtripSaveAndLoad(s"save_with_all_empty_partitions_$randomSuffix", df) 83 | // Now try overwriting that table. Although the new table is empty, it should still overwrite 84 | // the existing table. 85 | val df2 = df.withColumnRenamed("foo", "bar") 86 | testRoundtripSaveAndLoad( 87 | s"save_with_all_empty_partitions_$randomSuffix", df2, saveMode = SaveMode.Overwrite) 88 | } 89 | 90 | test("informative error message when saving a table with string that is longer than max length") { 91 | val tableName = s"error_message_when_string_too_long_$randomSuffix" 92 | try { 93 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row("a" * 512))), 94 | StructType(StructField("A", StringType) :: Nil)) 95 | val e = intercept[SQLException] { 96 | write(df) 97 | .option("dbtable", tableName) 98 | .mode(SaveMode.ErrorIfExists) 99 | .save() 100 | } 101 | assert(e.getMessage.contains("while loading data into Redshift")) 102 | } finally { 103 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 104 | } 105 | } 106 | 107 | test("full timestamp precision is preserved in loads (regression test for #214)") { 108 | val timestamps = Seq( 109 | TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 1), 110 | TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 10), 111 | TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 100), 112 | TestUtils.toTimestamp(1970, 0, 1, 0, 0, 0, millis = 1000)) 113 | testRoundtripSaveAndLoad( 114 | s"full_timestamp_precision_is_preserved$randomSuffix", 115 | sqlContext.createDataFrame(sc.parallelize(timestamps.map(Row(_))), 116 | StructType(StructField("ts", TimestampType) :: Nil)) 117 | ) 118 | } 119 | } 120 | 121 | class AvroRedshiftWriteSuite extends BaseRedshiftWriteSuite { 122 | override protected val tempformat: String = "AVRO" 123 | 124 | test("informative error message when saving with column names that contain spaces (#84)") { 125 | intercept[IllegalArgumentException] { 126 | testRoundtripSaveAndLoad( 127 | s"error_when_saving_column_name_with_spaces_$randomSuffix", 128 | sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 129 | StructType(StructField("column name with spaces", IntegerType) :: Nil))) 130 | } 131 | } 132 | } 133 | 134 | class CSVRedshiftWriteSuite extends BaseRedshiftWriteSuite { 135 | override protected val tempformat: String = "CSV" 136 | 137 | test("save with column names that contain spaces (#84)") { 138 | testRoundtripSaveAndLoad( 139 | s"save_with_column_names_that_contain_spaces_$randomSuffix", 140 | sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 141 | StructType(StructField("column name with spaces", IntegerType) :: Nil))) 142 | } 143 | } 144 | 145 | class CSVGZIPRedshiftWriteSuite extends IntegrationSuiteBase { 146 | // Note: we purposely don't inherit from BaseRedshiftWriteSuite because we're only interested in 147 | // testing basic functionality of the GZIP code; the rest of the write path should be unaffected 148 | // by compression here. 149 | 150 | override protected def write(df: DataFrame): DataFrameWriter[Row] = 151 | super.write(df).option("tempformat", "CSV GZIP") 152 | 153 | test("roundtrip save and load") { 154 | // This test can be simplified once #98 is fixed. 155 | val tableName = s"roundtrip_save_and_load_$randomSuffix" 156 | try { 157 | write( 158 | sqlContext.createDataFrame(sc.parallelize(TestUtils.expectedData), TestUtils.testSchema)) 159 | .option("dbtable", tableName) 160 | .mode(SaveMode.ErrorIfExists) 161 | .save() 162 | 163 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 164 | checkAnswer(read.option("dbtable", tableName).load(), TestUtils.expectedData) 165 | } finally { 166 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 167 | } 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/ParametersSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 TouchType Ltd 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.scalatest.{FunSuite, Matchers} 20 | 21 | /** 22 | * Check validation of parameter config 23 | */ 24 | class ParametersSuite extends FunSuite with Matchers { 25 | 26 | test("Minimal valid parameter map is accepted") { 27 | val params = Map( 28 | "tempdir" -> "s3://foo/bar", 29 | "dbtable" -> "test_schema.test_table", 30 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password", 31 | "forward_spark_s3_credentials" -> "true", 32 | "include_column_list" -> "true") 33 | 34 | val mergedParams = Parameters.mergeParameters(params) 35 | 36 | mergedParams.rootTempDir should startWith(params("tempdir")) 37 | mergedParams.createPerQueryTempDir() should startWith(params("tempdir")) 38 | mergedParams.jdbcUrl shouldBe params("url") 39 | mergedParams.table shouldBe Some(TableName("test_schema", "test_table")) 40 | assert(mergedParams.forwardSparkS3Credentials) 41 | assert(mergedParams.includeColumnList) 42 | 43 | // Check that the defaults have been added 44 | ( 45 | Parameters.DEFAULT_PARAMETERS 46 | - "forward_spark_s3_credentials" 47 | - "include_column_list" 48 | ).foreach { 49 | case (key, value) => mergedParams.parameters(key) shouldBe value 50 | } 51 | } 52 | 53 | test("createPerQueryTempDir() returns distinct temp paths") { 54 | val params = Map( 55 | "forward_spark_s3_credentials" -> "true", 56 | "tempdir" -> "s3://foo/bar", 57 | "dbtable" -> "test_table", 58 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password") 59 | 60 | val mergedParams = Parameters.mergeParameters(params) 61 | 62 | mergedParams.createPerQueryTempDir() should not equal mergedParams.createPerQueryTempDir() 63 | } 64 | 65 | test("Errors are thrown when mandatory parameters are not provided") { 66 | def checkMerge(params: Map[String, String], err: String): Unit = { 67 | val e = intercept[IllegalArgumentException] { 68 | Parameters.mergeParameters(params) 69 | } 70 | assert(e.getMessage.contains(err)) 71 | } 72 | 73 | val testURL = "jdbc:redshift://foo/bar?user=user&password=password" 74 | checkMerge(Map("dbtable" -> "test_table", "url" -> testURL), "tempdir") 75 | checkMerge(Map("tempdir" -> "s3://foo/bar", "url" -> testURL), "Redshift table name") 76 | checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar"), "JDBC URL") 77 | checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar", "url" -> testURL), 78 | "method for authenticating") 79 | } 80 | 81 | test("Must specify either 'dbtable' or 'query' parameter, but not both") { 82 | intercept[IllegalArgumentException] { 83 | Parameters.mergeParameters(Map( 84 | "forward_spark_s3_credentials" -> "true", 85 | "tempdir" -> "s3://foo/bar", 86 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) 87 | }.getMessage should (include("dbtable") and include("query")) 88 | 89 | intercept[IllegalArgumentException] { 90 | Parameters.mergeParameters(Map( 91 | "forward_spark_s3_credentials" -> "true", 92 | "tempdir" -> "s3://foo/bar", 93 | "dbtable" -> "test_table", 94 | "query" -> "select * from test_table", 95 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) 96 | }.getMessage should (include("dbtable") and include("query") and include("both")) 97 | 98 | Parameters.mergeParameters(Map( 99 | "forward_spark_s3_credentials" -> "true", 100 | "tempdir" -> "s3://foo/bar", 101 | "query" -> "select * from test_table", 102 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) 103 | } 104 | 105 | test("Must specify credentials in either URL or 'user' and 'password' parameters, but not both") { 106 | intercept[IllegalArgumentException] { 107 | Parameters.mergeParameters(Map( 108 | "forward_spark_s3_credentials" -> "true", 109 | "tempdir" -> "s3://foo/bar", 110 | "query" -> "select * from test_table", 111 | "url" -> "jdbc:redshift://foo/bar")) 112 | }.getMessage should (include("credentials")) 113 | 114 | intercept[IllegalArgumentException] { 115 | Parameters.mergeParameters(Map( 116 | "forward_spark_s3_credentials" -> "true", 117 | "tempdir" -> "s3://foo/bar", 118 | "query" -> "select * from test_table", 119 | "user" -> "user", 120 | "password" -> "password", 121 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) 122 | }.getMessage should (include("credentials") and include("both")) 123 | 124 | Parameters.mergeParameters(Map( 125 | "forward_spark_s3_credentials" -> "true", 126 | "tempdir" -> "s3://foo/bar", 127 | "query" -> "select * from test_table", 128 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password")) 129 | } 130 | 131 | test("tempformat option is case-insensitive") { 132 | val params = Map( 133 | "forward_spark_s3_credentials" -> "true", 134 | "tempdir" -> "s3://foo/bar", 135 | "dbtable" -> "test_schema.test_table", 136 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password") 137 | 138 | Parameters.mergeParameters(params + ("tempformat" -> "csv")) 139 | Parameters.mergeParameters(params + ("tempformat" -> "CSV")) 140 | 141 | intercept[IllegalArgumentException] { 142 | Parameters.mergeParameters(params + ("tempformat" -> "invalid-temp-format")) 143 | } 144 | } 145 | 146 | test("can only specify one Redshift to S3 authentication mechanism") { 147 | val e = intercept[IllegalArgumentException] { 148 | Parameters.mergeParameters(Map( 149 | "tempdir" -> "s3://foo/bar", 150 | "dbtable" -> "test_schema.test_table", 151 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password", 152 | "forward_spark_s3_credentials" -> "true", 153 | "aws_iam_role" -> "role")) 154 | } 155 | assert(e.getMessage.contains("mutually-exclusive")) 156 | } 157 | 158 | test("preaction and postactions should be trimmed before splitting by semicolon") { 159 | val params = Parameters.mergeParameters(Map( 160 | "forward_spark_s3_credentials" -> "true", 161 | "tempdir" -> "s3://foo/bar", 162 | "dbtable" -> "test_schema.test_table", 163 | "url" -> "jdbc:redshift://foo/bar?user=user&password=password", 164 | "preactions" -> "update table1 set col1 = val1;update table1 set col2 = val2; ", 165 | "postactions" -> "update table2 set col1 = val1;update table2 set col2 = val2; " 166 | )) 167 | 168 | assert(params.preActions.length == 2) 169 | assert(params.preActions.head == "update table1 set col1 = val1") 170 | assert(params.preActions.last == "update table1 set col2 = val2") 171 | assert(params.postActions.length == 2) 172 | assert(params.postActions.head == "update table2 set col1 = val1") 173 | assert(params.postActions.last == "update table2 set col2 = val2") 174 | } 175 | 176 | } 177 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/Utils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 TouchType Ltd 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.net.URI 20 | import java.util.UUID 21 | 22 | import com.amazonaws.services.s3.model.BucketLifecycleConfiguration 23 | import com.amazonaws.services.s3.{AmazonS3Client, AmazonS3URI} 24 | import org.apache.hadoop.conf.Configuration 25 | import org.apache.hadoop.fs.FileSystem 26 | import org.slf4j.LoggerFactory 27 | 28 | import scala.collection.JavaConverters._ 29 | import scala.util.control.NonFatal 30 | 31 | /** 32 | * Various arbitrary helper functions 33 | */ 34 | private[redshift] object Utils { 35 | 36 | private val log = LoggerFactory.getLogger(getClass) 37 | 38 | def classForName(className: String): Class[_] = { 39 | val classLoader = 40 | Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader) 41 | // scalastyle:off 42 | Class.forName(className, true, classLoader) 43 | // scalastyle:on 44 | } 45 | 46 | /** 47 | * Joins prefix URL a to path suffix b, and appends a trailing /, in order to create 48 | * a temp directory path for S3. 49 | */ 50 | def joinUrls(a: String, b: String): String = { 51 | a.stripSuffix("/") + "/" + b.stripPrefix("/").stripSuffix("/") + "/" 52 | } 53 | 54 | /** 55 | * Redshift COPY and UNLOAD commands don't support s3n or s3a, but users may wish to use them 56 | * for data loads. This function converts the URL back to the s3:// format. 57 | */ 58 | def fixS3Url(url: String): String = { 59 | url.replaceAll("s3[an]://", "s3://") 60 | } 61 | 62 | /** 63 | * Factory method to create new S3URI in order to handle various library incompatibilities with 64 | * older AWS Java Libraries 65 | */ 66 | def createS3URI(url: String): AmazonS3URI = { 67 | try { 68 | // try to instantiate AmazonS3URI with url 69 | new AmazonS3URI(url) 70 | } catch { 71 | case e: IllegalArgumentException if e.getMessage. 72 | startsWith("Invalid S3 URI: hostname does not appear to be a valid S3 endpoint") => { 73 | new AmazonS3URI(addEndpointToUrl(url)) 74 | } 75 | } 76 | } 77 | 78 | /** 79 | * Since older AWS Java Libraries do not handle S3 urls that have just the bucket name 80 | * as the host, add the endpoint to the host 81 | */ 82 | def addEndpointToUrl(url: String, domain: String = "s3.amazonaws.com"): String = { 83 | val uri = new URI(url) 84 | val hostWithEndpoint = uri.getHost + "." + domain 85 | new URI(uri.getScheme, 86 | uri.getUserInfo, 87 | hostWithEndpoint, 88 | uri.getPort, 89 | uri.getPath, 90 | uri.getQuery, 91 | uri.getFragment).toString 92 | } 93 | 94 | /** 95 | * Returns a copy of the given URI with the user credentials removed. 96 | */ 97 | def removeCredentialsFromURI(uri: URI): URI = { 98 | new URI( 99 | uri.getScheme, 100 | null, // no user info 101 | uri.getHost, 102 | uri.getPort, 103 | uri.getPath, 104 | uri.getQuery, 105 | uri.getFragment) 106 | } 107 | 108 | // Visible for testing 109 | private[redshift] var lastTempPathGenerated: String = null 110 | 111 | /** 112 | * Creates a randomly named temp directory path for intermediate data 113 | */ 114 | def makeTempPath(tempRoot: String): String = { 115 | lastTempPathGenerated = Utils.joinUrls(tempRoot, UUID.randomUUID().toString) 116 | lastTempPathGenerated 117 | } 118 | 119 | /** 120 | * Checks whether the S3 bucket for the given UI has an object lifecycle configuration to 121 | * ensure cleanup of temporary files. If no applicable configuration is found, this method logs 122 | * a helpful warning for the user. 123 | */ 124 | def checkThatBucketHasObjectLifecycleConfiguration( 125 | tempDir: String, 126 | s3Client: AmazonS3Client): Unit = { 127 | try { 128 | val s3URI = createS3URI(Utils.fixS3Url(tempDir)) 129 | val bucket = s3URI.getBucket 130 | assert(bucket != null, "Could not get bucket from S3 URI") 131 | val key = Option(s3URI.getKey).getOrElse("") 132 | val hasMatchingBucketLifecycleRule: Boolean = { 133 | val rules = Option(s3Client.getBucketLifecycleConfiguration(bucket)) 134 | .map(_.getRules.asScala) 135 | .getOrElse(Seq.empty) 136 | rules.exists { rule => 137 | // Note: this only checks that there is an active rule which matches the temp directory; 138 | // it does not actually check that the rule will delete the files. This check is still 139 | // better than nothing, though, and we can always improve it later. 140 | rule.getStatus == BucketLifecycleConfiguration.ENABLED && key.startsWith(rule.getPrefix) 141 | } 142 | } 143 | if (!hasMatchingBucketLifecycleRule) { 144 | log.warn(s"The S3 bucket $bucket does not have an object lifecycle configuration to " + 145 | "ensure cleanup of temporary files. Consider configuring `tempdir` to point to a " + 146 | "bucket with an object lifecycle policy that automatically deletes files after an " + 147 | "expiration period. For more information, see " + 148 | "https://docs.aws.amazon.com/AmazonS3/latest/dev/object-lifecycle-mgmt.html") 149 | } 150 | } catch { 151 | case NonFatal(e) => 152 | log.warn("An error occurred while trying to read the S3 bucket lifecycle configuration", e) 153 | } 154 | } 155 | 156 | /** 157 | * Given a URI, verify that the Hadoop FileSystem for that URI is not the S3 block FileSystem. 158 | * `spark-redshift` cannot use this FileSystem because the files written to it will not be 159 | * readable by Redshift (and vice versa). 160 | */ 161 | def assertThatFileSystemIsNotS3BlockFileSystem(uri: URI, hadoopConfig: Configuration): Unit = { 162 | val fs = FileSystem.get(uri, hadoopConfig) 163 | // Note that we do not want to use isInstanceOf here, since we're only interested in detecting 164 | // exact matches. We compare the class names as strings in order to avoid introducing a binary 165 | // dependency on classes which belong to the `hadoop-aws` JAR, as that artifact is not present 166 | // in some environments (such as EMR). See #92 for details. 167 | if (fs.getClass.getCanonicalName == "org.apache.hadoop.fs.s3.S3FileSystem") { 168 | throw new IllegalArgumentException( 169 | "spark-redshift does not support the S3 Block FileSystem. Please reconfigure `tempdir` to" + 170 | "use a s3n:// or s3a:// scheme.") 171 | } 172 | } 173 | 174 | /** 175 | * Attempts to retrieve the region of the S3 bucket. 176 | */ 177 | def getRegionForS3Bucket(tempDir: String, s3Client: AmazonS3Client): Option[String] = { 178 | try { 179 | val s3URI = createS3URI(Utils.fixS3Url(tempDir)) 180 | val bucket = s3URI.getBucket 181 | assert(bucket != null, "Could not get bucket from S3 URI") 182 | val region = s3Client.getBucketLocation(bucket) match { 183 | // Map "US Standard" to us-east-1 184 | case null | "US" => "us-east-1" 185 | case other => other 186 | } 187 | Some(region) 188 | } catch { 189 | case NonFatal(e) => 190 | log.warn("An error occurred while trying to determine the S3 bucket's region", e) 191 | None 192 | } 193 | } 194 | 195 | /** 196 | * Attempts to determine the region of a Redshift cluster based on its URL. It may not be possible 197 | * to determine the region in some cases, such as when the Redshift cluster is placed behind a 198 | * proxy. 199 | */ 200 | def getRegionForRedshiftCluster(url: String): Option[String] = { 201 | val regionRegex = """.*\.([^.]+)\.redshift\.amazonaws\.com.*""".r 202 | url match { 203 | case regionRegex(region) => Some(region) 204 | case _ => None 205 | } 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftInputFormat.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.io.{BufferedInputStream, IOException} 20 | import java.lang.{Long => JavaLong} 21 | import java.nio.charset.Charset 22 | 23 | import org.apache.hadoop.conf.Configuration 24 | import org.apache.hadoop.fs.{FileSystem, Path} 25 | import org.apache.hadoop.io.compress.CompressionCodecFactory 26 | import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} 27 | import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} 28 | 29 | import scala.collection.mutable.ArrayBuffer 30 | 31 | /** 32 | * Input format for text records saved with in-record delimiter and newline characters escaped. 33 | * 34 | * For example, a record containing two fields: `"a\n"` and `"|b\\"` saved with delimiter `|` 35 | * should be the following: 36 | * {{{ 37 | * a\\\n|\\|b\\\\\n 38 | * }}}, 39 | * where the in-record `|`, `\r`, `\n`, and `\\` characters are escaped by `\\`. 40 | * Users can configure the delimiter via [[RedshiftInputFormat$#KEY_DELIMITER]]. 41 | * Its default value [[RedshiftInputFormat$#DEFAULT_DELIMITER]] is set to match Redshift's UNLOAD 42 | * with the ESCAPE option: 43 | * {{{ 44 | * UNLOAD ('select_statement') 45 | * TO 's3://object_path_prefix' 46 | * ESCAPE 47 | * }}} 48 | * 49 | * @see org.apache.spark.SparkContext#newAPIHadoopFile 50 | */ 51 | class RedshiftInputFormat extends FileInputFormat[JavaLong, Array[String]] { 52 | 53 | override def createRecordReader( 54 | split: InputSplit, 55 | context: TaskAttemptContext): RecordReader[JavaLong, Array[String]] = { 56 | new RedshiftRecordReader 57 | } 58 | } 59 | 60 | object RedshiftInputFormat { 61 | 62 | /** configuration key for delimiter */ 63 | val KEY_DELIMITER = "redshift.delimiter" 64 | /** default delimiter */ 65 | val DEFAULT_DELIMITER = '|' 66 | 67 | /** Gets the delimiter char from conf or the default. */ 68 | private[redshift] def getDelimiterOrDefault(conf: Configuration): Char = { 69 | val c = conf.get(KEY_DELIMITER, DEFAULT_DELIMITER.toString) 70 | if (c.length != 1) { 71 | throw new IllegalArgumentException(s"Expect delimiter be a single character but got '$c'.") 72 | } else { 73 | c.charAt(0) 74 | } 75 | } 76 | } 77 | 78 | private[redshift] class RedshiftRecordReader extends RecordReader[JavaLong, Array[String]] { 79 | 80 | private var reader: BufferedInputStream = _ 81 | 82 | private var key: JavaLong = _ 83 | private var value: Array[String] = _ 84 | 85 | private var start: Long = _ 86 | private var end: Long = _ 87 | private var cur: Long = _ 88 | 89 | private var eof: Boolean = false 90 | 91 | private var delimiter: Byte = _ 92 | @inline private[this] final val escapeChar: Byte = '\\' 93 | @inline private[this] final val lineFeed: Byte = '\n' 94 | @inline private[this] final val carriageReturn: Byte = '\r' 95 | 96 | @inline private[this] final val defaultBufferSize = 1024 * 1024 97 | 98 | private[this] val chars = ArrayBuffer.empty[Byte] 99 | 100 | override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = { 101 | val split = inputSplit.asInstanceOf[FileSplit] 102 | val file = split.getPath 103 | val conf: Configuration = context.getConfiguration 104 | delimiter = RedshiftInputFormat.getDelimiterOrDefault(conf).asInstanceOf[Byte] 105 | require(delimiter != escapeChar, 106 | s"The delimiter and the escape char cannot be the same but found $delimiter.") 107 | require(delimiter != lineFeed, "The delimiter cannot be the lineFeed character.") 108 | require(delimiter != carriageReturn, "The delimiter cannot be the carriage return.") 109 | val compressionCodecs = new CompressionCodecFactory(conf) 110 | val codec = compressionCodecs.getCodec(file) 111 | if (codec != null) { 112 | throw new IOException(s"Do not support compressed files but found $file.") 113 | } 114 | val fs = file.getFileSystem(conf) 115 | val size = fs.getFileStatus(file).getLen 116 | start = findNext(fs, file, size, split.getStart) 117 | end = findNext(fs, file, size, split.getStart + split.getLength) 118 | cur = start 119 | val in = fs.open(file) 120 | if (cur > 0L) { 121 | in.seek(cur - 1L) 122 | in.read() 123 | } 124 | reader = new BufferedInputStream(in, defaultBufferSize) 125 | } 126 | 127 | override def getProgress: Float = { 128 | if (start >= end) { 129 | 1.0f 130 | } else { 131 | math.min((cur - start).toFloat / (end - start), 1.0f) 132 | } 133 | } 134 | 135 | override def nextKeyValue(): Boolean = { 136 | if (cur < end && !eof) { 137 | key = cur 138 | value = nextValue() 139 | true 140 | } else { 141 | key = null 142 | value = null 143 | false 144 | } 145 | } 146 | 147 | override def getCurrentValue: Array[String] = value 148 | 149 | override def getCurrentKey: JavaLong = key 150 | 151 | override def close(): Unit = { 152 | if (reader != null) { 153 | reader.close() 154 | } 155 | } 156 | 157 | /** 158 | * Finds the start of the next record. 159 | * Because we don't know whether the first char is escaped or not, we need to first find a 160 | * position that is not escaped. 161 | * 162 | * @param fs file system 163 | * @param file file path 164 | * @param size file size 165 | * @param offset start offset 166 | * @return the start position of the next record 167 | */ 168 | private def findNext(fs: FileSystem, file: Path, size: Long, offset: Long): Long = { 169 | if (offset == 0L) { 170 | return 0L 171 | } else if (offset >= size) { 172 | return size 173 | } 174 | val in = fs.open(file) 175 | var pos = offset 176 | in.seek(pos) 177 | val bis = new BufferedInputStream(in, defaultBufferSize) 178 | // Find the first unescaped char. 179 | var escaped = true 180 | var thisEof = false 181 | while (escaped && !thisEof) { 182 | val v = bis.read() 183 | if (v < 0) { 184 | thisEof = true 185 | } else { 186 | pos += 1 187 | if (v != escapeChar) { 188 | escaped = false 189 | } 190 | } 191 | } 192 | // Find the next unescaped line feed. 193 | var endOfRecord = false 194 | while ((escaped || !endOfRecord) && !thisEof) { 195 | val v = bis.read() 196 | if (v < 0) { 197 | thisEof = true 198 | } else { 199 | pos += 1 200 | if (v == escapeChar) { 201 | escaped = true 202 | } else { 203 | if (!escaped) { 204 | endOfRecord = v == lineFeed 205 | } else { 206 | escaped = false 207 | } 208 | } 209 | } 210 | } 211 | in.close() 212 | pos 213 | } 214 | 215 | private def nextValue(): Array[String] = { 216 | val fields = ArrayBuffer.empty[String] 217 | var escaped = false 218 | var endOfRecord = false 219 | while (!endOfRecord && !eof) { 220 | var endOfField = false 221 | chars.clear() 222 | while (!endOfField && !endOfRecord && !eof) { 223 | val v = reader.read() 224 | if (v < 0) { 225 | eof = true 226 | } else { 227 | cur += 1L 228 | val c = v.asInstanceOf[Byte] 229 | if (escaped) { 230 | if (c != escapeChar && c != delimiter && c != lineFeed && c != carriageReturn) { 231 | throw new IllegalStateException( 232 | s"Found `$c` (ASCII $v) after $escapeChar.") 233 | } 234 | chars.append(c) 235 | escaped = false 236 | } else { 237 | if (c == escapeChar) { 238 | escaped = true 239 | } else if (c == delimiter) { 240 | endOfField = true 241 | } else if (c == lineFeed) { 242 | endOfRecord = true 243 | } else { 244 | // also copy carriage return 245 | chars.append(c) 246 | } 247 | } 248 | } 249 | } 250 | // TODO: charset? 251 | fields.append(new String(chars.toArray, Charset.forName("UTF-8"))) 252 | } 253 | if (escaped) { 254 | throw new IllegalStateException(s"Found hanging escape char.") 255 | } 256 | fields.toArray 257 | } 258 | } 259 | 260 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/RedshiftRelation.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 TouchType Ltd 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.io.InputStreamReader 20 | import java.net.URI 21 | 22 | import com.amazonaws.auth.AWSCredentialsProvider 23 | import com.amazonaws.services.s3.AmazonS3Client 24 | import com.eclipsesource.json.Json 25 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters 26 | import org.apache.spark.rdd.RDD 27 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 28 | import org.apache.spark.sql.sources._ 29 | import org.apache.spark.sql.types._ 30 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} 31 | import org.slf4j.LoggerFactory 32 | 33 | import scala.collection.JavaConverters._ 34 | 35 | /** 36 | * Data Source API implementation for Amazon Redshift database tables 37 | */ 38 | private[redshift] case class RedshiftRelation( 39 | jdbcWrapper: JDBCWrapper, 40 | s3ClientFactory: AWSCredentialsProvider => AmazonS3Client, 41 | params: MergedParameters, 42 | userSchema: Option[StructType]) 43 | (@transient val sqlContext: SQLContext) 44 | extends BaseRelation 45 | with PrunedFilteredScan 46 | with InsertableRelation { 47 | 48 | private val log = LoggerFactory.getLogger(getClass) 49 | 50 | if (sqlContext != null) { 51 | Utils.assertThatFileSystemIsNotS3BlockFileSystem( 52 | new URI(params.rootTempDir), sqlContext.sparkContext.hadoopConfiguration) 53 | } 54 | 55 | private val tableNameOrSubquery = 56 | params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get 57 | 58 | override lazy val schema: StructType = { 59 | userSchema.getOrElse { 60 | val tableNameOrSubquery = 61 | params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get 62 | val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) 63 | try { 64 | jdbcWrapper.resolveTable(conn, tableNameOrSubquery) 65 | } finally { 66 | conn.close() 67 | } 68 | } 69 | } 70 | 71 | override def toString: String = s"RedshiftRelation($tableNameOrSubquery)" 72 | 73 | override def insert(data: DataFrame, overwrite: Boolean): Unit = { 74 | val saveMode = if (overwrite) { 75 | SaveMode.Overwrite 76 | } else { 77 | SaveMode.Append 78 | } 79 | val writer = new RedshiftWriter(jdbcWrapper, s3ClientFactory) 80 | writer.saveToRedshift(sqlContext, data, saveMode, params) 81 | } 82 | 83 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { 84 | filters.filterNot(filter => FilterPushdown.buildFilterExpression(schema, filter).isDefined) 85 | } 86 | 87 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { 88 | val creds = AWSCredentialsUtils.load(params, sqlContext.sparkContext.hadoopConfiguration) 89 | for ( 90 | redshiftRegion <- Utils.getRegionForRedshiftCluster(params.jdbcUrl); 91 | s3Region <- Utils.getRegionForS3Bucket(params.rootTempDir, s3ClientFactory(creds)) 92 | ) { 93 | if (redshiftRegion != s3Region) { 94 | // We don't currently support `extraunloadoptions`, so even if Amazon _did_ add a `region` 95 | // option for this we wouldn't be able to pass in the new option. However, we choose to 96 | // err on the side of caution and don't throw an exception because we don't want to break 97 | // existing workloads in case the region detection logic is wrong. 98 | log.error("The Redshift cluster and S3 bucket are in different regions " + 99 | s"($redshiftRegion and $s3Region, respectively). Redshift's UNLOAD command requires " + 100 | s"that the Redshift cluster and Amazon S3 bucket be located in the same region, so " + 101 | s"this read will fail.") 102 | } 103 | } 104 | Utils.checkThatBucketHasObjectLifecycleConfiguration(params.rootTempDir, s3ClientFactory(creds)) 105 | if (requiredColumns.isEmpty) { 106 | // In the special case where no columns were requested, issue a `count(*)` against Redshift 107 | // rather than unloading data. 108 | val whereClause = FilterPushdown.buildWhereClause(schema, filters) 109 | val countQuery = s"SELECT count(*) FROM $tableNameOrSubquery $whereClause" 110 | log.info(countQuery) 111 | val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) 112 | try { 113 | val results = jdbcWrapper.executeQueryInterruptibly(conn.prepareStatement(countQuery)) 114 | if (results.next()) { 115 | val numRows = results.getLong(1) 116 | val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt 117 | val emptyRow = RowEncoder(StructType(Seq.empty)).toRow(Row(Seq.empty)) 118 | sqlContext.sparkContext 119 | .parallelize(1L to numRows, parallelism) 120 | .map(_ => emptyRow) 121 | .asInstanceOf[RDD[Row]] 122 | } else { 123 | throw new IllegalStateException("Could not read count from Redshift") 124 | } 125 | } finally { 126 | conn.close() 127 | } 128 | } else { 129 | // Unload data from Redshift into a temporary directory in S3: 130 | val tempDir = params.createPerQueryTempDir() 131 | val unloadSql = buildUnloadStmt(requiredColumns, filters, tempDir, creds) 132 | val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) 133 | try { 134 | jdbcWrapper.executeInterruptibly(conn.prepareStatement(unloadSql)) 135 | } finally { 136 | conn.close() 137 | } 138 | // Read the MANIFEST file to get the list of S3 part files that were written by Redshift. 139 | // We need to use a manifest in order to guard against S3's eventually-consistent listings. 140 | val filesToRead: Seq[String] = { 141 | val cleanedTempDirUri = 142 | Utils.fixS3Url(Utils.removeCredentialsFromURI(URI.create(tempDir)).toString) 143 | val s3URI = Utils.createS3URI(cleanedTempDirUri) 144 | val s3Client = s3ClientFactory(creds) 145 | val is = s3Client.getObject(s3URI.getBucket, s3URI.getKey + "manifest").getObjectContent 146 | val s3Files = try { 147 | val entries = Json.parse(new InputStreamReader(is)).asObject().get("entries").asArray() 148 | entries.iterator().asScala.map(_.asObject().get("url").asString()).toSeq 149 | } finally { 150 | is.close() 151 | } 152 | // The filenames in the manifest are of the form s3://bucket/key, without credentials. 153 | // If the S3 credentials were originally specified in the tempdir's URI, then we need to 154 | // reintroduce them here 155 | s3Files.map { file => 156 | tempDir.stripSuffix("/") + '/' + file.stripPrefix(cleanedTempDirUri).stripPrefix("/") 157 | } 158 | } 159 | 160 | val prunedSchema = pruneSchema(schema, requiredColumns) 161 | 162 | sqlContext.read 163 | .format(classOf[RedshiftFileFormat].getName) 164 | .schema(prunedSchema) 165 | .option("nullString", params.nullString) 166 | .load(filesToRead: _*) 167 | .queryExecution.executedPlan.execute().asInstanceOf[RDD[Row]] 168 | } 169 | } 170 | 171 | override def needConversion: Boolean = false 172 | 173 | private def buildUnloadStmt( 174 | requiredColumns: Array[String], 175 | filters: Array[Filter], 176 | tempDir: String, 177 | creds: AWSCredentialsProvider): String = { 178 | assert(!requiredColumns.isEmpty) 179 | // Always quote column names: 180 | val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ") 181 | val whereClause = FilterPushdown.buildWhereClause(schema, filters) 182 | val credsString: String = 183 | AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials) 184 | val query = { 185 | // Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape 186 | // any backslashes and single quotes that appear in the query itself 187 | val escapedTableNameOrSubqury = tableNameOrSubquery.replace("\\", "\\\\").replace("'", "\\'") 188 | s"SELECT $columnList FROM $escapedTableNameOrSubqury $whereClause" 189 | } 190 | log.info(query) 191 | // We need to remove S3 credentials from the unload path URI because they will conflict with 192 | // the credentials passed via `credsString`. 193 | val fixedUrl = Utils.fixS3Url(Utils.removeCredentialsFromURI(new URI(tempDir)).toString) 194 | 195 | s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString'" + 196 | s" ESCAPE MANIFEST NULL AS '${params.nullString}'" 197 | } 198 | 199 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { 200 | val fieldMap = Map(schema.fields.map(x => x.name -> x): _*) 201 | new StructType(columns.map(name => fieldMap(name))) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/IntegrationSuiteBase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.net.URI 20 | import java.sql.Connection 21 | 22 | import org.apache.hadoop.conf.Configuration 23 | import org.apache.hadoop.fs.s3native.NativeS3FileSystem 24 | import org.apache.hadoop.fs.{FileSystem, Path} 25 | import org.apache.spark.SparkContext 26 | import org.apache.spark.sql._ 27 | import org.apache.spark.sql.hive.test.TestHiveContext 28 | import org.apache.spark.sql.types.StructType 29 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Matchers} 30 | 31 | import scala.util.Random 32 | 33 | 34 | /** 35 | * Base class for writing integration tests which run against a real Redshift cluster. 36 | */ 37 | trait IntegrationSuiteBase 38 | extends QueryTest 39 | with Matchers 40 | with BeforeAndAfterAll 41 | with BeforeAndAfterEach { 42 | 43 | protected def loadConfigFromEnv(envVarName: String): String = { 44 | Option(System.getenv(envVarName)).getOrElse { 45 | fail(s"Must set $envVarName environment variable") 46 | } 47 | } 48 | 49 | // The following configurations must be set in order to run these tests. In Travis, these 50 | // environment variables are set using Travis's encrypted environment variables feature: 51 | // http://docs.travis-ci.com/user/environment-variables/#Encrypted-Variables 52 | 53 | // JDBC URL listed in the AWS console (should not contain username and password). 54 | protected val AWS_REDSHIFT_JDBC_URL: String = loadConfigFromEnv("AWS_REDSHIFT_JDBC_URL") 55 | protected val AWS_REDSHIFT_USER: String = loadConfigFromEnv("AWS_REDSHIFT_USER") 56 | protected val AWS_REDSHIFT_PASSWORD: String = loadConfigFromEnv("AWS_REDSHIFT_PASSWORD") 57 | protected val AWS_ACCESS_KEY_ID: String = loadConfigFromEnv("AWS_ACCESS_KEY_ID") 58 | protected val AWS_SECRET_ACCESS_KEY: String = loadConfigFromEnv("AWS_SECRET_ACCESS_KEY") 59 | // Path to a directory in S3 (e.g. 's3n://bucket-name/path/to/scratch/space'). 60 | protected val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE") 61 | require(AWS_S3_SCRATCH_SPACE.contains("s3a"), "must use s3a:// URL") 62 | 63 | protected def jdbcUrl: String = { 64 | s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD&ssl=true" 65 | } 66 | 67 | protected def jdbcUrlNoUserPassword: String = { 68 | s"$AWS_REDSHIFT_JDBC_URL?ssl=true" 69 | } 70 | /** 71 | * Random suffix appended appended to table and directory names in order to avoid collisions 72 | * between separate Travis builds. 73 | */ 74 | protected val randomSuffix: String = Math.abs(Random.nextLong()).toString 75 | 76 | protected val tempDir: String = AWS_S3_SCRATCH_SPACE + randomSuffix + "/" 77 | 78 | /** 79 | * Spark Context with Hadoop file overridden to point at our local test data file for this suite, 80 | * no-matter what temp directory was generated and requested. 81 | */ 82 | protected var sc: SparkContext = _ 83 | protected var sqlContext: SQLContext = _ 84 | protected var conn: Connection = _ 85 | 86 | override def beforeAll(): Unit = { 87 | super.beforeAll() 88 | sc = new SparkContext("local", "RedshiftSourceSuite") 89 | // Bypass Hadoop's FileSystem caching mechanism so that we don't cache the credentials: 90 | sc.hadoopConfiguration.setBoolean("fs.s3.impl.disable.cache", true) 91 | sc.hadoopConfiguration.setBoolean("fs.s3n.impl.disable.cache", true) 92 | sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID) 93 | sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) 94 | sc.hadoopConfiguration.set("fs.s3a.access.key", AWS_ACCESS_KEY_ID) 95 | sc.hadoopConfiguration.set("fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY) 96 | conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl, None) 97 | } 98 | 99 | override def afterAll(): Unit = { 100 | try { 101 | val conf = new Configuration(false) 102 | conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID) 103 | conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) 104 | conf.set("fs.s3a.access.key", AWS_ACCESS_KEY_ID) 105 | conf.set("fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY) 106 | // Bypass Hadoop's FileSystem caching mechanism so that we don't cache the credentials: 107 | conf.setBoolean("fs.s3.impl.disable.cache", true) 108 | conf.setBoolean("fs.s3n.impl.disable.cache", true) 109 | conf.setBoolean("fs.s3a.impl.disable.cache", true) 110 | conf.set("fs.s3.impl", classOf[InMemoryS3AFileSystem].getCanonicalName) 111 | conf.set("fs.s3a.impl", classOf[InMemoryS3AFileSystem].getCanonicalName) 112 | val fs = FileSystem.get(URI.create(tempDir), conf) 113 | fs.delete(new Path(tempDir), true) 114 | fs.close() 115 | } finally { 116 | try { 117 | conn.close() 118 | } finally { 119 | try { 120 | sc.stop() 121 | } finally { 122 | super.afterAll() 123 | } 124 | } 125 | } 126 | } 127 | 128 | override protected def beforeEach(): Unit = { 129 | super.beforeEach() 130 | sqlContext = new TestHiveContext(sc, loadTestTables = false) 131 | } 132 | 133 | /** 134 | * Create a new DataFrameReader using common options for reading from Redshift. 135 | */ 136 | protected def read: DataFrameReader = { 137 | sqlContext.read 138 | .format("io.github.spark_redshift_community.spark.redshift") 139 | .option("url", jdbcUrl) 140 | .option("tempdir", tempDir) 141 | .option("forward_spark_s3_credentials", "true") 142 | } 143 | /** 144 | * Create a new DataFrameWriter using common options for writing to Redshift. 145 | */ 146 | protected def write(df: DataFrame): DataFrameWriter[Row] = { 147 | df.write 148 | .format("io.github.spark_redshift_community.spark.redshift") 149 | .option("url", jdbcUrl) 150 | .option("tempdir", tempDir) 151 | .option("forward_spark_s3_credentials", "true") 152 | } 153 | 154 | protected def createTestDataInRedshift(tableName: String): Unit = { 155 | conn.createStatement().executeUpdate( 156 | s""" 157 | |create table $tableName ( 158 | |testbyte int2, 159 | |testbool boolean, 160 | |testdate date, 161 | |testdouble float8, 162 | |testfloat float4, 163 | |testint int4, 164 | |testlong int8, 165 | |testshort int2, 166 | |teststring varchar(256), 167 | |testtimestamp timestamp 168 | |) 169 | """.stripMargin 170 | ) 171 | // scalastyle:off 172 | conn.createStatement().executeUpdate( 173 | s""" 174 | |insert into $tableName values 175 | |(null, null, null, null, null, null, null, null, null, null), 176 | |(0, null, '2015-07-03', 0.0, -1.0, 4141214, 1239012341823719, null, 'f', '2015-07-03 00:00:00.000'), 177 | |(0, false, null, -1234152.12312498, 100000.0, null, 1239012341823719, 24, '___|_123', null), 178 | |(1, false, '2015-07-02', 0.0, 0.0, 42, 1239012341823719, -13, 'asdf', '2015-07-02 00:00:00.000'), 179 | |(1, true, '2015-07-01', 1234152.12312498, 1.0, 42, 1239012341823719, 23, 'Unicode''s樂趣', '2015-07-01 00:00:00.001') 180 | """.stripMargin 181 | ) 182 | // scalastyle:on 183 | } 184 | 185 | protected def withTempRedshiftTable[T](namePrefix: String)(body: String => T): T = { 186 | val tableName = s"$namePrefix$randomSuffix" 187 | try { 188 | body(tableName) 189 | } finally { 190 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 191 | } 192 | } 193 | 194 | /** 195 | * Save the given DataFrame to Redshift, then load the results back into a DataFrame and check 196 | * that the returned DataFrame matches the one that we saved. 197 | * 198 | * @param tableName the table name to use 199 | * @param df the DataFrame to save 200 | * @param expectedSchemaAfterLoad if specified, the expected schema after loading the data back 201 | * from Redshift. This should be used in cases where you expect 202 | * the schema to differ due to reasons like case-sensitivity. 203 | * @param saveMode the [[SaveMode]] to use when writing data back to Redshift 204 | */ 205 | def testRoundtripSaveAndLoad( 206 | tableName: String, 207 | df: DataFrame, 208 | expectedSchemaAfterLoad: Option[StructType] = None, 209 | saveMode: SaveMode = SaveMode.ErrorIfExists): Unit = { 210 | try { 211 | write(df) 212 | .option("dbtable", tableName) 213 | .mode(saveMode) 214 | .save() 215 | // Check that the table exists. It appears that creating a table in one connection then 216 | // immediately querying for existence from another connection may result in spurious "table 217 | // doesn't exist" errors; this caused the "save with all empty partitions" test to become 218 | // flaky (see #146). To work around this, add a small sleep and check again: 219 | if (!DefaultJDBCWrapper.tableExists(conn, tableName)) { 220 | Thread.sleep(1000) 221 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 222 | } 223 | val loadedDf = read.option("dbtable", tableName).load() 224 | assert(loadedDf.schema === expectedSchemaAfterLoad.getOrElse(df.schema)) 225 | checkAnswer(loadedDf, df.collect()) 226 | } finally { 227 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 228 | } 229 | } 230 | } 231 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/RedshiftReadSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.sql.Timestamp 20 | 21 | import org.apache.spark.sql.types.LongType 22 | import org.apache.spark.sql.{Row, execution} 23 | 24 | /** 25 | * End-to-end tests of functionality which only impacts the read path (e.g. filter pushdown). 26 | */ 27 | class RedshiftReadSuite extends IntegrationSuiteBase { 28 | 29 | private val test_table: String = s"read_suite_test_table_$randomSuffix" 30 | 31 | override def beforeAll(): Unit = { 32 | super.beforeAll() 33 | conn.prepareStatement(s"drop table if exists $test_table").executeUpdate() 34 | createTestDataInRedshift(test_table) 35 | } 36 | 37 | override def afterAll(): Unit = { 38 | try { 39 | conn.prepareStatement(s"drop table if exists $test_table").executeUpdate() 40 | } finally { 41 | super.afterAll() 42 | } 43 | } 44 | 45 | override def beforeEach(): Unit = { 46 | super.beforeEach() 47 | read.option("dbtable", test_table).load().createOrReplaceTempView("test_table") 48 | } 49 | 50 | test("DefaultSource can load Redshift UNLOAD output to a DataFrame") { 51 | checkAnswer( 52 | sqlContext.sql("select * from test_table"), 53 | TestUtils.expectedData) 54 | } 55 | 56 | test("count() on DataFrame created from a Redshift table") { 57 | checkAnswer( 58 | sqlContext.sql("select count(*) from test_table"), 59 | Seq(Row(TestUtils.expectedData.length)) 60 | ) 61 | } 62 | 63 | test("count() on DataFrame created from a Redshift query") { 64 | val loadedDf = 65 | // scalastyle:off 66 | read.option("query", s"select * from $test_table where teststring = 'Unicode''s樂趣'").load() 67 | // scalastyle:on 68 | checkAnswer( 69 | loadedDf.selectExpr("count(*)"), 70 | Seq(Row(1)) 71 | ) 72 | } 73 | 74 | test("backslashes in queries/subqueries are escaped (regression test for #215)") { 75 | val loadedDf = 76 | read.option("query", s"select replace(teststring, '\\\\', '') as col from $test_table").load() 77 | checkAnswer( 78 | loadedDf.filter("col = 'asdf'"), 79 | Seq(Row("asdf")) 80 | ) 81 | } 82 | 83 | test("Can load output when 'dbtable' is a subquery wrapped in parentheses") { 84 | // scalastyle:off 85 | val query = 86 | s""" 87 | |(select testbyte, testbool 88 | |from $test_table 89 | |where testbool = true 90 | | and teststring = 'Unicode''s樂趣' 91 | | and testdouble = 1234152.12312498 92 | | and testfloat = 1.0 93 | | and testint = 42) 94 | """.stripMargin 95 | // scalastyle:on 96 | checkAnswer(read.option("dbtable", query).load(), Seq(Row(1, true))) 97 | } 98 | 99 | test("Can load output when 'query' is specified instead of 'dbtable'") { 100 | // scalastyle:off 101 | val query = 102 | s""" 103 | |select testbyte, testbool 104 | |from $test_table 105 | |where testbool = true 106 | | and teststring = 'Unicode''s樂趣' 107 | | and testdouble = 1234152.12312498 108 | | and testfloat = 1.0 109 | | and testint = 42 110 | """.stripMargin 111 | // scalastyle:on 112 | checkAnswer(read.option("query", query).load(), Seq(Row(1, true))) 113 | } 114 | 115 | test("Can load output of Redshift aggregation queries") { 116 | checkAnswer( 117 | read.option("query", s"select testbool, count(*) from $test_table group by testbool").load(), 118 | Seq(Row(true, 1), Row(false, 2), Row(null, 2))) 119 | } 120 | 121 | test("multiple scans on same table") { 122 | // .rdd() forces the first query to be unloaded from Redshift 123 | val rdd1 = sqlContext.sql("select testint from test_table").rdd 124 | // Similarly, this also forces an unload: 125 | sqlContext.sql("select testdouble from test_table").rdd 126 | // If the unloads were performed into the same directory then this call would fail: the 127 | // second unload from rdd2 would have overwritten the integers with doubles, so we'd get 128 | // a NumberFormatException. 129 | rdd1.count() 130 | } 131 | 132 | test("DefaultSource supports simple column filtering") { 133 | checkAnswer( 134 | sqlContext.sql("select testbyte, testbool from test_table"), 135 | Seq( 136 | Row(null, null), 137 | Row(0.toByte, null), 138 | Row(0.toByte, false), 139 | Row(1.toByte, false), 140 | Row(1.toByte, true))) 141 | } 142 | 143 | test("query with pruned and filtered scans") { 144 | // scalastyle:off 145 | checkAnswer( 146 | sqlContext.sql( 147 | """ 148 | |select testbyte, testbool 149 | |from test_table 150 | |where testbool = true 151 | | and teststring = "Unicode's樂趣" 152 | | and testdouble = 1234152.12312498 153 | | and testfloat = 1.0 154 | | and testint = 42 155 | """.stripMargin), 156 | Seq(Row(1, true))) 157 | // scalastyle:on 158 | } 159 | 160 | test("RedshiftRelation implements Spark 1.6+'s unhandledFilters API") { 161 | assume(org.apache.spark.SPARK_VERSION.take(3) >= "1.6") 162 | val df = sqlContext.sql("select testbool from test_table where testbool = true") 163 | val physicalPlan = df.queryExecution.sparkPlan 164 | physicalPlan.collectFirst { case f: execution.FilterExec => f }.foreach { filter => 165 | fail(s"Filter should have been eliminated:\n${df.queryExecution}") 166 | } 167 | } 168 | 169 | test("filtering based on date constants (regression test for #152)") { 170 | val date = TestUtils.toDate(year = 2015, zeroBasedMonth = 6, date = 3) 171 | val df = sqlContext.sql("select testdate from test_table") 172 | 173 | checkAnswer(df.filter(df("testdate") === date), Seq(Row(date))) 174 | // This query failed in Spark 1.6.0 but not in earlier versions. It looks like 1.6.0 performs 175 | // constant-folding, whereas earlier Spark versions would preserve the cast which prevented 176 | // filter pushdown. 177 | checkAnswer(df.filter("testdate = to_date('2015-07-03')"), Seq(Row(date))) 178 | } 179 | 180 | test("filtering based on timestamp constants (regression test for #152)") { 181 | val timestamp = TestUtils.toTimestamp(2015, zeroBasedMonth = 6, 1, 0, 0, 0, 1) 182 | val df = sqlContext.sql("select testtimestamp from test_table") 183 | 184 | checkAnswer(df.filter(df("testtimestamp") === timestamp), Seq(Row(timestamp))) 185 | // This query failed in Spark 1.6.0 but not in earlier versions. It looks like 1.6.0 performs 186 | // constant-folding, whereas earlier Spark versions would preserve the cast which prevented 187 | // filter pushdown. 188 | checkAnswer(df.filter("testtimestamp = '2015-07-01 00:00:00.001'"), Seq(Row(timestamp))) 189 | } 190 | 191 | test("read special float values (regression test for #261)") { 192 | val tableName = s"roundtrip_special_float_values_$randomSuffix" 193 | try { 194 | conn.createStatement().executeUpdate( 195 | s"CREATE TABLE $tableName (x real)") 196 | conn.createStatement().executeUpdate( 197 | s"INSERT INTO $tableName VALUES ('NaN'), ('Infinity'), ('-Infinity')") 198 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 199 | checkAnswer( 200 | read.option("dbtable", tableName).load(), 201 | Seq(Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity).map(x => Row.apply(x))) 202 | } finally { 203 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 204 | } 205 | } 206 | 207 | test("test empty string and null") { 208 | withTempRedshiftTable("records_with_empty_and_null_characters") { tableName => 209 | conn.createStatement().executeUpdate( 210 | s"CREATE TABLE $tableName (x varchar(256))") 211 | conn.createStatement().executeUpdate( 212 | s"INSERT INTO $tableName VALUES ('null'), (''), (null)") 213 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 214 | checkAnswer( 215 | read.option("dbtable", tableName).load(), 216 | Seq("null", "", null).map(x => Row.apply(x))) 217 | } 218 | } 219 | 220 | test("test timestamptz parsing") { 221 | withTempRedshiftTable("luca_test_timestamptz_spark_redshift") { tableName => 222 | conn.createStatement().executeUpdate( 223 | s"CREATE TABLE $tableName (x timestamptz)" 224 | ) 225 | conn.createStatement().executeUpdate( 226 | s"INSERT INTO $tableName VALUES ('2015-07-03 00:00:00.000 -0300')" 227 | ) 228 | 229 | checkAnswer( 230 | read.option("dbtable", tableName).load(), 231 | Seq(Row.apply( 232 | new Timestamp(TestUtils.toMillis( 233 | 2015, 6, 3, 0, 0, 0, 0, "-03")))) 234 | ) 235 | } 236 | } 237 | 238 | test("read special double values (regression test for #261)") { 239 | val tableName = s"roundtrip_special_double_values_$randomSuffix" 240 | try { 241 | conn.createStatement().executeUpdate( 242 | s"CREATE TABLE $tableName (x double precision)") 243 | conn.createStatement().executeUpdate( 244 | s"INSERT INTO $tableName VALUES ('NaN'), ('Infinity'), ('-Infinity')") 245 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 246 | checkAnswer( 247 | read.option("dbtable", tableName).load(), 248 | Seq(Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity).map(x => Row.apply(x))) 249 | } finally { 250 | conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() 251 | } 252 | } 253 | 254 | test("read records containing escaped characters") { 255 | withTempRedshiftTable("records_with_escaped_characters") { tableName => 256 | conn.createStatement().executeUpdate( 257 | s"CREATE TABLE $tableName (x text)") 258 | conn.createStatement().executeUpdate( 259 | s"""INSERT INTO $tableName VALUES ('a\\nb'), ('\\\\'), ('"')""") 260 | assert(DefaultJDBCWrapper.tableExists(conn, tableName)) 261 | checkAnswer( 262 | read.option("dbtable", tableName).load(), 263 | Seq("a\nb", "\\", "\"").map(x => Row.apply(x))) 264 | } 265 | } 266 | 267 | test("read result of approximate count(distinct) query (#300)") { 268 | val df = read 269 | .option("query", s"select approximate count(distinct testbool) as c from $test_table") 270 | .load() 271 | assert(df.schema.fields(0).dataType === LongType) 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 17 | 39 | 40 | 41 | Scalastyle standard configuration 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | true 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW 105 | 106 | 107 | 108 | 109 | 110 | ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | ^println$ 120 | 124 | 125 | 126 | 127 | org.apache.spark.Logging 128 | 129 | 130 | 131 | 132 | Class\.forName 133 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 800> 188 | 189 | 190 | 191 | 192 | 30 193 | 194 | 195 | 196 | 197 | 10 198 | 199 | 200 | 201 | 202 | 50 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | -1,0,1,2,3 214 | 215 | 216 | 217 | --------------------------------------------------------------------------------