├── .github ├── FUNDING.yml └── workflows │ ├── greetings.yml │ └── stale.yml ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── _config.yml ├── build.sbt ├── project ├── Versions.scala ├── build.properties └── plugins.sbt └── src ├── main └── scala │ └── ru │ └── chermenin │ └── spark │ └── sql │ └── execution │ └── streaming │ └── state │ ├── RocksDbStateStoreProvider.scala │ └── implicits.scala └── test ├── resources └── log4j.properties └── scala └── ru └── chermenin └── spark └── sql └── execution └── streaming └── state ├── RocksDbStateStoreHelper.scala ├── RocksDbStateStoreProviderSuite.scala └── RocksDbStateTimeoutSuite.scala /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: chermenin 4 | liberapay: chermenin 5 | issuehunt: chermenin 6 | -------------------------------------------------------------------------------- /.github/workflows/greetings.yml: -------------------------------------------------------------------------------- 1 | name: Greetings 2 | 3 | on: [pull_request, issues] 4 | 5 | jobs: 6 | greeting: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/first-interaction@v1 10 | with: 11 | repo-token: ${{ secrets.GITHUB_TOKEN }} 12 | issue-message: 'Automated response: Thanks for reporting your first issue to us! One of the committers will respond to you shortly. Cheers!' 13 | pr-message: 'Automated response: Congratulations on raising your first Pull Request! One of the committers will review this at the earliest. Your contributions are greatly appreciated.' 14 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: "Close stale PRs" 2 | on: 3 | schedule: 4 | - cron: "0 0 * * *" 5 | 6 | jobs: 7 | stale: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/stale@v3 11 | with: 12 | repo-token: ${{ secrets.GITHUB_TOKEN }} 13 | stale-issue-message: > 14 | Automated Message: We're closing this issue because it hasn't been updated in a while. 15 | This isn't a judgement on the merit of the issue in any way. It's just 16 | a way of keeping the issue queue manageable. 17 | If you'd like to revive this issue, please reopen it and ask a 18 | committer to remove the Stale tag! 19 | stale-pr-message: > 20 | Automated Message: We're closing this PR because it hasn't been updated in a while. 21 | This isn't a judgement on the merit of the PR in any way. It's just 22 | a way of keeping the PR queue manageable. 23 | If you'd like to revive this PR, please reopen it and ask a 24 | committer to remove the Stale tag! 25 | days-before-stale: 30 26 | days-before-close: 15 27 | stale-issue-label: 'stale-issue' 28 | stale-pr-label: 'stale-pr' 29 | exempt-issue-labels: 'awaiting-approval,work-in-progress' 30 | exempt-pr-labels: 'awaiting-approval,work-in-progress' 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | **/target/ 4 | .idea/ 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | 3 | language: scala 4 | scala: 5 | - 2.11.12 6 | - 2.12.11 7 | 8 | jdk: 9 | - oraclejdk8 10 | - oraclejdk9 11 | - openjdk8 12 | - openjdk9 13 | - openjdk10 14 | 15 | script: 16 | - sbt clean coverage test coverageReport 17 | 18 | after_success: 19 | - bash <(curl -s https://codecov.io/bash) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Custom state store providers for Apache Spark 2 | 3 | [![Build Status](https://travis-ci.org/chermenin/spark-states.svg?branch=master)](https://travis-ci.org/chermenin/spark-states) 4 | [![CodeFactor](https://www.codefactor.io/repository/github/chermenin/spark-states/badge)](https://www.codefactor.io/repository/github/chermenin/spark-states) 5 | [![codecov](https://codecov.io/gh/chermenin/spark-states/branch/master/graph/badge.svg)](https://codecov.io/gh/chermenin/spark-states) 6 | [![Maven Central](https://img.shields.io/maven-central/v/ru.chermenin/spark-states_2.12.svg)](https://central.sonatype.com/search?q=g%3Aru.chermenin++spark-states_*) 7 | [![javadoc](https://javadoc.io/badge2/ru.chermenin/spark-states_2.12/javadoc.svg)](https://javadoc.io/doc/ru.chermenin/spark-states_2.12/latest/ru/chermenin/spark/sql/execution/streaming/state/RocksDbStateStoreProvider.html) 8 | 9 | State management extensions for Apache Spark to keep data across micro-batches during stateful stream processing. 10 | 11 | ### Motivation 12 | 13 | Out of the box, Apache Spark has only one implementation of state store providers. It's `HDFSBackedStateStoreProvider` which stores all of the data in memory, what is a very memory consuming approach. To avoid `OutOfMemory` errors, this repository and custom state store providers were created. 14 | 15 | ### Usage 16 | 17 | To use the custom state store provider for your pipelines use the following additional configuration for the submit script/ SparkConf: 18 | 19 | --conf spark.sql.streaming.stateStore.providerClass="ru.chermenin.spark.sql.execution.streaming.state.RocksDbStateStoreProvider" 20 | 21 | Here is some more information about it: https://docs.databricks.com/spark/latest/structured-streaming/production.html 22 | 23 | Alternatively, you can use the `useRocksDBStateStore()` helper method in your application while creating the SparkSession, 24 | 25 | ``` 26 | import ru.chermenin.spark.sql.execution.streaming.state.implicits._ 27 | 28 | val spark = SparkSession.builder().master(...).useRocksDBStateStore().getOrCreate() 29 | ``` 30 | 31 | Note: For the helper methods to be available, you must import the implicits as shown above. 32 | 33 | 34 | ### State Timeout 35 | 36 | With semantics similar to those of `GroupState`/ `FlatMapGroupWithState`, state timeout features have been built directly into the custom state store. 37 | 38 | Important points to note when using State Timeouts, 39 | 40 | * Timeouts can be set differently for each streaming query. This relies on `queryName` and its `checkpointLocation`. 41 | * The poll trigger set on a streaming query may or may not be set to a different value than the state expiration. 42 | * Timeouts are currently based on processing time 43 | * The timeout will occur once 44 | 1) a fixed duration has elapsed after the entry's creation, or 45 | 2) the most recent replacement (update) of its value, or 46 | 3) its last access 47 | * Unlike `GroupState`, the timeout **is not** eventual as it is independent from query progress 48 | * Since the processing time timeout is based on the clock time, it is affected by the variations in the system clock (i.e. time zone changes, clock skew, etc.) 49 | * Timeout may or may not be set to strict expiration at the slight cost of memory. More info [here](https://github.com/chermenin/spark-states/issues/1). 50 | 51 | There are 2 different ways configure state timeout: 52 | 53 | 1. Via additional configuration on SparkConf: 54 | 55 | To set a processing time timeout for all streaming queries in strict mode. 56 | ``` 57 | --conf spark.sql.streaming.stateStore.stateExpirySecs=5 58 | --conf spark.sql.streaming.stateStore.strictExpire=true 59 | ``` 60 | 61 | To configure state timeout differently for each query the above configs can be modified to, 62 | ``` 63 | --conf spark.sql.streaming.stateStore.stateExpirySecs.queryName1=5 64 | --conf spark.sql.streaming.stateStore.stateExpirySecs.queryName2=10 65 | ... 66 | ... 67 | --conf spark.sql.streaming.stateStore.strictExpire=true 68 | ``` 69 | 70 | 2. Via `stateTimeout()` helper method _(recommended way)_: 71 | 72 | ``` 73 | import ru.chermenin.spark.sql.execution.streaming.state.implicits._ 74 | 75 | val spark: SparkSession = ... 76 | val streamingDF: DataFrame = ... 77 | 78 | streamingDF.writeStream 79 | .format(...) 80 | .outputMode(...) 81 | .trigger(Trigger.ProcessingTime(1000L)) 82 | .queryName("myQuery1") 83 | .option("checkpointLocation", "chkpntloc") 84 | .stateTimeout(spark.conf, expirySecs = 5) 85 | .start() 86 | 87 | spark.streams.awaitAnyTermination() 88 | ``` 89 | 90 | Preferably, the `queryName` and `checkpointLocation` can be set directly via the `stateTimeout()` method, as below: 91 | ``` 92 | streamingDF.writeStream 93 | .format(...) 94 | .outputMode(...) 95 | .trigger(Trigger.ProcessingTime(1000L)) 96 | .stateTimeout(spark.conf, queryName="myQuery1", expirySecs = 5, checkpointLocation ="chkpntloc") 97 | .start() 98 | ``` 99 | 100 | Note: If `queryName` is invalid/ unavailable, the streaming query will be tagged as `UNNAMED` and timeout applicable will be as per the value of `spark.sql.streaming.stateStore.stateExpirySecs` (which defaults to -1, but can be overridden via SparkConf) 101 | 102 | Other state timeout related points (applicable on global and query level), 103 | * For no timeout, i.e. infinite state, set `spark.sql.streaming.stateStore.stateExpirySecs=-1` 104 | * For stateless processing, i.e. no state, set `spark.sql.streaming.stateStore.stateExpirySecs=0` 105 | 106 | ### Contributing 107 | 108 | You're welcome to submit pull requests with any changes for this repository at any time. I'll be very glad to see any contributions. 109 | 110 | ### License 111 | 112 | The standard [Apache 2.0](LICENSE) license is used for this project. 113 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Aleksandr Chermenin 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | organization := "ru.chermenin" 18 | name := "spark-states" 19 | version := "0.3-SNAPSHOT" 20 | 21 | description := "Custom state store providers for Apache Spark" 22 | homepage := Some(url("http://code.chermenin.ru/spark-states/")) 23 | 24 | crossScalaVersions := Seq(Versions.Scala_2_11, Versions.Scala_2_12) 25 | 26 | libraryDependencies ++= Seq( 27 | 28 | // general dependencies 29 | "org.apache.spark" %% "spark-sql" % Versions.Spark % "provided", 30 | "org.apache.spark" %% "spark-streaming" % Versions.Spark % "provided", 31 | "org.rocksdb" % "rocksdbjni" % Versions.RocksDb, 32 | 33 | // test dependencies 34 | "org.scalatest" %% "scalatest" % "3.0.5" % "test", 35 | "org.apache.spark" %% "spark-sql" % Versions.Spark % "test" classifier "tests", 36 | "com.google.guava" % "guava-testlib" % "14.0.1" % "test" 37 | ) 38 | 39 | scmInfo := Some( 40 | ScmInfo( 41 | url("https://github.com/chermenin/spark-states"), 42 | "git@github.com:chermenin/spark-states.git" 43 | ) 44 | ) 45 | 46 | developers := List( 47 | Developer( 48 | "chermenin", 49 | "Alex Chermenin", 50 | "alex@chermenin.ru", 51 | url("https://chermenin.ru") 52 | ) 53 | ) 54 | 55 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")) 56 | 57 | publishMavenStyle := true 58 | 59 | publishTo := Some( 60 | if (isSnapshot.value) 61 | Opts.resolver.sonatypeSnapshots 62 | else 63 | Opts.resolver.sonatypeStaging 64 | ) 65 | -------------------------------------------------------------------------------- /project/Versions.scala: -------------------------------------------------------------------------------- 1 | object Versions { 2 | val Scala_2_11 = "2.11.12" 3 | val Scala_2_12 = "2.12.11" 4 | val Spark = "2.4.5" 5 | val RocksDb = "6.7.3" 6 | } 7 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.3.9 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.2") 2 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1") 3 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.1") 4 | -------------------------------------------------------------------------------- /src/main/scala/ru/chermenin/spark/sql/execution/streaming/state/RocksDbStateStoreProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Aleksandr Chermenin 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package ru.chermenin.spark.sql.execution.streaming.state 18 | 19 | import java.io.{File, FileInputStream, FileOutputStream, IOException} 20 | import java.nio.file.attribute.BasicFileAttributes 21 | import java.nio.file.{Path => LocalPath, _} 22 | import java.util.concurrent.{ConcurrentHashMap, TimeUnit} 23 | import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream} 24 | 25 | import com.google.common.cache.{CacheBuilder, CacheLoader} 26 | import org.apache.hadoop.conf.Configuration 27 | import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} 28 | import org.apache.spark.SparkConf 29 | import org.apache.spark.internal.Logging 30 | import org.apache.spark.sql.catalyst.expressions.UnsafeRow 31 | import org.apache.spark.sql.execution.streaming.state._ 32 | import org.apache.spark.sql.types.StructType 33 | import org.rocksdb._ 34 | import org.rocksdb.util.SizeUnit 35 | import ru.chermenin.spark.sql.execution.streaming.state.RocksDbStateStoreProvider._ 36 | 37 | import scala.collection.JavaConverters._ 38 | import scala.util.control.NonFatal 39 | import scala.util.{Failure, Success, Try} 40 | 41 | /** 42 | * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed 43 | * by files in a HDFS-compatible file system using RocksDB key-value storage format. 44 | * All updates to the store has to be done in sets transactionally, and each set of updates 45 | * increments the store's version. These versions can be used to re-execute the updates 46 | * (by retries in RDD operations) on the correct version of the store, and regenerate 47 | * the store version. 48 | * 49 | * Usage: 50 | * To update the data in the state store, the following order of operations are needed. 51 | * 52 | * // get the right store 53 | * - val store = StateStore.get( 54 | * StateStoreId(checkpointLocation, operatorId, partitionId), ..., version, ...) 55 | * - store.put(...) 56 | * - store.remove(...) 57 | * - store.commit() // commits all the updates to made; the new version will be returned 58 | * - store.iterator() // key-value data after last commit as an iterator 59 | * - store.updates() // updates made in the last commit as an iterator 60 | * 61 | * Fault-tolerance model: 62 | * - Every set of updates is written to a snapshot file before committing. 63 | * - The state store is responsible for cleaning up of old snapshot files. 64 | * - Multiple attempts to commit the same version of updates may overwrite each other. 65 | * Consistency guarantees depend on whether multiple attempts have the same updates and 66 | * the overwrite semantics of underlying file system. 67 | * - Background maintenance of files ensures that last versions of the store is always 68 | * recoverable to ensure re-executed RDD operations re-apply updates on the correct 69 | * past version of the store. 70 | * 71 | * Description of State Timeout API 72 | * -------------------------------- 73 | * 74 | * This API should be used to open the db when key-values inserted are 75 | * meant to be removed from the db in 'ttl' amount of time provided in seconds. 76 | * The timeouts can be optionally set to strict expiration by setting 77 | * spark.sql.streaming.stateStore.strictExpire = true on `SparkConf` 78 | * 79 | * Timeout Modes: 80 | * - In non strict mode (default), this guarantees that key-values inserted will remain in the db 81 | * for >= ttl amount of time and the db will make efforts to remove the key-values 82 | * as soon as possible after ttl seconds of their insertion. 83 | * - In strict mode, the key-values inserted will remain in the db for exactly ttl amount of time. 84 | * To ensure exact expiration, a separate cache of keys is maintained in memory with 85 | * their respective deadlines and is used for reference during operations. 86 | * 87 | * The timeouts may be set on global (for all queries) for differently for each streaming query. 88 | * This can be done be appending the query name to [[STATE_EXPIRY_SECS]], like below, 89 | * 90 | * spark.sql.streaming.stateStore.stateExpirySecs.queryName1 = 5 91 | * 92 | * This API can also be used to allow, 93 | * - Stateless Processing - set timeout to 0 94 | * - Infinite State (no timeout) - set timeout to -1, which is set by default. 95 | */ 96 | class RocksDbStateStoreProvider extends StateStoreProvider with Logging { 97 | 98 | /** Load native RocksDb library */ 99 | RocksDB.loadLibrary() 100 | 101 | private val options: Options = new Options() 102 | .setCreateIfMissing(true) 103 | .setWriteBufferSize(RocksDbStateStoreProvider.DEFAULT_WRITE_BUFFER_SIZE_MB * SizeUnit.MB) 104 | .setMaxWriteBufferNumber(RocksDbStateStoreProvider.DEFAULT_WRITE_BUFFER_NUMBER) 105 | .setMaxBackgroundCompactions(RocksDbStateStoreProvider.DEFAULT_BACKGROUND_COMPACTIONS) 106 | .setCompressionType(CompressionType.SNAPPY_COMPRESSION) 107 | .setCompactionStyle(CompactionStyle.UNIVERSAL) 108 | 109 | /** Implementation of [[StateStore]] API which is backed by RocksDB */ 110 | class RocksDbStateStore(val version: Long, 111 | val dbPath: String, 112 | val keySchema: StructType, 113 | val valueSchema: StructType, 114 | val localSnapshots: ConcurrentHashMap[Long, String], 115 | val keyCache: MapType) extends StateStore { 116 | 117 | /** New state version */ 118 | private val newVersion = version + 1 119 | 120 | /** RocksDb database to keep state */ 121 | private val store: RocksDB = TtlDB.open(options, dbPath, ttlSec, false) 122 | 123 | /** Enumeration representing the internal state of the store */ 124 | object State extends Enumeration { 125 | val Updating, Committed, Aborted = Value 126 | } 127 | 128 | @volatile private var keysNumber: Long = 0 129 | @volatile private var state: State.Value = State.Updating 130 | 131 | /** Unique identifier of the store */ 132 | override def id: StateStoreId = RocksDbStateStoreProvider.this.stateStoreId 133 | 134 | /** 135 | * Get the current value of a non-null key. 136 | * 137 | * @return a non-null row if the key exists in the store, otherwise null. 138 | */ 139 | override def get(key: UnsafeRow): UnsafeRow = { 140 | if (isStrictExpire) { 141 | Option(keyCache.getIfPresent(key)) match { 142 | case Some(_) => getValue(key) 143 | case None => null 144 | } 145 | } else getValue(key) 146 | } 147 | 148 | /** 149 | * Put a new value for a non-null key. Implementations must be aware that the UnsafeRows in 150 | * the params can be reused, and must make copies of the data as needed for persistence. 151 | */ 152 | override def put(key: UnsafeRow, value: UnsafeRow): Unit = { 153 | verify(state == State.Updating, "Cannot put entry into already committed or aborted state") 154 | val keyCopy = key.copy() 155 | val valueCopy = value.copy() 156 | synchronized { 157 | store.put(keyCopy.getBytes, valueCopy.getBytes) 158 | if (isStrictExpire) { 159 | keyCache.put(keyCopy, DUMMY_VALUE) 160 | } 161 | } 162 | } 163 | 164 | /** 165 | * Remove a single non-null key. 166 | */ 167 | override def remove(key: UnsafeRow): Unit = { 168 | verify(state == State.Updating, "Cannot remove entry from already committed or aborted state") 169 | synchronized { 170 | store.delete(key.getBytes) 171 | if (isStrictExpire) { 172 | keyCache.invalidate(key.getBytes) 173 | } 174 | } 175 | } 176 | 177 | /** 178 | * Get key value pairs with optional approximate `start` and `end` extents. 179 | * If the State Store implementation maintains indices for the data based on the optional 180 | * `keyIndexOrdinal` over fields `keySchema` (see `StateStoreProvider.init()`), then it can use 181 | * `start` and `end` to make a best-effort scan over the data. Default implementation returns 182 | * the full data scan iterator, which is correct but inefficient. Custom implementations must 183 | * ensure that updates (puts, removes) can be made while iterating over this iterator. 184 | * 185 | * @param start UnsafeRow having the `keyIndexOrdinal` column set with appropriate starting value. 186 | * @param end UnsafeRow having the `keyIndexOrdinal` column set with appropriate ending value. 187 | * @return An iterator of key-value pairs that is guaranteed not miss any key between start and 188 | * end, both inclusive. 189 | */ 190 | override def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { 191 | verify(state == State.Updating, "Cannot getRange from already committed or aborted state") 192 | iterator() 193 | } 194 | 195 | /** 196 | * Commit all the updates that have been made to the store, and return the new version. 197 | */ 198 | override def commit(): Long = { 199 | verify(state == State.Updating, "Cannot commit already committed or aborted state") 200 | 201 | try { 202 | state = State.Committed 203 | keysNumber = if (isStrictExpire) { 204 | keyCache.size 205 | } else { 206 | store.getLongProperty(ROCKSDB_ESTIMATE_KEYS_NUMBER_PROPERTY) 207 | } 208 | store.close() 209 | putLocalSnapshot(newVersion, dbPath) 210 | snapshot(newVersion, dbPath) 211 | logInfo(s"Committed version $newVersion for $this") 212 | newVersion 213 | } catch { 214 | case NonFatal(e) => 215 | throw new IllegalStateException(s"Error committing version $newVersion into $this", e) 216 | } 217 | } 218 | 219 | /** 220 | * Abort all the updates made on this store. This store will not be usable any more. 221 | */ 222 | override def abort(): Unit = { 223 | verify(state != State.Committed, "Cannot abort already committed state") 224 | try { 225 | state = State.Aborted 226 | keysNumber = if (isStrictExpire) { 227 | keyCache.size 228 | } else { 229 | store.getLongProperty(ROCKSDB_ESTIMATE_KEYS_NUMBER_PROPERTY) 230 | } 231 | store.close() 232 | putLocalSnapshot(newVersion + 1, dbPath) 233 | logInfo(s"Aborted version $newVersion for $this") 234 | } catch { 235 | case e: Exception => 236 | logWarning(s"Error aborting version $newVersion into $this", e) 237 | } 238 | } 239 | 240 | /** 241 | * Get an iterator of all the store data. 242 | * This can be called only after committing all the updates made in the current thread. 243 | */ 244 | override def iterator(): Iterator[UnsafeRowPair] = { 245 | val stateFromRocksIter: Iterator[UnsafeRowPair] = new Iterator[UnsafeRowPair] { 246 | 247 | /** Internal RocksDb iterator */ 248 | private val iterator = store.newIterator() 249 | iterator.seekToFirst() 250 | 251 | /** Check if has some data */ 252 | override def hasNext: Boolean = iterator.isValid 253 | 254 | /** Get next data from RocksDb */ 255 | override def next(): UnsafeRowPair = { 256 | iterator.status() 257 | 258 | val key = new UnsafeRow(keySchema.fields.length) 259 | val keyBytes = iterator.key() 260 | key.pointTo(keyBytes, keyBytes.length) 261 | 262 | val value = new UnsafeRow(valueSchema.fields.length) 263 | val valueBytes = iterator.value() 264 | value.pointTo(valueBytes, valueBytes.length) 265 | 266 | iterator.next() 267 | 268 | new UnsafeRowPair(key, value) 269 | } 270 | } 271 | 272 | if (isStrictExpire) { 273 | stateFromRocksIter.filter(x => keyCache.asMap().keySet().contains(x.key)) 274 | } else { 275 | stateFromRocksIter 276 | } 277 | } 278 | 279 | /** 280 | * Returns current metrics of the state store 281 | */ 282 | override def metrics: StateStoreMetrics = 283 | StateStoreMetrics(keysNumber, keysNumber * (keySchema.defaultSize + valueSchema.defaultSize), Map.empty) 284 | 285 | /** 286 | * Whether all updates have been committed 287 | */ 288 | override def hasCommitted: Boolean = state == State.Committed 289 | 290 | /** 291 | * Custom toString implementation for this state store class. 292 | */ 293 | override def toString: String = 294 | s"RocksDbStateStore[id=(op=${id.operatorId},part=${id.partitionId}),localDir=$dbPath,snapshotsDir=$baseDir]" 295 | 296 | /** 297 | * Method to put current DB path to local snapshots list. 298 | */ 299 | private def putLocalSnapshot(version: Long, dbPath: String): Unit = { 300 | localSnapshots.keys().asScala 301 | .filter(_ < version - storeConf.minVersionsToRetain) 302 | .foreach(version => deleteFile(localSnapshots.get(version))) 303 | localSnapshots.put(version, dbPath) 304 | } 305 | 306 | private def getValue(key: UnsafeRow): UnsafeRow = { 307 | val valueBytes = store.get(key.getBytes) 308 | if (valueBytes == null) return null 309 | val value = new UnsafeRow(valueSchema.fields.length) 310 | value.pointTo(valueBytes, valueBytes.length) 311 | value 312 | } 313 | } 314 | 315 | /* Internal fields and methods */ 316 | private val localSnapshots: ConcurrentHashMap[Long, String] = new ConcurrentHashMap[Long, String]() 317 | 318 | @volatile private var stateStoreId_ : StateStoreId = _ 319 | @volatile private var keySchema: StructType = _ 320 | @volatile private var valueSchema: StructType = _ 321 | @volatile private var storeConf: StateStoreConf = _ 322 | @volatile private var hadoopConf: Configuration = _ 323 | @volatile private var tempDir: String = _ 324 | @volatile private var ttlSec: Int = _ 325 | @volatile private var isStrictExpire: Boolean = _ 326 | @volatile private var expirationByQuery: Map[String, Int] = _ 327 | @volatile private var actualCheckpointRoot: String = _ 328 | @volatile private var queryName: String = _ 329 | 330 | private def baseDir: Path = stateStoreId.storeCheckpointLocation() 331 | 332 | private def fs: FileSystem = baseDir.getFileSystem(hadoopConf) 333 | 334 | /** 335 | * Initialize the provide with more contextual information from the SQL operator. 336 | * This method will be called first after creating an instance of the StateStoreProvider by 337 | * reflection. 338 | * 339 | * @param stateStoreId Id of the versioned StateStores that this provider will generate 340 | * @param keySchema Schema of keys to be stored 341 | * @param valueSchema Schema of value to be stored 342 | * @param indexOrdinal Optional column (represent as the ordinal of the field in keySchema) by 343 | * which the StateStore implementation could index the data. 344 | * @param storeConf Configurations used by the StateStores 345 | * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data 346 | */ 347 | override def init(stateStoreId: StateStoreId, 348 | keySchema: StructType, 349 | valueSchema: StructType, 350 | indexOrdinal: Option[Int], 351 | storeConf: StateStoreConf, 352 | hadoopConf: Configuration): Unit = { 353 | this.stateStoreId_ = stateStoreId 354 | this.keySchema = keySchema 355 | this.valueSchema = valueSchema 356 | this.storeConf = storeConf 357 | this.hadoopConf = hadoopConf 358 | this.tempDir = getTempDir(getTempPrefix(hadoopConf.get("spark.app.name")), "") 359 | this.expirationByQuery = getExpirationByQuery(storeConf.confs) 360 | this.actualCheckpointRoot = 361 | stateStoreId.checkpointRootLocation.replaceAll("/state$", "") 362 | this.queryName = { 363 | val value = actualCheckpointRoot.split("/").last 364 | if (expirationByQuery.contains(value)) value 365 | else { 366 | logWarning( 367 | "An Unnamed Query encountered, default expiration will be applicable. " + 368 | s"Default Expiration is '$DEFAULT_STATE_EXPIRY_SECS' i.e no timeout. " + 369 | s"This can be overridden by setting SparkSession.conf.set($STATE_EXPIRY_SECS, ...)" 370 | ) 371 | UNNAMED_QUERY 372 | } 373 | } 374 | 375 | this.ttlSec = expirationByQuery(queryName) 376 | this.isStrictExpire = setExpireMode(storeConf.confs) 377 | 378 | fs.mkdirs(baseDir) 379 | } 380 | 381 | /** 382 | * Get the state store for making updates to create a new `version` of the store. 383 | */ 384 | override def getStore(version: Long): StateStore = synchronized { 385 | require(version >= 0, "Version cannot be less than 0") 386 | 387 | val snapshotVersions = fetchVersions() 388 | val localVersions = localSnapshots.keySet().asScala 389 | val versions = (snapshotVersions ++ localVersions).filter(_ <= version) 390 | 391 | def initStateStore(path: String): StateStore = 392 | new RocksDbStateStore(version, path, keySchema, valueSchema, localSnapshots, createCache(ttlSec)) 393 | 394 | val stateStore = versions.sorted(Ordering.Long.reverse).toStream 395 | .map(version => Try(loadDb(version)).map(initStateStore)) 396 | .find(_.isSuccess).map(_.get) 397 | .getOrElse(initStateStore(getTempDir(getTempPrefix(), s".$version"))) 398 | 399 | logInfo(s"Retrieved $stateStore for version $version of ${RocksDbStateStoreProvider.this} for update") 400 | stateStore 401 | } 402 | 403 | /** 404 | * Return the id of the StateStores this provider will generate. 405 | */ 406 | override def stateStoreId: StateStoreId = stateStoreId_ 407 | 408 | /** 409 | * Do maintenance backing data files, including cleaning up old files 410 | */ 411 | override def doMaintenance(): Unit = { 412 | try { 413 | cleanup() 414 | } catch { 415 | case NonFatal(ex) => 416 | logWarning(s"Error cleaning up $this", ex) 417 | } 418 | } 419 | 420 | /** 421 | * Called when the provider instance is unloaded from the executor. 422 | */ 423 | override def close(): Unit = { 424 | deleteFile(tempDir) 425 | } 426 | 427 | /** 428 | * Custom toString implementation for this state store provider class. 429 | */ 430 | override def toString: String = { 431 | s"RocksDbStateStoreProvider[" + 432 | s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]" 433 | } 434 | 435 | /** 436 | * Remove CRC files and old logs before uploading. 437 | */ 438 | private def removeCrcAndLogs(dbPath: String): Boolean = { 439 | new File(dbPath).listFiles().filter(file => { 440 | val name = file.getName.toLowerCase() 441 | name.endsWith(".crc") || name.startsWith("log.old") 442 | }).forall(_.delete()) 443 | } 444 | 445 | /** 446 | * Finalize snapshot by moving local RocksDb files into HDFS as zipped file. 447 | */ 448 | private def snapshot(version: Long, dbPath: String): Path = synchronized { 449 | val snapshotFile = getSnapshotFile(version) 450 | try { 451 | if (removeCrcAndLogs(dbPath)) { 452 | compress(new File(dbPath).listFiles(), snapshotFile) 453 | logInfo(s"Saved snapshot for version $version in $snapshotFile") 454 | } else { 455 | throw new IOException(s"Failed to delete CRC files or old logs before moving $dbPath to $snapshotFile") 456 | } 457 | } catch { 458 | case ex: Exception => 459 | throw new IOException(s"Failed to move $dbPath to $snapshotFile", ex) 460 | } 461 | snapshotFile 462 | } 463 | 464 | /** 465 | * Load the required version of the RocksDb files from HDFS. 466 | */ 467 | private def loadDb(version: Long): String = { 468 | val dbPath = getTempDir(getTempPrefix(), s".$version") 469 | if (hasLocalSnapshot(version) && loadLocalSnapshot(version, dbPath) || loadHdfsSnapshot(version, dbPath)) { 470 | dbPath 471 | } else { 472 | throw new IOException(s"Failed to load state snapshot for version $version") 473 | } 474 | } 475 | 476 | /** 477 | * Returns true if there is a local snapshot directory for the version. 478 | */ 479 | private def hasLocalSnapshot(version: Long): Boolean = 480 | localSnapshots.containsKey(version) 481 | 482 | /** 483 | * Move files from local snapshot to reuse in new version. 484 | */ 485 | private def loadLocalSnapshot(version: Long, dbPath: String): Boolean = { 486 | val localSnapshotDir = localSnapshots.remove(version) 487 | try { 488 | Files.move( 489 | Paths.get(localSnapshotDir), Paths.get(dbPath), 490 | StandardCopyOption.REPLACE_EXISTING 491 | ) 492 | true 493 | } catch { 494 | case ex: Exception => 495 | logWarning(s"Failed to reuse $localSnapshotDir as $dbPath", ex) 496 | false 497 | } 498 | } 499 | 500 | /** 501 | * Load files from distributed file system to use in new version of state. 502 | */ 503 | private def loadHdfsSnapshot(version: Long, dbPath: String): Boolean = { 504 | val snapshotFile = getSnapshotFile(version) 505 | try { 506 | decompress(snapshotFile, dbPath) 507 | true 508 | } catch { 509 | case ex: Exception => 510 | throw new IOException(s"Failed to load from $snapshotFile to $dbPath", ex) 511 | } 512 | } 513 | 514 | /** 515 | * Save RocksDB files as ZIP archive in HDFS as snapshot. 516 | */ 517 | private def compress(files: Array[File], snapshotFile: Path): Unit = { 518 | val buffer = new Array[Byte](hadoopConf.getInt("io.file.buffer.size", 4096)) 519 | val output = new ZipOutputStream(fs.create(snapshotFile)) 520 | try { 521 | files.foreach(file => { 522 | val input = new FileInputStream(file) 523 | try { 524 | output.putNextEntry(new ZipEntry(file.getName)) 525 | Iterator.continually(input.read(buffer)) 526 | .takeWhile(_ != -1) 527 | .filter(_ > 0) 528 | .foreach(read => 529 | output.write(buffer, 0, read) 530 | ) 531 | output.closeEntry() 532 | } finally { 533 | input.close() 534 | } 535 | }) 536 | } finally { 537 | output.close() 538 | } 539 | } 540 | 541 | /** 542 | * Load archive from HDFS and unzip RocksDB files. 543 | */ 544 | private def decompress(snapshotFile: Path, dbPath: String): Unit = { 545 | val buffer = new Array[Byte](hadoopConf.getInt("io.file.buffer.size", 4096)) 546 | val input = new ZipInputStream(fs.open(snapshotFile)) 547 | try { 548 | Iterator.continually(input.getNextEntry) 549 | .takeWhile(_ != null) 550 | .foreach(entry => { 551 | val output = new FileOutputStream(s"$dbPath${File.separator}${entry.getName}") 552 | try { 553 | Iterator.continually(input.read(buffer)) 554 | .takeWhile(_ != -1) 555 | .filter(_ > 0) 556 | .foreach(read => 557 | output.write(buffer, 0, read) 558 | ) 559 | } finally { 560 | output.close() 561 | } 562 | }) 563 | } finally { 564 | input.close() 565 | } 566 | } 567 | 568 | /** 569 | * Clean up old snapshots that are not needed any more. It ensures that last 570 | * few versions of the store can be recovered from the files, so re-executed RDD operations 571 | * can re-apply updates on the past versions of the store. 572 | */ 573 | private def cleanup(): Unit = { 574 | try { 575 | val versions = fetchVersions() 576 | if (versions.nonEmpty) { 577 | val earliestVersionToRetain = versions.max - storeConf.minVersionsToRetain + 1 578 | 579 | val filesToDelete = versions 580 | .filter(_ < earliestVersionToRetain) 581 | .map(getSnapshotFile) 582 | 583 | if (filesToDelete.nonEmpty) { 584 | filesToDelete.foreach(fs.delete(_, true)) 585 | logInfo(s"Deleted files older than $earliestVersionToRetain for $this: ${filesToDelete.mkString(", ")}") 586 | } 587 | } 588 | } catch { 589 | case NonFatal(e) => 590 | logWarning(s"Error cleaning up files for $this", e) 591 | } 592 | } 593 | 594 | /** 595 | * Fetch all versions that back the store. 596 | */ 597 | private def fetchVersions(): Seq[Long] = { 598 | val files: Seq[FileStatus] = try { 599 | fs.listStatus(baseDir) 600 | } catch { 601 | case _: java.io.FileNotFoundException => 602 | Seq.empty 603 | } 604 | files.flatMap { status => 605 | val path = status.getPath 606 | val nameParts = path.getName.split("\\.") 607 | if (nameParts.size == 3) { 608 | Seq(nameParts(2).toLong) 609 | } else { 610 | Seq() 611 | } 612 | } 613 | } 614 | 615 | /** 616 | * Get path to snapshot file for the version. 617 | */ 618 | private def getSnapshotFile(version: Long): Path = 619 | new Path(baseDir, s"state.snapshot.$version") 620 | 621 | /** 622 | * Get full prefix for local temp directory. 623 | * 624 | * @return 625 | */ 626 | private def getTempPrefix(prefix: String = "state"): String = 627 | s"$prefix-${stateStoreId_.operatorId}-${stateStoreId_.partitionId}-${stateStoreId_.storeName}-" 628 | 629 | /** 630 | * Create local temporary directory. 631 | */ 632 | private def getTempDir(prefix: String, suffix: String): String = { 633 | val file = if (tempDir != null) { 634 | File.createTempFile(prefix, suffix, new File(tempDir)).getAbsoluteFile 635 | } else { 636 | File.createTempFile(prefix, suffix).getAbsoluteFile 637 | } 638 | if (file.delete() && file.mkdirs()) { 639 | file.getAbsolutePath 640 | } else { 641 | throw new IOException(s"Failed to create temp directory ${file.getAbsolutePath}") 642 | } 643 | } 644 | 645 | /** 646 | * Verify the condition and rise an exception if the condition is failed. 647 | */ 648 | private def verify(condition: => Boolean, msg: String): Unit = 649 | if (!condition) throw new IllegalStateException(msg) 650 | 651 | /** 652 | * Get iterator of all the data of the latest version of the store. 653 | * Note that this will look up the files to determined the latest known version. 654 | */ 655 | private[state] def latestIterator(): Iterator[UnsafeRowPair] = { 656 | val versions = fetchVersions() 657 | if (versions.nonEmpty) { 658 | getStore(versions.max).iterator() 659 | } else Iterator.empty 660 | } 661 | 662 | /** 663 | * Method to delete directory or file. 664 | */ 665 | private def deleteFile(path: String): Unit = { 666 | Files.walkFileTree(Paths.get(path), new SimpleFileVisitor[LocalPath] { 667 | 668 | override def visitFile(visitedFile: LocalPath, attrs: BasicFileAttributes): FileVisitResult = { 669 | Files.delete(visitedFile) 670 | FileVisitResult.CONTINUE 671 | } 672 | 673 | override def postVisitDirectory(visitedDirectory: LocalPath, exc: IOException): FileVisitResult = { 674 | Files.delete(visitedDirectory) 675 | FileVisitResult.CONTINUE 676 | } 677 | }) 678 | } 679 | 680 | } 681 | 682 | /** 683 | * Companion object with constants. 684 | */ 685 | object RocksDbStateStoreProvider { 686 | type MapType = com.google.common.cache.LoadingCache[UnsafeRow, String] 687 | 688 | /** Default write buffer size for RocksDb in megabytes */ 689 | val DEFAULT_WRITE_BUFFER_SIZE_MB = 200 690 | 691 | /** Default number of write buffers for RocksDb */ 692 | val DEFAULT_WRITE_BUFFER_NUMBER = 3 693 | 694 | /** Default background compactions value for RocksDb */ 695 | val DEFAULT_BACKGROUND_COMPACTIONS = 10 696 | 697 | val ROCKSDB_ESTIMATE_KEYS_NUMBER_PROPERTY = "rocksdb.estimate-num-keys" 698 | 699 | final val STATE_EXPIRY_SECS: String = "spark.sql.streaming.stateStore.stateExpirySecs" 700 | 701 | final val DEFAULT_STATE_EXPIRY_SECS: String = "-1" 702 | 703 | final val STATE_EXPIRY_STRICT_MODE: String = "spark.sql.streaming.stateStore.strictExpire" 704 | 705 | final val UNNAMED_QUERY: String = "UNNAMED_QUERY" 706 | 707 | final val DEFAULT_STATE_EXPIRY_METHOD: String = "false" 708 | 709 | final val DUMMY_VALUE: String = "" 710 | 711 | private def createCache(stateTtlSecs: Long): MapType = { 712 | val loader = new CacheLoader[UnsafeRow, String] { 713 | override def load(key: UnsafeRow): String = DUMMY_VALUE 714 | } 715 | 716 | val cacheBuilder = CacheBuilder.newBuilder() 717 | 718 | val cacheBuilderWithOptions = { 719 | if (stateTtlSecs >= 0) 720 | cacheBuilder.expireAfterAccess(stateTtlSecs, TimeUnit.SECONDS) 721 | else 722 | cacheBuilder 723 | } 724 | 725 | cacheBuilderWithOptions.build[UnsafeRow, String](loader) 726 | } 727 | 728 | /** 729 | * Creates a mapping of Streaming Query and its expiry timeout (seconds). 730 | * For backward compatibility, an additional entry is done for [[UNNAMED_QUERY]]'s 731 | * 732 | * The timeout value for [[UNNAMED_QUERY]]'s is set by the value of [[STATE_EXPIRY_SECS]] 733 | * 734 | * @param stateStoreConf state store config map set on [[SparkConf]] 735 | * @return mapping of queryName -> expirySecs 736 | */ 737 | 738 | private def getExpirationByQuery(stateStoreConf: Map[String, String]): Map[String, Int] = 739 | stateStoreConf 740 | .filterKeys(_.startsWith(s"$STATE_EXPIRY_SECS.")) 741 | .map { case (key, value) => key.replace(s"$STATE_EXPIRY_SECS.", "") -> getTTL(value) } 742 | .+(UNNAMED_QUERY -> stateStoreConf.getOrElse(STATE_EXPIRY_SECS, DEFAULT_STATE_EXPIRY_SECS).toInt) 743 | 744 | /** 745 | * Helper method to check if the given string is an integer 746 | * 747 | * @param value [[String]] value to be checked 748 | * @return [[Option]] if value is not an integer returns [[None]] else [[Some]] 749 | */ 750 | private def toInt(value: String): Option[Int] = { 751 | Try(value.toInt) match { 752 | case Success(v) => Some(v) 753 | case Failure(_) => None 754 | } 755 | } 756 | 757 | private def getTTL(expirySecs: String): Int = toInt(expirySecs) match { 758 | case Some(value) => value 759 | case None => 760 | throw new IllegalArgumentException( 761 | s"Provided value '$expirySecs' is invalid. Expiry Secs must be an Integer." 762 | ) 763 | } 764 | 765 | private def setExpireMode(conf: Map[String, String]): Boolean = 766 | Try(conf.getOrElse(STATE_EXPIRY_STRICT_MODE, DEFAULT_STATE_EXPIRY_METHOD).toBoolean) match { 767 | case Success(value) => value 768 | case Failure(e) => throw new IllegalArgumentException(e) 769 | } 770 | 771 | } 772 | -------------------------------------------------------------------------------- /src/main/scala/ru/chermenin/spark/sql/execution/streaming/state/implicits.scala: -------------------------------------------------------------------------------- 1 | package ru.chermenin.spark.sql.execution.streaming.state 2 | 3 | import org.apache.hadoop.fs.Path 4 | import org.apache.spark.sql.RuntimeConfig 5 | import org.apache.spark.sql.SparkSession.Builder 6 | import org.apache.spark.sql.internal.SQLConf 7 | import org.apache.spark.sql.streaming.DataStreamWriter 8 | import ru.chermenin.spark.sql.execution.streaming.state.RocksDbStateStoreProvider._ 9 | 10 | import scala.collection.mutable 11 | 12 | /** 13 | * Implicits aka helper methods 14 | * 15 | * The can be imported into scope with , 16 | * import ru.chermenin.spark.sql.execution.streaming.state.implicits._ 17 | * 18 | * SessionImplicits: 19 | * - Makes the `useRocksDBStateStore` method available on [[Builder]] 20 | * - Sets provider to [[RocksDbStateStoreProvider]] 21 | * 22 | * WriterImplicits: 23 | * - Makes the `stateTimeout` method available on [[DataStreamWriter]] 24 | * - Precedence is given to the provided arguments (if any), then previously set value, 25 | * and finally the value set on [[RuntimeConfig]] (in case of checkpoint location) 26 | * - Makes Checkpoint mandatory for all query on which applied 27 | * - Expiry Seconds less than 0 are treated as -1 (no timeout) 28 | */ 29 | 30 | object implicits extends Serializable { 31 | 32 | implicit class SessionImplicits(sparkSessionBuilder: Builder) { 33 | 34 | def useRocksDBStateStore(): Builder = 35 | sparkSessionBuilder.config(SQLConf.STATE_STORE_PROVIDER_CLASS.key, 36 | classOf[RocksDbStateStoreProvider].getCanonicalName) 37 | 38 | } 39 | 40 | implicit class WriterImplicits[T](dsw: DataStreamWriter[T]) { 41 | 42 | def stateTimeout(runtimeConfig: RuntimeConfig, 43 | queryName: String = "", 44 | expirySecs: Int = DEFAULT_STATE_EXPIRY_SECS.toInt, 45 | checkpointLocation: String = ""): DataStreamWriter[T] = { 46 | 47 | val extraOptions = getExtraOptions 48 | val name = queryName match { 49 | case "" | null => extraOptions.getOrElse("queryName", UNNAMED_QUERY) 50 | case _ => queryName 51 | } 52 | 53 | val location = new Path(checkpointLocation match { 54 | case "" | null => 55 | extraOptions.getOrElse("checkpointLocation", 56 | runtimeConfig.getOption(SQLConf.CHECKPOINT_LOCATION.key 57 | ).getOrElse(throw new IllegalStateException( 58 | "Checkpoint Location must be specified for State Expiry either " + 59 | """through option("checkpointLocation", ...) or """ + 60 | s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""")) 61 | ) 62 | case _ => checkpointLocation 63 | }, name) 64 | .toUri.toString 65 | 66 | runtimeConfig.set(s"$STATE_EXPIRY_SECS.$name", if (expirySecs < 0) -1 else expirySecs) 67 | 68 | dsw 69 | .queryName(name) 70 | .option("checkpointLocation", location) 71 | } 72 | 73 | private def getExtraOptions: mutable.HashMap[String, String] = { 74 | val className = classOf[DataStreamWriter[T]] 75 | val field = className.getDeclaredField("extraOptions") 76 | field.setAccessible(true) 77 | 78 | field.get(dsw).asInstanceOf[mutable.HashMap[String, String]] 79 | } 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootLogger = ERROR, CONSOLE 2 | log4j.appender.CONSOLE = org.apache.log4j.ConsoleAppender 3 | log4j.appender.CONSOLE.target = System.err 4 | log4j.appender.CONSOLE.layout = org.apache.log4j.PatternLayout 5 | log4j.appender.CONSOLE.layout.ConversionPattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 6 | -------------------------------------------------------------------------------- /src/test/scala/ru/chermenin/spark/sql/execution/streaming/state/RocksDbStateStoreHelper.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Aleksandr Chermenin 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package ru.chermenin.spark.sql.execution.streaming.state 18 | 19 | import java.io.File 20 | 21 | import org.apache.hadoop.conf.Configuration 22 | import org.apache.hadoop.fs.Path 23 | import org.apache.spark.sql.execution.streaming.state.StateStoreTestsHelper.{newDir, rowsToStringInt, stringToRow} 24 | import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId} 25 | import org.apache.spark.sql.internal.SQLConf 26 | import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} 27 | import org.scalatest.PrivateMethodTester 28 | 29 | import scala.reflect.io.Path._ 30 | import scala.util.Random 31 | 32 | object RocksDbStateStoreHelper extends PrivateMethodTester { 33 | 34 | val keySchema: StructType = StructType(Seq(StructField("key", StringType, nullable = true))) 35 | val valueSchema: StructType = StructType(Seq(StructField("value", IntegerType, nullable = true))) 36 | 37 | val key: String = "a" 38 | val batchesToRetain: Int = 3 39 | 40 | def newStoreProvider(): RocksDbStateStoreProvider = { 41 | createStoreProvider(opId = Random.nextInt(), partition = 0, keySchema = keySchema, valueSchema = valueSchema) 42 | } 43 | 44 | def newStoreProvider(storeId: StateStoreId, 45 | keySchema: StructType = keySchema, 46 | valueSchema: StructType = valueSchema): RocksDbStateStoreProvider = { 47 | createStoreProvider( 48 | storeId.operatorId.toInt, 49 | storeId.partitionId, 50 | dir = storeId.checkpointRootLocation, 51 | keySchema = keySchema, 52 | valueSchema = valueSchema) 53 | } 54 | 55 | def getData(provider: RocksDbStateStoreProvider, version: Int = -1): Set[(String, Int)] = { 56 | val reloadedProvider = newStoreProvider(provider.stateStoreId) 57 | if (version < 0) { 58 | reloadedProvider.latestIterator().map(rowsToStringInt).toSet 59 | } else { 60 | reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet 61 | } 62 | } 63 | 64 | def createStoreProvider(opId: Int, 65 | partition: Int, 66 | dir: String = newDir(), 67 | hadoopConf: Configuration = new Configuration, 68 | sqlConf: SQLConf = new SQLConf(), 69 | keySchema: StructType = keySchema, 70 | valueSchema: StructType = valueSchema): RocksDbStateStoreProvider = { 71 | sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, batchesToRetain) 72 | val provider = new RocksDbStateStoreProvider() 73 | provider.init( 74 | StateStoreId(dir, opId, partition), 75 | keySchema, 76 | valueSchema, 77 | indexOrdinal = None, 78 | new StateStoreConf(sqlConf), 79 | hadoopConf 80 | ) 81 | provider 82 | } 83 | 84 | def snapshot(version: Int): String = s"state.snapshot.$version" 85 | 86 | def fileExists(provider: RocksDbStateStoreProvider, version: Int): Boolean = { 87 | val method = PrivateMethod[Path]('baseDir) 88 | val basePath = provider invokePrivate method() 89 | val fileName = snapshot(version) 90 | val filePath = new File(basePath.toString, fileName) 91 | filePath.exists 92 | } 93 | 94 | def corruptSnapshot(provider: RocksDbStateStoreProvider, version: Int): Unit = { 95 | val method = PrivateMethod[Path]('baseDir) 96 | val basePath = provider invokePrivate method() 97 | val fileName = snapshot(version) 98 | new File(basePath.toString, fileName).delete() 99 | } 100 | 101 | def minSnapshotToRetain(version: Int): Int = version - batchesToRetain + 1 102 | 103 | def performCleanUp(pathSlice: String): Unit = { 104 | ".".toDirectory.dirs 105 | .filter(_.name.contains(pathSlice)) 106 | .foreach(x => clearDB(x.jfile)) 107 | } 108 | 109 | private def clearDB(file: File): Unit = { 110 | if (file.isDirectory) 111 | file.listFiles.foreach(clearDB) 112 | if (file.exists && !file.delete) 113 | throw new Exception(s"Unable to delete ${file.getAbsolutePath}") 114 | } 115 | 116 | def contains(store: StateStore, key: String): Boolean = 117 | store.iterator.toSeq.map(_.key).contains(stringToRow(key)) 118 | 119 | def size(store: StateStore): Long = store.iterator.size 120 | 121 | def createSQLConf(defaultTTL: Long = -1, 122 | isStrict: Boolean, 123 | configs: Map[String, String] = Map.empty): SQLConf = { 124 | val sqlConf: SQLConf = new SQLConf() 125 | 126 | sqlConf.setConfString("spark.sql.streaming.stateStore.providerClass", 127 | "ru.chermenin.spark.sql.execution.streaming.state.RocksDbStateStoreProvider") 128 | 129 | sqlConf.setConfString(RocksDbStateStoreProvider.STATE_EXPIRY_SECS, defaultTTL.toString) 130 | sqlConf.setConfString(RocksDbStateStoreProvider.STATE_EXPIRY_STRICT_MODE, isStrict.toString) 131 | 132 | configs.foreach { 133 | case (key, value) => sqlConf.setConfString(key, value) 134 | } 135 | 136 | sqlConf 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/test/scala/ru/chermenin/spark/sql/execution/streaming/state/RocksDbStateStoreProviderSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Aleksandr Chermenin 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package ru.chermenin.spark.sql.execution.streaming.state 18 | 19 | import java.util.UUID 20 | 21 | import org.apache.hadoop.conf.Configuration 22 | import org.apache.spark.sql.execution.streaming.state.StateStoreTestsHelper._ 23 | import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} 24 | import org.scalatest.{BeforeAndAfter, FunSuite} 25 | import ru.chermenin.spark.sql.execution.streaming.state.RocksDbStateStoreHelper._ 26 | 27 | import scala.util.Random 28 | 29 | class RocksDbStateStoreProviderSuite extends FunSuite with BeforeAndAfter { 30 | 31 | before { 32 | StateStore.stop() 33 | require(!StateStore.isMaintenanceRunning) 34 | } 35 | 36 | after { 37 | StateStore.stop() 38 | require(!StateStore.isMaintenanceRunning) 39 | } 40 | 41 | test("Snapshotting") { 42 | val provider = createStoreProvider(opId = Random.nextInt, partition = 0) 43 | 44 | var currentVersion = 0 45 | 46 | def updateVersionTo(targetVersion: Int): Unit = { 47 | for (i <- currentVersion + 1 to targetVersion) { 48 | val store = provider.getStore(currentVersion) 49 | put(store, key, i) 50 | store.commit() 51 | currentVersion += 1 52 | } 53 | require(currentVersion === targetVersion) 54 | } 55 | 56 | updateVersionTo(2) 57 | assert(getData(provider) === Set(key -> 2)) 58 | 59 | assert(fileExists(provider, 1)) 60 | assert(fileExists(provider, 2)) 61 | 62 | def verifySnapshot(version: Int): Unit = { 63 | updateVersionTo(version) 64 | provider.doMaintenance() 65 | require(getData(provider) === Set(key -> version), "store not updated correctly") 66 | 67 | val snapshotVersion = (0 to version).filter(version => fileExists(provider, version)).min 68 | assert(snapshotVersion >= minSnapshotToRetain(version), "no snapshot files cleaned up") 69 | 70 | assert( 71 | getData(provider, snapshotVersion) === Set(key -> snapshotVersion), 72 | "cleaning messed up the data of the snapshotted version" 73 | ) 74 | 75 | assert( 76 | getData(provider) === Set(key -> version), 77 | "cleaning messed up the data of the final version" 78 | ) 79 | } 80 | 81 | verifySnapshot(version = 6) 82 | verifySnapshot(version = 20) 83 | } 84 | 85 | test("Cleaning up") { 86 | val provider = createStoreProvider(opId = Random.nextInt, partition = 0) 87 | val maxVersion = 20 88 | 89 | for (i <- 1 to maxVersion) { 90 | val store = provider.getStore(i - 1) 91 | put(store, key, i) 92 | store.commit() 93 | provider.doMaintenance() // do cleanup 94 | } 95 | require(rowsToSet(provider.latestIterator()) === Set(key -> maxVersion), "store not updated correctly") 96 | for (version <- 1 until minSnapshotToRetain(maxVersion)) { 97 | assert(!fileExists(provider, version)) // first snapshots should be deleted 98 | } 99 | 100 | // last couple of versions should be retrievable 101 | for (version <- minSnapshotToRetain(maxVersion) to maxVersion) { 102 | assert(getData(provider, version) === Set(key -> version)) 103 | } 104 | } 105 | 106 | test("Corrupted snapshots") { 107 | val provider = createStoreProvider(opId = Random.nextInt, partition = 0) 108 | for (i <- 1 to 6) { 109 | val store = provider.getStore(i - 1) 110 | put(store, key, i) 111 | store.commit() 112 | } 113 | 114 | // clean up 115 | provider.doMaintenance() 116 | 117 | val snapshotVersion = (0 to 10).filter(version => fileExists(provider, version)).max 118 | assert(snapshotVersion === 6) 119 | 120 | // Corrupt snapshot file 121 | assert(getData(provider, snapshotVersion) === Set(key -> snapshotVersion)) 122 | corruptSnapshot(provider, snapshotVersion) 123 | 124 | // Load data from previous correct snapshot 125 | assert(getData(provider, snapshotVersion) === Set(key -> (snapshotVersion - 1))) 126 | 127 | // Do cleanup and corrupt some more snapshots 128 | corruptSnapshot(provider, snapshotVersion - 1) 129 | corruptSnapshot(provider, snapshotVersion - 2) 130 | 131 | // If no correct snapshots, create empty state 132 | assert(getData(provider, snapshotVersion) === Set()) 133 | } 134 | 135 | test("Reports metrics") { 136 | val provider = newStoreProvider() 137 | val store = provider.getStore(0) 138 | val noDataMemoryUsed = store.metrics.memoryUsedBytes 139 | put(store, key, 1) 140 | store.commit() 141 | assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) 142 | } 143 | 144 | test("StateStore.get") { 145 | val dir = newDir() 146 | val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) 147 | val storeConf = StateStoreConf.empty 148 | val hadoopConf = new Configuration() 149 | 150 | // Verify that trying to get incorrect versions throw errors 151 | intercept[IllegalArgumentException] { 152 | StateStore.get( 153 | storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf) 154 | } 155 | assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store 156 | 157 | intercept[IllegalStateException] { 158 | StateStore.get( 159 | storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) 160 | } 161 | 162 | // Increase version of the store and try to get again 163 | val store0 = StateStore.get( 164 | storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) 165 | assert(store0.version === 0) 166 | put(store0, key, 1) 167 | store0.commit() 168 | 169 | val store1 = StateStore.get( 170 | storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) 171 | assert(StateStore.isLoaded(storeId)) 172 | assert(store1.version === 1) 173 | assert(rowsToSet(store1.iterator()) === Set(key -> 1)) 174 | 175 | // Verify that you can also load older version 176 | val store0reloaded = StateStore.get( 177 | storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) 178 | assert(store0reloaded.version === 0) 179 | assert(rowsToSet(store0reloaded.iterator()) === Set.empty) 180 | 181 | // Verify that you can remove the store and still reload and use it 182 | StateStore.unload(storeId) 183 | assert(!StateStore.isLoaded(storeId)) 184 | 185 | val store1reloaded = StateStore.get( 186 | storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) 187 | assert(StateStore.isLoaded(storeId)) 188 | assert(store1reloaded.version === 1) 189 | put(store1reloaded, key, 2) 190 | assert(store1reloaded.commit() === 2) 191 | assert(rowsToSet(store1reloaded.iterator()) === Set(key -> 2)) 192 | } 193 | 194 | } 195 | -------------------------------------------------------------------------------- /src/test/scala/ru/chermenin/spark/sql/execution/streaming/state/RocksDbStateTimeoutSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Aleksandr Chermenin 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package ru.chermenin.spark.sql.execution.streaming.state 18 | 19 | import java.util.concurrent.{ConcurrentHashMap, TimeUnit} 20 | 21 | import com.google.common.base.Ticker 22 | import com.google.common.cache.{CacheBuilder, CacheLoader} 23 | import com.google.common.testing.FakeTicker 24 | import org.apache.spark.sql.catalyst.expressions.UnsafeRow 25 | import org.apache.spark.sql.execution.streaming.state._ 26 | import org.apache.spark.sql.internal.SQLConf 27 | import org.scalatest.{BeforeAndAfter, FunSuite} 28 | import ru.chermenin.spark.sql.execution.streaming.state.RocksDbStateStoreProvider.{DUMMY_VALUE, MapType} 29 | 30 | import scala.util.Random 31 | 32 | /** 33 | * @author Chitral Verma 34 | * @since 10/30/18 35 | */ 36 | class RocksDbStateTimeoutSuite extends FunSuite with BeforeAndAfter { 37 | 38 | import RocksDbStateStoreHelper._ 39 | import StateStoreTestsHelper._ 40 | 41 | final val testDBLocation: String = "testdb" 42 | 43 | private def withTTLStore(ttl: Long, sqlConf: SQLConf, pathSuffix: String = "") 44 | (f: (FakeTicker, StateStore) => Unit): Unit = { 45 | val (ticker, stateStore) = createTTLStore(ttl, sqlConf, testDBLocation + pathSuffix) 46 | 47 | f(ticker, stateStore) 48 | stateStore.commit() 49 | } 50 | 51 | private def stopAndCleanUp(): Unit = { 52 | StateStore.stop() 53 | require(!StateStore.isMaintenanceRunning) 54 | performCleanUp(testDBLocation) 55 | } 56 | 57 | before { 58 | stopAndCleanUp() 59 | } 60 | 61 | after { 62 | stopAndCleanUp() 63 | } 64 | 65 | test("no timeout") { 66 | val expireTime = -1 67 | val sqlConf = createSQLConf(expireTime, isStrict = true) 68 | 69 | withTTLStore(expireTime, sqlConf)((ticker, store) => { 70 | put(store, "k1", 1) 71 | ticker.advance(20, TimeUnit.SECONDS) 72 | 73 | assert(size(store) === 1) 74 | assert(contains(store, "k1")) 75 | 76 | ticker.advance(Long.MaxValue, TimeUnit.SECONDS) 77 | 78 | assert(size(store) === 1) 79 | assert(contains(store, "k1")) 80 | }) 81 | } 82 | 83 | test("statelessness") { 84 | val expireTime = 0 85 | val sqlConf = createSQLConf(expireTime, isStrict = true) 86 | 87 | withTTLStore(expireTime, sqlConf)((_, store) => { 88 | put(store, "k1", 1) 89 | 90 | assert(size(store) === 0) 91 | assert(!contains(store, "k1")) 92 | 93 | put(store, "k1", 1) 94 | put(store, "k2", 1) 95 | put(store, "k3", 1) 96 | 97 | assert(size(store) === 0) 98 | assert(!contains(store, "k1")) 99 | assert(!contains(store, "k2")) 100 | assert(!contains(store, "k3")) 101 | }) 102 | } 103 | 104 | test("processing timeout") { 105 | val expireTime = 5 106 | val sqlConf = createSQLConf(expireTime, isStrict = true) 107 | 108 | withTTLStore(expireTime, sqlConf)((ticker, store) => { 109 | put(store, "k1", 1) 110 | 111 | ticker.advance(3, TimeUnit.SECONDS) 112 | 113 | assert(size(store) === 1) 114 | assert(contains(store, "k1")) 115 | 116 | ticker.advance(expireTime - 3, TimeUnit.SECONDS) 117 | 118 | assert(size(store) === 0) 119 | assert(!contains(store, "k1")) 120 | }) 121 | } 122 | 123 | test("ttl should reset on get, set and update") { 124 | val expireTime = 5 125 | val sqlConf = createSQLConf(expireTime, isStrict = true) 126 | 127 | withTTLStore(expireTime, sqlConf)((ticker, store) => { 128 | put(store, "k1", 1) 129 | put(store, "k2", 1) 130 | ticker.advance(3, TimeUnit.SECONDS) 131 | 132 | assert(size(store) === 2) 133 | assert(contains(store, "k1")) 134 | 135 | put(store, "k1", 2) // reset timeout for k1 136 | ticker.advance(2, TimeUnit.SECONDS) // deadline met for k2 137 | 138 | assert(size(store) === 1) 139 | assert(!contains(store, "k2")) 140 | 141 | ticker.advance(2, TimeUnit.SECONDS) 142 | 143 | assert(size(store) === 1) // 1 second remains for k1 here 144 | assert(contains(store, "k1")) 145 | 146 | ticker.advance(1, TimeUnit.SECONDS) // deadline met for k1 147 | 148 | put(store, "k3", 3) 149 | 150 | assert(size(store) === 1) 151 | assert(!contains(store, "k1")) 152 | 153 | ticker.advance(4, TimeUnit.SECONDS) // 1 second remains for k3 here 154 | 155 | assert(size(store) === 1) 156 | assert(contains(store, "k3")) 157 | 158 | get(store, "k3") // reset timeout for k3 159 | 160 | ticker.advance(1, TimeUnit.SECONDS) 161 | 162 | assert(size(store) === 1) 163 | assert(contains(store, "k3")) 164 | 165 | ticker.advance(4, TimeUnit.SECONDS) // deadline met for k3 166 | 167 | assert(size(store) === 0) 168 | assert(!contains(store, "k3")) 169 | }) 170 | } 171 | 172 | test("different timeouts for each streaming query (states)") { 173 | // Each query creates its own state store, the SQLConf is the same 174 | import RocksDbStateStoreProvider.STATE_EXPIRY_SECS 175 | val query1 = "query1" 176 | val timeout1 = 3 177 | 178 | val query2 = "query2" 179 | val timeout2 = 5 180 | 181 | val sqlConf = createSQLConf(isStrict = true, configs = Map( 182 | s"$STATE_EXPIRY_SECS.$query1" -> s"$timeout1", 183 | s"$STATE_EXPIRY_SECS.$query2" -> s"$timeout2" 184 | )) 185 | 186 | withTTLStore(timeout1, sqlConf, "1")((ticker1, store1) => { 187 | withTTLStore(timeout2, sqlConf, "2")((ticker2, store2) => { 188 | 189 | // Same data is read by both queries 190 | put(store1, "k1", 1) 191 | put(store1, "k2", 1) 192 | put(store2, "k1", 1) 193 | put(store2, "k2", 1) 194 | 195 | assert(size(store1) === 2) 196 | assert(contains(store1, "k1")) 197 | assert(contains(store1, "k2")) 198 | 199 | assert(size(store2) === 2) 200 | assert(contains(store2, "k1")) 201 | assert(contains(store2, "k2")) 202 | 203 | // Clock progression is the same for both queries 204 | ticker1.advance(2, TimeUnit.SECONDS) 205 | ticker2.advance(2, TimeUnit.SECONDS) 206 | 207 | assert(size(store1) === 2) 208 | assert(contains(store1, "k1")) 209 | assert(contains(store1, "k2")) 210 | 211 | assert(size(store2) === 2) 212 | assert(contains(store2, "k1")) 213 | assert(contains(store2, "k2")) 214 | 215 | ticker1.advance(1, TimeUnit.SECONDS) // deadline met for query1 216 | ticker2.advance(1, TimeUnit.SECONDS) 217 | 218 | assert(size(store1) === 0) 219 | assert(!contains(store1, "k1")) 220 | assert(!contains(store1, "k2")) 221 | 222 | assert(size(store2) === 2) 223 | assert(contains(store2, "k1")) 224 | assert(contains(store2, "k2")) 225 | 226 | ticker1.advance(2, TimeUnit.SECONDS) 227 | ticker2.advance(2, TimeUnit.SECONDS) // deadline met for query2 228 | 229 | assert(size(store1) === 0) 230 | assert(!contains(store1, "k1")) 231 | assert(!contains(store1, "k2")) 232 | 233 | assert(size(store2) === 0) 234 | assert(!contains(store2, "k1")) 235 | assert(!contains(store2, "k2")) 236 | 237 | 238 | }) 239 | }) 240 | } 241 | 242 | private def createTTLStore(ttl: Long, sqlConf: SQLConf, dbPath: String): (FakeTicker, StateStore) = { 243 | 244 | def createMockCache(ttl: Long, ticker: Ticker): MapType = { 245 | val loader = new CacheLoader[UnsafeRow, String] { 246 | override def load(key: UnsafeRow): String = DUMMY_VALUE 247 | } 248 | 249 | val cacheBuilder = CacheBuilder.newBuilder() 250 | 251 | val cacheBuilderWithOptions = { 252 | if (ttl >= 0) { 253 | cacheBuilder 254 | .expireAfterAccess(ttl, TimeUnit.SECONDS) 255 | .ticker(ticker) 256 | } else 257 | cacheBuilder 258 | } 259 | 260 | cacheBuilderWithOptions.build[UnsafeRow, String](loader) 261 | } 262 | 263 | val ticker = new FakeTicker 264 | val cache = createMockCache(ttl, ticker) 265 | 266 | val provider = createStoreProvider(opId = Random.nextInt(), partition = Random.nextInt(), sqlConf = sqlConf) 267 | val store = new provider.RocksDbStateStore(0, dbPath, keySchema, valueSchema, new ConcurrentHashMap, cache) 268 | 269 | (ticker, store) 270 | } 271 | 272 | } 273 | --------------------------------------------------------------------------------