├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY-LICENSES ├── docs └── images │ └── s3-connector-overview.png ├── pom.xml ├── scalastyle-config.xml └── src ├── main ├── resources │ ├── META-INF │ │ └── services │ │ │ └── org.apache.spark.sql.sources.DataSourceRegister │ └── log4j.properties └── scala │ ├── com │ └── amazonaws │ │ └── spark │ │ └── sql │ │ └── streaming │ │ └── connector │ │ ├── ConnectorAwsCredentialsProvider.scala │ │ ├── S3ConnectorException.scala │ │ ├── S3ConnectorFileCache.scala │ │ ├── S3ConnectorFileValidator.scala │ │ ├── S3ConnectorModel.scala │ │ ├── S3ConnectorSource.scala │ │ ├── S3ConnectorSourceOptions.scala │ │ ├── S3ConnectorSourceProvider.scala │ │ ├── Utils.scala │ │ ├── client │ │ ├── AsyncClientBuilder.scala │ │ ├── AsyncClientMetrics.scala │ │ ├── AsyncQueueClient.scala │ │ ├── AsyncSqsClientBuilder.scala │ │ ├── AsyncSqsClientImpl.scala │ │ └── AsyncSqsClientMetricsImpl.scala │ │ └── metadataLog │ │ ├── RocksDBS3SourceLog.scala │ │ └── S3MetadataLog.scala │ └── org │ └── apache │ └── spark │ └── sql │ └── streaming │ └── connector │ └── s3 │ ├── RocksDB.scala │ ├── RocksDBFileManager.scala │ ├── RocksDBLoader.scala │ ├── RocksDBStateEncoder.scala │ └── S3SparkUtils.scala └── test ├── java └── it │ └── spark │ └── sql │ └── streaming │ └── connector │ └── IntegrationTestSuite.java ├── resources └── log4j.properties └── scala ├── com └── amazonaws │ └── spark │ └── sql │ └── streaming │ └── connector │ ├── S3ConnectorFileCacheSuite.scala │ ├── S3ConnectorFileValidatorSuite.scala │ ├── S3ConnectorSourceOptionsSuite.scala │ ├── S3ConnectorTestBase.scala │ ├── TestUtils.scala │ ├── client │ └── AsyncSqsClientSuite.scala │ └── metadataLog │ └── RocksDBS3SourceLogSuite.scala ├── it └── spark │ └── sql │ └── streaming │ └── connector │ ├── ItTestUtils.scala │ ├── QueueTestBase.scala │ ├── S3ConnectorItBase.scala │ ├── S3ConnectorSourceCrossAccountItSuite.scala │ ├── S3ConnectorSourceItSuite.scala │ ├── TestForeachWriter.scala │ └── client │ └── AsyncSqsClientItSuite.scala └── pt └── spark └── sql └── streaming └── connector ├── DataConsumer.scala ├── DataGenerator.scala ├── DataValidator.scala ├── FileSourceConsumer.scala └── TestTool.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | .DS_Store 3 | .idea/ 4 | .idea_modules/ 5 | spark-warehouse/ 6 | target/ 7 | checkpoints/ 8 | s3checkpoints/ 9 | dependency-reduced-pom.xml 10 | testDir/ 11 | ptcheckpoint*/ 12 | ptconsumer*/ 13 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Apache Spark Structured Streaming S3 Connector 2 | 3 | An Apache Spark Structured Streaming S3 connector for reading S3 files using Amazon S3 event notifications to AWS SQS. 4 | 5 | ## Archicture Overview 6 | 7 | ![s3-connector](./docs/images/s3-connector-overview.png) 8 | 9 | 1. Configure [Amazon S3 Event Notifications](https://docs.aws.amazon.com/AmazonS3/latest/userguide/NotificationHowTo.html) to send `s3:ObjectCreated:*` events with specified prefix to SQS 10 | 2. The S3 connector discovers new files via `ObjectCreated` S3 events in AWS SQS. 11 | 3. The files' metadata are persisted in RocksDB in the checkpoint location together with Spark Structured Streaming engine maintained offset. This ensures that data is ingested exactly once. (End to end exactly once requires the data sink to be idempotent.) 12 | 4. Driver distributes the S3 file list to executors 13 | 5. Executors read the S3 files 14 | 6. After successful data sink processing, Spark Structured Streaming engine commit the batch 15 | 16 | The RocksDB used by this connector is self-contained. The Spark structured streaming application using this connector is free to use any state store backend. 17 | 18 | ## How to build 19 | **Prerequisite**: [install Rocksdb](https://github.com/facebook/rocksdb/blob/main/INSTALL.md) 20 | 21 | Clone `spark-streaming-sql-s3-connector` from the source repository on GitHub. 22 | 23 | ``` 24 | git clone https://github.com/aws-samples/spark-streaming-sql-s3-connector.git 25 | mvn clean install -DskipTests 26 | ``` 27 | This will create *target/spark-streaming-sql-s3-connector-.jar* file which contains the connector code and its dependencies. The jar file will also be installed to local maven repository. 28 | 29 | The jar file can also be downloaded at https://awslabs-code-us-east-1.s3.amazonaws.com/spark-streaming-sql-s3-connector/spark-streaming-sql-s3-connector-0.0.1.jar. Change the jar file name based on version 30 | 31 | Current version is compatible with Spark 3.2 and above. 32 | 33 | ## How to test 34 | 35 | * To run the unit tests 36 | ``` 37 | mvn test 38 | ``` 39 | 40 | * To run the integration tests 41 | * Export the following environment variables with your values: 42 | ``` 43 | export AWS_ACCESS_KEY_ID="" 44 | export AWS_SECRET_ACCESS_KEY="" 45 | export AWS_REGION= 46 | 47 | export TEST_UPLOAD_S3_PATH= 48 | export TEST_REGION= 49 | export TEST_QUEUE_URL=/> 50 | 51 | export CROSS_ACCOUNT_TEST_UPLOAD_S3_PATH= 52 | export CROSS_ACCOUNT_TEST_REGION= 53 | export CROSS_ACCOUNT_TEST_QUEUE_URL= 54 | ``` 55 | * run `mvn test -Pintegration-test` or `mvn test -Pintegration-test -Dsuites='it.spark.sql.streaming.connector.S3ConnectorSourceSqsRocksDBItSuite'` (S3ConnectorSourceSqsRocksDBItSuite only ) 56 | 57 | To setup cross account access: (following assumes S3 connector runs in account A to access S3 and SQS in account B) 58 | 59 | 1. Add following to account B's S3 bucket policy 60 | ```json 61 | { 62 | "Version": "2012-10-17", 63 | "Statement": [ 64 | { 65 | "Sid": "cross account bucket", 66 | "Effect": "Allow", 67 | "Principal": { 68 | "AWS": [ 69 | "arn:aws:iam:::user/" 70 | ] 71 | }, 72 | "Action": [ 73 | "s3:GetLifecycleConfiguration", 74 | "s3:ListBucket" 75 | ], 76 | "Resource": "arn:aws:s3:::" 77 | }, 78 | { 79 | "Sid": "cross account object", 80 | "Effect": "Allow", 81 | "Principal": { 82 | "AWS": [ 83 | "arn:aws:iam:::user/" 84 | ] 85 | }, 86 | "Action": [ 87 | "s3:*" 88 | ], 89 | "Resource": "arn:aws:s3:::/*" 90 | } 91 | ] 92 | } 93 | ``` 94 | 95 | 2. Create a new SQS queue and add following to SQS access policy 96 | ```json 97 | { 98 | "Sid": "__crossaccount_statement", 99 | "Effect": "Allow", 100 | "Principal": { 101 | "AWS": [ 102 | "arn:aws:iam:::user/" 103 | ] 104 | }, 105 | "Action": [ 106 | "sqs:ChangeMessageVisibility", 107 | "sqs:DeleteMessage", 108 | "sqs:GetQueueAttributes", 109 | "sqs:PurgeQueue", 110 | "sqs:ReceiveMessage", 111 | "sqs:SendMessage" 112 | ], 113 | "Resource": "" 114 | } 115 | ``` 116 | 117 | 3. Configure account B's S3 bucket to send event notifications to the new SQS queue 118 | 119 | ## How to use 120 | 121 | After the connector jar is install in local Maven repository, configure your project pom.xml (use Maven as an example ): 122 | 123 | ```xml 124 | 125 | com.amazonaws 126 | spark-streaming-sql-s3-connector 127 | {version} 128 | 129 | ``` 130 | 131 | Code example 132 | 133 | ```scala 134 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions._ 135 | 136 | val connectorOptions = spark.sqlContext.getAllConfs ++ Map( 137 | QUEUE_REGION -> "", 138 | S3_FILE_FORMAT -> "csv", 139 | MAX_FILES_PER_TRIGGER -> "500", 140 | MAX_FILE_AGE->"15d", 141 | QUEUE_URL -> "/ >", 142 | QUEUE_FETCH_WAIT_TIMEOUT_SECONDS -> "10", 143 | SQS_LONG_POLLING_WAIT_TIME_SECONDS -> "5", 144 | SQS_VISIBILITY_TIMEOUT_SECONDS -> "60", 145 | PATH_GLOB_FILTER -> "*.csv", 146 | PARTITION_COLUMNS -> "valPartition", 147 | BASE_PATH -> "root path of S3 files" 148 | ) 149 | 150 | val testSchemaWithPartition: StructType = StructType(Array( 151 | StructField("valString", StringType, nullable = true), 152 | StructField("valBoolean", BooleanType, nullable = true), 153 | StructField("valDouble", DoubleType, nullable = true), 154 | StructField("valInt", IntegerType, nullable = true), 155 | StructField("valPartition", StringType, nullable = false), 156 | )) 157 | 158 | val inputDf = spark 159 | .readStream 160 | .format(SOURCE_SHORT_NAME) 161 | .schema(testSchemaWithPartition) 162 | .options(connectorOptions) 163 | .load() 164 | ``` 165 | 166 | A full running example is at `src/test/scala/pt/spark/sql/streaming/connector/DataConsumer.scala` which is included in `spark-streaming-sql-s3-connector--tests.jar`. 167 | 168 | Use following command to submit to Spark on Amazon EMR (Assume `spark-streaming-sql-s3-connector-.jar` and `spark-streaming-sql-s3-connector--tests.jar` are copied to EMR master node and under current directory). 169 | 170 | ```bash 171 | spark-submit --class pt.spark.sql.streaming.connector.DataConsumer --deploy-mode cluster --jars spark-streaming-sql-s3-connector-.jar spark-streaming-sql-s3-connector--tests.jar csv 172 | ``` 173 | This application receives s3 event notfication from ``, reads the new files from `` and save the result to ``. 174 | 175 | If run with OSS Spark, `spark-submit` needs to add s3a related configurations, e.g. 176 | 177 | ``` 178 | --conf spark.hadoop.fs.s3.impl=org.apache.hadoop.fs.s3a.S3AFileSystem --conf spark.hadoop.fs.s3a.aws.credentials.provider=com.amazonaws.auth.EnvironmentVariableCredentialsProvider 179 | ``` 180 | 181 | Note: `spark.hadoop.fs.s3.impl` instead of `spark.hadoop.fs.s3a.impl` is used so that s3a can read `s3://` prefixed file paths. 182 | 183 | Run below to generate the S3 test files to be consumed by `pt.spark.sql.streaming.connector.DataConsumer` 184 | 185 | ```bash 186 | spark-submit --class pt.spark.sql.streaming.connector.DataGenerator --jars ~/spark-streaming-sql-s3-connector--SNAPSHOT.jar ~/spark-streaming-sql-s3-connector--SNAPSHOT-tests.jar 187 | ``` 188 | 189 | ## How to configure 190 | Spark Structured Streaming S3 connector supports the following settings. 191 | 192 | Name | Default | Description 193 | --- |:----------------------------------------| --- 194 | spark.s3conn.fileFormat| required, no default value |file format for the s3 files stored on Amazon S3 195 | spark.s3conn.queueRegion| required, no default value |AWS region where queue is created 196 | spark.s3conn.queueUrl| required, no default value |SQS queue url, e.g. https://sqs.us-east-1.amazonaws.com// 197 | spark.s3conn.queueType| required, SQS | only support SQS 198 | spark.s3conn.queueFetchWaitTimeoutSeconds| requred, 2 * longPollingWaitTimeSeconds |wait time (in seconds) for fetching messages from SQS at each trigger. Message fetching is finished either messages fetched is greater than maxFilerPerTrigger or queueFetchWaitTimeoutSeconds expires. 199 | spark.s3conn.maxFilesPerTrigger| required,100 |maximum number of files to process in a microbatch. -1 for unlimited 200 | spark.s3conn.maxFileAge| required,15d |maximum age of a file that can be stored in RocksDB. Files older than this will be ignored. 201 | spark.s3conn.pathGlobFilter| optional |only include S3 files with file names matching the pattern. 202 | spark.s3conn.partitionColumns| optional |comma seperated partition columns. Partition columns must be defined in the schema. Use together with "basePath" option to read from S3 folder with partitions. For example, for file s3:///testdatarootpath/part1=A/part2=B/testdata.csv, set "spark.s3conn.partitionColumns" -> "part1,part2", "basePath" -> "s3:///testdatarootpath/" 203 | spark.s3conn.reprocessStartBatchId| optional |start batch id for a reprocess run (inclusive). Note: a reprocess run will not consume new messages in SQS. 204 | spark.s3conn.reprocessEndBatchId| optional |end batch id for a reprocess run (inclusive) 205 | spark.s3conn.reprocessDryRun| optional,true |a dry run to list all the files to be reprocessed. 206 | spark.s3conn.sqs.longPollingWaitTimeSeconds| 10 |wait time (in seconds) for SQS client long polling 207 | spark.s3conn.sqs.maxConcurrency| 50 |number of parallel connections to Amazon SQS queue 208 | spark.s3conn.sqs.maxRetries| 10 |maximum number of consecutive retries in case of a SQS client connection failure before giving up 209 | spark.s3conn.sqs.visibilityTimeoutSeconds| 60 |SQS message visibility timeout 210 | spark.s3conn.sqs.keepMessageForConsumerError| false |when set to true, the invalid messages with following reasons in SQS will be kept 1. the file is expired 2. the file doesn't match glob pattern 3. the file is already processed and persisted in RocksDB. This can be used with DLQ when for debug purpose 211 | 212 | ## How to use S3 event notifications for multiple applications 213 | 214 | If one S3 path's event notifications need to be consumed by multiple Spark Structured Streaming applications, SNS can be used to fanout to Amazon SQS queues. The message flow is S3 event notifications -> SNS -> SQS. When an S3 event notification is published to the SNS topic, Amazon SNS sends the notification to each of the subscribed SQS queues. 215 | 216 | ## Security 217 | 218 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 219 | 220 | ## Acknowledgement 221 | 222 | Reference implementation [Apache Bahir Spark SQL Streaming Amazon SQS Data Source](https://github.com/apache/bahir/tree/master/sql-streaming-sqs). 223 | 224 | RocksDB related code is reusing the work done by [Apache Spark](https://github.com/apache/spark). 225 | -------------------------------------------------------------------------------- /docs/images/s3-connector-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/spark-streaming-sql-s3-connector/fbc6de2cee2017e2a72b22e13bcf699be6298bba/docs/images/s3-connector-overview.png -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 18 | 19 | 4.0.0 20 | 21 | com.amazonaws 22 | spark-streaming-sql-s3-connector 23 | 0.0.2 24 | jar 25 | Spark Structure Streaming S3 Connector 26 | 27 | 28 | s3-sqs-connector 29 | 3.2.1 30 | 3.0.1 31 | 2.12 32 | 4.6.3 33 | 2.0.2 34 | UTF-8 35 | 2.19.21 36 | 4.1.36 37 | 8 38 | 8 39 | 40 | 41 | 42 | 43 | 44 | software.amazon.awssdk 45 | bom 46 | ${aws.java.sdk2.version} 47 | pom 48 | import 49 | 50 | 51 | 52 | 53 | 54 | 55 | software.amazon.awssdk 56 | sqs 57 | ${aws.java.sdk2.version} 58 | 59 | 60 | software.amazon.awssdk 61 | sts 62 | ${aws.java.sdk2.version} 63 | 64 | 65 | software.amazon.awssdk 66 | netty-nio-client 67 | ${aws.java.sdk2.version} 68 | 69 | 70 | io.dropwizard.metrics 71 | metrics-core 72 | ${metrics.version} 73 | 74 | 75 | org.apache.spark 76 | spark-sql_${scala.binary.version} 77 | ${spark.version} 78 | provided 79 | 80 | 81 | org.apache.spark 82 | spark-tags_${scala.binary.version} 83 | ${spark.version} 84 | provided 85 | 86 | 87 | org.apache.spark 88 | spark-core_${scala.binary.version} 89 | ${spark.version} 90 | test-jar 91 | test 92 | 93 | 94 | org.apache.spark 95 | spark-catalyst_${scala.binary.version} 96 | ${spark.version} 97 | test-jar 98 | test 99 | 100 | 101 | org.apache.spark 102 | spark-sql_${scala.binary.version} 103 | ${spark.version} 104 | test-jar 105 | test 106 | 107 | 108 | org.apache.hadoop 109 | hadoop-aws 110 | ${hadoop.version} 111 | 112 | 113 | com.amazonaws 114 | aws-java-sdk-core 115 | 116 | 117 | com.amazonaws 118 | aws-java-sdk-sqs 119 | 120 | 121 | test 122 | 123 | 124 | org.scalatest 125 | scalatest_${scala.binary.version} 126 | 3.2.9 127 | test 128 | 129 | 130 | org.scalatestplus 131 | scalacheck-1-15_${scala.binary.version} 132 | 3.2.9.0 133 | test 134 | 135 | 136 | org.scalatestplus 137 | mockito-3-4_${scala.binary.version} 138 | 3.2.9.0 139 | test 140 | 141 | 142 | software.amazon.awssdk 143 | s3 144 | 145 | 146 | software.amazon.awssdk 147 | netty-nio-client 148 | 149 | 150 | software.amazon.awssdk 151 | apache-client 152 | 153 | 154 | ${aws.java.sdk2.version} 155 | test 156 | 157 | 158 | software.amazon.awssdk 159 | apache-client 160 | 161 | 162 | commons-logging 163 | commons-logging 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | net.alchim31.maven 173 | scala-maven-plugin 174 | ${scala.maven.version} 175 | 176 | 177 | compile 178 | 179 | compile 180 | add-source 181 | doc-jar 182 | testCompile 183 | 184 | compile 185 | 186 | 187 | 188 | 189 | org.apache.maven.plugins 190 | maven-shade-plugin 191 | 3.2.1 192 | 193 | 194 | package 195 | 196 | shade 197 | 198 | 199 | 200 | 201 | software.amazon.awssdk:sqs:* 202 | software.amazon.awssdk:sdk-core:* 203 | software.amazon.awssdk:utils:* 204 | software.amazon.awssdk:annotations:* 205 | software.amazon.awssdk:apache-client:* 206 | software.amazon.awssdk:arns:* 207 | software.amazon.awssdk:auth:* 208 | software.amazon.awssdk:sts:* 209 | software.amazon.awssdk:netty-nio-client:* 210 | software.amazon.awssdk:http-auth-spi:* 211 | software.amazon.awssdk:http-auth-aws:* 212 | software.amazon.awssdk:http-auth:* 213 | software.amazon.awssdk:aws-core:* 214 | software.amazon.awssdk:aws-query-protocol:* 215 | software.amazon.awssdk:aws-xml-protocol:* 216 | software.amazon.awssdk:endpoints-spi:* 217 | software.amazon.awssdk:http-client-spi:* 218 | software.amazon.awssdk:json-utils:* 219 | software.amazon.awssdk:metrics-spi:* 220 | software.amazon.awssdk:netty-nio-client:* 221 | software.amazon.awssdk:profiles:* 222 | software.amazon.awssdk:protocol-core:* 223 | software.amazon.awssdk:regions:* 224 | software.amazon.awssdk:third-party-jackson-core:* 225 | io.dropwizard.metrics:metrics-core:* 226 | org.reactivestreams:reactive-streams:* 227 | io.netty:netty-common:* 228 | io.netty:netty-buffer:* 229 | io.netty:netty-codec:* 230 | io.netty:netty-codec-http:* 231 | io.netty:netty-codec-http2:* 232 | io.netty:netty-handler:* 233 | io.netty:netty-resolver:* 234 | io.netty:netty-transport:* 235 | io.netty:netty-transport-classes-epoll:* 236 | io.netty:netty-transport-native-unix-common:* 237 | 238 | 239 | 240 | 241 | *:* 242 | 243 | META-INF/maven/** 244 | META-INF/MANIFEST.MF 245 | 246 | 247 | 248 | 249 | 250 | software.amazon.awssdk 251 | s3connector.software.amazon.awssdk 252 | 253 | 254 | com.codahale.metrics 255 | s3connector.com.codahale.metrics 256 | 257 | 258 | org.reactivestreams 259 | s3connector.org.reactivestreams 260 | 261 | 262 | io.netty 263 | s3connector.io.netty 264 | 265 | 266 | 267 | 268 | 269 | log4j.properties 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | org.apache.maven.plugins 280 | maven-jar-plugin 281 | 3.1.0 282 | 283 | 284 | 285 | test-jar 286 | 287 | 288 | 289 | 290 | 291 | org.scalatest 292 | scalatest-maven-plugin 293 | ${scalatest-maven-plugin.version} 294 | 295 | 296 | 297 | 298 | 299 | org.apache.maven.plugins 300 | maven-shade-plugin 301 | 302 | 303 | org.apache.maven.plugins 304 | maven-jar-plugin 305 | 306 | 307 | target/scala-${scala.binary.version}/classes 308 | target/scala-${scala.binary.version}/test-classes 309 | 310 | 311 | 312 | 313 | default-test 314 | 315 | true 316 | 317 | 318 | 319 | 320 | net.alchim31.maven 321 | scala-maven-plugin 322 | 323 | 324 | org.scalatest 325 | scalatest-maven-plugin 326 | 327 | ${project.build.directory}/unittest-reports 328 | . 329 | WDF UnitTestSuite.txt 330 | it.spark.sql.streaming.connector.IntegrationTestSuite 331 | 332 | 333 | 334 | test 335 | 336 | test 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | integration-test 346 | 347 | false 348 | 349 | 350 | 351 | 352 | net.alchim31.maven 353 | scala-maven-plugin 354 | 355 | 356 | org.scalatest 357 | scalatest-maven-plugin 358 | 359 | ${project.build.directory}/integrationtest-reports 360 | . 361 | WDF IntegrationTestSuite.txt 362 | it.spark.sql.streaming.connector.IntegrationTestSuite 363 | 364 | 365 | 366 | test 367 | 368 | test 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 17 | 39 | 40 | 41 | Scalastyle standard configuration 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | true 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW 126 | 127 | 128 | 129 | 130 | 131 | ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | ^FunSuite[A-Za-z]*$ 141 | Tests must extend org.apache.spark.SparkFunSuite instead. 142 | 143 | 144 | 145 | 146 | ^println$ 147 | 151 | 152 | 153 | 154 | @VisibleForTesting 155 | 158 | 159 | 160 | 161 | Runtime\.getRuntime\.addShutdownHook 162 | 170 | 171 | 172 | 173 | mutable\.SynchronizedBuffer 174 | 182 | 183 | 184 | 185 | Class\.forName 186 | 193 | 194 | 195 | 196 | Await\.result 197 | 204 | 205 | 206 | 207 | Await\.ready 208 | 215 | 216 | 217 | 218 | 219 | JavaConversions 220 | Instead of importing implicits in scala.collection.JavaConversions._, import 221 | scala.collection.JavaConverters._ and use .asScala / .asJava methods 222 | 223 | 224 | 225 | org\.apache\.commons\.lang\. 226 | Use Commons Lang 3 classes (package org.apache.commons.lang3.*) instead 227 | of Commons Lang 2 (package org.apache.commons.lang.*) 228 | 229 | 230 | 231 | extractOpt 232 | Use Utils.jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter 233 | is slower. 234 | 235 | 236 | 237 | 238 | java,scala,3rdParty,spark 239 | javax?\..* 240 | scala\..* 241 | (?!org\.apache\.spark\.).* 242 | org\.apache\.spark\..* 243 | 244 | 245 | 246 | 247 | 248 | COMMA 249 | 250 | 251 | 252 | 253 | 254 | \)\{ 255 | 258 | 259 | 260 | 261 | (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] 262 | Use Javadoc style indentation for multiline comments 263 | 264 | 265 | 266 | case[^\n>]*=>\s*\{ 267 | Omit braces in case clauses. 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 800> 321 | 322 | 323 | 324 | 325 | 30 326 | 327 | 328 | 329 | 330 | 30 331 | 332 | 333 | 334 | 335 | 50 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | -1,0,1,2,3 347 | 348 | 349 | 350 | -------------------------------------------------------------------------------- /src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceProvider -------------------------------------------------------------------------------- /src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | log4j.rootCategory=INFO, console 19 | 20 | # File appender 21 | log4j.appender.file=org.apache.log4j.FileAppender 22 | log4j.appender.file.append=false 23 | log4j.appender.file.file=target/unit-tests.log 24 | log4j.appender.file.layout=org.apache.log4j.PatternLayout 25 | log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n 26 | 27 | # Console appender 28 | log4j.appender.console=org.apache.log4j.ConsoleAppender 29 | log4j.appender.console.target=System.out 30 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 31 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 32 | 33 | # Settings to quiet third party logs that are too verbose 34 | log4j.logger.org.sparkproject.jetty=WARN 35 | log4j.logger.org.sparkproject.jetty.util.component.AbstractLifeCycle=ERROR 36 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 37 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 38 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/ConnectorAwsCredentialsProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector 18 | 19 | import java.io.Closeable 20 | 21 | import scala.annotation.tailrec 22 | import scala.util.{Failure, Success, Try} 23 | 24 | import software.amazon.awssdk.auth.credentials.{AwsCredentials, AwsCredentialsProvider, DefaultCredentialsProvider} 25 | 26 | /** 27 | * Serializable interface providing a method executors can call to obtain an 28 | * AWSCredentialsProvider instance for authenticating to AWS services. 29 | */ 30 | sealed trait ConnectorAwsCredentialsProvider extends Serializable with Closeable { 31 | def provider: AwsCredentialsProvider 32 | override def close(): Unit = {} 33 | } 34 | 35 | case class RetryableDefaultCredentialsProvider() extends AwsCredentialsProvider with Closeable { 36 | // private val provider = DefaultCredentialsProvider.create() 37 | 38 | private val provider = DefaultCredentialsProvider.builder() 39 | .asyncCredentialUpdateEnabled(true) 40 | .build() 41 | 42 | private val MAX_ATTEMPT = 10 43 | private val SLEEP_TIME = 1000 44 | 45 | override def resolveCredentials(): AwsCredentials = { 46 | @tailrec 47 | def getCredentialsWithRetry(retries: Int): AwsCredentials = { 48 | Try { 49 | provider.resolveCredentials() 50 | } match { 51 | case Success(credentials) => 52 | credentials 53 | case Failure(_) if retries > 0 => 54 | Thread.sleep(SLEEP_TIME) 55 | getCredentialsWithRetry(retries - 1) // Recursive call to retry 56 | case Failure(exception) => 57 | throw exception 58 | } 59 | } 60 | 61 | getCredentialsWithRetry(MAX_ATTEMPT) 62 | } 63 | 64 | override def close(): Unit = { 65 | provider.close() 66 | } 67 | } 68 | 69 | case class ConnectorDefaultCredentialsProvider() extends ConnectorAwsCredentialsProvider { 70 | 71 | private var providerOpt: Option[RetryableDefaultCredentialsProvider] = None 72 | override def provider: AwsCredentialsProvider = { 73 | if (providerOpt.isEmpty) { 74 | providerOpt = Some(RetryableDefaultCredentialsProvider()) 75 | } 76 | providerOpt.get 77 | } 78 | 79 | override def close(): Unit = { 80 | providerOpt.foreach(_.close()) 81 | } 82 | } 83 | 84 | 85 | class Builder { 86 | def build(): ConnectorAwsCredentialsProvider = { 87 | ConnectorDefaultCredentialsProvider() 88 | } 89 | } 90 | 91 | object ConnectorAwsCredentialsProvider { 92 | def builder: Builder = new Builder 93 | } 94 | 95 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorException.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.amazonaws.spark.sql.streaming.connector 19 | 20 | class S3ConnectorException (msg: String) extends RuntimeException(msg) {} 21 | class S3ConnectorNoSchemaException (msg: String = "Schema not defined.") extends S3ConnectorException(msg) {} 22 | class S3ConnectorUnsupportedQueueTypeException (msg: String) extends S3ConnectorException(msg) {} 23 | class S3ConnectorMetalogAddException (msg: String) extends S3ConnectorException(msg) {} 24 | class S3ConnectorReprocessException (msg: String) extends S3ConnectorException(msg) {} 25 | class S3ConnectorReprocessDryRunException (msg: String) extends S3ConnectorReprocessException(msg) {} 26 | class S3ConnectorReprocessLockExistsException (msg: String) extends S3ConnectorReprocessDryRunException(msg) {} -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorFileCache.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.amazonaws.spark.sql.streaming.connector 19 | 20 | import java.util.concurrent.ConcurrentHashMap 21 | import java.util.concurrent.atomic.AtomicLong 22 | 23 | import scala.collection.JavaConverters._ 24 | import scala.collection.mutable.ArrayBuffer 25 | import scala.util.control.Breaks.{break, breakable} 26 | 27 | import com.amazonaws.spark.sql.streaming.connector.FileCacheNewFileResults.FileCacheNewFileResult 28 | import com.amazonaws.spark.sql.streaming.connector.Utils.reportTimeTaken 29 | 30 | import org.apache.spark.internal.Logging 31 | 32 | /** 33 | * A custom hash map used to track the list of files not processed yet. This map is thread-safe. 34 | */ 35 | 36 | class S3ConnectorFileCache[T](maxFileAgeMs: Long) extends Logging { 37 | 38 | require(maxFileAgeMs >= 0) 39 | 40 | /** Mapping from file path to its message description. */ 41 | private val fileMap = new ConcurrentHashMap[String, QueueMessageDesc[T]] 42 | 43 | /** Timestamp for the last purge operation. */ 44 | def lastPurgeTimestamp: Long = _lastPurgeTimestamp.get() 45 | private val _lastPurgeTimestamp: AtomicLong = new AtomicLong(0L) 46 | 47 | /** Timestamp of the latest file. */ 48 | private def latestTimestamp: Long = _latestTimestamp.get() 49 | private val _latestTimestamp: AtomicLong = new AtomicLong(0L) 50 | 51 | private def setAtomicTimestamp(atomicTs: AtomicLong, newTs: Long): Unit = { 52 | breakable { 53 | while (true) { 54 | val oldTs = atomicTs.get() 55 | if (newTs > oldTs) { 56 | val success = atomicTs.compareAndSet(oldTs, newTs) 57 | if (success) break 58 | } 59 | else { 60 | break 61 | } 62 | } 63 | } 64 | } 65 | 66 | /** Add a new file to the map. */ 67 | def add(path: String, fileStatus: QueueMessageDesc[T]): Unit = { 68 | fileMap.put(path, fileStatus) 69 | setAtomicTimestamp(_latestTimestamp, fileStatus.timestampMs) 70 | } 71 | 72 | def addIfAbsent(path: String, fileStatus: QueueMessageDesc[T]): QueueMessageDesc[T] = { 73 | val ret = fileMap.computeIfAbsent(path, _ => fileStatus) 74 | setAtomicTimestamp(_latestTimestamp, fileStatus.timestampMs) 75 | 76 | ret 77 | } 78 | 79 | def isNewFile(path: String): FileCacheNewFileResult = { 80 | val fileMsg = fileMap.get(path) 81 | val isNew = if (fileMsg == null) FileCacheNewFileResults.Ok 82 | else if (fileMsg.isProcessed) FileCacheNewFileResults.ExistInCacheProcessed 83 | else FileCacheNewFileResults.ExistInCacheNotProcessed 84 | 85 | logDebug(s"fileCache isNewFile for ${path}: ${isNew}") 86 | isNew 87 | } 88 | 89 | /** 90 | * Returns all the new files found - ignore processed and aged files. 91 | */ 92 | def getUnprocessedFiles(maxFilesPerTrigger: Option[Int], 93 | shouldSortFiles: Boolean = false): Seq[FileMetadata[T]] = { 94 | reportTimeTaken("File cache getUnprocessedFiles") { 95 | if (shouldSortFiles) { 96 | val uncommittedFiles = filterAllUnprocessedFiles() 97 | val sortedFiles = reportTimeTaken("Sorting Files") { 98 | uncommittedFiles.sortWith(_.timestampMs < _.timestampMs) 99 | } 100 | 101 | maxFilesPerTrigger match { 102 | case Some(maxFiles) => 103 | sortedFiles 104 | .filter(file => file.timestampMs >= lastPurgeTimestamp) 105 | .take(maxFiles) 106 | case None => sortedFiles 107 | } 108 | } else { 109 | maxFilesPerTrigger match { 110 | case Some(maxFiles) => filterTopUnprocessedFiles(maxFiles) 111 | case None => filterAllUnprocessedFiles() 112 | } 113 | } 114 | } 115 | } 116 | 117 | private def filterTopUnprocessedFiles(maxFilesPerTrigger: Int): List[FileMetadata[T]] = { 118 | val iterator = fileMap.asScala.iterator 119 | val uncommittedFiles = ArrayBuffer[FileMetadata[T]]() 120 | while (uncommittedFiles.length < maxFilesPerTrigger && iterator.hasNext) { 121 | val file = iterator.next() 122 | 123 | if (!file._2.isProcessed && file._2.timestampMs >= lastPurgeTimestamp) { 124 | uncommittedFiles += FileMetadata(file._1, file._2.timestampMs, file._2.messageId) 125 | } 126 | } 127 | uncommittedFiles.toList 128 | } 129 | 130 | private def filterAllUnprocessedFiles(): List[FileMetadata[T]] = { 131 | fileMap.asScala.foldLeft(List[FileMetadata[T]]()) { 132 | (list, file) => 133 | if (!file._2.isProcessed && file._2.timestampMs >= lastPurgeTimestamp) { 134 | list :+ FileMetadata[T](file._1, file._2.timestampMs, file._2.messageId) 135 | } else { 136 | list 137 | } 138 | } 139 | } 140 | 141 | /** Removes aged and processed entries and returns the number of files removed. */ 142 | def purge(): Int = { 143 | setAtomicTimestamp(_lastPurgeTimestamp, latestTimestamp - maxFileAgeMs) 144 | var count = 0 145 | fileMap.asScala.foreach { fileEntry => 146 | if (fileEntry._2.timestampMs < lastPurgeTimestamp 147 | || fileEntry._2.isProcessed 148 | ) { 149 | fileMap.remove(fileEntry._1) 150 | count += 1 151 | } 152 | } 153 | count 154 | } 155 | 156 | /** Mark file entry as committed or already processed */ 157 | def markProcessed(path: String): Unit = { 158 | fileMap.replace(path, QueueMessageDesc( 159 | fileMap.get(path).timestampMs, isProcessed = true, fileMap.get(path).messageId)) 160 | } 161 | 162 | def size: Int = fileMap.size() 163 | 164 | } 165 | 166 | object FileCacheNewFileResults extends Enumeration { 167 | type FileCacheNewFileResult = Value 168 | val Ok, ExistInCacheNotProcessed, ExistInCacheProcessed = Value 169 | } 170 | 171 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorFileValidator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.amazonaws.spark.sql.streaming.connector 19 | 20 | import com.amazonaws.spark.sql.streaming.connector.FileValidResults.FileValidResult 21 | import com.amazonaws.spark.sql.streaming.connector.metadataLog.S3MetadataLog 22 | import org.apache.hadoop.fs.GlobPattern 23 | 24 | import org.apache.spark.internal.Logging 25 | 26 | class S3ConnectorFileValidator (sourceOptions: S3ConnectorSourceOptions, 27 | fileCache: S3ConnectorFileCache[_], 28 | metadataLog: S3MetadataLog) extends Logging { 29 | 30 | private val globber = sourceOptions.pathGlobFilter.map(new GlobPattern(_)) 31 | 32 | def isValidNewFile(filePath: String, timestamp: Long) : FileValidResult = { 33 | val lastPurgeTimestamp = fileCache.lastPurgeTimestamp 34 | globber.map( p => p.matches(filePath) ) match { 35 | case Some(true) | None => 36 | if (timestamp < lastPurgeTimestamp) { 37 | logInfo(s"isValidNewFile ${filePath} has ts ${timestamp} " + 38 | s"is older than ${lastPurgeTimestamp}") 39 | FileValidResults.FileExpired 40 | } else { 41 | val cacheResult = fileCache.isNewFile(filePath) 42 | 43 | if (cacheResult == FileCacheNewFileResults.ExistInCacheProcessed) { 44 | FileValidResults.ExistInCacheProcessed 45 | } else if (cacheResult == FileCacheNewFileResults.ExistInCacheNotProcessed) { 46 | FileValidResults.ExistInCacheNotProcessed 47 | } else if (!metadataLog.isNewFile(filePath, lastPurgeTimestamp)) { 48 | FileValidResults.PersistedInMetadataLog 49 | } else { 50 | FileValidResults.Ok 51 | } 52 | } 53 | case Some(false) => FileValidResults.PatternNotMatch 54 | } 55 | } 56 | } 57 | 58 | object FileValidResults extends Enumeration { 59 | type FileValidResult = Value 60 | val Ok, PatternNotMatch, FileExpired, ExistInCacheProcessed, ExistInCacheNotProcessed, PersistedInMetadataLog = Value 61 | } 62 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorModel.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector 18 | 19 | /** 20 | * A case class to store queue message description. 21 | * 22 | */ 23 | case class QueueMessageDesc[T](timestampMs: Long, 24 | isProcessed: Boolean = false, 25 | messageId: Option[T]) 26 | 27 | 28 | /** 29 | * A case class to store file metadata. 30 | * 31 | */ 32 | case class FileMetadata[T](filePath: String, 33 | timestampMs: Long, 34 | messageId: Option[T]) 35 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorSource.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector 18 | 19 | import java.net.URI 20 | 21 | import scala.util.{Failure, Success, Try} 22 | import scala.util.control.NonFatal 23 | 24 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSource.REPROCESS_LOCK_FILE 25 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions.SQS_QUEUE 26 | import com.amazonaws.spark.sql.streaming.connector.Utils.reportTimeTaken 27 | import com.amazonaws.spark.sql.streaming.connector.client.{AsyncQueueClient, AsyncSqsClientBuilder} 28 | import com.amazonaws.spark.sql.streaming.connector.metadataLog.{RocksDBS3SourceLog, S3MetadataLog} 29 | import org.apache.hadoop.fs.{FSDataOutputStream, GlobPattern, Path} 30 | 31 | import org.apache.spark.internal.Logging 32 | import org.apache.spark.sql.{DataFrame, SparkSession} 33 | import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} 34 | import org.apache.spark.sql.execution.streaming._ 35 | import org.apache.spark.sql.execution.streaming.FileStreamSource._ 36 | import org.apache.spark.sql.streaming.connector.s3.S3SparkUtils 37 | import org.apache.spark.sql.types.StructType 38 | 39 | class S3ConnectorSource[T](sparkSession: SparkSession, 40 | metadataPath: String, 41 | options: Map[String, String], 42 | userSchema: Option[StructType]) 43 | extends Source with Logging { 44 | 45 | private val sourceOptions = S3ConnectorSourceOptions(options) 46 | 47 | override def schema: StructType = userSchema.getOrElse( 48 | throw new S3ConnectorNoSchemaException 49 | ) 50 | 51 | private val fileCache = new S3ConnectorFileCache[T](sourceOptions.maxFileAgeMs) 52 | 53 | private val metadataLog: S3MetadataLog = sourceOptions.queueType match { 54 | case SQS_QUEUE => new RocksDBS3SourceLog() 55 | case _ => throw new S3ConnectorUnsupportedQueueTypeException( 56 | s"Unsupported queue type: ${sourceOptions.queueType}") 57 | } 58 | 59 | metadataLog.init(sparkSession, metadataPath, fileCache) 60 | 61 | private var metadataLogCurrentOffset = metadataLog.getLatest().map(_._1).getOrElse(-1L) 62 | 63 | private val partitionColumns: Seq[String] = sourceOptions.partitionColumns match { 64 | case Some(columns) => columns.split(",").map(_.trim) 65 | case None => Seq.empty 66 | } 67 | 68 | private val fileValidator = new S3ConnectorFileValidator(sourceOptions, fileCache, metadataLog) 69 | 70 | private val queueClient: AsyncQueueClient[T] = sourceOptions.queueType match { 71 | case SQS_QUEUE => 72 | var sqsClient: AsyncQueueClient[T] = null 73 | 74 | sqsClient = new AsyncSqsClientBuilder() 75 | .sourceOptions(sourceOptions) 76 | .consumer( 77 | (msg: FileMetadata[T]) => { 78 | val validateResult = fileValidator.isValidNewFile(msg.filePath, msg.timestampMs) 79 | if ( validateResult == FileValidResults.Ok) { 80 | logDebug(s"SQS message consumer file add to cache: ${msg}") 81 | val msgDesc = QueueMessageDesc[T](msg.timestampMs, isProcessed = false, msg.messageId) 82 | val result = fileCache.addIfAbsent(msg.filePath, msgDesc) 83 | if (result != msgDesc) { 84 | // This could happen as isValidFile check and adding to fileCache are not atomic. 85 | // Don't delete the message, let it retry instead. 86 | logWarning(s"SQS message consumer the message was not added to cache: ${msg}," + 87 | s"as same path already exist in cache: ${result}. The message will be retried.") 88 | msg.messageId.map(sqsClient.setMessageVisibility(_, sourceOptions.sqsVisibilityTimeoutSeconds)) 89 | } 90 | } else if (validateResult == FileValidResults.ExistInCacheNotProcessed) { 91 | // This can happen when the filePath is not processed/persisted yet but visibilityTimeout. 92 | // Do not delete the message. Let it retry until the filePath is persisted. 93 | logWarning(s"SQS message consumer file already exists in cache: ${msg}.") 94 | msg.messageId.map(sqsClient.setMessageVisibility(_, sourceOptions.sqsVisibilityTimeoutSeconds)) 95 | } else { 96 | logWarning(s"SQS message consumer delete msg of invalid file: ${msg}." + 97 | s" Reason: ${validateResult}") 98 | msg.messageId.map(sqsClient.deleteInvalidMessageIfNecessary(_)) 99 | } 100 | } 101 | ) 102 | .build() 103 | sqsClient 104 | case _ => throw new S3ConnectorUnsupportedQueueTypeException( 105 | s"Unsupported queue type: ${sourceOptions.queueType}") 106 | } 107 | 108 | purgeCache() 109 | 110 | logInfo(s"maxFilesPerBatch = ${sourceOptions.maxFilesPerTrigger}, maxFileAgeMs = ${sourceOptions.maxFileAgeMs}") 111 | 112 | for (reprocessStarId <- sourceOptions.reprocessStartBatchId; 113 | reprocessEndId <- sourceOptions.reprocessEndBatchId 114 | ) yield handleReprocessing(reprocessStarId, reprocessEndId) 115 | 116 | private def handleReprocessing(startLogId: Int, endLogId: Int): Unit = { 117 | sourceOptions.reprocessState match { 118 | case ReprocessStates.DryRun => 119 | logInfo(s"Reprocess dry run batch start ${startLogId}, end ${endLogId}." + 120 | s"Following files to be reprocessed") 121 | val files = getMetadataLogByRange(startLogId, endLogId) 122 | files.foreach( file => logInfo(file.productIterator.mkString("\t"))) 123 | throw new S3ConnectorReprocessDryRunException(s"Get ${endLogId -startLogId +1} batches," + 124 | s" ${files.length} files. Reprocess dry run completed. S3ConnectorReprocessDryRunException to exit.") 125 | case ReprocessStates.InAction => 126 | val fs = new Path(metadataPath).getFileSystem(sparkSession.sparkContext.hadoopConfiguration) 127 | var os: FSDataOutputStream = null 128 | try { 129 | val lockFile = new Path(metadataPath + REPROCESS_LOCK_FILE) 130 | if (fs.exists(lockFile)) { 131 | throw new S3ConnectorReprocessLockExistsException(s"${lockFile} already exists." + 132 | s"Remove it and rerun the reprocessing.") 133 | } 134 | os = fs.create(lockFile) 135 | 136 | val files = getMetadataLogByRange(startLogId, endLogId) 137 | files.foreach { file => 138 | fileCache.add(file.path, QueueMessageDesc(file.timestamp, isProcessed = false, None)) 139 | } 140 | } catch { 141 | case le: S3ConnectorReprocessLockExistsException => throw le 142 | case NonFatal(e) => 143 | val reprocessException = new S3ConnectorReprocessException("Error in reprocessing") 144 | reprocessException.addSuppressed(e) 145 | throw reprocessException 146 | } 147 | finally { 148 | if (os != null) os.close() 149 | } 150 | 151 | case _ => logWarning("Reprocess skipped") 152 | 153 | } 154 | } 155 | 156 | /** 157 | * Returns the data that is between the offsets (`start`, `end`]. 158 | */ 159 | override def getBatch(start: Option[Offset], end: Offset): DataFrame = { 160 | val startOffset = start.map(FileStreamSourceOffset(_).logOffset).getOrElse(-1L) 161 | val endOffset = FileStreamSourceOffset(end).logOffset 162 | 163 | assert(startOffset <= endOffset) 164 | 165 | val files = getMetadataLogByRange(startOffset + 1, endOffset) // startOffset is exclusive 166 | 167 | logInfo(s"getBatch processing ${files.length} files from ${startOffset + 1}:$endOffset") 168 | logTrace(s"Files are:\n\t" + files.mkString("\n\t")) 169 | 170 | val newDataSource = 171 | DataSource( 172 | sparkSession, 173 | paths = files.map(f => new Path(new URI(f.path)).toString), 174 | userSpecifiedSchema = Some(schema), 175 | partitionColumns = partitionColumns, 176 | className = sourceOptions.fileFormat, 177 | options = options) 178 | 179 | S3SparkUtils.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( 180 | checkFilesExist = false), isStreaming = true)) 181 | } 182 | 183 | /* 184 | * both startId and endId are inclusive 185 | */ 186 | private def getMetadataLogByRange(startId: Timestamp, endId: Timestamp) = { 187 | val globber = sourceOptions.pathGlobFilter.map(new GlobPattern(_)) 188 | metadataLog.get(Some(startId), Some(endId)) 189 | .flatMap(_._2) 190 | .filter { file => globber.forall(p => p.matches(file.path)) } 191 | } 192 | 193 | private def fetchMaxOffset(): FileStreamSourceOffset = { 194 | 195 | // only fetch new messages from SQS when not reprocessing 196 | if (sourceOptions.reprocessState == ReprocessStates.NoReprocess) { 197 | // If asyncFetch can't finish in time, it continues in a separate thread. 198 | // And here proceeds with what's available in the fileCache 199 | queueClient.asyncFetch(sourceOptions.queueFetchWaitTimeoutSeconds) 200 | } 201 | 202 | 203 | val batchFiles = fileCache.getUnprocessedFiles(sourceOptions.maxFilesPerTrigger) 204 | 205 | if (batchFiles.nonEmpty) { 206 | metadataLogCurrentOffset += 1 207 | val addSuccess = metadataLog.add(metadataLogCurrentOffset, batchFiles.map { 208 | case FileMetadata(path, timestamp, _) => 209 | FileEntry(path = path, timestamp = timestamp, batchId = metadataLogCurrentOffset) 210 | }.toArray) 211 | 212 | if (addSuccess) { 213 | logInfo(s"Log offset set to $metadataLogCurrentOffset with ${batchFiles.size} new files") 214 | val messageIds = batchFiles.map { 215 | case FileMetadata(path, _, messageId) => 216 | fileCache.markProcessed(path) 217 | logDebug(s"New file in fetchMaxOffset: $path") 218 | messageId 219 | }.toList 220 | 221 | queueClient.handleProcessedMessageBatch(messageIds.flatten) 222 | } 223 | else { 224 | throw new S3ConnectorMetalogAddException(s"BatchId ${metadataLogCurrentOffset} already exists.") 225 | } 226 | 227 | } 228 | else if (sourceOptions.reprocessState != ReprocessStates.NoReprocess) { 229 | logWarning("This is a reprocessing run. No new data are fetched from the queue." 230 | + " To resume the new data processing, restart the application without reprocessing parameters.") 231 | } 232 | 233 | val numPurged = purgeCache() 234 | 235 | logDebug( 236 | s""" 237 | |Number of files selected for batch = ${batchFiles.size} 238 | |Number of files purged from tracking map = ${numPurged.getOrElse(0)} 239 | """.stripMargin) 240 | 241 | FileStreamSourceOffset(metadataLogCurrentOffset) 242 | } 243 | 244 | override def getOffset: Option[Offset] = reportTimeTaken ("getOffset") { 245 | Some(fetchMaxOffset()).filterNot(_.logOffset == -1) 246 | } 247 | 248 | override def commit(end: Offset): Unit = { 249 | purgeCache() 250 | metadataLog.commit() 251 | logQueueClientMetrics() 252 | } 253 | 254 | override def stop(): Unit = { 255 | try { 256 | queueClient.close() 257 | } 258 | finally { 259 | metadataLog.close() 260 | } 261 | 262 | logQueueClientMetrics() 263 | } 264 | 265 | override def toString: String = s"S3ConnectorSource[${queueClient.queueUrl}]" 266 | 267 | private def purgeCache(): Option[Int] = { 268 | Try(fileCache.purge()) match { 269 | case Success(cnt) => 270 | logDebug(s"Successfully purged ${cnt} entries in sqsFileCache") 271 | Some(cnt) 272 | case Failure(e) => 273 | logError("failed to purge fileCache", e) 274 | None 275 | } 276 | } 277 | 278 | private def logQueueClientMetrics(): Unit = { 279 | Try(queueClient.metrics.json) match { 280 | case Success(metricsString) => 281 | logInfo(s"queueClient metrics: ${metricsString}") 282 | case Failure(e) => 283 | logError("failed to get queueClient.metrics", e) 284 | } 285 | } 286 | 287 | } 288 | 289 | object S3ConnectorSource { 290 | val REPROCESS_LOCK_FILE = "/reprocess.lock" 291 | } 292 | 293 | 294 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorSourceOptions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.amazonaws.spark.sql.streaming.connector 19 | 20 | import scala.util.Try 21 | 22 | import com.amazonaws.spark.sql.streaming.connector.ReprocessStates.ReprocessState 23 | 24 | import org.apache.spark.network.util.JavaUtils 25 | 26 | case class S3ConnectorSourceOptions ( 27 | maxFilesPerTrigger: Option[Int], 28 | pathGlobFilter: Option[String], 29 | fileFormat: String, 30 | maxFileAgeMs: Long, 31 | partitionColumns: Option[String], 32 | queueRegion: String, 33 | queueUrl: String, 34 | queueType: String, 35 | queueFetchWaitTimeoutSeconds: Long, 36 | 37 | // reprocess 38 | reprocessStartBatchId: Option[Int], // inclusive 39 | reprocessEndBatchId: Option[Int], // inclusive 40 | reprocessState: ReprocessState, 41 | 42 | // SQS parameters 43 | sqsMaxRetries: Int, 44 | sqsMaxConcurrency: Int, 45 | sqsLongPollWaitTimeSeconds: Int, 46 | sqsVisibilityTimeoutSeconds: Int, 47 | sqsKeepMessageForConsumerError: Boolean, 48 | ) 49 | 50 | object S3ConnectorSourceOptions { 51 | val SQS_QUEUE = "SQS" 52 | 53 | private val PREFIX = "spark.s3conn." 54 | private val SQS_PREFIX = PREFIX + ".sqs." 55 | 56 | val BASE_PATH: String = "basePath" 57 | 58 | val MAX_FILES_PER_TRIGGER: String = PREFIX + "maxFilesPerTrigger" 59 | val PATH_GLOB_FILTER: String = PREFIX + "pathGlobFilter" 60 | val S3_FILE_FORMAT: String = PREFIX + "fileFormat" 61 | val MAX_FILE_AGE: String = PREFIX + "maxFileAge" 62 | val PARTITION_COLUMNS: String = PREFIX + "partitionColumns" 63 | val QUEUE_URL: String = PREFIX + "queueUrl" 64 | val QUEUE_REGION: String = PREFIX + "queueRegion" 65 | val QUEUE_TYPE: String = PREFIX + "queueType" 66 | val QUEUE_FETCH_WAIT_TIMEOUT_SECONDS: String = PREFIX + "queueFetchWaitTimeoutSeconds" 67 | 68 | val REPROCESS_START_BATCH_ID: String = PREFIX + "reprocessStartBatchId" 69 | val REPROCESS_END_BATCH_ID: String = PREFIX + "reprocessEndBatchId" 70 | val REPROCESS_DRY_RUN: String = PREFIX + "reprocessDryRun" 71 | 72 | val SQS_LONG_POLLING_WAIT_TIME_SECONDS: String = SQS_PREFIX + "longPollingWaitTimeSeconds" 73 | val SQS_VISIBILITY_TIMEOUT_SECONDS: String = SQS_PREFIX + "visibilityTimeoutSeconds" 74 | val SQS_KEEP_MESSAGE_FOR_CONSUMER_ERROR: String = SQS_PREFIX + "keepMessageForConsumerError" 75 | val SQS_MAX_RETRIES: String = SQS_PREFIX + "maxRetries" 76 | val SQS_MAX_CONCURRENCY: String = SQS_PREFIX + "maxConcurrency" 77 | 78 | val MAX_FILES_PER_TRIGGER_DEFAULT_VALUE: Int = 100 79 | val MAX_FILE_AGE_DEFAULT_VALUE: String = "15d" 80 | val REPROCESS_DRY_RUN_DEFAULT_VALUE: Boolean = true 81 | val SQS_LONG_POLLING_WAIT_TIME_SECONDS_MIN_VALUE: Int = 0 82 | val SQS_LONG_POLLING_WAIT_TIME_SECONDS_MAX_VALUE: Int = 20 83 | val SQS_LONG_POLLING_WAIT_TIME_SECONDS_DEFAULT_VALUE: Int = 10 84 | val SQS_MAX_RETRIES_DEFAULT_VALUE: Int = 10 85 | val SQS_KEEP_MESSAGE_FOR_CONSUMER_ERROR_DEFAULT_VALUE: Boolean = false 86 | val SQS_MAX_CONCURRENCY_DEFAULT_VALUE: Int = 50 87 | val SQS_VISIBILITY_TIMEOUT_DEFAULT_VALUE: Int = 60 88 | 89 | 90 | def apply(parameters: Map[String, String]): S3ConnectorSourceOptions = { 91 | 92 | val maxFilesPerTrigger: Option[Int] = parameters.get(MAX_FILES_PER_TRIGGER) match { 93 | case Some(str) => Try(str.toInt).toOption.filter(_ > 0).orElse(None) 94 | case None => Some(MAX_FILES_PER_TRIGGER_DEFAULT_VALUE) 95 | } 96 | 97 | val pathGlobFilter = parameters.get(PATH_GLOB_FILTER) 98 | 99 | val fileFormat = parameters.getOrElse(S3_FILE_FORMAT, 100 | throw new IllegalArgumentException(s"Specifying ${S3_FILE_FORMAT} is mandatory with s3 connector source")) 101 | 102 | val maxFileAgeMs = JavaUtils.timeStringAsMs(parameters.getOrElse(MAX_FILE_AGE, MAX_FILE_AGE_DEFAULT_VALUE)) 103 | 104 | val partitionColumns = parameters.get(PARTITION_COLUMNS) 105 | 106 | val queueUrl: String = parameters.getOrElse(QUEUE_URL, 107 | throw new IllegalArgumentException(s"${QUEUE_URL} is not specified")) 108 | 109 | val queueRegion: String = parameters.getOrElse(QUEUE_REGION, 110 | throw new IllegalArgumentException(s"${QUEUE_REGION} is not specified")) 111 | 112 | val queueType: String = parameters.getOrElse(QUEUE_TYPE, SQS_QUEUE) 113 | 114 | val REPROCESS_START_LOG_ID: String = PREFIX + "reprocessStartBatchId" 115 | val REPROCESS_END_LOG_ID: String = PREFIX + "reprocessEndBatchId" 116 | val REPROCESS_DRY_RUN: String = PREFIX + "reprocessDryRun" 117 | 118 | val reprocessStartBatchId: Option[Int] = parameters.get(REPROCESS_START_LOG_ID).map { str => 119 | Try(str.toInt).toOption.filter(_ >= 0).getOrElse { 120 | throw new IllegalArgumentException( 121 | s"Invalid value '$str' for option '${REPROCESS_START_LOG_ID}', must be zero or a positive integer") 122 | } 123 | } 124 | 125 | val reprocessEndBatchId: Option[Int] = parameters.get(REPROCESS_END_LOG_ID).map { str => 126 | Try(str.toInt).toOption.filter(_ >= 0).getOrElse { 127 | throw new IllegalArgumentException( 128 | s"Invalid value '$str' for option '${REPROCESS_END_LOG_ID}', must be zero or a positive integer") 129 | } 130 | } 131 | 132 | val reprocessDryRun: Boolean = withBooleanParameter(parameters, REPROCESS_DRY_RUN, 133 | REPROCESS_DRY_RUN_DEFAULT_VALUE) 134 | 135 | 136 | val reprocessState: ReprocessState = { 137 | for ( startId <- reprocessStartBatchId; 138 | endId <- reprocessEndBatchId 139 | ) yield { 140 | if (startId > endId) { 141 | throw new IllegalArgumentException( 142 | s"reprocessStartBatchId must be less than or equal to reprocessEndBatchId: start ${startId}, end ${endId}") 143 | } 144 | 145 | if (reprocessDryRun) ReprocessStates.DryRun else ReprocessStates.InAction 146 | } 147 | } getOrElse ReprocessStates.NoReprocess 148 | 149 | val sqsKeepMessageForConsumerError: Boolean = withBooleanParameter( parameters, 150 | SQS_KEEP_MESSAGE_FOR_CONSUMER_ERROR, SQS_KEEP_MESSAGE_FOR_CONSUMER_ERROR_DEFAULT_VALUE) 151 | 152 | val sqsLongPollWaitTimeSeconds: Int = parameters.get(SQS_LONG_POLLING_WAIT_TIME_SECONDS).map { str => 153 | Try(str.toInt).toOption.filter{ x => 154 | x >= SQS_LONG_POLLING_WAIT_TIME_SECONDS_MIN_VALUE && x <= SQS_LONG_POLLING_WAIT_TIME_SECONDS_MAX_VALUE 155 | } 156 | .getOrElse { 157 | throw new IllegalArgumentException( 158 | s"Invalid value '$str' for option ${SQS_LONG_POLLING_WAIT_TIME_SECONDS}," + 159 | s"must be an integer between ${SQS_LONG_POLLING_WAIT_TIME_SECONDS_MIN_VALUE}" + 160 | s" and ${SQS_LONG_POLLING_WAIT_TIME_SECONDS_MAX_VALUE}") 161 | } 162 | }.getOrElse(SQS_LONG_POLLING_WAIT_TIME_SECONDS_DEFAULT_VALUE) 163 | 164 | val sqsMaxRetries: Int = withPositiveIntegerParameter(parameters, SQS_MAX_RETRIES, SQS_MAX_RETRIES_DEFAULT_VALUE) 165 | 166 | val sqsMaxConcurrency: Int = withPositiveIntegerParameter(parameters, SQS_MAX_CONCURRENCY, 167 | SQS_MAX_CONCURRENCY_DEFAULT_VALUE) 168 | 169 | val sqsVisibilityTimeoutSeconds: Int = withPositiveIntegerParameter(parameters, SQS_VISIBILITY_TIMEOUT_SECONDS, 170 | SQS_VISIBILITY_TIMEOUT_DEFAULT_VALUE) 171 | 172 | val queueFetchWaitTimeoutSeconds: Long = withPositiveIntegerParameter(parameters, QUEUE_FETCH_WAIT_TIMEOUT_SECONDS, 173 | 2 * sqsLongPollWaitTimeSeconds) 174 | 175 | new S3ConnectorSourceOptions( 176 | maxFilesPerTrigger = maxFilesPerTrigger, 177 | pathGlobFilter = pathGlobFilter, 178 | fileFormat = fileFormat, 179 | maxFileAgeMs = maxFileAgeMs, 180 | partitionColumns = partitionColumns, 181 | queueRegion = queueRegion, 182 | queueUrl = queueUrl, 183 | queueType = queueType, 184 | queueFetchWaitTimeoutSeconds = queueFetchWaitTimeoutSeconds, 185 | reprocessStartBatchId = reprocessStartBatchId, 186 | reprocessEndBatchId = reprocessEndBatchId, 187 | reprocessState = reprocessState, 188 | sqsLongPollWaitTimeSeconds = sqsLongPollWaitTimeSeconds, 189 | sqsMaxRetries = sqsMaxRetries, 190 | sqsMaxConcurrency = sqsMaxConcurrency, 191 | sqsVisibilityTimeoutSeconds = sqsVisibilityTimeoutSeconds, 192 | sqsKeepMessageForConsumerError = sqsKeepMessageForConsumerError 193 | ) 194 | } 195 | 196 | private def withPositiveIntegerParameter(parameters: Map[String, String], name: String, default: Int) = { 197 | parameters.get(name).map { str => 198 | Try(str.toInt).toOption.filter(_ > 0).getOrElse { 199 | throw new IllegalArgumentException( 200 | s"Invalid value '$str' for option '${name}', must be a positive integer") 201 | } 202 | }.getOrElse(default) 203 | } 204 | 205 | private def withBooleanParameter(parameters: Map[String, String], name: String, default: Boolean) = { 206 | parameters.get(name).map { str => 207 | try { 208 | str.toBoolean 209 | } catch { 210 | case _: IllegalArgumentException => 211 | throw new IllegalArgumentException( 212 | s"Invalid value '$str' for option '$name', must be true or false") 213 | } 214 | }.getOrElse(default) 215 | } 216 | } 217 | 218 | object ReprocessStates extends Enumeration { 219 | type ReprocessState = Value 220 | val NoReprocess, DryRun, InAction = Value 221 | } -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorSourceProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.amazonaws.spark.sql.streaming.connector 19 | 20 | import org.apache.spark.internal.Logging 21 | import org.apache.spark.sql.SQLContext 22 | import org.apache.spark.sql.execution.streaming.Source 23 | import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} 24 | import org.apache.spark.sql.types.StructType 25 | 26 | class S3ConnectorSourceProvider extends DataSourceRegister 27 | with StreamSourceProvider 28 | with Logging { 29 | 30 | override def shortName(): String = "s3-connector" 31 | 32 | override def sourceSchema(sqlContext: SQLContext, 33 | schema: Option[StructType], 34 | providerName: String, 35 | parameters: Map[String, String]): (String, StructType) = { 36 | 37 | require(schema.isDefined, "S3-connector source doesn't support empty schema") 38 | (shortName(), schema.get) 39 | } 40 | 41 | override def createSource(sqlContext: SQLContext, 42 | metadataPath: String, 43 | schema: Option[StructType], 44 | providerName: String, 45 | parameters: Map[String, String]): Source = { 46 | 47 | new S3ConnectorSource[String]( 48 | sqlContext.sparkSession, 49 | metadataPath, 50 | parameters, 51 | schema) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/Utils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | 19 | package com.amazonaws.spark.sql.streaming.connector 20 | 21 | import java.text.SimpleDateFormat 22 | import java.util.TimeZone 23 | import java.util.concurrent.{ExecutorService, TimeUnit} 24 | 25 | import org.apache.spark.internal.Logging 26 | 27 | object Utils extends Logging{ 28 | 29 | // TODO: default values configurable 30 | val DEFAULT_CONNECTION_ACQUIRE_TIMEOUT = 60 // seconds 31 | val DEFAULT_SHUTDOWN_WAIT_TIMEOUT = 180 // seconds 32 | 33 | def reportTimeTaken[T](operation: String)(body: => T): T = { 34 | val startTime = System.currentTimeMillis() 35 | val result = body 36 | val endTime = System.currentTimeMillis() 37 | val timeTaken = math.max(endTime - startTime, 0) 38 | 39 | logInfo(s"reportTimeTaken $operation took $timeTaken ms") 40 | result 41 | } 42 | 43 | def shutdownAndAwaitTermination(pool: ExecutorService, await_timeout: Int = DEFAULT_SHUTDOWN_WAIT_TIMEOUT): Unit = { 44 | if (! pool.isTerminated) { 45 | pool.shutdown() // Disable new tasks from being submitted 46 | 47 | try { 48 | // Wait a while for existing tasks to terminate 49 | if (!pool.awaitTermination(await_timeout, TimeUnit.SECONDS)) { 50 | pool.shutdownNow // Cancel currently executing tasks 51 | 52 | // Wait a while for tasks to respond to being cancelled 53 | if (!pool.awaitTermination(await_timeout, TimeUnit.SECONDS)) { 54 | logError(s"Thread pool did not stop properly: ${pool.toString}.") 55 | } 56 | } 57 | } catch { 58 | case _: InterruptedException => 59 | // (Re-)Cancel if current thread also interrupted 60 | pool.shutdownNow 61 | // Preserve interrupt status 62 | Thread.currentThread.interrupt() 63 | } 64 | } 65 | } 66 | 67 | def convertTimestampToMills(timestamp: String): Long = { 68 | val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 69 | timestampFormat.setTimeZone(TimeZone.getTimeZone("UTC")) 70 | val timeInMillis = timestampFormat.parse(timestamp).getTime 71 | timeInMillis 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/client/AsyncClientBuilder.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector.client 18 | 19 | import com.amazonaws.spark.sql.streaming.connector.{FileMetadata, S3ConnectorSourceOptions} 20 | 21 | trait AsyncClientBuilder[T] { 22 | def sourceOptions(options: S3ConnectorSourceOptions): AsyncClientBuilder[T] 23 | def consumer(function: FileMetadata[T] => Unit): AsyncClientBuilder[T] 24 | def build(): AsyncQueueClient[T] 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/client/AsyncClientMetrics.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.amazonaws.spark.sql.streaming.connector.client 19 | 20 | import com.codahale.metrics.Counter 21 | 22 | trait AsyncClientMetrics { 23 | def receiveMessageCounter: Counter 24 | def receiveMessageFailedCounter: Counter 25 | def parseMessageCounter: Counter 26 | def parseMessageFailedCounter: Counter 27 | def discardedMessageCounter: Counter 28 | def consumeMessageCounter: Counter 29 | def consumeMessageFailedCounter: Counter 30 | def deleteMessageCounter: Counter 31 | def deleteMessageFailedCounter: Counter 32 | def setMessageVisibilityCounter: Counter 33 | def setMessageVisibilityFailedCounter: Counter 34 | def fetchThreadConsumeMessageCounter: Counter 35 | def fetchThreadConsumeMessageFailedCounter: Counter 36 | def fetchThreadUncaughtExceptionCounter: Counter 37 | def json: String 38 | } 39 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/client/AsyncQueueClient.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector.client 18 | 19 | import java.io.Closeable 20 | import java.util.concurrent.{CompletableFuture, Future} 21 | import java.util.function.Consumer 22 | 23 | import com.amazonaws.spark.sql.streaming.connector.FileMetadata 24 | import com.amazonaws.spark.sql.streaming.connector.client.AsyncQueueConsumerResults.AsyncQueueConsumerResult 25 | 26 | trait AsyncQueueClient[T] extends Closeable { 27 | def queueUrl: String 28 | 29 | def deleteInvalidMessageIfNecessary(messageId: T): CompletableFuture[Boolean] 30 | /** 31 | * @param messageId message id to delete 32 | * @return the returned future will be completed with true if the message is successfully deleted 33 | */ 34 | def deleteMessage(messageId: T): CompletableFuture[Boolean] 35 | /** 36 | * the input messageIds can be divided into several sub batches. 37 | * 38 | * @param messageIds List of message ids to delete 39 | * @return a list of futures for sub batches. The returned future will be completed with true only 40 | * when all messages are successfully deleted in the sub batch 41 | */ 42 | def deleteMessageBatch(messageIds: Seq[T]): Seq[CompletableFuture[Boolean]] 43 | 44 | def handleProcessedMessage(messageId: T): CompletableFuture[Boolean] 45 | def handleProcessedMessageBatch(messageIds: Seq[T]): Seq[CompletableFuture[Boolean]] 46 | def setMessageVisibility(messageId: T, 47 | visibilityTimeoutSeconds: Int): CompletableFuture[Boolean] 48 | def consume(consumer: Consumer[FileMetadata[T]]): CompletableFuture[Seq[AsyncQueueConsumerResult]] 49 | def metrics: AsyncClientMetrics 50 | def asyncFetch(waitTimeoutSecond: Long): Future[_] 51 | def awaitFetchReady(future: Future[_], timeoutSecond: Long): Unit 52 | } 53 | 54 | object AsyncQueueConsumerResults extends Enumeration { 55 | type AsyncQueueConsumerResult = Value 56 | val Ok, ReceiveEmpty, ParseNone, ReceiveException, ConsumerException = Value 57 | } -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/client/AsyncSqsClientBuilder.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector.client 18 | 19 | import java.time.Duration 20 | import java.util.function.Consumer 21 | 22 | import scala.language.implicitConversions 23 | 24 | import com.amazonaws.spark.sql.streaming.connector.{ConnectorAwsCredentialsProvider, FileMetadata, S3ConnectorSourceOptions} 25 | import com.amazonaws.spark.sql.streaming.connector.Utils.DEFAULT_CONNECTION_ACQUIRE_TIMEOUT 26 | import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider 27 | import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration 28 | import software.amazon.awssdk.core.retry.RetryPolicy 29 | import software.amazon.awssdk.core.retry.backoff.BackoffStrategy 30 | import software.amazon.awssdk.core.retry.conditions.RetryCondition 31 | import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient 32 | import software.amazon.awssdk.regions.Region 33 | import software.amazon.awssdk.services.sqs.SqsAsyncClient 34 | 35 | class AsyncSqsClientBuilder[T] extends AsyncClientBuilder[T] { 36 | 37 | var options: S3ConnectorSourceOptions = _ 38 | var consumer: (FileMetadata[T]) => Unit = _ 39 | 40 | private val credentialsProvider: AwsCredentialsProvider = 41 | ConnectorAwsCredentialsProvider.builder.build().provider 42 | 43 | implicit def toConsumer[A](function: A => Unit): Consumer[A] = new Consumer[A]() { 44 | override def accept(arg: A): Unit = function.apply(arg) 45 | } 46 | 47 | override def sourceOptions(options: S3ConnectorSourceOptions): AsyncClientBuilder[T] = { 48 | this.options = options 49 | this 50 | } 51 | 52 | override def consumer(function: FileMetadata[T] => Unit): AsyncClientBuilder[T] = { 53 | this.consumer = function 54 | this 55 | } 56 | 57 | def build(): AsyncQueueClient[T] = { 58 | require(options!=null, "sourceOptions can't be null") 59 | require(consumer!=null, "sqs message consumer can't be null") 60 | 61 | val asyncSqsClient = getAsyncSQSClient(options) 62 | new AsyncSqsClientImpl(asyncSqsClient, options, Some(consumer)) 63 | } 64 | 65 | private def getAsyncSQSClient(options: S3ConnectorSourceOptions): SqsAsyncClient = { 66 | val retryPolicy = RetryPolicy.builder 67 | .numRetries(options.sqsMaxRetries) 68 | .retryCondition(RetryCondition.defaultRetryCondition) 69 | .backoffStrategy(BackoffStrategy.defaultThrottlingStrategy) 70 | .build 71 | 72 | val clientOverrideConfiguration = ClientOverrideConfiguration.builder 73 | .retryPolicy(retryPolicy) 74 | .build 75 | 76 | SqsAsyncClient.builder 77 | .httpClient( 78 | NettyNioAsyncHttpClient 79 | .builder() 80 | .maxConcurrency(options.sqsMaxConcurrency) 81 | .connectionAcquisitionTimeout(Duration.ofSeconds(DEFAULT_CONNECTION_ACQUIRE_TIMEOUT)) 82 | .build() 83 | ) 84 | .region(Region.of(options.queueRegion)) 85 | .overrideConfiguration(clientOverrideConfiguration) 86 | .credentialsProvider(credentialsProvider) 87 | .build 88 | 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/client/AsyncSqsClientMetricsImpl.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector.client 18 | 19 | import scala.collection.JavaConverters._ 20 | 21 | import com.codahale.metrics.{Counter, MetricRegistry} 22 | import org.json4s.NoTypeHints 23 | import org.json4s.jackson.Serialization 24 | 25 | class AsyncSqsClientMetricsImpl extends AsyncClientMetrics{ 26 | 27 | private val metricRegistry = new MetricRegistry 28 | 29 | private def getCounter(name: String): Counter = { 30 | metricRegistry.counter(MetricRegistry.name("AsyncSqsClient", name)) 31 | } 32 | 33 | override val receiveMessageCounter: Counter = getCounter("receiveMessageCounter") 34 | 35 | override val receiveMessageFailedCounter: Counter = getCounter("receiveMessageFailedCounter") 36 | 37 | override val parseMessageCounter: Counter = getCounter("parseMassageCounter") 38 | 39 | override val parseMessageFailedCounter: Counter = getCounter("parseMessageFailedCounter") 40 | 41 | override val discardedMessageCounter: Counter = getCounter("discardedMessageCounter") 42 | 43 | override val consumeMessageCounter: Counter = getCounter("consumeMessageCounter") 44 | 45 | override val consumeMessageFailedCounter: Counter = getCounter("consumeMessageFailedCounter") 46 | 47 | override val deleteMessageCounter: Counter = getCounter("deleteMessageCounter") 48 | 49 | override val deleteMessageFailedCounter: Counter = getCounter("deleteMessageFailedCounter") 50 | 51 | override val setMessageVisibilityCounter: Counter = getCounter("setMessageVisibilityCounter") 52 | 53 | override val setMessageVisibilityFailedCounter: Counter = getCounter("setMessageVisibilityFailedCounter") 54 | 55 | override val fetchThreadConsumeMessageCounter: Counter = getCounter("fetchThreadConsumeMessageCounter") 56 | 57 | override val fetchThreadConsumeMessageFailedCounter: Counter = getCounter("fetchThreadConsumeMessageFailedCounter") 58 | 59 | override val fetchThreadUncaughtExceptionCounter: Counter = getCounter("fetchThreadUncaughtExceptionCounter") 60 | 61 | override def json: String = { 62 | Serialization.write( 63 | metricRegistry.getCounters.asScala.map { kv => 64 | (kv._1, kv._2.getCount) 65 | } 66 | )(AsyncSqsClientMetricsImpl.format) 67 | } 68 | } 69 | 70 | object AsyncSqsClientMetricsImpl { 71 | val format = Serialization.formats(NoTypeHints) 72 | 73 | def apply(): AsyncClientMetrics = { 74 | new AsyncSqsClientMetricsImpl() 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/scala/com/amazonaws/spark/sql/streaming/connector/metadataLog/S3MetadataLog.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector.metadataLog 18 | 19 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorFileCache 20 | 21 | import org.apache.spark.sql.SparkSession 22 | import org.apache.spark.sql.execution.streaming.FileStreamSource.FileEntry 23 | import org.apache.spark.sql.execution.streaming.MetadataLog 24 | 25 | trait S3MetadataLog extends MetadataLog[Array[FileEntry]]{ 26 | /** 27 | * Returns true if we should consider this file a new file. The file is only considered "new" 28 | * if it is new enough that we are still tracking, and we have not seen it before. 29 | */ 30 | def isNewFile(path: String, lastPurgeTimestamp: Long): Boolean 31 | 32 | def getFile(path: String): Option[Long] 33 | 34 | def init(sparkSession: SparkSession, 35 | checkpointPath: String, 36 | fileCache: S3ConnectorFileCache[_] 37 | ): Unit 38 | 39 | def add(batchId: Long, fileEntries: Array[FileEntry], timestamp: Option[Long]): Boolean 40 | 41 | def commit(): Unit 42 | 43 | def close(): Unit 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/streaming/connector/s3/RocksDBLoader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.streaming.connector.s3 19 | 20 | import org.rocksdb.{RocksDB => NativeRocksDB} 21 | 22 | import org.apache.spark.internal.Logging 23 | import org.apache.spark.util.UninterruptibleThread 24 | 25 | // This file is copied from Spark 3.2.1 to decouple the dependencies 26 | 27 | /** 28 | * A wrapper for RocksDB library loading using an uninterruptible thread, as the native RocksDB 29 | * code will throw an error when interrupted. 30 | */ 31 | object RocksDBLoader extends Logging { 32 | /** 33 | * Keep tracks of the exception thrown from the loading thread, if any. 34 | */ 35 | private var exception: Option[Throwable] = null 36 | 37 | private val loadLibraryThread = new UninterruptibleThread("RocksDBLoader") { 38 | override def run(): Unit = { 39 | try { 40 | runUninterruptibly { 41 | NativeRocksDB.loadLibrary() 42 | exception = None 43 | } 44 | } catch { 45 | case e: Throwable => 46 | exception = Some(e) 47 | } 48 | } 49 | } 50 | 51 | def loadLibrary(): Unit = synchronized { 52 | if (exception == null) { 53 | loadLibraryThread.start() 54 | logInfo("RocksDB library loading thread started") 55 | loadLibraryThread.join() 56 | exception.foreach(throw _) 57 | logInfo("RocksDB library loading thread finished successfully") 58 | } else { 59 | exception.foreach(throw _) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/streaming/connector/s3/RocksDBStateEncoder.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.streaming.connector.s3 19 | 20 | import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} 21 | import org.apache.spark.sql.types.{StructField, StructType} 22 | import org.apache.spark.unsafe.Platform 23 | 24 | // This file is copied from Spark 3.2.1 with minor adaption to S3 connector 25 | 26 | sealed trait RocksDBStateEncoder { 27 | def supportPrefixKeyScan: Boolean 28 | def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] 29 | def extractPrefixKey(key: UnsafeRow): UnsafeRow 30 | 31 | def encodeKey(row: UnsafeRow): Array[Byte] 32 | def encodeValue(row: UnsafeRow): Array[Byte] 33 | 34 | def decodeKey(keyBytes: Array[Byte]): UnsafeRow 35 | def decodeValue(valueBytes: Array[Byte]): UnsafeRow 36 | def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair 37 | } 38 | 39 | /** Mutable, and reusable class for representing a pair of UnsafeRows. */ 40 | class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { 41 | def withRows(key: UnsafeRow, value: UnsafeRow): UnsafeRowPair = { 42 | this.key = key 43 | this.value = value 44 | this 45 | } 46 | } 47 | 48 | object RocksDBStateEncoder { 49 | 50 | // Version as a single byte that specifies the encoding of the row data in RocksDB 51 | val STATE_ENCODING_NUM_VERSION_BYTES = 1 52 | val STATE_ENCODING_VERSION: Byte = 0 53 | 54 | def getEncoder( 55 | keySchema: StructType, 56 | valueSchema: StructType, 57 | numColsPrefixKey: Int): RocksDBStateEncoder = { 58 | if (numColsPrefixKey > 0) { 59 | new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey) 60 | } else { 61 | new NoPrefixKeyStateEncoder(keySchema, valueSchema) 62 | } 63 | } 64 | 65 | /** 66 | * Encode the UnsafeRow of N bytes as a N+1 byte array. 67 | * @note This creates a new byte array and memcopies the UnsafeRow to the new array. 68 | */ 69 | def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = { 70 | val bytesToEncode = row.getBytes 71 | val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) 72 | Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) 73 | // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform. 74 | Platform.copyMemory( 75 | bytesToEncode, Platform.BYTE_ARRAY_OFFSET, 76 | encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, 77 | bytesToEncode.length) 78 | encodedBytes 79 | } 80 | 81 | def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = { 82 | if (bytes != null) { 83 | val row = new UnsafeRow(numFields) 84 | decodeToUnsafeRow(bytes, row) 85 | } else { 86 | null 87 | } 88 | } 89 | 90 | def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = { 91 | if (bytes != null) { 92 | // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. 93 | reusedRow.pointTo( 94 | bytes, 95 | Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, 96 | bytes.length - STATE_ENCODING_NUM_VERSION_BYTES) 97 | reusedRow 98 | } else { 99 | null 100 | } 101 | } 102 | } 103 | 104 | class PrefixKeyScanStateEncoder( 105 | keySchema: StructType, 106 | valueSchema: StructType, 107 | numColsPrefixKey: Int) extends RocksDBStateEncoder { 108 | 109 | import RocksDBStateEncoder._ 110 | 111 | require(keySchema.length > numColsPrefixKey, "The number of columns in the key must be " + 112 | "greater than the number of columns for prefix key!") 113 | 114 | private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = { 115 | keySchema.zipWithIndex.take(numColsPrefixKey) 116 | } 117 | 118 | private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = { 119 | keySchema.zipWithIndex.drop(numColsPrefixKey) 120 | } 121 | 122 | private val prefixKeyProjection: UnsafeProjection = { 123 | val refs = prefixKeyFieldsWithIdx.map(x => BoundReference(x._2, x._1.dataType, x._1.nullable)) 124 | UnsafeProjection.create(refs) 125 | } 126 | 127 | private val remainingKeyProjection: UnsafeProjection = { 128 | val refs = remainingKeyFieldsWithIdx.map(x => 129 | BoundReference(x._2, x._1.dataType, x._1.nullable)) 130 | UnsafeProjection.create(refs) 131 | } 132 | 133 | // This is quite simple to do - just bind sequentially, as we don't change the order. 134 | private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema) 135 | 136 | // Reusable objects 137 | private val joinedRowOnKey = new JoinedRow() 138 | private val valueRow = new UnsafeRow(valueSchema.size) 139 | private val rowTuple = new UnsafeRowPair() 140 | 141 | override def encodeKey(row: UnsafeRow): Array[Byte] = { 142 | val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row)) 143 | val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row)) 144 | 145 | val encodedBytes = new Array[Byte](prefixKeyEncoded.length + remainingEncoded.length + 4) 146 | Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncoded.length) 147 | Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, 148 | encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length) 149 | // NOTE: We don't put the length of remainingEncoded as we can calculate later 150 | // on deserialization. 151 | Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET, 152 | encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4 + prefixKeyEncoded.length, 153 | remainingEncoded.length) 154 | 155 | encodedBytes 156 | } 157 | 158 | override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) 159 | 160 | override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { 161 | val prefixKeyEncodedLen = Platform.getInt(keyBytes, Platform.BYTE_ARRAY_OFFSET) 162 | val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen) 163 | Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded, 164 | Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen) 165 | 166 | // Here we calculate the remainingKeyEncodedLen leveraging the length of keyBytes 167 | val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen 168 | 169 | val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen) 170 | Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4 + 171 | prefixKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, 172 | remainingKeyEncodedLen) 173 | 174 | val prefixKeyDecoded = decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey) 175 | val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, 176 | numFields = keySchema.length - numColsPrefixKey) 177 | 178 | restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)) 179 | } 180 | 181 | override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { 182 | decodeToUnsafeRow(valueBytes, valueRow) 183 | } 184 | 185 | override def extractPrefixKey(key: UnsafeRow): UnsafeRow = { 186 | prefixKeyProjection(key) 187 | } 188 | 189 | override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { 190 | val prefixKeyEncoded = encodeUnsafeRow(prefixKey) 191 | val prefix = new Array[Byte](prefixKeyEncoded.length + 4) 192 | Platform.putInt(prefix, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncoded.length) 193 | Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefix, 194 | Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length) 195 | prefix 196 | } 197 | 198 | override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = { 199 | rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value)) 200 | } 201 | 202 | override def supportPrefixKeyScan: Boolean = true 203 | } 204 | 205 | /** 206 | * Encodes/decodes UnsafeRows to versioned byte arrays. 207 | * It uses the first byte of the generated byte array to store the version the describes how the 208 | * row is encoded in the rest of the byte array. Currently, the default version is 0, 209 | * 210 | * VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ] 211 | * The bytes of a UnsafeRow is written unmodified to starting from offset 1 212 | * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, 213 | * then the generated array byte will be N+1 bytes. 214 | */ 215 | class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType) 216 | extends RocksDBStateEncoder { 217 | 218 | import RocksDBStateEncoder._ 219 | 220 | // Reusable objects 221 | private val keyRow = new UnsafeRow(keySchema.size) 222 | private val valueRow = new UnsafeRow(valueSchema.size) 223 | private val rowTuple = new UnsafeRowPair() 224 | 225 | override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) 226 | 227 | override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) 228 | 229 | /** 230 | * Decode byte array for a key to a UnsafeRow. 231 | * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to 232 | * the given byte array. 233 | */ 234 | override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { 235 | decodeToUnsafeRow(keyBytes, keyRow) 236 | } 237 | 238 | /** 239 | * Decode byte array for a value to a UnsafeRow. 240 | * 241 | * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to 242 | * the given byte array. 243 | */ 244 | override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { 245 | decodeToUnsafeRow(valueBytes, valueRow) 246 | } 247 | 248 | /** 249 | * Decode pair of key-value byte arrays in a pair of key-value UnsafeRows. 250 | * 251 | * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to 252 | * the given byte array. 253 | */ 254 | override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = { 255 | rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value)) 256 | } 257 | 258 | override def supportPrefixKeyScan: Boolean = false 259 | 260 | override def extractPrefixKey(key: UnsafeRow): UnsafeRow = { 261 | throw new IllegalStateException("This encoder doesn't support prefix key!") 262 | } 263 | 264 | override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { 265 | throw new IllegalStateException("This encoder doesn't support prefix key!") 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/streaming/connector/s3/S3SparkUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.streaming.connector.s3 19 | 20 | import java.io.File 21 | import java.util.concurrent.{ExecutorService, ScheduledExecutorService, ThreadPoolExecutor} 22 | 23 | import org.apache.spark.SparkConf 24 | import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} 25 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 26 | import org.apache.spark.util.{ThreadUtils, Utils} 27 | 28 | 29 | 30 | 31 | 32 | object S3SparkUtils { 33 | def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { 34 | ThreadUtils.newDaemonSingleThreadScheduledExecutor(threadName) 35 | } 36 | 37 | def newDaemonSingleThreadExecutor(threadName: String): ExecutorService = { 38 | ThreadUtils.newDaemonSingleThreadExecutor(threadName) 39 | } 40 | 41 | def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { 42 | ThreadUtils.newDaemonFixedThreadPool(nThreads, prefix) 43 | } 44 | 45 | 46 | def createTempDir(root: String, namePrefix: String): File = { 47 | Utils.createTempDir(root, namePrefix) 48 | } 49 | 50 | def deleteRecursively(file: File): Unit = { 51 | Utils.deleteRecursively(file) 52 | } 53 | 54 | def getLocalDir(conf: SparkConf): String = { 55 | Utils.getLocalDir(conf) 56 | } 57 | 58 | def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = 59 | Dataset.ofRows(sparkSession, logicalPlan) 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/test/java/it/spark/sql/streaming/connector/IntegrationTestSuite.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package it.spark.sql.streaming.connector; 18 | 19 | import org.scalatest.TagAnnotation; 20 | 21 | import java.lang.annotation.ElementType; 22 | import java.lang.annotation.Retention; 23 | import java.lang.annotation.RetentionPolicy; 24 | import java.lang.annotation.Target; 25 | 26 | @TagAnnotation 27 | @Retention(RetentionPolicy.RUNTIME) 28 | @Target({ElementType.METHOD, ElementType.TYPE}) 29 | public @interface IntegrationTestSuite {} -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | # Set everything to be logged to the file target/unit-tests.log 19 | #log4j.rootCategory=INFO, file 20 | 21 | log4j.rootCategory=INFO, console 22 | log4j.appender.file=org.apache.log4j.FileAppender 23 | log4j.appender.file.append=true 24 | #log4j.appender.file.file=target/tests.log 25 | log4j.appender.file.file=ptconsumer_csv/tests.log 26 | log4j.appender.file.layout=org.apache.log4j.PatternLayout 27 | log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n 28 | 29 | log4j.appender.console=org.apache.log4j.ConsoleAppender 30 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 31 | 32 | # Ignore messages below warning level from Jetty, because it's a bit verbose 33 | log4j.logger.org.sparkproject.jetty=WARN 34 | -------------------------------------------------------------------------------- /src/test/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorFileCacheSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector 18 | 19 | import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper 20 | 21 | class S3ConnectorFileCacheSuite extends S3ConnectorTestBase { 22 | 23 | test("isNewFile result") { 24 | val cache = new S3ConnectorFileCache[String](5000) 25 | cache.add("/test/path1", QueueMessageDesc[String](1000, isProcessed = true, Some("handle1"))) 26 | cache.add("/test/path2", QueueMessageDesc[String](1000, isProcessed = false, Some("handle2"))) 27 | 28 | cache.isNewFile("/test/path0") shouldBe FileCacheNewFileResults.Ok 29 | cache.isNewFile("/test/path1") shouldBe FileCacheNewFileResults.ExistInCacheProcessed 30 | cache.isNewFile("/test/path2") shouldBe FileCacheNewFileResults.ExistInCacheNotProcessed 31 | 32 | } 33 | 34 | test("getUnprocessedFiles without sort") { 35 | val cache = new S3ConnectorFileCache[String](5000) 36 | 37 | cache.add("/test/path1", QueueMessageDesc[String](1000, isProcessed = false, Some("handle1"))) 38 | cache.add("/test/path2", QueueMessageDesc[String](0, isProcessed = false, Some("handle2"))) 39 | cache.add("/test/path3", QueueMessageDesc[String](200, isProcessed = false, Some("handle3"))) 40 | 41 | 42 | val unprocessedFiles1 = cache.getUnprocessedFiles(None) 43 | 44 | unprocessedFiles1.toSet shouldEqual Set( 45 | FileMetadata("/test/path1", 1000, Some("handle1")), 46 | FileMetadata("/test/path2", 0, Some("handle2")), 47 | FileMetadata("/test/path3", 200, Some("handle3"))) 48 | 49 | cache.markProcessed("/test/path2") 50 | val unprocessedFiles2 = cache.getUnprocessedFiles(None) 51 | 52 | unprocessedFiles2.toSet shouldEqual Set( 53 | FileMetadata("/test/path1", 1000, Some("handle1")), 54 | FileMetadata ("/test/path3", 200, Some("handle3"))) 55 | 56 | } 57 | 58 | test("getUnprocessedFiles with sort") { 59 | val cache = new S3ConnectorFileCache[String](5000) 60 | 61 | cache.add("/test/path1", QueueMessageDesc[String](1000, isProcessed = false, Some("handle1"))) 62 | cache.add("/test/path2", QueueMessageDesc[String](0, isProcessed = false, Some("handle2"))) 63 | cache.add("/test/path3", QueueMessageDesc[String](200, isProcessed = false, Some("handle3"))) 64 | 65 | 66 | val unprocessedFiles1 = cache.getUnprocessedFiles(None, shouldSortFiles = true) 67 | 68 | unprocessedFiles1 shouldEqual List( 69 | FileMetadata("/test/path2", 0, Some("handle2")), 70 | FileMetadata("/test/path3", 200, Some("handle3")), 71 | FileMetadata("/test/path1", 1000, Some("handle1"))) 72 | } 73 | 74 | test("getUnprocessedFiles ignore old files") { 75 | val cache = new S3ConnectorFileCache[String](5000) 76 | 77 | cache.add("/test/path1", QueueMessageDesc[String](0, isProcessed = false, Some("handle1"))) 78 | cache.add("/test/path2", QueueMessageDesc[String](1000, isProcessed = false, Some("handle2"))) 79 | cache.add("/test/path3", QueueMessageDesc[String](6000, isProcessed = false, Some("handle3"))) 80 | 81 | 82 | val unprocessedFiles = cache.getUnprocessedFiles(None) 83 | 84 | unprocessedFiles.toSet shouldEqual Set( 85 | FileMetadata("/test/path1", 0, Some("handle1")), 86 | FileMetadata("/test/path2", 1000, Some("handle2")), 87 | FileMetadata("/test/path3", 6000, Some("handle3"))) 88 | 89 | cache.purge() // purge will move the latest timestamp 90 | 91 | cache.add("/test/path4", QueueMessageDesc[String](100, isProcessed = false, Some("handle4"))) 92 | cache.add("/test/path5", QueueMessageDesc[String](200, isProcessed = false, Some("handle5"))) 93 | 94 | val unprocessedFiles2 = cache.getUnprocessedFiles(None) 95 | 96 | unprocessedFiles2.toSet shouldEqual Set( 97 | FileMetadata("/test/path2", 1000, Some("handle2")), 98 | FileMetadata("/test/path3", 6000, Some("handle3"))) 99 | } 100 | 101 | test("getUnprocessedFiles maxFilesPerTrigger without sort") { 102 | val cache = new S3ConnectorFileCache[String](5000) 103 | 104 | cache.add("/test/path1", QueueMessageDesc[String](1000, isProcessed = false, Some("handle1"))) 105 | cache.add("/test/path2", QueueMessageDesc[String](0, isProcessed = false, Some("handle2"))) 106 | cache.add("/test/path3", QueueMessageDesc[String](200, isProcessed = false, Some("handle3"))) 107 | 108 | 109 | val unprocessedFiles1 = cache.getUnprocessedFiles(Some(2)) 110 | 111 | unprocessedFiles1.toSet shouldEqual Set( 112 | FileMetadata("/test/path1", 1000, Some("handle1")), 113 | FileMetadata("/test/path2", 0, Some("handle2"))) 114 | 115 | cache.markProcessed("/test/path2") 116 | val unprocessedFiles2 = cache.getUnprocessedFiles(Some(2)) 117 | 118 | unprocessedFiles2.toSet shouldEqual Set( 119 | FileMetadata("/test/path1", 1000, Some("handle1")), 120 | FileMetadata("/test/path3", 200, Some("handle3"))) 121 | } 122 | 123 | test("getUnprocessedFiles maxFilesPerTrigger with sort") { 124 | val cache = new S3ConnectorFileCache[String](5000) 125 | 126 | cache.add("/test/path1", QueueMessageDesc[String](1000, isProcessed = false, Some("handle1"))) 127 | cache.add("/test/path2", QueueMessageDesc[String](0, isProcessed = false, Some("handle2"))) 128 | cache.add("/test/path3", QueueMessageDesc[String](200, isProcessed = false, Some("handle3"))) 129 | 130 | 131 | val unprocessedFiles1 = cache.getUnprocessedFiles(Some(2), shouldSortFiles = true) 132 | 133 | unprocessedFiles1 shouldEqual List( 134 | FileMetadata("/test/path2", 0, Some("handle2")), 135 | FileMetadata("/test/path3", 200, Some("handle3"))) 136 | } 137 | 138 | } 139 | 140 | -------------------------------------------------------------------------------- /src/test/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorFileValidatorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.amazonaws.spark.sql.streaming.connector 19 | 20 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions.PATH_GLOB_FILTER 21 | import com.amazonaws.spark.sql.streaming.connector.TestUtils.doReturnMock 22 | import com.amazonaws.spark.sql.streaming.connector.metadataLog.S3MetadataLog 23 | import org.mockito.ArgumentMatchers.{any, eq => meq} 24 | import org.mockito.Mockito.mock 25 | import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper 26 | 27 | class S3ConnectorFileValidatorSuite extends S3ConnectorTestBase { 28 | 29 | 30 | val defaultTs: Long = 1000L 31 | 32 | val mockFileCache = mock(classOf[S3ConnectorFileCache[String]]) 33 | val mockMetadataLog = mock(classOf[S3MetadataLog]) 34 | 35 | doReturnMock(FileCacheNewFileResults.Ok).when(mockFileCache).isNewFile(any()) 36 | doReturnMock(100L).when(mockFileCache).lastPurgeTimestamp 37 | doReturnMock(true).when(mockMetadataLog).isNewFile(any(), any()) 38 | 39 | test("isValidNewFile fileFilter not set") { 40 | doReturnMock(FileCacheNewFileResults.ExistInCacheNotProcessed).when(mockFileCache).isNewFile(meq("/path/test1")) 41 | doReturnMock(FileCacheNewFileResults.ExistInCacheProcessed).when(mockFileCache).isNewFile(meq("/path/test2")) 42 | doReturnMock(false).when(mockMetadataLog).isNewFile(meq("/path/test3"), any()) 43 | 44 | val validator = new S3ConnectorFileValidator( 45 | S3ConnectorSourceOptions(defaultOptionMap), 46 | mockFileCache, mockMetadataLog 47 | ) 48 | 49 | validator.isValidNewFile("/path/test0", defaultTs) shouldBe FileValidResults.Ok 50 | validator.isValidNewFile("/path/test1", defaultTs) shouldBe FileValidResults.ExistInCacheNotProcessed 51 | validator.isValidNewFile("/path/test2", defaultTs) shouldBe FileValidResults.ExistInCacheProcessed 52 | validator.isValidNewFile("/path/test3", defaultTs) shouldBe FileValidResults.PersistedInMetadataLog 53 | validator.isValidNewFile("/path/test1", 10L) shouldBe FileValidResults.FileExpired 54 | } 55 | 56 | 57 | test("isValidNewFile fileFilter set to *.csv") { 58 | val validator = new S3ConnectorFileValidator( 59 | S3ConnectorSourceOptions(defaultOptionMap + (PATH_GLOB_FILTER -> "*1.csv")), 60 | mockFileCache, mockMetadataLog 61 | ) 62 | 63 | validator.isValidNewFile("/path/test1", defaultTs) shouldBe FileValidResults.PatternNotMatch 64 | validator.isValidNewFile("/path/test1.csv", defaultTs) shouldBe FileValidResults.Ok 65 | validator.isValidNewFile("/path/test21.csv", defaultTs) shouldBe FileValidResults.Ok 66 | validator.isValidNewFile("/path/test2.csv", defaultTs) shouldBe FileValidResults.PatternNotMatch 67 | validator.isValidNewFile("/path/test1.json", defaultTs) shouldBe FileValidResults.PatternNotMatch 68 | 69 | } 70 | 71 | 72 | test("isValidNewFile fileFilter set to partition=1*") { 73 | val validator = new S3ConnectorFileValidator( 74 | S3ConnectorSourceOptions(defaultOptionMap + (PATH_GLOB_FILTER -> "*/partition=1*/*")), 75 | mockFileCache, mockMetadataLog 76 | ) 77 | 78 | validator.isValidNewFile("/path/partition=1/test1", defaultTs) shouldBe FileValidResults.Ok 79 | validator.isValidNewFile("/path/partition=11/test1.csv", defaultTs) shouldBe FileValidResults.Ok 80 | validator.isValidNewFile("/path/partition=2/test1.csv", defaultTs) shouldBe FileValidResults.PatternNotMatch 81 | 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/test/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorSourceOptionsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.amazonaws.spark.sql.streaming.connector 19 | 20 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions.{MAX_FILES_PER_TRIGGER, QUEUE_REGION, QUEUE_URL, REPROCESS_DRY_RUN, REPROCESS_END_BATCH_ID, REPROCESS_START_BATCH_ID, S3_FILE_FORMAT} 21 | import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper 22 | 23 | import org.apache.spark.network.util.JavaUtils 24 | 25 | class S3ConnectorSourceOptionsSuite extends S3ConnectorTestBase { 26 | test("option uses default values") { 27 | val option = S3ConnectorSourceOptions( Map( 28 | S3_FILE_FORMAT -> TESTBASE_DEFAULT_FILE_FORMAT, 29 | QUEUE_URL -> TESTBASE_DEFAULT_QUEUE_URL, 30 | QUEUE_REGION -> TESTBASE_DEFAULT_QUEUE_REGION 31 | )) 32 | 33 | option.fileFormat shouldBe TESTBASE_DEFAULT_FILE_FORMAT 34 | option.queueUrl shouldBe TESTBASE_DEFAULT_QUEUE_URL 35 | option.queueRegion shouldBe TESTBASE_DEFAULT_QUEUE_REGION 36 | option.queueType shouldBe S3ConnectorSourceOptions.SQS_QUEUE 37 | option.queueFetchWaitTimeoutSeconds shouldBe 2 * option.sqsLongPollWaitTimeSeconds 38 | option.pathGlobFilter shouldBe None 39 | option.partitionColumns shouldBe None 40 | option.maxFileAgeMs shouldBe JavaUtils.timeStringAsMs(S3ConnectorSourceOptions.MAX_FILE_AGE_DEFAULT_VALUE) 41 | option.maxFilesPerTrigger shouldBe Some(S3ConnectorSourceOptions.MAX_FILES_PER_TRIGGER_DEFAULT_VALUE) 42 | option.reprocessStartBatchId shouldBe None 43 | option.reprocessEndBatchId shouldBe None 44 | option.reprocessState shouldBe ReprocessStates.NoReprocess 45 | option.sqsMaxRetries shouldBe S3ConnectorSourceOptions.SQS_MAX_RETRIES_DEFAULT_VALUE 46 | option.sqsMaxConcurrency shouldBe S3ConnectorSourceOptions.SQS_MAX_CONCURRENCY_DEFAULT_VALUE 47 | option.sqsLongPollWaitTimeSeconds shouldBe 48 | S3ConnectorSourceOptions.SQS_LONG_POLLING_WAIT_TIME_SECONDS_DEFAULT_VALUE 49 | option.sqsVisibilityTimeoutSeconds shouldBe S3ConnectorSourceOptions.SQS_VISIBILITY_TIMEOUT_DEFAULT_VALUE 50 | option.sqsKeepMessageForConsumerError shouldBe 51 | S3ConnectorSourceOptions.SQS_KEEP_MESSAGE_FOR_CONSUMER_ERROR_DEFAULT_VALUE 52 | 53 | } 54 | 55 | test("maxFilesPerTrigger given a value") { 56 | val option = S3ConnectorSourceOptions(defaultOptionMap + (MAX_FILES_PER_TRIGGER -> "50")) 57 | option.maxFilesPerTrigger shouldBe Some(50) 58 | } 59 | 60 | test("maxFilesPerTrigger set to -1") { 61 | val option = S3ConnectorSourceOptions(defaultOptionMap + (MAX_FILES_PER_TRIGGER -> "-1")) 62 | option.maxFilesPerTrigger shouldBe None 63 | } 64 | 65 | test("only reprocessStartBatchId set") { 66 | val option = S3ConnectorSourceOptions(defaultOptionMap 67 | + (REPROCESS_START_BATCH_ID -> "50") 68 | ) 69 | option.reprocessState shouldBe ReprocessStates.NoReprocess 70 | } 71 | 72 | test("only reprocessEndBatchId set") { 73 | val option = S3ConnectorSourceOptions(defaultOptionMap 74 | + (REPROCESS_END_BATCH_ID -> "50") 75 | ) 76 | option.reprocessState shouldBe ReprocessStates.NoReprocess 77 | } 78 | 79 | test("Both reprocessStartBatchId and reprocessEndBatchId set") { 80 | val option = S3ConnectorSourceOptions(defaultOptionMap 81 | + (REPROCESS_START_BATCH_ID -> "50", 82 | REPROCESS_END_BATCH_ID -> "60") 83 | ) 84 | option.reprocessState shouldBe ReprocessStates.DryRun 85 | } 86 | 87 | test("Both reprocessDryRun set") { 88 | val option = S3ConnectorSourceOptions(defaultOptionMap 89 | + (REPROCESS_START_BATCH_ID -> "50", 90 | REPROCESS_END_BATCH_ID -> "60", 91 | REPROCESS_DRY_RUN -> "false" 92 | ) 93 | ) 94 | option.reprocessState shouldBe ReprocessStates.InAction 95 | } 96 | 97 | test("reprocessStartBatchId larger than reprocessEndBatchId") { 98 | 99 | val exception = intercept[IllegalArgumentException] { 100 | S3ConnectorSourceOptions(defaultOptionMap 101 | + (REPROCESS_START_BATCH_ID -> "70", 102 | REPROCESS_END_BATCH_ID -> "60") 103 | ) 104 | } 105 | 106 | exception.getMessage should include("reprocessStartBatchId must be less than or equal to reprocessEndBatchId") 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /src/test/scala/com/amazonaws/spark/sql/streaming/connector/S3ConnectorTestBase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector 18 | 19 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions.{QUEUE_FETCH_WAIT_TIMEOUT_SECONDS, QUEUE_REGION, QUEUE_URL, S3_FILE_FORMAT} 20 | import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} 21 | import org.scalatest.matchers.must.Matchers 22 | 23 | import org.apache.spark.SparkFunSuite 24 | import org.apache.spark.internal.Logging 25 | 26 | 27 | trait S3ConnectorTestBase extends SparkFunSuite 28 | with Matchers 29 | with BeforeAndAfterAll 30 | with BeforeAndAfter with Logging{ 31 | 32 | val TESTBASE_DEFAULT_FILE_FORMAT = "csv" 33 | val TESTBASE_DEFAULT_QUEUE_URL = "testurl" 34 | val TESTBASE_DEFAULT_QUEUE_REGION = "testregion" 35 | val TESTBASE_DEFAULT_QUEUE_FETCH_WAIT_TIMEOUT = "5" 36 | 37 | val defaultOptionMap = Map( 38 | S3_FILE_FORMAT -> TESTBASE_DEFAULT_FILE_FORMAT, 39 | QUEUE_URL -> TESTBASE_DEFAULT_QUEUE_URL, 40 | QUEUE_REGION -> TESTBASE_DEFAULT_QUEUE_REGION, 41 | QUEUE_FETCH_WAIT_TIMEOUT_SECONDS -> TESTBASE_DEFAULT_QUEUE_FETCH_WAIT_TIMEOUT 42 | ) 43 | } 44 | -------------------------------------------------------------------------------- /src/test/scala/com/amazonaws/spark/sql/streaming/connector/TestUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector 18 | 19 | import com.amazonaws.spark.sql.streaming.connector.client.{AsyncClientMetrics, AsyncQueueClient} 20 | import com.codahale.metrics.Counter 21 | import org.apache.spark.internal.Logging 22 | import org.apache.spark.sql.streaming.connector.s3.S3SparkUtils 23 | import org.mockito.stubbing.Stubber 24 | import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper 25 | 26 | import java.io.{File, IOException} 27 | import java.util.UUID 28 | 29 | object TestUtils extends Logging { 30 | 31 | // Set this parameter to false to keep files in local temp dir 32 | val TEMP_DIR_CLEAN_UP: Boolean = false 33 | 34 | def doReturnMock(value: Any): Stubber = org.mockito.Mockito.doReturn(value, Seq.empty: _*) 35 | 36 | // scalastyle:off argcount 37 | def verifyMetrics(metrics: AsyncClientMetrics, 38 | expectedReceiveMessageCount: Int = 0, 39 | expectedReceiveMessageFailedCount: Int = 0, 40 | expectedParseMessageCount: Int = 0, 41 | expectedParseMessageFailedCount: Int = 0, 42 | expectedDiscardedMessageCount: Int = 0, 43 | expectedConsumeMessageCount: Int = 0, 44 | expectedConsumeMessageFailedCount: Int = 0, 45 | expectedDeleteMessageCount: Int = 0, 46 | expectedDeleteMessageFailedCount: Int = 0, 47 | expectedSetMessageVisibilityCount: Int = 0, 48 | expectedSetMessageVisibilityFailedCount: Int = 0, 49 | expectedFetchThreadConsumeMessageCount: Int = 0, 50 | expectedFetchThreadConsumeMessageFailedCount: Int = 0, 51 | expectedFetchThreadUncaughtExceptionCount: Int = 0 52 | ): Unit = { 53 | 54 | metrics.receiveMessageCounter.getCount shouldBe expectedReceiveMessageCount 55 | metrics.receiveMessageFailedCounter.getCount shouldBe expectedReceiveMessageFailedCount 56 | metrics.parseMessageCounter.getCount shouldBe expectedParseMessageCount 57 | metrics.parseMessageFailedCounter.getCount shouldBe expectedParseMessageFailedCount 58 | metrics.discardedMessageCounter.getCount shouldBe expectedDiscardedMessageCount 59 | metrics.consumeMessageCounter.getCount shouldBe expectedConsumeMessageCount 60 | metrics.consumeMessageFailedCounter.getCount shouldBe expectedConsumeMessageFailedCount 61 | metrics.deleteMessageCounter.getCount shouldBe expectedDeleteMessageCount 62 | metrics.deleteMessageFailedCounter.getCount shouldBe expectedDeleteMessageFailedCount 63 | metrics.setMessageVisibilityCounter.getCount shouldBe expectedSetMessageVisibilityCount 64 | metrics.setMessageVisibilityFailedCounter.getCount shouldBe expectedSetMessageVisibilityFailedCount 65 | metrics.fetchThreadConsumeMessageCounter.getCount shouldBe expectedFetchThreadConsumeMessageCount 66 | metrics.fetchThreadConsumeMessageFailedCounter.getCount shouldBe expectedFetchThreadConsumeMessageFailedCount 67 | metrics.fetchThreadUncaughtExceptionCounter.getCount shouldBe expectedFetchThreadUncaughtExceptionCount 68 | 69 | } 70 | // scalastyle:on argcount 71 | 72 | def withTestTempDir(f: File => Unit): Unit = { 73 | val dir = createDirectory("testDir", "S3-connector") 74 | try f(dir) finally { 75 | if (TEMP_DIR_CLEAN_UP) S3SparkUtils.deleteRecursively(dir) 76 | } 77 | } 78 | 79 | def createDirectory(root: String, namePrefix: String = "spark"): File = { 80 | var attempts = 0 81 | val maxAttempts = 15 82 | var dir: File = null 83 | while (dir == null) { 84 | attempts += 1 85 | if (attempts > maxAttempts) { 86 | throw new IOException("Failed to create a temp directory (under " + root + ") after " + 87 | maxAttempts + " attempts!") 88 | } 89 | try { 90 | dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) 91 | if (dir.exists() || !dir.mkdirs()) { 92 | dir = null 93 | } 94 | } catch { 95 | case e: SecurityException => dir = null; 96 | } 97 | } 98 | 99 | dir.getCanonicalFile 100 | } 101 | 102 | def recursiveList(f: File): Array[File] = { 103 | require(f.isDirectory) 104 | val current = f.listFiles 105 | current ++ current.filter(_.isDirectory).flatMap(recursiveList) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/test/scala/com/amazonaws/spark/sql/streaming/connector/metadataLog/RocksDBS3SourceLogSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.amazonaws.spark.sql.streaming.connector.metadataLog 18 | 19 | import com.amazonaws.spark.sql.streaming.connector.TestUtils.withTestTempDir 20 | import com.amazonaws.spark.sql.streaming.connector.{QueueMessageDesc, S3ConnectorFileCache, S3ConnectorTestBase} 21 | import org.apache.hadoop.fs.FileSystem 22 | import org.apache.spark.sql.execution.streaming.FileStreamSource.FileEntry 23 | import org.apache.spark.sql.internal.SQLConf 24 | import org.apache.spark.sql.test.SharedSparkSession 25 | import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper 26 | import org.scalatest.time.SpanSugar.convertLongToGrainOfTime 27 | 28 | import java.io.File 29 | 30 | 31 | class RocksDBS3SourceLogSuite extends S3ConnectorTestBase with SharedSparkSession{ 32 | 33 | var testS3SourceLog: RocksDBS3SourceLog = _ 34 | var fileCache: S3ConnectorFileCache[String] = _ 35 | 36 | before { 37 | testS3SourceLog = new RocksDBS3SourceLog() 38 | } 39 | 40 | after { 41 | logInfo("close testS3SourceLog") 42 | testS3SourceLog.close() 43 | } 44 | 45 | private def listBatchFiles(fs: FileSystem, sourceLog: RocksDBS3SourceLog): Set[String] = { 46 | fs.listStatus(sourceLog.metadataPath).map(_.getPath).filter { fileName => 47 | try { 48 | com.amazonaws.spark.sql.streaming.connector.metadataLog.RocksDBS3SourceLog.isBatchFile(fileName) 49 | } catch { 50 | case _: NumberFormatException => false 51 | } 52 | }.map(_.getName).toSet 53 | } 54 | 55 | def initTestS3SourceLog(parentDir: File): Unit = { 56 | val checkpointDir = s"${parentDir}/checkpoint/" 57 | fileCache = new S3ConnectorFileCache(1000) 58 | testS3SourceLog.init(spark, checkpointDir, fileCache) 59 | } 60 | 61 | test("delete expired logs") { 62 | withSQLConf( 63 | SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2") { 64 | withTestTempDir { tempDir => 65 | initTestS3SourceLog(tempDir) 66 | val fs = testS3SourceLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) 67 | val waitTime = 3 * testS3SourceLog.maintenanceInterval 68 | def listBatchFiles(): Set[String] = this.listBatchFiles(fs, testS3SourceLog) 69 | 70 | testS3SourceLog.add(0, Array(FileEntry("/test/path0", 1000, 0))) 71 | eventually(timeout(waitTime.milliseconds)) { 72 | assert(Set("1.zip") === listBatchFiles()) 73 | } 74 | 75 | testS3SourceLog.add(1, Array(FileEntry("/test/path1", 2000, 1))) 76 | eventually(timeout(waitTime.milliseconds)) { 77 | assert(Set("1.zip", "2.zip") === listBatchFiles()) 78 | } 79 | 80 | testS3SourceLog.add(2, Array(FileEntry("/test/path2", 3000, 2))) 81 | eventually(timeout(waitTime.milliseconds)) { 82 | assert(Set("2.zip", "3.zip") === listBatchFiles()) 83 | } 84 | } 85 | } 86 | } 87 | 88 | test("add S3 Source logs successfully") { 89 | withTestTempDir { tempDir => 90 | initTestS3SourceLog(tempDir) 91 | 92 | val fileArray = Array( 93 | FileEntry("/test/path1", 1000, 0), 94 | FileEntry("/test/path2", 2000, 0), 95 | ) 96 | 97 | val addResult = testS3SourceLog.add(0, fileArray) 98 | addResult shouldBe true 99 | 100 | val getResult = testS3SourceLog.get(0) 101 | getResult.get shouldEqual fileArray 102 | } 103 | } 104 | 105 | test("add duplicate batch should return false") { 106 | withTestTempDir { tempDir => 107 | initTestS3SourceLog(tempDir) 108 | 109 | val fileArray = Array( 110 | FileEntry("/test/path1", 1000, 0), 111 | FileEntry("/test/path2", 2000, 0), 112 | ) 113 | 114 | val addResult = testS3SourceLog.add(0, fileArray) 115 | addResult shouldBe true 116 | 117 | val addResult2 = testS3SourceLog.add(0, fileArray) 118 | addResult2 shouldBe false 119 | } 120 | 121 | 122 | } 123 | test("get range success") { 124 | withTestTempDir { tempDir => 125 | initTestS3SourceLog(tempDir) 126 | 127 | val fileArray1 = Array( 128 | FileEntry("/test/path1", 1000, 0), 129 | FileEntry("/test/path2", 1000, 0), 130 | ) 131 | 132 | val fileArray2 = Array( 133 | FileEntry("/test2/path1", 2000, 1), 134 | FileEntry("/test2/path2", 2000, 1), 135 | ) 136 | 137 | val fileArray3 = Array( 138 | FileEntry("/test3/path1", 3000, 2), 139 | FileEntry("/test3/path2", 3000, 2), 140 | ) 141 | 142 | testS3SourceLog.add(0, fileArray1) 143 | testS3SourceLog.add(1, fileArray2) 144 | testS3SourceLog.add(2, fileArray3) 145 | 146 | val getResult = testS3SourceLog.get(Some(0), Some(2)) 147 | getResult(0)._1 shouldEqual 0 148 | getResult(0)._2 shouldEqual fileArray1 149 | getResult(1)._1 shouldEqual 1 150 | getResult(1)._2 shouldEqual fileArray2 151 | getResult(2)._1 shouldEqual 2 152 | getResult(2)._2 shouldEqual fileArray3 153 | 154 | val getResult2 = testS3SourceLog.get(None, None) 155 | getResult2(0)._1 shouldEqual 0 156 | getResult2(0)._2 shouldEqual fileArray1 157 | getResult2(1)._1 shouldEqual 1 158 | getResult2(1)._2 shouldEqual fileArray2 159 | getResult2(2)._1 shouldEqual 2 160 | getResult2(2)._2 shouldEqual fileArray3 161 | } 162 | } 163 | 164 | test("get range exception") { 165 | withTestTempDir { tempDir => 166 | initTestS3SourceLog(tempDir) 167 | 168 | val fileArray1 = Array( 169 | FileEntry("/test/path1", 1000, 0), 170 | FileEntry("/test/path2", 1000, 0), 171 | ) 172 | 173 | testS3SourceLog.add(0, fileArray1) 174 | val exception = intercept[IllegalArgumentException] { 175 | testS3SourceLog.get(Some(0), Some(1)) 176 | } 177 | 178 | exception.getMessage should include ("batch 1 not found") 179 | } 180 | } 181 | 182 | test("check isNewFile or not") { 183 | withTestTempDir { tempDir => 184 | initTestS3SourceLog(tempDir) 185 | 186 | val fileArray = Array( 187 | FileEntry("/test/path1", 1000, 0), 188 | FileEntry("/test/path2", 2000, 0), 189 | ) 190 | 191 | val newFile = FileEntry("/test/path3", 3000, 1) 192 | 193 | testS3SourceLog.add(fileArray(0).batchId, fileArray) 194 | 195 | val result1 = testS3SourceLog.isNewFile(fileArray(0).path, fileArray(0).timestamp - 5000) 196 | val result2 = testS3SourceLog.isNewFile(newFile.path, newFile.timestamp - 5000) 197 | 198 | result1 shouldBe false 199 | result2 shouldBe true 200 | 201 | testS3SourceLog.add(newFile.batchId, Array(newFile)) 202 | 203 | val result3 = testS3SourceLog.isNewFile(newFile.path, newFile.timestamp - 5000) 204 | result3 shouldBe false 205 | 206 | val result4 = testS3SourceLog.isNewFile(fileArray(0).path, 3000) 207 | result4 shouldBe true 208 | 209 | val result5 = testS3SourceLog.isNewFile(newFile.path, 3000) 210 | result5 shouldBe false 211 | } 212 | } 213 | 214 | test("verify old logs are cleaned") { 215 | withTestTempDir { tempDir => 216 | initTestS3SourceLog(tempDir) 217 | 218 | val fileArray1 = Array( 219 | FileEntry("/test/path11", 100, 1), 220 | FileEntry("/test/path12", 100, 1), 221 | ) 222 | 223 | val fileArray2 = Array( 224 | FileEntry("/test/path21", 200, 2), 225 | FileEntry("/test/path22", 200, 2), 226 | ) 227 | 228 | val fileArray6 = Array( 229 | FileEntry("/test/path61", 300, 6), 230 | FileEntry("/test/path62", 600, 6), 231 | ) 232 | 233 | val fileArray15 = Array( 234 | FileEntry("/test/path151", 1500, 15), 235 | FileEntry("/test/path152", 1500, 15), 236 | ) 237 | 238 | val fileArray35 = Array( 239 | FileEntry("/test/path351", 3500, 35), 240 | FileEntry("/test/path352", 3500, 35), 241 | ) 242 | 243 | val fileArray50 = Array( 244 | FileEntry("/test/path501", 8000, 50), 245 | FileEntry("/test/path502", 8000, 50), 246 | ) 247 | 248 | 249 | testS3SourceLog.add(1, fileArray1, Some(300)) 250 | testS3SourceLog.add(2, fileArray2, Some(300)) 251 | testS3SourceLog.add(6, fileArray6, Some(600)) 252 | 253 | // move lastPurgeTimestamp 254 | fileCache.add( 255 | "/test/path151", 256 | QueueMessageDesc(1500, isProcessed = false, Some("id151")) 257 | ) 258 | fileCache.purge() 259 | 260 | testS3SourceLog.add(15, fileArray15, Some(5000)) 261 | testS3SourceLog.add(35, fileArray35, Some(5000)) 262 | 263 | 264 | testS3SourceLog.get(1) shouldEqual None 265 | testS3SourceLog.get(2) shouldEqual None 266 | testS3SourceLog.get(6).get shouldEqual fileArray6 267 | testS3SourceLog.get(15).get shouldEqual fileArray15 268 | testS3SourceLog.get(35).get shouldEqual fileArray35 269 | 270 | testS3SourceLog.getFile("/test/path11") shouldEqual None 271 | testS3SourceLog.getFile("/test/path12") shouldEqual None 272 | testS3SourceLog.getFile("/test/path21") shouldEqual None 273 | testS3SourceLog.getFile("/test/path22") shouldEqual None 274 | testS3SourceLog.getFile("/test/path61") shouldEqual Some(300) 275 | testS3SourceLog.getFile("/test/path62") shouldEqual Some(600) 276 | testS3SourceLog.getFile("/test/path151") shouldEqual Some(1500) 277 | testS3SourceLog.getFile("/test/path152") shouldEqual Some(1500) 278 | testS3SourceLog.getFile("/test/path351") shouldEqual Some(3500) 279 | testS3SourceLog.getFile("/test/path352") shouldEqual Some(3500) 280 | 281 | // move lastPurgeTimestamp 282 | fileCache.add( 283 | "/test/path501", 284 | QueueMessageDesc(8000, isProcessed = false, Some("id501")) 285 | ) 286 | fileCache.purge() 287 | 288 | testS3SourceLog.add(50, fileArray50, Some(9000)) 289 | 290 | testS3SourceLog.get(1) shouldEqual None 291 | testS3SourceLog.get(2) shouldEqual None 292 | testS3SourceLog.get(6) shouldEqual None 293 | testS3SourceLog.get(15) shouldEqual None 294 | testS3SourceLog.get(35) shouldEqual None 295 | testS3SourceLog.get(50).get shouldEqual fileArray50 296 | 297 | testS3SourceLog.getFile("/test/path11") shouldEqual None 298 | testS3SourceLog.getFile("/test/path12") shouldEqual None 299 | testS3SourceLog.getFile("/test/path21") shouldEqual None 300 | testS3SourceLog.getFile("/test/path22") shouldEqual None 301 | testS3SourceLog.getFile("/test/path61") shouldEqual None 302 | testS3SourceLog.getFile("/test/path62") shouldEqual None 303 | testS3SourceLog.getFile("/test/path151") shouldEqual None 304 | testS3SourceLog.getFile("/test/path152") shouldEqual None 305 | testS3SourceLog.getFile("/test/path351") shouldEqual None 306 | testS3SourceLog.getFile("/test/path352") shouldEqual None 307 | testS3SourceLog.getFile("/test/path501") shouldEqual Some(8000) 308 | testS3SourceLog.getFile("/test/path502") shouldEqual Some(8000) 309 | } 310 | } 311 | } 312 | -------------------------------------------------------------------------------- /src/test/scala/it/spark/sql/streaming/connector/ItTestUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package it.spark.sql.streaming.connector 18 | 19 | import com.amazonaws.services.s3.AmazonS3ClientBuilder 20 | import com.amazonaws.services.s3.transfer.TransferManagerBuilder 21 | import com.amazonaws.spark.sql.streaming.connector.TestUtils.recursiveList 22 | import org.apache.spark.internal.Logging 23 | import software.amazon.awssdk.core.sync.RequestBody 24 | import software.amazon.awssdk.http.apache.ApacheHttpClient 25 | import software.amazon.awssdk.regions.Region 26 | import software.amazon.awssdk.services.s3.S3Client 27 | import software.amazon.awssdk.services.s3.model.{DeleteObjectRequest, ListObjectsV2Request, ListObjectsV2Response, PutObjectRequest} 28 | 29 | import java.io.{File, FileInputStream} 30 | import java.net.URI 31 | import java.nio.file.{FileSystems, Files, Path} 32 | import scala.collection.JavaConversions._ 33 | import scala.collection.JavaConverters.asScalaIteratorConverter 34 | 35 | 36 | 37 | object ItTestUtils extends Logging { 38 | // Set this parameter to false to keep files in S3 39 | val TEMP_S3_CLEAN_UP: Boolean = false 40 | 41 | def getUploadS3Path: String = { 42 | Option(System.getenv("TEST_UPLOAD_S3_PATH")) 43 | .map{timebasedSubDir} 44 | .getOrElse{ 45 | throw new IllegalArgumentException("Env variable TEST_UPLOAD_S3_PATH " + 46 | "for uploading test new files must be defined") 47 | } 48 | } 49 | 50 | def getTestRegion: String = { 51 | Option(System.getenv("TEST_REGION")).getOrElse{ 52 | throw new IllegalArgumentException("Env variable TEST_REGION must be defined") 53 | } 54 | } 55 | 56 | def getQueueUrl: String = { 57 | Option(System.getenv("TEST_QUEUE_URL")) 58 | .getOrElse{ 59 | throw new IllegalArgumentException("Env variable TEST_QUEUE_URL for queue must be defined") 60 | } 61 | } 62 | 63 | def getCrossAccountUploadS3Path: String = { 64 | Option(System.getenv("CROSS_ACCOUNT_TEST_UPLOAD_S3_PATH")) 65 | .map{timebasedSubDir} 66 | .getOrElse{ 67 | throw new IllegalArgumentException("Env variable CROSS_ACCOUNT_TEST_UPLOAD_S3_PATH " + 68 | "for uploading test new files must be defined") 69 | } 70 | } 71 | 72 | def getCrossAccountTestRegion: String = { 73 | Option(System.getenv("CROSS_ACCOUNT_TEST_REGION")).getOrElse{ 74 | throw new IllegalArgumentException("Env variable CROSS_ACCOUNT_TEST_REGION must be defined") 75 | } 76 | } 77 | 78 | def getCrossAccountQueueUrl: String = { 79 | Option(System.getenv("CROSS_ACCOUNT_TEST_QUEUE_URL")) 80 | .getOrElse{ 81 | throw new IllegalArgumentException("Env variable CROSS_ACCOUNT_TEST_QUEUE_URL for queue must be defined") 82 | } 83 | } 84 | 85 | def timebasedSubDir(dir: String): String = { 86 | if (dir.endsWith("/")) { 87 | s"${dir}s3dir_${System.currentTimeMillis}/" 88 | } 89 | else { 90 | s"${dir}/s3dir_${System.currentTimeMillis}/" 91 | } 92 | } 93 | 94 | def recursiveUploadNewFilesToS3(uploadS3Path: String, 95 | testDataPath: String, 96 | suffix: String, 97 | region: String 98 | ): Unit = { 99 | 100 | val dir = FileSystems.getDefault.getPath(testDataPath).toFile 101 | val files: java.util.List[File] = recursiveList(dir) 102 | .filter { file => 103 | if (suffix.nonEmpty) file.getName.endsWith(suffix) 104 | else true 105 | } 106 | .toList 107 | 108 | val s3Uri = URI.create(uploadS3Path) 109 | val bucketName = s3Uri.getHost; 110 | val prefix = s3Uri.getPath.substring(1); 111 | 112 | val s3Client = AmazonS3ClientBuilder.standard().withRegion(region).build(); 113 | val xfer_mgr = TransferManagerBuilder.standard().withS3Client(s3Client).build() 114 | 115 | try { 116 | val xfer = xfer_mgr.uploadFileList(bucketName, 117 | prefix, dir, files); 118 | xfer.waitForCompletion() 119 | val progress = xfer.getProgress 120 | val so_far = progress.getBytesTransferred 121 | val total = progress.getTotalBytesToTransfer 122 | val pct = progress.getPercentTransferred 123 | val xfer_state = xfer.getState 124 | logInfo(s"xfer_state: ${xfer_state}, so_far: ${so_far}, total: ${total}, pct: ${pct}") 125 | } 126 | finally { 127 | xfer_mgr.shutdownNow() 128 | } 129 | } 130 | 131 | def uploadSingleNewFileToS3(uploadS3Path: String, testDataPath: String, suffix: String): Unit = { 132 | 133 | val s3Client: S3Client = S3Client.builder.httpClientBuilder(ApacheHttpClient.builder).build 134 | try { 135 | val (bucketName, prefix) = getS3URI(uploadS3Path) 136 | 137 | val dir = FileSystems.getDefault.getPath(testDataPath) 138 | 139 | Files.list(dir).iterator().asScala 140 | .filter(_.getFileName.toString.endsWith(suffix)) 141 | .foreach { f => 142 | s3Client.putObject(PutObjectRequest.builder.bucket(bucketName) 143 | .key(prefix + f.getFileName).build, 144 | RequestBody.fromBytes(getObjectFile(f))) 145 | 146 | logInfo(s"Upload completed for ${f}") 147 | } 148 | } 149 | finally { 150 | s3Client.close() 151 | } 152 | } 153 | 154 | def removeTempS3FolderIfEnabled(s3Path: String, region: String): Unit = { 155 | if (TEMP_S3_CLEAN_UP) { 156 | logInfo(s"Removing files in ${s3Path}") 157 | val s3Client: S3Client = S3Client.builder 158 | .httpClientBuilder(ApacheHttpClient.builder) 159 | .region(Region.of(region)) 160 | .build 161 | try { 162 | val (bucketName, prefix) = getS3URI(s3Path) 163 | var req = ListObjectsV2Request.builder 164 | .bucket(bucketName) 165 | .prefix(prefix) 166 | .build 167 | 168 | var result: ListObjectsV2Response = null 169 | do { 170 | result = s3Client.listObjectsV2(req) 171 | for (s3Object <- result.contents) { 172 | val deleteObjectRequest = DeleteObjectRequest.builder 173 | .bucket(bucketName) 174 | .key(s3Object.key()) 175 | .build 176 | s3Client.deleteObject(deleteObjectRequest) 177 | } 178 | // If there are more than maxKeys keys in the bucket, get a continuation token 179 | // and list the next objects. 180 | val token = result.nextContinuationToken 181 | req = req.toBuilder.continuationToken(token).build 182 | } while (result.isTruncated) 183 | } 184 | finally { 185 | s3Client.close() 186 | } 187 | } 188 | 189 | } 190 | 191 | def getS3URI(s3Path: String): (String, String) = { 192 | val s3Uri = URI.create(s3Path) 193 | val bucketName = s3Uri.getHost 194 | val prefix = s3Uri.getPath.substring(1) 195 | (bucketName, prefix) 196 | } 197 | 198 | // Return a byte array. 199 | private def getObjectFile(filePath: Path) = { 200 | 201 | val file = filePath.toFile 202 | val bytesArray: Array[Byte] = new Array[Byte](file.length.asInstanceOf[Int]) 203 | var fileInputStream: FileInputStream = null 204 | 205 | try { 206 | fileInputStream = new FileInputStream(file) 207 | fileInputStream.read(bytesArray) 208 | } 209 | finally if (fileInputStream != null) fileInputStream.close() 210 | bytesArray 211 | } 212 | 213 | } 214 | -------------------------------------------------------------------------------- /src/test/scala/it/spark/sql/streaming/connector/QueueTestBase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package it.spark.sql.streaming.connector 18 | 19 | import com.amazonaws.ClientConfiguration 20 | import com.amazonaws.auth.DefaultAWSCredentialsProviderChain 21 | import com.amazonaws.services.sqs.model.{GetQueueAttributesRequest, PurgeQueueRequest} 22 | import com.amazonaws.services.sqs.{AmazonSQS, AmazonSQSClientBuilder} 23 | import it.spark.sql.streaming.connector.SqsTest.{PURGE_RETRY_INTERVAL_MS, PURGE_WAIT_TIME_SECONDS, RETRY_INTERVAL_MS, SLEEP_TIME_MID, WAIT_TIME_SECONDS} 24 | import org.apache.spark.internal.Logging 25 | import org.scalatest.concurrent.Eventually.{eventually, interval} 26 | import org.scalatest.concurrent.Futures.timeout 27 | import org.scalatest.time.SpanSugar.convertIntToGrainOfTime 28 | 29 | import scala.util.{Failure, Try} 30 | 31 | trait QueueTestBase { 32 | def testQueueUrl: String 33 | 34 | protected def purgeQueue(url: String): Unit 35 | protected def waitForQueueReady(url: String, msgCount: Int): Unit 36 | } 37 | 38 | trait QueueTestClientBase[T] { 39 | def testClient: T 40 | protected def getClient(region: String): T 41 | } 42 | 43 | trait SqsTest extends QueueTestBase with QueueTestClientBase[AmazonSQS] with Logging { 44 | 45 | override def getClient(region: String): AmazonSQS = { 46 | AmazonSQSClientBuilder 47 | .standard() 48 | .withClientConfiguration(new ClientConfiguration().withMaxConnections(1)) 49 | .withCredentials(new DefaultAWSCredentialsProviderChain()) 50 | .withRegion(region) 51 | .build() 52 | } 53 | 54 | override def purgeQueue(url: String): Unit = { 55 | 56 | val purgeRequest = new PurgeQueueRequest(url) 57 | Thread.sleep(SLEEP_TIME_MID) 58 | 59 | eventually(timeout(PURGE_WAIT_TIME_SECONDS.seconds), 60 | interval((PURGE_RETRY_INTERVAL_MS*10).milliseconds)) { 61 | Try(testClient.purgeQueue(purgeRequest)) match { 62 | case Failure(e) => 63 | logInfo("purgeQueue failed", e) 64 | throw e 65 | case _ => 66 | Thread.sleep(SLEEP_TIME_MID) 67 | logInfo("purgeQueue success. Checking queue status.") 68 | waitForQueueReady(url, 0) 69 | } 70 | } 71 | } 72 | 73 | override def waitForQueueReady(url: String, msgCount: Int): Unit = { 74 | eventually(timeout(WAIT_TIME_SECONDS.seconds), interval((RETRY_INTERVAL_MS*10).milliseconds)) { 75 | val attrs = testClient.getQueueAttributes(new GetQueueAttributesRequest(url) 76 | .withAttributeNames("ApproximateNumberOfMessages", 77 | "ApproximateNumberOfMessagesNotVisible")) 78 | val approximateNumberOfMessages = attrs.getAttributes.get("ApproximateNumberOfMessages").toInt 79 | val approximateNumberOfMessagesNotVisible = attrs.getAttributes.get("ApproximateNumberOfMessagesNotVisible").toInt 80 | assert( approximateNumberOfMessages == msgCount, 81 | s"ApproximateNumberOfMessages is ${approximateNumberOfMessages}, expected: ${msgCount}") 82 | assert( approximateNumberOfMessagesNotVisible == 0, 83 | s"ApproximateNumberOfMessagesNotVisible is ${approximateNumberOfMessagesNotVisible}, expected 0") 84 | } 85 | } 86 | } 87 | 88 | object SqsTest { 89 | val SLEEP_TIME_SHORT = 5000 90 | val SLEEP_TIME_MID = 10000 91 | 92 | val WAIT_TIME_SECONDS = 60 93 | val RETRY_INTERVAL_MS = 500 94 | 95 | val PURGE_WAIT_TIME_SECONDS = WAIT_TIME_SECONDS * 3 // must > 60, as SQS purge only allow once per minute 96 | val PURGE_RETRY_INTERVAL_MS = 5000 97 | } -------------------------------------------------------------------------------- /src/test/scala/it/spark/sql/streaming/connector/S3ConnectorItBase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package it.spark.sql.streaming.connector 19 | 20 | import org.apache.spark.SparkConf 21 | import org.apache.spark.sql.catalyst.plans.logical 22 | import org.apache.spark.sql.functions.col 23 | import org.apache.spark.sql.streaming.{StreamTest, StreamingQuery} 24 | import org.apache.spark.sql.types._ 25 | import org.apache.spark.sql.{DataFrame, Row, SparkSession} 26 | 27 | import java.io.File 28 | 29 | trait S3ConnectorItBase extends StreamTest{ 30 | val SOURCE_SHORT_NAME: String = "s3-connector" 31 | val TEST_SPARK_PARTITION = 2 32 | val TEST_MAX_FILES_PER_TRIGGER = 5 33 | 34 | def uploadS3Path: String 35 | def testRegion: String 36 | 37 | val testSchema: StructType = StructType(Array( 38 | StructField("testString", StringType, nullable = true), 39 | StructField("testBoolean", BooleanType, nullable = true), 40 | StructField("testInt", IntegerType, nullable = true) 41 | )) 42 | 43 | val testRawData: Seq[Row] = Seq( 44 | Row("James", true, 3000), 45 | Row("Michael", false, 5000), 46 | Row("Robert", false, 5000) 47 | ) 48 | 49 | val testSchemaWithPartition: StructType = StructType(Array( 50 | StructField("testString", StringType, nullable = true), 51 | StructField("testBoolean", BooleanType, nullable = true), 52 | StructField("testInt", IntegerType, nullable = true), 53 | StructField("testPart1", StringType, nullable = false), 54 | StructField("testPart2", IntegerType, nullable = false) 55 | )) 56 | 57 | val testRawDataWithPartition: Seq[Row] = Seq( 58 | Row("James", true, 3000, "p1", 1), 59 | Row("Michael", false, 5000, "p1", 1), 60 | Row("Robert", false, 5000, "p1", 2), 61 | Row("James2", true, 3000, "p2", 1), 62 | Row("Michael2", false, 5000, "p2", 1), 63 | Row("Robert2", false, 5000, "p2", 3), 64 | ) 65 | 66 | override protected def sparkConf: SparkConf = { 67 | val conf = super.sparkConf 68 | conf.set("spark.sql.ui.explainMode", "extended") 69 | .set("spark.hadoop.fs.s3.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") 70 | .set("spark.hadoop.fs.s3a.aws.credentials.provider", 71 | "com.amazonaws.auth.EnvironmentVariableCredentialsProvider") 72 | } 73 | 74 | 75 | protected def getDefaultOptions(fileFormat: String): Map[String, String] 76 | 77 | protected def getTestDataFrame(data: Seq[Row], schema: StructType): DataFrame = { 78 | val df = spark.createDataFrame( 79 | spark.sparkContext.parallelize(data), schema) 80 | 81 | df 82 | } 83 | 84 | def createTestCSVFiles(parentDir: File): String = { 85 | val df = getTestDataFrame(testRawData, testSchema) 86 | 87 | val saveTo = s"${parentDir}/datacsv" 88 | df.coalesce(TEST_SPARK_PARTITION).write.csv(saveTo) 89 | saveTo 90 | } 91 | 92 | def createTestCSVFilesWithHeader(parentDir: File, sep: String): String = { 93 | val df = getTestDataFrame(testRawData, testSchema) 94 | 95 | val saveTo = s"${parentDir}/datacsv" 96 | df.coalesce(TEST_SPARK_PARTITION) 97 | .write 98 | .option("header", true) 99 | .option("sep", sep) 100 | .csv(saveTo) 101 | saveTo 102 | } 103 | 104 | def createTestCSVFilesWithPartition(parentDir: File): String = { 105 | val df = getTestDataFrame(testRawDataWithPartition, testSchemaWithPartition) 106 | 107 | val saveTo = s"${parentDir}/datacsv" 108 | df.repartition(TEST_SPARK_PARTITION, col("testPart1"), col("testPart2")) 109 | .write 110 | .partitionBy("testPart1", "testPart2") 111 | .csv(saveTo) 112 | saveTo 113 | } 114 | 115 | def createTestParquetFiles(parentDir: File): String = { 116 | val df = getTestDataFrame(testRawData, testSchema) 117 | 118 | val saveTo = s"${parentDir}/datacsv" 119 | df.coalesce(TEST_SPARK_PARTITION).write.parquet(saveTo) 120 | saveTo 121 | } 122 | 123 | def createTestJsonFiles(parentDir: File): String = { 124 | val df = getTestDataFrame(testRawData, testSchema) 125 | 126 | val saveTo = s"${parentDir}/datacsv" 127 | df.coalesce(TEST_SPARK_PARTITION).write.json(saveTo) 128 | saveTo 129 | } 130 | 131 | /** 132 | * Runs the plan and makes sure the answer matches the expected result. 133 | * @param df the [[DataFrame]] to be executed 134 | * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. 135 | */ 136 | def assertDF(df: DataFrame, expectedAnswer: Seq[Row], trim: Boolean = false): Unit = { 137 | val isSorted = df.queryExecution.logical.collect { case s: logical.Sort => s }.nonEmpty 138 | def prepareAnswer(answer: Seq[Row]): Seq[Row] = { 139 | // Converts data to types that we can do equality comparison using Scala collections. 140 | // For BigDecimal type, the Scala type has a better definition of equality test (similar to 141 | // Java's java.math.BigDecimal.compareTo). 142 | // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for 143 | // equality test. 144 | val converted: Seq[Row] = answer.map { s => 145 | Row.fromSeq(s.toSeq.map { 146 | case d: java.math.BigDecimal => BigDecimal(d) 147 | case b: Array[Byte] => b.toSeq 148 | case s: String if trim => s.trim() 149 | case o => o 150 | }) 151 | } 152 | if (!isSorted) converted.sortBy(_.toString()) else converted 153 | } 154 | val sparkAnswer = try df.collect().toSeq catch { 155 | case e: Exception => 156 | val errorMessage = 157 | s""" 158 | |Exception thrown while executing query: 159 | |${df.queryExecution} 160 | |== Exception == 161 | |$e 162 | |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} 163 | """.stripMargin 164 | fail(errorMessage) 165 | } 166 | 167 | val prepExpectedAnswer = prepareAnswer(expectedAnswer) 168 | val prepSparkAnswer = prepareAnswer(sparkAnswer) 169 | 170 | if (prepExpectedAnswer != prepSparkAnswer) { 171 | val errorMessage = 172 | s""" 173 | |Results do not match for query: 174 | |${df.queryExecution} 175 | |== Results == 176 | |${sideBySide( 177 | s"== Correct Answer - ${expectedAnswer.size} ==" +: 178 | prepExpectedAnswer.map(_.toString()), 179 | s"== Spark Answer - ${sparkAnswer.size} ==" +: 180 | prepSparkAnswer.map(_.toString())).mkString("\n")} 181 | """.stripMargin 182 | fail(errorMessage) 183 | } 184 | } 185 | 186 | private def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { 187 | val maxLeftSize = left.map(_.length).max 188 | val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") 189 | val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") 190 | 191 | leftPadded.zip(rightPadded).map { 192 | case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r 193 | } 194 | } 195 | 196 | def waitForQueryStarted(query: StreamingQuery): Unit = { 197 | 198 | while (!query.isActive) { 199 | Thread.sleep(1000) 200 | } 201 | Thread.sleep(3000) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /src/test/scala/it/spark/sql/streaming/connector/S3ConnectorSourceCrossAccountItSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package it.spark.sql.streaming.connector 18 | 19 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions._ 20 | import com.amazonaws.spark.sql.streaming.connector.TestUtils.withTestTempDir 21 | import it.spark.sql.streaming.connector.ItTestUtils.{getCrossAccountQueueUrl, getCrossAccountTestRegion, getCrossAccountUploadS3Path, recursiveUploadNewFilesToS3, removeTempS3FolderIfEnabled} 22 | import it.spark.sql.streaming.connector.SqsTest.SLEEP_TIME_MID 23 | import org.scalatest.matchers.must.Matchers 24 | 25 | abstract class S3ConnectorSourceCrossAccountIntegrationTest extends S3ConnectorItBase 26 | with QueueTestBase 27 | with Matchers { 28 | 29 | override val testQueueUrl: String = getCrossAccountQueueUrl 30 | override val uploadS3Path: String = getCrossAccountUploadS3Path 31 | override val testRegion: String = getCrossAccountTestRegion 32 | 33 | override def beforeAll(): Unit = { 34 | super.beforeAll() 35 | 36 | spark.sparkContext.setLogLevel("INFO") 37 | } 38 | 39 | override def beforeEach(): Unit = { 40 | super.beforeEach() 41 | waitForQueueReady(testQueueUrl, 0) 42 | } 43 | 44 | override def afterAll(): Unit = { 45 | try { 46 | super.afterAll() 47 | Thread.sleep(SLEEP_TIME_MID) 48 | waitForQueueReady(testQueueUrl, 0) 49 | } finally { 50 | removeTempS3FolderIfEnabled(uploadS3Path, testRegion) 51 | } 52 | } 53 | 54 | test("Cross account and region new S3 CSV files loaded to memory") { 55 | withTestTempDir { tempDir => 56 | val TEST_FILE_FORMAT = "csv" 57 | val testDataPath: String = createTestCSVFiles(tempDir) 58 | 59 | recursiveUploadNewFilesToS3(uploadS3Path, 60 | testDataPath, 61 | TEST_FILE_FORMAT, 62 | getCrossAccountTestRegion) 63 | 64 | waitForQueueReady(testQueueUrl, TEST_SPARK_PARTITION) 65 | 66 | val inputDf = spark 67 | .readStream 68 | .format(SOURCE_SHORT_NAME) 69 | .schema(testSchema) 70 | .options(getDefaultOptions(TEST_FILE_FORMAT)) 71 | .load() 72 | 73 | testStream(inputDf)( 74 | StartStream(), 75 | AssertOnQuery { q => 76 | q.processAllAvailable() 77 | true 78 | }, 79 | CheckAnswer(testRawData: _*), 80 | StopStream 81 | ) 82 | } 83 | } 84 | } 85 | 86 | @IntegrationTestSuite 87 | class S3ConnectorSourceCrossAccountSqsRocksDBItSuite extends S3ConnectorSourceCrossAccountIntegrationTest 88 | with SqsTest { 89 | 90 | override val testClient = getClient(testRegion) 91 | 92 | override def beforeAll(): Unit = { 93 | super.beforeAll() 94 | purgeQueue(testQueueUrl) 95 | } 96 | 97 | override def afterAll(): Unit = { 98 | super.afterAll() 99 | purgeQueue(testQueueUrl) 100 | } 101 | 102 | override def getDefaultOptions(fileFormat: String): Map[String, String] = { 103 | Map( 104 | QUEUE_REGION -> testRegion, 105 | S3_FILE_FORMAT -> fileFormat, 106 | MAX_FILES_PER_TRIGGER -> TEST_MAX_FILES_PER_TRIGGER.toString, 107 | QUEUE_URL -> testQueueUrl, 108 | SQS_LONG_POLLING_WAIT_TIME_SECONDS -> "15", 109 | SQS_VISIBILITY_TIMEOUT_SECONDS -> "120" 110 | ) 111 | } 112 | } -------------------------------------------------------------------------------- /src/test/scala/it/spark/sql/streaming/connector/TestForeachWriter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package it.spark.sql.streaming.connector 18 | 19 | import scala.collection.mutable 20 | import scala.collection.mutable.ArrayBuffer 21 | 22 | import org.apache.spark.sql.ForeachWriter 23 | 24 | class TestForeachWriter[T](key: String, rowToString: (T) => String) extends ForeachWriter[T] { 25 | 26 | override def open(partitionId: Long, version: Long): Boolean = true 27 | 28 | override def process(value: T): Unit = { 29 | TestForeachWriter.addValue(key, rowToString(value)) 30 | } 31 | 32 | override def close(errorOrNull: Throwable): Unit = {} 33 | 34 | } 35 | 36 | object TestForeachWriter { 37 | private val internalHashMap = new mutable.HashMap[String, ArrayBuffer[String]]() 38 | 39 | def addValue(key: String, value: String): Option[ArrayBuffer[String]] = { 40 | internalHashMap.synchronized { 41 | val values = internalHashMap.getOrElse(key, new ArrayBuffer[String]()) 42 | values.append(value) 43 | internalHashMap.put(key, values) 44 | } 45 | } 46 | 47 | def getValues(key: String): ArrayBuffer[String] = internalHashMap.getOrElse(key, ArrayBuffer.empty) 48 | 49 | def allValues: mutable.HashMap[String, ArrayBuffer[String]] = internalHashMap 50 | 51 | def clearAll(): Unit = { 52 | internalHashMap.synchronized { 53 | internalHashMap.clear() 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /src/test/scala/it/spark/sql/streaming/connector/client/AsyncSqsClientItSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package it.spark.sql.streaming.connector.client 19 | 20 | import java.util.concurrent.CopyOnWriteArrayList 21 | 22 | import scala.language.implicitConversions 23 | 24 | import com.amazonaws.services.sqs.AmazonSQS 25 | import com.amazonaws.spark.sql.streaming.connector.{FileMetadata, S3ConnectorSourceOptions} 26 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions.{QUEUE_REGION, QUEUE_URL, S3_FILE_FORMAT} 27 | import com.amazonaws.spark.sql.streaming.connector.TestUtils.withTestTempDir 28 | import com.amazonaws.spark.sql.streaming.connector.client.{AsyncQueueClient, AsyncSqsClientBuilder} 29 | import it.spark.sql.streaming.connector.{IntegrationTestSuite, S3ConnectorItBase, SqsTest} 30 | import it.spark.sql.streaming.connector.ItTestUtils._ 31 | import it.spark.sql.streaming.connector.SqsTest.SLEEP_TIME_MID 32 | import org.scalatest.matchers.must.Matchers.include 33 | import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper 34 | 35 | @IntegrationTestSuite 36 | class AsyncSqsClientItSuite extends S3ConnectorItBase with SqsTest { 37 | 38 | override val uploadS3Path: String = getUploadS3Path 39 | override val testRegion: String = getTestRegion 40 | override val testClient: AmazonSQS = getClient(testRegion) 41 | override val testQueueUrl: String = getQueueUrl 42 | 43 | override def beforeAll(): Unit = { 44 | super.beforeAll() 45 | spark.sparkContext.setLogLevel("INFO") 46 | 47 | purgeQueue(testQueueUrl) 48 | } 49 | 50 | override def beforeEach(): Unit = { 51 | super.beforeEach() 52 | waitForQueueReady(testQueueUrl, 0) 53 | } 54 | 55 | override def afterAll(): Unit = { 56 | 57 | try { 58 | Thread.sleep(SLEEP_TIME_MID) 59 | waitForQueueReady(testQueueUrl, 0) 60 | super.afterAll() 61 | } finally { 62 | removeTempS3FolderIfEnabled(uploadS3Path, testRegion) 63 | purgeQueue(testQueueUrl) 64 | } 65 | } 66 | 67 | override def getDefaultOptions(fileFormat: String): Map[String, String] = { 68 | Map( 69 | S3_FILE_FORMAT -> fileFormat, 70 | QUEUE_URL -> testQueueUrl, 71 | QUEUE_REGION -> testRegion 72 | ) 73 | } 74 | 75 | val DEFAULT_TEST_FILE_FORMAT = "csv" 76 | 77 | test("read message from SQS") { 78 | withTestTempDir { tempDir => 79 | val testDataPath: String = createTestCSVFiles(tempDir) 80 | 81 | recursiveUploadNewFilesToS3(uploadS3Path, testDataPath, DEFAULT_TEST_FILE_FORMAT, testRegion) 82 | 83 | var asyncClient: AsyncQueueClient[String] = null 84 | val pathList = new CopyOnWriteArrayList[String] 85 | 86 | try { 87 | val sourceOptions = S3ConnectorSourceOptions(getDefaultOptions(DEFAULT_TEST_FILE_FORMAT)) 88 | asyncClient = new AsyncSqsClientBuilder().sourceOptions( 89 | sourceOptions 90 | ) 91 | .consumer( 92 | (data: FileMetadata[String]) => { 93 | logInfo(s"consumer data: ${data}") 94 | data.messageId.map(asyncClient.handleProcessedMessage) 95 | pathList.add(data.filePath) 96 | } 97 | ) 98 | .build() 99 | 100 | asyncClient.asyncFetch(sourceOptions.queueFetchWaitTimeoutSeconds) 101 | } finally { 102 | if (asyncClient != null) { 103 | asyncClient.close() 104 | } 105 | } 106 | 107 | asyncClient.metrics.fetchThreadUncaughtExceptionCounter.getCount shouldBe 0 108 | var part0Cnt = 0 109 | var part1Cnt = 0 110 | pathList.forEach( 111 | path => { 112 | if (path.contains("part-00000")) part0Cnt = part0Cnt + 1 113 | else if (path.contains("part-00001")) part1Cnt = part1Cnt + 1 114 | } 115 | ) 116 | 117 | pathList.size shouldBe 2 118 | part0Cnt shouldBe 1 119 | part1Cnt shouldBe 1 120 | 121 | waitForQueueReady(testQueueUrl, 0) 122 | } 123 | 124 | } 125 | 126 | test("read message from SQS with consumer exception") { 127 | withTestTempDir { tempDir => 128 | val testDataPath: String = createTestCSVFiles(tempDir) 129 | 130 | 131 | recursiveUploadNewFilesToS3(uploadS3Path, testDataPath, DEFAULT_TEST_FILE_FORMAT, testRegion) 132 | 133 | var asyncClient: AsyncQueueClient[String] = null 134 | val pathList = new CopyOnWriteArrayList[String] 135 | 136 | try { 137 | val sourceOptions = S3ConnectorSourceOptions(getDefaultOptions(DEFAULT_TEST_FILE_FORMAT)) 138 | asyncClient = new AsyncSqsClientBuilder().sourceOptions( 139 | sourceOptions 140 | ) 141 | .consumer( 142 | (data: FileMetadata[String]) => { 143 | logInfo(s"consumer data: ${data}") 144 | try{ 145 | val path = data.filePath 146 | if (path.contains("part-00001")) { 147 | throw new RuntimeException("exception in consumer") 148 | } 149 | 150 | pathList.add(path) 151 | } 152 | finally { 153 | data.messageId.map(asyncClient.deleteInvalidMessageIfNecessary) 154 | } 155 | 156 | } 157 | ) 158 | .build() 159 | 160 | asyncClient.asyncFetch(sourceOptions.queueFetchWaitTimeoutSeconds) 161 | } finally { 162 | if (asyncClient != null) { 163 | asyncClient.close() 164 | } 165 | } 166 | 167 | asyncClient.metrics.fetchThreadUncaughtExceptionCounter.getCount shouldBe 0 168 | 169 | pathList.size shouldBe 1 170 | pathList.get(0) should include ("part-00000") 171 | 172 | waitForQueueReady(testQueueUrl, 0) 173 | } 174 | 175 | } 176 | 177 | } -------------------------------------------------------------------------------- /src/test/scala/pt/spark/sql/streaming/connector/DataConsumer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package pt.spark.sql.streaming.connector 18 | 19 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions._ 20 | import pt.spark.sql.streaming.connector.DataGenerator.{testSchemaWithPartition, DATA_PREFIX} 21 | 22 | import org.apache.spark.sql.SparkSession 23 | import org.apache.spark.sql.streaming.Trigger 24 | 25 | object DataConsumer { 26 | val SOURCE_SHORT_NAME: String = "s3-connector" 27 | val QUEUE_NAME: String = "PTConsume" 28 | val CONSUMER_MAX_FILES_PER_TRIGGER: String = "5000" 29 | 30 | val localTest: Boolean = System.getProperty("os.name").toLowerCase().startsWith("mac os") 31 | 32 | def main(args: Array[String]): Unit = { 33 | 34 | val dataSrc = args(0) 35 | val queueUrl = args(1) 36 | val fileFormat = args(2) 37 | val checkpointDir = args(3) 38 | val writeToDir = args(4) 39 | val readFrom = s"${dataSrc}/${DATA_PREFIX}" 40 | 41 | val sparkBuilder = SparkSession.builder() 42 | .appName("S3ConnectorPTDataConsumer") 43 | .config("spark.sql.ui.explainMode", "extended") 44 | 45 | addConfigForLocalTest(sparkBuilder) 46 | 47 | val spark = sparkBuilder.getOrCreate() 48 | 49 | spark.sparkContext.setLogLevel("INFO") 50 | 51 | val connectorOptions = spark.sqlContext.getAllConfs ++ Map( 52 | QUEUE_REGION -> "us-east-2", 53 | S3_FILE_FORMAT -> fileFormat, 54 | MAX_FILES_PER_TRIGGER -> CONSUMER_MAX_FILES_PER_TRIGGER, 55 | MAX_FILE_AGE->"15d", 56 | QUEUE_URL -> queueUrl, 57 | QUEUE_FETCH_WAIT_TIMEOUT_SECONDS -> "10", 58 | SQS_LONG_POLLING_WAIT_TIME_SECONDS -> "5", 59 | SQS_VISIBILITY_TIMEOUT_SECONDS -> "60", 60 | PATH_GLOB_FILTER -> s"*.${fileFormat}", 61 | BASE_PATH -> readFrom 62 | ) 63 | 64 | try { 65 | val inputDf = spark 66 | .readStream 67 | .format(SOURCE_SHORT_NAME) 68 | .schema(testSchemaWithPartition) 69 | .options(connectorOptions) 70 | .load() 71 | 72 | val query = inputDf 73 | .writeStream 74 | .queryName(QUEUE_NAME) 75 | .format("csv") 76 | .option("path", writeToDir) 77 | .option("checkpointLocation", checkpointDir) 78 | .trigger(Trigger.ProcessingTime("15 seconds")) 79 | .start() 80 | 81 | query.awaitTermination() 82 | } finally { 83 | spark.stop() 84 | } 85 | 86 | } 87 | 88 | def addConfigForLocalTest(sparkBuilder: SparkSession.Builder): Unit = { 89 | if (localTest) { 90 | sparkBuilder 91 | .master("local[2]") 92 | .config("spark.sql.debug.maxToStringFields", "100") 93 | .config("spark.hadoop.fs.s3.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") 94 | .config("spark.hadoop.fs.s3a.aws.credentials.provider", 95 | "com.amazonaws.auth.EnvironmentVariableCredentialsProvider") 96 | // .config("spark.driver.bindAddress", "127.0.0.1") // VPN env. 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/test/scala/pt/spark/sql/streaming/connector/DataGenerator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package pt.spark.sql.streaming.connector 18 | 19 | import scala.util.Random 20 | 21 | import pt.spark.sql.streaming.connector.DataConsumer.addConfigForLocalTest 22 | 23 | import org.apache.spark.sql.{DataFrame, Row, SparkSession} 24 | import org.apache.spark.sql.functions.col 25 | import org.apache.spark.sql.types._ 26 | 27 | object DataGenerator { 28 | 29 | val testSchemaWithPartition: StructType = StructType(Array( 30 | StructField("valString", StringType, nullable = true), 31 | StructField("valBoolean", BooleanType, nullable = true), 32 | StructField("valDouble", DoubleType, nullable = true), 33 | StructField("valInt", IntegerType, nullable = true), 34 | StructField("valPartition", StringType, nullable = false), 35 | )) 36 | 37 | val DATA_PREFIX = "datacsv" 38 | 39 | def main(args: Array[String]): Unit = { 40 | 41 | val dataDest = args(0) 42 | val rowCount: Int = args(1).toInt 43 | val sparkPartitionCount: Int = args(2).toInt 44 | val partitionPrefix: String = args(3) 45 | 46 | val sparkBuilder = SparkSession.builder() 47 | .appName("S3ConnectorPTDataGenerator") 48 | .config("spark.sql.ui.explainMode", "extended") 49 | addConfigForLocalTest(sparkBuilder) 50 | 51 | val spark = sparkBuilder.getOrCreate() 52 | 53 | spark.sparkContext.setLogLevel("INFO") 54 | 55 | try { 56 | val df = generateDataFrame(spark, rowCount, sparkPartitionCount, 57 | partitionPrefix, testSchemaWithPartition) 58 | 59 | val saveTo = s"${dataDest}/${DATA_PREFIX}" 60 | df.repartition(sparkPartitionCount, col("valPartition")) 61 | .write 62 | .mode("append") 63 | .partitionBy("valPartition") 64 | .csv(saveTo) 65 | } finally { 66 | spark.stop() 67 | } 68 | 69 | 70 | } 71 | 72 | def generateDataFrame(spark: SparkSession, 73 | rowCount: Int, 74 | partitionCount: Int, 75 | partitionPrefix: String, 76 | schema: StructType): DataFrame = { 77 | def randomString(len: Int) = Random.alphanumeric.take(len).mkString 78 | 79 | val df = spark.createDataFrame( 80 | spark.sparkContext.parallelize( 81 | Seq.fill(rowCount) { 82 | Row( 83 | randomString(10), 84 | Random.nextBoolean(), 85 | Random.nextDouble(), 86 | Random.nextInt(), 87 | partitionPrefix + "_" + Random.nextInt(partitionCount).toString) 88 | } 89 | ), 90 | schema 91 | ) 92 | 93 | df 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /src/test/scala/pt/spark/sql/streaming/connector/DataValidator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package pt.spark.sql.streaming.connector 19 | 20 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorSourceOptions.{MAX_FILES_PER_TRIGGER, QUEUE_REGION, QUEUE_URL, S3_FILE_FORMAT, SQS_LONG_POLLING_WAIT_TIME_SECONDS, SQS_VISIBILITY_TIMEOUT_SECONDS} 21 | import org.apache.log4j.LogManager 22 | import org.apache.spark.sql.SparkSession 23 | import org.apache.spark.sql.streaming.Trigger 24 | import pt.spark.sql.streaming.connector.DataGenerator.testSchemaWithPartition 25 | 26 | object DataValidator { 27 | 28 | def main(args: Array[String]) { 29 | val dataSrc = args(0) 30 | val expectedRows: Long = args(1).toLong 31 | val log = LogManager.getRootLogger 32 | 33 | val spark = SparkSession.builder() 34 | .appName("S3ConnectorPTDataValidator") 35 | .config("spark.sql.ui.explainMode", "extended") 36 | .getOrCreate() 37 | 38 | spark.sparkContext.setLogLevel("INFO") 39 | 40 | val df = spark.read 41 | .schema(testSchemaWithPartition) 42 | .format("csv") 43 | .option("header", false) 44 | .load(dataSrc + "/*.csv") 45 | 46 | val totalRows = df.count() 47 | 48 | log.info(s"totalRows: ${totalRows}") 49 | assert(totalRows==expectedRows, s"totalRows ${totalRows} doesn't match expectedRows ${expectedRows}") 50 | } 51 | } 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /src/test/scala/pt/spark/sql/streaming/connector/FileSourceConsumer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package pt.spark.sql.streaming.connector 18 | 19 | import org.apache.spark.sql.SparkSession 20 | import org.apache.spark.sql.streaming.Trigger 21 | import pt.spark.sql.streaming.connector.DataConsumer.CONSUMER_MAX_FILES_PER_TRIGGER 22 | import pt.spark.sql.streaming.connector.DataGenerator.{DATA_PREFIX, testSchemaWithPartition} 23 | 24 | // An test consumer using Spark's default file readStream 25 | object FileSourceConsumer { 26 | 27 | def main(args: Array[String]): Unit = { 28 | val dataSrc = args(0) 29 | val fileFormat = args(1) 30 | val checkpointDir = args(2) 31 | val writeToDir = args(3) 32 | val readFrom = s"${dataSrc}/${DATA_PREFIX}" 33 | 34 | val spark = SparkSession.builder() 35 | .appName("FileSourceConsumer") 36 | .config("spark.sql.ui.explainMode", "extended") 37 | .getOrCreate() 38 | 39 | spark.sparkContext.setLogLevel("INFO") 40 | 41 | val inputDf = spark 42 | .readStream 43 | .format(fileFormat) 44 | .schema(testSchemaWithPartition) 45 | .option("maxFilesPerTrigger", CONSUMER_MAX_FILES_PER_TRIGGER) 46 | .option("region", "us-east-2") 47 | .load(readFrom) 48 | 49 | val query = inputDf 50 | .writeStream 51 | .queryName("FileSourceDataConsumer") 52 | .format("csv") 53 | .option("path", writeToDir) 54 | .option("checkpointLocation", checkpointDir) 55 | .trigger(Trigger.ProcessingTime("15 seconds")) 56 | .start() 57 | 58 | query.awaitTermination() 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/test/scala/pt/spark/sql/streaming/connector/TestTool.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package pt.spark.sql.streaming.connector 18 | 19 | import com.amazonaws.spark.sql.streaming.connector.S3ConnectorFileCache 20 | import com.amazonaws.spark.sql.streaming.connector.metadataLog.RocksDBS3SourceLog 21 | import org.apache.log4j.LogManager 22 | import pt.spark.sql.streaming.connector.DataConsumer.addConfigForLocalTest 23 | 24 | import org.apache.spark.sql.SparkSession 25 | 26 | object TestTool { 27 | 28 | def main(args: Array[String]): Unit = { 29 | val command = args(0) 30 | 31 | val log = LogManager.getRootLogger 32 | 33 | val sparkBuilder = SparkSession.builder() 34 | .appName("TestTool") 35 | .config("spark.sql.ui.explainMode", "extended") 36 | 37 | addConfigForLocalTest(sparkBuilder) 38 | 39 | val spark = sparkBuilder.getOrCreate() 40 | spark.sparkContext.setLogLevel("INFO") 41 | 42 | try { 43 | command match { 44 | case "PrintRocksDB" => 45 | val metadataPath = args(1) 46 | val fileCache = new S3ConnectorFileCache(Int.MaxValue) 47 | 48 | val metadataLog = new RocksDBS3SourceLog() 49 | metadataLog.init(spark, metadataPath, fileCache) 50 | metadataLog.printAllBatchesInRocksDB() 51 | metadataLog.close() 52 | 53 | case _ => log.error(s"Unknown command: ${command}") 54 | 55 | } 56 | } finally { 57 | spark.stop() 58 | } 59 | } 60 | } 61 | --------------------------------------------------------------------------------