├── NOTICE ├── sonatype_central_credentials ├── sonatype_credentials ├── tutorial ├── images │ ├── loadreadstep.png │ ├── loadunloadstep.png │ └── savetoredshift.png └── how_to_build.md ├── .jvmopts ├── codecov.yml ├── ARTICLES.md ├── README-DEV.md ├── src ├── test │ ├── resources │ │ ├── redshift_unload_data.txt │ │ ├── log4j2.properties │ │ └── hive-site.xml │ └── scala │ │ └── io │ │ └── github │ │ └── spark_redshift_community │ │ └── spark │ │ └── redshift │ │ ├── SerializableConfigurationSuite.scala │ │ ├── RedshiftQuerySuite.scala │ │ ├── DirectMapredOutputCommitter.scala │ │ ├── DirectMapreduceOutputCommitter.scala │ │ ├── test │ │ └── SeekableByteArrayInputStream.java │ │ ├── TableNameSuite.scala │ │ ├── FilterPushdownSuite.scala │ │ └── QueryTest.scala ├── main │ └── scala │ │ └── io │ │ └── github │ │ └── spark_redshift_community │ │ └── spark │ │ ├── redshift │ │ ├── SetAccumulator.scala │ │ ├── pushdown │ │ │ ├── RedshiftDMLExec.scala │ │ │ ├── querygeneration │ │ │ │ ├── BinaryOp.scala │ │ │ │ ├── UnaryOp.scala │ │ │ │ ├── ScalarSubqueryExtractor.scala │ │ │ │ ├── PassthroughStatement.scala │ │ │ │ ├── DateStatement.scala │ │ │ │ ├── UnsupportedStatement.scala │ │ │ │ ├── NumericStatement.scala │ │ │ │ ├── BooleanStatement.scala │ │ │ │ └── AggregationStatement.scala │ │ │ ├── SqlToS3TempCache.scala │ │ │ ├── RedshiftPlan.scala │ │ │ ├── deoptimize │ │ │ │ └── UndoCharTypePadding.scala │ │ │ └── RedshiftScanExec.scala │ │ ├── data │ │ │ ├── DataApiRuntimeException.scala │ │ │ ├── RedshiftWrapperFactory.scala │ │ │ ├── RedshiftConnection.scala │ │ │ └── RedshiftResults.scala │ │ ├── SerializableConfiguration.scala │ │ ├── ComparableVersion.scala │ │ ├── package.scala │ │ ├── RecordReaderIterator.scala │ │ ├── TableName.scala │ │ └── AWSCredentialsUtils.scala │ │ ├── redshift_spark_3_3_ │ │ ├── RowEncoderUtils.scala │ │ ├── TimestampNTZTypeExtractor.scala │ │ ├── RedshiftFileFormatUtils.scala │ │ └── pushdown │ │ │ └── querygeneration │ │ │ ├── GetMapValueExtractor.scala │ │ │ ├── StringStatementExtensions.scala │ │ │ ├── PromotePrecisionExtractor.scala │ │ │ ├── ExistsExtractor.scala │ │ │ ├── RoundExtractor.scala │ │ │ ├── CastExtractor.scala │ │ │ └── MergeIntoTableExtractor.scala │ │ ├── redshift_spark_3_4_ │ │ ├── RowEncoderUtils.scala │ │ ├── TimestampNTZTypeExtractor.scala │ │ ├── RedshiftFileFormatUtils.scala │ │ └── pushdown │ │ │ └── querygeneration │ │ │ ├── PromotePrecisionExtractor.scala │ │ │ ├── StringStatementExtensions.scala │ │ │ ├── RoundExtractor.scala │ │ │ ├── CastExtractor.scala │ │ │ ├── GetMapValueExtractor.scala │ │ │ ├── ExistsExtractor.scala │ │ │ └── MergeIntoTableExtractor.scala │ │ └── redshift_spark_3_5_ │ │ ├── TimestampNTZTypeExtractor.scala │ │ ├── RedshiftFileFormatUtils.scala │ │ ├── RowEncoderUtils.scala │ │ └── pushdown │ │ └── querygeneration │ │ ├── PromotePrecisionExtractor.scala │ │ ├── RoundExtractor.scala │ │ ├── CastExtractor.scala │ │ ├── GetMapValueExtractor.scala │ │ ├── ExistsExtractor.scala │ │ ├── StringStatementExtensions.scala │ │ └── MergeIntoTableExtractor.scala └── it │ ├── resources │ ├── lst │ │ ├── 1_create_inventory.sql │ │ ├── 1_create_store_returns.sql │ │ ├── 2_load_inventory.sql │ │ ├── 1_create_web_returns.sql │ │ ├── 1_create_web_sales.sql │ │ ├── 1_create_catalog_returns.sql │ │ ├── 1_create_date_dim.sql │ │ ├── 1_create_catalog_sales.sql │ │ ├── 2_load_store_returns.sql │ │ ├── 2_load_web_returns.sql │ │ └── 2_load_date_dim.sql │ ├── log4j.properties │ └── log4j2.properties │ └── scala │ └── io │ └── github │ └── spark_redshift_community │ └── spark │ └── redshift │ ├── pushdown │ ├── TestCase.scala │ ├── AggregateStatisticalOperatorsCorrectnessSuite.scala │ ├── PushdownRedshiftReadSuite.scala │ ├── StringIntegrationPushdownSuiteBase.scala │ ├── lst │ │ └── LSTIntegrationPushdownSuiteBase.scala │ ├── StringSelectCorrectnessSuite.scala │ └── PushdownLocalRelationSuite.scala │ ├── QueryGroupIntegrationSuite.scala │ ├── PostgresDriverIntegrationSuite.scala │ ├── RedshiftCredentialsInConfIntegrationSuite.scala │ ├── CrossRegionIntegrationSuite.scala │ ├── IAMIntegrationSuite.scala │ ├── DecimalIntegrationSuite.scala │ └── OverrideNullableSuite.scala ├── .travis.yml ├── dev └── run-tests-travis.sh ├── version.sbt ├── .env.example ├── project ├── PackagingTypePlugin.scala ├── build.properties └── plugins.sbt ├── .gitignore ├── .gitignore.bak └── .pre-commit-config.yaml /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /sonatype_central_credentials: -------------------------------------------------------------------------------- 1 | host=central.sonatype.com 2 | user=************* 3 | password=************* -------------------------------------------------------------------------------- /sonatype_credentials: -------------------------------------------------------------------------------- 1 | realm=Sonatype Nexus Repository Manager 2 | host=oss.sonatype.org 3 | user=************* 4 | password=************* 5 | -------------------------------------------------------------------------------- /tutorial/images/loadreadstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spark-redshift-community/spark-redshift/HEAD/tutorial/images/loadreadstep.png -------------------------------------------------------------------------------- /tutorial/images/loadunloadstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spark-redshift-community/spark-redshift/HEAD/tutorial/images/loadunloadstep.png -------------------------------------------------------------------------------- /tutorial/images/savetoredshift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spark-redshift-community/spark-redshift/HEAD/tutorial/images/savetoredshift.png -------------------------------------------------------------------------------- /.jvmopts: -------------------------------------------------------------------------------- 1 | -Dfile.encoding=UTF8 2 | -Xms1024M 3 | -Xmx1024M 4 | -Xss6M 5 | -XX:MaxPermSize=512m 6 | -XX:+CMSClassUnloadingEnabled 7 | -XX:+UseConcMarkSweepGC 8 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: 2 | layout: header, changes, diff 3 | coverage: 4 | status: 5 | patch: false 6 | project: 7 | default: 8 | target: 85 9 | -------------------------------------------------------------------------------- /ARTICLES.md: -------------------------------------------------------------------------------- 1 | External articles on Spark-Redshift Connector 2 | 3 | [AWS] 4 | * [Amazon Redshift Integration with Apache Spark](https://aws.amazon.com/blogs/aws/new-amazon-redshift-integration-with-apache-spark/) (2022-11-29) -------------------------------------------------------------------------------- /README-DEV.md: -------------------------------------------------------------------------------- 1 | # Developer notes to iterate on the code 2 | 3 | ## Prerequisite 4 | 5 | * JDK 11 (since sbt currently supports to JDK 11) 6 | * sbt script version: 1.9.2 7 | * Python 3 (to use pre-commit) 8 | * pre-commit 3.4.0 -------------------------------------------------------------------------------- /src/test/resources/redshift_unload_data.txt: -------------------------------------------------------------------------------- 1 | 1|t|2015-07-01|1234152.12312498|1.0|42|1239012341823719|23|Unicode's樂趣|2015-07-01 00:00:00.001 2 | 1|f|2015-07-02|0|0.0|42|1239012341823719|-13|asdf|2015-07-02 00:00:00.0 3 | 0||2015-07-03|0.0|-1.0|4141214|1239012341823719||f|2015-07-03 00:00:00 4 | 0|f||-1234152.12312498|100000.0||1239012341823719|24|___\|_123| 5 | ||||||||| 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | sudo: false 3 | # Cache settings here are based on latest SBT documentation. 4 | cache: 5 | directories: 6 | - $HOME/.ivy2/cache 7 | - $HOME/.sbt/boot/ 8 | before_cache: 9 | # Tricks to avoid unnecessary cache updates 10 | - find $HOME/.ivy2 -name "ivydata-*.properties" -delete 11 | - find $HOME/.sbt -name "*.lock" -delete 12 | matrix: 13 | include: 14 | - jdk: openjdk8 15 | scala: 2.12.11 16 | env: HADOOP_VERSION="3.2.1" SPARK_VERSION="3.0.2" AWS_JAVA_SDK_VERSION="1.11.1033" 17 | 18 | script: 19 | - ./dev/run-tests-travis.sh 20 | 21 | after_success: 22 | - bash <(curl -s https://codecov.io/bash) 23 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/SetAccumulator.scala: -------------------------------------------------------------------------------- 1 | package io.github.spark_redshift_community.spark.redshift 2 | 3 | import org.apache.spark.util.AccumulatorV2 4 | 5 | class SetAccumulator[T](var value: Set[T]) extends AccumulatorV2[T, Set[T]] { 6 | def this() = this(Set.empty[T]) 7 | override def isZero: Boolean = value.isEmpty 8 | override def copy(): AccumulatorV2[T, Set[T]] = new SetAccumulator[T](value) 9 | override def reset(): Unit = value = Set.empty[T] 10 | override def add(v: T): Unit = value = value + v 11 | override def merge(other: AccumulatorV2[T, Set[T]]): Unit = 12 | value = value ++ other.value 13 | } 14 | -------------------------------------------------------------------------------- /dev/run-tests-travis.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | sbt ++$TRAVIS_SCALA_VERSION scalastyle 6 | sbt ++$TRAVIS_SCALA_VERSION "test:scalastyle" 7 | sbt ++$TRAVIS_SCALA_VERSION "it:scalastyle" 8 | 9 | sbt \ 10 | -Daws.testVersion=$AWS_JAVA_SDK_VERSION \ 11 | -Dhadoop.testVersion=$HADOOP_VERSION \ 12 | -Dspark.testVersion=$SPARK_VERSION \ 13 | ++$TRAVIS_SCALA_VERSION \ 14 | coverage test coverageReport 15 | 16 | if [ "$TRAVIS_SECURE_ENV_VARS" == "true" ]; then 17 | sbt \ 18 | -Daws.testVersion=$AWS_JAVA_SDK_VERSION \ 19 | -Dhadoop.testVersion=$HADOOP_VERSION \ 20 | -Dspark.testVersion=$SPARK_VERSION \ 21 | ++$TRAVIS_SCALA_VERSION \ 22 | coverage it:test coverageReport 2> /dev/null; 23 | fi 24 | -------------------------------------------------------------------------------- /version.sbt: -------------------------------------------------------------------------------- 1 | /* 2 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | ThisBuild / version := "6.5.1" 18 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # You need to use real values locally to integration test 2 | # We already add .env into .gitignore. Do not try to 3 | # include it in pull requests. 4 | AWS_ACCESS_KEY_ID=foo 5 | AWS_SECRET_ACCESS_KEY=bar 6 | AWS_SESSION_TOKEN= 7 | 8 | AWS_REDSHIFT_USER= 9 | AWS_REDSHIFT_PASSWORD= 10 | AWS_REDSHIFT_JDBC_URL= 11 | 12 | AWS_S3_CROSS_REGION_SCRATCH_SPACE= 13 | AWS_S3_CROSS_REGION_SCRATCH_SPACE_REGION= 14 | 15 | STS_ROLE_ARN= 16 | 17 | AWS_S3_SCRATCH_SPACE= 18 | AWS_S3_SCRATCH_SPACE_REGION= -------------------------------------------------------------------------------- /tutorial/how_to_build.md: -------------------------------------------------------------------------------- 1 | If you are building this project from source, you can try the following 2 | 3 | ``` 4 | git clone https://github.com/spark-redshift-community/spark-redshift.git 5 | ``` 6 | 7 | ``` 8 | cd spark-redshift 9 | ``` 10 | 11 | ``` 12 | ./build/sbt -v compile 13 | ``` 14 | 15 | ``` 16 | ./build/sbt -v package 17 | ``` 18 | 19 | To run the test 20 | 21 | ``` 22 | ./build/sbt -v test 23 | ``` 24 | 25 | To run the integration test 26 | 27 | For the first time, you need to set up all the evnironment variables to connect to Redshift (see https://github.com/spark-redshift-community/spark-redshift/blob/master/src/it/scala/io/github/spark_redshift_community/spark/redshift/IntegrationSuiteBase.scala#L54). 28 | 29 | ``` 30 | ./build/sbt -v it:test 31 | ``` 32 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/RedshiftDMLExec.scala: -------------------------------------------------------------------------------- 1 | package io.github.spark_redshift_community.spark.redshift.pushdown 2 | 3 | import io.github.spark_redshift_community.spark.redshift.RedshiftRelation 4 | import org.apache.spark.sql.{Row, SparkSession} 5 | import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} 6 | import org.apache.spark.sql.execution.command.LeafRunnableCommand 7 | import org.apache.spark.sql.types.LongType 8 | 9 | 10 | case class RedshiftDMLExec (query: RedshiftSQLStatement, relation: RedshiftRelation) 11 | extends LeafRunnableCommand { 12 | 13 | override def output: Seq[Attribute] = Seq(AttributeReference("num_affected_rows", LongType)()) 14 | 15 | override def run(sparkSession: SparkSession): Seq[Row] = { 16 | relation.runDMLFromSQL(query) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /project/PackagingTypePlugin.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | import sbt._ 17 | 18 | object PackagingTypePlugin extends AutoPlugin { 19 | override val buildSettings = { 20 | sys.props += "packaging.type" -> "jar" 21 | Nil 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | target/ 18 | project/target 19 | .idea/ 20 | .idea_modules/ 21 | *.DS_Store 22 | build/ 23 | aws_variables.env 24 | derby.log 25 | metastore_db/ 26 | lib_managed/ 27 | .vscode/ 28 | .metals/ 29 | .bloop/ 30 | venv/ 31 | .env 32 | .bsp/ 33 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/data/DataApiRuntimeException.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.data 17 | 18 | class DataApiRuntimeException(message: String) extends RuntimeException(message) 19 | -------------------------------------------------------------------------------- /.gitignore.bak: -------------------------------------------------------------------------------- 1 | # 2 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | target/ 18 | project/target 19 | .idea/ 20 | .idea_modules/ 21 | *.DS_Store 22 | build/ 23 | aws_variables.env 24 | derby.log 25 | metastore_db/ 26 | lib_managed/ 27 | .vscode/ 28 | .metals/ 29 | .bloop/ 30 | venv/ 31 | .env 32 | .bsp/ 33 | /.peru-sbt 34 | /.bsp 35 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | sbt.version=1.11.2 20 | -------------------------------------------------------------------------------- /src/it/resources/lst/1_create_inventory.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | CREATE TABLE "PUBLIC"."inventory" 19 | ( 20 | inv_item_sk int, 21 | inv_warehouse_sk int, 22 | inv_quantity_on_hand int, 23 | inv_date_sk int 24 | ); -------------------------------------------------------------------------------- /src/it/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # Set everything to be logged to the console 18 | log4j.rootCategory=WARN, console 19 | log4j.appender.console=org.apache.log4j.ConsoleAppender 20 | log4j.appender.console.target=System.err 21 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 22 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 23 | 24 | # Project-level logging (disabled by default) 25 | log4j.logger.io.github.spark_redshift_community.spark.redshift=OFF 26 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/RowEncoderUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift 17 | 18 | import org.apache.spark.sql.Row 19 | import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} 20 | import org.apache.spark.sql.types.StructType 21 | 22 | private[redshift] object RowEncoderUtils { 23 | def expressionEncoderForSchema(schema: StructType): ExpressionEncoder[Row] = { 24 | RowEncoder(schema) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/RowEncoderUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift 17 | 18 | import org.apache.spark.sql.Row 19 | import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} 20 | import org.apache.spark.sql.types.StructType 21 | 22 | private[redshift] object RowEncoderUtils { 23 | def expressionEncoderForSchema(schema: StructType): ExpressionEncoder[Row] = { 24 | RowEncoder(schema) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/TimestampNTZTypeExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift 17 | 18 | import org.apache.spark.sql.types.{DataType, TimestampNTZType} 19 | 20 | private[redshift] object TimestampNTZTypeExtractor { 21 | def unapply(dataType: DataType): Option[Boolean] = dataType match { 22 | case TimestampNTZType => Some(true) 23 | case _ => None 24 | } 25 | 26 | val defaultType: DataType = TimestampNTZType 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/TimestampNTZTypeExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift 17 | 18 | import org.apache.spark.sql.types.{DataType, TimestampNTZType} 19 | 20 | private[redshift] object TimestampNTZTypeExtractor { 21 | def unapply(dataType: DataType): Option[Boolean] = dataType match { 22 | case TimestampNTZType => Some(true) 23 | case _ => None 24 | } 25 | 26 | val defaultType: DataType = TimestampNTZType 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/TimestampNTZTypeExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift 17 | 18 | import org.apache.spark.sql.types.{DataType, TimestampType} 19 | 20 | private[redshift] object TimestampNTZTypeExtractor { 21 | def unapply(dataType: DataType): Option[Boolean] = dataType match { 22 | // TimestampNTZType does not exist until Spark 3.4 23 | case _ => None 24 | } 25 | 26 | val defaultType: DataType = TimestampType 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/RedshiftFileFormatUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.spark.sql.execution.datasources.PartitionedFile 20 | 21 | import java.net.URI 22 | 23 | // cannot be a companion object since it must be in a different file 24 | private[redshift] object RedshiftFileFormatUtils { 25 | def uriFromPartitionedFile(file: PartitionedFile): URI = file.pathUri 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/RedshiftFileFormatUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.spark.sql.execution.datasources.PartitionedFile 20 | 21 | import java.net.URI 22 | 23 | // cannot be a companion object since it must be in a different file 24 | private[redshift] object RedshiftFileFormatUtils { 25 | def uriFromPartitionedFile(file: PartitionedFile): URI = file.pathUri 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/RowEncoderUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift 17 | 18 | import org.apache.spark.sql.Row 19 | import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} 20 | import org.apache.spark.sql.types.StructType 21 | 22 | private[redshift] object RowEncoderUtils { 23 | def expressionEncoderForSchema(schema: StructType): ExpressionEncoder[Row] = { 24 | ExpressionEncoder(RowEncoder.encoderFor(schema)) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/RedshiftFileFormatUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import org.apache.spark.sql.execution.datasources.PartitionedFile 20 | 21 | import java.net.URI 22 | 23 | // cannot be a companion object since it must be in a different file 24 | private[redshift] object RedshiftFileFormatUtils { 25 | def uriFromPartitionedFile(file: PartitionedFile): URI = new URI(file.filePath) 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/pushdown/querygeneration/GetMapValueExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 17 | 18 | import org.apache.spark.sql.catalyst.expressions.{Expression, GetMapValue} 19 | 20 | object GetMapValueExtractor { 21 | def unapply(expr: Expression): Option[(Expression, Expression, Boolean)] = expr match { 22 | case GetMapValue(child, key, failOnError) => Some(child, key, failOnError) 23 | case _ => None 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/pushdown/querygeneration/PromotePrecisionExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.Expression 20 | 21 | private[querygeneration] object PromotePrecisionExtractor { 22 | // Always return none, this operator was removed from 3.4 23 | def unapply(expr: Expression): Option[Expression] = None 24 | } 25 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/pushdown/querygeneration/PromotePrecisionExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.Expression 20 | 21 | private[querygeneration] object PromotePrecisionExtractor { 22 | // Always return none, this operator was removed from 3.4 23 | def unapply(expr: Expression): Option[Expression] = None 24 | } 25 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/pushdown/TestCase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.test 17 | 18 | import org.apache.spark.sql.Row 19 | 20 | 21 | // class representing each test case we'd like to run 22 | case class TestCase( 23 | sparkStatement: String, // the spark statement executed 24 | expectedResult: Seq[Row], // the expected result 25 | expectedAnswers: String * // the expected pushdown sql (one or more) 26 | ) -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/pushdown/querygeneration/StringStatementExtensions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import io.github.spark_redshift_community.spark.redshift.pushdown.RedshiftSQLStatement 20 | import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} 21 | 22 | private[querygeneration] object StringStatementExtensions { 23 | def unapply(expAttr: (Expression, Seq[Attribute])): Option[RedshiftSQLStatement] = { 24 | None 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/pushdown/querygeneration/StringStatementExtensions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import io.github.spark_redshift_community.spark.redshift.pushdown.RedshiftSQLStatement 20 | import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} 21 | 22 | private[querygeneration] object StringStatementExtensions { 23 | def unapply(expAttr: (Expression, Seq[Attribute])): Option[RedshiftSQLStatement] = { 24 | None 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/it/resources/log4j2.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # Set everything to be logged to the console 18 | rootLogger.level = WARN 19 | rootLogger.appenderRef.stdout.ref = console 20 | 21 | appender.console.type = Console 22 | appender.console.name = console 23 | appender.console.target = SYSTEM_ERR 24 | appender.console.layout.type = PatternLayout 25 | appender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 26 | 27 | # Project-level logging (disabled by default) 28 | logger.spark_redshift_community.name = io.github.spark_redshift_community.spark.redshift 29 | logger.spark_redshift_community.level = OFF 30 | -------------------------------------------------------------------------------- /src/test/resources/log4j2.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # Set everything to be logged to the console 18 | rootLogger.level = WARN 19 | rootLogger.appenderRef.stdout.ref = console 20 | 21 | appender.console.type = Console 22 | appender.console.name = console 23 | appender.console.target = SYSTEM_ERR 24 | appender.console.layout.type = PatternLayout 25 | appender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 26 | 27 | # Project-level logging (disabled by default) 28 | logger.spark_redshift_community.name = io.github.spark_redshift_community.spark.redshift 29 | logger.spark_redshift_community.level = OFF 30 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/pushdown/querygeneration/PromotePrecisionExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Expression, PromotePrecision} 20 | 21 | private[querygeneration] object PromotePrecisionExtractor { 22 | def unapply(expr: Expression): Option[Expression] = expr match { 23 | case PromotePrecision(expression) => Some(expression) 24 | case _ => None 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/pushdown/querygeneration/RoundExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Expression, Round} 20 | 21 | private[querygeneration] object RoundExtractor { 22 | def unapply(expr: Expression): Option[(Expression, Expression, Boolean)] = expr match { 23 | case Round(child, scale, ansiEnabled) => Some(child, scale, ansiEnabled) 24 | case _ => None 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/pushdown/querygeneration/RoundExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Expression, Round} 20 | 21 | private[querygeneration] object RoundExtractor { 22 | def unapply(expr: Expression): Option[(Expression, Expression, Boolean)] = expr match { 23 | case Round(child, scale, ansiEnabled) => Some(child, scale, ansiEnabled) 24 | case _ => None 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/pushdown/querygeneration/CastExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} 20 | import org.apache.spark.sql.types.DataType 21 | 22 | private[querygeneration] object CastExtractor { 23 | def unapply(expr: Expression): Option[(Expression, DataType, Boolean)] = expr match { 24 | case c : Cast => Some(c.child, c.dataType, c.ansiEnabled) 25 | case _ => None 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/pushdown/querygeneration/CastExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} 20 | import org.apache.spark.sql.types.DataType 21 | 22 | private[querygeneration] object CastExtractor { 23 | def unapply(expr: Expression): Option[(Expression, DataType, Boolean)] = expr match { 24 | case c : Cast => Some(c.child, c.dataType, c.ansiEnabled) 25 | case _ => None 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/BinaryOp.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 19 | 20 | import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan} 21 | 22 | /** Extractor for binary logical operations (e.g., joins). */ 23 | private[querygeneration] object BinaryOp { 24 | 25 | def unapply(node: BinaryNode): Option[(LogicalPlan, LogicalPlan)] = 26 | Option(node match { 27 | case _: Join => (node.left, node.right) 28 | case _ => null 29 | }) 30 | } 31 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/SqlToS3TempCache.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.pushdown 18 | 19 | import java.util.concurrent.ConcurrentHashMap 20 | 21 | object SqlToS3TempCache { 22 | private val sqlToS3Cache = new ConcurrentHashMap[String, String]() 23 | 24 | def getS3Path(sql : String): Option[String] = { 25 | Option(sqlToS3Cache.get(sql)) 26 | } 27 | 28 | def setS3Path(sql : String, s3Path : String): Option[String] = { 29 | Option(sqlToS3Cache.put(sql, s3Path)) 30 | } 31 | 32 | def clearCache(): Unit = { 33 | sqlToS3Cache.clear() 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/test/resources/hive-site.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 | fs.permissions.umask-mode 23 | 022 24 | Setting a value for fs.permissions.umask-mode to work around issue in HIVE-6962. 25 | It has no impact in Hadoop 1.x line on HDFS operations. 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/pushdown/querygeneration/ExistsExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Exists, Expression} 20 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 21 | 22 | private[querygeneration] object ExistsExtractor { 23 | def unapply(expr: Expression): Option[(LogicalPlan, Seq[Expression])] = expr match { 24 | case Exists(subQuery, _, _, joinCond) => Some(subQuery, joinCond) 25 | case _ => None 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/pushdown/querygeneration/RoundExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Expression, Round} 20 | 21 | private[querygeneration] object RoundExtractor { 22 | def unapply(expr: Expression): Option[(Expression, Expression, Boolean)] = expr match { 23 | // always return false for ansiEnabled since spark 3.3 connector 24 | // acted as though it was always false 25 | case Round(child, scale) => Some(child, scale, false) 26 | case _ => None 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/UnaryOp.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 19 | 20 | import org.apache.spark.sql.catalyst.plans.logical._ 21 | 22 | 23 | /** Extractor for supported unary operations. */ 24 | private[querygeneration] object UnaryOp { 25 | 26 | def unapply(node: UnaryNode): Option[LogicalPlan] = 27 | node match { 28 | case _: Filter | _: Project | _: GlobalLimit | _: LocalLimit | 29 | _: Aggregate | _: Sort | _: ReturnAnswer | _: Window => 30 | Some(node.child) 31 | 32 | case _ => None 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/pushdown/querygeneration/GetMapValueExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 17 | 18 | import org.apache.spark.sql.catalyst.expressions.{Expression, GetMapValue} 19 | 20 | private[querygeneration] object GetMapValueExtractor { 21 | def unapply(expr: Expression): Option[(Expression, Expression, Boolean)] = expr match { 22 | // set third tuple value representing failOnError to false 23 | // this is how GetMapValue in spark 3.4 works. Since the 24 | // parameter has been removed 25 | case GetMapValue(child, key) => Some(child, key, false) 26 | case _ => None 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/pushdown/querygeneration/GetMapValueExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 17 | 18 | import org.apache.spark.sql.catalyst.expressions.{Expression, GetMapValue} 19 | 20 | private[querygeneration] object GetMapValueExtractor { 21 | def unapply(expr: Expression): Option[(Expression, Expression, Boolean)] = expr match { 22 | // set third tuple value representing failOnError to false 23 | // this is how GetMapValue in spark 3.4 works. Since the 24 | // parameter has been removed 25 | case GetMapValue(child, key) => Some(child, key, false) 26 | case _ => None 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/pushdown/querygeneration/ExistsExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Exists, Expression} 20 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 21 | 22 | private[querygeneration] object ExistsExtractor { 23 | def unapply(expr: Expression): Option[(LogicalPlan, Seq[Expression])] = expr match { 24 | // Fifth parameter (hint) was added after Spark 3.3 25 | case Exists(subQuery, _, _, joinCond, _) => Some(subQuery, joinCond) 26 | case _ => None 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/pushdown/querygeneration/ExistsExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{Exists, Expression} 20 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 21 | 22 | private[querygeneration] object ExistsExtractor { 23 | def unapply(expr: Expression): Option[(LogicalPlan, Seq[Expression])] = expr match { 24 | // Fifth parameter (hint) was added after Spark 3.3 25 | case Exists(subQuery, _, _, joinCond, _) => Some(subQuery, joinCond) 26 | case _ => None 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | repos: 17 | - repo: local 18 | hooks: 19 | - id: sbt-compile-package 20 | name: sbt-compile-package 21 | entry: bash -c 'sbt compile package' 22 | language: system 23 | types: [file] 24 | pass_filenames: false 25 | - id: unit-tests 26 | name: unit-tests 27 | entry: bash -c 'sbt test' 28 | language: system 29 | types: [file] 30 | pass_filenames: false 31 | - id: compile-integration-tests 32 | name: compile-integration-tests 33 | entry: bash -c 'sbt it:compile' 34 | language: system 35 | types: [file] 36 | pass_filenames: false 37 | - repo: https://github.com/pre-commit/pre-commit-hooks 38 | rev: v4.5.0 39 | hooks: 40 | - id: check-yaml 41 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/pushdown/querygeneration/CastExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast, Expression} 20 | import org.apache.spark.sql.types.DataType 21 | 22 | private[querygeneration] object CastExtractor { 23 | def unapply(expr: Expression): Option[(Expression, DataType, Boolean)] = expr match { 24 | case c : Cast => Some(c.child, c.dataType, c.ansiEnabled) 25 | // AnsiCast was deprecated after Spark 3.3 26 | case c : AnsiCast => Some(c.child, c.dataType, true) 27 | case _ => None 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/ScalarSubqueryExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import org.apache.spark.sql.catalyst.expressions.{ExprId, Expression, ScalarSubquery} 20 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 21 | 22 | object ScalarSubqueryExtractor { 23 | def unapply(expr: Expression): Option[(LogicalPlan, Seq[Expression], ExprId, Seq[Expression])] = 24 | expr match { 25 | // ignoring hintinfo and mayHaveCountBug 26 | case sq : ScalarSubquery => 27 | Some(sq.plan, sq.outerAttrs, sq.exprId, sq.joinCond) 28 | case _ => None 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | 2 | /* 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.7") 19 | 20 | addSbtPlugin("com.github.sbt" % "sbt-release" % "1.4.0") 21 | 22 | addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.3.1") 23 | 24 | addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.12.0") 25 | 26 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.2.0") 27 | 28 | addSbtPlugin("nl.gn0s1s" % "sbt-dotenv" % "3.1.1") 29 | 30 | libraryDependencies += "org.apache.maven" % "maven-artifact" % "3.3.9" 31 | 32 | // use built-in sbt plugin in sbt 1.4+ for dependency tree generation 33 | addDependencyTreePlugin 34 | 35 | // https://github.com/sbt/sbt/issues/6997 36 | ThisBuild / libraryDependencySchemes ++= Seq( 37 | "org.scala-lang.modules" %% "scala-xml" % VersionScheme.Always 38 | ) 39 | -------------------------------------------------------------------------------- /src/it/resources/lst/1_create_store_returns.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | create table "PUBLIC"."store_returns" 19 | ( 20 | sr_returned_date_sk int4 , 21 | sr_return_time_sk int4 , 22 | sr_item_sk int4 not null , 23 | sr_customer_sk int4 , 24 | sr_cdemo_sk int4 , 25 | sr_hdemo_sk int4 , 26 | sr_addr_sk int4 , 27 | sr_store_sk int4 , 28 | sr_reason_sk int4 , 29 | sr_ticket_number int8 not null, 30 | sr_return_quantity int4 , 31 | sr_return_amt numeric(7,2) , 32 | sr_return_tax numeric(7,2) , 33 | sr_return_amt_inc_tax numeric(7,2) , 34 | sr_fee numeric(7,2) , 35 | sr_return_ship_cost numeric(7,2) , 36 | sr_refunded_cash numeric(7,2) , 37 | sr_reversed_charge numeric(7,2) , 38 | sr_store_credit numeric(7,2) , 39 | sr_net_loss numeric(7,2) 40 | ); -------------------------------------------------------------------------------- /src/it/resources/lst/2_load_inventory.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | INSERT INTO "PUBLIC"."inventory" (inv_item_sk, inv_warehouse_sk, inv_quantity_on_hand, inv_date_sk) 19 | VALUES (1, 101, 500, 20230101), 20 | (2, 102, 300, 20230102), 21 | (3, 103, 250, 20230103), 22 | (4, 104, 400, 20230104), 23 | (5, 105, 150, 20230105), 24 | (6, 106, 800, 20230106), 25 | (7, 107, 600, 20230107), 26 | (8, 108, 450, 20230108), 27 | (9, 109, 200, 20230109), 28 | (10, 110, 900, 20230110), 29 | (11, 111, 350, 20230111), 30 | (12, 112, 275, 20230112), 31 | (13, 113, 650, 20230113), 32 | (14, 114, 180, 20230114), 33 | (15, 115, 425, 20230115), 34 | (16, 116, 760, 20230116), 35 | (17, 117, 320, 20230117), 36 | (18, 118, 560, 20230118), 37 | (19, 119, 215, 20230119), 38 | (20, 120, 790, 20230120); -------------------------------------------------------------------------------- /src/it/resources/lst/1_create_web_returns.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | create table "PUBLIC"."web_returns" 19 | ( 20 | wr_returned_date_sk int4 , 21 | wr_returned_time_sk int4 , 22 | wr_item_sk int4 not null , 23 | wr_refunded_customer_sk int4 , 24 | wr_refunded_cdemo_sk int4 , 25 | wr_refunded_hdemo_sk int4 , 26 | wr_refunded_addr_sk int4 , 27 | wr_returning_customer_sk int4 , 28 | wr_returning_cdemo_sk int4 , 29 | wr_returning_hdemo_sk int4 , 30 | wr_returning_addr_sk int4 , 31 | wr_web_page_sk int4 , 32 | wr_reason_sk int4 , 33 | wr_order_number int8 not null, 34 | wr_return_quantity int4 , 35 | wr_return_amt numeric(7,2) , 36 | wr_return_tax numeric(7,2) , 37 | wr_return_amt_inc_tax numeric(7,2) , 38 | wr_fee numeric(7,2) , 39 | wr_return_ship_cost numeric(7,2) , 40 | wr_refunded_cash numeric(7,2) , 41 | wr_reversed_charge numeric(7,2) , 42 | wr_account_credit numeric(7,2) , 43 | wr_net_loss numeric(7,2) 44 | ); -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/pushdown/querygeneration/StringStatementExtensions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import io.github.spark_redshift_community.spark.redshift.pushdown.{ConstantString, RedshiftSQLStatement} 20 | import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, ToPrettyString} 21 | import org.apache.spark.sql.types._ 22 | 23 | private[querygeneration] object StringStatementExtensions { 24 | def unapply(expAttr: (Expression, Seq[Attribute])): Option[RedshiftSQLStatement] = { 25 | val expr = expAttr._1 26 | val fields = expAttr._2 27 | 28 | Option(expr match { 29 | 30 | case ToPrettyString(child: Expression, timeZoneId: Option[String]) => 31 | ConstantString("CASE WHEN") + 32 | convertStatement(child, fields) + 33 | ConstantString("IS NULL THEN 'NULL' ELSE") + 34 | convertStatement(Cast(child, StringType, timeZoneId), fields) + 35 | ConstantString("END") 36 | 37 | case _ => null 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/it/resources/lst/1_create_web_sales.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | create table "PUBLIC"."web_sales" 19 | ( 20 | ws_sold_date_sk int4 , 21 | ws_sold_time_sk int4 , 22 | ws_ship_date_sk int4 , 23 | ws_item_sk int4 not null , 24 | ws_bill_customer_sk int4 , 25 | ws_bill_cdemo_sk int4 , 26 | ws_bill_hdemo_sk int4 , 27 | ws_bill_addr_sk int4 , 28 | ws_ship_customer_sk int4 , 29 | ws_ship_cdemo_sk int4 , 30 | ws_ship_hdemo_sk int4 , 31 | ws_ship_addr_sk int4 , 32 | ws_web_page_sk int4 , 33 | ws_web_site_sk int4 , 34 | ws_ship_mode_sk int4 , 35 | ws_warehouse_sk int4 , 36 | ws_promo_sk int4 , 37 | ws_order_number int8 not null, 38 | ws_quantity int4 , 39 | ws_wholesale_cost numeric(7,2) , 40 | ws_list_price numeric(7,2) , 41 | ws_sales_price numeric(7,2) , 42 | ws_ext_discount_amt numeric(7,2) , 43 | ws_ext_sales_price numeric(7,2) , 44 | ws_ext_wholesale_cost numeric(7,2) , 45 | ws_ext_list_price numeric(7,2) , 46 | ws_ext_tax numeric(7,2) , 47 | ws_coupon_amt numeric(7,2) , 48 | ws_ext_ship_cost numeric(7,2) , 49 | ws_net_paid numeric(7,2) , 50 | ws_net_paid_inc_tax numeric(7,2) , 51 | ws_net_paid_inc_ship numeric(7,2) , 52 | ws_net_paid_inc_ship_tax numeric(7,2) , 53 | ws_net_profit numeric(7,2) 54 | ); -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/data/RedshiftWrapperFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.data 17 | 18 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters 19 | import io.github.spark_redshift_community.spark.redshift.data.RedshiftWrapperType.{DataAPI, JDBC, RedshiftInterfaceType} 20 | import org.slf4j.LoggerFactory 21 | 22 | private[redshift] object RedshiftWrapperType extends Enumeration { 23 | type RedshiftInterfaceType = Value 24 | val DataAPI, JDBC = Value 25 | } 26 | 27 | private[redshift] object RedshiftWrapperFactory { 28 | private val log = LoggerFactory.getLogger(getClass) 29 | private val jdbcSingleton = RedshiftWrapperFactory(RedshiftWrapperType.JDBC) 30 | private val dataAPISingleton = RedshiftWrapperFactory(RedshiftWrapperType.DataAPI) 31 | 32 | def apply(params: MergedParameters): RedshiftWrapper = { 33 | if (params.dataAPICreds) { 34 | log.debug("Using Data API to communicate with Redshift") 35 | dataAPISingleton 36 | } else { 37 | log.debug("Using JDBC to communicate with Redshift") 38 | jdbcSingleton 39 | } 40 | } 41 | 42 | private def apply(dataInterfaceType: RedshiftInterfaceType): RedshiftWrapper = { 43 | dataInterfaceType match { 44 | case DataAPI => new DataApiWrapper() 45 | case _ => new JDBCWrapper() 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/SerializableConfigurationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.test 19 | 20 | import io.github.spark_redshift_community.spark.redshift.SerializableConfiguration 21 | import org.apache.hadoop.conf.Configuration 22 | import org.apache.spark.SparkConf 23 | import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} 24 | import org.scalatest.funsuite.AnyFunSuite 25 | 26 | class SerializableConfigurationSuite extends AnyFunSuite { 27 | 28 | private def testSerialization(serializer: SerializerInstance): Unit = { 29 | val conf = new SerializableConfiguration(new Configuration()) 30 | 31 | val serialized = serializer.serialize(conf) 32 | 33 | serializer.deserialize[Any](serialized) match { 34 | case c: SerializableConfiguration => 35 | assert(c.log != null, "log was null") 36 | assert(c.value != null, "value was null") 37 | case other => fail( 38 | s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.") 39 | } 40 | } 41 | 42 | test("serialization with JavaSerializer") { 43 | testSerialization(new JavaSerializer(new SparkConf()).newInstance()) 44 | } 45 | 46 | test("serialization with KryoSerializer") { 47 | testSerialization(new KryoSerializer(new SparkConf()).newInstance()) 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/PassthroughStatement.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 18 | 19 | import io.github.spark_redshift_community.spark.redshift.pushdown._ 20 | import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull 21 | import org.apache.spark.sql.catalyst.expressions.{Attribute, CheckOverflowInTableInsert, Expression} 22 | 23 | import scala.language.postfixOps 24 | 25 | /** 26 | * Extractor for expressions that are ignored. 27 | */ 28 | private[querygeneration] object PassthroughStatement { 29 | 30 | /** Used mainly by QueryGeneration.convertExpression. This matches 31 | * a tuple of (Expression, Seq[Attribute]) representing the expression to 32 | * be matched and the fields that define the valid fields in the current expression 33 | * scope, respectively. 34 | * 35 | * @param expAttr A pair-tuple representing the expression to be matched and the 36 | * attribute fields. 37 | * @return An option containing the translated SQL, if there is a match, or None if there 38 | * is no match. 39 | */ 40 | def unapply( 41 | expAttr: (Expression, Seq[Attribute]) 42 | ): Option[RedshiftSQLStatement] = { 43 | val expr = expAttr._1 44 | val fields = expAttr._2 45 | 46 | Option(expr match { 47 | case CheckOverflowInTableInsert(child, _) => convertStatement(child, fields) 48 | case _ => null 49 | }) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_3_/pushdown/querygeneration/MergeIntoTableExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 17 | 18 | import io.github.spark_redshift_community.spark.redshift.RedshiftRelation 19 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MergeAction, MergeIntoTable} 20 | import org.apache.spark.sql.catalyst.expressions.Expression 21 | import org.apache.spark.sql.execution.datasources.LogicalRelation 22 | 23 | object MergeIntoTableExtractor { 24 | def unapply(plan: LogicalPlan): Option[(LogicalPlan, 25 | RedshiftRelation, 26 | LogicalPlan, 27 | Expression, 28 | Seq[MergeAction], 29 | Seq[MergeAction], 30 | Seq[MergeAction])] = 31 | plan match { 32 | case MergeIntoTable(targetTable@LogicalRelation(targetRelation: RedshiftRelation, _, _, _), 33 | sourcePlan, 34 | mergeCondition, 35 | matchedActions, 36 | notMatchedActions) => 37 | Some(targetTable, 38 | targetRelation, 39 | sourcePlan, 40 | mergeCondition, 41 | matchedActions, 42 | notMatchedActions, 43 | Seq.empty[MergeAction]) 44 | case _ => None 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/RedshiftQuerySuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.test 17 | 18 | import io.github.spark_redshift_community.spark.redshift.pushdown.RedshiftScanExec 19 | import org.apache.spark.sql.SparkSession 20 | import org.scalatest.BeforeAndAfterAll 21 | import org.scalatest.funsuite.AnyFunSuite 22 | 23 | class RedshiftQuerySuite extends AnyFunSuite with BeforeAndAfterAll { 24 | 25 | private var spark: SparkSession = _ 26 | 27 | override def beforeAll(): Unit = { 28 | super.beforeAll() 29 | spark = SparkSession.builder() 30 | .master("local") 31 | .getOrCreate() 32 | } 33 | 34 | override def afterAll(): Unit = { 35 | spark.stop() 36 | super.afterAll() 37 | } 38 | 39 | test("test q1") { 40 | spark.sql(""" 41 | create table student(id int) 42 | using io.github.spark_redshift_community.spark.redshift 43 | OPTIONS ( 44 | dbtable 'public.parquet_struct_table_view', 45 | tempdir '/tmp/dir', 46 | url '', 47 | forward_spark_s3_credentials 'true' 48 | ) 49 | """).show() 50 | 51 | val df = spark.sql( 52 | """ 53 | |SELECT * FROM student 54 | |""".stripMargin) 55 | val plan = df.queryExecution.executedPlan 56 | 57 | assert(plan.isInstanceOf[RedshiftScanExec]) 58 | val rsPlan = plan.asInstanceOf[RedshiftScanExec] 59 | assert(rsPlan.query.statementString == 60 | """SELECT * FROM "public"."parquet_struct_table_view" AS "RCQ_ALIAS"""" 61 | .stripMargin) 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/SerializableConfiguration.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift 18 | 19 | import java.io._ 20 | 21 | import com.esotericsoftware.kryo.io.{Input, Output} 22 | import com.esotericsoftware.kryo.{Kryo, KryoSerializable} 23 | import org.apache.hadoop.conf.Configuration 24 | import org.slf4j.LoggerFactory 25 | 26 | import scala.util.control.NonFatal 27 | 28 | class SerializableConfiguration(@transient var value: Configuration) 29 | extends Serializable with KryoSerializable { 30 | @transient private[redshift] lazy val log = LoggerFactory.getLogger(getClass) 31 | 32 | private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { 33 | out.defaultWriteObject() 34 | value.write(out) 35 | } 36 | 37 | private def readObject(in: ObjectInputStream): Unit = tryOrIOException { 38 | value = new Configuration(false) 39 | value.readFields(in) 40 | } 41 | 42 | private def tryOrIOException[T](block: => T): T = { 43 | try { 44 | block 45 | } catch { 46 | case e: IOException => 47 | log.error("Exception encountered: {}", e.getMessage) 48 | throw e 49 | case NonFatal(e) => 50 | log.error("Exception encountered: {}", e.getMessage) 51 | throw new IOException(e) 52 | } 53 | } 54 | 55 | def write(kryo: Kryo, out: Output): Unit = { 56 | val dos = new DataOutputStream(out) 57 | value.write(dos) 58 | dos.flush() 59 | } 60 | 61 | def read(kryo: Kryo, in: Input): Unit = { 62 | value = new Configuration(false) 63 | value.readFields(new DataInputStream(in)) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_5_/pushdown/querygeneration/MergeIntoTableExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 17 | 18 | import io.github.spark_redshift_community.spark.redshift.RedshiftRelation 19 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MergeAction, MergeIntoTable} 20 | import org.apache.spark.sql.catalyst.expressions.Expression 21 | import org.apache.spark.sql.execution.datasources.LogicalRelation 22 | object MergeIntoTableExtractor { 23 | def unapply(plan: LogicalPlan): Option[(LogicalPlan, 24 | RedshiftRelation, 25 | LogicalPlan, 26 | Expression, 27 | Seq[MergeAction], 28 | Seq[MergeAction], 29 | Seq[MergeAction])] = 30 | plan match { 31 | case MergeIntoTable(targetTable@LogicalRelation(targetRelation: RedshiftRelation, _, _, _), 32 | sourcePlan, 33 | mergeCondition, 34 | matchedActions, 35 | notMatchedActions, 36 | notMatchedBySourceActions) => 37 | Some(targetTable, 38 | targetRelation, 39 | sourcePlan, 40 | mergeCondition, 41 | matchedActions, 42 | notMatchedActions, 43 | notMatchedBySourceActions) 44 | case _ => None 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift_spark_3_4_/pushdown/querygeneration/MergeIntoTableExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 17 | 18 | import io.github.spark_redshift_community.spark.redshift.RedshiftRelation 19 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MergeAction, MergeIntoTable} 20 | import org.apache.spark.sql.catalyst.expressions.Expression 21 | import org.apache.spark.sql.execution.datasources.LogicalRelation 22 | 23 | object MergeIntoTableExtractor { 24 | def unapply(plan: LogicalPlan): Option[(LogicalPlan, 25 | RedshiftRelation, 26 | LogicalPlan, 27 | Expression, 28 | Seq[MergeAction], 29 | Seq[MergeAction], 30 | Seq[MergeAction])] = 31 | plan match { 32 | case MergeIntoTable(targetTable@LogicalRelation(targetRelation: RedshiftRelation, _, _, _), 33 | sourcePlan, 34 | mergeCondition, 35 | matchedActions, 36 | notMatchedActions, 37 | notMatchedBySourceActions) => 38 | Some(targetTable, 39 | targetRelation, 40 | sourcePlan, 41 | mergeCondition, 42 | matchedActions, 43 | notMatchedActions, 44 | notMatchedBySourceActions) 45 | case _ => None 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/QueryGroupIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.test 17 | 18 | import io.github.spark_redshift_community.spark.redshift.Parameters 19 | import io.github.spark_redshift_community.spark.redshift.data.JDBCWrapper 20 | import org.mockito.Mockito.verify 21 | import org.scalatestplus.mockito.MockitoSugar.mock 22 | import org.slf4j.Logger 23 | 24 | class QueryGroupIntegrationSuite extends IntegrationSuiteBase { 25 | test("getConnectorWithQueryGroup returns a working connection when setting query group fails") { 26 | // This test is only valid for JDBC-based connections 27 | if (redshiftWrapper.isInstanceOf[JDBCWrapper]) { 28 | val invalidQueryGroup = "'" 29 | val params: Map[String, String] = defaultOptions() + ("dbtable" -> "fake_table") 30 | val mergedParams = Parameters.mergeParameters(params) 31 | val conn = TestJdbcWrapper.getConnectorWithQueryGroup(mergedParams, invalidQueryGroup) 32 | verify(TestJdbcWrapper.getLogger).debug("Unable to set query group: " + 33 | "Unterminated string literal started at position 21 in SQL set query_group to '''. Expected char") 34 | try { 35 | val results = TestJdbcWrapper.executeQueryInterruptibly(conn, "select 1") 36 | assert(results.next()) 37 | assert(results.getInt(1) == 1) 38 | assert(!results.next()) 39 | } finally { 40 | conn.close() 41 | } 42 | } 43 | } 44 | } 45 | 46 | private object TestJdbcWrapper extends JDBCWrapper { 47 | override protected val log = mock[Logger] 48 | def getLogger: Logger = log 49 | } 50 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/ComparableVersion.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift 17 | 18 | import com.fasterxml.jackson.core.Version 19 | 20 | private[redshift] case class ComparableVersion(strVersion: String) { 21 | private val version: Version = parseVersion(strVersion) 22 | 23 | private def parseVersion(strVersion: String): Version = { 24 | assert(strVersion != null && strVersion.nonEmpty) 25 | val versionComponents = strVersion.split('.') 26 | new Version( 27 | if (versionComponents.length > 0) versionComponents(0).toInt else 0, 28 | if (versionComponents.length > 1) versionComponents(1).toInt else 0, 29 | if (versionComponents.length > 2) versionComponents(2).toInt else 0, 30 | null, null, null) 31 | } 32 | 33 | def lessThan(strOtherVersion: String): Boolean = { 34 | val otherVersion = parseVersion(strOtherVersion) 35 | version.compareTo(otherVersion) < 0 36 | } 37 | 38 | def lessThanOrEqualTo(strOtherVersion: String): Boolean = { 39 | val otherVersion = parseVersion(strOtherVersion) 40 | version.compareTo(otherVersion) <= 0 41 | } 42 | 43 | def greaterThan(strOtherVersion: String): Boolean = { 44 | val otherVersion = parseVersion(strOtherVersion) 45 | version.compareTo(otherVersion) > 0 46 | } 47 | 48 | def greaterThanOrEqualTo(strOtherVersion: String): Boolean = { 49 | val otherVersion = parseVersion(strOtherVersion) 50 | version.compareTo(otherVersion) >= 0 51 | } 52 | 53 | def equalTo(strOtherVersion: String): Boolean = { 54 | val otherVersion = parseVersion(strOtherVersion) 55 | version.compareTo(otherVersion) == 0 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/data/RedshiftConnection.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.data 17 | 18 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters 19 | 20 | import java.sql.Connection 21 | import scala.collection.mutable.ArrayBuffer 22 | 23 | private[redshift] abstract class RedshiftConnection { 24 | def getAutoCommit(): Boolean 25 | def setAutoCommit(autoCommit: Boolean): Unit 26 | def close(): Unit 27 | } 28 | 29 | private[redshift] case class JDBCConnection(conn: Connection) extends RedshiftConnection { 30 | override def getAutoCommit(): Boolean = { 31 | conn.getAutoCommit() 32 | } 33 | 34 | override def setAutoCommit(autoCommit: Boolean): Unit = { 35 | conn.setAutoCommit(autoCommit) 36 | } 37 | 38 | override def close(): Unit = { 39 | conn.close() 40 | } 41 | } 42 | 43 | private[redshift] case class DataAPIConnection(params: MergedParameters, 44 | queryGroup: Option[String] = None 45 | ) extends RedshiftConnection { 46 | val bufferedCommands: ArrayBuffer[String] = ArrayBuffer.empty 47 | var autoCommit: Boolean = true 48 | 49 | override def getAutoCommit(): Boolean = { 50 | autoCommit 51 | } 52 | 53 | override def setAutoCommit(autoCommit: Boolean): Unit = { 54 | this.autoCommit = autoCommit 55 | } 56 | 57 | override def close(): Unit = { 58 | // Reset in case someone tries to reuse this object. However, this object should 59 | // not be used after closing. We may want to enforce this at some point. 60 | bufferedCommands.clear() 61 | autoCommit = true 62 | } 63 | } -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/DateStatement.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 19 | 20 | import io.github.spark_redshift_community.spark.redshift.pushdown.{ConstantString, RedshiftSQLStatement} 21 | import org.apache.spark.sql.catalyst.expressions.{AddMonths, Attribute, DateAdd, DateSub, Expression, TruncTimestamp} 22 | 23 | /** Extractor for boolean expressions (return true or false). */ 24 | private[querygeneration] object DateStatement { 25 | val REDSHIFT_DATEADD = "DATEADD" 26 | 27 | def unapply( 28 | expAttr: (Expression, Seq[Attribute]) 29 | ): Option[RedshiftSQLStatement] = { 30 | val expr = expAttr._1 31 | val fields = expAttr._2 32 | 33 | Option(expr match { 34 | case DateAdd(startDate, days) => 35 | ConstantString(REDSHIFT_DATEADD) + 36 | blockStatement( 37 | ConstantString("day,") + 38 | convertStatement(days, fields) + "," + 39 | convertStatement(startDate, fields) 40 | ) 41 | 42 | // it is pushdown by DATEADD with negative days 43 | case DateSub(startDate, days) => 44 | ConstantString(REDSHIFT_DATEADD) + 45 | blockStatement( 46 | ConstantString("day, (0 - (") + 47 | convertStatement(days, fields) + ") )," + 48 | convertStatement(startDate, fields) 49 | ) 50 | 51 | case _: AddMonths | _: TruncTimestamp => 52 | ConstantString(expr.prettyName.toUpperCase) + 53 | blockStatement(convertStatements(fields, expr.children: _*)) 54 | 55 | case _ => null 56 | }) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/PostgresDriverIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.test 18 | 19 | import io.github.spark_redshift_community.spark.redshift.Parameters 20 | import io.github.spark_redshift_community.spark.redshift.data.{JDBCConnection, JDBCWrapper, RedshiftWrapperFactory} 21 | import org.apache.spark.sql.Row 22 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 23 | 24 | /** 25 | * Basic integration tests with the Postgres JDBC driver. 26 | */ 27 | class PostgresDriverIntegrationSuite extends IntegrationSuiteBase { 28 | override def jdbcUrl: String = { 29 | super.jdbcUrl.replace("jdbc:redshift", "jdbc:postgresql") 30 | } 31 | 32 | // TODO (luca|issue #9) Fix tests when using postgresql driver 33 | ignore("postgresql driver takes precedence for jdbc:postgresql:// URIs") { 34 | // This test is only for JDBC-based credentials 35 | if (redshiftWrapper.isInstanceOf[JDBCWrapper]) { 36 | val params: Map[String, String] = defaultOptions() + 37 | ("dbtable" -> "fake_table") + ("url" -> jdbcUrl) 38 | val mergedParams = Parameters.mergeParameters(params) 39 | val conn = redshiftWrapper.getConnector(mergedParams) 40 | val jdbcConn = conn.asInstanceOf[JDBCConnection] 41 | try { 42 | assert(jdbcConn.conn.getClass.getName === "org.postgresql.jdbc4.Jdbc4Connection") 43 | } finally { 44 | conn.close() 45 | } 46 | } 47 | } 48 | 49 | ignore("roundtrip save and load") { 50 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), 51 | StructType(StructField("foo", IntegerType) :: Nil)) 52 | testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/DirectMapredOutputCommitter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | * not use this file except in compliance with the License. You may obtain 6 | * a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.test 18 | 19 | import org.apache.hadoop.fs.Path 20 | import org.apache.hadoop.mapred._ 21 | 22 | class DirectMapredOutputCommitter extends OutputCommitter { 23 | override def setupJob(jobContext: JobContext): Unit = { } 24 | 25 | override def setupTask(taskContext: TaskAttemptContext): Unit = { } 26 | 27 | override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = { 28 | // We return true here to guard against implementations that do not handle false correctly. 29 | // The meaning of returning false is not entirely clear, so it's possible to be interpreted 30 | // as an error. Returning true just means that commitTask() will be called, which is a no-op. 31 | true 32 | } 33 | 34 | override def commitTask(taskContext: TaskAttemptContext): Unit = { } 35 | 36 | override def abortTask(taskContext: TaskAttemptContext): Unit = { } 37 | 38 | /** 39 | * Creates a _SUCCESS file to indicate the entire job was successful. 40 | * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option. 41 | */ 42 | override def commitJob(context: JobContext): Unit = { 43 | val conf = context.getJobConf 44 | if (shouldCreateSuccessFile(conf)) { 45 | val outputPath = FileOutputFormat.getOutputPath(conf) 46 | if (outputPath != null) { 47 | val fileSys = outputPath.getFileSystem(conf) 48 | val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) 49 | fileSys.create(filePath).close() 50 | } 51 | } 52 | } 53 | 54 | /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */ 55 | private def shouldCreateSuccessFile(conf: JobConf): Boolean = { 56 | conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * Copyright 2015 TouchType Ltd. (Added JDBC-based Data Source API implementation) 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark 19 | 20 | import org.apache.spark.sql.functions.col 21 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 22 | import org.apache.spark.sql.{DataFrame, Row, SQLContext} 23 | 24 | package object redshift { 25 | 26 | private val LEGACY_TIMESTAMP_HANDLING_KEY = 27 | "spark.datasource.redshift.community.legacy_timestamp_handling" 28 | lazy val legacyTimestampHandling = 29 | Utils.getSparkConfigValue(LEGACY_TIMESTAMP_HANDLING_KEY, "true").toBoolean 30 | 31 | /** 32 | * Wrapper of SQLContext that provide `redshiftFile` method. 33 | */ 34 | implicit class RedshiftContext(sqlContext: SQLContext) { 35 | 36 | /** 37 | * Read a file unloaded from Redshift into a DataFrame. 38 | * @param path input path 39 | * @return a DataFrame with all string columns 40 | */ 41 | def redshiftFile(path: String, columns: Seq[String]): DataFrame = { 42 | val sc = sqlContext.sparkContext 43 | val rdd = sc.newAPIHadoopFile(path, classOf[RedshiftInputFormat], 44 | classOf[java.lang.Long], classOf[Array[String]], sc.hadoopConfiguration) 45 | // TODO: allow setting NULL string. 46 | val nullable = rdd.values.map(_.map(f => if (f.isEmpty) null else f)).map(x => Row(x: _*)) 47 | val schema = StructType(columns.map(c => StructField(c, StringType, nullable = true))) 48 | sqlContext.createDataFrame(nullable, schema) 49 | } 50 | 51 | /** 52 | * Reads a table unload from Redshift with its schema. 53 | */ 54 | def redshiftFile(path: String, schema: StructType): DataFrame = { 55 | val casts = schema.fields.map { field => 56 | col(field.name).cast(field.dataType).as(field.name) 57 | } 58 | redshiftFile(path, schema.fieldNames).select(casts: _*) 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/pushdown/AggregateStatisticalOperatorsCorrectnessSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.test 17 | 18 | import org.apache.spark.sql.Row 19 | 20 | abstract class AggregateStatisticalOperatorsCorrectnessSuite 21 | extends AggregateStddevSampCorrectnessSuite 22 | with AggregateStddevPopCorrectnessSuite 23 | with AggregateVarSampCorrectnessSuite 24 | with AggregateVarPopCorrectnessSuite { 25 | 26 | override protected val preloaded_data: String = "true" 27 | override def setTestTableName(): String = """"PUBLIC"."all_shapes_dist_all_sort_compound_12col"""" 28 | 29 | } 30 | 31 | class TextAggregateStatisticalOperatorsCorrectnessSuite 32 | extends AggregateStatisticalOperatorsCorrectnessSuite { 33 | override protected val s3format: String = "TEXT" 34 | } 35 | 36 | class ParquetAggregateStatisticalOperatorsCorrectnessSuite 37 | extends AggregateStatisticalOperatorsCorrectnessSuite { 38 | override protected val s3format: String = "PARQUET" 39 | } 40 | 41 | class TextPushdownNoCacheAggregateStatisticalOperatorsCorrectnessSuite 42 | extends AggregateStatisticalOperatorsCorrectnessSuite { 43 | override protected val s3_result_cache = "false" 44 | } 45 | 46 | class ParquetPushdownNoCacheAggregateStatisticalOperatorsCorrectnessSuite 47 | extends AggregateStatisticalOperatorsCorrectnessSuite { 48 | override protected val s3_result_cache = "false" 49 | } 50 | 51 | class ParquetNoPushdownAggregateStatisticalOperatorsCorrectnessSuite 52 | extends AggregateStatisticalOperatorsCorrectnessSuite { 53 | override protected val s3format: String = "PARQUET" 54 | override protected val auto_pushdown: String = "false" 55 | } 56 | 57 | class TextNoPushdownAggregateStatisticalOperatorsCorrectnessSuite 58 | extends AggregateStatisticalOperatorsCorrectnessSuite { 59 | override protected val s3format: String = "TEXT" 60 | override protected val auto_pushdown: String = "false" 61 | } 62 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/RedshiftCredentialsInConfIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.test 19 | 20 | import io.github.spark_redshift_community.spark.redshift.data.JDBCWrapper 21 | import org.apache.spark.sql.Row 22 | import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StructField, StructType} 23 | 24 | /** 25 | * This suite performs basic integration tests where the Redshift credentials have been 26 | * specified via `spark-redshift`'s configuration rather than as part of the JDBC URL. 27 | */ 28 | class RedshiftCredentialsInConfIntegrationSuite extends IntegrationSuiteBase { 29 | 30 | test("roundtrip save and load") { 31 | // This test is only valid for JDBC-based connections 32 | if (redshiftWrapper.isInstanceOf[JDBCWrapper]) { 33 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), 34 | StructType(StructField("foo", IntegerType, true, 35 | new MetadataBuilder().putString("redshift_type", "int4").build()) :: Nil)) 36 | val tableName = s"roundtrip_save_and_load_$randomSuffix" 37 | try { 38 | write(df) 39 | .option("url", jdbcUrlNoUserPassword) 40 | .option("user", AWS_REDSHIFT_USER) 41 | .option("password", AWS_REDSHIFT_PASSWORD) 42 | .option("dbtable", tableName) 43 | .save() 44 | assert(redshiftWrapper.tableExists(conn, tableName)) 45 | val loadedDf = read 46 | .option("url", jdbcUrlNoUserPassword) 47 | .option("user", AWS_REDSHIFT_USER) 48 | .option("password", AWS_REDSHIFT_PASSWORD) 49 | .option("dbtable", tableName) 50 | .load() 51 | assert(loadedDf.schema === df.schema) 52 | checkAnswer(loadedDf, df.collect()) 53 | } finally { 54 | redshiftWrapper.executeUpdate(conn, s"drop table if exists $tableName") 55 | } 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/it/resources/lst/1_create_catalog_returns.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | CREATE TABLE "PUBLIC"."catalog_returns"( 19 | cr_returned_time_sk int , 20 | cr_item_sk int , 21 | cr_refunded_customer_sk int , 22 | cr_refunded_cdemo_sk int , 23 | cr_refunded_hdemo_sk int , 24 | cr_refunded_addr_sk int , 25 | cr_returning_customer_sk int , 26 | cr_returning_cdemo_sk int , 27 | cr_returning_hdemo_sk int , 28 | cr_returning_addr_sk int , 29 | cr_call_center_sk int , 30 | cr_catalog_page_sk int , 31 | cr_ship_mode_sk int , 32 | cr_warehouse_sk int , 33 | cr_reason_sk int , 34 | cr_order_number bigint , 35 | cr_return_quantity int , 36 | cr_return_amount decimal(7,2) , 37 | cr_return_tax decimal(7,2) , 38 | cr_return_amt_inc_tax decimal(7,2) , 39 | cr_fee decimal(7,2) , 40 | cr_return_ship_cost decimal(7,2) , 41 | cr_refunded_cash decimal(7,2) , 42 | cr_reversed_charge decimal(7,2) , 43 | cr_store_credit decimal(7,2) , 44 | cr_net_loss decimal(7,2) , 45 | cr_returned_date_sk int 46 | ) 47 | -- WITH (location='${data_path}${experiment_start_time}/${repetition}/catalog_returns/', ${partition_spec_keyword}=ARRAY['cr_returned_date_sk'] ${tblproperties_suffix}); -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/DirectMapreduceOutputCommitter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | * not use this file except in compliance with the License. You may obtain 6 | * a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.test 18 | 19 | import org.apache.hadoop.conf.Configuration 20 | import org.apache.hadoop.fs.Path 21 | import org.apache.hadoop.mapreduce._ 22 | import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} 23 | 24 | class DirectMapreduceOutputCommitter extends OutputCommitter { 25 | override def setupJob(jobContext: JobContext): Unit = { } 26 | 27 | override def setupTask(taskContext: TaskAttemptContext): Unit = { } 28 | 29 | override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = { 30 | // We return true here to guard against implementations that do not handle false correctly. 31 | // The meaning of returning false is not entirely clear, so it's possible to be interpreted 32 | // as an error. Returning true just means that commitTask() will be called, which is a no-op. 33 | true 34 | } 35 | 36 | override def commitTask(taskContext: TaskAttemptContext): Unit = { } 37 | 38 | override def abortTask(taskContext: TaskAttemptContext): Unit = { } 39 | 40 | /** 41 | * Creates a _SUCCESS file to indicate the entire job was successful. 42 | * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option. 43 | */ 44 | override def commitJob(context: JobContext): Unit = { 45 | val conf = context.getConfiguration 46 | if (shouldCreateSuccessFile(conf)) { 47 | val outputPath = FileOutputFormat.getOutputPath(context) 48 | if (outputPath != null) { 49 | val fileSys = outputPath.getFileSystem(conf) 50 | val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) 51 | fileSys.create(filePath).close() 52 | } 53 | } 54 | } 55 | 56 | /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */ 57 | private def shouldCreateSuccessFile(conf: Configuration): Boolean = { 58 | conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/it/resources/lst/1_create_date_dim.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | CREATE TABLE "PUBLIC"."date_dim"( 19 | d_date_sk int , 20 | d_date_id varchar(16) , 21 | d_date date , 22 | d_month_seq int , 23 | d_week_seq int , 24 | d_quarter_seq int , 25 | d_year int , 26 | d_dow int , 27 | d_moy int , 28 | d_dom int , 29 | d_qoy int , 30 | d_fy_year int , 31 | d_fy_quarter_seq int , 32 | d_fy_week_seq int , 33 | d_day_name varchar(9) , 34 | d_quarter_name varchar(6) , 35 | d_holiday varchar(1) , 36 | d_weekend varchar(1) , 37 | d_following_holiday varchar(1) , 38 | d_first_dom int , 39 | d_last_dom int , 40 | d_same_day_ly int , 41 | d_same_day_lq int , 42 | d_current_day varchar(1) , 43 | d_current_week varchar(1) , 44 | d_current_month varchar(1) , 45 | d_current_quarter varchar(1) , 46 | d_current_year varchar(1) 47 | ) -- WITH (location='${data_path}${experiment_start_time}/${repetition}/date_dim/' ${tblproperties_suffix}); -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/RecordReaderIterator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift 19 | 20 | import java.io.Closeable 21 | 22 | import org.apache.hadoop.mapreduce.RecordReader 23 | 24 | /** 25 | * An adaptor from a Hadoop [[RecordReader]] to an [[Iterator]] over the values returned. 26 | * 27 | * This is copied from Apache Spark and is inlined here to avoid depending on Spark internals 28 | * in this external library. 29 | */ 30 | private[redshift] class RecordReaderIterator[T]( 31 | private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable { 32 | private[this] var havePair = false 33 | private[this] var finished = false 34 | 35 | override def hasNext: Boolean = { 36 | if (!finished && !havePair) { 37 | finished = !rowReader.nextKeyValue 38 | if (finished) { 39 | // Close and release the reader here; close() will also be called when the task 40 | // completes, but for tasks that read from many files, it helps to release the 41 | // resources early. 42 | close() 43 | } 44 | havePair = !finished 45 | } 46 | !finished 47 | } 48 | 49 | override def next(): T = { 50 | if (!hasNext) { 51 | throw new java.util.NoSuchElementException("End of stream") 52 | } 53 | havePair = false 54 | rowReader.getCurrentValue 55 | } 56 | 57 | override def close(): Unit = { 58 | if (rowReader != null) { 59 | try { 60 | // Close the reader and release it. Note: it's very important that we don't close the 61 | // reader more than once, since that exposes us to MAPREDUCE-5918 when running against 62 | // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues 63 | // when reading compressed input. 64 | rowReader.close() 65 | } finally { 66 | rowReader = null 67 | } 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/CrossRegionIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.test 19 | 20 | import io.github.spark_redshift_community.spark.redshift.Parameters.PARAM_TEMPDIR_REGION 21 | import org.apache.spark.sql.Row 22 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 23 | 24 | /** 25 | * Integration tests where the Redshift cluster and the S3 bucket are in different AWS regions. 26 | */ 27 | class CrossRegionIntegrationSuite extends IntegrationSuiteBase { 28 | 29 | protected val AWS_S3_CROSS_REGION_SCRATCH_SPACE: String = 30 | loadConfigFromEnv("AWS_S3_CROSS_REGION_SCRATCH_SPACE") 31 | protected val AWS_S3_CROSS_REGION_SCRATCH_SPACE_REGION: String = 32 | loadConfigFromEnv("AWS_S3_CROSS_REGION_SCRATCH_SPACE_REGION") 33 | require(AWS_S3_CROSS_REGION_SCRATCH_SPACE.contains("s3a"), "must use s3a:// URL") 34 | 35 | override protected val tempDir: String = AWS_S3_CROSS_REGION_SCRATCH_SPACE + randomSuffix + "/" 36 | 37 | test("write") { 38 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), 39 | StructType(StructField("foo", IntegerType) :: Nil)) 40 | val tableName = s"roundtrip_save_and_load_$randomSuffix" 41 | try { 42 | write(df) 43 | .option("dbtable", tableName) 44 | .option(PARAM_TEMPDIR_REGION, AWS_S3_CROSS_REGION_SCRATCH_SPACE_REGION) 45 | .save() 46 | // Check that the table exists. It appears that creating a table in one connection then 47 | // immediately querying for existence from another connection may result in spurious "table 48 | // doesn't exist" errors; this caused the "save with all empty partitions" test to become 49 | // flaky (see #146). To work around this, add a small sleep and check again: 50 | if (!redshiftWrapper.tableExists(conn, tableName)) { 51 | Thread.sleep(1000) 52 | assert(redshiftWrapper.tableExists(conn, tableName)) 53 | } 54 | } finally { 55 | redshiftWrapper.executeUpdate(conn, s"drop table if exists $tableName") 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/RedshiftPlan.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed to the Apache Software Foundation (ASF) under one or more 6 | * contributor license agreements. See the NOTICE file distributed with 7 | * this work for additional information regarding copyright ownership. 8 | * The ASF licenses this file to You under the Apache License, Version 2.0 9 | * (the "License"); you may not use this file except in compliance with 10 | * the License. You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | package io.github.spark_redshift_community.spark.redshift.pushdown 22 | 23 | import org.apache.spark.rdd.RDD 24 | import org.apache.spark.sql.catalyst.InternalRow 25 | import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} 26 | import org.apache.spark.sql.execution.SparkPlan 27 | import org.apache.spark.sql.types.{StructField, StructType} 28 | 29 | /** RedshiftPlan, with RDD defined by custom query. */ 30 | case class RedshiftPlan(output: Seq[Attribute], rdd: RDD[InternalRow]) 31 | extends SparkPlan { 32 | 33 | override def children: Seq[SparkPlan] = Nil 34 | protected override def doExecute(): RDD[InternalRow] = { 35 | rdd 36 | } 37 | 38 | override def simpleString(maxFields: Int): String = { 39 | super.simpleString(maxFields) + " " + output.mkString("[", ",", "]") 40 | } 41 | 42 | override def simpleStringWithNodeId(): String = { 43 | super.simpleStringWithNodeId() + " " + output.mkString("[", ",", "]") 44 | } 45 | 46 | // withNewChildrenInternal() is a new interface function from spark 3.2 in 47 | // org.apache.spark.sql.catalyst.trees.TreeNode. For details refer to 48 | // https://github.com/apache/spark/pull/32030 49 | // As for spark connector the RedshiftPlan is a leaf Node, we don't expect 50 | // caller to set any new children for it. 51 | // RedshiftPlan is only used for spark connector PushDown. Even if the Exception is 52 | // raised, the PushDown will not be used and it still works. 53 | override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = { 54 | if (newChildren.nonEmpty) { 55 | throw new Exception("Spark connector internal error: " + 56 | "RedshiftPlan.withNewChildrenInternal() is called to set some children nodes.") 57 | } 58 | this 59 | } 60 | } -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/UnsupportedStatement.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 19 | 20 | import io.github.spark_redshift_community.spark.redshift.{RedshiftFailMessage, RedshiftPushdownUnsupportedException} 21 | import io.github.spark_redshift_community.spark.redshift.pushdown.RedshiftSQLStatement 22 | import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF, ScalaUDF} 23 | import org.apache.spark.sql.execution.aggregate.ScalaUDAF 24 | 25 | /** 26 | * This class is used to catch unsupported statement and raise an exception 27 | * to stop the push-down to Redshift. 28 | */ 29 | private[querygeneration] object UnsupportedStatement { 30 | /** Used mainly by QueryGeneration.convertStatement. This matches 31 | * a tuple of (Expression, Seq[Attribute]) representing the expression to 32 | * be matched and the fields that define the valid fields in the current expression 33 | * scope, respectively. 34 | * 35 | * @param expAttr A pair-tuple representing the expression to be matched and the 36 | * attribute fields. 37 | * @return An option containing the translated SQL, if there is a match, or None if there 38 | * is no match. 39 | */ 40 | def unapply( 41 | expAttr: (Expression, Seq[Attribute]) 42 | ): Option[RedshiftSQLStatement] = { 43 | val expr = expAttr._1 44 | 45 | // This exception is not a real issue. It will be caught in 46 | // QueryBuilder.treeRoot. 47 | throw new RedshiftPushdownUnsupportedException( 48 | RedshiftFailMessage.FAIL_PUSHDOWN_STATEMENT, 49 | expr.prettyName, 50 | expr.sql, 51 | isKnownUnsupportedOperation(expr)) 52 | } 53 | 54 | // Determine whether the unsupported operation is known or not. 55 | private def isKnownUnsupportedOperation(expr: Expression): Boolean = { 56 | // The pushdown for UDF is known unsupported 57 | (expr.isInstanceOf[PythonUDF] 58 | || expr.isInstanceOf[ScalaUDF] 59 | || expr.isInstanceOf[ScalaUDAF]) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/pushdown/PushdownRedshiftReadSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.test 17 | 18 | import io.github.spark_redshift_community.spark.redshift.pushdown.RedshiftScanExec 19 | import io.github.spark_redshift_community.spark.redshift.data.JDBCWrapper 20 | import io.github.spark_redshift_community.spark.redshift.test.{IntegrationSuiteBase, OverrideNullableSuite} 21 | import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec 22 | 23 | class PushdownRedshiftReadSuite extends IntegrationSuiteBase with OverrideNullableSuite { 24 | override val auto_pushdown: String = "true" 25 | 26 | test("pushdowns across multiple clusters are executed separately") { 27 | // This method only works for JDBC. 28 | if (redshiftWrapper.isInstanceOf[JDBCWrapper]) { 29 | // A single pushdown operation cannot query data from multiple clusters 30 | // this verifies that separate scans are generated for each cluster 31 | val expectedUrl1 = jdbcUrl + "&ApplicationName=1" 32 | val expectedUrl2 = jdbcUrl + "&ApplicationName=2" 33 | 34 | withTempRedshiftTable("testTable") { name => 35 | redshiftWrapper.executeUpdate(conn, s"create table $name (id integer)") 36 | read 37 | .option("url", expectedUrl1) 38 | .option("dbtable", name) 39 | .load().createOrReplaceTempView("view1") 40 | read 41 | .option("url", expectedUrl2) 42 | .option("dbtable", name) 43 | .load().createOrReplaceTempView("view2") 44 | 45 | val plan = sqlContext.sql("select count(*) from view1 union select count(*) from view2"). 46 | queryExecution.executedPlan 47 | 48 | val traversablePlan = plan match { 49 | case p: AdaptiveSparkPlanExec => p.executedPlan 50 | case _ => plan 51 | } 52 | 53 | assert(traversablePlan.exists { 54 | case RedshiftScanExec(_, _, relation) => relation.params.jdbcUrl.get == expectedUrl1 55 | case _ => false 56 | }) 57 | assert(traversablePlan.exists { 58 | case RedshiftScanExec(_, _, relation) => relation.params.jdbcUrl.get == expectedUrl2 59 | case _ => false 60 | }) 61 | } 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/test/SeekableByteArrayInputStream.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* 19 | SeekableByteArrayInputStream copied from 20 | https://github.com/apache/accumulo/blob/master/core/src/test/java/org/apache/accumulo/core/file/rfile/RFileTest.java 21 | */ 22 | 23 | package io.github.spark_redshift_community.spark.redshift.test; 24 | 25 | import org.apache.hadoop.fs.PositionedReadable; 26 | import org.apache.hadoop.fs.Seekable; 27 | 28 | import java.io.ByteArrayInputStream; 29 | import java.io.IOException; 30 | 31 | 32 | class SeekableByteArrayInputStream extends ByteArrayInputStream 33 | implements Seekable, PositionedReadable { 34 | 35 | public SeekableByteArrayInputStream(byte[] buf) { 36 | super(buf); 37 | } 38 | 39 | @Override 40 | public long getPos() { 41 | return pos; 42 | } 43 | 44 | @Override 45 | public void seek(long pos) throws IOException { 46 | if (mark != 0) 47 | throw new IllegalStateException(); 48 | 49 | reset(); 50 | long skipped = skip(pos); 51 | 52 | if (skipped != pos) 53 | throw new IOException(); 54 | } 55 | 56 | @Override 57 | public boolean seekToNewSource(long targetPos) { 58 | return false; 59 | } 60 | 61 | @Override 62 | public int read(long position, byte[] buffer, int offset, int length) { 63 | 64 | if (position >= buf.length) 65 | throw new IllegalArgumentException(); 66 | if (position + length > buf.length) 67 | throw new IllegalArgumentException(); 68 | if (length > buffer.length) 69 | throw new IllegalArgumentException(); 70 | 71 | System.arraycopy(buf, (int) position, buffer, offset, length); 72 | return length; 73 | } 74 | 75 | @Override 76 | public void readFully(long position, byte[] buffer) { 77 | read(position, buffer, 0, buffer.length); 78 | 79 | } 80 | 81 | @Override 82 | public void readFully(long position, byte[] buffer, int offset, int length) { 83 | read(position, buffer, offset, length); 84 | } 85 | 86 | } 87 | 88 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/pushdown/StringIntegrationPushdownSuiteBase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.test 17 | 18 | class StringIntegrationPushdownSuiteBase extends IntegrationPushdownSuiteBase { 19 | override def createTestDataInRedshift(tableName: String): Unit = { 20 | redshiftWrapper.executeUpdate(conn, 21 | s""" 22 | |create table $tableName ( 23 | |testid int, 24 | |testbyte int2, 25 | |testbool boolean, 26 | |testdate date, 27 | |testdouble float8, 28 | |testfloat float4, 29 | |testint int4, 30 | |testlong int8, 31 | |testshort int2, 32 | |teststring varchar(256), 33 | |testfixedstring char(256), 34 | |testvarstring varchar(256), 35 | |testtimestamp timestamp 36 | |) 37 | """.stripMargin 38 | ) 39 | // scalastyle:off 40 | redshiftWrapper.executeUpdate(conn, 41 | s""" 42 | |insert into $tableName values 43 | |(0, null, null, null, null, null, null, null, null, null, null, null, null), 44 | |(1, 0, null, '2015-07-03', 0.0, -1.0, 4141214, 1239012341823719, null, 'f', 'Hello World', 'Hello World', '2015-07-03 00:00:00.000'), 45 | |(2, 0, false, null, -1234152.12312498, 100000.0, null, 1239012341823719, 24, '___|_123', 'Controls\t \b\n\r\f\\\\''\"', 'Controls\t \b\n\r\f\\\\''\"', null), 46 | |(3, 1, false, '2015-07-02', 0.0, 0.0, 42, 1239012341823719, -13, 'asdf', 'Specials/%', 'Specials/%', '2015-07-02 00:00:00.000'), 47 | |(4, 1, true, '2015-07-01', 1234152.12312498, 1.0, 42, 1239012341823719, 23, 'Unicode''s樂趣', 'Singl_Byte_Chars', 'Multi樂Byte趣Chars', '2015-07-01 00:00:00.001'), 48 | |(5, null, null, null, null, null, null, null, null, null, '', '', null), 49 | |(6, null, null, null, null, null, null, null, null, null, ' Hello World ', ' Hello World ', null), 50 | |(7, null, null, null, null, null, null, null, null, null, ' \t\b\nFoo\r\f\\\\''\" ', ' \t\b\nFoo\r\f\\\\''\" ', null), 51 | |(8, null, null, null, null, null, null, null, null, null, ' /%Foo%/ ', ' /%Foo%/ ', null), 52 | |(9, null, null, null, null, null, null, null, null, null, ' _Single_ ', ' 樂Multi趣 ', null) 53 | """.stripMargin 54 | ) 55 | // scalastyle:on 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/IAMIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.test 18 | 19 | import java.sql.SQLException 20 | 21 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 22 | import org.apache.spark.sql.{Row, SaveMode} 23 | 24 | /** 25 | * Integration tests for configuring Redshift to access S3 using Amazon IAM roles. 26 | */ 27 | class IAMIntegrationSuite extends IntegrationSuiteBase { 28 | 29 | private val IAM_ROLE_ARN: String = loadConfigFromEnv("STS_ROLE_ARN") 30 | 31 | // TODO (luca|issue #8) Fix IAM Authentication tests 32 | ignore("roundtrip save and load") { 33 | val tableName = s"iam_roundtrip_save_and_load$randomSuffix" 34 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 35 | StructType(StructField("a", IntegerType) :: Nil)) 36 | try { 37 | write(df) 38 | .option("dbtable", tableName) 39 | .option("forward_spark_s3_credentials", "false") 40 | .option("aws_iam_role", IAM_ROLE_ARN) 41 | .mode(SaveMode.ErrorIfExists) 42 | .save() 43 | 44 | assert(redshiftWrapper.tableExists(conn, tableName)) 45 | val loadedDf = read 46 | .option("dbtable", tableName) 47 | .option("forward_spark_s3_credentials", "false") 48 | .option("aws_iam_role", IAM_ROLE_ARN) 49 | .load() 50 | assert(loadedDf.schema.length === 1) 51 | assert(loadedDf.columns === Seq("a")) 52 | checkAnswer(loadedDf, Seq(Row(1))) 53 | } finally { 54 | redshiftWrapper.executeUpdate(conn, s"drop table if exists $tableName") 55 | } 56 | } 57 | 58 | ignore("load fails if IAM role cannot be assumed") { 59 | val tableName = s"iam_load_fails_if_role_cannot_be_assumed$randomSuffix" 60 | try { 61 | val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), 62 | StructType(StructField("a", IntegerType) :: Nil)) 63 | val err = intercept[SQLException] { 64 | write(df) 65 | .option("dbtable", tableName) 66 | .option("forward_spark_s3_credentials", "false") 67 | .option("aws_iam_role", IAM_ROLE_ARN + "-some-bogus-suffix") 68 | .mode(SaveMode.ErrorIfExists) 69 | .save() 70 | } 71 | assert(err.getCause.getMessage.contains("is not authorized to assume IAM Role")) 72 | } finally { 73 | redshiftWrapper.executeUpdate(conn, s"drop table if exists $tableName") 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/it/resources/lst/1_create_catalog_sales.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | CREATE TABLE "PUBLIC"."catalog_sales"( 19 | cs_sold_time_sk int , 20 | cs_ship_date_sk int , 21 | cs_bill_customer_sk int , 22 | cs_bill_cdemo_sk int , 23 | cs_bill_hdemo_sk int , 24 | cs_bill_addr_sk int , 25 | cs_ship_customer_sk int , 26 | cs_ship_cdemo_sk int , 27 | cs_ship_hdemo_sk int , 28 | cs_ship_addr_sk int , 29 | cs_call_center_sk int , 30 | cs_catalog_page_sk int , 31 | cs_ship_mode_sk int , 32 | cs_warehouse_sk int , 33 | cs_item_sk int , 34 | cs_promo_sk int , 35 | cs_order_number bigint , 36 | cs_quantity int , 37 | cs_wholesale_cost decimal(7,2) , 38 | cs_list_price decimal(7,2) , 39 | cs_sales_price decimal(7,2) , 40 | cs_ext_discount_amt decimal(7,2) , 41 | cs_ext_sales_price decimal(7,2) , 42 | cs_ext_wholesale_cost decimal(7,2) , 43 | cs_ext_list_price decimal(7,2) , 44 | cs_ext_tax decimal(7,2) , 45 | cs_coupon_amt decimal(7,2) , 46 | cs_ext_ship_cost decimal(7,2) , 47 | cs_net_paid decimal(7,2) , 48 | cs_net_paid_inc_tax decimal(7,2) , 49 | cs_net_paid_inc_ship decimal(7,2) , 50 | cs_net_paid_inc_ship_tax decimal(7,2) , 51 | cs_net_profit decimal(7,2) , 52 | cs_sold_date_sk int 53 | ) -- WITH (location='${data_path}${experiment_start_time}/${repetition}/catalog_sales/', ${partition_spec_keyword}=ARRAY['cs_sold_date_sk'] ${tblproperties_suffix}); -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/deoptimize/UndoCharTypePadding.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.deoptimize 17 | 18 | import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, IsNotNull, Literal} 19 | import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke 20 | import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} 21 | import org.apache.spark.sql.catalyst.rules.Rule 22 | import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, CharVarcharUtils} 23 | import org.apache.spark.sql.types.{CharType, IntegerType, StringType} 24 | 25 | object UndoCharTypePadding extends Rule[LogicalPlan] { 26 | 27 | // Remove padding as the SQL that will be generated shouldn't include padding which is an internal 28 | // detail of the engine. Redshift should handle this internally. 29 | override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithSubqueries { 30 | // Scope down to project alias cases as observed in TPC-DS queries 31 | case project @ Project(projectList, child) => 32 | var modified = false 33 | val newProjectList = projectList.map { 34 | case alias @ Alias(ReadSidePadding(ref), _) => 35 | modified = true 36 | alias.withNewChildren(ref :: Nil).asInstanceOf[Alias] 37 | case other => 38 | other 39 | } 40 | if (modified) { 41 | Project(newProjectList, child) 42 | } else { 43 | project 44 | } 45 | 46 | // Scope down to IsNotNull Filter cases as observed in TPC-DS queries 47 | case filter @ Filter(condition, child) => 48 | val newCondition = condition.transform { 49 | case IsNotNull(ReadSidePadding(ref)) => 50 | IsNotNull(ref) 51 | } 52 | if (condition eq newCondition) { 53 | filter 54 | } else { 55 | filter.copy(newCondition, child) 56 | } 57 | } 58 | } 59 | 60 | object ReadSidePadding { 61 | def unapply(s: StaticInvoke): Option[Expression] = s match { 62 | case StaticInvoke(clazz, StringType, "readSidePadding", ref +: Literal(length, IntegerType) 63 | +: Nil, _, _, _, _) 64 | if ref.isInstanceOf[AttributeReference] && 65 | clazz == classOf[CharVarcharCodegenUtils] && 66 | length.isInstanceOf[Int] => 67 | val metadata = ref.asInstanceOf[AttributeReference].metadata 68 | val optionalDataType = CharVarcharUtils.getRawType(metadata) 69 | optionalDataType.filter { dataType => 70 | dataType == CharType(length.asInstanceOf[Int]) 71 | }.map(_ => ref) 72 | case _ => None 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/pushdown/lst/LSTIntegrationPushdownSuiteBase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package io.github.spark_redshift_community.spark.redshift.pushdown.lst.test 18 | 19 | import software.amazon.awssdk.utils.IoUtils 20 | import io.github.spark_redshift_community.spark.redshift.pushdown.test.IntegrationPushdownSuiteBase 21 | 22 | class LSTIntegrationPushdownSuiteBase extends IntegrationPushdownSuiteBase{ 23 | 24 | // a list of all the tables used for the LST dataset testing 25 | protected val tableNames: List[String] = List( 26 | "catalog_returns", 27 | "catalog_sales", 28 | "date_dim", 29 | "inventory", 30 | "store_returns", 31 | "web_returns", 32 | "web_sales" 33 | ) 34 | 35 | // drops the tables necessary for running the TPC-DS correctness suite 36 | def tableCleanUpHelper(stmt: String): Unit = { 37 | for ( tpcds_table <- tableNames) { 38 | redshiftWrapper.executeUpdate(conn, s"${stmt} $tpcds_table") 39 | redshiftWrapper.executeUpdate(conn, s"${stmt} ${tpcds_table}_copy") 40 | } 41 | } 42 | 43 | // creates and populates the tables necessary for running the TPC-DS correctness suite 44 | def tableSetUpHelper(filename_prefix: String): Unit = { 45 | 46 | // for each of the defined tables, we want to run both the create and load SQL 47 | for ( lst_table <- tableNames) { 48 | val create_stmt = IoUtils.toUtf8String( 49 | getClass().getClassLoader().getResourceAsStream( 50 | s"lst/${filename_prefix}_${lst_table}.sql") 51 | ) 52 | 53 | val create_stmt_copy = s"CREATE TABLE IF NOT EXISTS ${lst_table}_copy (LIKE ${lst_table});" 54 | 55 | redshiftWrapper.executeUpdate(conn, create_stmt) 56 | redshiftWrapper.executeUpdate(conn, create_stmt_copy) 57 | } 58 | } 59 | 60 | // drops any danging tables from previous LST test runs then re-creates 61 | override def beforeAll(): Unit = { 62 | super.beforeAll() 63 | tableCleanUpHelper("drop table if exists") 64 | tableSetUpHelper("1_create") 65 | } 66 | 67 | // drops any danging tables from previous LST test runs 68 | override def afterAll(): Unit = { 69 | try { 70 | tableCleanUpHelper("drop table if exists") 71 | } finally { 72 | super.afterAll() 73 | } 74 | } 75 | 76 | // truncates existing tables from previous LST test case then re-loads 77 | override def beforeEach(): Unit = { 78 | super.beforeEach() 79 | try { 80 | tableCleanUpHelper("truncate")} catch { 81 | case _ : Exception => tableSetUpHelper("1_create") 82 | } 83 | tableSetUpHelper("2_load") 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/it/resources/lst/2_load_store_returns.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | INSERT INTO store_returns ( 19 | sr_returned_date_sk, 20 | sr_return_time_sk, 21 | sr_item_sk, 22 | sr_customer_sk, 23 | sr_cdemo_sk, 24 | sr_hdemo_sk, 25 | sr_addr_sk, 26 | sr_store_sk, 27 | sr_reason_sk, 28 | sr_ticket_number, 29 | sr_return_quantity, 30 | sr_return_amt, 31 | sr_return_tax, 32 | sr_return_amt_inc_tax, 33 | sr_fee, 34 | sr_return_ship_cost, 35 | sr_refunded_cash, 36 | sr_reversed_charge, 37 | sr_store_credit, 38 | sr_net_loss 39 | ) 40 | VALUES 41 | (1, 2, 3, 4, 5, 6, 7, 8, 9, 123456789, 1, 10.00, 1.00, 11.00, 0.50, 2.00, 8.00, 1.00, 0.50, 0.50), 42 | (2, 3, 4, 5, 6, 7, 8, 9, 10, 234567890, 2, 20.00, 2.00, 22.00, 1.00, 3.00, 16.00, 2.00, 1.00, 1.00), 43 | (3, 4, 5, 6, 7, 8, 9, 10, 11, 345678901, 3, 30.00, 3.00, 33.00, 1.50, 4.00, 24.00, 3.00, 1.50, 1.50), 44 | (4, 5, 6, 7, 8, 9, 10, 11, 12, 456789012, 4, 40.00, 4.00, 44.00, 2.00, 5.00, 32.00, 4.00, 2.00, 2.00), 45 | (5, 6, 7, 8, 9, 10, 11, 12, 13, 567890123, 5, 50.00, 5.00, 55.00, 2.50, 6.00, 40.00, 5.00, 2.50, 2.50), 46 | (6, 7, 8, 9, 10, 11, 12, 13, 14, 678901234, 6, 60.00, 6.00, 66.00, 3.00, 7.00, 48.00, 6.00, 3.00, 3.00), 47 | (7, 8, 9, 10, 11, 12, 13, 14, 15, 789012345, 7, 70.00, 7.00, 77.00, 3.50, 8.00, 56.00, 7.00, 3.50, 3.50), 48 | (8, 9, 10, 11, 12, 13, 14, 15, 16, 890123456, 8, 80.00, 8.00, 88.00, 4.00, 9.00, 64.00, 8.00, 4.00, 4.00), 49 | (9, 10, 11, 12, 13, 14, 15, 16, 17, 901234567, 9, 90.00, 9.00, 99.00, 4.50, 10.00, 72.00, 9.00, 4.50, 4.50), 50 | (10, 11, 12, 13, 14, 15, 16, 17, 18, 1012345678, 10, 100.00, 10.00, 110.00, 5.00, 11.00, 80.00, 10.00, 5.00, 5.00), 51 | (11, 12, 13, 14, 15, 16, 17, 18, 19, 1123456789, 11, 110.00, 11.00, 121.00, 5.50, 12.00, 88.00, 11.00, 5.50, 5.50), 52 | (12, 13, 14, 15, 16, 17, 18, 19, 20, 1234567890, 12, 120.00, 12.00, 132.00, 6.00, 13.00, 96.00, 12.00, 6.00, 6.00), 53 | (13, 14, 15, 16, 17, 18, 19, 20, 21, 1345678901, 13, 130.00, 13.00, 143.00, 6.50, 14.00, 104.00, 13.00, 6.50, 6.50), 54 | (14, 15, 16, 17, 18, 19, 20, 21, 22, 1456789012, 14, 140.00, 14.00, 154.00, 7.00, 15.00, 112.00, 14.00, 7.00, 7.00), 55 | (15, 16, 17, 18, 19, 20, 21, 22, 23, 1567890123, 15, 150.00, 15.00, 165.00, 7.50, 16.00, 120.00, 15.00, 7.50, 7.50), 56 | (16, 17, 18, 19, 20, 21, 22, 23, 24, 1678901234, 16, 160.00, 16.00, 176.00, 8.00, 17.00, 128.00, 16.00, 8.00, 8.00), 57 | (17, 18, 19, 20, 21, 22, 23, 24, 25, 1789012345, 17, 170.00, 17.00, 187.00, 8.50, 18.00, 136.00, 17.00, 8.50, 8.50), 58 | (18, 19, 20, 21, 22, 23, 24, 25, 26, 1890123456, 18, 180.00, 18.00, 198.00, 9.00, 19.00, 144.00, 18.00, 9.00, 9.00), 59 | (19, 20, 21, 22, 23, 24, 25, 26, 27, 1991234567, 19, 190.00, 19.00, 209.00, 9.50, 20.00, 152.00, 19.00, 9.50, 9.50), 60 | (20, 21, 22, 23, 24, 25, 26, 27, 28, 2092345678, 20, 200.00, 20.00, 220.00, 10.00, 21.00, 160.00, 20.00, 10.00, 10.00) 61 | ; -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/pushdown/StringSelectCorrectnessSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.test 17 | 18 | import org.apache.spark.sql.Row 19 | 20 | abstract class StringSelectCorrectnessSuite extends StringIntegrationPushdownSuiteBase { 21 | 22 | test("Select StringType Column") { 23 | // (id, column, result) 24 | val paramTuples = List( 25 | (0, "testfixedstring", null), 26 | (0, "testvarstring", null), 27 | (1, "testfixedstring", "Hello World"), 28 | (1, "testvarstring", "Hello World"), 29 | (2, "testfixedstring", "Controls\t \b\n\r\f\\'\""), 30 | (2, "testvarstring", "Controls\t \b\n\r\f\\'\""), 31 | (3, "testfixedstring", "Specials/%"), 32 | (3, "testvarstring", "Specials/%"), 33 | (4, "testfixedstring", "Singl_Byte_Chars" ), 34 | (4, "testvarstring", "Multi樂Byte趣Chars"), 35 | (5, "testfixedstring", ""), 36 | (5, "testvarstring", ""), 37 | (6, "testfixedstring", " Hello World"), 38 | (6, "testvarstring", " Hello World "), 39 | (7, "testfixedstring", " \t\b\nFoo\r\f\\'\""), 40 | (7, "testvarstring", " \t\b\nFoo\r\f\\'\" "), 41 | (8, "testfixedstring", " /%Foo%/"), 42 | (8, "testvarstring", " /%Foo%/ "), 43 | (9, "testfixedstring", " _Single_"), 44 | (9, "testvarstring", " 樂Multi趣 ") 45 | ) 46 | 47 | paramTuples.par.foreach(paramTuple => { 48 | val id = paramTuple._1 49 | val column = paramTuple._2 50 | val result = paramTuple._3 51 | 52 | checkAnswer( 53 | sqlContext.sql( 54 | s"""SELECT $column FROM test_table WHERE testid=$id""".stripMargin), 55 | Seq(Row(result))) 56 | }) 57 | } 58 | } 59 | 60 | class TextStringSelectCorrectnessSuite extends StringSelectCorrectnessSuite { 61 | override protected val s3format: String = "TEXT" 62 | override protected val auto_pushdown: String = "true" 63 | } 64 | 65 | class ParquetStringSelectCorrectnessSuite extends StringSelectCorrectnessSuite { 66 | override protected val s3format: String = "PARQUET" 67 | override protected val auto_pushdown: String = "true" 68 | } 69 | 70 | class TextNoPushdownStringSelectCorrectnessSuite extends StringSelectCorrectnessSuite { 71 | override protected val s3format: String = "TEXT" 72 | override protected val auto_pushdown: String = "false" 73 | } 74 | 75 | class ParquetNoPushdownStringSelectCorrectnessSuite extends StringSelectCorrectnessSuite { 76 | override protected val s3format: String = "PARQUET" 77 | override protected val auto_pushdown: String = "false" 78 | } 79 | 80 | class TextPushdownNoCacheStringSelectCorrectnessSuite 81 | extends TextStringSelectCorrectnessSuite { 82 | override protected val s3_result_cache = "false" 83 | } 84 | 85 | class ParquetPushdownNoCacheStringSelectCorrectnessSuite 86 | extends ParquetStringSelectCorrectnessSuite { 87 | override protected val s3_result_cache = "false" 88 | } 89 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/RedshiftScanExec.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.pushdown 18 | 19 | import io.github.spark_redshift_community.spark.redshift.RedshiftRelation 20 | import org.apache.spark.rdd.RDD 21 | import org.apache.spark.sql.catalyst.InternalRow 22 | import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} 23 | import org.apache.spark.sql.execution.LeafExecNode 24 | 25 | import java.util.concurrent.{Callable, ExecutorService, Executors, Future} 26 | 27 | /** 28 | * Redshift Scan Plan for pushing query fragment to redshift endpoint and 29 | * reading data from UNLOAD location 30 | * 31 | * @param output projected columns 32 | * @param query SQL query that is pushed to redshift for evaluation 33 | * @param relation Redshift node aiding in redshift cluster connection 34 | */ 35 | case class RedshiftScanExec(output: Seq[Attribute], 36 | query: RedshiftSQLStatement, 37 | relation: RedshiftRelation) 38 | extends LeafExecNode { 39 | 40 | @transient implicit private var data: Future[RedshiftPushdownResult] = _ 41 | @transient implicit private val service: ExecutorService = Executors.newCachedThreadPool() 42 | 43 | // this is the thread which constructed this not necessarily the executing thread 44 | private val threadName = Thread.currentThread.getName 45 | 46 | override protected def doPrepare(): Unit = { 47 | logInfo("Preparing query to push down to redshift") 48 | 49 | val work = new Callable[RedshiftPushdownResult]() { 50 | override def call(): RedshiftPushdownResult = { 51 | val result = { 52 | try { 53 | val data = relation.buildScanFromSQL[InternalRow](query, Some(schema), threadName) 54 | RedshiftPushdownResult(data = Some(data)) 55 | } catch { 56 | case e: Exception => 57 | logError(s"Failure in redshift query execution: ${e.getMessage}") 58 | RedshiftPushdownResult(failure = Some(e)) 59 | } 60 | } 61 | result 62 | } 63 | } 64 | data = service.submit(work) 65 | logInfo("submitted query to redshift asynchronously") 66 | } 67 | 68 | override protected def doExecute(): RDD[InternalRow] = { 69 | if (data.get().failure.nonEmpty) { 70 | // raise original exception 71 | throw data.get().failure.get 72 | } 73 | 74 | data.get().data.get 75 | } 76 | 77 | override def cleanupResources(): Unit = { 78 | logDebug(s"shutting down service to clean up resources") 79 | if (service != null) { 80 | service.shutdown() 81 | } 82 | } 83 | } 84 | 85 | /** 86 | * Result holder 87 | * 88 | * @param data RDD that holds the data from UNLOAD location 89 | * @param failure failure information if we unable to push down to 90 | * redshift or read unload data 91 | */ 92 | private case class RedshiftPushdownResult(data: Option[RDD[InternalRow]] = None, 93 | failure: Option[Exception] = None) 94 | extends Serializable 95 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/DecimalIntegrationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.github.spark_redshift_community.spark.redshift.test 18 | 19 | import io.github.spark_redshift_community.spark.redshift.Conversions 20 | import org.apache.spark.sql.Row 21 | import org.apache.spark.sql.types.DecimalType 22 | 23 | /** 24 | * Integration tests for decimal support. For a reference on Redshift's DECIMAL type, see 25 | * http://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html 26 | */ 27 | class DecimalIntegrationSuite extends IntegrationSuiteBase { 28 | 29 | private def testReadingDecimals(precision: Int, scale: Int, decimalStrings: Seq[String]): Unit = { 30 | test(s"reading DECIMAL($precision, $scale)") { 31 | val tableName = s"reading_decimal_${precision}_${scale}_$randomSuffix" 32 | val expectedRows = decimalStrings.map { d => 33 | if (d == null) { 34 | Row(null) 35 | } else { 36 | Row(Conversions.createRedshiftDecimalFormat().parse(d).asInstanceOf[java.math.BigDecimal]) 37 | } 38 | } 39 | try { 40 | redshiftWrapper.executeUpdate( 41 | conn, s"CREATE TABLE $tableName (x DECIMAL($precision, $scale))") 42 | for (x <- decimalStrings) { 43 | redshiftWrapper.executeUpdate(conn, s"INSERT INTO $tableName VALUES ($x)") 44 | } 45 | assert(redshiftWrapper.tableExists(conn, tableName)) 46 | val loadedDf = read.option("dbtable", tableName).load() 47 | checkAnswer(loadedDf, expectedRows) 48 | checkAnswer(loadedDf.selectExpr("x + 0"), expectedRows) 49 | } finally { 50 | redshiftWrapper.executeUpdate(conn, s"drop table if exists $tableName") 51 | } 52 | } 53 | } 54 | 55 | testReadingDecimals(19, 0, Seq( 56 | // Max and min values of DECIMAL(19, 0) column according to Redshift docs: 57 | "9223372036854775807", // 2^63 - 1 58 | "-9223372036854775807", 59 | "0", 60 | "12345678910", 61 | null 62 | )) 63 | 64 | testReadingDecimals(19, 4, Seq( 65 | "922337203685477.5807", 66 | "-922337203685477.5807", 67 | "0", 68 | "1234567.8910", 69 | null 70 | )) 71 | 72 | testReadingDecimals(38, 4, Seq( 73 | "922337203685477.5808", 74 | "9999999999999999999999999999999999.0000", 75 | "-9999999999999999999999999999999999.0000", 76 | "0", 77 | "1234567.8910", 78 | null 79 | )) 80 | 81 | test("Decimal precision is preserved when reading from query (regression test for issue #203)") { 82 | withTempRedshiftTable("issue203") { tableName => 83 | redshiftWrapper.executeUpdate(conn, s"CREATE TABLE $tableName (foo BIGINT)") 84 | redshiftWrapper.executeUpdate(conn, s"INSERT INTO $tableName VALUES (91593373)") 85 | assert(redshiftWrapper.tableExists(conn, tableName)) 86 | val df = read 87 | .option("query", s"select foo / 1000000.0 from $tableName limit 1") 88 | .load() 89 | val res: Double = df.collect().toSeq.head.getDecimal(0).doubleValue() 90 | assert(res === (91593373L / 1000000.0) +- 0.01) 91 | assert(df.schema.fields.head.dataType === DecimalType(28, 8)) 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/NumericStatement.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 19 | 20 | import io.github.spark_redshift_community.spark.redshift.pushdown.{ConstantString, RedshiftSQLStatement} 21 | import org.apache.spark.sql.catalyst.expressions.{Abs, Acos, Asin, Atan, Attribute, Ceil, CheckOverflow, Cos, Exp, Expression, Floor, Greatest, Least, Log10, Pi, Pow, Sin, Sqrt, Tan, UnaryMinus} 22 | 23 | import scala.language.postfixOps 24 | 25 | /** Extractor for boolean expressions (return true or false). */ 26 | private[querygeneration] object NumericStatement { 27 | 28 | /** Used mainly by QueryGeneration.convertExpression. This matches 29 | * a tuple of (Expression, Seq[Attribute]) representing the expression to 30 | * be matched and the fields that define the valid fields in the current expression 31 | * scope, respectively. 32 | * 33 | * @param expAttr A pair-tuple representing the expression to be matched and the 34 | * attribute fields. 35 | * @return An option containing the translated SQL, if there is a match, or None if there 36 | * is no match. 37 | */ 38 | def unapply( 39 | expAttr: (Expression, Seq[Attribute]) 40 | ): Option[RedshiftSQLStatement] = { 41 | val expr = expAttr._1 42 | val fields = expAttr._2 43 | 44 | Option(expr match { 45 | case _: Abs | _: Acos | _: Cos | _: Tan | _: Atan | 46 | _: Floor | _: Sin | _: Asin | _: Sqrt | _: Ceil | 47 | _: Sqrt | _: Greatest | _: Least | _: Exp => 48 | ConstantString(expr.prettyName.toUpperCase) + 49 | blockStatement(convertStatements(fields, expr.children: _*)) 50 | 51 | case _: Log10 => 52 | ConstantString("LOG") + 53 | blockStatement(convertStatements(fields, expr.children: _*)) 54 | 55 | // From spark 3.1, UnaryMinus() has 2 parameters. 56 | case UnaryMinus(child, _) => 57 | ConstantString("-") + 58 | blockStatement(convertStatement(child, fields)) 59 | 60 | case Pow(left, right) => 61 | ConstantString("POWER") + 62 | blockStatement( 63 | convertStatement(left, fields) + "," + convertStatement( 64 | right, 65 | fields 66 | ) 67 | ) 68 | 69 | case PromotePrecisionExtractor(child) => convertStatement(child, fields) 70 | 71 | case CheckOverflow(child, t, _) => 72 | getCastType(t) match { 73 | case Some(cast) => 74 | ConstantString("CAST") + 75 | blockStatement(convertStatement(child, fields) + "AS" + cast) 76 | case _ => convertStatement(child, fields) 77 | } 78 | 79 | // Spark has resolved PI() as 3.141592653589793 80 | // Suppose connector can't see Pi(). 81 | case Pi() => ConstantString("PI()") ! 82 | 83 | case RoundExtractor(child, scale, ansiEnabled) if !ansiEnabled => 84 | ConstantString("ROUND") + blockStatement( 85 | convertStatements(fields, child, scale) 86 | ) 87 | 88 | case _ => null 89 | }) 90 | } 91 | } -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/TableNameSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.test 19 | 20 | import io.github.spark_redshift_community.spark.redshift.TableName 21 | import org.scalatest.funsuite.AnyFunSuite 22 | 23 | class TableNameSuite extends AnyFunSuite { 24 | test("TableName.parseFromEscaped") { 25 | 26 | assert(TableName.parseFromEscaped("foo.bar") === TableName("", "foo", "bar")) 27 | assert(TableName.parseFromEscaped("foo") === TableName("", "PUBLIC", "foo")) 28 | assert(TableName.parseFromEscaped("\"foo\"") === TableName("", "PUBLIC", "foo")) 29 | assert(TableName.parseFromEscaped("\"\"\"foo\"\"\".bar") === TableName("", "\"foo\"", "bar")) 30 | // Dots (.) can also appear inside of valid identifiers. 31 | assert(TableName.parseFromEscaped("\"foo.bar\".baz") === TableName("", "foo.bar", "baz")) 32 | assert(TableName.parseFromEscaped("\"foo\"\".bar\".baz") === TableName("", "foo\".bar", "baz")) 33 | assert(TableName.parseFromEscaped(""""foo"".bar".baz""") === TableName("", "foo\".bar", "baz")) 34 | 35 | // Test three-part names 36 | assert(TableName.parseFromEscaped("foo.bar.baz") === TableName("foo", "bar", "baz")) 37 | assert(TableName.parseFromEscaped("awsdatacatalog.glue_db.my_table") === 38 | TableName("awsdatacatalog", "glue_db", "my_table")) 39 | assert(TableName.parseFromEscaped("awsdatacatalog.glue_db.\"public.my_table\"") === 40 | TableName("awsdatacatalog", "glue_db", "public.my_table")) 41 | assert(TableName.parseFromEscaped(""""awsdatacatalog"."glue_db"."public.my_table"""") === 42 | TableName("awsdatacatalog", "glue_db", "public.my_table")) 43 | assert(TableName.parseFromEscaped("""awsdatacatalog.glue_db."public.my_table"""") === 44 | TableName("awsdatacatalog", "glue_db", "public.my_table")) 45 | assert(TableName.parseFromEscaped("""awsdatacatalog."glue.db"."pub.lic.my_.table"""") === 46 | TableName("awsdatacatalog", "glue.db", "pub.lic.my_.table")) 47 | assert(TableName.parseFromEscaped("\"awsdata\"\"catalog\".\"glue_db\".\"public.my_table\"") === 48 | TableName("awsdata\"catalog", "glue_db", "public.my_table")) 49 | assert(TableName.parseFromEscaped(""""awsdata""catalog".glue_db."public.my_table"""") === 50 | TableName("awsdata\"catalog", "glue_db", "public.my_table")) 51 | } 52 | 53 | test("TableName.toString") { 54 | assert(TableName("", "foo", "bar").toString === """"foo"."bar"""") 55 | assert(TableName("", "PUBLIC", "bar").toString === """"PUBLIC"."bar"""") 56 | assert(TableName("", "\"foo\"", "bar").toString === "\"\"\"foo\"\"\".\"bar\"") 57 | 58 | // Test three-part names 59 | assert(TableName("foo", "bar", "baz").toString === """"foo"."bar"."baz"""") 60 | assert(TableName("awsdatacatalog", "glue_db", "my_table").toString === 61 | """"awsdatacatalog"."glue_db"."my_table"""") 62 | assert(TableName("awsdatacatalog", "glue_db", "public.my_table").toString === 63 | """"awsdatacatalog"."glue_db"."public.my_table"""") 64 | assert(TableName("""aws.data"catalog""", """gl.ue"_db""", """pub"lic.my"_tab.le""").toString === 65 | """"aws.data""catalog"."gl.ue""_db"."pub""lic.my""_tab.le"""") 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/TableName.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Copyright 2015 Databricks 4 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | package io.github.spark_redshift_community.spark.redshift 20 | 21 | import io.github.spark_redshift_community.spark.redshift.pushdown.{ConstantString, Identifier} 22 | 23 | import scala.collection.mutable.ArrayBuffer 24 | 25 | /** 26 | * Wrapper class for representing the name of a Redshift table. 27 | */ 28 | private[redshift] case class TableName(unescapedDatabaseName: String, 29 | unescapedSchemaName: String, 30 | unescapedTableName: String) { 31 | private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"' 32 | def escapedDatabaseName: String = quote(unescapedDatabaseName) 33 | def escapedSchemaName: String = quote(unescapedSchemaName) 34 | def escapedTableName: String = quote(unescapedTableName) 35 | override def toString: String = { 36 | if (unescapedDatabaseName.isEmpty) { 37 | s"$escapedSchemaName.$escapedTableName" 38 | } else { 39 | s"$escapedDatabaseName.$escapedSchemaName.$escapedTableName" 40 | } 41 | } 42 | def toStatement: Identifier = Identifier(toString) 43 | def toConstantString: ConstantString = ConstantString(toString) 44 | } 45 | 46 | private[redshift] object TableName { 47 | /** 48 | * Parses a table name which is assumed to have been escaped according to Redshift's rules for 49 | * delimited identifiers. 50 | */ 51 | def parseFromEscaped(str: String): TableName = { 52 | def dropOuterQuotes(s: String) = 53 | if (s.startsWith("\"") && s.endsWith("\"")) s.drop(1).dropRight(1) else s 54 | def unescapeQuotes(s: String) = s.replace("\"\"", "\"") 55 | def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s)) 56 | splitByDots(str) match { 57 | case Seq(tableName) => TableName("", "PUBLIC", unescape(tableName)) 58 | case Seq(schemaName, tableName) => TableName("", unescape(schemaName), unescape(tableName)) 59 | case Seq(databaseName, schemaName, tableName) => 60 | TableName(unescape(databaseName), unescape(schemaName), unescape(tableName)) 61 | case _ => throw new IllegalArgumentException(s"Could not parse table name from '$str'") 62 | } 63 | } 64 | 65 | /** 66 | * Split by dots (.) while obeying our identifier quoting rules in order to allow dots to appear 67 | * inside of quoted identifiers. 68 | */ 69 | private def splitByDots(str: String): Seq[String] = { 70 | val parts: ArrayBuffer[String] = ArrayBuffer.empty 71 | val sb = new StringBuilder 72 | var inQuotes: Boolean = false 73 | for (c <- str) c match { 74 | case '"' => 75 | // Note that double quotes are escaped by pairs of double quotes (""), so we don't need 76 | // any extra code to handle them; we'll be back in inQuotes=true after seeing the pair. 77 | sb.append('"') 78 | inQuotes = !inQuotes 79 | case '.' => 80 | if (!inQuotes) { 81 | parts.append(sb.toString()) 82 | sb.clear() 83 | } else { 84 | sb.append('.') 85 | } 86 | case other => 87 | sb.append(other) 88 | } 89 | if (sb.nonEmpty) { 90 | parts.append(sb.toString()) 91 | } 92 | parts 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/FilterPushdownSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.test 19 | 20 | import io.github.spark_redshift_community.spark.redshift.FilterPushdown._ 21 | import org.apache.spark.sql.sources._ 22 | import org.apache.spark.sql.types._ 23 | import org.scalatest.funsuite.AnyFunSuite 24 | 25 | 26 | class FilterPushdownSuite extends AnyFunSuite { 27 | test("buildWhereClause with empty list of filters") { 28 | assert(buildWhereClause(StructType(Nil), Seq.empty) === "") 29 | } 30 | 31 | test("buildWhereClause with no filters that can be pushed down") { 32 | assert(buildWhereClause(StructType(Nil), Seq(AlwaysTrue, AlwaysTrue)) === "") 33 | } 34 | 35 | test("buildWhereClause with with some filters that cannot be pushed down") { 36 | val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_int", 1), AlwaysTrue)) 37 | assert(whereClause === """WHERE "test_int" = 1""") 38 | } 39 | 40 | test("buildWhereClause with string literals that contain Unicode characters") { 41 | // scalastyle:off 42 | val whereClause = buildWhereClause(testSchema, Seq(EqualTo("test_string", "Unicode's樂趣"))) 43 | // Here, the apostrophe in the string needs to be replaced with two single quotes, '', but we 44 | // also need to escape those quotes with backslashes because this WHERE clause is going to 45 | // eventually be embedded inside of a single-quoted string that's embedded inside of a larger 46 | // Redshift query. 47 | assert(whereClause === """WHERE "test_string" = 'Unicode\'\'s樂趣'""") 48 | // scalastyle:on 49 | } 50 | 51 | test("buildWhereClause with multiple filters") { 52 | val filters = Seq( 53 | EqualTo("test_bool", true), 54 | // scalastyle:off 55 | EqualTo("test_string", "Unicode是樂趣"), 56 | // scalastyle:on 57 | GreaterThan("test_double", 1000.0), 58 | LessThan("test_double", Double.MaxValue), 59 | GreaterThanOrEqual("test_float", 1.0f), 60 | GreaterThanOrEqual("test_float", 1.0d), 61 | LessThanOrEqual("test_int", 43), 62 | IsNotNull("test_int"), 63 | IsNull("test_int")) 64 | val whereClause = buildWhereClause(testSchema, filters) 65 | // scalastyle:off 66 | val expectedWhereClause = 67 | """ 68 | |WHERE "test_bool" = true 69 | |AND "test_string" = 'Unicode是樂趣' 70 | |AND "test_double" > 1000.0 71 | |AND "test_double" < 1.7976931348623157E308 72 | |AND "test_float" >= 1.0::float4 73 | |AND "test_float" >= 1.0 74 | |AND "test_int" <= 43 75 | |AND "test_int" IS NOT NULL 76 | |AND "test_int" IS NULL 77 | """.stripMargin.lines.toArray.mkString(" ").trim 78 | // scalastyle:on 79 | assert(whereClause === expectedWhereClause) 80 | } 81 | 82 | private val testSchema: StructType = StructType(Seq( 83 | StructField("test_byte", ByteType), 84 | StructField("test_bool", BooleanType), 85 | StructField("test_date", DateType), 86 | StructField("test_double", DoubleType), 87 | StructField("test_float", FloatType), 88 | StructField("test_int", IntegerType), 89 | StructField("test_long", LongType), 90 | StructField("test_short", ShortType), 91 | StructField("test_string", StringType), 92 | StructField("test_timestamp", TimestampType))) 93 | 94 | } 95 | -------------------------------------------------------------------------------- /src/test/scala/io/github/spark_redshift_community/spark/redshift/QueryTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed to the Apache Software Foundation (ASF) under one or more 6 | * contributor license agreements. See the NOTICE file distributed with 7 | * this work for additional information regarding copyright ownership. 8 | * The ASF licenses this file to You under the Apache License, Version 2.0 9 | * (the "License"); you may not use this file except in compliance with 10 | * the License. You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | package io.github.spark_redshift_community.spark.redshift.test 22 | 23 | import org.apache.spark.sql.catalyst.plans.logical 24 | import org.apache.spark.sql.{DataFrame, Row} 25 | import org.scalatest.funsuite.AnyFunSuite 26 | 27 | /** 28 | * Copy of Spark SQL's `QueryTest` trait. 29 | */ 30 | trait QueryTest extends AnyFunSuite { 31 | /** 32 | * Runs the plan and makes sure the answer matches the expected result. 33 | * @param df the [[DataFrame]] to be executed 34 | * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. 35 | */ 36 | def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], trim: Boolean = false): Unit = { 37 | val isSorted = df.queryExecution.logical.collect { case s: logical.Sort => s }.nonEmpty 38 | def prepareAnswer(answer: Seq[Row]): Seq[Row] = { 39 | // Converts data to types that we can do equality comparison using Scala collections. 40 | // For BigDecimal type, the Scala type has a better definition of equality test (similar to 41 | // Java's java.math.BigDecimal.compareTo). 42 | // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for 43 | // equality test. 44 | val converted: Seq[Row] = answer.map { s => 45 | Row.fromSeq(s.toSeq.map { 46 | case d: java.math.BigDecimal => BigDecimal(d) 47 | case b: Array[Byte] => b.toSeq 48 | case s: String if trim => s.trim() 49 | case o => o 50 | }) 51 | } 52 | if (!isSorted) converted.sortBy(_.toString()) else converted 53 | } 54 | val sparkAnswer = try df.collect().toSeq catch { 55 | case e: Exception => 56 | val errorMessage = 57 | s""" 58 | |Exception thrown while executing query: 59 | |${df.queryExecution} 60 | |== Exception == 61 | |$e 62 | |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} 63 | """.stripMargin 64 | fail(errorMessage) 65 | } 66 | 67 | if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { 68 | val errorMessage = 69 | s""" 70 | |Results do not match for query: 71 | |${df.queryExecution} 72 | |== Results == 73 | |${sideBySide( 74 | s"== Correct Answer - ${expectedAnswer.size} ==" +: 75 | prepareAnswer(expectedAnswer).map(_.toString()), 76 | s"== Spark Answer - ${sparkAnswer.size} ==" +: 77 | prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} 78 | """.stripMargin 79 | fail(errorMessage) 80 | } 81 | } 82 | 83 | private def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { 84 | val maxLeftSize = left.map(_.length).max 85 | val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") 86 | val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") 87 | 88 | leftPadded.zip(rightPadded).map { 89 | case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/data/RedshiftResults.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.data 17 | 18 | import software.amazon.awssdk.services.redshiftdata.model.{Field, GetStatementResultResponse} 19 | 20 | import java.sql.ResultSet 21 | import scala.collection.JavaConverters._ 22 | 23 | private[redshift] abstract class RedshiftResults() { 24 | def next(): Boolean 25 | 26 | // 1-based indexing per JDBC convention 27 | def getInt(columnIndex: Int): Int 28 | def getLong(columnIndex: Int): Long 29 | def getString(columnIndex: Int): String 30 | 31 | def getInt(columnLabel: String): Int 32 | def getLong(columnLabel: String): Long 33 | def getString(columnLabel: String): String 34 | 35 | } 36 | 37 | private[redshift] case class DataApiResults(results: GetStatementResultResponse) 38 | extends RedshiftResults { 39 | 40 | private val iter = results.records().asScala.iterator 41 | private var curr: java.util.List[Field] = null 42 | 43 | override def next(): Boolean = { 44 | if (!iter.hasNext) { 45 | false 46 | } else { 47 | curr = iter.next() 48 | true 49 | } 50 | } 51 | 52 | override def getInt(columnIndex: Int): Int = { 53 | val longVal = curr.get(columnIndex - 1).longValue() 54 | if (longVal < Int.MinValue || longVal > Int.MaxValue) { 55 | throw new ArithmeticException(s"Long value $longVal cannot be converted to Int without data loss") 56 | } 57 | longVal.toInt 58 | } 59 | 60 | override def getLong(columnIndex: Int): Long = { 61 | curr.get(columnIndex - 1).longValue() 62 | } 63 | 64 | override def getString(columnIndex: Int): String = { 65 | curr.get(columnIndex - 1).stringValue() 66 | } 67 | 68 | override def getInt(columnLabel: String): Int = { 69 | val longVal = curr.get(getIndex(columnLabel)).longValue() 70 | if (longVal < Int.MinValue || longVal > Int.MaxValue) { 71 | throw new ArithmeticException(s"Long value $longVal cannot be converted to Int without data loss") 72 | } 73 | longVal.toInt 74 | } 75 | 76 | override def getLong(columnLabel: String): Long = { 77 | curr.get(getIndex(columnLabel)).longValue() 78 | } 79 | 80 | override def getString(columnLabel: String): String = { 81 | curr.get(getIndex(columnLabel)).stringValue() 82 | } 83 | 84 | private def getIndex(columnLabel: String): Int = { 85 | results.columnMetadata().asScala.indexWhere(col => col.label() == columnLabel) 86 | } 87 | } 88 | 89 | private[redshift] case class JDBCResults(results: ResultSet) extends RedshiftResults { 90 | override def next(): Boolean = { 91 | results.next() 92 | } 93 | 94 | override def getInt(columnIndex: Int): Int = { 95 | results.getInt(columnIndex) 96 | } 97 | 98 | override def getLong(columnIndex: Int): Long = { 99 | results.getLong(columnIndex) 100 | } 101 | 102 | override def getString(columnIndex: Int): String = { 103 | results.getString(columnIndex) 104 | } 105 | 106 | override def getInt(columnLabel: String): Int = { 107 | results.getInt(columnLabel) 108 | } 109 | 110 | override def getLong(columnLabel: String): Long = { 111 | results.getLong(columnLabel) 112 | } 113 | 114 | override def getString(columnLabel: String): String = { 115 | results.getString(columnLabel) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/AWSCredentialsUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Databricks 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift 19 | 20 | import java.net.URI 21 | import software.amazon.awssdk.auth.credentials._ 22 | import io.github.spark_redshift_community.spark.redshift.Parameters.MergedParameters 23 | import io.github.spark_redshift_community.spark.redshift.Utils.CONNECTOR_REDSHIFT_S3_CONNECTION_IAM_ROLE_ONLY 24 | import org.apache.hadoop.conf.Configuration 25 | 26 | private[redshift] object AWSCredentialsUtils { 27 | 28 | /** 29 | * Generates a credentials string for use in Redshift COPY and UNLOAD statements. 30 | * Favors a configured `aws_iam_role` if available in the parameters. 31 | */ 32 | def getRedshiftCredentialsString( 33 | params: MergedParameters, 34 | credentialsProvider: AwsCredentialsProvider): String = { 35 | 36 | def awsCredsToString(credentials: AwsCredentials): String = { 37 | credentials match { 38 | case creds: AwsSessionCredentials => 39 | s"aws_access_key_id=${creds.accessKeyId()};" + 40 | s"aws_secret_access_key=${creds.secretAccessKey()};token=${creds.sessionToken()}" 41 | case creds => 42 | s"aws_access_key_id=${creds.accessKeyId()};" + 43 | s"aws_secret_access_key=${creds.secretAccessKey()}" 44 | } 45 | } 46 | 47 | if (Utils.isRedshiftS3ConnectionViaIAMRoleOnly() && 48 | (params.temporaryAWSCredentials.isDefined || params.forwardSparkS3Credentials)) { 49 | throw new RedshiftConstraintViolationException("Only the aws_iam_role option for configuring " + 50 | "credentials is supported when configuration " + 51 | s"$CONNECTOR_REDSHIFT_S3_CONNECTION_IAM_ROLE_ONLY is set to true.") 52 | } 53 | 54 | if (params.iamRole.isDefined) { 55 | s"aws_iam_role=${params.iamRole.get}" 56 | } else if (params.temporaryAWSCredentials.isDefined) { 57 | awsCredsToString(params.temporaryAWSCredentials.get.resolveCredentials()) 58 | } else if (params.forwardSparkS3Credentials) { 59 | awsCredsToString(credentialsProvider.resolveCredentials()) 60 | } else { 61 | throw new IllegalStateException("No Redshift S3 authentication mechanism was specified") 62 | } 63 | } 64 | 65 | def staticCredentialsProvider(credentials: AwsCredentials): AwsCredentialsProvider = { 66 | StaticCredentialsProvider.create(credentials) 67 | } 68 | 69 | def load(params: MergedParameters, hadoopConfiguration: Configuration): AwsCredentialsProvider = { 70 | // Load the credentials. 71 | params.temporaryAWSCredentials.getOrElse(loadFromURI(params.rootTempDir, hadoopConfiguration)) 72 | } 73 | 74 | private def loadFromURI( 75 | tempPath: String, 76 | hadoopConfiguration: Configuration): AwsCredentialsProvider = { 77 | // scalastyle:off 78 | // A good reference on Hadoop's configuration loading / precedence is 79 | // https://github.com/apache/hadoop/blob/trunk/hadoop-tools/hadoop-aws/src/site/markdown/tools/hadoop-aws/index.md 80 | // scalastyle:on 81 | val uri = new URI(tempPath) 82 | val uriScheme = uri.getScheme 83 | 84 | uriScheme match { 85 | case "s3" | "s3n" | "s3a" => 86 | DefaultCredentialsProvider.create() 87 | case other => 88 | throw new IllegalArgumentException(s"Unrecognized scheme $other; expected s3, s3n, or s3a") 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/it/resources/lst/2_load_web_returns.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | INSERT INTO "PUBLIC"."web_returns" ( 19 | wr_returned_date_sk, 20 | wr_returned_time_sk, 21 | wr_item_sk, 22 | wr_refunded_customer_sk, 23 | wr_refunded_cdemo_sk, 24 | wr_refunded_hdemo_sk, 25 | wr_refunded_addr_sk, 26 | wr_returning_customer_sk, 27 | wr_returning_cdemo_sk, 28 | wr_returning_hdemo_sk, 29 | wr_returning_addr_sk, 30 | wr_web_page_sk, 31 | wr_reason_sk, 32 | wr_order_number, 33 | wr_return_quantity, 34 | wr_return_amt, 35 | wr_return_tax, 36 | wr_return_amt_inc_tax, 37 | wr_fee, 38 | wr_return_ship_cost, 39 | wr_refunded_cash, 40 | wr_reversed_charge, 41 | wr_account_credit, 42 | wr_net_loss 43 | ) 44 | VALUES 45 | (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 123456789, 1, 10.00, 1.00, 11.00, 0.50, 2.00, 8.00, 1.00, 0.50, 0.50), 46 | (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 234567890, 2, 20.00, 2.00, 22.00, 1.00, 3.00, 16.00, 2.00, 1.00, 1.00), 47 | (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 345678901, 3, 30.00, 3.00, 33.00, 1.50, 4.00, 24.00, 3.00, 1.50, 1.50), 48 | (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 456789012, 4, 40.00, 4.00, 44.00, 2.00, 5.00, 32.00, 4.00, 2.00, 2.00), 49 | (5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 567890123, 5, 50.00, 5.00, 55.00, 2.50, 6.00, 40.00, 5.00, 2.50, 2.50), 50 | (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 678901234, 6, 60.00, 6.00, 66.00, 3.00, 7.00, 48.00, 6.00, 3.00, 3.00), 51 | (7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 789012345, 7, 70.00, 7.00, 77.00, 3.50, 8.00, 56.00, 7.00, 3.50, 3.50), 52 | (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 890123456, 8, 80.00, 8.00, 88.00, 4.00, 9.00, 64.00, 8.00, 4.00, 4.00), 53 | (9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 901234567, 9, 90.00, 9.00, 99.00, 4.50, 10.00, 72.00, 9.00, 4.50, 4.50), 54 | (10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 1012345678, 10, 100.00, 10.00, 110.00, 5.00, 11.00, 80.00, 10.00, 5.00, 5.00), 55 | (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 1123456789, 11, 110.00, 11.00, 121.00, 5.50, 12.00, 88.00, 11.00, 5.50, 5.50), 56 | (12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 1234567890, 12, 120.00, 12.00, 132.00, 6.00, 13.00, 96.00, 12.00, 6.00, 6.00), 57 | (13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 1345678901, 13, 130.00, 13.00, 143.00, 6.50, 14.00, 104.00, 13.00, 6.50, 6.50), 58 | (14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 1456789012, 14, 140.00, 14.00, 154.00, 7.00, 15.00, 112.00, 14.00, 7.00, 7.00), 59 | (15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 1567890123, 15, 150.00, 15.00, 165.00, 7.50, 16.00, 120.00, 15.00, 7.50, 7.50), 60 | (16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 1678901234, 16, 160.00, 16.00, 176.00, 8.00, 17.00, 128.00, 16.00, 8.00, 8.00), 61 | (17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1789012345, 17, 170.00, 17.00, 187.00, 8.50, 18.00, 136.00, 17.00, 8.50, 8.50), 62 | (18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 1890123456, 18, 180.00, 18.00, 198.00, 9.00, 19.00, 144.00, 18.00, 9.00, 9.00), 63 | (19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 1991234567, 19, 190.00, 19.00, 209.00, 9.50, 20.00, 152.00, 19.00, 9.50, 9.50), 64 | (20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 2092345678, 20, 200.00, 20.00, 220.00, 10.00, 21.00, 160.00, 20.00, 10.00, 10.00) 65 | ; -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/BooleanStatement.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 19 | 20 | import io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration.StringStatement.DEFAULT_LIKE_ESCAPE_CHAR 21 | import io.github.spark_redshift_community.spark.redshift.pushdown.{ConstantString, RedshiftSQLStatement} 22 | import org.apache.spark.sql.catalyst.expressions.{Attribute, Concat, Contains, EndsWith, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Like, Literal, Not, StartsWith} 23 | import org.apache.spark.sql.types.StringType 24 | import org.apache.spark.unsafe.types.UTF8String 25 | 26 | /** Extractor for boolean expressions (return true or false). */ 27 | private[querygeneration] object BooleanStatement { 28 | def unapply( 29 | expAttr: (Expression, Seq[Attribute]) 30 | ): Option[RedshiftSQLStatement] = { 31 | val expr = expAttr._1 32 | val fields = expAttr._2 33 | 34 | Option(expr match { 35 | case In(child, list) if list.forall(_.isInstanceOf[Literal]) => 36 | convertStatement(child, fields) + "IN" + 37 | blockStatement(convertStatements(fields, list: _*)) 38 | case IsNull(child) => 39 | blockStatement(convertStatement(child, fields) + "IS NULL") 40 | case IsNotNull(child) => 41 | blockStatement(convertStatement(child, fields) + "IS NOT NULL") 42 | case Not(child) => { 43 | child match { 44 | case EqualTo(left, right) => 45 | blockStatement( 46 | convertStatement(left, fields) + "!=" + 47 | convertStatement(right, fields) 48 | ) 49 | // NOT ( GreaterThanOrEqual, LessThanOrEqual, 50 | // GreaterThan and LessThan ) have been optimized by spark 51 | // and are handled by BinaryOperator in BasicStatement. 52 | case GreaterThanOrEqual(left, right) => 53 | convertStatement(LessThan(left, right), fields) 54 | case LessThanOrEqual(left, right) => 55 | convertStatement(GreaterThan(left, right), fields) 56 | case GreaterThan(left, right) => 57 | convertStatement(LessThanOrEqual(left, right), fields) 58 | case LessThan(left, right) => 59 | convertStatement(GreaterThanOrEqual(left, right), fields) 60 | case _ => 61 | ConstantString("NOT") + 62 | blockStatement(convertStatement(child, fields)) 63 | } 64 | } 65 | // Cast the left string into a varchar to ensure fixed-length strings are right-trimmed 66 | // since Redshift doesn't do this automatically for LIKE expressions. We want the push-down 67 | // behavior to always match the non-push-down behavior which trims fixed-length strings. 68 | case Contains(left, right) => 69 | blockStatement(convertStatement(Like(left, Concat(Seq(Literal("%"), right, Literal("%"))), 70 | DEFAULT_LIKE_ESCAPE_CHAR), fields)) 71 | case EndsWith(left, right) => 72 | blockStatement(convertStatement(Like(left, Concat(Seq(Literal("%"), right)), 73 | DEFAULT_LIKE_ESCAPE_CHAR), fields)) 74 | case StartsWith(left, right) => 75 | blockStatement(convertStatement(Like(left, Concat(Seq(right, Literal("%"))), 76 | DEFAULT_LIKE_ESCAPE_CHAR), fields)) 77 | case _ => null 78 | }) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/it/resources/lst/2_load_date_dim.sql: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Microsoft Corporation. 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | INSERT INTO "PUBLIC"."date_dim" ( 19 | d_date_sk, 20 | d_date_id, 21 | d_date, 22 | d_month_seq, 23 | d_week_seq, 24 | d_quarter_seq, 25 | d_year, 26 | d_dow, 27 | d_moy, 28 | d_dom, 29 | d_qoy, 30 | d_fy_year, 31 | d_fy_quarter_seq, 32 | d_fy_week_seq, 33 | d_day_name, 34 | d_quarter_name, 35 | d_holiday, 36 | d_weekend, 37 | d_following_holiday, 38 | d_first_dom, 39 | d_last_dom, 40 | d_same_day_ly, 41 | d_same_day_lq, 42 | d_current_day, 43 | d_current_week, 44 | d_current_month, 45 | d_current_quarter, 46 | d_current_year 47 | ) VALUES 48 | (1, '20240101', '2024-01-01', 1, 1, 1, 2024, 1, 1, 1, 1, 2024, 1, 1, 'Monday', 'Q1', 'N', 'N', 'Y', 1, 31, 1, 52, 'Y', 'Y', 'Y', 'Y', 'Y'), 49 | (2, '20240102', '2024-01-02', 1, 1, 1, 2024, 2, 1, 2, 1, 2024, 1, 1, 'Tuesday', 'Q1', 'N', 'N', 'N', 1, 31, 2, 52, 'N', 'Y', 'Y', 'Y', 'Y'), 50 | (3, '20240103', '2024-01-03', 1, 1, 1, 2024, 3, 1, 3, 1, 2024, 1, 1, 'Wednesday', 'Q1', 'N', 'N', 'N', 1, 31, 3, 52, 'N', 'N', 'Y', 'Y', 'Y'), 51 | (4, '20240104', '2024-01-04', 1, 1, 1, 2024, 4, 1, 4, 1, 2024, 1, 1, 'Thursday', 'Q1', 'N', 'N', 'N', 1, 31, 4, 52, 'N', 'N', 'N', 'Y', 'Y'), 52 | (5, '20240105', '2024-01-05', 1, 1, 1, 2024, 5, 1, 5, 1, 2024, 1, 1, 'Friday', 'Q1', 'N', 'Y', 'N', 1, 31, 5, 52, 'N', 'N', 'N', 'N', 'Y'), 53 | (6, '20240106', '2024-01-06', 1, 1, 1, 2024, 6, 1, 6, 1, 2024, 1, 1, 'Saturday', 'Q1', 'N', 'Y', 'Y', 1, 31, 6, 52, 'N', 'N', 'N', 'N', 'Y'), 54 | (7, '20240107', '2024-01-07', 1, 1, 1, 2024, 7, 1, 7, 1, 2024, 1, 1, 'Sunday', 'Q1', 'N', 'Y', 'Y', 1, 31, 7, 52, 'N', 'N', 'N', 'N', 'Y'), 55 | (8, '20240108', '2024-01-08', 2, 2, 1, 2024, 1, 1, 8, 1, 2024, 1, 2, 'Monday', 'Q1', 'N', 'N', 'Y', 1, 31, 8, 53, 'Y', 'Y', 'Y', 'Y', 'Y'), 56 | (9, '20240109', '2024-01-09', 2, 2, 1, 2024, 2, 1, 9, 1, 2024, 1, 2, 'Tuesday', 'Q1', 'N', 'N', 'N', 1, 31, 9, 53, 'N', 'Y', 'Y', 'Y', 'Y'), 57 | (10, '20240110', '2024-01-10', 2, 2, 1, 2024, 3, 1, 10, 1, 2024, 1, 2, 'Wednesday', 'Q1', 'N', 'N', 'N', 1, 31, 10, 53, 'N', 'N', 'Y', 'Y', 'Y'), 58 | (11, '20240111', '2024-01-11', 2, 2, 1, 2024, 4, 1, 11, 1, 2024, 1, 2, 'Thursday', 'Q1', 'N', 'N', 'N', 1, 31, 11, 53, 'N', 'N', 'N', 'Y', 'Y'), 59 | (12, '20240112', '2024-01-12', 2, 2, 1, 2024, 5, 1, 12, 1, 2024, 1, 2, 'Friday', 'Q1', 'N', 'Y', 'N', 1, 31, 12, 53, 'N', 'N', 'N', 'N', 'Y'), 60 | (13, '20240113', '2024-01-13', 2, 2, 1, 2024, 6, 1, 13, 1, 2024, 1, 2, 'Saturday', 'Q1', 'N', 'Y', 'Y', 1, 31, 13, 53, 'N', 'N', 'N', 'N', 'Y'), 61 | (14, '20240114', '2024-01-14', 2, 2, 1, 2024, 7, 1, 14, 1, 2024, 1, 2, 'Sunday', 'Q1', 'N', 'Y', 'Y', 1, 31, 14, 53, 'N', 'N', 'N', 'N', 'Y'), 62 | (15, '20240115', '2024-01-15', 3, 3, 1, 2024, 1, 2, 15, 1, 2024, 1, 3, 'Monday', 'Q1', 'N', 'N', 'Y', 1, 29, 15, 1, 'Y', 'Y', 'Y', 'Y', 'Y'), 63 | (16, '20240116', '2024-01-16', 3, 3, 1, 2024, 2, 2, 16, 1, 2024, 1, 3, 'Tuesday', 'Q1', 'N', 'N', 'N', 1, 29, 16, 1, 'N', 'Y', 'Y', 'Y', 'Y'), 64 | (17, '20240117', '2024-01-17', 3, 3, 1, 2024, 3, 2, 17, 1, 2024, 1, 3, 'Wednesday', 'Q1', 'N', 'N', 'N', 1, 29, 17, 1, 'N', 'N', 'Y', 'Y', 'Y'), 65 | (18, '20240118', '2024-01-18', 3, 3, 1, 2024, 4, 2, 18, 1, 2024, 1, 3, 'Thursday', 'Q1', 'N', 'N', 'N', 1, 29, 18, 1, 'N', 'N', 'N', 'Y', 'Y'), 66 | (19, '20240119', '2024-01-19', 3, 3, 1, 2024, 5, 2, 19, 1, 2024, 1, 3, 'Friday', 'Q1', 'N', 'Y', 'N', 1, 29, 19, 1, 'N', 'N', 'N', 'N', 'Y'), 67 | (20, '20240120', '2024-01-20', 3, 3, 1, 2024, 6, 2, 20, 1, 2024, 1, 3, 'Saturday', 'Q1', 'N', 'Y', 'Y', 1, 29, 20, 1, 'N', 'N', 'N', 'N', 'Y'); 68 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/pushdown/PushdownLocalRelationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.pushdown.test 17 | 18 | class PushdownLocalRelationSuite extends IntegrationPushdownSuiteBase { 19 | // These tests cannot disable pushdown since insert happens in pushdown 20 | override protected val auto_pushdown: String = "true" 21 | // These tests cannot use cache since they check the result changing 22 | override val s3_result_cache: String = "false" 23 | 24 | test("Push down insert literal values") { 25 | withTempRedshiftTable("insertTable") { tableName => 26 | redshiftWrapper.executeUpdate(conn, 27 | s"CREATE TABLE ${tableName} (a int, b int)" 28 | ) 29 | read.option("dbtable", tableName).load.createOrReplaceTempView(tableName) 30 | val pre = sqlContext.sql(s"SELECT * FROM ${tableName}").count 31 | 32 | sqlContext.sql(s"INSERT INTO TABLE ${tableName} VALUES (1, 100), (3,2000)") 33 | 34 | checkSqlStatement( 35 | s"""INSERT INTO "PUBLIC"."$tableName" 36 | | SELECT ( CAST ( "SQ_1"."COL1" AS INTEGER ) ) AS "SQ_2_COL_0" , 37 | | ( CAST ( "SQ_1"."COL2" AS INTEGER ) ) AS "SQ_2_COL_1" 38 | | FROM ( ( (SELECT 1 AS "col1", 100 AS "col2") 39 | | UNION ALL (SELECT 3 AS "col1", 2000 AS "col2") ) ) AS "SQ_1"""".stripMargin 40 | ) 41 | 42 | val post = sqlContext.sql(s"SELECT * FROM ${tableName}").collect().map(row => row.toSeq).toSeq 43 | 44 | assert(pre == 0) 45 | val expected = Array(Array(1, 100), Array(3, 2000)) 46 | post should contain theSameElementsAs expected 47 | } 48 | } 49 | 50 | // Can not push down because Local Relation can not generate a SourceQuery 51 | test("Not able to push down simple select query with literal values") { 52 | val post = sqlContext.sql( 53 | s"""SELECT * FROM (VALUES (1, '1'), (2, '2'), (3, '3')) 54 | | AS tests(id, name)""".stripMargin).collect().map(row => row.toSeq).toSeq 55 | 56 | val expected = Array(Array(1, "1"), Array(2, "2"), Array(3, "3")) 57 | post should contain theSameElementsAs expected 58 | } 59 | 60 | test("Push down simple select query join with literal values") { 61 | withTempRedshiftTable("insertTable") { tableName => 62 | redshiftWrapper.executeUpdate(conn, 63 | s"CREATE TABLE ${tableName} (a int, b int)" 64 | ) 65 | read.option("dbtable", tableName).load.createOrReplaceTempView(tableName) 66 | val pre = sqlContext.sql(s"SELECT * FROM ${tableName}").count 67 | 68 | sqlContext.sql(s"INSERT INTO TABLE ${tableName} VALUES (1, 100), (3,2000)") 69 | 70 | val post = sqlContext.sql( 71 | s"""SELECT * FROM 72 | |(VALUES (1, 1000), (2, 2000), (3, 3000)) AS v(id, name) 73 | | JOIN ${tableName} AS t 74 | | ON v.id = t.a""".stripMargin).collect().map(row => row.toSeq).toSeq 75 | 76 | checkSqlStatement( 77 | s"""SELECT ( "SQ_0"."ID" ) AS "SQ_3_COL_0" , 78 | | ( "SQ_0"."NAME" ) AS "SQ_3_COL_1" , 79 | | ( "SQ_2"."A" ) AS "SQ_3_COL_2" , ( "SQ_2"."B" ) AS "SQ_3_COL_3" 80 | | FROM ( ( (SELECT 1 AS "id", 1000 AS "name") UNION ALL (SELECT 2 AS "id", 2000 AS "name") 81 | | UNION ALL (SELECT 3 AS "id", 3000 AS "name") ) ) AS "SQ_0" INNER JOIN 82 | | ( SELECT * FROM ( SELECT * FROM "PUBLIC"."${tableName}" AS "RCQ_ALIAS" ) 83 | | AS "SQ_1" WHERE ( "SQ_1"."A" IS NOT NULL ) ) AS "SQ_2" ON 84 | | ( "SQ_0"."ID" = "SQ_2"."A" )""".stripMargin 85 | ) 86 | 87 | assert(pre == 0) 88 | val expected = Array(Array(3, 3000, 3, 2000), Array(1, 1000, 1, 100)) 89 | post should contain theSameElementsAs expected 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/it/scala/io/github/spark_redshift_community/spark/redshift/OverrideNullableSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package io.github.spark_redshift_community.spark.redshift.test 17 | 18 | import io.github.spark_redshift_community.spark.redshift.Parameters 19 | import org.apache.spark.sql._ 20 | import org.apache.spark.sql.types._ 21 | 22 | trait OverrideNullableSuite extends IntegrationSuiteBase { 23 | test("read empty strings as null when overridenullable is true") { 24 | withTempRedshiftTable("overrideNullable") { name => 25 | redshiftWrapper.executeUpdate( 26 | conn, s"create table $name (name text not null, nullable_name text)") 27 | redshiftWrapper.executeUpdate(conn, s"insert into $name values ('', '')") 28 | val df = read 29 | .option(Parameters.PARAM_OVERRIDE_NULLABLE, true) 30 | .option("dbtable", name) 31 | .load 32 | checkAnswer( 33 | df, 34 | Seq(Row(null, null)) 35 | ) 36 | assert(df.schema match { 37 | case StructType(Array(StructField(_, StringType, true, _), 38 | StructField(_, StringType, true, _))) => true 39 | case _ => false 40 | }) 41 | } 42 | } 43 | 44 | test("read empty strings as null when overridenullable is true and unload_s3_format is TEXT") { 45 | withTempRedshiftTable("overrideNullable") { name => 46 | redshiftWrapper.executeUpdate( 47 | conn, s"create table $name (name text not null, nullable_name text)") 48 | redshiftWrapper.executeUpdate(conn, s"insert into $name values ('', '')") 49 | val df = read 50 | .option(Parameters.PARAM_OVERRIDE_NULLABLE, true) 51 | .option("dbtable", name) 52 | .option("unload_s3_format", "TEXT") 53 | .load 54 | checkAnswer( 55 | df, 56 | Seq(Row(null, null)) 57 | ) 58 | assert(df.schema match { 59 | case StructType(Array(StructField(_, StringType, true, _), 60 | StructField(_, StringType, true, _))) => true 61 | case _ => false 62 | }) 63 | } 64 | } 65 | 66 | test("read empty strings as empty strings when overridenullable is false") { 67 | withTempRedshiftTable("overrideNullable") { name => 68 | redshiftWrapper.executeUpdate( 69 | conn, s"create table $name (name text not null, nullable_name text)") 70 | redshiftWrapper.executeUpdate(conn, s"insert into $name values ('', '')") 71 | val df = read 72 | .option(Parameters.PARAM_OVERRIDE_NULLABLE, false) 73 | .option("dbtable", name) 74 | .load 75 | checkAnswer( 76 | df, 77 | Seq(Row("", "")) 78 | ) 79 | assert(df.schema match { 80 | case StructType(Array(StructField(_, StringType, false, _), 81 | StructField(_, StringType, true, _))) => true 82 | case _ => false 83 | }) 84 | } 85 | } 86 | 87 | test("read empty strings as empty strings when overridenullable" + 88 | " is false and unload_s3_format is TEXT") { 89 | withTempRedshiftTable("overrideNullable") { name => 90 | redshiftWrapper.executeUpdate( 91 | conn, s"create table $name (name text not null, nullable_name text)") 92 | redshiftWrapper.executeUpdate(conn, s"insert into $name values ('', '')") 93 | val df = read 94 | .option(Parameters.PARAM_OVERRIDE_NULLABLE, false) 95 | .option("dbtable", name) 96 | .option("unload_s3_format", "TEXT") 97 | .load 98 | checkAnswer( 99 | df, 100 | Seq(Row("", "")) 101 | ) 102 | assert(df.schema match { 103 | case StructType(Array(StructField(_, StringType, false, _), 104 | StructField(_, StringType, true, _))) => true 105 | case _ => false 106 | }) 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/AggregationStatement.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015-2018 Snowflake Computing 3 | * Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package io.github.spark_redshift_community.spark.redshift.pushdown.querygeneration 19 | 20 | import io.github.spark_redshift_community.spark.redshift.{RedshiftFailMessage, RedshiftPushdownUnsupportedException, Utils} 21 | import io.github.spark_redshift_community.spark.redshift.pushdown.{ConstantString, EmptyRedshiftSQLStatement, RedshiftSQLStatement} 22 | import org.apache.spark.sql.catalyst.expressions._ 23 | import org.apache.spark.sql.catalyst.expressions.aggregate._ 24 | import org.apache.spark.sql.types.{BooleanType, DecimalType, DoubleType, FloatType} 25 | 26 | import scala.language.postfixOps 27 | 28 | /** 29 | * Extractor for aggregate-style expressions. 30 | */ 31 | private[querygeneration] object AggregationStatement { 32 | def unapply( 33 | expAttr: (Expression, Seq[Attribute]) 34 | ): Option[RedshiftSQLStatement] = { 35 | val expr = expAttr._1 36 | val fields = expAttr._2 37 | 38 | expr match { 39 | case _: AggregateExpression => 40 | // Take only the first child, as all of the functions below have only one. 41 | expr.children.headOption.flatMap(agg_fun => { 42 | val distinct: RedshiftSQLStatement = 43 | if (expr.sql contains "(DISTINCT ") ConstantString("DISTINCT") ! 44 | else EmptyRedshiftSQLStatement() 45 | Option(agg_fun match { 46 | case Max(child) if child.dataType == BooleanType => 47 | throw new RedshiftPushdownUnsupportedException( 48 | RedshiftFailMessage.FAIL_PUSHDOWN_AGGREGATE_EXPRESSION, 49 | s"${agg_fun.prettyName} @ AggregationStatement", 50 | "MAX(Boolean) is not defined in redshift", 51 | true 52 | ) 53 | case Min(child) if child.dataType == BooleanType => 54 | throw new RedshiftPushdownUnsupportedException( 55 | RedshiftFailMessage.FAIL_PUSHDOWN_AGGREGATE_EXPRESSION, 56 | s"${agg_fun.prettyName} @ AggregationStatement", 57 | "MIN(Boolean) is not defined in redshift", 58 | true 59 | ) 60 | case _: Count | _: Max | _: Min | _: Sum | _: StddevSamp | _: StddevPop | 61 | _: VariancePop | _: VarianceSamp => 62 | ConstantString(agg_fun.prettyName.toUpperCase) + 63 | blockStatement( 64 | distinct + convertStatements(fields, agg_fun.children: _*) 65 | ) 66 | case avg: Average => 67 | // Type casting is needed if column type is short, int, long or decimal with scale 0. 68 | // Because Redshift and Spark have different behavior on AVG on these types, type 69 | // should be casted to float to keep result numbers after decimal point. 70 | val doCast: Boolean = avg.child.dataType match { 71 | case _: FloatType | DoubleType => false 72 | case d: DecimalType if d.scale != 0 => false 73 | case _ => true 74 | } 75 | ConstantString(agg_fun.prettyName.toUpperCase) + 76 | blockStatement( 77 | distinct + convertStatements(fields, agg_fun.children: _*) + 78 | (if (doCast) ConstantString("::FLOAT") ! else EmptyRedshiftSQLStatement()) 79 | ) 80 | case _ => 81 | // This exception is not a real issue. It will be caught in 82 | // QueryBuilder.treeRoot. 83 | throw new RedshiftPushdownUnsupportedException( 84 | RedshiftFailMessage.FAIL_PUSHDOWN_AGGREGATE_EXPRESSION, 85 | s"${agg_fun.prettyName} @ AggregationStatement", 86 | agg_fun.sql, 87 | false 88 | ) 89 | }) 90 | }) 91 | case _ => None 92 | } 93 | } 94 | } 95 | --------------------------------------------------------------------------------