├── project ├── build.properties └── plugins.sbt ├── version.sbt ├── src ├── main │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── org.apache.spark.sql.sources.DataSourceRegister │ └── scala │ │ ├── net │ │ └── heartsavior │ │ │ └── spark │ │ │ └── sql │ │ │ ├── util │ │ │ ├── HadoopPathUtil.scala │ │ │ └── SchemaUtil.scala │ │ │ ├── state │ │ │ ├── StateStoreRelation.scala │ │ │ ├── StateStoreReaderOperatorParamExtractor.scala │ │ │ ├── StateSchemaExtractor.scala │ │ │ ├── StateStoreWriter.scala │ │ │ ├── migration │ │ │ │ ├── StreamingAggregationMigrator.scala │ │ │ │ └── FlatMapGroupsWithStateMigrator.scala │ │ │ ├── StateStoreReaderRDD.scala │ │ │ ├── StateInformationInCheckpoint.scala │ │ │ └── StateStoreDataSourceProvider.scala │ │ │ └── checkpoint │ │ │ └── CheckpointUtil.scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── sql │ │ └── hack │ │ ├── SerializableConfigurationWrapper.scala │ │ └── SparkSqlHack.scala └── test │ ├── resources │ └── log4j.properties │ └── scala │ └── net │ └── heartsavior │ └── spark │ └── sql │ └── state │ ├── StateInformationInCheckpointSuite.scala │ ├── StateSchemaExtractorSuite.scala │ ├── StateStoreReaderOperatorParamExtractorSuite.scala │ ├── StreamingAggregationMigratorSuite.scala │ ├── FlatMapGroupsWithStateMigratorSuite.scala │ ├── StateStoreStreamingAggregationReadSuite.scala │ ├── StateStoreTest.scala │ └── StateStoreStreamingAggregationWriteSuite.scala ├── .gitignore ├── .circleci └── config.yml ├── checkstyle.xml ├── README.md ├── LICENSE └── scalastyle-config.xml /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.3.10 2 | -------------------------------------------------------------------------------- /version.sbt: -------------------------------------------------------------------------------- 1 | version in ThisBuild := "0.5.1-spark-3.0-SNAPSHOT" 2 | -------------------------------------------------------------------------------- /src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | net.heartsavior.spark.sql.state.StateStoreDataSourceProvider -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") 2 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.0") 3 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.6.1") 4 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.13") -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *#*# 2 | *.#* 3 | *.iml 4 | *.ipr 5 | *.iws 6 | *.pyc 7 | *.pyo 8 | *.swp 9 | *~ 10 | .DS_Store 11 | .cache 12 | .classpath 13 | .idea/ 14 | .idea_modules/ 15 | .project 16 | .pydevproject 17 | .scala_dependencies 18 | .settings 19 | /lib/ 20 | build/*.jar 21 | build/apache-maven* 22 | build/scala* 23 | build/zinc* 24 | cache 25 | dependency-reduced-pom.xml 26 | derby.log 27 | lib_managed/ 28 | log/ 29 | logs/ 30 | out/ 31 | project/boot/ 32 | project/build/target/ 33 | project/plugins/lib_managed/ 34 | project/plugins/project/build.properties 35 | project/plugins/src_managed/ 36 | project/plugins/target/ 37 | scalastyle-on-compile.generated.xml 38 | scalastyle-output.xml 39 | scalastyle.txt 40 | target/ 41 | unit-tests.log 42 | *.crc 43 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/util/HadoopPathUtil.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.util 18 | 19 | import org.apache.hadoop.conf.Configuration 20 | import org.apache.hadoop.fs.Path 21 | 22 | object HadoopPathUtil { 23 | def resolve(hadoopConf: Configuration, cpLocation: String): String = { 24 | val checkpointPath = new Path(cpLocation) 25 | val fs = checkpointPath.getFileSystem(hadoopConf) 26 | checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/util/SchemaUtil.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.util 18 | 19 | import org.apache.spark.sql.hack.SparkSqlHack 20 | import org.apache.spark.sql.types.{DataType, StructType} 21 | 22 | object SchemaUtil { 23 | def getSchemaAsDataType(schema: StructType, fieldName: String): DataType = { 24 | schema(SparkSqlHack.getFieldIndex(schema, fieldName).get).dataType 25 | } 26 | 27 | def keyValuePairSchema(keySchema: StructType, valueSchema: StructType): StructType = 28 | new StructType() 29 | .add("key", StructType(keySchema.fields), nullable = false) 30 | .add("value", StructType(valueSchema.fields), nullable = false) 31 | } 32 | -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | # Set everything to be logged to the file target/unit-tests.log 19 | test.appender=file 20 | log4j.rootCategory=INFO, ${test.appender} 21 | log4j.appender.file=org.apache.log4j.FileAppender 22 | log4j.appender.file.append=true 23 | log4j.appender.file.file=target/unit-tests.log 24 | log4j.appender.file.layout=org.apache.log4j.PatternLayout 25 | log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n 26 | 27 | # Silence some noisy libraries. 28 | log4j.logger.org.apache.http=WARN 29 | log4j.logger.org.apache.spark=INFO 30 | log4j.logger.org.eclipse.jetty=WARN 31 | log4j.logger.org.spark-project.jetty=WARN 32 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | jobs: 3 | build: 4 | working_directory: ~/repo 5 | docker: 6 | - image: openjdk:8 7 | environment: 8 | SBT_VERSION: 1.3.10 9 | steps: 10 | - run: echo 'export ARTIFACT_BUILD=$CIRCLE_PROJECT_REPONAME-$CIRCLE_BUILD_NUM.zip' >> $BASH_ENV 11 | - run: 12 | name: Get sbt binary 13 | command: | 14 | apt update && apt install -y curl 15 | curl -L -o sbt-$SBT_VERSION.deb https://dl.bintray.com/sbt/debian/sbt-$SBT_VERSION.deb 16 | dpkg -i sbt-$SBT_VERSION.deb 17 | rm sbt-$SBT_VERSION.deb 18 | apt-get update 19 | apt-get install -y python-pip git 20 | pip install awscli 21 | apt-get clean && apt-get autoclean 22 | - checkout 23 | - restore_cache: 24 | # Read about caching dependencies: https://circleci.com/docs/2.0/caching/ 25 | key: sbt-cache 26 | - run: 27 | name: Compile spark-state-tools dist package 28 | command: cat /dev/null | sbt clean +update scalastyle +test +package 29 | - when: 30 | condition: true 31 | steps: 32 | - store_artifacts: 33 | path: ~/repo/target/scalastyle-output.xml 34 | destination: scalastyle-output.xml 35 | 36 | - store_artifacts: 37 | path: ~/repo/target/unit-tests.log 38 | destination: unit-tests.log 39 | - save_cache: 40 | key: sbt-cache 41 | paths: 42 | - "~/.ivy2/cache" 43 | - "~/.sbt" 44 | - "~/.m2" 45 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/hack/SerializableConfigurationWrapper.scala: -------------------------------------------------------------------------------- 1 | // scalastyle:off header 2 | /* 3 | * Licensed to the Apache Software Foundation (ASF) under one or more 4 | * contributor license agreements. See the NOTICE file distributed with 5 | * this work for additional information regarding copyright ownership. 6 | * The ASF licenses this file to You under the Apache License, Version 2.0 7 | * (the "License"); you may not use this file except in compliance with 8 | * the License. You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | package org.apache.spark.sql.hack 20 | 21 | import org.apache.spark.broadcast.Broadcast 22 | import org.apache.spark.sql.SparkSession 23 | import org.apache.spark.util.SerializableConfiguration 24 | 25 | /** 26 | * This class was added because without it a NullPointerException was thrown by 27 | * StateStore Providers as the hadoop configuration resulted to be null. 28 | */ 29 | class SerializableConfigurationWrapper(session: SparkSession) extends Serializable { 30 | val broadcastedConf: Broadcast[SerializableConfiguration] = { 31 | val conf = new SerializableConfiguration(session.sparkContext.hadoopConfiguration) 32 | session.sparkContext.broadcast(conf) 33 | } 34 | } 35 | // scalastyle:on header 36 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/StateStoreRelation.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import net.heartsavior.spark.sql.util.SchemaUtil 20 | import org.apache.hadoop.fs.Path 21 | 22 | import org.apache.spark.internal.Logging 23 | import org.apache.spark.rdd.RDD 24 | import org.apache.spark.sql.{Row, SparkSession, SQLContext} 25 | import org.apache.spark.sql.execution.streaming.state.StateStoreId 26 | import org.apache.spark.sql.hack.SparkSqlHack 27 | import org.apache.spark.sql.sources.{BaseRelation, TableScan} 28 | import org.apache.spark.sql.types.StructType 29 | 30 | // TODO: read schema of key and value from metadata of state (requires SPARK-27237) 31 | class StateStoreRelation( 32 | session: SparkSession, 33 | keySchema: StructType, 34 | valueSchema: StructType, 35 | stateCheckpointLocation: String, 36 | batchId: Int, 37 | operatorId: Int, 38 | storeName: String = StateStoreId.DEFAULT_STORE_NAME) 39 | extends BaseRelation with TableScan with Logging { 40 | 41 | override def sqlContext: SQLContext = session.sqlContext 42 | 43 | override def schema: StructType = SchemaUtil.keyValuePairSchema(keySchema, valueSchema) 44 | 45 | override def buildScan(): RDD[Row] = { 46 | val resolvedCpLocation = { 47 | val checkpointPath = new Path(stateCheckpointLocation) 48 | val fs = checkpointPath.getFileSystem(SparkSqlHack.sessionState(sqlContext).newHadoopConf()) 49 | fs.mkdirs(checkpointPath) 50 | checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString 51 | } 52 | 53 | new StateStoreReaderRDD(session, keySchema, valueSchema, 54 | resolvedCpLocation, batchId, operatorId, storeName) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/hack/SparkSqlHack.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.apache.spark.sql.hack 18 | 19 | import java.io.File 20 | 21 | import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} 22 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 23 | import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, StateStoreSaveExec} 24 | import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateManager 25 | import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager 26 | import org.apache.spark.sql.internal.{SessionState, SQLConf} 27 | import org.apache.spark.sql.types.StructType 28 | import org.apache.spark.util.Utils 29 | 30 | object SparkSqlHack { 31 | def getFieldIndex(schema: StructType, fieldName: String): Option[Int] = { 32 | schema.getFieldIndex(fieldName) 33 | } 34 | 35 | def sessionState(sqlContext: SQLContext): SessionState = { 36 | sqlContext.sessionState 37 | } 38 | 39 | def sqlConf(sqlContext: SQLContext): SQLConf = { 40 | sqlContext.conf 41 | } 42 | 43 | def logicalPlan(query: DataFrame): LogicalPlan = query.logicalPlan 44 | 45 | def stateManager(exec: StateStoreSaveExec): StreamingAggregationStateManager = { 46 | exec.stateManager 47 | } 48 | 49 | def stateManager(exec: FlatMapGroupsWithStateExec): StateManager = { 50 | exec.stateManager 51 | } 52 | 53 | def analysisException(message: String): AnalysisException = { 54 | new AnalysisException(message) 55 | } 56 | 57 | def createTempDir( 58 | root: String = System.getProperty("java.io.tmpdir"), 59 | namePrefix: String = "spark"): File = { 60 | Utils.createTempDir(root, namePrefix) 61 | } 62 | 63 | def deleteRecursively(file: File): Unit = Utils.deleteRecursively(file) 64 | } 65 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/StateStoreReaderOperatorParamExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import net.heartsavior.spark.sql.state.StateInformationInCheckpoint.StateInformation 20 | import net.heartsavior.spark.sql.state.StateSchemaExtractor.StateSchemaInfo 21 | 22 | import org.apache.spark.sql.execution.streaming.state.StateStoreId 23 | import org.apache.spark.sql.types.StructType 24 | 25 | /** 26 | * This class combines [[StateInformation]] and [[StateSchemaInfo]] to provide actual 27 | * parameters needed for state store read. 28 | */ 29 | object StateStoreReaderOperatorParamExtractor { 30 | case class StateStoreReaderOperatorParam( 31 | lastStateVersion: Option[Long], 32 | opId: Int, 33 | storeName: String, 34 | stateSchema: Option[StructType]) 35 | 36 | def extract( 37 | stateInfo: StateInformation, 38 | schemaInfos: Seq[StateSchemaInfo]) 39 | : Seq[StateStoreReaderOperatorParam] = { 40 | 41 | val lastStateVer = stateInfo.lastCommittedBatchId.map(_ + 1) 42 | 43 | val stInfoGrouped = stateInfo.operators.groupBy(_.opId) 44 | val schemaInfoGrouped = schemaInfos.groupBy(_.opId) 45 | stInfoGrouped.flatMap { case (key, value) => 46 | if (value.length != 1) { 47 | throw new IllegalStateException("It should only have one state operator information " + 48 | "per operation ID") 49 | } 50 | 51 | value.head.storeNames.map { storeName => 52 | val stateSchema: Option[StructType] = { 53 | if (storeName == StateStoreId.DEFAULT_STORE_NAME) { 54 | schemaInfoGrouped.get(key).map { infoValue => 55 | if (infoValue.length != 1) { 56 | throw new IllegalStateException("StateSchemaInfo only supports one schema per " + 57 | "operator id - which uses DEFAULT_STORE_NAME as store name.") 58 | } 59 | val ret = infoValue.head 60 | new StructType().add("key", ret.keySchema).add("value", ret.valueSchema) 61 | } 62 | } else { 63 | None 64 | } 65 | } 66 | 67 | StateStoreReaderOperatorParam(lastStateVer, key, storeName, stateSchema) 68 | } 69 | }.toSeq 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/StateSchemaExtractor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import java.util.UUID 20 | 21 | import net.heartsavior.spark.sql.state.StateSchemaExtractor.{StateKind, StateSchemaInfo} 22 | 23 | import org.apache.spark.internal.Logging 24 | import org.apache.spark.sql.{DataFrame, SparkSession} 25 | import org.apache.spark.sql.execution.streaming._ 26 | import org.apache.spark.sql.hack.SparkSqlHack 27 | import org.apache.spark.sql.streaming.OutputMode 28 | import org.apache.spark.sql.types.StructType 29 | 30 | /** 31 | * This class enables extracting state schema and its format version via analyzing 32 | * the streaming query. The query should have its state operators but it should exclude sink(s). 33 | * 34 | * Note that it only returns which can be extracted by this class, so number of state 35 | * in given query may not be same as returned number of schema information. 36 | */ 37 | class StateSchemaExtractor(spark: SparkSession) extends Logging { 38 | 39 | def extract(query: DataFrame): Seq[StateSchemaInfo] = { 40 | require(query.isStreaming, "Given query is not a streaming query!") 41 | 42 | val queryExecution = new IncrementalExecution(spark, SparkSqlHack.logicalPlan(query), 43 | OutputMode.Update(), 44 | "", UUID.randomUUID(), UUID.randomUUID(), 0, OffsetSeqMetadata(0, 0)) 45 | 46 | // TODO: handle Streaming Join (if possible), etc. 47 | queryExecution.executedPlan.collect { 48 | case store: StateStoreSaveExec => 49 | val stateFormatVersion = store.stateFormatVersion 50 | val keySchema = store.keyExpressions.toStructType 51 | val valueSchema = SparkSqlHack.stateManager(store).getStateValueSchema 52 | store.stateInfo match { 53 | case Some(stInfo) => 54 | val operatorId = stInfo.operatorId 55 | StateSchemaInfo(operatorId, StateKind.StreamingAggregation, 56 | stateFormatVersion, keySchema, valueSchema) 57 | 58 | case None => throw new IllegalStateException("State information not set!") 59 | } 60 | 61 | case store: FlatMapGroupsWithStateExec => 62 | val stateFormatVersion = store.stateFormatVersion 63 | val keySchema = store.groupingAttributes.toStructType 64 | val valueSchema = SparkSqlHack.stateManager(store).stateSchema 65 | store.stateInfo match { 66 | case Some(stInfo) => 67 | val operatorId = stInfo.operatorId 68 | StateSchemaInfo(operatorId, StateKind.FlatMapGroupsWithState, 69 | stateFormatVersion, keySchema, valueSchema) 70 | 71 | case None => throw new IllegalStateException("State information not set!") 72 | } 73 | } 74 | } 75 | 76 | } 77 | 78 | object StateSchemaExtractor { 79 | object StateKind extends Enumeration { 80 | val StreamingAggregation, StreamingJoin, FlatMapGroupsWithState = Value 81 | } 82 | 83 | case class StateSchemaInfo( 84 | opId: Long, 85 | stateKind: StateKind.Value, 86 | formatVersion: Int, 87 | keySchema: StructType, 88 | valueSchema: StructType) 89 | } 90 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/checkpoint/CheckpointUtil.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.checkpoint 18 | 19 | import net.heartsavior.spark.sql.util.HadoopPathUtil 20 | import org.apache.hadoop.fs.{FileUtil, Path} 21 | 22 | import org.apache.spark.sql.SparkSession 23 | import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata} 24 | 25 | /** 26 | * Providing features to deal with checkpoint, like creating savepoint. 27 | */ 28 | object CheckpointUtil { 29 | 30 | /** 31 | * Create savepoint from existing checkpoint. 32 | * OffsetLog and CommitLog will be purged based on newLastBatchId. 33 | * Use `additionalMetadataConf` to modify metadata configuration: you may want to modify it 34 | * when rescaling state, or migrate state format version. 35 | * e.g. when rescaling, pass Map(SQLConf.SHUFFLE_PARTITIONS.key -> newShufflePartitions.toString) 36 | * 37 | * @param sparkSession spark session 38 | * @param checkpointRoot the root path of existing checkpoint 39 | * @param newCheckpointRoot the root path of new savepoint - target directory should be empty 40 | * @param newLastBatchId the new last batch ID - it needs to be one of committed batch ID 41 | * @param additionalMetadataConf the configuration to add to existing metadata configuration 42 | * @param excludeState whether to exclude state directory 43 | */ 44 | def createSavePoint( 45 | sparkSession: SparkSession, 46 | checkpointRoot: String, 47 | newCheckpointRoot: String, 48 | newLastBatchId: Long, 49 | additionalMetadataConf: Map[String, String], 50 | excludeState: Boolean = false): Unit = { 51 | val hadoopConf = sparkSession.sessionState.newHadoopConf() 52 | 53 | val src = new Path(HadoopPathUtil.resolve(hadoopConf, checkpointRoot)) 54 | val srcFs = src.getFileSystem(hadoopConf) 55 | val dst = new Path(HadoopPathUtil.resolve(hadoopConf, newCheckpointRoot)) 56 | val dstFs = dst.getFileSystem(hadoopConf) 57 | 58 | if (dstFs.listFiles(dst, false).hasNext) { 59 | throw new IllegalArgumentException("Destination directory should be empty.") 60 | } 61 | 62 | dstFs.mkdirs(dst) 63 | 64 | // copy content of src directory to dst directory 65 | srcFs.listStatus(src).foreach { fs => 66 | val path = fs.getPath 67 | val fileName = path.getName 68 | if (fileName == "state" && excludeState) { 69 | // pass 70 | } else { 71 | FileUtil.copy(srcFs, path, dstFs, new Path(dst, fileName), 72 | false, false, hadoopConf) 73 | } 74 | } 75 | 76 | val offsetLog = new OffsetSeqLog(sparkSession, new Path(dst, "offsets").toString) 77 | val logForBatch = offsetLog.get(newLastBatchId) match { 78 | case Some(log) => log 79 | case None => throw new IllegalStateException("offset log for batch should exist") 80 | } 81 | 82 | val newMetadata = logForBatch.metadata match { 83 | case Some(md) => 84 | val newMap = md.conf ++ additionalMetadataConf 85 | Some(md.copy(conf = newMap)) 86 | case None => 87 | Some(OffsetSeqMetadata(conf = additionalMetadataConf)) 88 | } 89 | 90 | val newLogForBatch = logForBatch.copy(metadata = newMetadata) 91 | 92 | // we will restart from last batch + 1: overwrite the last batch with new configuration 93 | offsetLog.purgeAfter(newLastBatchId - 1) 94 | offsetLog.add(newLastBatchId, newLogForBatch) 95 | 96 | val commitLog = new CommitLog(sparkSession, new Path(dst, "commits").toString) 97 | commitLog.purgeAfter(newLastBatchId) 98 | 99 | // state doesn't expose purge mechanism as its interface 100 | // assuming state would work with overwriting batch files when it replays previous batch 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/StateStoreWriter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import java.util.UUID 20 | 21 | import org.apache.hadoop.fs.Path 22 | 23 | import org.apache.spark.TaskContext 24 | import org.apache.spark.broadcast.Broadcast 25 | import org.apache.spark.sql.{Column, DataFrame, SparkSession} 26 | import org.apache.spark.sql.catalyst.InternalRow 27 | import org.apache.spark.sql.catalyst.expressions.UnsafeRow 28 | import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} 29 | import org.apache.spark.sql.hack.SerializableConfigurationWrapper 30 | import org.apache.spark.sql.types.StructType 31 | 32 | class StateStoreWriter( 33 | session: SparkSession, 34 | data: DataFrame, 35 | keySchema: StructType, 36 | valueSchema: StructType, 37 | stateCheckpointLocation: String, 38 | version: Int, 39 | operatorId: Int, 40 | storeName: String, 41 | newPartitions: Int) { 42 | 43 | import StateStoreWriter._ 44 | 45 | private val storeConf = new StateStoreConf(session.sessionState.conf) 46 | 47 | // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it 48 | private val hadoopConfBroadcast = new SerializableConfigurationWrapper(session) 49 | 50 | def write(): Unit = { 51 | val resolvedCpLocation = { 52 | val checkpointPath = new Path(stateCheckpointLocation) 53 | val fs = checkpointPath.getFileSystem(session.sessionState.newHadoopConf()) 54 | if (fs.exists(checkpointPath)) { 55 | throw new IllegalStateException(s"Checkpoint location should not exist. " + 56 | s"Path: $checkpointPath") 57 | } 58 | fs.mkdirs(checkpointPath) 59 | checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString 60 | } 61 | 62 | // just provide dummy ID since it doesn't matter 63 | // if it really matters in future, convert it to parameter 64 | val queryId = UUID.randomUUID() 65 | 66 | // TODO: expand this to cover multi-depth (nested) columns (do we want to cover it?) 67 | val fullPathsForKeyColumns = keySchema.map(key => new Column(s"key.${key.name}")) 68 | data 69 | .repartition(newPartitions, fullPathsForKeyColumns: _*) 70 | .queryExecution 71 | .toRdd 72 | .foreachPartition( 73 | writeFn(resolvedCpLocation, version, operatorId, storeName, keySchema, valueSchema, 74 | storeConf, hadoopConfBroadcast, queryId)) 75 | } 76 | } 77 | 78 | object StateStoreWriter { 79 | 80 | def writeFn( 81 | resolvedCpLocation: String, 82 | version: Int, 83 | operatorId: Int, 84 | storeName: String, 85 | keySchema: StructType, 86 | valueSchema: StructType, 87 | storeConf: StateStoreConf, 88 | hadoopConfBroadcast: SerializableConfigurationWrapper, 89 | queryId: UUID): Iterator[InternalRow] => Unit = iter => { 90 | val taskContext = TaskContext.get() 91 | 92 | val partIdx = taskContext.partitionId() 93 | val hadoopConf = hadoopConfBroadcast.broadcastedConf.value.value 94 | 95 | val storeId = StateStoreId(resolvedCpLocation, operatorId, partIdx, storeName) 96 | val storeProviderId = StateStoreProviderId(storeId, queryId) 97 | 98 | // fill empty state until target version - 1 99 | (0 until version - 1).map { id => 100 | val store = StateStore.get(storeProviderId, keySchema, valueSchema, None, id, 101 | storeConf, hadoopConf) 102 | store.commit() 103 | } 104 | 105 | // all states will be written at version 106 | val store = StateStore.get(storeProviderId, keySchema, valueSchema, None, version - 1, 107 | storeConf, hadoopConf) 108 | iter.foreach { row => 109 | store.put( 110 | row.getStruct(0, keySchema.fields.length).asInstanceOf[UnsafeRow], 111 | row.getStruct(1, valueSchema.fields.length).asInstanceOf[UnsafeRow] 112 | ) 113 | } 114 | store.commit() 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/test/scala/net/heartsavior/spark/sql/state/StateInformationInCheckpointSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import org.apache.hadoop.fs.Path 20 | import org.scalatest.{Assertions, BeforeAndAfterAll} 21 | 22 | import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId} 23 | import org.apache.spark.sql.hack.SparkSqlHack 24 | import org.apache.spark.sql.internal.SQLConf 25 | 26 | class StateInformationInCheckpointSuite 27 | extends StateStoreTest 28 | with BeforeAndAfterAll 29 | with Assertions { 30 | 31 | override def afterAll(): Unit = { 32 | super.afterAll() 33 | StateStore.stop() 34 | } 35 | 36 | test("Reading checkpoint from streaming aggregation") { 37 | withTempDir { cpDir => 38 | runCompositeKeyStreamingAggregationQuery(cpDir.getAbsolutePath) 39 | 40 | val stateInfo = new StateInformationInCheckpoint(spark) 41 | .gatherInformation(new Path(cpDir.getAbsolutePath)) 42 | 43 | assert(stateInfo.lastCommittedBatchId === Some(2)) 44 | assert(stateInfo.operators.length === 1) 45 | 46 | val operator = stateInfo.operators.head 47 | assert(operator.opId === 0) 48 | assert(operator.partitions === SparkSqlHack.sqlConf(spark.sqlContext).numShufflePartitions) 49 | assert(operator.storeNames === Seq(StateStoreId.DEFAULT_STORE_NAME)) 50 | 51 | assert(stateInfo.confs.get(SQLConf.SHUFFLE_PARTITIONS.key) === 52 | Some(operator.partitions.toString)) 53 | } 54 | } 55 | 56 | test("Reading checkpoint from streaming deduplication") { 57 | withTempDir { cpDir => 58 | runStreamingDeduplicationQuery(cpDir.getAbsolutePath) 59 | 60 | val stateInfo = new StateInformationInCheckpoint(spark) 61 | .gatherInformation(new Path(cpDir.getAbsolutePath)) 62 | 63 | assert(stateInfo.lastCommittedBatchId === Some(2)) 64 | assert(stateInfo.operators.length === 1) 65 | 66 | val operator = stateInfo.operators.head 67 | assert(operator.opId === 0) 68 | assert(operator.partitions === SparkSqlHack.sqlConf(spark.sqlContext).numShufflePartitions) 69 | assert(operator.storeNames === Seq(StateStoreId.DEFAULT_STORE_NAME)) 70 | 71 | assert(stateInfo.confs.get(SQLConf.SHUFFLE_PARTITIONS.key) === 72 | Some(operator.partitions.toString)) 73 | } 74 | } 75 | 76 | test("Reading checkpoint from streaming join") { 77 | withTempDir { cpDir => 78 | runStreamingJoinQuery(cpDir.getAbsolutePath) 79 | 80 | val stateInfo = new StateInformationInCheckpoint(spark) 81 | .gatherInformation(new Path(cpDir.getAbsolutePath)) 82 | 83 | assert(stateInfo.lastCommittedBatchId === Some(1)) 84 | assert(stateInfo.operators.length === 1) 85 | 86 | val operator = stateInfo.operators.head 87 | assert(operator.opId === 0) 88 | assert(operator.partitions === SparkSqlHack.sqlConf(spark.sqlContext).numShufflePartitions) 89 | // NOTE: this verification couples with implementation details of streaming join 90 | assert(operator.storeNames.toSet === Set("left-keyToNumValues", "left-keyWithIndexToValue", 91 | "right-keyToNumValues", "right-keyWithIndexToValue")) 92 | 93 | assert(stateInfo.confs.get(SQLConf.SHUFFLE_PARTITIONS.key) === 94 | Some(operator.partitions.toString)) 95 | } 96 | } 97 | 98 | test("Reading checkpoint from flatMapGroupsWithState") { 99 | withTempDir { cpDir => 100 | runFlatMapGroupsWithStateQuery(cpDir.getAbsolutePath) 101 | 102 | val stateInfo = new StateInformationInCheckpoint(spark) 103 | .gatherInformation(new Path(cpDir.getAbsolutePath)) 104 | 105 | assert(stateInfo.lastCommittedBatchId === Some(1)) 106 | assert(stateInfo.operators.length === 1) 107 | 108 | val operator = stateInfo.operators.head 109 | assert(operator.opId === 0) 110 | assert(operator.partitions === SparkSqlHack.sqlConf(spark.sqlContext).numShufflePartitions) 111 | assert(operator.storeNames === Seq(StateStoreId.DEFAULT_STORE_NAME)) 112 | 113 | assert(stateInfo.confs.get(SQLConf.SHUFFLE_PARTITIONS.key) === 114 | Some(operator.partitions.toString)) 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/test/scala/net/heartsavior/spark/sql/state/StateSchemaExtractorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import net.heartsavior.spark.sql.state.StateSchemaExtractor.StateKind 20 | import net.heartsavior.spark.sql.util.SchemaUtil 21 | import org.scalatest.{Assertions, BeforeAndAfterAll} 22 | 23 | import org.apache.spark.sql.Encoders 24 | import org.apache.spark.sql.execution.streaming.state.StateStore 25 | import org.apache.spark.sql.internal.SQLConf 26 | import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} 27 | 28 | class StateSchemaExtractorSuite 29 | extends StateStoreTest 30 | with BeforeAndAfterAll 31 | with Assertions { 32 | 33 | override def afterAll(): Unit = { 34 | super.afterAll() 35 | StateStore.stop() 36 | } 37 | 38 | Seq(1, 2).foreach { ver => 39 | test(s"extract schema from streaming aggregation query - state format v$ver") { 40 | withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> ver.toString) { 41 | val aggregated = getCompositeKeyStreamingAggregationQuery 42 | 43 | val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(ver) 44 | val expectedKeySchema = SchemaUtil.getSchemaAsDataType(stateSchema, "key") 45 | .asInstanceOf[StructType] 46 | val expectedValueSchema = SchemaUtil.getSchemaAsDataType(stateSchema, "value") 47 | .asInstanceOf[StructType] 48 | 49 | val schemaInfos = new StateSchemaExtractor(spark).extract(aggregated.toDF()) 50 | assert(schemaInfos.length === 1) 51 | val schemaInfo = schemaInfos.head 52 | assert(schemaInfo.opId === 0) 53 | assert(schemaInfo.formatVersion === ver) 54 | assert(schemaInfo.stateKind === StateKind.StreamingAggregation) 55 | 56 | assert(compareSchemaWithoutName(schemaInfo.keySchema, expectedKeySchema), 57 | s"Even without column names, ${schemaInfo.keySchema} did not equal $expectedKeySchema") 58 | assert(compareSchemaWithoutName(schemaInfo.valueSchema, expectedValueSchema), 59 | s"Even without column names, ${schemaInfo.valueSchema} did not equal " + 60 | s"$expectedValueSchema") 61 | } 62 | } 63 | } 64 | 65 | Seq(1, 2).foreach { ver => 66 | test(s"extract schema from flatMapGroupsWithState query - state format v$ver") { 67 | withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> ver.toString) { 68 | // This is borrowed from StateStoreTest, runFlatMapGroupsWithStateQuery 69 | val aggregated = getFlatMapGroupsWithStateQuery 70 | 71 | val expectedKeySchema = new StructType().add("value", StringType, nullable = true) 72 | 73 | val expectedValueSchema = if (ver == 1) { 74 | Encoders.product[SessionInfo].schema 75 | .add("timeoutTimestamp", IntegerType, nullable = false) 76 | } else { 77 | // ver == 2 78 | new StructType() 79 | .add("groupState", Encoders.product[SessionInfo].schema) 80 | .add("timeoutTimestamp", LongType, nullable = false) 81 | } 82 | 83 | val schemaInfos = new StateSchemaExtractor(spark).extract(aggregated.toDF()) 84 | assert(schemaInfos.length === 1) 85 | val schemaInfo = schemaInfos.head 86 | assert(schemaInfo.opId === 0) 87 | assert(schemaInfo.stateKind === StateKind.FlatMapGroupsWithState) 88 | assert(schemaInfo.formatVersion === ver) 89 | 90 | assert(compareSchemaWithoutName(schemaInfo.keySchema, expectedKeySchema), 91 | s"Even without column names, ${schemaInfo.keySchema} did not equal $expectedKeySchema") 92 | assert(compareSchemaWithoutName(schemaInfo.valueSchema, expectedValueSchema), 93 | s"Even without column names, ${schemaInfo.valueSchema} did not equal " + 94 | s"$expectedValueSchema") 95 | } 96 | } 97 | } 98 | 99 | private def compareSchemaWithoutName(s1: StructType, s2: StructType): Boolean = { 100 | if (s1.length != s2.length) { 101 | false 102 | } else { 103 | s1.zip(s2).forall { case (column1, column2) => 104 | column1.dataType == column2.dataType && column1.nullable == column2.nullable 105 | } 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/test/scala/net/heartsavior/spark/sql/state/StateStoreReaderOperatorParamExtractorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import org.apache.hadoop.fs.Path 20 | import org.scalatest.{Assertions, BeforeAndAfterAll} 21 | 22 | import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId} 23 | import org.apache.spark.sql.types.StructType 24 | 25 | class StateStoreReaderOperatorParamExtractorSuite 26 | extends StateStoreTest 27 | with BeforeAndAfterAll 28 | with Assertions { 29 | 30 | override def afterAll(): Unit = { 31 | super.afterAll() 32 | StateStore.stop() 33 | } 34 | 35 | test("combine state info and schema info from streaming aggregation query") { 36 | withTempDir { cpDir => 37 | runCompositeKeyStreamingAggregationQuery(cpDir.getAbsolutePath) 38 | 39 | val stateInfo = new StateInformationInCheckpoint(spark) 40 | .gatherInformation(new Path(cpDir.getAbsolutePath)) 41 | assert(stateInfo.operators.length === 1) 42 | // other validation of stateInfo is covered by StateInformationCheckpointSuite 43 | 44 | val query = getCompositeKeyStreamingAggregationQuery 45 | 46 | val schemaInfo = new StateSchemaExtractor(spark).extract(query.toDF) 47 | assert(schemaInfo.length === 1) 48 | // other validation of schemaInfo is covered by StateSchemaExtractorSuite 49 | 50 | val opParams = StateStoreReaderOperatorParamExtractor.extract(stateInfo, schemaInfo) 51 | 52 | // expecting only one state operator and only one store name 53 | assert(opParams.size == 1) 54 | val opParam = opParams.head 55 | assert(opParam.lastStateVersion === Some(3)) 56 | assert(opParam.storeName === StateStoreId.DEFAULT_STORE_NAME) 57 | assert(opParam.stateSchema.isDefined) 58 | val expectedStateSchema = new StructType() 59 | .add("key", schemaInfo.head.keySchema) 60 | .add("value", schemaInfo.head.valueSchema) 61 | assert(opParam.stateSchema.get === expectedStateSchema) 62 | } 63 | } 64 | 65 | test("combine state info and schema info from flatMapGroupsWithState") { 66 | withTempDir { cpDir => 67 | runFlatMapGroupsWithStateQuery(cpDir.getAbsolutePath) 68 | 69 | val stateInfo = new StateInformationInCheckpoint(spark) 70 | .gatherInformation(new Path(cpDir.getAbsolutePath)) 71 | assert(stateInfo.operators.length === 1) 72 | // other validation of stateInfo is covered by StateInformationCheckpointSuite 73 | 74 | val query = getFlatMapGroupsWithStateQuery 75 | 76 | val schemaInfo = new StateSchemaExtractor(spark).extract(query.toDF) 77 | assert(schemaInfo.length === 1) 78 | // other validation of schemaInfo is covered by StateSchemaExtractorSuite 79 | 80 | val opParams = StateStoreReaderOperatorParamExtractor.extract(stateInfo, schemaInfo) 81 | 82 | // expecting only one state operator and only one store name 83 | assert(opParams.size == 1) 84 | val opParam = opParams.head 85 | assert(opParam.lastStateVersion === Some(2)) 86 | assert(opParam.storeName === StateStoreId.DEFAULT_STORE_NAME) 87 | assert(opParam.stateSchema.isDefined) 88 | val expectedStateSchema = new StructType() 89 | .add("key", schemaInfo.head.keySchema) 90 | .add("value", schemaInfo.head.valueSchema) 91 | assert(opParam.stateSchema.get === expectedStateSchema) 92 | } 93 | } 94 | 95 | test("combine state info and schema info from streaming join - schema not supported") { 96 | withTempDir { cpDir => 97 | runStreamingJoinQuery(cpDir.getAbsolutePath) 98 | 99 | val stateInfo = new StateInformationInCheckpoint(spark) 100 | .gatherInformation(new Path(cpDir.getAbsolutePath)) 101 | assert(stateInfo.operators.length === 1) 102 | // other validation of stateInfo is covered by StateInformationCheckpointSuite 103 | 104 | val query = getStreamingJoinQuery 105 | 106 | val schemaInfo = new StateSchemaExtractor(spark).extract(query.toDF) 107 | 108 | val opParams = StateStoreReaderOperatorParamExtractor.extract(stateInfo, schemaInfo) 109 | 110 | // expecting only one state operator which has 4 store names 111 | // NOTE: this verification couples with implementation details of streaming join 112 | assert(opParams.size == 4) 113 | opParams.forall { opParam => 114 | opParam.lastStateVersion.contains(2) && opParam.opId == 0 && opParam.stateSchema.isEmpty 115 | } 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /src/test/scala/net/heartsavior/spark/sql/state/StreamingAggregationMigratorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import java.io.File 20 | 21 | import net.heartsavior.spark.sql.state.migration.StreamingAggregationMigrator 22 | import net.heartsavior.spark.sql.util.SchemaUtil 23 | import org.apache.hadoop.fs.Path 24 | import org.scalatest.{Assertions, BeforeAndAfterAll} 25 | 26 | import org.apache.spark.sql.Row 27 | import org.apache.spark.sql.execution.streaming.MemoryStream 28 | import org.apache.spark.sql.execution.streaming.state.StateStore 29 | import org.apache.spark.sql.internal.SQLConf 30 | import org.apache.spark.sql.streaming.OutputMode 31 | import org.apache.spark.sql.types.StructType 32 | 33 | class StreamingAggregationMigratorSuite 34 | extends StateStoreTest 35 | with BeforeAndAfterAll 36 | with Assertions { 37 | 38 | override def afterAll(): Unit = { 39 | super.afterAll() 40 | StateStore.stop() 41 | } 42 | 43 | test("migrate streaming aggregation state format version 1 to 2") { 44 | withTempCheckpoints { case (oldCpDir, newCpDir) => 45 | val oldCpPath = new Path(oldCpDir.getAbsolutePath) 46 | val newCpPath = new Path(newCpDir.getAbsolutePath) 47 | 48 | // run streaming aggregation query to state format version 1 49 | withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { 50 | runCompositeKeyStreamingAggregationQuery(oldCpDir.getAbsolutePath) 51 | } 52 | 53 | val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(1) 54 | 55 | val migrator = new StreamingAggregationMigrator(spark) 56 | migrator.convertVersion1To2( 57 | oldCpPath, 58 | newCpPath, 59 | SchemaUtil.getSchemaAsDataType(stateSchema, "key").asInstanceOf[StructType], 60 | SchemaUtil.getSchemaAsDataType(stateSchema, "value").asInstanceOf[StructType]) 61 | 62 | val newStateInfo = new StateInformationInCheckpoint(spark).gatherInformation(newCpPath) 63 | assert(newStateInfo.lastCommittedBatchId.isDefined, 64 | "The checkpoint directory should contain committed batch!") 65 | 66 | // check whether it's running well with new checkpoint 67 | 68 | // read state with new expected state schema (state format version 2) 69 | val newStateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(2) 70 | 71 | // we assume operator id = 0, store_name = default 72 | val stateReadDf = spark.read 73 | .format("state") 74 | .schema(newStateSchema) 75 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 76 | new File(newCpDir, "state").getAbsolutePath) 77 | .option(StateStoreDataSourceProvider.PARAM_VERSION, 78 | newStateInfo.lastCommittedBatchId.get + 1) 79 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, 0) 80 | .load() 81 | 82 | checkAnswer( 83 | stateReadDf 84 | .selectExpr("key.groupKey AS key_groupKey", "key.fruit AS key_fruit", 85 | "value.cnt AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", 86 | "value.min AS value_min"), 87 | Seq( 88 | Row(0, "Apple", 2, 6, 6, 0), 89 | Row(1, "Banana", 3, 9, 7, 1), 90 | Row(0, "Strawberry", 3, 12, 8, 2), 91 | Row(1, "Apple", 3, 15, 9, 3), 92 | Row(0, "Banana", 2, 14, 10, 4), 93 | Row(1, "Strawberry", 1, 5, 5, 5) 94 | ) 95 | ) 96 | 97 | // rerun streaming query from migrated checkpoint 98 | verifyContinueRunCompositeKeyStreamingAggregationQuery(newCpPath.toString) 99 | } 100 | } 101 | 102 | 103 | private def verifyContinueRunCompositeKeyStreamingAggregationQuery( 104 | checkpointRoot: String): Unit = { 105 | import testImplicits._ 106 | 107 | val inputData = MemoryStream[Int] 108 | val aggregated = getCompositeKeyStreamingAggregationQuery(inputData) 109 | 110 | // batch 0 111 | inputData.addData(0 to 5) 112 | // batch 1 113 | inputData.addData(6 to 10) 114 | // batch 2 115 | inputData.addData(3, 2, 1) 116 | 117 | testStream(aggregated, OutputMode.Update)( 118 | StartStream(checkpointLocation = checkpointRoot), 119 | // batch 3 120 | AddData(inputData, 3, 2, 1), 121 | CheckLastBatch( 122 | (1, "Banana", 4, 10, 7, 1), // 1, 7, 1, 1 123 | (0, "Strawberry", 4, 14, 8, 2), // 2, 8, 2, 2 124 | (1, "Apple", 4, 18, 9, 3) // 3, 9, 3, 3 125 | ) 126 | ) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/migration/StreamingAggregationMigrator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state.migration 18 | 19 | import net.heartsavior.spark.sql.checkpoint.CheckpointUtil 20 | import net.heartsavior.spark.sql.state.{StateInformationInCheckpoint, StateStoreDataSourceProvider} 21 | import org.apache.hadoop.fs.Path 22 | 23 | import org.apache.spark.internal.Logging 24 | import org.apache.spark.sql.SparkSession 25 | import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager 26 | import org.apache.spark.sql.internal.SQLConf 27 | import org.apache.spark.sql.types.StructType 28 | 29 | /** 30 | * This class enables migration functionality for query using streaming aggregation (e.g. agg()). 31 | */ 32 | class StreamingAggregationMigrator(spark: SparkSession) extends Logging { 33 | 34 | /** 35 | * Migrate state being written as format version 1 to format version 2. 36 | * 37 | * @param checkpointRoot the root path of existing checkpoint 38 | * @param newCheckpointRoot the root path savepoint with migrated state will be stored 39 | * @param keySchema key schema of existing state 40 | * @param valueSchema value schema of existing state 41 | */ 42 | def convertVersion1To2( 43 | checkpointRoot: Path, 44 | newCheckpointRoot: Path, 45 | keySchema: StructType, 46 | valueSchema: StructType): Unit = { 47 | val stateInfo = new StateInformationInCheckpoint(spark).gatherInformation(checkpointRoot) 48 | 49 | val stateVer = stateInfo.confs.getOrElse(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key, 50 | StreamingAggregationStateManager.legacyVersion.toString).toInt 51 | 52 | if (stateVer != 1) { 53 | throw new IllegalArgumentException("Given checkpoint doesn't use state formation ver. 1 " + 54 | s"for streaming aggregation! version: $stateVer") 55 | } 56 | 57 | val lastCommittedBatchId = stateInfo.lastCommittedBatchId match { 58 | case Some(bid) => bid 59 | case None => throw new IllegalArgumentException("No committed batch in given checkpoint.") 60 | } 61 | 62 | val addConf = Map(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") 63 | CheckpointUtil.createSavePoint(spark, checkpointRoot.toString, newCheckpointRoot.toString, 64 | lastCommittedBatchId, addConf, excludeState = true) 65 | 66 | val stateSchema = new StructType() 67 | .add("key", keySchema) 68 | .add("value", valueSchema) 69 | 70 | val stateVersion = lastCommittedBatchId + 1 71 | stateInfo.operators.foreach { op => 72 | val partitions = op.partitions 73 | op.storeNames.map { storeName => 74 | val stateReadDf = spark.read 75 | .format("state") 76 | .schema(stateSchema) 77 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 78 | new Path(checkpointRoot, "state").toString) 79 | .option(StateStoreDataSourceProvider.PARAM_VERSION, stateVersion) 80 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, op.opId) 81 | .option(StateStoreDataSourceProvider.PARAM_STORE_NAME, storeName) 82 | .load() 83 | 84 | logInfo(s"Schema of state format 1 (current): ${stateReadDf.schema.treeString}") 85 | 86 | // This assumes only columns in key part are duplicated between key and value schema 87 | val newValueSchema = valueSchema.filterNot(field => keySchema.contains(field)) 88 | 89 | val newValueColumns = newValueSchema.map("value." + _.name).mkString(",") 90 | val selectExprs = Seq("key", s"struct($newValueColumns) AS value") 91 | 92 | val modifiedDf = stateReadDf.selectExpr(selectExprs: _*) 93 | 94 | logInfo(s"Schema of state format 2 (new): ${modifiedDf.schema.treeString}") 95 | 96 | modifiedDf.write 97 | .format("state") 98 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 99 | new Path(newCheckpointRoot, "state").toString) 100 | .option(StateStoreDataSourceProvider.PARAM_VERSION, stateVersion) 101 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, op.opId) 102 | .option(StateStoreDataSourceProvider.PARAM_STORE_NAME, storeName) 103 | .option(StateStoreDataSourceProvider.PARAM_NEW_PARTITIONS, partitions) 104 | .save 105 | 106 | logInfo(s"Migrated state (opId: ${op.opId}, storeName: ${storeName}, " + 107 | s"partitions: ${partitions} from format 1 to 2") 108 | } 109 | } 110 | } 111 | 112 | } 113 | 114 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/migration/FlatMapGroupsWithStateMigrator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state.migration 18 | 19 | import net.heartsavior.spark.sql.checkpoint.CheckpointUtil 20 | import net.heartsavior.spark.sql.state.{StateInformationInCheckpoint, StateStoreDataSourceProvider} 21 | import org.apache.hadoop.fs.Path 22 | 23 | import org.apache.spark.internal.Logging 24 | import org.apache.spark.sql.SparkSession 25 | import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper 26 | import org.apache.spark.sql.internal.SQLConf 27 | import org.apache.spark.sql.types.StructType 28 | 29 | /** 30 | * This class enables migration functionality for query using (flat)MapGroupsWithState. 31 | */ 32 | class FlatMapGroupsWithStateMigrator(spark: SparkSession) extends Logging { 33 | 34 | /** 35 | * Migrate state being written as format version 1 to format version 2. 36 | * 37 | * @param checkpointRoot the root path of existing checkpoint 38 | * @param newCheckpointRoot the root path savepoint with migrated state will be stored 39 | * @param keySchema key schema of existing state 40 | * @param valueSchema value schema of existing state 41 | */ 42 | def convertVersion1To2( 43 | checkpointRoot: Path, 44 | newCheckpointRoot: Path, 45 | keySchema: StructType, 46 | valueSchema: StructType): Unit = { 47 | val stateInfo = new StateInformationInCheckpoint(spark).gatherInformation(checkpointRoot) 48 | 49 | val stateVer = stateInfo.confs.getOrElse( 50 | SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key, 51 | FlatMapGroupsWithStateExecHelper.legacyVersion.toString).toInt 52 | 53 | if (stateVer != 1) { 54 | throw new IllegalArgumentException("Given checkpoint doesn't use state formation ver. 1 " + 55 | s"for flatMapGroupsWithState! version: $stateVer") 56 | } 57 | 58 | val lastCommittedBatchId = stateInfo.lastCommittedBatchId match { 59 | case Some(bid) => bid 60 | case None => throw new IllegalArgumentException("No committed batch in given checkpoint.") 61 | } 62 | 63 | val addConf = Map(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2") 64 | CheckpointUtil.createSavePoint(spark, checkpointRoot.toString, newCheckpointRoot.toString, 65 | lastCommittedBatchId, addConf, excludeState = true) 66 | 67 | val stateSchema = new StructType() 68 | .add("key", keySchema) 69 | .add("value", valueSchema) 70 | 71 | val stateVersion = lastCommittedBatchId + 1 72 | stateInfo.operators.foreach { op => 73 | val partitions = op.partitions 74 | op.storeNames.map { storeName => 75 | val stateReadDf = spark.read 76 | .format("state") 77 | .schema(stateSchema) 78 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 79 | new Path(checkpointRoot, "state").toString) 80 | .option(StateStoreDataSourceProvider.PARAM_VERSION, stateVersion) 81 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, op.opId) 82 | .option(StateStoreDataSourceProvider.PARAM_STORE_NAME, storeName) 83 | .load() 84 | 85 | logInfo(s"Schema of state format 1 (current): ${stateReadDf.schema.treeString}") 86 | 87 | val valueFieldsWithoutTimestamp = valueSchema.filterNot(_.name == "timeoutTimestamp") 88 | 89 | val newValueColumns = valueFieldsWithoutTimestamp.map("value." + _.name).mkString(",") 90 | val selectExprs = Seq("key", s"struct(struct($newValueColumns) AS groupState, " + 91 | "CAST(value.timeoutTimestamp AS LONG) AS timeoutTimestamp) AS value") 92 | 93 | val modifiedDf = stateReadDf.selectExpr(selectExprs: _*) 94 | 95 | logInfo(s"Schema of state format 2 (new): ${modifiedDf.schema.treeString}") 96 | 97 | modifiedDf.write 98 | .format("state") 99 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 100 | new Path(newCheckpointRoot, "state").toString) 101 | .option(StateStoreDataSourceProvider.PARAM_VERSION, stateVersion) 102 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, op.opId) 103 | .option(StateStoreDataSourceProvider.PARAM_STORE_NAME, storeName) 104 | .option(StateStoreDataSourceProvider.PARAM_NEW_PARTITIONS, partitions) 105 | .save 106 | 107 | logInfo(s"Migrated state (opId: ${op.opId}, storeName: ${storeName}, " + 108 | s"partitions: ${partitions} from format 1 to 2") 109 | } 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/StateStoreReaderRDD.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import java.util.UUID 20 | 21 | import scala.util.Try 22 | 23 | import net.heartsavior.spark.sql.util.SchemaUtil 24 | import org.apache.hadoop.fs.{Path, PathFilter} 25 | 26 | import org.apache.spark.{Partition, TaskContext} 27 | import org.apache.spark.rdd.RDD 28 | import org.apache.spark.sql.{Row, SparkSession} 29 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 30 | import org.apache.spark.sql.catalyst.expressions.GenericInternalRow 31 | import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} 32 | import org.apache.spark.sql.hack.SerializableConfigurationWrapper 33 | import org.apache.spark.sql.types.StructType 34 | 35 | class StateStorePartition( 36 | val partition: Int, 37 | val queryId: UUID) extends Partition { 38 | override def index: Int = partition 39 | } 40 | 41 | /** 42 | * An RDD that reads (key, value) pairs of state and provides rows having columns (key, value). 43 | */ 44 | class StateStoreReaderRDD( 45 | session: SparkSession, 46 | keySchema: StructType, 47 | valueSchema: StructType, 48 | stateCheckpointRootLocation: String, 49 | batchId: Long, 50 | operatorId: Long, 51 | storeName: String) 52 | extends RDD[Row](session.sparkContext, Nil) { 53 | 54 | private val storeConf = new StateStoreConf(session.sessionState.conf) 55 | 56 | // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it 57 | private val hadoopConfBroadcastWrapper = new SerializableConfigurationWrapper(session) 58 | 59 | override def compute(split: Partition, context: TaskContext): Iterator[Row] = { 60 | split match { 61 | case p: StateStorePartition => 62 | val stateStoreId = StateStoreId(stateCheckpointRootLocation, operatorId, 63 | p.partition, storeName) 64 | val stateStoreProviderId = StateStoreProviderId(stateStoreId, p.queryId) 65 | 66 | val store = StateStore.get(stateStoreProviderId, keySchema, valueSchema, 67 | indexOrdinal = None, version = batchId, storeConf = storeConf, 68 | hadoopConf = hadoopConfBroadcastWrapper.broadcastedConf.value.value) 69 | 70 | val encoder = RowEncoder(SchemaUtil.keyValuePairSchema(keySchema, valueSchema)) 71 | .resolveAndBind() 72 | val fromRow = encoder.createDeserializer() 73 | val iter = store.iterator().map { pair => 74 | val row = new GenericInternalRow(Array(pair.key, pair.value).asInstanceOf[Array[Any]]) 75 | fromRow(row) 76 | } 77 | 78 | // close state store provider after using 79 | StateStore.unload(stateStoreProviderId) 80 | 81 | iter 82 | 83 | case e => throw new IllegalStateException("Expected StateStorePartition but other type of " + 84 | s"partition passed - $e") 85 | } 86 | } 87 | 88 | override protected def getPartitions: Array[Partition] = { 89 | val fs = stateCheckpointPartitionsLocation.getFileSystem( 90 | hadoopConfBroadcastWrapper.broadcastedConf.value.value) 91 | val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() { 92 | override def accept(path: Path): Boolean = { 93 | fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 94 | } 95 | }) 96 | 97 | if (partitions.headOption.isEmpty) { 98 | Array.empty[Partition] 99 | } else { 100 | // just a dummy query id because we are actually not running streaming query 101 | val queryId = UUID.randomUUID() 102 | 103 | val partitionsSorted = partitions.sortBy(fs => fs.getPath.getName.toInt) 104 | val partitionNums = partitionsSorted.map(_.getPath.getName.toInt) 105 | // assuming no same number - they're directories hence no same name 106 | val head = partitionNums.head 107 | val tail = partitionNums(partitionNums.length - 1) 108 | assert(head == 0, "Partition should start with 0") 109 | assert((tail - head + 1) == partitionNums.length, 110 | s"No continuous partitions in state: $partitionNums") 111 | 112 | partitionNums.map(pn => new StateStorePartition(pn, queryId)).toArray 113 | } 114 | } 115 | 116 | def stateCheckpointPartitionsLocation: Path = { 117 | new Path(stateCheckpointRootLocation, s"$operatorId") 118 | } 119 | 120 | def stateCheckpointLocation(partitionId: Int): Path = { 121 | val partitionsLocation = stateCheckpointPartitionsLocation 122 | if (storeName == StateStoreId.DEFAULT_STORE_NAME) { 123 | // For reading state store data that was generated before store names were used (Spark <= 2.2) 124 | new Path(partitionsLocation, s"$partitionId") 125 | } else { 126 | new Path(partitionsLocation, s"$partitionId/$storeName") 127 | } 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/StateInformationInCheckpoint.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import scala.util.Try 20 | 21 | import net.heartsavior.spark.sql.state 22 | import net.heartsavior.spark.sql.state.StateInformationInCheckpoint.{StateInformation, StateOperatorInformation} 23 | import net.heartsavior.spark.sql.util.HadoopPathUtil 24 | import org.apache.hadoop.fs.{FileStatus, Path, PathFilter} 25 | 26 | import org.apache.spark.sql.SparkSession 27 | import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog} 28 | import org.apache.spark.sql.execution.streaming.state.StateStoreId 29 | 30 | /** 31 | * This class enables retrieving 32 | * [[state.StateInformationInCheckpoint.StateInformation]] 33 | * via reading checkpoint. 34 | */ 35 | class StateInformationInCheckpoint(spark: SparkSession) { 36 | 37 | val hadoopConf = spark.sessionState.newHadoopConf() 38 | 39 | def gatherInformation(checkpointPath: Path): StateInformation = { 40 | val offsetSeq = new OffsetSeqLog(spark, new Path(checkpointPath, "offsets").toString) 41 | val confMap: Map[String, String] = offsetSeq.getLatest() match { 42 | case Some((_, offset)) => offset.metadata match { 43 | case Some(md) => md.conf 44 | case None => Map.empty[String, String] 45 | } 46 | case None => Map.empty[String, String] 47 | } 48 | 49 | val commitLog = new CommitLog(spark, new Path(checkpointPath, "commits").toString) 50 | val lastCommittedBatchId = commitLog.getLatest() match { 51 | case Some((lastId, _)) => lastId 52 | case None => -1 53 | } 54 | 55 | if (lastCommittedBatchId < 0) { 56 | return StateInformation(None, Seq.empty[StateOperatorInformation], confMap) 57 | } 58 | 59 | val fs = checkpointPath.getFileSystem(hadoopConf) 60 | val numericDirectories = new PathFilter() { 61 | override def accept(path: Path): Boolean = { 62 | fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 63 | } 64 | } 65 | 66 | val statePath = new Path(checkpointPath, "state") 67 | val operatorDirs = fs.listStatus(statePath, numericDirectories) 68 | 69 | val opInfos = operatorDirs.map { operatorDir => 70 | val opPath = operatorDir.getPath 71 | val opId = opPath.getName.toInt 72 | 73 | val partitions = fs.listStatus(opPath, numericDirectories) 74 | if (partitions.nonEmpty) { 75 | validateCorrectPartitions(partitions) 76 | 77 | // assuming information is same across partitions 78 | val partitionDir = partitions.head 79 | 80 | val statuses = fs.listStatus(partitionDir.getPath) 81 | val dirs = statuses.filter(status => fs.isDirectory(status.getPath)) 82 | val storeNames = if (dirs.nonEmpty) { 83 | dirs.map(_.getPath.getName).toList 84 | } else { 85 | // assuming default store name 86 | List(StateStoreId.DEFAULT_STORE_NAME) 87 | } 88 | 89 | StateOperatorInformation(opId, partitions.length, storeNames) 90 | } else { 91 | StateOperatorInformation(opId, 0, Seq.empty) 92 | } 93 | } 94 | 95 | StateInformation(Some(lastCommittedBatchId), opInfos, confMap) 96 | } 97 | 98 | private def validateCorrectPartitions(partitions: Array[FileStatus]): Unit = { 99 | val partitionsSorted = partitions.sortBy(fs => fs.getPath.getName.toInt) 100 | val partitionNums = partitionsSorted.map(_.getPath.getName.toInt) 101 | // assuming no same number - they're directories hence no same name 102 | val head = partitionNums.head 103 | val tail = partitionNums(partitionNums.length - 1) 104 | assert(head == 0, "Partition should start with 0") 105 | assert((tail - head + 1) == partitionNums.length, 106 | s"No continuous partitions in state: $partitionNums") 107 | } 108 | } 109 | 110 | object StateInformationInCheckpoint { 111 | 112 | case class StateOperatorInformation(opId: Int, partitions: Int, storeNames: Seq[String]) 113 | case class StateInformation( 114 | lastCommittedBatchId: Option[Long], 115 | operators: Seq[StateOperatorInformation], 116 | confs: Map[String, String]) 117 | 118 | // scalastyle:off println 119 | def main(args: Array[String]): Unit = { 120 | val spark = SparkSession 121 | .builder 122 | .appName("StateInformationInCheckpoint") 123 | .getOrCreate() 124 | 125 | if (args.length < 1) { 126 | System.err.println("Usage: StateInformationInCheckpoint [checkpoint path]") 127 | sys.exit(1) 128 | } 129 | 130 | val checkpointRoot = args(0) 131 | 132 | println(s"Checkpoint path: $checkpointRoot") 133 | 134 | val hadoopConf = spark.sessionState.newHadoopConf() 135 | val checkpointPath = new Path(HadoopPathUtil.resolve(hadoopConf, checkpointRoot)) 136 | val fs = checkpointPath.getFileSystem(hadoopConf) 137 | 138 | if (!fs.exists(checkpointPath) || !fs.isDirectory(checkpointPath)) { 139 | System.err.println("Checkpoint path doesn't exist or not a directory.") 140 | sys.exit(2) 141 | } 142 | 143 | val stateInfo = new StateInformationInCheckpoint(spark).gatherInformation(checkpointPath) 144 | 145 | stateInfo.lastCommittedBatchId match { 146 | case Some(lastId) => 147 | println(s"Last committed batch ID: $lastId") 148 | stateInfo.operators.foreach { op => 149 | println(s"Operator ID: ${op.opId}, partitions: ${op.partitions}, " + 150 | s"storeNames: ${op.storeNames}") 151 | } 152 | 153 | case None => println("No batch has been committed.") 154 | } 155 | } 156 | // scalastyle:on println 157 | } 158 | -------------------------------------------------------------------------------- /src/test/scala/net/heartsavior/spark/sql/state/FlatMapGroupsWithStateMigratorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import java.io.File 20 | 21 | import net.heartsavior.spark.sql.state.migration.FlatMapGroupsWithStateMigrator 22 | import org.apache.hadoop.fs.Path 23 | import org.scalatest.{Assertions, BeforeAndAfterAll} 24 | import org.scalatest.time.Span 25 | import org.scalatest.time.SpanSugar._ 26 | 27 | import org.apache.spark.sql.{Encoders, Row} 28 | import org.apache.spark.sql.execution.streaming.MemoryStream 29 | import org.apache.spark.sql.execution.streaming.state.StateStore 30 | import org.apache.spark.sql.internal.SQLConf 31 | import org.apache.spark.sql.streaming.{OutputMode, Trigger} 32 | import org.apache.spark.sql.streaming.util.StreamManualClock 33 | import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} 34 | 35 | class FlatMapGroupsWithStateMigratorSuite 36 | extends StateStoreTest 37 | with BeforeAndAfterAll 38 | with Assertions { 39 | 40 | override val streamingTimeout: Span = 30.seconds 41 | 42 | override def afterAll(): Unit = { 43 | super.afterAll() 44 | StateStore.stop() 45 | } 46 | 47 | test("migrate flatMapGroupsWithState state format version 1 to 2") { 48 | withTempCheckpoints { case (oldCpDir, newCpDir) => 49 | val oldCpPath = new Path(oldCpDir.getAbsolutePath) 50 | val newCpPath = new Path(newCpDir.getAbsolutePath) 51 | 52 | // run flatMapGroupsWithState query to state format version 1 53 | withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "1") { 54 | runFlatMapGroupsWithStateQuery(oldCpDir.getAbsolutePath) 55 | } 56 | 57 | val keySchema = new StructType().add("value", StringType, nullable = true) 58 | 59 | val valueSchema = Encoders.product[SessionInfo].schema 60 | .add("timeoutTimestamp", IntegerType, nullable = false) 61 | 62 | val migrator = new FlatMapGroupsWithStateMigrator(spark) 63 | migrator.convertVersion1To2(oldCpPath, newCpPath, keySchema, valueSchema) 64 | 65 | val newStateInfo = new StateInformationInCheckpoint(spark).gatherInformation(newCpPath) 66 | assert(newStateInfo.lastCommittedBatchId.isDefined, 67 | "The checkpoint directory should contain committed batch!") 68 | 69 | // check whether it's running well with new checkpoint 70 | 71 | // read state with new expected state schema (state format version 2) 72 | val newValueSchema = new StructType() 73 | .add("groupState", Encoders.product[SessionInfo].schema) 74 | .add("timeoutTimestamp", LongType, nullable = false) 75 | 76 | val newStateSchema = new StructType() 77 | .add("key", keySchema) 78 | .add("value", newValueSchema) 79 | 80 | // we assume operator id = 0, store_name = default 81 | val stateReadDf = spark.read 82 | .format("state") 83 | .schema(newStateSchema) 84 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 85 | new File(newCpDir, "state").getAbsolutePath) 86 | .option(StateStoreDataSourceProvider.PARAM_VERSION, 87 | newStateInfo.lastCommittedBatchId.get + 1) 88 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, 0) 89 | .load() 90 | 91 | checkAnswer( 92 | stateReadDf 93 | .selectExpr("key.value AS key_value", "value.groupState.numEvents AS value_numEvents", 94 | "value.groupState.startTimestampMs AS value_startTimestampMs", 95 | "value.groupState.endTimestampMs AS value_endTimestampMs", 96 | "value.timeoutTimestamp AS value_timeoutTimestamp"), 97 | Seq( 98 | Row("hello", 4, 1000, 4000, 12000), 99 | Row("world", 2, 1000, 3000, 12000), 100 | Row("scala", 2, 2000, 4000, 12000) 101 | ) 102 | ) 103 | 104 | // rerun streaming query from migrated checkpoint 105 | verifyFlatMapGroupsWithStateQuery(newCpPath.toString) 106 | } 107 | } 108 | 109 | private def verifyFlatMapGroupsWithStateQuery(checkpointRoot: String): Unit = { 110 | // scalastyle:off line.size.limit 111 | // This test code is borrowed from sessionization example of Apache Spark, 112 | // with modification a bit to run with testStream 113 | // https://github.com/apache/spark/blob/v2.4.1/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala 114 | // scalastyle:on 115 | import testImplicits._ 116 | 117 | val clock = new StreamManualClock 118 | 119 | val inputData = MemoryStream[(String, Long)] 120 | val remapped = getFlatMapGroupsWithStateQuery(inputData) 121 | 122 | // batch 0 123 | inputData.addData(("hello world", 1L), ("hello scala", 2L)) 124 | clock.advance(1 * 1000) 125 | 126 | // batch 1 127 | inputData.addData(("hello world", 3L), ("hello scala", 4L)) 128 | clock.advance(1 * 1000) 129 | 130 | testStream(remapped, OutputMode.Update)( 131 | StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, 132 | checkpointLocation = checkpointRoot), 133 | 134 | // batch 2 135 | AddData(inputData, ("spark scala", 20L)), 136 | AdvanceManualClock(15 * 1000), 137 | CheckNewAnswer( 138 | ("hello", 4, 3000, true), 139 | ("world", 2, 2000, true), 140 | ("spark", 1, 0, false), 141 | ("scala", 3, 18000, false) 142 | ), 143 | // batch 3 144 | AddData(inputData, ("hello world", 30L), ("hello spark scala", 32L)), 145 | AdvanceManualClock(15 * 1000), 146 | CheckNewAnswer( 147 | ("hello", 2, 2000, false), 148 | ("world", 1, 0, false), 149 | ("spark", 2, 12000, false), 150 | ("scala", 4, 30000, false) 151 | ) 152 | ) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/main/scala/net/heartsavior/spark/sql/state/StateStoreDataSourceProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import net.heartsavior.spark.sql.util.SchemaUtil 20 | 21 | import org.apache.spark.sql._ 22 | import org.apache.spark.sql.execution.streaming.state.StateStoreId 23 | import org.apache.spark.sql.hack.SparkSqlHack 24 | import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, SchemaRelationProvider} 25 | import org.apache.spark.sql.types.StructType 26 | 27 | // TODO: read schema of key and value from metadata of state (requires SPARK-27237) 28 | // and change SchemaRelationProvider to RelationProvider to receive schema optionally 29 | /** 30 | * Data Source Provider for state store to enable read to/write from state. 31 | */ 32 | class StateStoreDataSourceProvider 33 | extends DataSourceRegister 34 | with SchemaRelationProvider 35 | with CreatableRelationProvider { 36 | 37 | import StateStoreDataSourceProvider._ 38 | 39 | override def shortName(): String = "state" 40 | 41 | override def createRelation( 42 | sqlContext: SQLContext, 43 | parameters: Map[String, String], 44 | schema: StructType): BaseRelation = { 45 | if (!isValidSchema(schema)) { 46 | throw SparkSqlHack.analysisException("The fields of schema should be 'key' and 'value', " + 47 | "and each field should have corresponding fields (they should be a StructType)") 48 | } 49 | 50 | val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] 51 | val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] 52 | 53 | val checkpointLocation = parameters.get(PARAM_CHECKPOINT_LOCATION) match { 54 | case Some(cpLocation) => cpLocation 55 | case None => throw SparkSqlHack.analysisException( 56 | s"'$PARAM_CHECKPOINT_LOCATION' must be specified.") 57 | } 58 | 59 | val version = parameters.get(PARAM_VERSION) match { 60 | case Some(ver) => ver.toInt 61 | case None => throw SparkSqlHack.analysisException(s"'$PARAM_VERSION' must be specified.") 62 | } 63 | 64 | val operatorId = parameters.get(PARAM_OPERATOR_ID) match { 65 | case Some(opId) => opId.toInt 66 | case None => throw SparkSqlHack.analysisException(s"'$PARAM_OPERATOR_ID' must be specified.") 67 | } 68 | 69 | val storeName = parameters.get(PARAM_STORE_NAME) match { 70 | case Some(stName) => stName 71 | case None => StateStoreId.DEFAULT_STORE_NAME 72 | } 73 | 74 | new StateStoreRelation(sqlContext.sparkSession, keySchema, 75 | valueSchema, checkpointLocation, version, operatorId, 76 | storeName) 77 | } 78 | 79 | override def createRelation( 80 | sqlContext: SQLContext, 81 | mode: SaveMode, 82 | parameters: Map[String, String], 83 | data: DataFrame): BaseRelation = { 84 | mode match { 85 | case SaveMode.Overwrite | SaveMode.ErrorIfExists => // good 86 | case _ => throw SparkSqlHack.analysisException(s"Save mode $mode not allowed for state. " + 87 | s"Allowed save modes are ${SaveMode.Overwrite} and ${SaveMode.ErrorIfExists}.") 88 | } 89 | 90 | val checkpointLocation = parameters.get(PARAM_CHECKPOINT_LOCATION) match { 91 | case Some(cpLocation) => cpLocation 92 | case None => throw SparkSqlHack.analysisException( 93 | s"'$PARAM_CHECKPOINT_LOCATION' must be specified.") 94 | } 95 | 96 | val version = parameters.get(PARAM_VERSION) match { 97 | case Some(ver) => ver.toInt 98 | case None => throw SparkSqlHack.analysisException(s"'$PARAM_VERSION' must be specified.") 99 | } 100 | 101 | val operatorId = parameters.get(PARAM_OPERATOR_ID) match { 102 | case Some(opId) => opId.toInt 103 | case None => throw SparkSqlHack.analysisException(s"'$PARAM_OPERATOR_ID' must be specified.") 104 | } 105 | 106 | val storeName = parameters.get(PARAM_STORE_NAME) match { 107 | case Some(stName) => stName 108 | case None => StateStoreId.DEFAULT_STORE_NAME 109 | } 110 | 111 | val newPartitions = parameters.get(PARAM_NEW_PARTITIONS) match { 112 | case Some(partitions) => partitions.toInt 113 | case None => throw SparkSqlHack.analysisException( 114 | s"'$PARAM_NEW_PARTITIONS' must be specified.") 115 | } 116 | 117 | if (!isValidSchema(data.schema)) { 118 | throw SparkSqlHack.analysisException("The fields of schema should be 'key' and 'value', " + 119 | "and each field should have corresponding fields (they should be a StructType)") 120 | } 121 | 122 | val keySchema = SchemaUtil.getSchemaAsDataType(data.schema, "key").asInstanceOf[StructType] 123 | val valueSchema = SchemaUtil.getSchemaAsDataType(data.schema, "value").asInstanceOf[StructType] 124 | 125 | new StateStoreWriter(sqlContext.sparkSession, data, keySchema, valueSchema, checkpointLocation, 126 | version, operatorId, storeName, newPartitions).write() 127 | 128 | // just return the same as we just update it 129 | createRelation(sqlContext, parameters, data.schema) 130 | } 131 | 132 | private def isValidSchema(schema: StructType): Boolean = { 133 | if (schema.fieldNames.toSeq != Seq("key", "value")) { 134 | false 135 | } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { 136 | false 137 | } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { 138 | false 139 | } else { 140 | true 141 | } 142 | } 143 | } 144 | 145 | object StateStoreDataSourceProvider { 146 | val PARAM_CHECKPOINT_LOCATION = "checkpointLocation" 147 | val PARAM_VERSION = "version" 148 | val PARAM_OPERATOR_ID = "operatorId" 149 | val PARAM_STORE_NAME = "storeName" 150 | val PARAM_NEW_PARTITIONS = "newPartitions" 151 | } 152 | -------------------------------------------------------------------------------- /src/test/scala/net/heartsavior/spark/sql/state/StateStoreStreamingAggregationReadSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import java.io.File 20 | 21 | import org.scalatest.{Assertions, BeforeAndAfterAll} 22 | 23 | import org.apache.spark.sql.Row 24 | import org.apache.spark.sql.execution.streaming.state.StateStore 25 | import org.apache.spark.sql.internal.SQLConf 26 | 27 | class StateStoreStreamingAggregationReadSuite 28 | extends StateStoreTest 29 | with BeforeAndAfterAll 30 | with Assertions { 31 | 32 | override def afterAll(): Unit = { 33 | super.afterAll() 34 | StateStore.stop() 35 | } 36 | 37 | test("reading state from simple aggregation - state format version 1") { 38 | withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1"): _*) { 39 | withTempDir { tempDir => 40 | runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) 41 | 42 | val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(1) 43 | 44 | val operatorId = 0 45 | val batchId = 1 46 | 47 | val stateReadDf = spark.read 48 | .format("state") 49 | .schema(stateSchema) 50 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 51 | new File(tempDir, "state").getAbsolutePath) 52 | .option(StateStoreDataSourceProvider.PARAM_VERSION, batchId + 1) 53 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 54 | .load() 55 | 56 | logInfo(s"Schema: ${stateReadDf.schema.treeString}") 57 | 58 | checkAnswer( 59 | stateReadDf 60 | .selectExpr("key.groupKey AS key_groupKey", "value.groupKey AS value_groupKey", 61 | "value.cnt AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", 62 | "value.min AS value_min"), 63 | Seq( 64 | Row(0, 0, 4, 60, 30, 0), // 0, 10, 20, 30 65 | Row(1, 1, 4, 64, 31, 1), // 1, 11, 21, 31 66 | Row(2, 2, 4, 68, 32, 2), // 2, 12, 22, 32 67 | Row(3, 3, 4, 72, 33, 3), // 3, 13, 23, 33 68 | Row(4, 4, 4, 76, 34, 4), // 4, 14, 24, 34 69 | Row(5, 5, 4, 80, 35, 5), // 5, 15, 25, 35 70 | Row(6, 6, 4, 84, 36, 6), // 6, 16, 26, 36 71 | Row(7, 7, 4, 88, 37, 7), // 7, 17, 27, 37 72 | Row(8, 8, 4, 92, 38, 8), // 8, 18, 28, 38 73 | Row(9, 9, 4, 96, 39, 9) // 9, 19, 29, 39 74 | ) 75 | ) 76 | } 77 | } 78 | } 79 | 80 | test("reading state from simple aggregation - state format version 2") { 81 | withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { 82 | withTempDir { tempDir => 83 | runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) 84 | 85 | val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(2) 86 | 87 | val operatorId = 0 88 | val batchId = 1 89 | 90 | val stateReadDf = spark.read 91 | .format("state") 92 | .schema(stateSchema) 93 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 94 | new File(tempDir, "state").getAbsolutePath) 95 | .option(StateStoreDataSourceProvider.PARAM_VERSION, batchId + 1) 96 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 97 | .load() 98 | 99 | logInfo(s"Schema: ${stateReadDf.schema.treeString}") 100 | 101 | checkAnswer( 102 | stateReadDf 103 | .selectExpr("key.groupKey AS key_groupKey", "value.cnt AS value_cnt", 104 | "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min"), 105 | Seq( 106 | Row(0, 4, 60, 30, 0), // 0, 10, 20, 30 107 | Row(1, 4, 64, 31, 1), // 1, 11, 21, 31 108 | Row(2, 4, 68, 32, 2), // 2, 12, 22, 32 109 | Row(3, 4, 72, 33, 3), // 3, 13, 23, 33 110 | Row(4, 4, 76, 34, 4), // 4, 14, 24, 34 111 | Row(5, 4, 80, 35, 5), // 5, 15, 25, 35 112 | Row(6, 4, 84, 36, 6), // 6, 16, 26, 36 113 | Row(7, 4, 88, 37, 7), // 7, 17, 27, 37 114 | Row(8, 4, 92, 38, 8), // 8, 18, 28, 38 115 | Row(9, 4, 96, 39, 9) // 9, 19, 29, 39 116 | ) 117 | ) 118 | } 119 | } 120 | } 121 | 122 | test("reading state from simple aggregation - composite key") { 123 | withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { 124 | withTempDir { tempDir => 125 | runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) 126 | 127 | val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(2) 128 | 129 | val operatorId = 0 130 | val batchId = 1 131 | 132 | val stateReadDf = spark.read 133 | .format("state") 134 | .schema(stateSchema) 135 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 136 | new File(tempDir, "state").getAbsolutePath) 137 | .option(StateStoreDataSourceProvider.PARAM_VERSION, batchId + 1) 138 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 139 | .load() 140 | 141 | logInfo(s"Schema: ${stateReadDf.schema.treeString}") 142 | 143 | checkAnswer( 144 | stateReadDf 145 | .selectExpr("key.groupKey AS key_groupKey", "key.fruit AS key_fruit", 146 | "value.cnt AS value_cnt", "value.sum AS value_sum", "value.max AS value_max", 147 | "value.min AS value_min"), 148 | Seq( 149 | Row(0, "Apple", 2, 6, 6, 0), 150 | Row(1, "Banana", 2, 8, 7, 1), 151 | Row(0, "Strawberry", 2, 10, 8, 2), 152 | Row(1, "Apple", 2, 12, 9, 3), 153 | Row(0, "Banana", 2, 14, 10, 4), 154 | Row(1, "Strawberry", 1, 5, 5, 5) 155 | ) 156 | ) 157 | } 158 | } 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /checkstyle.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 21 | 22 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 123 | 124 | 125 | 126 | 128 | 129 | 130 | 131 | 133 | 134 | 135 | 137 | 139 | 141 | 143 | 144 | 145 | 155 | 156 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spark State Tools 2 | 3 | [![CircleCI](https://circleci.com/gh/HeartSaVioR/spark-state-tools/tree/master.svg?style=svg)](https://circleci.com/gh/HeartSaVioR/spark-state-tools/tree/master) 4 | 5 | Spark State Tools provides features about offline manipulation of Structured Streaming state on existing query. 6 | 7 | The features we provide as of now are: 8 | 9 | * Show some state information which you'll need to provide to enjoy below features 10 | * state operator information from checkpoint 11 | * state schema from streaming query 12 | * Create savepoint from existing checkpoint of Structured Streaming query 13 | * You can pick specific batch (if it exists on metadata) to create savepoint 14 | * Read state as batch source of Spark SQL 15 | * Write DataFrame to state as batch sink of Spark SQL 16 | * With feature of writing state, you can achieve rescaling state (repartition), simple schema evolution, etc. 17 | * Migrate state format from old to new 18 | * migrating Streaming Aggregation from ver 1 to 2 19 | * migrating FlatMapGroupsWithState from ver 1 to 2 20 | 21 | As this project leverages Spark Structured Streaming's interfaces, and doesn't deal with internal 22 | (e.g. the structure of state file for HDFS state store), the performance may be suboptimal. 23 | 24 | For now, from the most parts, states from Streaming Aggregation query (`groupBy().agg()`) and (Flat)MapGroupsWithState are supported. 25 | 26 | ## Disclaimer 27 | 28 | This is something more of a proof of concept implementation, might not be something for production ready. 29 | When you deal with writing state, you may want to backup your checkpoint with CheckpointUtil and try doing it with savepoint. 30 | 31 | The project is intended to deal with offline state, not against state which streaming query is running. 32 | Actually it can be possible, but state store provider in running query can purge old batches, which would produce error on here. 33 | 34 | ## Supported versions 35 | 36 | Both Spark 3.0.x and 2.4.x is supported: it only means you should use these versions when using this project. 37 | 38 | The project provides cross-compile for Scala 2.11 and 2.12 (thanks [@redsk](https://github.com/redsk)!); please pick the right artifact for your Scala version. 39 | 40 | Spark version | Scala versions | artifact version 41 | ------------- | -------------- | ---------------- 42 | 2.4.x | 2.11 / 2.12 | 0.5.0-spark-2.4 43 | 3.0.x | 2.12 | 0.5.0-spark-3.0 44 | 45 | ## Pulling artifacts 46 | 47 | You may use this library in your applications with the following dependency information: 48 | 49 | ``` 50 | groupId: net.heartsavior.spark 51 | artifactId: spark-state-tools_ 52 | ``` 53 | 54 | You are encouraged to always use latest version which is compatible to your Apache Spark version. 55 | 56 | e.g. For maven: 57 | 58 | (Please replace `{{...}}` with content in above matrix.) 59 | 60 | ``` 61 | 62 | net.heartsavior.spark 63 | spark-state-tool_{{scala_version}} 64 | {{artifact_version}} 65 | 66 | ``` 67 | 68 | For other dependency managements, you can refer below page to get the guide: 69 | 70 | https://search.maven.org/artifact/net.heartsavior.spark/spark-state-tools_2.11/ 71 | https://search.maven.org/artifact/net.heartsavior.spark/spark-state-tools_2.12/ 72 | 73 | (NOTE: Use at least 0.4.0 or higher as previous versions have critical performance issue on reading path.) 74 | 75 | 76 | ## How to use 77 | 78 | First of all, you may want to get state and last batch information to provide them as parameters. 79 | You can get it from `StateInformationInCheckpoint`, whether calling from your codebase or running with `spark-submit`. 80 | Here we assume you have artifact jar of spark-state-tool and you want to run it from cli (leveraging `spark-submit`). 81 | 82 | ```text 83 | /bin/spark-submit --master "local[*]" \ 84 | --class net.heartsavior.spark.sql.state.StateInformationInCheckpoint \ 85 | spark-state-tool-0.0.1-SNAPSHOT.jar 86 | ``` 87 | 88 | The command line will provide checkpoint information like below: 89 | 90 | ```text 91 | Last committed batch ID: 2 92 | Operator ID: 0, partitions: 5, storeNames: List(default) 93 | ``` 94 | 95 | This output means the query has batch ID 2 as last committed (NOTE: corresponding state version is 3, not 2), and 96 | there's only one stateful operator which has ID as 0, and 5 partitions, and there's also only one kind of store named "default". 97 | 98 | You can achieve this as calling `StateInformationInCheckpoint.gatherInformation` against checkpoint directory too. 99 | 100 | ```scala 101 | // Here we assume 'spark' as SparkSession. 102 | // Here the class of Path is `org.apache.hadoop.fs.Path` 103 | val stateInfo = new StateInformationInCheckpoint(spark).gatherInformation(new Path(cpDir.getAbsolutePath)) 104 | // Here stateInfo is `StateInformation`, which you can extract same information as running CLI app 105 | ``` 106 | 107 | To read state from your existing query, you may want to provide state schema manually, or read from your existing query: 108 | 109 | * Read schema from existing query 110 | 111 | (supported: `streaming aggregation`, `flatMapGroupsWithState`) 112 | 113 | ```scala 114 | // Here we assume 'spark' as SparkSession. 115 | // the query shouldn't have sink - you may need to get rid of writeStream part and pass DataFrame 116 | val schemaInfos = new StateSchemaExtractor(spark).extract(streamingQueryDf) 117 | // Here schemaInfos is `Seq[StateSchemaInfo]`, which you can extract keySchema, 118 | // and valueSchema and finally define state schema. Please refer "Manual schema" 119 | // to define state schema with key schema and value schema 120 | ``` 121 | 122 | * Manual schema 123 | 124 | ```scala 125 | val stateKeySchema = new StructType() 126 | .add("groupKey", IntegerType) 127 | 128 | val stateValueSchema = new StructType() 129 | .add("cnt", LongType) 130 | .add("sum", LongType) 131 | .add("max", IntegerType) 132 | .add("min", IntegerType) 133 | 134 | val stateFormat = new StructType() 135 | .add("key", stateKeySchema) 136 | .add("value", stateValueSchema) 137 | ``` 138 | 139 | You can also combine both state operator information in state information and state schema via `StateStoreReaderOperatorParamExtractor` 140 | to get necessary parameters for state batch read: 141 | 142 | ```scala 143 | // Here we assume 'spark' as SparkSession. 144 | val stateInfo = new StateInformationInCheckpoint(spark).gatherInformation(new Path(cpDir.getAbsolutePath)) 145 | val schemaInfos = new StateSchemaExtractor(spark).extract(streamingQueryDf) 146 | val stateReadParams = StateStoreReaderOperatorParamExtractor.extract(stateInfo, schemaInfos) 147 | // from `stateReadParams` you can get last committed state version, operatorId, storeName, state schema per each (operatorId, storeName) group 148 | ``` 149 | 150 | Then you can start your batch query like: 151 | 152 | ```scala 153 | val operatorId = 0 154 | val batchId = 1 // the version of state for the output of batch is batchId + 1 155 | 156 | // Here we assume 'spark' as SparkSession 157 | val stateReadDf = spark.read 158 | .format("state") 159 | .schema(stateSchema) 160 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 161 | new Path(checkpointRoot, "state").getAbsolutePath) 162 | .option(StateStoreDataSourceProvider.PARAM_VERSION, batchId + 1) 163 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 164 | .load() 165 | 166 | 167 | // The schema of stateReadDf follows: 168 | // For streaming aggregation state format v1 169 | // (query ran with lower than Spark 2.4.0 for the first time) 170 | /* 171 | root 172 | |-- key: struct (nullable = false) 173 | | |-- groupKey: integer (nullable = true) 174 | |-- value: struct (nullable = false) 175 | | |-- groupKey: integer (nullable = true) 176 | | |-- cnt: long (nullable = true) 177 | | |-- sum: long (nullable = true) 178 | | |-- max: integer (nullable = true) 179 | | |-- min: integer (nullable = true) 180 | */ 181 | 182 | // For streaming aggregation state format v2 183 | // (query ran with Spark 2.4.0 or higher for the first time) 184 | /* 185 | root 186 | |-- key: struct (nullable = false) 187 | | |-- groupKey: integer (nullable = true) 188 | |-- value: struct (nullable = false) 189 | | |-- cnt: long (nullable = true) 190 | | |-- sum: long (nullable = true) 191 | | |-- max: integer (nullable = true) 192 | | |-- min: integer (nullable = true) 193 | */ 194 | ``` 195 | 196 | To write Dataset as state of Structured Streaming, you can transform your Dataset as having schema as follows: 197 | 198 | ```text 199 | root 200 | |-- key: struct (nullable = false) 201 | | |-- ...key fields... 202 | |-- value: struct (nullable = false) 203 | | |-- ...value fields... 204 | ``` 205 | 206 | and add state batch output as follow: 207 | 208 | ```scala 209 | val operatorId = 0 210 | val batchId = 1 // the version of state for the output of batch is batchId + 1 211 | val newShufflePartitions = 10 212 | 213 | df.write 214 | .format("state") 215 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 216 | new Path(newCheckpointRoot, "state").getAbsolutePath) 217 | .option(StateStoreDataSourceProvider.PARAM_VERSION, batchId + 1) 218 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 219 | .option(StateStoreDataSourceProvider.PARAM_NEW_PARTITIONS, newShufflePartitions) 220 | .save() // saveAsTable() also supported 221 | ``` 222 | 223 | Before that, you may want to create a savepoint from existing checkpoint to another path, so that you can simply 224 | run new Structured Streaming query with modified state. 225 | 226 | ```scala 227 | // Here we assume 'spark' as SparkSession. 228 | // If you just want to create a savepoint without modifying state, provide `additionalMetadataConf` as `Map.empty`, 229 | // and `excludeState` as `false`. 230 | // That said, if you want to prepare state modification, it would be good to create a savepoint with providing 231 | // addConf to new shuffle partition (like below), and `excludeState` as `true` (to avoid unnecessary copy for state) 232 | val addConf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> newShufflePartitions.toString) 233 | CheckpointUtil.createSavePoint(spark, oldCpPath, newCpPath, newLastBatchId, addConf, excludeState = true) 234 | ``` 235 | 236 | If you ran streaming aggregation query before Spark 2.4.0 and want to upgrade (or already upgraded) to Spark 2.4.0 or higher, 237 | you may also want to migrate your state from state format 1 to 2 (Spark 2.4.0 introduces it) to reduce overall state size, 238 | and get some speedup from most of cases. 239 | 240 | Please refer [SPARK-24763](https://issues.apache.org/jira/browse/SPARK-24763) for more details. 241 | 242 | ```scala 243 | // Here we assume 'spark' as SparkSession. 244 | 245 | // Please refer above to see how to construct `stateSchema` 246 | // (manually, or reading from existing query) 247 | // Here we already construct `stateSchema` as state schema. 248 | 249 | val migrator = new StreamingAggregationMigrator(spark) 250 | migrator.convertVersion1To2(oldCpPath, newCpPath, stateKeySchema, stateValueSchema) 251 | ``` 252 | 253 | Similarly, if you ran flatMapGroupsWithState query before Spark 2.4.0 and want to upgrade (or already upgraded) to Spark 2.4.0 or higher, 254 | you may also want to migrate your state from state format 1 to 2 (Spark 2.4.0 introduces it) to enable setting timeout even when state is null. 255 | (This also changes timeout timestamp type from int to long.) 256 | 257 | Please refer [SPARK-22187](https://issues.apache.org/jira/browse/SPARK-22187) for more details. 258 | 259 | ```scala 260 | // Here we assume 'spark' as SparkSession. 261 | 262 | // Please refer above to see how to construct `stateSchema` 263 | // (manually, or reading from existing query) 264 | // Here we already construct `stateSchema` as state schema. 265 | 266 | val migrator = new FlatMapGroupsWithStateMigrator(spark) 267 | migrator.convertVersion1To2(oldCpPath, newCpPath, stateKeySchema, stateValueSchema) 268 | ``` 269 | 270 | Please refer the [test codes](https://github.com/HeartSaVioR/spark-state-tools/tree/master/src/test/scala/net/heartsavior/spark/sql/state) to see more examples on how to use. 271 | 272 | ## License 273 | 274 | Copyright 2019-2020 Jungtaek Lim "" 275 | 276 | Licensed under the Apache License, Version 2.0 (the "License"); 277 | you may not use this file except in compliance with the License. 278 | You may obtain a copy of the License at 279 | 280 | http://www.apache.org/licenses/LICENSE-2.0 281 | 282 | Unless required by applicable law or agreed to in writing, software 283 | distributed under the License is distributed on an "AS IS" BASIS, 284 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 285 | See the License for the specific language governing permissions and 286 | limitations under the License. 287 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | 204 | ------------------------------------------------------------------------------------ 205 | This product bundles various third-party components under other open source licenses. 206 | This section summarizes those components and their licenses. See licenses/ 207 | for text of these licenses. 208 | 209 | 210 | Apache Software Foundation License 2.0 211 | -------------------------------------- 212 | 213 | common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java 214 | core/src/main/java/org/apache/spark/util/collection/TimSort.java 215 | core/src/main/resources/org/apache/spark/ui/static/bootstrap* 216 | core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* 217 | core/src/main/resources/org/apache/spark/ui/static/vis* 218 | docs/js/vendor/bootstrap.js 219 | 220 | 221 | Python Software Foundation License 222 | ---------------------------------- 223 | 224 | pyspark/heapq3.py 225 | 226 | 227 | BSD 3-Clause 228 | ------------ 229 | 230 | python/lib/py4j-*-src.zip 231 | python/pyspark/cloudpickle.py 232 | python/pyspark/join.py 233 | core/src/main/resources/org/apache/spark/ui/static/d3.min.js 234 | 235 | The CSS style for the navigation sidebar of the documentation was originally 236 | submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project 237 | is distributed under the 3-Clause BSD license. 238 | 239 | 240 | MIT License 241 | ----------- 242 | 243 | core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js 244 | core/src/main/resources/org/apache/spark/ui/static/*dataTables* 245 | core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js 246 | ore/src/main/resources/org/apache/spark/ui/static/jquery* 247 | core/src/main/resources/org/apache/spark/ui/static/sorttable.js 248 | docs/js/vendor/anchor.min.js 249 | docs/js/vendor/jquery* 250 | docs/js/vendor/modernizer* 251 | 252 | 253 | Creative Commons CC0 1.0 Universal Public Domain Dedication 254 | ----------------------------------------------------------- 255 | (see LICENSE-CC0.txt) 256 | 257 | data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg 258 | data/mllib/images/kittens/54893.jpg 259 | data/mllib/images/kittens/DP153539.jpg 260 | data/mllib/images/kittens/DP802813.jpg 261 | data/mllib/images/multi-channel/chr30.4.184.jpg -------------------------------------------------------------------------------- /src/test/scala/net/heartsavior/spark/sql/state/StateStoreTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import java.io.File 20 | import java.sql.Timestamp 21 | 22 | import org.apache.spark.sql.Dataset 23 | import org.apache.spark.sql.execution.streaming.MemoryStream 24 | import org.apache.spark.sql.execution.streaming.state.StateStore 25 | import org.apache.spark.sql.functions._ 26 | import org.apache.spark.sql.hack.SparkSqlHack 27 | import org.apache.spark.sql.streaming._ 28 | import org.apache.spark.sql.streaming.util.StreamManualClock 29 | import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} 30 | 31 | trait StateStoreTest extends StreamTest { 32 | import testImplicits._ 33 | 34 | override def afterAll(): Unit = { 35 | super.afterAll() 36 | StateStore.stop() 37 | } 38 | 39 | protected def withTempCheckpoints(body: (File, File) => Unit) { 40 | val src = SparkSqlHack.createTempDir(namePrefix = "streaming.old") 41 | val tmp = SparkSqlHack.createTempDir(namePrefix = "streaming.new") 42 | try { 43 | body(src, tmp) 44 | } finally { 45 | SparkSqlHack.deleteRecursively(src) 46 | SparkSqlHack.deleteRecursively(tmp) 47 | } 48 | } 49 | 50 | protected def runCompositeKeyStreamingAggregationQuery( 51 | checkpointRoot: String): Unit = { 52 | val inputData = MemoryStream[Int] 53 | val aggregated = getCompositeKeyStreamingAggregationQuery(inputData) 54 | 55 | testStream(aggregated, OutputMode.Update)( 56 | StartStream(checkpointLocation = checkpointRoot), 57 | // batch 0 58 | AddData(inputData, 0 to 5: _*), 59 | CheckLastBatch( 60 | (0, "Apple", 1, 0, 0, 0), 61 | (1, "Banana", 1, 1, 1, 1), 62 | (0, "Strawberry", 1, 2, 2, 2), 63 | (1, "Apple", 1, 3, 3, 3), 64 | (0, "Banana", 1, 4, 4, 4), 65 | (1, "Strawberry", 1, 5, 5, 5) 66 | ), 67 | // batch 1 68 | AddData(inputData, 6 to 10: _*), 69 | // state also contains (1, "Strawberry", 1, 5, 5, 5) but not updated here 70 | CheckLastBatch( 71 | (0, "Apple", 2, 6, 6, 0), // 0, 6 72 | (1, "Banana", 2, 8, 7, 1), // 1, 7 73 | (0, "Strawberry", 2, 10, 8, 2), // 2, 8 74 | (1, "Apple", 2, 12, 9, 3), // 3, 9 75 | (0, "Banana", 2, 14, 10, 4) // 4, 10 76 | ), 77 | StopStream, 78 | StartStream(checkpointLocation = checkpointRoot), 79 | // batch 2 80 | AddData(inputData, 3, 2, 1), 81 | CheckLastBatch( 82 | (1, "Banana", 3, 9, 7, 1), // 1, 7, 1 83 | (0, "Strawberry", 3, 12, 8, 2), // 2, 8, 2 84 | (1, "Apple", 3, 15, 9, 3) // 3, 9, 3 85 | ) 86 | ) 87 | } 88 | 89 | protected def getCompositeKeyStreamingAggregationQuery 90 | : Dataset[(Int, String, Long, Long, Int, Int)] = { 91 | getCompositeKeyStreamingAggregationQuery(MemoryStream[Int]) 92 | } 93 | 94 | protected def getCompositeKeyStreamingAggregationQuery( 95 | inputData: MemoryStream[Int]): Dataset[(Int, String, Long, Long, Int, Int)] = { 96 | inputData.toDF() 97 | .selectExpr("value", "value % 2 AS groupKey", 98 | "(CASE value % 3 WHEN 0 THEN 'Apple' WHEN 1 THEN 'Banana' ELSE 'Strawberry' END) AS fruit") 99 | .groupBy($"groupKey", $"fruit") 100 | .agg( 101 | count("*").as("cnt"), 102 | sum("value").as("sum"), 103 | max("value").as("max"), 104 | min("value").as("min") 105 | ) 106 | .as[(Int, String, Long, Long, Int, Int)] 107 | } 108 | 109 | protected def getSchemaForCompositeKeyStreamingAggregationQuery( 110 | formatVersion: Int): StructType = { 111 | val stateKeySchema = new StructType() 112 | .add("groupKey", IntegerType) 113 | .add("fruit", StringType, nullable = false) 114 | 115 | var stateValueSchema = formatVersion match { 116 | case 1 => 117 | new StructType().add("groupKey", IntegerType).add("fruit", StringType, nullable = false) 118 | case 2 => new StructType() 119 | case v => throw new IllegalArgumentException(s"Not valid format version $v") 120 | } 121 | 122 | stateValueSchema = stateValueSchema 123 | .add("cnt", LongType, nullable = false) 124 | .add("sum", LongType) 125 | .add("max", IntegerType) 126 | .add("min", IntegerType) 127 | 128 | new StructType() 129 | .add("key", stateKeySchema) 130 | .add("value", stateValueSchema) 131 | } 132 | 133 | protected def runLargeDataStreamingAggregationQuery( 134 | checkpointRoot: String): Unit = { 135 | val inputData = MemoryStream[Int] 136 | val aggregated = getLargeDataStreamingAggregationQuery(inputData) 137 | 138 | // check with more data - leverage full partitions 139 | testStream(aggregated, OutputMode.Update)( 140 | StartStream(checkpointLocation = checkpointRoot), 141 | // batch 0 142 | AddData(inputData, 0 until 20: _*), 143 | CheckLastBatch( 144 | (0, 2, 10, 10, 0), // 0, 10 145 | (1, 2, 12, 11, 1), // 1, 11 146 | (2, 2, 14, 12, 2), // 2, 12 147 | (3, 2, 16, 13, 3), // 3, 13 148 | (4, 2, 18, 14, 4), // 4, 14 149 | (5, 2, 20, 15, 5), // 5, 15 150 | (6, 2, 22, 16, 6), // 6, 16 151 | (7, 2, 24, 17, 7), // 7, 17 152 | (8, 2, 26, 18, 8), // 8, 18 153 | (9, 2, 28, 19, 9) // 9, 19 154 | ), 155 | // batch 1 156 | AddData(inputData, 20 until 40: _*), 157 | CheckLastBatch( 158 | (0, 4, 60, 30, 0), // 0, 10, 20, 30 159 | (1, 4, 64, 31, 1), // 1, 11, 21, 31 160 | (2, 4, 68, 32, 2), // 2, 12, 22, 32 161 | (3, 4, 72, 33, 3), // 3, 13, 23, 33 162 | (4, 4, 76, 34, 4), // 4, 14, 24, 34 163 | (5, 4, 80, 35, 5), // 5, 15, 25, 35 164 | (6, 4, 84, 36, 6), // 6, 16, 26, 36 165 | (7, 4, 88, 37, 7), // 7, 17, 27, 37 166 | (8, 4, 92, 38, 8), // 8, 18, 28, 38 167 | (9, 4, 96, 39, 9) // 9, 19, 29, 39 168 | ), 169 | StopStream, 170 | StartStream(checkpointLocation = checkpointRoot), 171 | // batch 2 172 | AddData(inputData, 0, 1, 2), 173 | CheckLastBatch( 174 | (0, 5, 60, 30, 0), // 0, 10, 20, 30, 0 175 | (1, 5, 65, 31, 1), // 1, 11, 21, 31, 1 176 | (2, 5, 70, 32, 2) // 2, 12, 22, 32, 2 177 | ) 178 | ) 179 | } 180 | 181 | protected def getLargeDataStreamingAggregationQuery: Dataset[(Int, Long, Long, Int, Int)] = { 182 | getLargeDataStreamingAggregationQuery(MemoryStream[Int]) 183 | } 184 | 185 | protected def getLargeDataStreamingAggregationQuery( 186 | inputData: MemoryStream[Int]): Dataset[(Int, Long, Long, Int, Int)] = { 187 | inputData.toDF() 188 | .selectExpr("value", "value % 10 AS groupKey") 189 | .groupBy($"groupKey") 190 | .agg( 191 | count("*").as("cnt"), 192 | sum("value").as("sum"), 193 | max("value").as("max"), 194 | min("value").as("min") 195 | ) 196 | .as[(Int, Long, Long, Int, Int)] 197 | } 198 | 199 | protected def getSchemaForLargeDataStreamingAggregationQuery(formatVersion: Int): StructType = { 200 | val stateKeySchema = new StructType() 201 | .add("groupKey", IntegerType) 202 | 203 | var stateValueSchema = formatVersion match { 204 | case 1 => new StructType().add("groupKey", IntegerType) 205 | case 2 => new StructType() 206 | case v => throw new IllegalArgumentException(s"Not valid format version $v") 207 | } 208 | 209 | stateValueSchema = stateValueSchema 210 | .add("cnt", LongType) 211 | .add("sum", LongType) 212 | .add("max", IntegerType) 213 | .add("min", IntegerType) 214 | 215 | new StructType() 216 | .add("key", stateKeySchema) 217 | .add("value", stateValueSchema) 218 | } 219 | 220 | protected def runStreamingDeduplicationQuery(checkpointRoot: String): Unit = { 221 | val inputData = MemoryStream[Int] 222 | 223 | val aggregated = inputData.toDF() 224 | .selectExpr("value", "value % 10 AS groupKey") 225 | .dropDuplicates(Seq("groupKey")) 226 | .as[(Int, Int)] 227 | 228 | testStream(aggregated, OutputMode.Update)( 229 | StartStream(checkpointLocation = checkpointRoot), 230 | // batch 0 231 | AddData(inputData, 0 until 20: _*), 232 | CheckLastBatch( 233 | (0, 0), 234 | (1, 1), 235 | (2, 2), 236 | (3, 3), 237 | (4, 4), 238 | (5, 5), 239 | (6, 6), 240 | (7, 7), 241 | (8, 8), 242 | (9, 9) 243 | ), 244 | // batch 1 245 | AddData(inputData, 20 until 40: _*), 246 | // no new update 247 | CheckLastBatch(), 248 | StopStream, 249 | StartStream(checkpointLocation = checkpointRoot), 250 | // batch 2 251 | AddData(inputData, 0, 1, 2), 252 | // no new update 253 | CheckLastBatch() 254 | ) 255 | } 256 | 257 | protected def runStreamingJoinQuery(checkpointRoot: String): Unit = { 258 | val inputData = MemoryStream[Int] 259 | val joined = getStreamingJoinQuery(inputData) 260 | 261 | testStream(joined, OutputMode.Append)( 262 | StartStream(checkpointLocation = checkpointRoot), 263 | // batch 0 264 | AddData(inputData, 0 until 5: _*), 265 | // 0 and 3 don't exist on df2 266 | CheckLastBatch( 267 | (1, "odd", 1, "odd"), 268 | (2, "even", 2, "even"), 269 | (4, "even", 4, "even") 270 | ), 271 | // batch 1 272 | AddData(inputData, 5 until 10: _*), 273 | CheckLastBatch( 274 | (5, "odd", 5, "odd"), 275 | (7, "odd", 7, "odd"), 276 | (8, "even", 8, "even") 277 | ) 278 | ) 279 | } 280 | 281 | protected def getStreamingJoinQuery: Dataset[(Int, String, Int, String)] = { 282 | getStreamingJoinQuery(MemoryStream[Int]) 283 | } 284 | 285 | protected def getStreamingJoinQuery( 286 | inputData: MemoryStream[Int]): Dataset[(Int, String, Int, String)] = { 287 | val df = inputData.toDF() 288 | .selectExpr("value", "CASE value % 2 WHEN 0 THEN 'even' ELSE 'odd' END AS isEven") 289 | val df2 = df.selectExpr("value AS value2", "iseven AS isEven2") 290 | .where("value % 3 != 0") 291 | 292 | df.join(df2, expr("value == value2")) 293 | .selectExpr("value", "iseven", "value2", "iseven2") 294 | .as[(Int, String, Int, String)] 295 | } 296 | 297 | protected def runFlatMapGroupsWithStateQuery(checkpointRoot: String): Unit = { 298 | val clock = new StreamManualClock 299 | 300 | val inputData = MemoryStream[(String, Long)] 301 | val remapped = getFlatMapGroupsWithStateQuery(inputData) 302 | 303 | testStream(remapped, OutputMode.Update)( 304 | // batch 0 305 | StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, 306 | checkpointLocation = checkpointRoot), 307 | AddData(inputData, ("hello world", 1L), ("hello scala", 2L)), 308 | AdvanceManualClock(1 * 1000), 309 | CheckNewAnswer( 310 | ("hello", 2, 1000, false), 311 | ("world", 1, 0, false), 312 | ("scala", 1, 0, false) 313 | ), 314 | // batch 1 315 | AddData(inputData, ("hello world", 3L), ("hello scala", 4L)), 316 | AdvanceManualClock(1 * 1000), 317 | CheckNewAnswer( 318 | ("hello", 4, 3000, false), 319 | ("world", 2, 2000, false), 320 | ("scala", 2, 2000, false) 321 | ) 322 | ) 323 | } 324 | 325 | protected def getFlatMapGroupsWithStateQuery: Dataset[(String, Int, Long, Boolean)] = { 326 | getFlatMapGroupsWithStateQuery(MemoryStream[(String, Long)]) 327 | } 328 | 329 | protected def getFlatMapGroupsWithStateQuery( 330 | inputData: MemoryStream[(String, Long)]): Dataset[(String, Int, Long, Boolean)] = { 331 | // scalastyle:off line.size.limit 332 | // This test code is borrowed from sessionization example of Apache Spark, 333 | // with modification a bit to run with testStream 334 | // https://github.com/apache/spark/blob/v2.4.1/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala 335 | // scalastyle:on 336 | 337 | val events = inputData.toDF() 338 | .as[(String, Timestamp)] 339 | .flatMap { case (line, timestamp) => 340 | line.split(" ").map(word => Event(sessionId = word, timestamp)) 341 | } 342 | 343 | val sessionUpdates = events 344 | .groupByKey(event => event.sessionId) 345 | .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) { 346 | 347 | case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) => 348 | if (state.hasTimedOut) { 349 | val finalUpdate = 350 | SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true) 351 | state.remove() 352 | finalUpdate 353 | } else { 354 | val timestamps = events.map(_.timestamp.getTime).toSeq 355 | val updatedSession = if (state.exists) { 356 | val oldSession = state.get 357 | SessionInfo( 358 | oldSession.numEvents + timestamps.size, 359 | oldSession.startTimestampMs, 360 | math.max(oldSession.endTimestampMs, timestamps.max)) 361 | } else { 362 | SessionInfo(timestamps.size, timestamps.min, timestamps.max) 363 | } 364 | state.update(updatedSession) 365 | 366 | state.setTimeoutDuration("10 seconds") 367 | SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false) 368 | } 369 | } 370 | 371 | sessionUpdates.map(si => (si.id, si.numEvents, si.durationMs, si.expired)) 372 | } 373 | 374 | } 375 | 376 | case class Event(sessionId: String, timestamp: Timestamp) 377 | 378 | case class SessionInfo( 379 | numEvents: Int, 380 | startTimestampMs: Long, 381 | endTimestampMs: Long) { 382 | def durationMs: Long = endTimestampMs - startTimestampMs 383 | } 384 | 385 | case class SessionUpdate( 386 | id: String, 387 | durationMs: Long, 388 | numEvents: Int, 389 | expired: Boolean) 390 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 17 | 39 | 40 | 41 | Scalastyle standard configuration 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | " 53 | * 54 | * Licensed under the Apache License, Version 2.0 (the "License"); 55 | * you may not use this file except in compliance with the License. 56 | * You may obtain a copy of the License at 57 | * 58 | * http://www.apache.org/licenses/LICENSE-2.0 59 | * 60 | * Unless required by applicable law or agreed to in writing, software 61 | * distributed under the License is distributed on an "AS IS" BASIS, 62 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 63 | * See the License for the specific language governing permissions and 64 | * limitations under the License. 65 | */]]> 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | true 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW 125 | 126 | 127 | 128 | 129 | 130 | ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | ^FunSuite[A-Za-z]*$ 140 | Tests must extend org.apache.spark.SparkFunSuite instead. 141 | 142 | 143 | 144 | 145 | ^println$ 146 | 150 | 151 | 152 | 153 | spark(.sqlContext)?.sparkContext.hadoopConfiguration 154 | 163 | 164 | 165 | 166 | @VisibleForTesting 167 | 170 | 171 | 172 | 173 | Runtime\.getRuntime\.addShutdownHook 174 | 182 | 183 | 184 | 185 | mutable\.SynchronizedBuffer 186 | 194 | 195 | 196 | 197 | Class\.forName 198 | 205 | 206 | 207 | 208 | Await\.result 209 | 216 | 217 | 218 | 219 | Await\.ready 220 | 227 | 228 | 229 | 230 | 231 | JavaConversions 232 | Instead of importing implicits in scala.collection.JavaConversions._, import 233 | scala.collection.JavaConverters._ and use .asScala / .asJava methods 234 | 235 | 236 | 237 | org\.apache\.commons\.lang\. 238 | Use Commons Lang 3 classes (package org.apache.commons.lang3.*) instead 239 | of Commons Lang 2 (package org.apache.commons.lang.*) 240 | 241 | 242 | 243 | extractOpt 244 | Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter 245 | is slower. 246 | 247 | 248 | 249 | 250 | java,scala,3rdParty,spark 251 | javax?\..* 252 | scala\..* 253 | (?!org\.apache\.spark\.).* 254 | org\.apache\.spark\..* 255 | 256 | 257 | 258 | 259 | 260 | COMMA 261 | 262 | 263 | 264 | 265 | 266 | \)\{ 267 | 270 | 271 | 272 | 273 | (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] 274 | Use Javadoc style indentation for multiline comments 275 | 276 | 277 | 278 | case[^\n>]*=>\s*\{ 279 | Omit braces in case clauses. 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 800> 333 | 334 | 335 | 336 | 337 | 30 338 | 339 | 340 | 341 | 342 | 10 343 | 344 | 345 | 346 | 347 | 50 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | -1,0,1,2,3 359 | 360 | 361 | 362 | -------------------------------------------------------------------------------- /src/test/scala/net/heartsavior/spark/sql/state/StateStoreStreamingAggregationWriteSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 Jungtaek Lim "" 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package net.heartsavior.spark.sql.state 18 | 19 | import java.io.File 20 | 21 | import net.heartsavior.spark.sql.checkpoint.CheckpointUtil 22 | import org.scalatest.{Assertions, BeforeAndAfterAll} 23 | 24 | import org.apache.spark.sql.Row 25 | import org.apache.spark.sql.execution.streaming._ 26 | import org.apache.spark.sql.execution.streaming.state.StateStore 27 | import org.apache.spark.sql.internal.SQLConf 28 | import org.apache.spark.sql.streaming.OutputMode 29 | import org.apache.spark.sql.types._ 30 | 31 | class StateStoreStreamingAggregationWriteSuite 32 | extends StateStoreTest 33 | with BeforeAndAfterAll 34 | with Assertions { 35 | 36 | override def afterAll(): Unit = { 37 | super.afterAll() 38 | StateStore.stop() 39 | } 40 | 41 | override protected def beforeEach(): Unit = { 42 | super.beforeEach() 43 | sql("drop table if exists tbl") 44 | } 45 | 46 | test("rescale state from streaming aggregation - state format version 1") { 47 | withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1"): _*) { 48 | withTempCheckpoints { case (oldCpDir, newCpDir) => 49 | runLargeDataStreamingAggregationQuery(oldCpDir.getAbsolutePath) 50 | 51 | val operatorId = 0 52 | val newLastBatchId = 1 53 | val newShufflePartitions = 20 54 | 55 | val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(1) 56 | 57 | val stateReadDf = spark.read 58 | .format("state") 59 | .schema(stateSchema) 60 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 61 | new File(oldCpDir, "state").getAbsolutePath) 62 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 63 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 64 | .load() 65 | 66 | val expectedRows = stateReadDf.collect() 67 | 68 | // copy all contents except state to new checkpoint root directory 69 | // adjust number of shuffle partitions in prior to migrate state 70 | val addConf = getAdditionalConfForMetadata(newShufflePartitions) 71 | CheckpointUtil.createSavePoint(spark, oldCpDir.getAbsolutePath, 72 | newCpDir.getAbsolutePath, newLastBatchId, addConf, excludeState = true) 73 | 74 | stateReadDf.write 75 | .format("state") 76 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 77 | new File(newCpDir, "state").getAbsolutePath) 78 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 79 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 80 | .option(StateStoreDataSourceProvider.PARAM_NEW_PARTITIONS, newShufflePartitions) 81 | .saveAsTable("tbl") 82 | 83 | // verify write-and-read works 84 | checkAnswer(spark.sql("select * from tbl"), expectedRows) 85 | 86 | // read again 87 | val stateReadDf2 = spark.read 88 | .format("state") 89 | .schema(stateSchema) 90 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 91 | new File(newCpDir, "state").getAbsolutePath) 92 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 93 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 94 | .load() 95 | 96 | checkAnswer(stateReadDf2, expectedRows) 97 | 98 | verifyContinueRunLargeDataStreamingAggregationQuery(newCpDir.getAbsolutePath, 99 | newShufflePartitions) 100 | } 101 | } 102 | } 103 | 104 | private def getAdditionalConfForMetadata(newShufflePartitions: Int) = { 105 | Map(SQLConf.SHUFFLE_PARTITIONS.key -> newShufflePartitions.toString) 106 | } 107 | 108 | test("rescale state from streaming aggregation - state format version 2") { 109 | withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { 110 | withTempCheckpoints { case (oldCpDir, newCpDir) => 111 | runLargeDataStreamingAggregationQuery(oldCpDir.getAbsolutePath) 112 | 113 | val operatorId = 0 114 | val newLastBatchId = 1 115 | val newShufflePartitions = 20 116 | 117 | val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(2) 118 | 119 | val stateReadDf = spark.read 120 | .format("state") 121 | .schema(stateSchema) 122 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 123 | new File(oldCpDir, "state").getAbsolutePath) 124 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 125 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 126 | .load() 127 | 128 | val expectedRows = stateReadDf.collect() 129 | 130 | // copy all contents except state to new checkpoint root directory 131 | // adjust number of shuffle partitions in prior to migrate state 132 | val addConf = getAdditionalConfForMetadata(newShufflePartitions) 133 | CheckpointUtil.createSavePoint(spark, oldCpDir.getAbsolutePath, 134 | newCpDir.getAbsolutePath, newLastBatchId, addConf, excludeState = true) 135 | 136 | stateReadDf.write 137 | .format("state") 138 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 139 | new File(newCpDir, "state").getAbsolutePath) 140 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 141 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 142 | .option(StateStoreDataSourceProvider.PARAM_NEW_PARTITIONS, newShufflePartitions) 143 | .saveAsTable("tbl") 144 | 145 | // verify write-and-read works 146 | checkAnswer(spark.sql("select * from tbl"), expectedRows) 147 | 148 | // read again 149 | val stateReadDf2 = spark.read 150 | .format("state") 151 | .schema(stateSchema) 152 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 153 | new File(newCpDir, "state").getAbsolutePath) 154 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 155 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 156 | .load() 157 | 158 | checkAnswer(stateReadDf2, expectedRows) 159 | 160 | verifyContinueRunLargeDataStreamingAggregationQuery(newCpDir.getAbsolutePath, 161 | newShufflePartitions) 162 | } 163 | } 164 | } 165 | 166 | test("simple state schema evolution from streaming aggregation - state format version 2") { 167 | withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { 168 | withTempCheckpoints { case (oldCpDir, newCpDir) => 169 | runLargeDataStreamingAggregationQuery(oldCpDir.getAbsolutePath) 170 | 171 | val operatorId = 0 172 | val newLastBatchId = 1 173 | val newShufflePartitions = 20 174 | 175 | val stateSchema = getSchemaForLargeDataStreamingAggregationQuery(2) 176 | 177 | val stateReadDf = spark.read 178 | .format("state") 179 | .schema(stateSchema) 180 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 181 | new File(oldCpDir, "state").getAbsolutePath) 182 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 183 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 184 | .load() 185 | 186 | // rows: 187 | // (0, 4, 60, 30, 0) 188 | // (1, 4, 64, 31, 1) 189 | // (2, 4, 68, 32, 2) 190 | // (3, 4, 72, 33, 3) 191 | // (4, 4, 76, 34, 4) 192 | // (5, 4, 80, 35, 5) 193 | // (6, 4, 84, 36, 6) 194 | // (7, 4, 88, 37, 7) 195 | // (8, 4, 92, 38, 8) 196 | // (9, 4, 96, 39, 9) 197 | 198 | val evolutionDf = stateReadDf 199 | .selectExpr("key", "value", "(key.groupKey * value.cnt) AS groupKeySum") 200 | .selectExpr("key", "struct(value.*, groupKeySum) AS value") 201 | 202 | logInfo(s"Schema: ${evolutionDf.schema.treeString}") 203 | 204 | // new rows 205 | val expectedRows = Seq( 206 | Row(Row(0), Row(4, 60, 30, 0, 0)), 207 | Row(Row(1), Row(4, 64, 31, 1, 4)), 208 | Row(Row(2), Row(4, 68, 32, 2, 8)), 209 | Row(Row(3), Row(4, 72, 33, 3, 12)), 210 | Row(Row(4), Row(4, 76, 34, 4, 16)), 211 | Row(Row(5), Row(4, 80, 35, 5, 20)), 212 | Row(Row(6), Row(4, 84, 36, 6, 24)), 213 | Row(Row(7), Row(4, 88, 37, 7, 28)), 214 | Row(Row(8), Row(4, 92, 38, 8, 32)), 215 | Row(Row(9), Row(4, 96, 39, 9, 36)) 216 | ) 217 | 218 | // copy all contents except state to new checkpoint root directory 219 | // adjust number of shuffle partitions in prior to migrate state 220 | val addConf = getAdditionalConfForMetadata(newShufflePartitions) 221 | CheckpointUtil.createSavePoint(spark, oldCpDir.getAbsolutePath, 222 | newCpDir.getAbsolutePath, newLastBatchId, addConf, excludeState = true) 223 | 224 | evolutionDf.write 225 | .format("state") 226 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 227 | new File(newCpDir, "state").getAbsolutePath) 228 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 229 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 230 | .option(StateStoreDataSourceProvider.PARAM_NEW_PARTITIONS, newShufflePartitions) 231 | .saveAsTable("tbl") 232 | 233 | // verify write-and-read works 234 | checkAnswer(spark.sql("select * from tbl"), expectedRows) 235 | 236 | val newStateSchema = new StructType(stateSchema.fields.map { field => 237 | if (field.name == "value") { 238 | StructField("value", field.dataType.asInstanceOf[StructType] 239 | .add("groupKeySum", LongType)) 240 | } else { 241 | field 242 | } 243 | }) 244 | 245 | // read again 246 | val stateReadDf2 = spark.read 247 | .format("state") 248 | .schema(newStateSchema) 249 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 250 | new File(newCpDir, "state").getAbsolutePath) 251 | .option(StateStoreDataSourceProvider.PARAM_VERSION, "2") 252 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, "0") 253 | .load() 254 | 255 | checkAnswer(stateReadDf2, expectedRows) 256 | 257 | verifyContinueRunLargeDataStreamingAggregationQueryWithSchemaEvolution( 258 | newCpDir.getAbsolutePath, newShufflePartitions) 259 | } 260 | } 261 | } 262 | 263 | test("simple state schema evolution from streaming aggregation - composite key") { 264 | withSQLConf(Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2"): _*) { 265 | withTempCheckpoints { case (oldCpDir, newCpDir) => 266 | runCompositeKeyStreamingAggregationQuery(oldCpDir.getAbsolutePath) 267 | 268 | val operatorId = 0 269 | val newLastBatchId = 1 270 | val newShufflePartitions = 20 271 | 272 | val stateSchema = getSchemaForCompositeKeyStreamingAggregationQuery(2) 273 | 274 | val stateReadDf = spark.read 275 | .format("state") 276 | .schema(stateSchema) 277 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 278 | new File(oldCpDir, "state").getAbsolutePath) 279 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 280 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 281 | .load() 282 | 283 | // rows: 284 | // (0, "Apple", 2, 6, 6, 0) 285 | // (1, "Banana", 2, 8, 7, 1) 286 | // (0, "Strawberry", 2, 10, 8, 2) 287 | // (1, "Apple", 2, 12, 9, 3) 288 | // (0, "Banana", 2, 14, 10, 4) 289 | // (1, "Strawberry", 1, 5, 5, 5) 290 | 291 | val evolutionDf = stateReadDf 292 | .selectExpr("key", "value", "(key.groupKey * value.cnt) AS groupKeySum") 293 | .selectExpr("key", "struct(value.*, groupKeySum) AS value") 294 | 295 | logInfo(s"Schema: ${evolutionDf.schema.treeString}") 296 | 297 | // new rows 298 | val expectedRows = Seq( 299 | Row(Row(0, "Apple"), Row(2, 6, 6, 0, 0)), 300 | Row(Row(1, "Banana"), Row(2, 8, 7, 1, 2)), 301 | Row(Row(0, "Strawberry"), Row(2, 10, 8, 2, 0)), 302 | Row(Row(1, "Apple"), Row(2, 12, 9, 3, 2)), 303 | Row(Row(0, "Banana"), Row(2, 14, 10, 4, 0)), 304 | Row(Row(1, "Strawberry"), Row(1, 5, 5, 5, 1)) 305 | ) 306 | 307 | // copy all contents except state to new checkpoint root directory 308 | // adjust number of shuffle partitions in prior to migrate state 309 | val addConf = getAdditionalConfForMetadata(newShufflePartitions) 310 | CheckpointUtil.createSavePoint(spark, oldCpDir.getAbsolutePath, 311 | newCpDir.getAbsolutePath, newLastBatchId, addConf, excludeState = true) 312 | 313 | evolutionDf.write 314 | .format("state") 315 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 316 | new File(newCpDir, "state").getAbsolutePath) 317 | .option(StateStoreDataSourceProvider.PARAM_VERSION, newLastBatchId + 1) 318 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, operatorId) 319 | .option(StateStoreDataSourceProvider.PARAM_NEW_PARTITIONS, newShufflePartitions) 320 | .saveAsTable("tbl") 321 | 322 | // verify write-and-read works 323 | checkAnswer(spark.sql("select * from tbl"), expectedRows) 324 | 325 | val newStateSchema = new StructType(stateSchema.fields.map { field => 326 | if (field.name == "value") { 327 | StructField("value", field.dataType.asInstanceOf[StructType] 328 | .add("groupKeySum", LongType)) 329 | } else { 330 | field 331 | } 332 | }) 333 | 334 | // read again 335 | val stateReadDf2 = spark.read 336 | .format("state") 337 | .schema(newStateSchema) 338 | .option(StateStoreDataSourceProvider.PARAM_CHECKPOINT_LOCATION, 339 | new File(newCpDir, "state").getAbsolutePath) 340 | .option(StateStoreDataSourceProvider.PARAM_VERSION, "2") 341 | .option(StateStoreDataSourceProvider.PARAM_OPERATOR_ID, "0") 342 | .load() 343 | 344 | checkAnswer(stateReadDf2, expectedRows) 345 | 346 | verifyContinueRunCompositeKeyStreamingAggregationQueryWithSchemaEvolution( 347 | newCpDir.getAbsolutePath, newShufflePartitions) 348 | } 349 | } 350 | } 351 | 352 | private def verifyContinueRunLargeDataStreamingAggregationQuery( 353 | checkpointRoot: String, 354 | newShufflePartitions: Int): Unit = { 355 | import testImplicits._ 356 | 357 | val inputData = MemoryStream[Int] 358 | val aggregated = getLargeDataStreamingAggregationQuery(inputData) 359 | 360 | // batch 0 361 | inputData.addData(0 until 20) 362 | // batch 1 363 | inputData.addData(20 until 40) 364 | 365 | testStream(aggregated, OutputMode.Update)( 366 | StartStream(checkpointLocation = checkpointRoot), 367 | // batch 2 368 | AddData(inputData, 0, 1, 2), 369 | CheckLastBatch( 370 | (0, 5, 60, 30, 0), // 0, 10, 20, 30, 0 371 | (1, 5, 65, 31, 1), // 1, 11, 21, 31, 1 372 | (2, 5, 70, 32, 2) // 2, 12, 22, 32, 2 373 | ), 374 | AssertOnQuery { query => 375 | val operators = query.lastExecution.executedPlan.collect { 376 | case p: StateStoreSaveExec => p 377 | } 378 | operators.forall(_.stateInfo.get.numPartitions === newShufflePartitions) 379 | } 380 | ) 381 | } 382 | 383 | private def verifyContinueRunLargeDataStreamingAggregationQueryWithSchemaEvolution( 384 | checkpointRoot: String, 385 | newShufflePartitions: Int): Unit = { 386 | import org.apache.spark.sql.functions._ 387 | import testImplicits._ 388 | 389 | val inputData = MemoryStream[Int] 390 | 391 | val aggregated = inputData.toDF() 392 | .selectExpr("value", "value % 10 AS groupKey") 393 | .groupBy($"groupKey") 394 | .agg( 395 | count("*").as("cnt"), 396 | sum("value").as("sum"), 397 | max("value").as("max"), 398 | min("value").as("min"), 399 | // NOTE: this query is modified after the query is checkpointed, and we are modifying state 400 | sum("groupKey").as("groupKeySum") 401 | ) 402 | .as[(Int, Long, Long, Int, Int, Long)] 403 | 404 | // batch 0 405 | inputData.addData(0 until 20) 406 | // batch 1 407 | inputData.addData(20 until 40) 408 | 409 | testStream(aggregated, OutputMode.Update)( 410 | StartStream(checkpointLocation = checkpointRoot), 411 | // batch 2 412 | AddData(inputData, 0, 1, 2), 413 | CheckLastBatch( 414 | (0, 5, 60, 30, 0, 0), // 0, 10, 20, 30, 0 415 | (1, 5, 65, 31, 1, 5), // 1, 11, 21, 31, 1 416 | (2, 5, 70, 32, 2, 10) // 2, 12, 22, 32, 2 417 | ), 418 | AssertOnQuery { query => 419 | val operators = query.lastExecution.executedPlan.collect { 420 | case p: StateStoreSaveExec => p 421 | } 422 | operators.forall(_.stateInfo.get.numPartitions === newShufflePartitions) 423 | } 424 | ) 425 | } 426 | 427 | private def verifyContinueRunCompositeKeyStreamingAggregationQueryWithSchemaEvolution( 428 | checkpointRoot: String, 429 | newShufflePartitions: Int): Unit = { 430 | import org.apache.spark.sql.functions._ 431 | import testImplicits._ 432 | 433 | val inputData = MemoryStream[Int] 434 | 435 | val aggregated = inputData.toDF() 436 | .selectExpr("value", "value % 2 AS groupKey", 437 | "(CASE value % 3 WHEN 0 THEN 'Apple' WHEN 1 THEN 'Banana' ELSE 'Strawberry' END) AS fruit") 438 | .groupBy($"groupKey", $"fruit") 439 | .agg( 440 | count("*").as("cnt"), 441 | sum("value").as("sum"), 442 | max("value").as("max"), 443 | min("value").as("min"), 444 | // NOTE: this query is modified after the query is checkpointed, and we are modifying state 445 | sum("groupKey").as("groupKeySum") 446 | ) 447 | .as[(Int, String, Long, Long, Int, Int, Long)] 448 | 449 | // batch 0 450 | inputData.addData(0 to 5) 451 | // batch 1 452 | inputData.addData(6 to 10) 453 | 454 | testStream(aggregated, OutputMode.Update)( 455 | StartStream(checkpointLocation = checkpointRoot), 456 | // batch 2 457 | AddData(inputData, 3, 2, 1), 458 | CheckLastBatch( 459 | (1, "Banana", 3, 9, 7, 1, 3), // 1, 7, 1 460 | (0, "Strawberry", 3, 12, 8, 2, 0), // 2, 8, 2 461 | (1, "Apple", 3, 15, 9, 3, 3) // 3, 9, 3 462 | ), 463 | AssertOnQuery { query => 464 | val operators = query.lastExecution.executedPlan.collect { 465 | case p: StateStoreSaveExec => p 466 | } 467 | operators.forall(_.stateInfo.get.numPartitions === newShufflePartitions) 468 | } 469 | ) 470 | } 471 | } 472 | --------------------------------------------------------------------------------