├── .gitignore ├── .travis.yml ├── LICENSE ├── Makefile ├── NOTICE ├── README.md ├── avro └── src │ ├── main │ └── scala │ │ ├── AvroRandomExtractor.scala │ │ ├── AvroRandomGenerator.scala │ │ ├── AvroToRow.scala │ │ ├── AvroToSchema.scala │ │ └── AvroTransformer.scala │ └── test │ └── scala │ └── test │ ├── AvroExtractorSpec.scala │ ├── AvroRandomGeneratorSpec.scala │ ├── AvroToRowSpec.scala │ ├── AvroToSchemaSpec.scala │ ├── AvroTransformerSpec.scala │ └── util │ ├── Fixtures.scala │ ├── LocalSparkContext.scala │ ├── TestLogger.scala │ └── UnitSpec.scala ├── build.sbt ├── kafka └── src │ └── main │ └── scala │ ├── CheckpointedDirectKafkaInputDStream.scala │ ├── CheckpointedKafkaUtils.scala │ └── CheckpointingKafkaExtractor.scala ├── project └── assembly.sbt ├── src ├── main │ └── scala │ │ └── com │ │ └── memsql │ │ └── streamliner │ │ └── examples │ │ ├── Extractors.scala │ │ ├── FileExtractor.scala │ │ ├── S3AccessLogsTransformer.scala │ │ └── Transformers.scala └── test │ ├── resources │ ├── log4j.properties │ └── tweets │ └── scala │ └── test │ ├── ExtractorsSpec.scala │ ├── TransformersSpec.scala │ └── util │ ├── LocalSparkContext.scala │ ├── TestLogger.scala │ └── UnitSpec.scala └── thrift └── src ├── main └── scala │ ├── ThriftRandomExtractor.scala │ ├── ThriftRandomGenerator.scala │ ├── ThriftToRow.scala │ ├── ThriftToRowSerializer.scala │ ├── ThriftToSchema.scala │ └── ThriftTransformer.scala └── test ├── scala ├── ThriftRandomGeneratorSpec.scala ├── ThriftToRowSpec.scala └── ThriftToSchemaSpec.scala └── thrift └── TestClass.thrift /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .idea/ 3 | thrift/src/test/java/* 4 | avro/src/test/java/* 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.10.5 4 | 5 | # we need to install thrift and compile the test schema 6 | # from 7 | before_install: 8 | - sudo apt-get update -qq 9 | - sudo apt-get install libboost-dev libboost-test-dev libboost-program-options-dev libevent-dev automake libtool flex bison pkg-config g++ libssl-dev ant 10 | - wget http://www.us.apache.org/dist/thrift/0.9.3/thrift-0.9.3.tar.gz 11 | - tar xfz thrift-0.9.3.tar.gz 12 | - cd thrift-0.9.3 && ./configure --without-ruby && sudo make install 13 | - cd $TRAVIS_BUILD_DIR && mkdir -p thrift/src/test/java/ && thrift -o thrift/src/test/java/ --gen java thrift/src/test/thrift/TestClass.thrift 14 | 15 | script: 16 | - sbt test 17 | - sbt "project avro" test 18 | - sbt "project thrift" test 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2015 MemSQL (http://www.memsql.com) 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | VERSION := $(shell sbt 'export version' | tail -n 1) 4 | export VERSION 5 | 6 | .PHONY: default 7 | default: build 8 | 9 | .PHONY: version 10 | version: 11 | @echo $(VERSION) 12 | 13 | .PHONY: clean 14 | clean: 15 | sbt clean \ 16 | "project thrift" clean \ 17 | "project avro" clean 18 | 19 | ############################## 20 | # PROJECT BUILD RULES 21 | # 22 | .PHONY: build 23 | build: clean 24 | sbt assembly 25 | 26 | .PHONY: build-thrift 27 | build-thrift: clean 28 | sbt "project thrift" assembly 29 | 30 | .PHONY: build-avro 31 | build-avro: clean 32 | sbt "project avro" assembly 33 | 34 | ############################## 35 | # PROJECT TEST RULES 36 | # 37 | .PHONY: test 38 | test: 39 | sbt test 40 | 41 | .PHONY: thrift-test-deps-compile 42 | thrift-test-deps-compile: 43 | mkdir -p thrift/src/test/java 44 | thrift -o thrift/src/test/java/ --gen java thrift/src/test/thrift/TestClass.thrift 45 | 46 | .PHONY: test-thrift 47 | test-thrift: thrift-test-deps-compile 48 | sbt "project thrift" test 49 | 50 | .PHONY: test-avro 51 | test-avro: 52 | sbt "project avro" test 53 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | MemSQL Spark Streamliner Examples 2 | Copyright 2015 MemSQL (http://www.memsql.com). 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | MemSQL Spark Streamliner Examples 2 | ================================= 3 | **MemSQL Streamliner is a deprecated feature, and will be deleted in MemSQL 6.0** 4 | 5 | [![Build Status](https://travis-ci.org/memsql/streamliner-examples.svg?branch=master)](https://travis-ci.org/memsql/streamliner-examples) 6 | 7 | This is a repository featuring example code for the [MemSQL Spark Streamliner](http://docs.memsql.com/latest/spark/). 8 | 9 | MemSQL Spark Streamliner lets you build custom Spark pipelines to: 10 | 1. extract from real-time data sources such as Kafka, 11 | 2. transform data structures such as CSV, JSON, or Thrift in table rows, 12 | 3. load your data into MemSQL. 13 | 14 | Check out: 15 | 16 | - [Examples of Extractors](./src/main/scala/com/memsql/streamliner/examples/Extractors.scala) 17 | - [Examples of Transformers](./src/main/scala/com/memsql/streamliner/examples/Transformers.scala) 18 | - ... and browse the code for more 19 | 20 | 21 | Get Started with MemSQL Spark Streamliner 22 | ----------------------------------------- 23 | 24 | Check out the [MemSQL Spark Streamliner Starter](https://github.com/memsql/streamliner-starter) repository. 25 | 26 | Or read more on how to [create custom Spark Interface JARs](http://docs.memsql.com/latest/spark/memsql-spark-interface/) in our docs. 27 | 28 | 29 | Contribute to MemSQL Spark Streamliner Examples 30 | ----------------------------------------------- 31 | 32 | Please submit a pull request with new Extractors and Transformers. 33 | 34 | When you contribute code, you affirm that the contribution is your original work and that you license the work to the project under the project's open source license. Whether or not you state this explicitly, by submitting any copyrighted material via pull request, email, or other means you agree to license the material under the project's open source license and warrant that you have the legal authority to do so. 35 | 36 | 37 | Build the Examples 38 | ------------------ 39 | 40 | Clone the repository, then run: 41 | 42 | ```bash 43 | make build 44 | ``` 45 | 46 | The JAR will be placed in `target/scala-/`. You can upload the JAR to MemSQL Ops and create a pipeline using this or your custom code. 47 | 48 | 49 | Run Tests 50 | --------- 51 | 52 | Run: 53 | 54 | ```bash 55 | make test 56 | ``` 57 | -------------------------------------------------------------------------------- /avro/src/main/scala/AvroRandomExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import com.memsql.spark.etl.api._ 4 | import com.memsql.spark.etl.utils.PhaseLogger 5 | import org.apache.spark.streaming.StreamingContext 6 | import org.apache.spark.sql.{SQLContext, DataFrame, Row} 7 | import org.apache.spark.sql.types._ 8 | import org.apache.avro.Schema 9 | import org.apache.avro.generic.GenericData 10 | import org.apache.avro.io.{DatumWriter, EncoderFactory} 11 | import org.apache.avro.specific.SpecificDatumWriter 12 | 13 | import java.io.ByteArrayOutputStream 14 | 15 | // Generates an RDD of byte arrays, where each is a serialized Avro record. 16 | class AvroRandomExtractor extends Extractor { 17 | var count: Int = 1 18 | var generator: AvroRandomGenerator = null 19 | var writer: DatumWriter[GenericData.Record] = null 20 | var avroSchema: Schema = null 21 | 22 | def schema: StructType = StructType(StructField("bytes", BinaryType, false) :: Nil) 23 | 24 | val parser: Schema.Parser = new Schema.Parser() 25 | 26 | override def initialize(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Unit = { 27 | val userConfig = config.asInstanceOf[UserExtractConfig] 28 | val avroSchemaJson = userConfig.getConfigJsValue("avroSchema") match { 29 | case Some(s) => s 30 | case None => throw new IllegalArgumentException("avroSchema must be set in the config") 31 | } 32 | count = userConfig.getConfigInt("count").getOrElse(1) 33 | avroSchema = parser.parse(avroSchemaJson.toString) 34 | 35 | writer = new SpecificDatumWriter(avroSchema) 36 | generator = new AvroRandomGenerator(avroSchema) 37 | } 38 | 39 | override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Option[DataFrame] = { 40 | val rdd = sqlContext.sparkContext.parallelize((1 to count).map(_ => Row({ 41 | val out = new ByteArrayOutputStream 42 | val encoder = EncoderFactory.get().binaryEncoder(out, null) 43 | val avroRecord: GenericData.Record = generator.next().asInstanceOf[GenericData.Record] 44 | 45 | writer.write(avroRecord, encoder) 46 | encoder.flush 47 | out.close 48 | out.toByteArray 49 | }))) 50 | 51 | Some(sqlContext.createDataFrame(rdd, schema)) 52 | } 53 | } 54 | 55 | -------------------------------------------------------------------------------- /avro/src/main/scala/AvroRandomGenerator.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import org.apache.avro.Schema 4 | import org.apache.avro.generic.GenericData 5 | 6 | import scala.collection.JavaConversions._ 7 | import scala.util.Random 8 | 9 | class AvroRandomGenerator(inSchema: Schema) { 10 | // Avoid nested Records, since our destination is a DataFrame. 11 | val MAX_RECURSION_LEVEL: Int = 1 12 | 13 | val topSchema: Schema = inSchema 14 | val random = new Random 15 | 16 | def next(schema: Schema = this.topSchema, level: Int = 0): Any = { 17 | if (level <= MAX_RECURSION_LEVEL) { 18 | 19 | schema.getType match { 20 | case Schema.Type.RECORD => { 21 | val datum = new GenericData.Record(schema) 22 | schema.getFields.foreach { 23 | x => datum.put(x.pos, next(x.schema, level + 1)) 24 | } 25 | datum 26 | } 27 | 28 | case Schema.Type.UNION => { 29 | val types = schema.getTypes 30 | // Generate a value using the first type in the union. 31 | // "Random type" is also a valid option. 32 | next(types(0), level) 33 | } 34 | 35 | case _ => generateValue(schema.getType) 36 | } 37 | 38 | } else { 39 | null 40 | } 41 | } 42 | 43 | def generateValue(avroType: Schema.Type): Any = avroType match { 44 | case Schema.Type.BOOLEAN => random.nextBoolean 45 | case Schema.Type.DOUBLE => random.nextDouble 46 | case Schema.Type.FLOAT => random.nextFloat 47 | case Schema.Type.INT => random.nextInt 48 | case Schema.Type.LONG => random.nextLong 49 | case Schema.Type.NULL => null 50 | case Schema.Type.STRING => getRandomString 51 | case _ => null 52 | } 53 | 54 | def getRandomString(): String = { 55 | val length: Int = 5 + random.nextInt(5) 56 | (1 to length).map(x => ('a'.toInt + random.nextInt(26)).toChar).mkString 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /avro/src/main/scala/AvroToRow.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import collection.JavaConversions._ 4 | import org.apache.spark.sql.Row 5 | import org.apache.avro.Schema 6 | import org.apache.avro.generic.GenericData 7 | 8 | // Converts an Avro record to a Spark DataFrame row. 9 | // 10 | // This assumes that the Avro schema is "flat", i.e. a Record that includes primitive types 11 | // or unions of primitive types. Unions, and Avro types that don't directly map to Scala types, 12 | // are converted to Strings and put in a Spark SQL StringType column. 13 | private class AvroToRow { 14 | def getRow(record: GenericData.Record): Row = { 15 | Row.fromSeq(record.getSchema.getFields().map(f => { 16 | val schema = f.schema() 17 | val obj = record.get(f.pos) 18 | 19 | schema.getType match { 20 | case Schema.Type.BOOLEAN => obj.asInstanceOf[Boolean] 21 | case Schema.Type.DOUBLE => obj.asInstanceOf[Double] 22 | case Schema.Type.FLOAT => obj.asInstanceOf[Float] 23 | case Schema.Type.INT => obj.asInstanceOf[Int] 24 | case Schema.Type.LONG => obj.asInstanceOf[Long] 25 | case Schema.Type.NULL => null 26 | 27 | case _ => obj.toString 28 | } 29 | })) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /avro/src/main/scala/AvroToSchema.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import collection.JavaConversions._ 4 | import org.apache.spark.sql.types._ 5 | import org.apache.avro.Schema 6 | 7 | // Converts an Avro schema to a Spark DataFrame schema. 8 | // 9 | // This assumes that the Avro schema is "flat", i.e. a Record that includes primitive types 10 | // or unions of primitive types. Unions, and Avro types that don't directly map to Scala types, 11 | // are converted to Strings and put in a Spark SQL StringType column. 12 | private object AvroToSchema { 13 | def getSchema(schema: Schema): StructType = { 14 | StructType(schema.getFields.map(field => { 15 | val fieldName = field.name 16 | val fieldSchema = field.schema 17 | val fieldType = fieldSchema.getType match { 18 | case Schema.Type.BOOLEAN => BooleanType 19 | case Schema.Type.DOUBLE => DoubleType 20 | case Schema.Type.FLOAT => FloatType 21 | case Schema.Type.INT => IntegerType 22 | case Schema.Type.LONG => LongType 23 | case Schema.Type.NULL => NullType 24 | case Schema.Type.STRING => StringType 25 | case _ => StringType 26 | } 27 | StructField(fieldName, fieldType.asInstanceOf[DataType], true) 28 | })) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /avro/src/main/scala/AvroTransformer.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import com.memsql.spark.etl.api.{UserTransformConfig, Transformer, PhaseConfig} 4 | import com.memsql.spark.etl.utils.PhaseLogger 5 | import org.apache.spark.rdd.RDD 6 | import org.apache.spark.sql.{SQLContext, DataFrame, Row} 7 | import org.apache.spark.sql.types.StructType 8 | 9 | import org.apache.avro.Schema 10 | import org.apache.avro.generic.GenericData 11 | import org.apache.avro.io.DecoderFactory 12 | import org.apache.avro.specific.SpecificDatumReader 13 | 14 | // Takes DataFrames of byte arrays, where each row is a serialized Avro record. 15 | // Returns DataFrames of deserialized data, where each field has its own column. 16 | class AvroTransformer extends Transformer { 17 | var avroSchemaStr: String = null 18 | var sparkSqlSchema: StructType = null 19 | 20 | def AvroRDDToDataFrame(sqlContext: SQLContext, rdd: RDD[Row]): DataFrame = { 21 | 22 | val rowRDD: RDD[Row] = rdd.mapPartitions({ partition => { 23 | // Create per-partition copies of non-serializable objects 24 | val parser: Schema.Parser = new Schema.Parser() 25 | val avroSchema = parser.parse(avroSchemaStr) 26 | val reader = new SpecificDatumReader[GenericData.Record](avroSchema) 27 | 28 | partition.map({ rowOfBytes => 29 | val bytes = rowOfBytes(0).asInstanceOf[Array[Byte]] 30 | val decoder = DecoderFactory.get().binaryDecoder(bytes, null) 31 | val record = reader.read(null, decoder) 32 | val avroToRow = new AvroToRow() 33 | 34 | avroToRow.getRow(record) 35 | }) 36 | }}) 37 | sqlContext.createDataFrame(rowRDD, sparkSqlSchema) 38 | } 39 | 40 | override def initialize(sqlContext: SQLContext, config: PhaseConfig, logger: PhaseLogger): Unit = { 41 | val userConfig = config.asInstanceOf[UserTransformConfig] 42 | 43 | val avroSchemaJson = userConfig.getConfigJsValue("avroSchema") match { 44 | case Some(s) => s 45 | case None => throw new IllegalArgumentException("avroSchema must be set in the config") 46 | } 47 | avroSchemaStr = avroSchemaJson.toString 48 | 49 | val parser = new Schema.Parser() 50 | val avroSchema = parser.parse(avroSchemaJson.toString) 51 | sparkSqlSchema = AvroToSchema.getSchema(avroSchema) 52 | } 53 | 54 | override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = { 55 | AvroRDDToDataFrame(sqlContext, df.rdd) 56 | } 57 | } 58 | 59 | 60 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/AvroExtractorSpec.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import com.memsql.spark.etl.api.UserExtractConfig 4 | import org.apache.spark.streaming._ 5 | import org.apache.spark.sql.SQLContext 6 | import test.util.{Fixtures, UnitSpec, TestLogger, LocalSparkContext} 7 | import spray.json._ 8 | 9 | class ExtractorsSpec extends UnitSpec with LocalSparkContext { 10 | var ssc: StreamingContext = _ 11 | var sqlContext: SQLContext = _ 12 | 13 | override def beforeEach(): Unit = { 14 | super.beforeEach() 15 | ssc = new StreamingContext(sc, Seconds(1)) 16 | sqlContext = new SQLContext(sc) 17 | } 18 | 19 | val avroConfig = Fixtures.avroConfig.parseJson 20 | val extractConfig = UserExtractConfig(class_name = "Test", value = avroConfig) 21 | val logger = new TestLogger("test") 22 | 23 | "AvroRandomExtractor" should "emit a random DF" in { 24 | val extract = new AvroRandomExtractor 25 | extract.initialize(ssc, sqlContext, extractConfig, 1, logger) 26 | 27 | val maybeDf = extract.next(ssc, 1, sqlContext, extractConfig, 1, logger) 28 | assert(maybeDf.isDefined) 29 | assert(maybeDf.get.count == 5) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/AvroRandomGeneratorSpec.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import org.scalatest._ 4 | import org.apache.avro.Schema 5 | import org.apache.avro.generic.GenericData 6 | import test.util.Fixtures 7 | 8 | class AvroRandomGeneratorSpec extends FlatSpec { 9 | "AvroRandomGenerator" should "create Avro objects with random values" in { 10 | val schema = new Schema.Parser().parse(Fixtures.avroSchema) 11 | val avroRecord:GenericData.Record = new AvroRandomGenerator(schema).next().asInstanceOf[GenericData.Record] 12 | 13 | assert(avroRecord.get("testBool").isInstanceOf[Boolean]) 14 | assert(avroRecord.get("testDouble").isInstanceOf[Double]) 15 | assert(avroRecord.get("testFloat").isInstanceOf[Float]) 16 | assert(avroRecord.get("testInt").isInstanceOf[Int]) 17 | assert(avroRecord.get("testLong").isInstanceOf[Long]) 18 | assert(avroRecord.get("testNull") == null) 19 | assert(avroRecord.get("testString").isInstanceOf[String]) 20 | assert(avroRecord.get("testUnion").isInstanceOf[Int]) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/AvroToRowSpec.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import com.memsql.spark.connector.dataframe.JsonValue 4 | 5 | import org.apache.avro.Schema 6 | import org.apache.avro.generic.GenericData 7 | import org.apache.spark.sql.Row 8 | import test.util.Fixtures 9 | 10 | import collection.JavaConversions._ 11 | import java.nio.ByteBuffer 12 | import org.scalatest._ 13 | 14 | class AvroToRowSpec extends FlatSpec { 15 | "AvroToRow" should "create Spark SQL Rows from Avro objects" in { 16 | val parser: Schema.Parser = new Schema.Parser() 17 | val avroTestSchema: Schema = parser.parse(Fixtures.avroSchema) 18 | 19 | val record: GenericData.Record = new GenericData.Record(avroTestSchema) 20 | 21 | record.put("testBool", true) 22 | record.put("testDouble", 19.88) 23 | record.put("testFloat", 3.19f) 24 | record.put("testInt", 1123) 25 | record.put("testLong", 2147483648L) 26 | record.put("testNull", null) 27 | record.put("testString", "Conor") 28 | record.put("testUnion", 17) 29 | 30 | val row: Row = new AvroToRow().getRow(record) 31 | 32 | assert(row.getAs[Boolean](0)) 33 | assert(row.getAs[Double](1) == 19.88) 34 | assert(row.getAs[Float](2) == 3.19f) 35 | assert(row.getAs[Int](3) == 1123) 36 | assert(row.getAs[Long](4) == 2147483648L) 37 | assert(row.getAs[Null](5) == null) 38 | assert(row.getAs[String](6) == "Conor") 39 | assert(row.getAs[String](7) == "17") 40 | } 41 | } 42 | 43 | 44 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/AvroToSchemaSpec.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.avro 2 | 3 | import com.memsql.spark.connector.dataframe.JsonType 4 | import org.apache.spark.sql.types._ 5 | import org.apache.avro.Schema 6 | import org.scalatest._ 7 | import test.util.Fixtures 8 | 9 | class AvroToSchemaSpec extends FlatSpec { 10 | "AvroToSchema" should "create a Spark SQL schema from an Avro schema" in { 11 | val parser = new Schema.Parser() 12 | val avroSchema = parser.parse(Fixtures.avroSchema) 13 | val sparkSchema = AvroToSchema.getSchema(avroSchema) 14 | val fields = sparkSchema.fields 15 | 16 | assert(fields.forall(field => field.nullable)) 17 | assert(fields(0).name == "testBool") 18 | assert(fields(0).dataType == BooleanType) 19 | 20 | assert(fields(1).name == "testDouble") 21 | assert(fields(1).dataType == DoubleType) 22 | 23 | assert(fields(2).name == "testFloat") 24 | assert(fields(2).dataType == FloatType) 25 | 26 | assert(fields(3).name == "testInt") 27 | assert(fields(3).dataType == IntegerType) 28 | 29 | assert(fields(4).name == "testLong") 30 | assert(fields(4).dataType == LongType) 31 | 32 | assert(fields(5).name == "testNull") 33 | assert(fields(5).dataType == NullType) 34 | 35 | assert(fields(6).name == "testString") 36 | assert(fields(6).dataType == StringType) 37 | 38 | assert(fields(7).name == "testUnion") 39 | assert(fields(7).dataType == StringType) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/AvroTransformerSpec.scala: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import com.memsql.spark.connector.MemSQLContext 4 | import com.memsql.spark.etl.api.{UserTransformConfig, UserExtractConfig} 5 | import com.memsql.spark.examples.avro.{AvroTransformer, AvroRandomExtractor} 6 | import org.apache.spark.streaming.{StreamingContext, Seconds} 7 | import test.util.{Fixtures, UnitSpec, LocalSparkContext} 8 | import spray.json._ 9 | 10 | class AvroTransformerSpec extends UnitSpec with LocalSparkContext { 11 | var ssc: StreamingContext = _ 12 | var msc: MemSQLContext = _ 13 | 14 | override def beforeEach(): Unit = { 15 | super.beforeEach() 16 | ssc = new StreamingContext(sc, Seconds(1)) 17 | msc = new MemSQLContext(sc) 18 | } 19 | 20 | val avroConfig = Fixtures.avroConfig.parseJson 21 | val extractConfig = UserExtractConfig(class_name = "Test", value = avroConfig) 22 | val transformConfig = UserTransformConfig(class_name = "Test", value = avroConfig) 23 | 24 | "AvroRandomTransformer" should "emit a dataframe of properly deserialized data" in { 25 | val extractor = new AvroRandomExtractor 26 | val transformer = new AvroTransformer 27 | 28 | extractor.initialize(null, null, extractConfig, 0, null) 29 | transformer.initialize(null, transformConfig, null) 30 | 31 | val maybeDf = extractor.next(null, 0, msc, null, 0, null) 32 | assert(maybeDf.isDefined) 33 | val extractedDf = maybeDf.get 34 | 35 | val transformedDf = transformer.transform(msc, extractedDf, null, null) 36 | 37 | val rows = transformedDf.collect() 38 | for (row <- rows) { 39 | assert(row(0).isInstanceOf[Boolean]) 40 | assert(row(1).isInstanceOf[Double]) 41 | assert(row(2).isInstanceOf[Float]) 42 | assert(row(3).isInstanceOf[Int]) 43 | assert(row(4).isInstanceOf[Long]) 44 | assert(row(5) === null) 45 | assert(row(6).isInstanceOf[String]) 46 | assert(row(7).isInstanceOf[String]) 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/util/Fixtures.scala: -------------------------------------------------------------------------------- 1 | package test.util 2 | 3 | object Fixtures { 4 | 5 | val avroSchema = s""" 6 | { 7 | "namespace": "com.memsql.spark.examples.avro", 8 | "type": "record", 9 | "name": "TestSchema", 10 | "fields": [ 11 | { 12 | "name": "testBool", 13 | "type": "boolean" 14 | }, 15 | { 16 | "name": "testDouble", 17 | "type": "double" 18 | }, 19 | { 20 | "name": "testFloat", 21 | "type": "float" 22 | }, 23 | { 24 | "name": "testInt", 25 | "type": "int" 26 | }, 27 | { 28 | "name": "testLong", 29 | "type": "long" 30 | }, 31 | { 32 | "name": "testNull", 33 | "type": "null" 34 | }, 35 | { 36 | "name": "testString", 37 | "type": "string" 38 | }, 39 | { 40 | "name": "testUnion", 41 | "type": [ 42 | "int", 43 | "string", 44 | "null" 45 | ] 46 | } 47 | ] 48 | } 49 | """ 50 | 51 | 52 | val avroConfig = s""" 53 | { 54 | "count": 5, 55 | "avroSchema": $avroSchema 56 | } 57 | """ 58 | } 59 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/util/LocalSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * 17 | * With small modifications by MemSQL 18 | */ 19 | 20 | package test.util 21 | 22 | import org.apache.spark.{SparkConf, SparkContext} 23 | import org.scalatest.BeforeAndAfterEach 24 | import org.scalatest._ 25 | 26 | trait LocalSparkContext extends BeforeAndAfterEach { self: Suite => 27 | 28 | @transient private var _sc: SparkContext = _ 29 | 30 | val _sparkConf = new SparkConf(false) 31 | .set("spark.ui.showConsoleProgress", "false") 32 | 33 | def sc: SparkContext = _sc 34 | 35 | override def beforeEach() { 36 | _sc = new SparkContext("local[4]", "test", _sparkConf) 37 | super.beforeEach() 38 | } 39 | 40 | override def afterEach() { 41 | resetSparkContext() 42 | super.afterEach() 43 | } 44 | 45 | def resetSparkContext(): Unit = { 46 | LocalSparkContext.stop(_sc) 47 | _sc = null 48 | } 49 | 50 | } 51 | 52 | object LocalSparkContext { 53 | def stop(sc: SparkContext) { 54 | if (sc != null) { 55 | sc.stop() 56 | } 57 | // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown 58 | System.clearProperty("spark.driver.port") 59 | } 60 | 61 | /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ 62 | def withSpark[T](sc: SparkContext)(f: SparkContext => T): T = { 63 | try { 64 | f(sc) 65 | } finally { 66 | stop(sc) 67 | } 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/util/TestLogger.scala: -------------------------------------------------------------------------------- 1 | package test.util 2 | 3 | import com.memsql.spark.etl.utils.PhaseLogger 4 | import org.apache.log4j.Logger 5 | 6 | class TestLogger(override val name: String) extends PhaseLogger { 7 | override protected val logger: Logger = Logger.getRootLogger 8 | } 9 | 10 | -------------------------------------------------------------------------------- /avro/src/test/scala/test/util/UnitSpec.scala: -------------------------------------------------------------------------------- 1 | package test.util 2 | 3 | import org.scalatest._ 4 | 5 | abstract class UnitSpec 6 | extends FlatSpec 7 | with Matchers 8 | with OptionValues 9 | with Inside 10 | with Inspectors 11 | with BeforeAndAfter 12 | with BeforeAndAfterEach 13 | with BeforeAndAfterAll 14 | with OneInstancePerTest { 15 | } 16 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | lazy val commonSettings = Seq( 2 | organization := "com.memsql", 3 | version := "0.0.1", 4 | scalaVersion := "2.10.5" 5 | ) 6 | 7 | lazy val avro = (project in file("avro")). 8 | settings(commonSettings: _*). 9 | settings( 10 | name := "memsql-spark-streamliner-avro-examples", 11 | parallelExecution in Test := false, 12 | libraryDependencies ++= { 13 | Seq( 14 | "org.apache.spark" %% "spark-core" % "1.5.2" % "provided", 15 | "org.apache.spark" %% "spark-streaming" % "1.5.2" % "provided", 16 | "org.apache.spark" %% "spark-sql" % "1.5.2" % "provided", 17 | "org.apache.avro" % "avro" % "1.7.7", 18 | "org.scalatest" %% "scalatest" % "2.2.5" % "test", 19 | "com.memsql" %% "memsql-etl" % "1.3.3" 20 | ) 21 | } 22 | ) 23 | 24 | lazy val thrift = (project in file("thrift")). 25 | settings(commonSettings: _*). 26 | settings( 27 | name := "memsql-spark-streamliner-thrift-examples", 28 | parallelExecution in Test := false, 29 | libraryDependencies ++= { 30 | Seq( 31 | "org.apache.spark" %% "spark-core" % "1.5.2" % "provided", 32 | "org.apache.spark" %% "spark-streaming" % "1.5.2" % "provided", 33 | "org.apache.spark" %% "spark-sql" % "1.5.2" % "provided", 34 | "org.apache.thrift" % "libthrift" % "0.9.2", 35 | "org.scalatest" %% "scalatest" % "2.2.5" % "test", 36 | "com.memsql" %% "memsql-etl" % "1.3.3" 37 | ) 38 | } 39 | ) 40 | 41 | lazy val kafka = (project in file("kafka")). 42 | settings(commonSettings: _*). 43 | settings( 44 | name := "memsql-spark-streamliner-kafka-examples", 45 | parallelExecution in Test := false, 46 | libraryDependencies ++= { 47 | Seq( 48 | "org.apache.spark" %% "spark-core" % "1.5.2" % "provided", 49 | "org.apache.spark" %% "spark-streaming" % "1.5.2" % "provided", 50 | "org.apache.spark" %% "spark-sql" % "1.5.2" % "provided", 51 | "org.apache.spark" %% "spark-streaming-kafka" % "1.5.2" exclude("org.spark-project.spark", "unused"), 52 | "org.scalatest" %% "scalatest" % "2.2.5" % "test", 53 | "com.memsql" %% "memsql-etl" % "1.3.3" 54 | ) 55 | } 56 | ) 57 | 58 | lazy val root = (project in file(".")). 59 | dependsOn(kafka). 60 | dependsOn(avro). 61 | dependsOn(thrift). 62 | settings(commonSettings: _*). 63 | settings( 64 | name := "memsql-spark-streamliner-examples", 65 | parallelExecution in Test := false, 66 | libraryDependencies ++= Seq( 67 | "org.apache.spark" %% "spark-core" % "1.5.2" % "provided", 68 | "org.apache.spark" %% "spark-sql" % "1.5.2" % "provided", 69 | "org.apache.spark" %% "spark-streaming" % "1.5.2" % "provided", 70 | "org.scalatest" %% "scalatest" % "2.2.5" % "test", 71 | "com.memsql" %% "memsql-etl" % "1.3.3" 72 | ) 73 | ) 74 | -------------------------------------------------------------------------------- /kafka/src/main/scala/CheckpointedDirectKafkaInputDStream.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.streaming.kafka 2 | 3 | /* 4 | * Licensed to the Apache Software Foundation (ASF) under one or more 5 | * contributor license agreements. See the NOTICE file distributed with 6 | * this work for additional information regarding copyright ownership. 7 | * The ASF licenses this file to You under the Apache License, Version 2.0 8 | * (the "License"); you may not use this file except in compliance with 9 | * the License. You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | * 19 | * Modified for MemSQL Streamliner 20 | */ 21 | 22 | import kafka.common.TopicAndPartition 23 | import kafka.message.MessageAndMetadata 24 | import kafka.serializer.Decoder 25 | import org.apache.spark.streaming.{Time, StreamingContext} 26 | 27 | import scala.reflect.ClassTag 28 | 29 | class CheckpointedDirectKafkaInputDStream[K: ClassTag, V: ClassTag, U <: Decoder[K]: ClassTag, T <: Decoder[V]: ClassTag, R: ClassTag]( 30 | @transient ssc_ : StreamingContext, 31 | override val kafkaParams: Map[String, String], 32 | override val fromOffsets: Map[TopicAndPartition, Long], 33 | messageHandler: MessageAndMetadata[K, V] => R, 34 | batchDuration: Long) extends DirectKafkaInputDStream[K, V, U, T, R](ssc_, kafkaParams, fromOffsets, messageHandler) { 35 | 36 | override val checkpointData = null 37 | 38 | //NOTE: We override this to use the pipeline specific batch duration 39 | override val maxMessagesPerPartition: Option[Long] = { 40 | val ratePerSec = context.sparkContext.getConf.getInt( 41 | "spark.streaming.kafka.maxRatePerPartition", 0) 42 | if (ratePerSec > 0) { 43 | val secsPerBatch = batchDuration / 1000 44 | Some(secsPerBatch * ratePerSec) 45 | } else { 46 | None 47 | } 48 | } 49 | 50 | //Track the previous batch's offsets so we can retry the batch if it fails 51 | var prevOffsets: Map[TopicAndPartition, Long] = null 52 | 53 | //NOTE: We override this to suppress input info tracking because the StreamingContext has not been started. 54 | override def compute(validTime: Time): Option[KafkaRDD[K, V, U, T, R]] = { 55 | val untilOffsets = clamp(latestLeaderOffsets(maxRetries)) 56 | val rdd = KafkaRDD[K, V, U, T, R]( 57 | context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) 58 | 59 | // The vanilla implementation calls the inputInfoTracker here. This code is left as a comment here for clarity. 60 | /* 61 | * Report the record number of this batch interval to InputInfoTracker. 62 | * val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum 63 | * val inputInfo = InputInfo(id, numRecords) 64 | * ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) 65 | */ 66 | 67 | prevOffsets = currentOffsets 68 | currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) 69 | 70 | prevOffsets == currentOffsets match { 71 | case false => Some(rdd) 72 | case true => None 73 | } 74 | } 75 | 76 | def getCurrentOffsets(): Map[TopicAndPartition, Long] = currentOffsets 77 | def setCurrentOffsets(offsets: Map[TopicAndPartition, Long]): Unit = { 78 | currentOffsets = offsets 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /kafka/src/main/scala/CheckpointedKafkaUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.streaming.kafka 2 | 3 | /* 4 | * Licensed to the Apache Software Foundation (ASF) under one or more 5 | * contributor license agreements. See the NOTICE file distributed with 6 | * this work for additional information regarding copyright ownership. 7 | * The ASF licenses this file to You under the Apache License, Version 2.0 8 | * (the "License"); you may not use this file except in compliance with 9 | * the License. You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | * 19 | * Modified for MemSQL Streamliner examples 20 | */ 21 | 22 | import com.memsql.spark.etl.utils.Logging 23 | import kafka.common.TopicAndPartition 24 | import kafka.utils.{ZkUtils, ZKStringSerializer} 25 | import org.I0Itec.zkclient.ZkClient 26 | 27 | import scala.reflect.ClassTag 28 | 29 | import kafka.message.MessageAndMetadata 30 | import kafka.serializer.Decoder 31 | 32 | import org.apache.spark.SparkException 33 | import org.apache.spark.annotation.Experimental 34 | import org.apache.spark.streaming.StreamingContext 35 | 36 | import scala.util.control.NonFatal 37 | 38 | class KafkaException(message: String) extends Exception(message) 39 | 40 | object CheckpointedKafkaUtils extends Logging { 41 | val ZK_SESSION_TIMEOUT = Int.MaxValue //milliseconds 42 | val ZK_CONNECT_TIMEOUT = 10000 //milliseconds 43 | 44 | /** 45 | * :: Experimental :: 46 | * Create an input stream that directly pulls messages from Kafka Brokers 47 | * without using any receiver. This stream can guarantee that each message 48 | * from Kafka is included in transformations exactly once (see points below). 49 | * NOTE: Modified to use Zookeeper quorum for MemSQL Streamliner. 50 | * 51 | * Points to note: 52 | * - No receivers: This stream does not use any receiver. It directly queries Kafka 53 | * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked 54 | * by the stream itself. For interoperability with Kafka monitoring tools that depend on 55 | * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. 56 | * You can access the offsets used in each batch from the generated RDDs (see 57 | * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). 58 | * - Failure Recovery: To recover from driver failures, you have to enable checkpointing 59 | * in the [[StreamingContext]]. The information on consumed offset can be 60 | * recovered from the checkpoint. See the programming guide for details (constraints, etc.). 61 | * - End-to-end semantics: This stream ensures that every records is effectively received and 62 | * transformed exactly once, but gives no guarantees on whether the transformed data are 63 | * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure 64 | * that the output operation is idempotent, or use transactions to output records atomically. 65 | * See the programming guide for more details. 66 | * 67 | * @param ssc StreamingContext object 68 | * @param kafkaParams Kafka 69 | * configuration parameters. Requires "memsql.zookeeper.connect" to be set with Zookeeper servers, 70 | * specified in host1:port1,host2:port2/chroot2 form. 71 | * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" 72 | * to determine where the stream starts (defaults to "largest") 73 | * @param topics Names of the topics to consume 74 | * @param batchInterval Batch interval for this pipeline. NOTE: Modified for MemSQL Streamliner 75 | * @param lastCheckpoint Offsets to use when initializing the consumer. If the topic, partition count, 76 | * or offsets from the checkpoint are invalid, fall back to the offsets specified by "auto.offset.reset". 77 | * NOTE: Modified for MemSQL Streamliner 78 | * 79 | */ 80 | @Experimental 81 | def createDirectStreamFromZookeeper[K: ClassTag, V: ClassTag, KD <: Decoder[K]: ClassTag, VD <: Decoder[V]: ClassTag] ( 82 | ssc: StreamingContext, 83 | kafkaParams: Map[String, String], 84 | topics: Set[String], 85 | batchInterval: Long, 86 | lastCheckpoint: Option[Map[String, Any]]): CheckpointedDirectKafkaInputDStream[K, V, KD, VD, V] = { 87 | val messageHandler = (mmd: MessageAndMetadata[K, V]) => mmd.message 88 | val brokers = getKafkaBrokersFromZookeeper(kafkaParams) 89 | val kafkaParamsWithBrokers = kafkaParams + ("metadata.broker.list" -> brokers.mkString(",")) 90 | 91 | val initialOffsets = getInitialOffsetsFromZookeeper(kafkaParamsWithBrokers, topics, lastCheckpoint) 92 | new CheckpointedDirectKafkaInputDStream[K, V, KD, VD, V]( 93 | ssc, kafkaParamsWithBrokers, initialOffsets, messageHandler, batchInterval) 94 | } 95 | 96 | private def getInitialOffsets(kafkaParams: Map[String, String], topics: Set[String], 97 | lastCheckpoint: Option[Map[String, Any]]): Map[TopicAndPartition, Long] = { 98 | val kc = new KafkaCluster(kafkaParams) 99 | val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) 100 | val broker = kafkaParams("metadata.broker.list") 101 | 102 | val checkpointBroker = lastCheckpoint.flatMap { x => x.get("broker") } 103 | val checkpointOffsets = lastCheckpoint.flatMap(getCheckpointOffsets) 104 | val zookeeperOffsets = getZookeeperOffsets(kc, topics, reset) 105 | 106 | (checkpointBroker, checkpointOffsets, zookeeperOffsets) match { 107 | case (None, _, _) => zookeeperOffsets 108 | case (_, None, _) => zookeeperOffsets 109 | case (Some(checkpointBroker), Some(offsets), _) => { 110 | if (checkpointBroker != broker) { 111 | logWarn("Kafka broker has changed since the last checkpoint, falling back to Zookeeper offsets") 112 | zookeeperOffsets 113 | } else if (offsets.size != zookeeperOffsets.size) { 114 | logWarn("Kafka partition count has changed since the last checkpoint, falling back to Zookeeper offsets") 115 | zookeeperOffsets 116 | } else if (offsets.nonEmpty && offsets.keys.head.topic != zookeeperOffsets.keys.head.topic) { 117 | logWarn("Kafka topic has changed since the last checkpoint, falling back to Zookeeper offsets") 118 | zookeeperOffsets 119 | } else { 120 | offsets 121 | } 122 | } 123 | } 124 | } 125 | 126 | private def getKafkaBrokersFromZookeeper(kafkaParams: Map[String, String]): Seq[String] = { 127 | val zkServerString = kafkaParams("memsql.zookeeper.connect") 128 | val zkClient = new ZkClient(zkServerString, ZK_SESSION_TIMEOUT, ZK_CONNECT_TIMEOUT, ZKStringSerializer) 129 | ZkUtils.getAllBrokersInCluster(zkClient).map { b => s"${b.host}:${b.port}" }.sorted 130 | } 131 | 132 | private def getInitialOffsetsFromZookeeper(kafkaParams: Map[String, String], topics: Set[String], 133 | lastCheckpoint: Option[Map[String, Any]]): Map[TopicAndPartition, Long] = { 134 | val zkServerString = kafkaParams("memsql.zookeeper.connect") 135 | val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) 136 | 137 | val kc = new KafkaCluster(kafkaParams) 138 | 139 | val checkpointZkServers = lastCheckpoint.flatMap { x => x.get("zookeeper") }.flatMap { 140 | case x: String => Some(x.split(",").sorted.mkString(",")) 141 | case default => None 142 | } 143 | val checkpointOffsets = lastCheckpoint.flatMap(getCheckpointOffsets) 144 | val zookeeperOffsets = getZookeeperOffsets(kc, topics, reset) 145 | 146 | (checkpointZkServers, checkpointOffsets, zookeeperOffsets) match { 147 | case (None, _, _) => zookeeperOffsets 148 | case (_, None, _) => zookeeperOffsets 149 | case (Some(checkpointZkServerString), Some(offsets), _) => { 150 | if (checkpointZkServerString != zkServerString) { 151 | logWarn("Zookeeper quorum list for this extractor has changed since the last checkpoint, falling back to default offsets") 152 | zookeeperOffsets 153 | } else if (offsets.size != zookeeperOffsets.size) { 154 | logWarn("Kafka partition count has changed since the last checkpoint, falling back to Zookeeper offsets") 155 | zookeeperOffsets 156 | } else if (offsets.nonEmpty && offsets.keys.head.topic != zookeeperOffsets.keys.head.topic) { 157 | logWarn("Kafka topic has changed since the last checkpoint, falling back to Zookeeper offsets") 158 | zookeeperOffsets 159 | } else { 160 | offsets 161 | } 162 | } 163 | } 164 | } 165 | 166 | // Serializes the checkpoint data into the format expected by KafkaDirectInputDStream 167 | private def getCheckpointOffsets(checkpoint: Map[String, Any]): Option[Map[TopicAndPartition, Long]] = { 168 | try { 169 | val offsets = checkpoint("offsets").asInstanceOf[List[Map[String, Any]]] 170 | val checkpointOffsets = offsets.map { partitionInfo => 171 | val topic = partitionInfo("topic").asInstanceOf[String] 172 | val partition = partitionInfo("partition").asInstanceOf[Int] 173 | val offset = partitionInfo("offset").asInstanceOf[Number].longValue 174 | 175 | (TopicAndPartition(topic, partition), offset) 176 | }.toMap 177 | Some(checkpointOffsets) 178 | } catch { 179 | case NonFatal(e) => { 180 | logWarn("Kafka checkpoint data is invalid, it will be ignored", e) 181 | None 182 | } 183 | } 184 | } 185 | 186 | private def getZookeeperOffsets(kc: KafkaCluster, topics: Set[String], reset: Option[String]): Map[TopicAndPartition, Long] = { 187 | (for { 188 | topicPartitions <- kc.getPartitions(topics).right 189 | leaderOffsets <- (if (reset == Some("smallest")) { 190 | kc.getEarliestLeaderOffsets(topicPartitions) 191 | } else { 192 | kc.getLatestLeaderOffsets(topicPartitions) 193 | }).right 194 | } yield { 195 | leaderOffsets.map { case (tp, lo) => 196 | (tp, lo.offset) 197 | } 198 | }).fold( 199 | errs => { 200 | val wrappedErrs = errs.map { 201 | case err: java.nio.channels.ClosedChannelException => { 202 | val broker = kc.config.seedBrokers.toList(0) 203 | new KafkaException(s"Could not connect to Kafka broker(s) at ${broker._1}:${broker._2}: $err") 204 | } 205 | case default => default 206 | } 207 | throw new SparkException(wrappedErrs.mkString("\n")) 208 | }, 209 | ok => ok 210 | ) 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /kafka/src/main/scala/CheckpointingKafkaExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.kafka 2 | 3 | import com.memsql.spark.etl.api.{UserExtractConfig, PhaseConfig, ByteArrayExtractor} 4 | import com.memsql.spark.etl.utils.PhaseLogger 5 | import org.apache.spark.sql.SQLContext 6 | import org.apache.spark.streaming.StreamingContext 7 | 8 | import kafka.serializer.{DefaultDecoder, StringDecoder} 9 | import org.apache.spark.streaming.kafka.{CheckpointedDirectKafkaInputDStream, CheckpointedKafkaUtils} 10 | import org.apache.spark.streaming.dstream.InputDStream 11 | 12 | /** 13 | * A checkpointing enabled Kafka extractor configured by a Zookeeper quorum. 14 | * The configuration has 2 required fields: 15 | * zk_quorum: a comma delimited string of host1:port,host2:port,host3:port denoting the Zookeeper quorum. 16 | * topic: the Kafka topic to extract. 17 | */ 18 | class CheckpointingKafkaExtractor extends ByteArrayExtractor { 19 | var CHECKPOINT_DATA_VERSION = 1 20 | 21 | var dstream: CheckpointedDirectKafkaInputDStream[String, Array[Byte], StringDecoder, DefaultDecoder, Array[Byte]] = null 22 | 23 | var zkQuorum: String = null 24 | var topic: String = null 25 | 26 | override def initialize(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Unit = { 27 | val kafkaConfig = config.asInstanceOf[UserExtractConfig] 28 | zkQuorum = kafkaConfig.getConfigString("zk_quorum").getOrElse { 29 | throw new IllegalArgumentException("\"zk_quorum\" must be set in the config") 30 | } 31 | topic = kafkaConfig.getConfigString("topic").getOrElse { 32 | throw new IllegalArgumentException("\"topic\" must be set in the config") 33 | } 34 | } 35 | 36 | def extract(ssc: StreamingContext, extractConfig: PhaseConfig, batchDuration: Long, logger: PhaseLogger): InputDStream[Array[Byte]] = { 37 | val kafkaParams = Map[String, String]( 38 | "memsql.zookeeper.connect" -> zkQuorum 39 | ) 40 | val topics = Set(topic) 41 | 42 | dstream = CheckpointedKafkaUtils.createDirectStreamFromZookeeper[String, Array[Byte], StringDecoder, DefaultDecoder]( 43 | ssc, kafkaParams, topics, batchDuration, lastCheckpoint) 44 | dstream 45 | } 46 | 47 | override def batchCheckpoint: Option[Map[String, Any]] = { 48 | dstream match { 49 | case null => None 50 | case default => { 51 | val currentOffsets = dstream.getCurrentOffsets.map { case (tp, offset) => 52 | Map("topic" -> tp.topic, "partition" -> tp.partition, "offset" -> offset) 53 | } 54 | Some(Map("offsets" -> currentOffsets, "zookeeper" -> zkQuorum, "version" -> CHECKPOINT_DATA_VERSION)) 55 | } 56 | } 57 | } 58 | 59 | override def batchRetry: Unit = { 60 | if (dstream.prevOffsets != null) { 61 | dstream.setCurrentOffsets(dstream.prevOffsets) 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.12.0-M1") 2 | -------------------------------------------------------------------------------- /src/main/scala/com/memsql/streamliner/examples/Extractors.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.streamliner.examples 2 | 3 | import org.apache.spark._ 4 | import org.apache.spark.rdd._ 5 | import org.apache.spark.streaming._ 6 | import org.apache.spark.streaming.dstream._ 7 | import org.apache.spark.sql._ 8 | import org.apache.spark.sql.types._ 9 | import com.memsql.spark.connector._ 10 | import com.memsql.spark.etl.api._ 11 | import com.memsql.spark.etl.utils._ 12 | import com.memsql.spark.etl.utils.PhaseLogger 13 | import org.apache.hadoop.io.{LongWritable, Text} 14 | import org.apache.hadoop.mapreduce.lib.input.TextInputFormat 15 | 16 | // The simplest implementation of an Extractor just provides a next method. 17 | // This is useful for prototyping and debugging. 18 | class ConstantExtractor extends Extractor { 19 | override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, 20 | logger: PhaseLogger): Option[DataFrame] = { 21 | logger.info("extracting a constant sequence DataFrame") 22 | 23 | val schema = StructType(StructField("number", IntegerType, false) :: Nil) 24 | 25 | val sampleData = List(1,2,3,4,5) 26 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 27 | 28 | val df = sqlContext.createDataFrame(rowRDD, schema) 29 | Some(df) 30 | } 31 | } 32 | 33 | // An Extractor can also be configured with the config blob that is provided in 34 | // MemSQL Ops. 35 | class ConfigurableConstantExtractor extends Extractor { 36 | override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, 37 | logger: PhaseLogger): Option[DataFrame] = { 38 | val userConfig = config.asInstanceOf[UserExtractConfig] 39 | val start = userConfig.getConfigInt("start").getOrElse(1) 40 | val end = userConfig.getConfigInt("end").getOrElse(5) 41 | val columnName = userConfig.getConfigString("column_name").getOrElse("number") 42 | 43 | logger.info("extracting a sequence DataFrame from $start to $end") 44 | 45 | val schema = StructType(StructField(columnName, IntegerType, false) :: Nil) 46 | 47 | val sampleData = List.range(start, end + 1) 48 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 49 | 50 | val df = sqlContext.createDataFrame(rowRDD, schema) 51 | Some(df) 52 | } 53 | 54 | } 55 | 56 | // A more complex Extractor which maintains some state can be implemented using 57 | // the initialize and cleanup methods. 58 | class SequenceExtractor extends Extractor { 59 | var i: Int = Int.MinValue 60 | 61 | override def initialize(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Unit = { 62 | val userConfig = config.asInstanceOf[UserExtractConfig] 63 | i = userConfig.getConfigInt("sequence", "initial_value").getOrElse(0) 64 | 65 | logger.info(s"initializing the sequence at $i") 66 | } 67 | 68 | override def cleanup(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Unit = { 69 | logger.info("cleaning up the sequence") 70 | } 71 | 72 | override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Option[DataFrame] = { 73 | val userConfig = config.asInstanceOf[UserExtractConfig] 74 | val sequenceSize = userConfig.getConfigInt("sequence", "size").getOrElse(5) 75 | 76 | logger.info(s"emitting a sequence RDD from $i to ${i + sequenceSize}") 77 | 78 | val schema = StructType(StructField("number", IntegerType, false) :: Nil) 79 | 80 | i += sequenceSize 81 | val sampleData = List.range(i - sequenceSize, i) 82 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 83 | 84 | val df = sqlContext.createDataFrame(rowRDD, schema) 85 | Some(df) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/main/scala/com/memsql/streamliner/examples/FileExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.streamliner.examples 2 | 3 | import org.apache.spark._ 4 | import org.apache.spark.rdd._ 5 | import org.apache.spark.streaming._ 6 | import org.apache.spark.streaming.dstream._ 7 | import org.apache.spark.sql._ 8 | import org.apache.spark.sql.types._ 9 | import com.memsql.spark.connector._ 10 | import com.memsql.spark.etl.api._ 11 | import com.memsql.spark.etl.utils._ 12 | 13 | // TODO: 14 | // 1) Figure out how to keep track of file state. we can keep track of names -> md5s in a bucket and periodically check whether any files 15 | // have changed. 16 | // 2) We need to hook into a lower level API in spark to figure out how to expand a file path to an actual list of files (for the above map) 17 | // 3) We need to persist the map somewhere 18 | // 4) Support non-'\n' line terminators 19 | class FileExtractor extends Extractor { 20 | var first = true 21 | var repeat = false 22 | var filePath = "" 23 | def schema: StructType = StructType(StructField("data", StringType, false) :: Nil) 24 | 25 | override def initialize(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, 26 | logger: PhaseLogger): Unit = { 27 | val userConfig = config.asInstanceOf[UserExtractConfig] 28 | filePath = userConfig.getConfigString("path") match { 29 | case Some(s) => s 30 | case None => throw new IllegalArgumentException("Missing required argument 'path'") 31 | } 32 | 33 | repeat = userConfig.getConfigBoolean("repeat").getOrElse(false) 34 | 35 | if (filePath.startsWith("s3n://")) { 36 | val access_key = userConfig.getConfigString("aws_access_key") match { 37 | case Some(s) => s 38 | case None => throw new IllegalArgumentException("Missing required argument 'aws_access_key' (because path starts with s3n)") 39 | } 40 | val secret_key = userConfig.getConfigString("aws_secret_key") match { 41 | case Some(s) => s 42 | case None => throw new IllegalArgumentException("Missing required argument 'aws_secret_key' (because path starts with s3n)") 43 | } 44 | 45 | // Ideally would like to put this in the URL (i.e. "s3n://ACCESS_KEY:SECRET_KEY@path"), but am blocked 46 | // https://issues.apache.org/jira/browse/HADOOP-3733 47 | val sparkContext = sqlContext.sparkContext 48 | val current_access_key = sparkContext.hadoopConfiguration.get("fs.s3.awsAccessKeyId") 49 | val current_secret_key = sparkContext.hadoopConfiguration.get("fs.s3n.awsSecretAccessKey") 50 | if (current_access_key != null && current_access_key != access_key) { 51 | throw new IllegalArgumentException("Access key is already set and is different") 52 | } 53 | if (current_secret_key != null && current_secret_key != secret_key) { 54 | throw new IllegalArgumentException("Secret key is already set and is different") 55 | } 56 | 57 | sparkContext.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", access_key) 58 | sparkContext.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", secret_key) 59 | } 60 | 61 | } 62 | 63 | override def cleanup(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, 64 | logger: PhaseLogger): Unit = { 65 | // TODO: should "flush" state here 66 | } 67 | 68 | override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, 69 | logger: PhaseLogger): Option[DataFrame] = { 70 | if (!first) { 71 | logger.info("There is nothing left to serve because first=false") 72 | return None 73 | } 74 | if (!repeat) { 75 | first = false 76 | } 77 | logger.info("Grabbing files now for path: [" + filePath + "]") 78 | val rowRDD = sqlContext.sparkContext.textFile(filePath).map(x => Row(x)) 79 | val df = sqlContext.createDataFrame(rowRDD, schema) 80 | Some(df) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/main/scala/com/memsql/streamliner/examples/S3AccessLogsTransformer.scala: -------------------------------------------------------------------------------- 1 | // This file contains transformers for processing log files generated by S3 2 | package com.memsql.streamliner.examples 3 | 4 | import com.memsql.spark.etl.api._ 5 | import com.memsql.spark.etl.utils.PhaseLogger 6 | import org.apache.spark.rdd._ 7 | import org.apache.spark.sql._ 8 | import org.apache.spark.sql.types._ 9 | 10 | class S3AccessLogsTransformer extends Transformer { 11 | val S3_LINE_LOGPATS = """(\S+) (\S+) \[(.*?)\] (\S+) (\S+) (\S+) (\S+) (\S+) (?:"([^"]+)"|-) (\S+) (\S+) (\S+) (\S+) (\S+) (\S+) (?:"([^"]+)"|-) (?:"([^"]+)"|-) (\S+)\s*""".r 12 | 13 | val SCHEMA = StructType(Array( 14 | StructField("bucket_owner", StringType, true), 15 | StructField("bucket", StringType, true), 16 | StructField("datetime", TimestampType, true), 17 | StructField("ip", StringType, true), 18 | StructField("requestor_id", StringType, true), 19 | StructField("request_id", StringType, true), 20 | StructField("operation", StringType, true), 21 | StructField("path", StringType, true), 22 | StructField("http_method_uri_proto", StringType, true), 23 | StructField("http_status", IntegerType, true), 24 | StructField("s3_error", StringType, true), 25 | StructField("bytes_sent", IntegerType, true), 26 | StructField("object_size", IntegerType, true), 27 | StructField("total_time", IntegerType, true), 28 | StructField("turn_around_time", IntegerType, true), 29 | StructField("referer", StringType, true), 30 | StructField("user_agent", StringType, true), 31 | StructField("version_id", StringType, true) 32 | )) 33 | 34 | val PARSER = new java.text.SimpleDateFormat("dd/MMM/yyyy:HH:mm:ss +SSSS") 35 | 36 | def parseDateTime(s: String) = { 37 | new java.sql.Timestamp(PARSER.parse(s).getTime()) 38 | } 39 | 40 | def parseIntOrDash(s: String) = s match { 41 | case "-" => null 42 | case str => str.toInt 43 | } 44 | 45 | override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = { 46 | val stringRDD = df.rdd 47 | val parsedRDD = stringRDD.flatMap(x => S3_LINE_LOGPATS.findAllMatchIn(x(0).asInstanceOf[String]).map(y => y.subgroups)) 48 | 49 | val typeConvertedRdd = parsedRDD.map(r => 50 | r.zipWithIndex.map({case (x, i) => 51 | if (i == 2) { 52 | parseDateTime(x) 53 | } else if (Set(9, 11, 12, 13, 14) contains i) { 54 | parseIntOrDash(x) 55 | } else { 56 | x 57 | } 58 | })) 59 | 60 | val rowRDD = typeConvertedRdd.map(x => Row.fromSeq(x)) 61 | sqlContext.createDataFrame(rowRDD, SCHEMA) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/scala/com/memsql/streamliner/examples/Transformers.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.streamliner.examples 2 | 3 | import org.apache.spark.rdd._ 4 | import org.apache.spark.sql._ 5 | import org.apache.spark.sql.types._ 6 | import com.memsql.spark.connector.dataframe.{JsonType, JsonValue} 7 | import com.memsql.spark.etl.api._ 8 | import com.memsql.spark.etl.utils.{ByteUtils, PhaseLogger} 9 | import com.memsql.spark.etl.utils.{JSONPath, JSONUtils} 10 | 11 | import com.fasterxml.jackson.databind.ObjectMapper 12 | import com.fasterxml.jackson.module.scala.DefaultScalaModule 13 | 14 | import scala.collection.JavaConversions._ 15 | 16 | // A helper object to extract the first column of a schema 17 | object ExtractFirstStructField { 18 | def unapply(schema: StructType): Option[(String, DataType, Boolean, Metadata)] = schema.fields match { 19 | case Array(first: StructField, _*) => Some((first.name, first.dataType, first.nullable, first.metadata)) 20 | } 21 | } 22 | 23 | // A Transformer implements the transform method which allows inspecting and transforming the DataFrame. 24 | class EvenNumbersOnlyTransformer extends Transformer { 25 | def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = { 26 | logger.info("transforming the DataFrame") 27 | 28 | // check that the first column is of type IntegerType and return its name 29 | val column = df.schema match { 30 | case ExtractFirstStructField(name: String, dataType: IntegerType, _, _) => name 31 | case _ => throw new IllegalArgumentException("The first column of the input DataFrame should be IntegerType") 32 | } 33 | 34 | // filter the dataframe, returning only even numbers 35 | df.filter(s"$column % 2 = 0") 36 | } 37 | } 38 | 39 | // A Transformer can also be configured with the config blob that is provided in MemSQL Ops. 40 | class ConfigurableNumberParityTransformer extends Transformer { 41 | override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = { 42 | val userConfig = config.asInstanceOf[UserTransformConfig] 43 | val keepEvenNumbers = userConfig.getConfigBoolean("filter", "even").getOrElse(true) 44 | val keepOddNumbers = userConfig.getConfigBoolean("filter", "odd").getOrElse(true) 45 | 46 | // check that the first column is of type IntegerType and return its name 47 | val column = df.schema match { 48 | case ExtractFirstStructField(name: String, dataType: IntegerType, _, _) => name 49 | case _ => throw new IllegalArgumentException("The first column of the input DataFrame should be IntegerType") 50 | } 51 | 52 | // filter the dataframe, returning resp. nothing, odd, even or all numbers 53 | logger.info(s"transforming the DataFrame: $keepEvenNumbers, $keepOddNumbers") 54 | if (!keepEvenNumbers && !keepOddNumbers) { 55 | df.filter("1 = 0") 56 | } 57 | else if (!keepEvenNumbers) { 58 | df.filter(s"$column % 2 = 1") 59 | } 60 | else if (!keepOddNumbers) { 61 | df.filter(s"$column % 2 = 0") 62 | } else { 63 | df 64 | } 65 | } 66 | } 67 | 68 | // A Transformer that extracts some fields from a JSON object 69 | // It supports both RDD[String] or RDD[Array[Byte]] 70 | // This saves into MemSQL 2 columns of type TEXT 71 | class JSONMultiColsTransformer extends Transformer { 72 | override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = { 73 | // define a JSON path 74 | val paths = Array[JSONPath]( 75 | JSONPath("id"), 76 | JSONPath("txt")) 77 | 78 | // check that the first column is of type StringType or BinaryType and convert the RDD accordingly 79 | val rdd = df.schema match { 80 | case ExtractFirstStructField(_, dataType: StringType, _, _) => df.map(r => r(0).asInstanceOf[String]) 81 | case ExtractFirstStructField(_, dataType: BinaryType, _, _) => df.map(r => ByteUtils.bytesToUTF8String(r(0).asInstanceOf[Array[Byte]])) 82 | case _ => throw new IllegalArgumentException("The first column of the input DataFrame should be either StringType or BinaryType") 83 | } 84 | 85 | // return a DataFrame with schema based on the JSON path 86 | JSONUtils.JSONRDDToDataFrame(paths, sqlContext, rdd) 87 | } 88 | } 89 | 90 | // A Transformer that parses a JSON object (using jackson) and filters only objects containing an "id" field 91 | // This saves into MemSQL a single column of type JSON 92 | class JSONCheckIdTransformer extends Transformer { 93 | override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = { 94 | val userConfig = config.asInstanceOf[UserTransformConfig] 95 | val columnName = userConfig.getConfigString("column_name").getOrElse("data") 96 | 97 | // check that the first column is of type StringType or BinaryType and convert the RDD accordingly 98 | val rdd = df.schema match { 99 | case ExtractFirstStructField(_, dataType: StringType, _, _) => df.map(r => r(0).asInstanceOf[String]) 100 | case ExtractFirstStructField(_, dataType: BinaryType, _, _) => df.map(r => ByteUtils.bytesToUTF8String(r(0).asInstanceOf[Array[Byte]])) 101 | case _ => throw new IllegalArgumentException("The first column of the input DataFrame should be either StringType or BinaryType") 102 | } 103 | 104 | // convert each input element to a JsonValue 105 | val jsonRDD = rdd.map(r => new JsonValue(r)) 106 | 107 | // filters only objects that contain an "id" field 108 | val filteredRDD = jsonRDD.mapPartitions(r => { 109 | // register jackson mapper (this needs to be executed on each partition) 110 | val mapper = new ObjectMapper() 111 | mapper.registerModule(DefaultScalaModule) 112 | 113 | // filter the partition for only the objects that contain an "id" field 114 | r.filter(x => mapper.readValue(x.value, classOf[Map[String,Any]]).contains("id")) 115 | }) 116 | val rowRDD = filteredRDD.map(x => Row(x)) 117 | 118 | // create a schema with a single non-nullable JSON column using the configured column name 119 | val schema = StructType(Array(StructField(columnName, JsonType, true))) 120 | 121 | sqlContext.createDataFrame(rowRDD, schema) 122 | } 123 | } 124 | 125 | class TwitterHashtagTransformer extends Transformer { 126 | override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = { 127 | val userConfig = config.asInstanceOf[UserTransformConfig] 128 | val columnName = userConfig.getConfigString("column_name").getOrElse("hashtags") 129 | 130 | // check that the first column is of type StringType or BinaryType and convert the RDD accordingly 131 | val rdd = df.schema match { 132 | case ExtractFirstStructField(_, dataType: StringType, _, _) => df.map(r => r(0).asInstanceOf[String]) 133 | case ExtractFirstStructField(_, dataType: BinaryType, _, _) => df.map(r => ByteUtils.bytesToUTF8String(r(0).asInstanceOf[Array[Byte]])) 134 | case _ => throw new IllegalArgumentException("The first column of the input DataFrame should be either StringType or BinaryType") 135 | } 136 | 137 | // convert each input element to a JsonValue 138 | val jsonRDD = df.map(r => r(0).asInstanceOf[String]) 139 | 140 | val hashtagsRDD: RDD[String] = jsonRDD.mapPartitions(r => { 141 | // register jackson mapper (this needs to be instantiated per partition 142 | // since it is not serializable) 143 | val mapper = new ObjectMapper() 144 | mapper.registerModule(DefaultScalaModule) 145 | 146 | r.flatMap(tweet => { 147 | val rootNode = mapper.readTree(tweet) 148 | val hashtags = rootNode.path("entities").path("hashtags") 149 | if (!hashtags.isMissingNode) { 150 | hashtags.elements 151 | .filter(n => n.has("text")) 152 | .map(n => n.get("text").asText) 153 | } else { 154 | Nil 155 | } 156 | }) 157 | }) 158 | 159 | val rowRDD: RDD[Row] = hashtagsRDD.map(x => Row(x)) 160 | val schema = StructType(Array(StructField(columnName, StringType, true))) 161 | sqlContext.createDataFrame(rowRDD, schema) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.logger.org.apache.spark=WARN 2 | log4j.logger.Remoting=WARN 3 | log4j.logger.org.eclipse.jetty=WARN 4 | log4j.logger.akka.remote=WARN 5 | log4j.logger.akka.event.slf4j=WARN 6 | -------------------------------------------------------------------------------- /src/test/scala/test/ExtractorsSpec.scala: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import com.memsql.spark.etl.api.UserExtractConfig 4 | import com.memsql.spark.etl.utils.ByteUtils 5 | import com.memsql.streamliner.examples._ 6 | import org.apache.spark.streaming.{Time, Duration, StreamingContext} 7 | import java.io._ 8 | import spray.json._ 9 | import test.util.{UnitSpec, TestLogger, LocalSparkContext} 10 | import org.apache.spark.streaming._ 11 | import org.apache.spark.sql.{SQLContext, Row} 12 | 13 | class ExtractorsSpec extends UnitSpec with LocalSparkContext { 14 | val emptyConfig = UserExtractConfig(class_name = "Test", value = new JsString("empty")) 15 | val logger = new TestLogger("test") 16 | 17 | var ssc: StreamingContext = _ 18 | var sqlContext: SQLContext = _ 19 | 20 | override def beforeEach(): Unit = { 21 | super.beforeEach() 22 | ssc = new StreamingContext(sc, Seconds(1)) 23 | sqlContext = new SQLContext(sc) 24 | } 25 | 26 | "ConstantExtractor" should "emit a constant DataFrame" in { 27 | val extract = new ConstantExtractor 28 | 29 | val maybeDf = extract.next(ssc, 1, sqlContext, emptyConfig, 1, logger) 30 | assert(maybeDf.isDefined) 31 | 32 | val total = maybeDf.get.select("number").map(r => r(0).asInstanceOf[Int]).sum() 33 | assert(total == 15) 34 | } 35 | 36 | "ConfigurableConstantExtractor" should "emit what the user specifies" in { 37 | val extract = new ConfigurableConstantExtractor 38 | 39 | val columnName = "mycolumn" 40 | val config = UserExtractConfig( 41 | class_name="test", 42 | value=JsObject( 43 | "start" -> JsNumber(1), 44 | "end" -> JsNumber(3), 45 | "column_name" -> JsString(columnName) 46 | ) 47 | ) 48 | 49 | val maybeDf = extract.next(ssc, 1, sqlContext, config, 1, logger) 50 | assert(maybeDf.isDefined) 51 | 52 | val total = maybeDf.get.select(columnName).map(r => r(0).asInstanceOf[Int]).sum() 53 | assert(total == 6) 54 | } 55 | 56 | "SequenceExtractor" should "maintain sequence state" in { 57 | val extract = new SequenceExtractor 58 | val config = UserExtractConfig( 59 | class_name = "test", 60 | value = JsObject( 61 | "sequence" -> JsObject( 62 | "initial_value" -> JsNumber(1), 63 | "size" -> JsNumber(1) 64 | ) 65 | ) 66 | ) 67 | 68 | extract.initialize(ssc, sqlContext, config, 1, logger) 69 | 70 | var i = 0 71 | for (i <- 1 to 3) { 72 | val maybeDf = extract.next(ssc, 1, sqlContext, config, 1, logger) 73 | assert(maybeDf.isDefined) 74 | 75 | val rdd = maybeDf.get.select("number").map(r => r(0).asInstanceOf[Int]) 76 | assert(rdd.count == 1) 77 | assert(rdd.first == i) 78 | } 79 | 80 | extract.cleanup(ssc, sqlContext, config, 1, logger) 81 | } 82 | 83 | "FileExtractor" should "produce DataFrame from files" in { 84 | val extract = new FileExtractor 85 | 86 | // initialize the extractor 87 | val tweetsURI = getClass.getResource("/tweets").toURI 88 | val config = UserExtractConfig( 89 | class_name = "test", 90 | value = JsObject( 91 | "path" -> JsString(tweetsURI.toURL.toString) 92 | ) 93 | ) 94 | extract.initialize(ssc, sqlContext, config, 1, logger) 95 | 96 | // extract data 97 | val maybeDf = extract.next(ssc, 1, sqlContext, config, 1, logger) 98 | assert(maybeDf.isDefined) 99 | 100 | val df = maybeDf.get 101 | assert(df.count == 312) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/test/scala/test/TransformersSpec.scala: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import com.memsql.spark.etl.api.UserTransformConfig 4 | import com.memsql.spark.etl.utils.ByteUtils 5 | import com.memsql.streamliner.examples._ 6 | import com.memsql.spark.connector.dataframe.JsonType 7 | import org.apache.spark.sql.{Row, SQLContext} 8 | import org.apache.spark.sql.types._ 9 | import spray.json.{JsBoolean, JsObject, JsString} 10 | import test.util.{UnitSpec, TestLogger, LocalSparkContext} 11 | 12 | class TransformersSpec extends UnitSpec with LocalSparkContext { 13 | val emptyConfig = UserTransformConfig(class_name = "Test", value = JsString("empty")) 14 | val logger = new TestLogger("test") 15 | 16 | var sqlContext: SQLContext = _ 17 | 18 | override def beforeEach(): Unit = { 19 | super.beforeEach() 20 | sqlContext = new SQLContext(sc) 21 | } 22 | 23 | "EvenNumbersOnlyTransformer" should "only emit even numbers" in { 24 | val transform = new EvenNumbersOnlyTransformer 25 | 26 | val schema = StructType(StructField("number", IntegerType, false) :: Nil) 27 | val sampleData = List(1,2,3) 28 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 29 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 30 | 31 | val df = transform.transform(sqlContext, dfIn, emptyConfig, logger) 32 | assert(df.schema == schema) 33 | assert(df.first == Row(2)) 34 | assert(df.count == 1) 35 | } 36 | 37 | "EvenNumbersOnlyTransformer" should "only accept IntegerType fields" in { 38 | val transform = new EvenNumbersOnlyTransformer 39 | 40 | val schema = StructType(StructField("column", StringType, false) :: Nil) 41 | val sampleData = List(1,2,3) 42 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 43 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 44 | 45 | val e = intercept[IllegalArgumentException] { 46 | transform.transform(sqlContext, dfIn, emptyConfig, logger) 47 | } 48 | assert(e.getMessage() == "The first column of the input DataFrame should be IntegerType") 49 | } 50 | 51 | "ConfigurableNumberParityTransformer" should "support skipping odd numbers" in { 52 | val transform = new ConfigurableNumberParityTransformer 53 | val schema = StructType(StructField("number", IntegerType, false) :: Nil) 54 | val sampleData = List(1,2,3) 55 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 56 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 57 | 58 | val config = UserTransformConfig( 59 | class_name="test", 60 | value=JsObject("filter" -> JsObject("odd" -> JsBoolean(false))) 61 | ) 62 | 63 | val df = transform.transform(sqlContext, dfIn, config, logger) 64 | assert(df.first == Row(2)) 65 | assert(df.count == 1) 66 | } 67 | 68 | it should "support skipping even numbers" in { 69 | val transform = new ConfigurableNumberParityTransformer 70 | val schema = StructType(StructField("number", IntegerType, false) :: Nil) 71 | val sampleData = List(1,2,3) 72 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 73 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 74 | 75 | val config = UserTransformConfig( 76 | class_name="test", 77 | value=JsObject("filter" -> JsObject( "even" -> JsBoolean(false) )) 78 | ) 79 | 80 | val df = transform.transform(sqlContext, dfIn, config, logger) 81 | assert(df.first == Row(1)) 82 | assert(df.count == 2) 83 | } 84 | 85 | it should "handle an empty filter" in { 86 | val transform = new ConfigurableNumberParityTransformer 87 | val schema = StructType(StructField("number", IntegerType, false) :: Nil) 88 | val sampleData = List(1,2,3) 89 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 90 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 91 | 92 | val config = UserTransformConfig( 93 | class_name="test", 94 | value=JsObject("filter" -> JsObject()) 95 | ) 96 | 97 | val df = transform.transform(sqlContext, dfIn, config, logger) 98 | assert(df.first == Row(1)) 99 | assert(df.count == 3) 100 | } 101 | 102 | "JSONMultiColsTransformer" should "insert rows with 2 fields id, txt" in { 103 | val transform = new JSONMultiColsTransformer 104 | val schema = StructType(StructField("data", StringType, false) :: Nil) 105 | val sampleData = List( 106 | """{"id": "a001", "txt": "hello"}""", 107 | """{"id": "b002", "txt": "world", "foo": "bar"}""", // foo field ignored 108 | """{"xid": "c001", "txt": "text"}""" // xid ignored, id NULL 109 | )//.map(line => ByteUtils.utf8StringToBytes(line)) 110 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 111 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 112 | 113 | val df = transform.transform(sqlContext, dfIn, emptyConfig, logger) 114 | assert(df.schema == StructType(Array( 115 | StructField("id", StringType, true), 116 | StructField("txt", StringType, true) 117 | ))) 118 | assert(df.count == 3) 119 | assert(df.first == Row("a001", "hello")) 120 | for ( (a, b) <- df.head(3).zip(Array( 121 | Row("a001", "hello"), 122 | Row("b002", "world"), 123 | Row(null, "text") 124 | ))) { 125 | assert(a == b) 126 | } 127 | } 128 | 129 | "JSONCheckIdTransformer" should "insert rows with 1 field of type JSON" in { 130 | val transform = new JSONCheckIdTransformer 131 | val schema = StructType(StructField("testdata", StringType, false) :: Nil) 132 | val sampleData = List( 133 | """{"id": "a001", "txt": "hello"}""" 134 | ) 135 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 136 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 137 | 138 | val columnName = "test" 139 | val config = UserTransformConfig( 140 | class_name = "test", 141 | value = JsObject("column_name" -> JsString(columnName)) 142 | ) 143 | 144 | val df = transform.transform(sqlContext, dfIn, config, logger) 145 | assert(df.schema == StructType(Array(StructField(columnName, JsonType, true)))) 146 | assert(df.count == 1) 147 | assert(df.first.toString == """[{"id": "a001", "txt": "hello"}]""") 148 | } 149 | 150 | "JSONCheckIdTransformer" should "insert rows with 1 field of type JSON on input Array[Byte]" in { 151 | val transform = new JSONCheckIdTransformer 152 | val schema = StructType(StructField("testdata", BinaryType, false) :: Nil) 153 | val sampleData = List( 154 | """{"id": "a001", "txt": "hello"}""" 155 | ).map(ByteUtils.utf8StringToBytes) 156 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 157 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 158 | 159 | val columnName = "test" 160 | val config = UserTransformConfig( 161 | class_name = "test", 162 | value = JsObject("column_name" -> JsString(columnName)) 163 | ) 164 | 165 | val df = transform.transform(sqlContext, dfIn, config, logger) 166 | assert(df.schema == StructType(Array(StructField(columnName, JsonType, true)))) 167 | assert(df.count == 1) 168 | assert(df.first.toString == """[{"id": "a001", "txt": "hello"}]""") 169 | } 170 | 171 | it should "skip rows with no id field" in { 172 | val transform = new JSONCheckIdTransformer 173 | val schema = StructType(StructField("testdata", StringType, false) :: Nil) 174 | val sampleData = List( 175 | """{"id": "a001", "txt": "hello"}""", 176 | """{"id": "b002", "txt": "world", "foo": "bar"}""", // foo field 177 | """{"xid": "c001", "txt": "text"}""" // id not available, row skipped 178 | ) 179 | val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_)) 180 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 181 | 182 | val df = transform.transform(sqlContext, dfIn, emptyConfig, logger) 183 | assert(df.schema == StructType(Array(StructField("data", JsonType, true)))) 184 | assert(df.count == 2) 185 | } 186 | 187 | "TwitterHashtagTransformer" should "should extract all the hashtags from the tweets resource" in { 188 | val transform = new TwitterHashtagTransformer 189 | val schema = StructType(StructField("testdata", StringType, false) :: Nil) 190 | val tweetsURI = getClass.getResource("/tweets").toURI 191 | val tweets = sc.textFile(tweetsURI.toURL.toString) 192 | val rowRDD = tweets.map(Row(_)) 193 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 194 | 195 | val columnName = "test" 196 | val config = UserTransformConfig( 197 | class_name = "TwitterHashtagTransformer", 198 | value = JsObject("column_name" -> JsString(columnName)) 199 | ) 200 | 201 | val df = transform.transform(sqlContext, dfIn, config, logger) 202 | 203 | assert(df.schema == StructType(Array(StructField(columnName, StringType, true)))) 204 | // looked in the tweets file (see resources/tweets) and found the 205 | // first hashtag in the first tweet 206 | assert(df.first == Row("MTVFANWARSArianators")) 207 | assert(df.count == 141) 208 | } 209 | 210 | "S3AccessLogsTransformer" should "correctly parse S3 logs" in { 211 | // example from AWS website 212 | val logOutput = List( 213 | """79a59df900b949e55d96a1e698fbacedfd6e09d98eacf8f8d5218e7cd47ef2be mybucket [06/Feb/2014:00:00:38 +0000] 192.0.2.3 214 | | 79a59df900b949e55d96a1e698fbacedfd6e09d98eacf8f8d5218e7cd47ef2be 3E57427F3EXAMPLE REST.GET.VERSIONING - "GET/mybucket?versioning HTTP/1.1" 200 215 | | - 113 - 7 - "-" "S3Console/0.4" -""".stripMargin.replaceAll("\n", "") 216 | ) 217 | 218 | val transform = new S3AccessLogsTransformer 219 | val schema = StructType(StructField("testdata", StringType, false) :: Nil) 220 | val rowRDD = sqlContext.sparkContext.parallelize(logOutput).map(Row(_)) 221 | val dfIn = sqlContext.createDataFrame(rowRDD, schema) 222 | 223 | val df = transform.transform(sqlContext, dfIn, emptyConfig, logger) 224 | assert(df.count == 1) 225 | 226 | val first = df.first 227 | 228 | assert(first.getAs[String]("bucket") == "mybucket") 229 | assert(first.getAs[String]("ip") == "192.0.2.3") 230 | assert(first.getAs[String]("user_agent") == "S3Console/0.4") 231 | assert(first.getAs[String]("version_id") == "-") 232 | assert(first.getAs[Int]("http_status") == 200) 233 | assert(first.getAs[Int]("object_size") === null) 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /src/test/scala/test/util/LocalSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * 17 | * With small modifications by MemSQL 18 | */ 19 | 20 | package test.util 21 | 22 | import org.apache.spark.{SparkConf, SparkContext} 23 | import org.scalatest.BeforeAndAfterEach 24 | import org.scalatest._ 25 | 26 | trait LocalSparkContext extends BeforeAndAfterEach { self: Suite => 27 | 28 | @transient private var _sc: SparkContext = _ 29 | 30 | val _sparkConf = new SparkConf(false) 31 | .set("spark.ui.showConsoleProgress", "false") 32 | 33 | def sc: SparkContext = _sc 34 | 35 | override def beforeEach() { 36 | _sc = new SparkContext("local[4]", "test", _sparkConf) 37 | super.beforeEach() 38 | } 39 | 40 | override def afterEach() { 41 | resetSparkContext() 42 | super.afterEach() 43 | } 44 | 45 | def resetSparkContext(): Unit = { 46 | LocalSparkContext.stop(_sc) 47 | _sc = null 48 | } 49 | 50 | } 51 | 52 | object LocalSparkContext { 53 | def stop(sc: SparkContext) { 54 | if (sc != null) { 55 | sc.stop() 56 | } 57 | // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown 58 | System.clearProperty("spark.driver.port") 59 | } 60 | 61 | /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ 62 | def withSpark[T](sc: SparkContext)(f: SparkContext => T): T = { 63 | try { 64 | f(sc) 65 | } finally { 66 | stop(sc) 67 | } 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /src/test/scala/test/util/TestLogger.scala: -------------------------------------------------------------------------------- 1 | package test.util 2 | 3 | import com.memsql.spark.etl.utils.PhaseLogger 4 | import org.apache.log4j.Logger 5 | 6 | class TestLogger(override val name: String) extends PhaseLogger { 7 | override protected val logger: Logger = Logger.getRootLogger 8 | } 9 | 10 | -------------------------------------------------------------------------------- /src/test/scala/test/util/UnitSpec.scala: -------------------------------------------------------------------------------- 1 | package test.util 2 | 3 | import org.scalatest._ 4 | 5 | abstract class UnitSpec 6 | extends FlatSpec 7 | with Matchers 8 | with OptionValues 9 | with Inside 10 | with Inspectors 11 | with BeforeAndAfter 12 | with BeforeAndAfterEach 13 | with BeforeAndAfterAll 14 | with OneInstancePerTest { 15 | } 16 | -------------------------------------------------------------------------------- /thrift/src/main/scala/ThriftRandomExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import com.memsql.spark.etl.api._ 4 | import com.memsql.spark.etl.utils.PhaseLogger 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.sql.{SQLContext, DataFrame, Row} 7 | import org.apache.spark.sql.types._ 8 | import org.apache.spark.streaming.StreamingContext 9 | import org.apache.thrift.protocol.TBinaryProtocol 10 | import org.apache.thrift.{TBase, TFieldIdEnum, TSerializer} 11 | 12 | class ThriftRandomExtractor extends Extractor { 13 | var count: Int = 1 14 | var thriftType: Class[_] = null 15 | var serializer: TSerializer = null 16 | 17 | def schema: StructType = StructType(StructField("bytes", BinaryType, false) :: Nil) 18 | 19 | override def initialize(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Unit = { 20 | val userConfig = config.asInstanceOf[UserExtractConfig] 21 | val className = userConfig.getConfigString("className") match { 22 | case Some(s) => s 23 | case None => throw new IllegalArgumentException("className must be set in the config") 24 | } 25 | thriftType = Class.forName(className) 26 | serializer = new TSerializer(new TBinaryProtocol.Factory()) 27 | count = userConfig.getConfigInt("count").getOrElse(1) 28 | } 29 | 30 | override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Option[DataFrame] = { 31 | val rdd = sqlContext.sparkContext.parallelize((1 to count).map(_ => Row({ 32 | val thriftObject = ThriftRandomGenerator.next(thriftType).asInstanceOf[TBase[_ <: TBase[_, _], _ <: TFieldIdEnum]] 33 | serializer.serialize(thriftObject) 34 | }))) 35 | Some(sqlContext.createDataFrame(rdd, schema)) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /thrift/src/main/scala/ThriftRandomGenerator.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import collection.JavaConversions._ 4 | import java.lang.reflect.Method 5 | import java.nio.ByteBuffer 6 | import org.apache.thrift.{TBase, TFieldIdEnum} 7 | import org.apache.thrift.protocol.{TField, TType} 8 | import org.apache.thrift.meta_data._ 9 | 10 | import scala.util.Random 11 | 12 | object ThriftRandomGenerator { 13 | val random = new Random 14 | val MAX_RECURSION_LEVEL = 5 15 | 16 | def next[F <: TFieldIdEnum](c: Class[_], level: Int = 0): Any = { 17 | if (level > MAX_RECURSION_LEVEL) { 18 | return null 19 | } 20 | val className = c.getName 21 | try { 22 | val tBaseClass = c.asInstanceOf[Class[TBase[_ <: TBase[_, _], F]]] 23 | val instance = tBaseClass.newInstance() 24 | val metaDataMap: Map[_ <: TFieldIdEnum, FieldMetaData] = FieldMetaData.getStructMetaDataMap(tBaseClass).toMap 25 | metaDataMap.foreach({ case (field, fieldMetaData) => 26 | val valueMetaData = fieldMetaData.valueMetaData 27 | val value = getValue(valueMetaData, level) 28 | instance.setFieldValue(instance.fieldForId(field.getThriftFieldId), value) 29 | }) 30 | instance 31 | } catch { 32 | case e: ClassCastException => throw new IllegalArgumentException(s"Class $className is not a subclass of org.apache.thrift.TBase") 33 | } 34 | } 35 | 36 | def getValue(valueMetaData: FieldValueMetaData, level: Int): Any = { 37 | if (level > MAX_RECURSION_LEVEL) { 38 | return null 39 | } 40 | valueMetaData.`type` match { 41 | case TType.BOOL => random.nextBoolean 42 | case TType.BYTE => random.nextInt.toByte 43 | case TType.I16 => random.nextInt.toShort 44 | case TType.I32 => random.nextInt 45 | case TType.I64 => random.nextLong 46 | case TType.DOUBLE => random.nextInt(5) * 0.25 47 | case TType.ENUM => { 48 | val enumClass = valueMetaData.asInstanceOf[EnumMetaData].enumClass 49 | getEnumValue(enumClass) 50 | } 51 | case TType.STRING => { 52 | val length: Int = 5 + random.nextInt(5) 53 | val s = (1 to length).map(x => ('a'.toInt + random.nextInt(26)).toChar).mkString 54 | if (valueMetaData.isBinary) { 55 | ByteBuffer.wrap(s.getBytes) 56 | } else { 57 | s 58 | } 59 | } 60 | case TType.LIST => { 61 | val elemMetaData = valueMetaData.asInstanceOf[ListMetaData].elemMetaData 62 | val length: Int = 5 + random.nextInt(5) 63 | val ret: java.util.List[Any] = (1 to length).map(x => getValue(elemMetaData, level + 1)) 64 | ret 65 | } 66 | case TType.SET => { 67 | val elemMetaData = valueMetaData.asInstanceOf[SetMetaData].elemMetaData 68 | val length: Int = 5 + random.nextInt(5) 69 | val ret: Set[Any] = (1 to length).map(x => getValue(elemMetaData, level + 1)).toSet 70 | val javaSet: java.util.Set[Any] = ret 71 | javaSet 72 | } 73 | case TType.MAP => { 74 | val mapMetaData = valueMetaData.asInstanceOf[MapMetaData] 75 | val keyMetaData = mapMetaData.keyMetaData 76 | val mapValueMetaData = mapMetaData.valueMetaData 77 | val length: Int = 5 + random.nextInt(5) 78 | val ret: Map[Any, Any] = (1 to length).map(_ => { 79 | val mapKey = getValue(keyMetaData, level + 1) 80 | val mapValue = getValue(mapValueMetaData, level + 1) 81 | mapKey -> mapValue 82 | }).toMap 83 | val javaMap: java.util.Map[Any, Any] = ret 84 | javaMap 85 | } 86 | case TType.STRUCT => { 87 | val structClass = valueMetaData.asInstanceOf[StructMetaData].structClass 88 | next(structClass, level = level + 1) 89 | } 90 | case _ => null 91 | } 92 | } 93 | 94 | def getEnumValue(enumType: Class[_]): Any = { 95 | val enumConstants = enumType.getEnumConstants 96 | enumConstants(random.nextInt(enumConstants.length)) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /thrift/src/main/scala/ThriftToRow.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import org.apache.thrift.TBase 4 | import org.apache.thrift.protocol.TField 5 | import org.apache.spark.sql.Row 6 | 7 | private class ThriftToRow(c: Class[_]) { 8 | val fieldMembers = c.getDeclaredFields.filter(_.getType() == classOf[TField]) 9 | val thriftIds = fieldMembers.map({ x => 10 | x.setAccessible(true) 11 | x.get(null).asInstanceOf[TField].id 12 | }) 13 | val protocol = new ThriftToRowSerializer(thriftIds) 14 | 15 | def getRow(t: TBase[_,_]): Row = { 16 | t.write(protocol) 17 | return protocol.getRow() 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /thrift/src/main/scala/ThriftToRowSerializer.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import java.io.ByteArrayOutputStream 4 | import java.nio.ByteBuffer 5 | 6 | import com.memsql.spark.connector.dataframe.JsonValue 7 | import org.apache.spark.sql.Row 8 | import org.apache.thrift.TSerializer 9 | import org.apache.thrift.protocol._ 10 | import org.apache.thrift.transport.TIOStreamTransport 11 | 12 | import scala.collection.mutable.ListBuffer 13 | 14 | private class ThriftToRowSerializer(val thriftIds: Array[Short]) extends TProtocol(null) { 15 | val values = new ListBuffer[Any] 16 | var nestingLevel: Int = 0 17 | 18 | var jsonProtocol: TSimpleJSONProtocol = null 19 | var outputStream: ByteArrayOutputStream = null 20 | var transport: TIOStreamTransport = null 21 | var serializer: TSerializer = null 22 | var valueWritten = true 23 | var currentFieldId = 0 24 | 25 | def getRow(): Row = { 26 | while (currentFieldId < thriftIds.length) { 27 | values.append(null) 28 | currentFieldId += 1 29 | } 30 | return Row.fromSeq(values) 31 | } 32 | 33 | def startJSON(): Unit = { 34 | this.outputStream = new ByteArrayOutputStream() 35 | this.transport = new TIOStreamTransport(outputStream) 36 | this.jsonProtocol = new TSimpleJSONProtocol(this.transport) 37 | } 38 | 39 | def saveJSON(): Unit = { 40 | values.append(new JsonValue(outputStream.toString())) 41 | jsonProtocol = null 42 | outputStream = null 43 | transport = null 44 | } 45 | 46 | override def writeStructBegin(struct: TStruct): Unit = { 47 | if (nestingLevel == 0) { 48 | currentFieldId = 0 49 | } else { 50 | jsonProtocol.writeStructBegin(struct) 51 | } 52 | nestingLevel += 1 53 | } 54 | 55 | override def writeStructEnd(): Unit = { 56 | nestingLevel -= 1 57 | if (nestingLevel > 0) { 58 | jsonProtocol.writeStructEnd() 59 | } 60 | if (nestingLevel == 1) { 61 | saveJSON() 62 | } 63 | } 64 | 65 | override def writeMapBegin(m: TMap): Unit = { 66 | nestingLevel += 1 67 | jsonProtocol.writeMapBegin(m) 68 | } 69 | 70 | override def writeMapEnd(): Unit = { 71 | nestingLevel -= 1 72 | jsonProtocol.writeMapEnd() 73 | if (nestingLevel == 1) { 74 | saveJSON() 75 | } 76 | } 77 | 78 | override def writeSetBegin(s: TSet): Unit = { 79 | nestingLevel += 1 80 | jsonProtocol.writeSetBegin(s) 81 | } 82 | 83 | override def writeSetEnd(): Unit = { 84 | nestingLevel -= 1 85 | jsonProtocol.writeSetEnd() 86 | if (nestingLevel == 1) { 87 | saveJSON() 88 | } 89 | } 90 | 91 | override def writeListBegin(l: TList): Unit = { 92 | nestingLevel += 1 93 | jsonProtocol.writeListBegin(l) 94 | } 95 | 96 | override def writeListEnd(): Unit = { 97 | nestingLevel -= 1 98 | jsonProtocol.writeListEnd() 99 | if (nestingLevel == 1) { 100 | saveJSON() 101 | } 102 | } 103 | 104 | override def writeFieldBegin(field: TField): Unit = { 105 | if (nestingLevel != 1) { 106 | jsonProtocol.writeFieldBegin(field) 107 | } else { 108 | while (thriftIds(currentFieldId) != field.id) { 109 | values.append(null) 110 | currentFieldId += 1 111 | } 112 | field.`type` match { 113 | case TType.LIST | TType.MAP | TType.SET | TType.STRUCT => 114 | startJSON() 115 | case _ => { } 116 | } 117 | } 118 | } 119 | 120 | override def writeFieldStop(): Unit = { 121 | if (nestingLevel != 1) { 122 | jsonProtocol.writeFieldStop() 123 | } 124 | } 125 | 126 | override def writeFieldEnd(): Unit = { 127 | if (nestingLevel != 1) { 128 | jsonProtocol.writeFieldStop() 129 | } else { 130 | currentFieldId += 1 131 | } 132 | } 133 | 134 | override def writeString(str: String): Unit = { 135 | if (nestingLevel == 1) { 136 | values.append(str) 137 | } else { 138 | jsonProtocol.writeString(str) 139 | } 140 | } 141 | 142 | override def writeBool(b: Boolean): Unit = { 143 | if (nestingLevel == 1) { 144 | values.append(b) 145 | } else { 146 | jsonProtocol.writeBool(b) 147 | } 148 | } 149 | 150 | override def writeMessageBegin(tMessage: TMessage): Unit = { 151 | if (nestingLevel != 1) { 152 | jsonProtocol.writeMessageBegin(tMessage) 153 | } 154 | } 155 | 156 | override def writeMessageEnd(): Unit = { 157 | if (nestingLevel != 1) { 158 | jsonProtocol.writeMessageEnd() 159 | } 160 | } 161 | 162 | override def writeByte(i: Byte): Unit = { 163 | if (nestingLevel == 1) { 164 | values.append(i) 165 | } else { 166 | jsonProtocol.writeByte(i) 167 | } 168 | } 169 | 170 | override def writeI16(i: Short): Unit = { 171 | if (nestingLevel == 1) { 172 | values.append(i) 173 | } else { 174 | jsonProtocol.writeI16(i) 175 | } 176 | } 177 | 178 | override def writeI32(i: Int): Unit = { 179 | if (nestingLevel == 1) { 180 | values.append(i) 181 | } else { 182 | jsonProtocol.writeI32(i) 183 | } 184 | } 185 | 186 | override def writeI64(i: Long): Unit = { 187 | if (nestingLevel == 1) { 188 | values.append(i) 189 | } else { 190 | jsonProtocol.writeI64(i) 191 | } 192 | } 193 | 194 | override def writeDouble(v: Double): Unit = { 195 | if (nestingLevel == 1) { 196 | values.append(v) 197 | } else { 198 | jsonProtocol.writeDouble(v) 199 | } 200 | } 201 | 202 | override def writeBinary(byteBuffer: ByteBuffer): Unit = { 203 | if (nestingLevel == 1) { 204 | values.append(byteBuffer) 205 | } else { 206 | jsonProtocol.writeBinary(byteBuffer) 207 | } 208 | } 209 | 210 | override def toString(): String = { 211 | val sb = new StringBuilder 212 | values.foreach({x => sb.append(x.toString())}) 213 | return sb.toString() 214 | } 215 | 216 | override def readBool(): Boolean = ??? 217 | 218 | override def readSetBegin(): TSet = ??? 219 | 220 | override def readByte(): Byte = ??? 221 | 222 | override def readStructBegin(): TStruct = ??? 223 | 224 | override def readStructEnd(): Unit = ??? 225 | 226 | override def readListEnd(): Unit = ??? 227 | 228 | override def readI32(): Int = ??? 229 | 230 | override def readI64(): Long = ??? 231 | 232 | override def readI16(): Short = ??? 233 | 234 | override def readMessageBegin(): TMessage = ??? 235 | 236 | override def readFieldBegin(): TField = ??? 237 | 238 | override def readListBegin(): TList = ??? 239 | 240 | override def readMapEnd(): Unit = ??? 241 | 242 | override def readFieldEnd(): Unit = ??? 243 | 244 | override def readString(): String = ??? 245 | 246 | override def readMessageEnd(): Unit = ??? 247 | 248 | override def readDouble(): Double = ??? 249 | 250 | override def readBinary(): ByteBuffer = ??? 251 | 252 | override def readSetEnd(): Unit = ??? 253 | 254 | override def readMapBegin(): TMap = ??? 255 | } 256 | -------------------------------------------------------------------------------- /thrift/src/main/scala/ThriftToSchema.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import collection.JavaConversions._ 4 | import com.memsql.spark.connector.dataframe.JsonType 5 | import org.apache.spark.sql.types._ 6 | import org.apache.thrift.{TBase, TFieldIdEnum} 7 | import org.apache.thrift.protocol.{TField, TType} 8 | import org.apache.thrift.meta_data._ 9 | 10 | private object ThriftToSchema { 11 | def getSchema(c: Class[_]) : StructType = { 12 | val className = c.getName 13 | var tBaseClass: Class[TBase[_ <: TBase[_, _], _ <: TFieldIdEnum]] = null 14 | try { 15 | tBaseClass = c.asInstanceOf[Class[TBase[_ <: TBase[_, _], _ <: TFieldIdEnum]]] 16 | } catch { 17 | case e: ClassCastException => throw new IllegalArgumentException(s"Class $className is not a subclass of org.apache.thrift.TBase") 18 | } 19 | val metaDataMap: Map[_ <: TFieldIdEnum, FieldMetaData] = FieldMetaData.getStructMetaDataMap(tBaseClass).toMap 20 | // Sort the fields by their thrift ID so that they're in a consistent 21 | // order. 22 | val metaDataSeq = metaDataMap.toSeq.sortBy(_._1.getThriftFieldId) 23 | StructType(metaDataSeq.map({ case (field, fieldMetaData) => 24 | val fieldName = fieldMetaData.fieldName 25 | val fieldType = fieldMetaData.valueMetaData.`type` match { 26 | case TType.BOOL => BooleanType 27 | case TType.BYTE => ByteType 28 | case TType.I16 => ShortType 29 | case TType.I32 => IntegerType 30 | case TType.I64 => LongType 31 | case TType.DOUBLE => DoubleType 32 | case TType.STRING => { 33 | if (fieldMetaData.valueMetaData.isBinary) { 34 | BinaryType 35 | } else { 36 | StringType 37 | } 38 | } 39 | case TType.ENUM => IntegerType 40 | case _ => JsonType 41 | } 42 | StructField(fieldName, fieldType, true) 43 | })) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /thrift/src/main/scala/ThriftTransformer.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import com.memsql.spark.etl.api._ 4 | import com.memsql.spark.etl.utils.PhaseLogger 5 | import org.apache.spark.rdd.RDD 6 | import org.apache.spark.sql.{DataFrame, Row, SQLContext} 7 | import org.apache.spark.sql.types._ 8 | import org.apache.thrift.{TBase, TDeserializer, TFieldIdEnum} 9 | 10 | class ThriftTransformer extends Transformer { 11 | private var classObj: Class[_] = null 12 | private var thriftToRow: ThriftToRow = null 13 | private var deserializer: TDeserializer = null 14 | private var schema: StructType = null 15 | 16 | def thriftRDDToDataFrame(sqlContext: SQLContext, rdd: RDD[Row]): DataFrame = { 17 | val rowRDD: RDD[Row] = rdd.map({ record => 18 | val recordAsBytes = record(0).asInstanceOf[Array[Byte]] 19 | val i = classObj.newInstance().asInstanceOf[TBase[_ <: TBase[_, _], _ <: TFieldIdEnum]] 20 | deserializer.deserialize(i, recordAsBytes) 21 | thriftToRow.getRow(i) 22 | }) 23 | sqlContext.createDataFrame(rowRDD, schema) 24 | } 25 | 26 | override def initialize(sqlContext: SQLContext, config: PhaseConfig, logger: PhaseLogger): Unit = { 27 | val userConfig = config.asInstanceOf[UserTransformConfig] 28 | val className = userConfig.getConfigString("className") match { 29 | case Some(s) => s 30 | case None => throw new IllegalArgumentException("className must be set in the config") 31 | } 32 | 33 | classObj = Class.forName(className) 34 | thriftToRow = new ThriftToRow(classObj) 35 | deserializer = new TDeserializer() 36 | 37 | schema = ThriftToSchema.getSchema(classObj) 38 | } 39 | 40 | override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = { 41 | thriftRDDToDataFrame(sqlContext, df.rdd) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /thrift/src/test/scala/ThriftRandomGeneratorSpec.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import org.scalatest._ 4 | import org.apache.thrift.{TBase, TFieldIdEnum} 5 | 6 | class ThriftRandomGeneratorSpec extends FlatSpec { 7 | "ThriftRandomGenerator" should "create Thrift objects with random values" in { 8 | val thriftType = classOf[TestClass] 9 | val thriftObject = ThriftRandomGenerator.next(thriftType).asInstanceOf[TestClass] 10 | assert(thriftObject.string_value != null) 11 | assert(thriftObject.binary_value != null) 12 | assert(thriftObject.list_value.size > 0) 13 | assert(thriftObject.list_value.get(0).isInstanceOf[String]) 14 | assert(thriftObject.set_value.size > 0) 15 | assert(thriftObject.set_value.toArray()(0).isInstanceOf[String]) 16 | assert(thriftObject.map_value.size > 0) 17 | val keys = thriftObject.map_value.keySet.toArray 18 | assert(keys(0).isInstanceOf[String]) 19 | assert(thriftObject.map_value.get(keys(0)).isInstanceOf[String]) 20 | val testEnumValues = Set(TestEnum.FIRST_VALUE, TestEnum.SECOND_VALUE) 21 | assert(testEnumValues.contains(thriftObject.enum_value)) 22 | assert(thriftObject.sub_class_value != null) 23 | assert(thriftObject.sub_class_value.string_value != null) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /thrift/src/test/scala/ThriftToRowSpec.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import com.memsql.spark.connector.dataframe.JsonValue 4 | import collection.JavaConversions._ 5 | import java.nio.ByteBuffer 6 | import org.scalatest._ 7 | 8 | class ThriftToRowSpec extends FlatSpec { 9 | "ThriftToRow" should "create Spark SQL Rows from Thrift objects" in { 10 | val testClassInstance = new TestClass( 11 | true, 12 | 42.toByte, 13 | 128.toShort, 14 | 1024, 15 | 2048.toLong, 16 | 2.5, 17 | "test1", 18 | ByteBuffer.wrap("test2".getBytes), 19 | mapAsJavaMap(Map("value" -> "test3")).asInstanceOf[java.util.Map[String, String]], 20 | List("test4"), 21 | Set("test5"), 22 | TestEnum.FIRST_VALUE, 23 | new SubClass("test6") 24 | ) 25 | val thriftToRow = new ThriftToRow(classOf[TestClass]) 26 | val row = thriftToRow.getRow(testClassInstance) 27 | assert(row.getAs[Boolean](0)) 28 | assert(row.getAs[Byte](1) == 42.toByte) 29 | assert(row.getAs[Short](2) == 128.toShort) 30 | assert(row.getAs[Int](3) == 1024) 31 | assert(row.getAs[Long](4) == 2048) 32 | assert(row.getAs[Double](5) == 2.5) 33 | assert(row.getAs[String](6) == "test1") 34 | assert(row.getAs[ByteBuffer](7) == ByteBuffer.wrap("test2".getBytes)) 35 | val mapValue = row.getAs[JsonValue](8) 36 | assert(mapValue.value == "{\"value\":\"test3\"}") 37 | val listValue = row.getAs[JsonValue](9) 38 | assert(listValue.value == "[\"test4\"]") 39 | val setValue = row.getAs[JsonValue](10) 40 | assert(setValue.value == "[\"test5\"]") 41 | assert(row.getAs[Int](11) == TestEnum.FIRST_VALUE.getValue) 42 | val subClassValue = row.getAs[JsonValue](12) 43 | assert(subClassValue.value == "{\"string_value\":\"test6\"}") 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /thrift/src/test/scala/ThriftToSchemaSpec.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark.examples.thrift 2 | 3 | import com.memsql.spark.connector.dataframe.JsonType 4 | import org.apache.spark.sql.types._ 5 | import org.scalatest._ 6 | 7 | class ThriftToSchemaSpec extends FlatSpec { 8 | "ThriftToSchema" should "create a schema from Thrift classes" in { 9 | val schema = ThriftToSchema.getSchema(classOf[TestClass]) 10 | val fields = schema.fields 11 | assert(fields.forall(field => field.nullable)) 12 | assert(fields(0).name == "bool_value") 13 | assert(fields(0).dataType == BooleanType) 14 | assert(fields(1).name == "byte_value") 15 | assert(fields(1).dataType == ByteType) 16 | assert(fields(2).name == "i16_value") 17 | assert(fields(2).dataType == ShortType) 18 | assert(fields(3).name == "i32_value") 19 | assert(fields(3).dataType == IntegerType) 20 | assert(fields(4).name == "i64_value") 21 | assert(fields(4).dataType == LongType) 22 | assert(fields(5).name == "double_value") 23 | assert(fields(5).dataType == DoubleType) 24 | assert(fields(6).name == "string_value") 25 | assert(fields(6).dataType == StringType) 26 | assert(fields(7).name == "binary_value") 27 | assert(fields(7).dataType == BinaryType) 28 | assert(fields(8).name == "map_value") 29 | assert(fields(8).dataType == JsonType) 30 | assert(fields(9).name == "list_value") 31 | assert(fields(9).dataType == JsonType) 32 | assert(fields(10).name == "set_value") 33 | assert(fields(10).dataType == JsonType) 34 | assert(fields(11).name == "enum_value") 35 | assert(fields(11).dataType == IntegerType) 36 | assert(fields(12).name == "sub_class_value") 37 | assert(fields(12).dataType == JsonType) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /thrift/src/test/thrift/TestClass.thrift: -------------------------------------------------------------------------------- 1 | namespace java com.memsql.spark.examples.thrift 2 | 3 | struct SubClass { 4 | 1: string string_value 5 | } 6 | 7 | enum TestEnum { 8 | FIRST_VALUE = 1, 9 | SECOND_VALUE = 2 10 | } 11 | 12 | struct TestClass { 13 | 1: bool bool_value, 14 | 2: byte byte_value, 15 | 3: i16 i16_value, 16 | 4: i32 i32_value, 17 | 5: i64 i64_value, 18 | 6: double double_value, 19 | 7: string string_value, 20 | 8: binary binary_value, 21 | 9: map map_value, 22 | 10: list list_value, 23 | 11: set set_value, 24 | 12: TestEnum enum_value, 25 | 13: SubClass sub_class_value 26 | } 27 | --------------------------------------------------------------------------------