├── .gitignore ├── project ├── build.properties └── assembly.sbt ├── src └── main │ ├── scala │ └── nl │ │ └── mrooding │ │ ├── data │ │ ├── AvroGenericRecordWriter.scala │ │ ├── AvroSerializable.scala │ │ ├── ProductStock.scala │ │ ├── ProductDescription.scala │ │ ├── AvroSchema.scala │ │ └── Product.scala │ │ ├── state │ │ ├── ProductSerializerSnapshot.scala │ │ ├── ProductSerializer.scala │ │ ├── CustomAvroSerializerSnapshot.scala │ │ └── CustomAvroSerializer.scala │ │ ├── source │ │ ├── ProductStockSource.scala │ │ └── ProductDescriptionSource.scala │ │ ├── ProductAggregator.scala │ │ └── ProductProcessor.scala │ └── resources │ └── avro │ └── product.avsc └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .idea/ 3 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.2.8 2 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9") 2 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/data/AvroGenericRecordWriter.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.data 2 | 3 | import org.apache.avro.generic.GenericRecord 4 | 5 | trait AvroGenericRecordWriter { 6 | def toGenericRecord: GenericRecord 7 | } 8 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/data/AvroSerializable.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.data 2 | 3 | import org.apache.flink.api.common.typeutils.TypeSerializer 4 | 5 | trait AvroSerializable[T] { 6 | def serializer: TypeSerializer[T] 7 | } 8 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/data/ProductStock.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.data 2 | 3 | import java.time.Instant 4 | 5 | case class ProductStock(id: String, 6 | stock: Long, 7 | updatedAt: Instant) 8 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/data/ProductDescription.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.data 2 | 3 | import java.time.Instant 4 | 5 | case class ProductDescription(id: String, 6 | description: String, 7 | updatedAt: Instant) 8 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/data/AvroSchema.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.data 2 | 3 | import org.apache.avro.Schema 4 | 5 | import scala.io.Source 6 | 7 | trait AvroSchema { 8 | def schemaPath: String 9 | 10 | lazy val getCurrentSchema: Schema = { 11 | val content = Source.fromURL(getClass.getResource(schemaPath)).mkString 12 | 13 | new Schema.Parser().parse(content) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # flink-avro-state-serialization 2 | 3 | Sample project showcasing how to use Apache Flink custom serializers to support state schema migration using Apache Avro. 4 | 5 | ## Avrohugger generation 6 | 7 | There's a branch called `avro-generated` which contains the setup to generate Classes based on Avro schemas. It generates the classes during compilation (`sbt compile`) or you can explicitly force generation using `sbt avroScalaGenerateSpecific`. 8 | 9 | Based on [sbt-avrohugger](https://github.com/julianpeeters/sbt-avrohugger) 10 | -------------------------------------------------------------------------------- /src/main/resources/avro/product.avsc: -------------------------------------------------------------------------------- 1 | { 2 | "type": "record", 3 | "name": "Product", 4 | "fields": [ 5 | { 6 | "name": "id", 7 | "type": "string" 8 | }, 9 | { 10 | "name": "description", 11 | "type": ["null", "string"], 12 | "default": null 13 | }, 14 | { 15 | "name": "stock", 16 | "type": ["null", "long"], 17 | "default": null 18 | }, 19 | { 20 | "name": "updatedAt", 21 | "type": "long", 22 | "logicalType": "timestamp-millis", 23 | "default": 0 24 | } 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/state/ProductSerializerSnapshot.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.state 2 | 3 | import nl.mrooding.data.Product 4 | import org.apache.avro.Schema 5 | import org.apache.flink.api.common.typeutils.TypeSerializer 6 | 7 | class ProductSerializerSnapshot(var stateSchema: Option[Schema]) extends CustomAvroSerializerSnapshot[Product] { 8 | def this() = { 9 | this(None) 10 | } 11 | 12 | override def getCurrentSchema: Schema = Product.getCurrentSchema 13 | 14 | override def restoreSerializer(): TypeSerializer[Product] = new ProductSerializer(stateSchema) 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/source/ProductStockSource.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.source 2 | 3 | import java.time.Instant 4 | 5 | import nl.mrooding.data.ProductStock 6 | import org.apache.flink.streaming.api.functions.source.SourceFunction 7 | 8 | class ProductStockSource(intervalMs: Long) extends SourceFunction[ProductStock] with Serializable { 9 | private var isRunning: Boolean = true 10 | 11 | private val r = new scala.util.Random 12 | 13 | override def run(ctx: SourceFunction.SourceContext[ProductStock]): Unit = { 14 | while (isRunning) { 15 | ctx.markAsTemporarilyIdle() 16 | Thread.sleep(intervalMs) 17 | ctx.collect(ProductStock(random(5).toString, random(10000), Instant.now)) 18 | } 19 | } 20 | 21 | private def random(max: Int) = { 22 | 0 + r.nextInt(max) 23 | } 24 | 25 | override def cancel(): Unit = isRunning = false 26 | } 27 | 28 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/source/ProductDescriptionSource.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.source 2 | 3 | import java.time.Instant 4 | 5 | import nl.mrooding.data.ProductDescription 6 | import org.apache.flink.streaming.api.functions.source.SourceFunction 7 | 8 | class ProductDescriptionSource(intervalMs: Long) extends SourceFunction[ProductDescription] with Serializable { 9 | private var isRunning: Boolean = true 10 | 11 | private val r = new scala.util.Random 12 | 13 | override def run(ctx: SourceFunction.SourceContext[ProductDescription]): Unit = { 14 | while (isRunning) { 15 | ctx.markAsTemporarilyIdle() 16 | Thread.sleep(intervalMs) 17 | ctx.collect(ProductDescription( 18 | random.toString, 19 | s"Product $random", 20 | Instant.now 21 | )) 22 | } 23 | } 24 | 25 | private def random = { 26 | 0 + r.nextInt(5) 27 | } 28 | 29 | override def cancel(): Unit = isRunning = false 30 | } 31 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/state/ProductSerializer.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.state 2 | 3 | import nl.mrooding.data.Product 4 | import org.apache.avro.Schema 5 | import org.apache.avro.generic.GenericRecord 6 | import org.apache.flink.api.common.typeutils.{TypeSerializer, TypeSerializerSnapshot} 7 | import org.apache.flink.util.InstantiationUtil 8 | 9 | class ProductSerializer(val stateSchema: Option[Schema]) extends CustomAvroSerializer[Product] { 10 | 11 | override def getCurrentSchema: Schema = Product.getCurrentSchema 12 | 13 | override def fromGenericRecord(genericRecord: GenericRecord): Product = Product.apply(genericRecord) 14 | 15 | override def duplicate(): TypeSerializer[Product] = 16 | new ProductSerializer(stateSchema) 17 | 18 | override def createInstance(): Product = InstantiationUtil.instantiate(classOf[Product]) 19 | 20 | override def snapshotConfiguration(): TypeSerializerSnapshot[Product] = new ProductSerializerSnapshot() 21 | } 22 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/ProductAggregator.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding 2 | 3 | import nl.mrooding.data.{ProductDescription, ProductStock} 4 | import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} 5 | import org.apache.flink.api.scala._ 6 | import nl.mrooding.source.{ProductDescriptionSource, ProductStockSource} 7 | 8 | object ProductAggregator { 9 | private[this] val intervalMs = 1000 10 | 11 | def main(args: Array[String]) : Unit = { 12 | val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment 13 | 14 | val productDescriptionStream: DataStream[ProductDescription] = env 15 | .addSource(new ProductDescriptionSource(intervalMs = intervalMs)) 16 | .keyBy(_.id) 17 | val productStockStream: DataStream[ProductStock] = env 18 | .addSource(new ProductStockSource(intervalMs = intervalMs)) 19 | .keyBy(_.id) 20 | 21 | productDescriptionStream 22 | .connect(productStockStream) 23 | .process(ProductProcessor()) 24 | .print() 25 | .setParallelism(1) 26 | 27 | env.execute("Product aggregator") 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/state/CustomAvroSerializerSnapshot.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.state 2 | 3 | import org.apache.avro.Schema 4 | import org.apache.flink.api.common.typeutils.{TypeSerializer, TypeSerializerSchemaCompatibility, TypeSerializerSnapshot} 5 | import org.apache.flink.core.memory.{DataInputView, DataOutputView} 6 | 7 | trait CustomAvroSerializerSnapshot[T] extends TypeSerializerSnapshot[T] { 8 | var stateSchema: Option[Schema] 9 | def getCurrentSchema: Schema 10 | 11 | override def getCurrentVersion: Int = 1 12 | 13 | override def writeSnapshot(out: DataOutputView): Unit = out.writeUTF(getCurrentSchema.toString(false)) 14 | 15 | override def readSnapshot(readVersion: Int, in: DataInputView, userCodeClassLoader: ClassLoader): Unit = { 16 | val previousSchemaDefinition = in.readUTF 17 | 18 | this.stateSchema = Some(parseAvroSchema(previousSchemaDefinition)) 19 | } 20 | 21 | private def parseAvroSchema(previousSchemaDefinition: String): Schema = { 22 | new Schema.Parser().parse(previousSchemaDefinition) 23 | } 24 | 25 | override def resolveSchemaCompatibility(newSerializer: TypeSerializer[T]): TypeSerializerSchemaCompatibility[T] = 26 | TypeSerializerSchemaCompatibility.compatibleAsIs() 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/data/Product.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.data 2 | 3 | import java.time.Instant 4 | 5 | import nl.mrooding.state.ProductSerializer 6 | import org.apache.avro.generic.{GenericData, GenericRecord} 7 | import org.apache.flink.api.common.typeutils.TypeSerializer 8 | 9 | case class Product( 10 | id: String, 11 | description: Option[String], 12 | stock: Option[Long], 13 | updatedAt: Instant 14 | ) extends AvroGenericRecordWriter { 15 | 16 | def toGenericRecord: GenericRecord = { 17 | val genericRecord = new GenericData.Record(Product.getCurrentSchema) 18 | genericRecord.put("id", id) 19 | genericRecord.put("description", description.orNull) 20 | genericRecord.put("stock", stock.getOrElse(0l)) 21 | genericRecord.put("updatedAt", updatedAt.toEpochMilli) 22 | 23 | genericRecord 24 | } 25 | } 26 | 27 | object Product extends AvroSchema with AvroSerializable[Product] { 28 | val schemaPath: String = "/avro/product.avsc" 29 | 30 | val serializer: TypeSerializer[Product] = new ProductSerializer(None) 31 | 32 | def apply(record: GenericRecord): Product = { 33 | Product( 34 | id = record.get("id").toString, 35 | description = Option(record.get("description")).map(_.toString), 36 | stock = Option(record.get("stock")).map(_.asInstanceOf[Long]), 37 | updatedAt = Instant.ofEpochMilli(record.get("updatedAt").asInstanceOf[Long]) 38 | ) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/ProductProcessor.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding 2 | 3 | import nl.mrooding.data.{Product, ProductDescription, ProductStock} 4 | import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor} 5 | import org.apache.flink.streaming.api.functions.co.CoProcessFunction 6 | import org.apache.flink.util.Collector 7 | 8 | case class ProductProcessor() extends CoProcessFunction[ProductDescription, ProductStock, Product] { 9 | private[this] lazy val stateDescriptor: ValueStateDescriptor[Product] = 10 | new ValueStateDescriptor[Product]("product-join", Product.serializer) 11 | private[this] lazy val state: ValueState[Product] = getRuntimeContext.getState(stateDescriptor) 12 | 13 | override def processElement1(value: ProductDescription, ctx: CoProcessFunction[ProductDescription, ProductStock, Product]#Context, out: Collector[Product]): Unit = { 14 | val product = Option(state.value()) match { 15 | case Some(stateProduct) => 16 | stateProduct.copy( 17 | description = Some(value.description), 18 | updatedAt = value.updatedAt 19 | ) 20 | case None => 21 | Product( 22 | id = value.id, 23 | description = Some(value.description), 24 | stock = None, 25 | updatedAt = value.updatedAt) 26 | } 27 | 28 | state.update(product) 29 | out.collect(product) 30 | } 31 | 32 | override def processElement2(value: ProductStock, ctx: CoProcessFunction[ProductDescription, ProductStock, Product]#Context, out: Collector[Product]): Unit = { 33 | val product = Option(state.value()) match { 34 | case Some(stateProduct) => 35 | stateProduct.copy( 36 | stock = Some(value.stock), 37 | updatedAt = value.updatedAt 38 | ) 39 | case None => 40 | Product( 41 | id = value.id, 42 | description = None, 43 | stock = Some(value.stock), 44 | updatedAt = value.updatedAt) 45 | } 46 | 47 | state.update(product) 48 | out.collect(product) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/nl/mrooding/state/CustomAvroSerializer.scala: -------------------------------------------------------------------------------- 1 | package nl.mrooding.state 2 | 3 | import akka.util.ByteString 4 | import nl.mrooding.data.AvroGenericRecordWriter 5 | import org.apache.avro.Schema 6 | import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter, GenericRecord} 7 | import org.apache.avro.io.{DecoderFactory, EncoderFactory} 8 | import org.apache.flink.api.common.typeutils._ 9 | import org.apache.flink.core.memory.{DataInputView, DataOutputView} 10 | 11 | trait CustomAvroSerializer[T <: AvroGenericRecordWriter] extends TypeSerializer[T] with Serializable { 12 | def stateSchema: Option[Schema] 13 | def getCurrentSchema: Schema 14 | 15 | def fromGenericRecord(genericRecord: GenericRecord): T 16 | 17 | override def serialize(instance: T, target: DataOutputView): Unit = { 18 | val genericRecord = instance.toGenericRecord 19 | 20 | val builder = ByteString.newBuilder 21 | val avroEncoder = EncoderFactory.get().binaryEncoder(builder.asOutputStream, null) 22 | new GenericDatumWriter[GenericRecord](genericRecord.getSchema).write(genericRecord, avroEncoder) 23 | avroEncoder.flush() 24 | 25 | val blob = builder.result().toArray 26 | 27 | target.writeInt(blob.length) 28 | target.write(blob) 29 | } 30 | 31 | override def deserialize(source: DataInputView): T = { 32 | val blobSize = source.readInt() 33 | val blob = new Array[Byte](blobSize) 34 | source.read(blob) 35 | 36 | val decoder = DecoderFactory.get().binaryDecoder(blob, null) 37 | 38 | val reader = stateSchema match { 39 | case Some(previousSchema) => new GenericDatumReader[GenericRecord](previousSchema, getCurrentSchema) 40 | case None => new GenericDatumReader[GenericRecord](getCurrentSchema) 41 | } 42 | val genericRecord = reader.read(null, decoder) 43 | 44 | fromGenericRecord(genericRecord) 45 | } 46 | 47 | /* 48 | Default functions required for TypeSerializer 49 | */ 50 | 51 | override def equals(obj: scala.Any): Boolean = { 52 | obj match { 53 | case _ => false 54 | } 55 | } 56 | 57 | override def hashCode(): Int = 1 58 | 59 | override def isImmutableType: Boolean = false 60 | 61 | override def getLength: Int = -1 62 | 63 | override def copy(from: T): T = from 64 | 65 | override def copy(from: T, reuse: T): T = copy(from) 66 | 67 | override def copy(source: DataInputView, target: DataOutputView): Unit = serialize(deserialize(source), target) 68 | 69 | override def deserialize(reuse: T, source: DataInputView): T = deserialize(source) 70 | } 71 | --------------------------------------------------------------------------------