├── .gitignore ├── README.md ├── build.sbt ├── project ├── assembly.sbt ├── build.properties └── plugins.sbt └── src └── main └── scala └── org ├── apache └── spark │ ├── sql │ ├── MyLogging.scala │ ├── rzlabs │ │ └── DataModule.scala │ ├── sources │ │ └── druid │ │ │ ├── AggregateTransform.scala │ │ │ ├── CloseableIterator.scala │ │ │ ├── DruidPlanner.scala │ │ │ ├── DruidQueryResultIterator.scala │ │ │ ├── DruidScanResultIterator.scala │ │ │ ├── DruidSchema.scala │ │ │ ├── DruidStrategy.scala │ │ │ ├── DruidTransforms.scala │ │ │ ├── PostAggregate.scala │ │ │ └── ProjectFilterTransform.scala │ └── util │ │ └── ExprUtil.scala │ └── util │ └── MyThreadUtils.scala ├── fasterxml └── jackson │ └── databind │ └── ObjectMapper.scala └── rzlabs └── druid ├── DateTimeExtractor.scala ├── DefaultSource.scala ├── DruidDataSource.scala ├── DruidExceptions.scala ├── DruidQueryBuilder.scala ├── DruidQueryGranularity.scala ├── DruidQuerySpec.scala ├── DruidRDD.scala ├── DruidRelation.scala ├── QueryIntervals.scala ├── Utils.scala ├── client ├── CuratorConnection.scala ├── DruidClient.scala └── DruidMessages.scala ├── jscodegen ├── JSAggrGenerator.scala ├── JSCast.scala ├── JSCodeGenerator.scala ├── JSDateTime.scala └── JSExpr.scala └── metadata ├── DruidInfo.scala ├── DruidMetadataCache.scala └── DruidRelationColumn.scala /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .idea/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spark-druid-connector 2 | 3 | A library for querying Druid data sources with Apache Spark. 4 | # Compatability 5 | 6 | This libaray is compatable with Spark-2.x and Druid-0.9.0+ 7 | 8 | # Usage 9 | 10 | ## Compile 11 | 12 | ``` 13 | sbt clean assembly 14 | ``` 15 | 16 | ## Using with spark-shell 17 | 18 | ``` 19 | bin/spark-shell --jars spark-druid-connector-assembly-0.1.0-SNAPSHOT.jar 20 | ``` 21 | 22 | In spark-shell, a temp table could be created like this: 23 | 24 | ``` 25 | val df = spark.read.format("org.rzlabs.druid"). 26 | option("druidDatasource", "ds1"). 27 | option("zkHost", "localhost:2181"). 28 | option("hyperUniqueColumnInfo", """[{"column":"city", "hllMetric": "unique_city"}]""").load 29 | df.createOrReplaceTempView("ds") 30 | spark.sql("select time, sum(event) from ds group by time").show 31 | ``` 32 | 33 | or you can create a hive table: 34 | 35 | ``` 36 | spark.sql(""" 37 | create table ds1 using org.rzlabs.druid options ( 38 | druidDatasource "ds1", 39 | zkHost "localhost:2181", 40 | hyperUniqueColumnInfo, "[{\"column\": \"city\", \"hllMetric\": \"unique_city\"}]" 41 | ) 42 | """) 43 | ``` 44 | 45 | # Options 46 | 47 | |option|required|default value|descrption| 48 | |-|-|-|-| 49 | |druidDatasource|yes|none|data source name in Druid| 50 | |zkHost|no|localhost|zookeeper server Druid use, e.g., localhost:2181| 51 | |zkSessionTimeout|no|30000|zk server connection timeout| 52 | |zkEnableCompression|no|true|zk enbale compression or not| 53 | |zkDruidPath|no|/druid|The druid metadata root path in zk| 54 | |zkQualifyDiscoveryNames|no|true|| 55 | |queryGranularity|no|all|The query granularity of the Druid datasource| 56 | |maxConnectionsPerRoute|no|20|The max simultaneous live connections per Druid server| 57 | |maxConnections|no|100|The max simultaneous live connnections of the Druid cluster| 58 | |loadMetadataFromAllSegments|no|true|Fetch metadata from all available segments or not| 59 | |debugTransformations|no|false|Log debug informations about the transformations or not| 60 | |timeZoneId|no|UTC|| 61 | |useV2GroupByEngine|no|false|Use V2 groupby engine or not| 62 | |useSmile|no|true|Use smile binary format as the data format exchanged between client and Druid servers| 63 | 64 | # Major features 65 | 66 | ## Currently 67 | 68 | * Direct table creating in Spark without requiring of base table. 69 | * Support Aggregate and Project & Filter operators pushing down and transform to GROUPBY and SCAN query against Druid accordingly. 70 | * Support majority of primitive filter specs, aggregation specs and extraction functions. 71 | * Lightweight datasource metadata updating. 72 | 73 | ## In the future 74 | 75 | * Support Join operator. 76 | * Support Limit and Having operators pushing down. 77 | * Suport more primitive specs and extraction functions. 78 | * Support more Druid query specs according to query details. 79 | * Suport datasource creating and metadata lookup. 80 | * ... 81 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | 2 | val sparkVersion = "2.3.0" 3 | val json4sVersion = "3.6.0-M2" 4 | val jodaVersion = "2.9.9" 5 | val curatorVersion = "4.0.1" 6 | val jacksonVersion = "2.6.5" 7 | val apacheHttpVersion = "4.5.5" 8 | 9 | val myDependencies = Seq( 10 | "org.apache.spark" %% "spark-core" % sparkVersion % "provided", 11 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", 12 | "joda-time" % "joda-time" % jodaVersion, 13 | "org.apache.curator" % "curator-framework" % curatorVersion, 14 | "com.fasterxml.jackson.core" % "jackson-core" % jacksonVersion, 15 | "com.fasterxml.jackson.core" % "jackson-annotations" % jacksonVersion, 16 | "com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion, 17 | "com.fasterxml.jackson.module" %% "jackson-module-scala" % jacksonVersion, 18 | "com.fasterxml.jackson.dataformat" % "jackson-dataformat-smile" % jacksonVersion, 19 | "com.fasterxml.jackson.datatype" % "jackson-datatype-joda" % jacksonVersion, 20 | "com.fasterxml.jackson.jaxrs" % "jackson-jaxrs-smile-provider" % jacksonVersion, 21 | "org.apache.httpcomponents" % "httpclient" % apacheHttpVersion 22 | ) 23 | 24 | lazy val commonSettings = Seq( 25 | organization := "org.rzlabs", 26 | version := "0.1.0-SNAPSHOT", 27 | 28 | scalaVersion := "2.11.8" 29 | ) 30 | 31 | lazy val root = (project in file(".")) 32 | .settings( 33 | commonSettings, 34 | name := "spark-druid-connector", 35 | libraryDependencies ++= myDependencies 36 | ) 37 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6") 2 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | // This file should only contain the version of sbt to use. 2 | sbt.version=1.1.2 3 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | // You may use this file to add plugin dependencies for sbt. 2 | 3 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") 4 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/MyLogging.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql 2 | 3 | import org.apache.spark.internal.Logging 4 | 5 | trait MyLogging extends Logging { 6 | 7 | override def logInfo(msg: => String): Unit = { 8 | super.logInfo(msg) 9 | } 10 | 11 | override def logDebug(msg: => String): Unit = { 12 | super.logDebug(msg) 13 | } 14 | 15 | override def logTrace(msg: => String): Unit = { 16 | super.logTrace(msg) 17 | } 18 | 19 | override def logWarning(msg: => String): Unit = { 20 | super.logWarning(msg) 21 | } 22 | 23 | override def logError(msg: => String): Unit = { 24 | super.logError(msg) 25 | } 26 | 27 | override def logInfo(msg: => String, throwable: Throwable): Unit = { 28 | super.logInfo(msg, throwable) 29 | } 30 | 31 | override def logDebug(msg: => String, throwable: Throwable): Unit = { 32 | super.logDebug(msg, throwable) 33 | } 34 | 35 | override def logTrace(msg: => String, throwable: Throwable): Unit = { 36 | super.logTrace(msg, throwable) 37 | } 38 | 39 | override def logWarning(msg: => String, throwable: Throwable): Unit = { 40 | super.logWarning(msg, throwable) 41 | } 42 | 43 | override def logError(msg: => String, throwable: Throwable): Unit = { 44 | super.logError(msg, throwable) 45 | } 46 | 47 | def logInfo(msg: => String, arg: Object): Unit = { 48 | super.log.info(msg, arg) 49 | } 50 | 51 | def logDebug(msg: => String, arg: Object): Unit = { 52 | super.log.debug(msg, arg) 53 | } 54 | 55 | def logTrace(msg: => String, arg: Object): Unit = { 56 | super.log.trace(msg, arg) 57 | } 58 | 59 | def logWarning(msg: => String, arg: Object): Unit = { 60 | super.log.warn(msg, arg) 61 | } 62 | 63 | def logError(msg: => String, arg: Object): Unit = { 64 | super.log.error(msg, arg) 65 | } 66 | 67 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/rzlabs/DataModule.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.rzlabs 2 | 3 | import org.apache.spark.sql.sources.druid.{DruidPlanner, DruidStrategy} 4 | import org.apache.spark.sql.{SQLContext, SparkSession, Strategy} 5 | import org.rzlabs.druid.metadata.DruidOptions 6 | 7 | trait DataModule { 8 | 9 | def physicalRules(sqlContext: SQLContext, druidOptions: DruidOptions): Seq[Strategy] = Nil 10 | 11 | } 12 | 13 | object DruidBaseModule extends DataModule { 14 | 15 | override def physicalRules(sqlContext: SQLContext, druidOptions: DruidOptions): Seq[Strategy] = { 16 | val druidPlanner = DruidPlanner(sqlContext, druidOptions) 17 | Seq(new DruidStrategy(druidPlanner)) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/AggregateTransform.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | import org.apache.spark.sql.catalyst.analysis.TypeCoercion 4 | import org.apache.spark.sql.catalyst.expressions.aggregate._ 5 | import org.apache.spark.sql.catalyst.expressions._ 6 | import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand} 7 | import org.apache.spark.sql.types._ 8 | import org.rzlabs.druid._ 9 | import org.rzlabs.druid.jscodegen.{JSAggrGenerator, JSCodeGenerator} 10 | import org.rzlabs.druid.metadata.DruidRelationColumn 11 | 12 | trait AggregateTransform { 13 | self: DruidPlanner => 14 | 15 | /** 16 | * Collect the [[AggregateExpression]]s in ''aggregateExpressions'' 17 | * of [[Aggregate]] operator. 18 | * @param aggrExprs The aggregateExpressions of Aggregate. 19 | * @return The collected AggregateExpressions. 20 | */ 21 | def aggExpressions(aggrExprs: Seq[NamedExpression]): Seq[AggregateExpression] = { 22 | aggrExprs.flatMap(_ collect { case a: AggregateExpression => a }).distinct 23 | } 24 | 25 | def addCountAgg(dqb: DruidQueryBuilder, aggrExpr: AggregateExpression) = { 26 | val outputName = dqb.nextAlias 27 | 28 | // 'Count' is a implicit metric can be applied 'count' operator on. 29 | dqb.aggregationSpec(new CountAggregationSpec(outputName, "count")). 30 | outputAttribute(outputName, aggrExpr, aggrExpr.dataType, LongType) 31 | } 32 | 33 | private def setAggregationSpecs(dqb: DruidQueryBuilder, aggrExpr: AggregateExpression) = { 34 | 35 | (dqb, aggrExpr, aggrExpr.aggregateFunction) match { 36 | case (_, _, Count(_)) => 37 | Some(addCountAgg(dqb, aggrExpr)) 38 | //case (_, ae, fn) if JSAggrGenerator.jsAvgCandidate(dqb, fn) => 39 | case (_, _, fn @ Average(_)) => 40 | // Based on the same reason (cannot know the denominator metric) 41 | // we just throw a DruidDataSourceException. 42 | throw new DruidDataSourceException(s"${fn.toAggString(false)} calculation may " + 43 | s"not be finished correctly, because we do not know the metric specified as 'count' type " + 44 | s"at indexing time and the 'longSum' of which will be the denominator of the Average function.") 45 | case DruidNativeAggregator(dqb1) => Some(dqb1) 46 | case (_, _, fn) => 47 | for (jsdqb <- JSAggrGenerator.jsAggr(dqb, aggrExpr, fn, 48 | dqb.druidRelationInfo.options.timeZoneId)) yield 49 | jsdqb._1 50 | 51 | } 52 | } 53 | 54 | class PrimitiveExtractionFunction(dqb: DruidQueryBuilder) { 55 | 56 | self => 57 | 58 | def unapply(e: Expression): Option[(String, ExtractionFunctionSpec, DataType)] = e match { 59 | case Substring(AttributeReference(nm, _, _, _), Literal(pos, _), Literal(len, _)) => 60 | for (dc <- dqb.druidColumn(nm)) yield { 61 | val index = pos.toString.toInt 62 | val length = len.toString.toInt 63 | (if (dc.isTimeDimension) DruidDataSource.INNER_TIME_COLUMN_NAME else nm, 64 | new SubstringExtractionFunctionSpec(index, length), StringType) 65 | } 66 | case Substring(expr, Literal(pos, _), Literal(len, _)) => 67 | for ((dim, spec, dt) <- self.unapply(expr)) yield { 68 | val index = pos.toString.toInt 69 | val length = len.toString.toInt 70 | (dim, new CascadeExtractionFunctionSpec(List( 71 | spec, 72 | new SubstringExtractionFunctionSpec(index, length) 73 | )), StringType) 74 | } 75 | case Length(AttributeReference(nm, _, _, _)) => 76 | for (dc <- dqb.druidColumn(nm)) yield 77 | (if (dc.isTimeDimension) DruidDataSource.INNER_TIME_COLUMN_NAME else nm, 78 | new StrlenExtractionFunctionSpec(), IntegerType) 79 | case Length(expr) => 80 | for ((dim, spec, dt) <- self.unapply(expr)) yield 81 | (dim, new CascadeExtractionFunctionSpec(List( 82 | spec, 83 | new StrlenExtractionFunctionSpec() 84 | )), IntegerType) 85 | case Upper(AttributeReference(nm, _, _, _)) => 86 | for (dc <- dqb.druidColumn(nm)) yield 87 | (if (dc.isTimeDimension) DruidDataSource.INNER_TIME_COLUMN_NAME else nm, 88 | new UpperAndLowerExtractionFunctionSpec("upper"), StringType) 89 | case Upper(expr) => 90 | for ((dim, spec, dt) <- self.unapply(expr)) yield 91 | (dim, new CascadeExtractionFunctionSpec(List( 92 | spec, 93 | new UpperAndLowerExtractionFunctionSpec("upper") 94 | )), StringType) 95 | case Lower(AttributeReference(nm, _, _, _)) => 96 | for (dc <- dqb.druidColumn(nm)) yield 97 | (if (dc.isTimeDimension) DruidDataSource.INNER_TIME_COLUMN_NAME else nm, 98 | new UpperAndLowerExtractionFunctionSpec("lower"), StringType) 99 | case Lower(expr) => 100 | for ((dim, spec, dt) <- self.unapply(expr)) yield 101 | (dim, new CascadeExtractionFunctionSpec(List( 102 | spec, 103 | new UpperAndLowerExtractionFunctionSpec("lower") 104 | )), StringType) 105 | 106 | case Cast(expr, dt, _) => 107 | for ((dim, spec, _) <- self.unapply(expr)) yield { 108 | (dim, spec, dt) 109 | } 110 | //TODO: Add more extraction function check. 111 | case _ => None 112 | } 113 | } 114 | 115 | private def setDimensionSpecs(dqb: DruidQueryBuilder, 116 | timeElemExtractor: SparkNativeTimeElementExtractor, 117 | primitiveExtractionFunction: PrimitiveExtractionFunction, 118 | grpExpr: Expression 119 | ): Option[DruidQueryBuilder] = { 120 | 121 | grpExpr match { 122 | case AttributeReference(nm, dataType, _, _) if dqb.isNonTimeDimension(nm) => 123 | val dc = dqb.druidColumn(nm).get 124 | Some(dqb.dimensionSpec(new DefaultDimensionSpec(dc.name, nm)).outputAttribute(nm, 125 | grpExpr, dataType, DruidDataType.sparkDataType(dc.dataType))) 126 | case AttributeReference(nm, dataType, _, _) if dqb.isNotIndexedDimension(nm) => 127 | val dc = dqb.druidColumn(nm).get 128 | log.warn(s"Column '$nm' is not indexed into datasource.") 129 | Some(dqb.dimensionSpec(new DefaultDimensionSpec(dc.name, nm)).outputAttribute(nm, 130 | grpExpr, dataType, DruidDataType.sparkDataType(dc.dataType))) 131 | case timeElemExtractor(dtGrp) => 132 | val timeFmtExtractFunc: ExtractionFunctionSpec = { 133 | if (dtGrp.inputFormat.isDefined) { 134 | new TimeParsingExtractionFunctionSpec(dtGrp.inputFormat.get, dtGrp.formatToApply) 135 | } else { 136 | new TimeFormatExtractionFunctionSpec(dtGrp.formatToApply, dtGrp.timeZone.getOrElse(null)) 137 | } 138 | } 139 | // If the related column is time column, we should give it to the inner name "__time" to ensure 140 | // correctness in querySpec. 141 | val colName = if (dtGrp.druidColumn.isTimeDimension) { 142 | DruidDataSource.INNER_TIME_COLUMN_NAME 143 | } else dtGrp.druidColumn.name 144 | Some(dqb.dimensionSpec( 145 | new ExtractionDimensionSpec(colName, timeFmtExtractFunc, dtGrp.outputName)) 146 | .outputAttribute(dtGrp.outputName, grpExpr, grpExpr.dataType, 147 | DruidDataType.sparkDataType(dtGrp.druidColumn.dataType))) 148 | case primitiveExtractionFunction(dim, extractionFunctionSpec, dt) => 149 | val outDName = dqb.nextAlias 150 | Some(dqb.dimensionSpec(new ExtractionDimensionSpec(dim, extractionFunctionSpec, outDName)). 151 | outputAttribute(outDName, grpExpr, grpExpr.dataType, dt)) 152 | case _ => 153 | val codeGen = JSCodeGenerator(dqb, grpExpr, false, false, 154 | dqb.druidRelationInfo.options.timeZoneId) 155 | for (fn <- codeGen.fnCode) yield { 156 | val outDName = dqb.nextAlias 157 | dqb.dimensionSpec(new ExtractionDimensionSpec(codeGen.fnParams.last, 158 | new JavascriptExtractionFunctionSpec(fn), outDName)). 159 | outputAttribute(outDName, grpExpr, grpExpr.dataType, StringType) 160 | } 161 | } 162 | } 163 | 164 | private def transformAggregation(dqb: DruidQueryBuilder, 165 | aggOp: Aggregate, 166 | grpExprs: Seq[Expression], 167 | aggrExprs: Seq[NamedExpression] 168 | ): Option[DruidQueryBuilder] = { 169 | 170 | val timeElemExtractor = new SparkNativeTimeElementExtractor()(dqb) 171 | val primitiveExtractionFunction = new PrimitiveExtractionFunction(dqb) 172 | 173 | val dqb1 = grpExprs.foldLeft(Some(dqb).asInstanceOf[Option[DruidQueryBuilder]]) { 174 | (odqb, e) => odqb.flatMap(setDimensionSpecs(_, timeElemExtractor, 175 | primitiveExtractionFunction, e)) 176 | } 177 | 178 | // all AggregateExpressions in agregateExpressions list. 179 | val allAggrExprs = aggExpressions(aggrExprs) 180 | 181 | val dqb2 = allAggrExprs.foldLeft(dqb1) { 182 | (dqb, ae) => dqb.flatMap(setAggregationSpecs(_, ae)) 183 | } 184 | 185 | dqb2.map(_.aggregateOp(aggOp)) 186 | } 187 | 188 | private def attrRefName(e: Expression): Option[String] = { 189 | e match { 190 | case AttributeReference(nm, _, _, _) => Some(nm) 191 | case Cast(AttributeReference(nm, _, _, _), _, _) => Some(nm) 192 | case Alias(AttributeReference(nm, _, _, _), _) => Some(nm) 193 | case _ => None 194 | } 195 | } 196 | 197 | private object DruidNativeAggregator { 198 | 199 | def unapply(t: (DruidQueryBuilder, AggregateExpression, AggregateFunction)): 200 | Option[DruidQueryBuilder] = { 201 | val dqb = t._1 202 | val aggrExpr = t._2 203 | val aggrFunc = t._3 204 | val outputName = dqb.nextAlias 205 | (dqb, aggrFunc, outputName) match { 206 | case ApproximateCountDistinctAggregate(aggrSpec) => 207 | Some(dqb.aggregationSpec(aggrSpec). 208 | outputAttribute(outputName, aggrExpr, aggrExpr.dataType, LongType)) 209 | 210 | case SumMinMaxFirstLastAggregate(dc, aggrSpec) => 211 | Some(dqb.aggregationSpec(aggrSpec). 212 | outputAttribute(outputName, aggrExpr, aggrExpr.dataType, 213 | DruidDataType.sparkDataType(dc.dataType))) 214 | 215 | case AvgAggregate(dqb1, sumAlias, countAlias) => 216 | Some(dqb1.avgExpression(aggrExpr, sumAlias, countAlias)) 217 | 218 | case _ => None 219 | } 220 | } 221 | } 222 | 223 | private object AvgAggregate { 224 | 225 | def unapply(t: (DruidQueryBuilder, AggregateFunction, String)): 226 | Option[(DruidQueryBuilder, String, String)] = { 227 | val dqb = t._1 228 | val aggrFunc = t._2 229 | val outputName = t._3 230 | val r = for (c <- aggrFunc.children.headOption if aggrFunc.children.size == 1; 231 | columnName <- attrRefName(c); 232 | dc <- dqb.druidColumn(columnName) if dc.isMetric; 233 | cdt <- Some(DruidDataType.sparkDataType(dc.dataType)); 234 | dt <- TypeCoercion.findTightestCommonType(aggrFunc.dataType, cdt) 235 | ) yield (aggrFunc, dt, dc, outputName) 236 | 237 | r.flatMap { 238 | // count may not be the count metric!!! 239 | // case (_: Average, dt, dc, outputName) 240 | // if (dqb.druidRelationInfo.druidColumns.exists(_ == "count")) => 241 | // val outputName2 = dqb.nextAlias 242 | // val druidAggrFunc = dc.dataType match { 243 | // case DruidDataType.Long => "longSum" 244 | // case _ => "doubleSum" 245 | // } 246 | // val aggrFuncDataType = DruidDataType.sparkDataType(dc.dataType) 247 | // Some((dqb.aggregationSpec(SumAggregationSpec(druidAggrFunc, outputName, dc.name)). 248 | // outputAttribute(outputName, null, aggrFuncDataType, aggrFuncDataType). 249 | // aggregationSpec(SumAggregationSpec("longSum", outputName2, "count")). 250 | // outputAttribute(outputName2, null, LongType, LongType), outputName, outputName2)) 251 | case (fn: Average, _, _, _) => 252 | throw new DruidDataSourceException(s"${fn.toAggString(false)} calculation may " + 253 | s"not be finished correctly, because we do not know the metric specified as 'count' type " + 254 | s"at indexing time and the 'longSum' of which will be the denominator of the Average function.") 255 | case _ => None 256 | } 257 | } 258 | } 259 | 260 | private object SumMinMaxFirstLastAggregate { 261 | 262 | def unapply(t: (DruidQueryBuilder, AggregateFunction, String)): 263 | Option[(DruidRelationColumn, AggregationSpec)] = { 264 | val dqb = t._1 265 | val aggrFunc = t._2 266 | val outputName = t._3 267 | val r = for (c <- aggrFunc.children.headOption if aggrFunc.children.size == 1; 268 | columnName <- attrRefName(c); 269 | dc <- dqb.druidColumn(columnName) if dc.isMetric; 270 | cdt <- Some(DruidDataType.sparkDataType(dc.dataType)); 271 | dt <- TypeCoercion.findTightestCommonType(aggrFunc.dataType, cdt) 272 | ) yield 273 | (aggrFunc, dt, dc, outputName) 274 | 275 | r.flatMap { 276 | case (_: Sum, LongType, dc, outputName) => 277 | Some(dc -> SumAggregationSpec("longSum", outputName, dc.name)) 278 | case (_: Sum, FloatType, dc, outputName) => 279 | Some(dc -> SumAggregationSpec("floatSum", outputName, dc.name)) 280 | case (_: Sum, DoubleType, dc, outputName) => 281 | Some(dc -> SumAggregationSpec("doubleSum", outputName, dc.name)) 282 | case (_: Min, LongType, dc, outputName) => 283 | Some(dc -> MinAggregationSpec("longMin", outputName, dc.name)) 284 | case (_: Min, FloatType, dc, outputName) => 285 | Some(dc -> MinAggregationSpec("floatMin", outputName, dc.name)) 286 | case (_: Min, DoubleType, dc, outputName) => 287 | Some(dc -> MinAggregationSpec("doubleMin", outputName, dc.name)) 288 | case (_: Max, LongType, dc, outputName) => 289 | Some(dc -> MaxAggregationSpec("longMax", outputName, dc.name)) 290 | case (_: Max, FloatType, dc, outputName) => 291 | Some(dc -> MaxAggregationSpec("floatMax", outputName, dc.name)) 292 | case (_: Max, DoubleType, dc, outputName) => 293 | Some(dc -> MaxAggregationSpec("doubleMax", outputName, dc.name)) 294 | case (_: First, LongType, dc, outputName) => 295 | Some(dc -> FirstAggregationSpec("longFirst", outputName, dc.name)) 296 | case (_: First, FloatType, dc, outputName) => 297 | Some(dc -> FirstAggregationSpec("floatFirst", outputName, dc.name)) 298 | case (_: First, DoubleType, dc, outputName) => 299 | Some(dc -> FirstAggregationSpec("doubleFirst", outputName, dc.name)) 300 | case (_: Last, LongType, dc, outputName) => 301 | Some(dc -> LastAggregationSpec("longLast", outputName, dc.name)) 302 | case (_: Last, FloatType, dc, outputName) => 303 | Some(dc -> LastAggregationSpec("floatLast", outputName, dc.name)) 304 | case (_: Last, DoubleType, dc, outputName) => 305 | Some(dc -> LastAggregationSpec("doubleLast", outputName, dc.name)) 306 | case _ => None 307 | } 308 | } 309 | } 310 | 311 | private def isHyperUniqueAggregator(dqb: DruidQueryBuilder, dc: DruidRelationColumn): Boolean = { 312 | val aggregators = dqb.druidRelationInfo.druidDataSource.aggregators 313 | if (dc.hasHllMetric) { 314 | aggregators.map { aggrs => 315 | aggrs.find(_._1 == dc.hllMetric.get.name).map { aggr => 316 | DruidDataType.withName(aggr._2.`type`) == DruidDataType.HyperUnique 317 | }.getOrElse(false) 318 | }.getOrElse(true) // Have no aggregators info got from MetadataResponse. 319 | } else false 320 | } 321 | 322 | private def isThetaSketchAggregator(dqb: DruidQueryBuilder, dc: DruidRelationColumn): Boolean = { 323 | val aggregators = dqb.druidRelationInfo.druidDataSource.aggregators 324 | if (dc.hasSketchMetric) { 325 | aggregators.map { aggrs => 326 | aggrs.find(_._1 == dc.sketchMetric.get.name).map { aggr => 327 | DruidDataType.withName(aggr._2.`type`) == DruidDataType.ThetaSketch 328 | }.getOrElse(false) 329 | }.getOrElse(true) // Have no aggregators info got from MetadataResponse. 330 | } else false 331 | } 332 | 333 | private object ApproximateCountDistinctAggregate { 334 | 335 | def unapply(t: (DruidQueryBuilder, AggregateFunction, String)): Option[AggregationSpec] = { 336 | val dqb = t._1 337 | val aggFunc = t._2 338 | val outputName = t._3 339 | // Druid's aggregators only accept one argument. 340 | val r = for (c <- aggFunc.children.headOption if aggFunc.children.size == 1; 341 | columnName <- attrRefName(c); 342 | dc <- dqb.druidColumn(columnName) 343 | if dc.isDimension(true) || dc.hasHllMetric) yield 344 | (aggFunc, dc, outputName) 345 | // TODO: Sketch supports. 346 | r.flatMap { 347 | case (_: HyperLogLogPlusPlus, dc, outputName) if isHyperUniqueAggregator(dqb, dc) => 348 | Some(new HyperUniqueAggregationSpec(outputName, dc.hllMetric.get.name)) 349 | case (_: HyperLogLogPlusPlus, dc, outputName) if isThetaSketchAggregator(dqb, dc) => 350 | Some(new SketchAggregationSpec(outputName, dc.sketchMetric.get.name)) 351 | case (_: HyperLogLogPlusPlus, dc, outputName) => 352 | Some(new CardinalityAggregationSpec(outputName, List(dc.name))) 353 | case _ => None // not approximate count distinct aggregation 354 | } 355 | } 356 | } 357 | 358 | val aggregateTransform: DruidTransform = { 359 | 360 | case (dqb, Aggregate(_, _, Aggregate(_, _, Expand(_, _, _)))) => 361 | // There are more than 1 distinct aggregate expressions. 362 | // Because Druid cannot handle accurate distinct operation, 363 | // so we do not push aggregation down to Druid. 364 | throw new DruidDataSourceException("Currently the DISTINCT operation is not permitted. " + 365 | "If you submit a COUNT(DISTINCT) aggregation function, " + 366 | "please use APPROX_COUNT_DISTINCT instead.") 367 | case (_, Aggregate(_, _, Aggregate(_, _, _))) => Nil 368 | case (dqb, agg @ Aggregate(grpExprs, aggrExprs, child)) => 369 | // There is 1 distinct aggregate expressions. 370 | // Because Druid cannot handle accurate distinct operation, 371 | // so we do not push aggregation down to Druid. 372 | if (aggrExprs.exists { 373 | case ne: NamedExpression => ne.find { 374 | case ae: AggregateExpression if ae.isDistinct => true 375 | case _ => false 376 | }.isDefined 377 | }) { 378 | throw new DruidDataSourceException("Currently the DISTINCT operation is not permitted. " + 379 | "If you submit a COUNT(DISTINCT) aggregation function, " + 380 | "please use APPROX_COUNT_DISTINCT instead.") 381 | } else { 382 | // There is no distinct aggregate expressions. 383 | // Returns Nil if plan returns Nil. 384 | plan(dqb, child).flatMap { dqb => 385 | transformAggregation(dqb, agg, grpExprs, aggrExprs) 386 | } 387 | } 388 | case _ => Nil 389 | } 390 | } 391 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/CloseableIterator.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | trait CloseableIterator[+A] extends Iterator[A] { 4 | def closeIfNeeded(): Unit 5 | } 6 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/DruidPlanner.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | import org.apache.spark.sql.SQLContext 4 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 5 | import org.rzlabs.druid.DruidQueryBuilder 6 | import org.rzlabs.druid.client.ConnectionManager 7 | import org.rzlabs.druid.metadata.DruidOptions 8 | 9 | class DruidPlanner(val sqlContext: SQLContext, val druidOptions: DruidOptions) extends DruidTransforms 10 | with AggregateTransform with ProjectFilterTransform { 11 | 12 | val transforms: Seq[DruidTransform] = Seq( 13 | aggregateTransform.debug("aggregate"), 14 | druidRelationTransform.debug("druidRelationTransform") 15 | ) 16 | 17 | def plan(dqb: Seq[DruidQueryBuilder], plan: LogicalPlan): Seq[DruidQueryBuilder] = { 18 | transforms.view.flatMap(_(dqb, plan)) 19 | } 20 | } 21 | 22 | object DruidPlanner { 23 | 24 | def apply(sqlContext: SQLContext, druidOptions: DruidOptions) = { 25 | val planner = new DruidPlanner(sqlContext, druidOptions) 26 | ConnectionManager.init(druidOptions) 27 | planner 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/DruidQueryResultIterator.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | import java.io.{ByteArrayInputStream, InputStream} 4 | 5 | import org.apache.spark.util.NextIterator 6 | import org.rzlabs.druid.client.{QueryResultRow, ResultRow} 7 | import org.fasterxml.jackson.databind.ObjectMapper._ 8 | import com.fasterxml.jackson.core._ 9 | import com.fasterxml.jackson.core.`type`.TypeReference 10 | import org.apache.commons.io.IOUtils 11 | 12 | private class DruidQueryResultStreamingIterator(useSmile: Boolean, 13 | is: InputStream, 14 | onDone: => Unit = () 15 | ) extends NextIterator[QueryResultRow] 16 | with CloseableIterator[QueryResultRow] { 17 | 18 | // In NextIterator the abstract `closeIfNeeded` 19 | // method declared in CloseableIterator is defined. 20 | 21 | private val (mapper, factory) = if (useSmile) { 22 | (smileMapper, smileMapper.getFactory) 23 | } else { 24 | (jsonMapper, jsonMapper.getFactory) 25 | } 26 | 27 | private val parser = factory.createParser(is) 28 | var token = parser.nextToken() // current token is START_ARRAY 29 | token = parser.nextToken() // current token is START_OBJECT or END_ARRAY (empty result set) 30 | 31 | override protected def getNext(): QueryResultRow = { 32 | if (token == JsonToken.END_ARRAY) { 33 | finished = true 34 | null 35 | } else if (token == JsonToken.START_OBJECT) { 36 | val r: QueryResultRow = mapper.readValue(parser, new TypeReference[QueryResultRow] {}) 37 | token = parser.nextToken() 38 | r 39 | } else null 40 | } 41 | 42 | override protected def close(): Unit = { 43 | parser.close() 44 | onDone 45 | } 46 | } 47 | 48 | private class DruidQueryResultStaticIterator(useSmile: Boolean, 49 | is: InputStream, 50 | onDone: => Unit = () 51 | ) extends NextIterator[QueryResultRow] 52 | with CloseableIterator[QueryResultRow] { 53 | 54 | val rowList: List[QueryResultRow] = if (useSmile) { 55 | val bais = new ByteArrayInputStream(IOUtils.toByteArray(is)) 56 | smileMapper.readValue(bais, new TypeReference[List[QueryResultRow]] {}) 57 | } else { 58 | jsonMapper.readValue(is, new TypeReference[List[QueryResultRow]] {}) 59 | } 60 | 61 | onDone 62 | 63 | val iter = rowList.toIterator 64 | 65 | override protected def getNext(): QueryResultRow = { 66 | if (iter.hasNext) { 67 | iter.next() 68 | } else { 69 | finished = true 70 | null 71 | } 72 | } 73 | 74 | override protected def close(): Unit = () // This because the onDone is called in constructor. 75 | } 76 | 77 | object DruidQueryResultIterator { 78 | 79 | def apply(useSmile: Boolean, 80 | is: InputStream, 81 | onDone: => Unit = (), 82 | fromList: Boolean = false): CloseableIterator[QueryResultRow] = { 83 | if(fromList) { 84 | new DruidQueryResultStaticIterator(useSmile, is, onDone) 85 | } else { 86 | new DruidQueryResultStreamingIterator(useSmile, is, onDone) 87 | } 88 | } 89 | } 90 | 91 | class DummyResultIterator extends NextIterator[ResultRow] with CloseableIterator[ResultRow] { 92 | override protected def getNext(): ResultRow = { 93 | finished = true 94 | null 95 | } 96 | 97 | override protected def close(): Unit = () 98 | } 99 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/DruidScanResultIterator.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | import java.io.{ByteArrayInputStream, InputStream} 4 | 5 | import org.apache.spark.util.NextIterator 6 | import org.rzlabs.druid.client.{QueryResultRow, ResultRow, ScanResultRow} 7 | import org.fasterxml.jackson.databind.ObjectMapper._ 8 | import com.fasterxml.jackson.core._ 9 | import com.fasterxml.jackson.core.`type`.TypeReference 10 | import org.apache.commons.io.IOUtils 11 | 12 | private class DruidScanResultStreamingIterator(useSmile: Boolean, 13 | is: InputStream, 14 | onDone: => Unit = () 15 | ) extends NextIterator[ScanResultRow] 16 | with CloseableIterator[ScanResultRow] { 17 | 18 | // In NextIterator the abstract `closeIfNeeded` 19 | // method declared in CloseableIterator is defined. 20 | 21 | private val (mapper, factory) = if (useSmile) { 22 | (smileMapper, smileMapper.getFactory) 23 | } else { 24 | (jsonMapper, jsonMapper.getFactory) 25 | } 26 | private val parser = factory.createParser(is) 27 | private var token = parser.nextToken() // START_ARRAY 28 | token = parser.nextToken() // START_OBJECT or END_ARRAY (empty result set) 29 | 30 | var lastToken: JsonToken = JsonToken.START_ARRAY 31 | 32 | override protected def getNext(): ScanResultRow = { 33 | 34 | if ((lastToken == JsonToken.START_ARRAY || 35 | lastToken == JsonToken.END_OBJECT) && token != JsonToken.END_ARRAY) { 36 | token = parser.nextToken() // FIELD_NAME -- segmentId 37 | token = parser.nextToken() // VALUE_STRING -- segment name 38 | token = parser.nextToken() // FIELD_NAME -- columns 39 | token = parser.nextToken() // START_ARRAY -- column array 40 | val columns: List[String] = mapper.readValue(parser, classOf[List[String]]) 41 | token = parser.nextToken() // FIELD_NAME -- events 42 | token = parser.nextToken() // START_ARRAY -- event array 43 | token = parser.nextToken() // START_OBJECT -- event object 44 | } 45 | 46 | if (token == JsonToken.END_ARRAY && lastToken == JsonToken.END_OBJECT) { 47 | finished = true 48 | null 49 | } else if (token == JsonToken.START_OBJECT) { 50 | val r: ScanResultRow = ScanResultRow( 51 | mapper.readValue(parser, classOf[Map[String, Any]])) 52 | token = parser.nextToken() // START_OBJECT or END_ARRAY 53 | if (token == JsonToken.END_ARRAY) { 54 | lastToken = parser.nextToken() // END_OBJECT 55 | token = parser.nextToken() // END_ARRAY or START_OBJECT 56 | } else { 57 | lastToken = JsonToken.START_OBJECT 58 | } 59 | r 60 | } else null 61 | } 62 | 63 | override protected def close(): Unit = { 64 | parser.close() 65 | onDone 66 | } 67 | } 68 | 69 | private class DruidScanResultStaticIterator(useSmile: Boolean, 70 | is: InputStream, 71 | onDone: => Unit = () 72 | ) extends NextIterator[ScanResultRow] 73 | with CloseableIterator[ScanResultRow] { 74 | 75 | private var rowList: List[ScanResultRow] = List() 76 | 77 | private val (mapper, factory) = if (useSmile) { 78 | (smileMapper, smileMapper.getFactory) 79 | } else { 80 | (jsonMapper, jsonMapper.getFactory) 81 | } 82 | private val parser = factory.createParser(is) 83 | 84 | private var token = parser.nextToken() // START_ARRAY 85 | token = parser.nextToken() // START_OBJECT or END_ARRAY 86 | while (token == JsonToken.START_OBJECT) { 87 | token = parser.nextToken() // FIELD_NAME -- segmentId 88 | token = parser.nextToken() // VALUE_STRING -- segment name 89 | token = parser.nextToken() // FIELD_NAME -- columns 90 | token = parser.nextToken() // START_ARRAY -- column array 91 | val columns: List[String] = mapper.readValue(parser, classOf[List[String]]) 92 | token = parser.nextToken() // FIELD_NAME -- events 93 | token = parser.nextToken() // START_ARRAY -- event array 94 | val events: List[Map[String, Any]] = mapper.readValue(parser, classOf[List[Map[String, Any]]]) 95 | rowList = rowList ++ events.map(ScanResultRow(_)) 96 | token = parser.nextToken() // END_OBJECT 97 | token = parser.nextToken() // START_OBJECT or END_ARRAY 98 | } 99 | 100 | // After loop, here will be END_ARRAY token 101 | assert(parser.nextToken() == JsonToken.END_ARRAY) 102 | 103 | onDone 104 | parser.close() 105 | 106 | val iter = rowList.toIterator 107 | 108 | override protected def getNext(): ScanResultRow = { 109 | if (iter.hasNext) { 110 | iter.next() 111 | } else { 112 | finished = true 113 | null 114 | } 115 | } 116 | 117 | override protected def close(): Unit = () // This because the onDone is called in constructor. 118 | } 119 | 120 | object DruidScanResultIterator { 121 | 122 | def apply(useSmile: Boolean, 123 | is: InputStream, 124 | onDone: => Unit = (), 125 | fromList: Boolean = false): CloseableIterator[ScanResultRow] = { 126 | if(fromList) { 127 | new DruidScanResultStaticIterator(useSmile, is, onDone) 128 | } else { 129 | new DruidScanResultStreamingIterator(useSmile, is, onDone) 130 | } 131 | } 132 | } 133 | 134 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/DruidSchema.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | import org.apache.spark.sql.catalyst.expressions._ 4 | import org.apache.spark.sql.types.DataType 5 | import org.rzlabs.druid.{DruidAttribute, DruidQueryBuilder} 6 | 7 | class DruidSchema(val dqb: DruidQueryBuilder) { 8 | 9 | def avgExpressions: Map[Expression, (String, String)] = dqb.avgExpressions 10 | 11 | lazy val druidAttributes: List[DruidAttribute] = druidAttrMap.values.toList 12 | 13 | lazy val druidAttrMap: Map[String, DruidAttribute] = buildDruidAttr 14 | 15 | lazy val schema: List[Attribute] = druidAttributes.map { 16 | case DruidAttribute(exprId, name, druidDT, tf) => 17 | AttributeReference(name, druidDT)(exprId) 18 | } 19 | 20 | lazy val pushedDownExprToDruidAttr: Map[Expression, DruidAttribute] = 21 | buildPushDownDruidAttrsMap 22 | 23 | private def buildPushDownDruidAttrsMap: Map[Expression, DruidAttribute] = { 24 | dqb.outputAttributeMap.map { 25 | case (name, (expr, _, _, _)) => expr -> druidAttrMap(name) 26 | } 27 | } 28 | 29 | private def buildDruidAttr: Map[String, DruidAttribute] = { 30 | 31 | dqb.outputAttributeMap.map { 32 | case (name, (expr, _, druidDT, tf)) => { 33 | val druidExprId = expr match { 34 | case null => NamedExpression.newExprId 35 | case ne: NamedExpression => ne.exprId 36 | case _ => NamedExpression.newExprId 37 | } 38 | (name -> DruidAttribute(druidExprId, name, druidDT, tf)) 39 | } 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/DruidStrategy.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} 4 | import org.apache.spark.sql.execution._ 5 | import org.apache.spark.sql._ 6 | import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, Divide, Expression, NamedExpression} 7 | import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning 8 | import org.apache.spark.sql.sources.Filter 9 | import org.apache.spark.sql.types.{DoubleType, StringType} 10 | import org.apache.spark.sql.util.ExprUtil 11 | import org.rzlabs.druid._ 12 | import org.rzlabs.druid.metadata.DruidRelationColumn 13 | 14 | import scala.collection.mutable.{Map => MMap} 15 | 16 | private[sql] class DruidStrategy(planner: DruidPlanner) extends Strategy 17 | with MyLogging { 18 | 19 | override def apply(lp: LogicalPlan): Seq[SparkPlan] = { 20 | 21 | val plan: Seq[SparkPlan] = for (dqb <- planner.plan(null, lp)) yield { 22 | if (dqb.aggregateOper.isDefined) { 23 | aggregatePlan(dqb) 24 | } else { 25 | scanPlan(dqb, lp) 26 | } 27 | } 28 | 29 | plan.filter(_ != null).toList 30 | if (plan.size < 2) plan else Seq(UnionExec(plan)) 31 | } 32 | 33 | private def scanPlan(dqb: DruidQueryBuilder, lp: LogicalPlan): SparkPlan = { 34 | lp match { 35 | // Just in the case that Project operator as current logical plan 36 | // the query spec will be generated 37 | case Project(projectList, _) => scanPlan(dqb, projectList) 38 | case _ => null 39 | } 40 | } 41 | 42 | private def scanPlan(dqb: DruidQueryBuilder, 43 | projectList: Seq[NamedExpression]): SparkPlan = { 44 | 45 | // Replace __time with timestamp because Druid returns 'tiemstamp' field 46 | // to represent the timestamp string if specifying 'lagecy' in scanQuerySpec . 47 | val referredDruidColumn = dqb.referencedDruidColumns.mapValues { dc => 48 | if (dc.isTimeDimension) { 49 | // The dataType of druidColumn of type DruidColumn in 50 | // DruidRelationColumn has already been set as StringType. 51 | dc.copy(druidColumn = 52 | dc.druidColumn.map(d => 53 | d.asInstanceOf[DruidTimeDimension].copy(name = DruidDataSource.TIMESTAMP_KEY_NAME))) 54 | } else dc 55 | } 56 | 57 | var dqb1 = dqb.copy(referencedDruidColumns = MMap(referredDruidColumn.toSeq: _*)) 58 | 59 | def addOutputAttributes(exprs: Seq[Expression]) = { 60 | for (na <- exprs; 61 | attr <- na.references; 62 | dc <- dqb.druidColumn(attr.name)) { 63 | dqb1 = if (dc.isTimeDimension) { 64 | dqb1.outputAttribute(DruidDataSource.TIMESTAMP_KEY_NAME, attr, attr.dataType, 65 | DruidDataType.sparkDataType(dc.dataType), null) 66 | } else { 67 | dqb1.outputAttribute(attr.name, attr, attr.dataType, 68 | DruidDataType.sparkDataType(dc.dataType), null) 69 | } 70 | } 71 | } 72 | 73 | // Set outputAttrs with projectList 74 | addOutputAttributes(projectList) 75 | // Set outputAttrs with filters 76 | addOutputAttributes(dqb1.origFilter.toSeq) 77 | 78 | val druidSchema = new DruidSchema(dqb1) 79 | 80 | val intervals = dqb1.queryIntervals.get 81 | 82 | // Remove the time dimension in select list because we will 83 | // use 'timestamp' field instead. 84 | var columns = dqb1.referencedDruidColumns.values.collect { 85 | case col if !col.isTimeDimension => col.name 86 | }.toList 87 | 88 | // The empty column list indicates that the time field is removed, 89 | // In order to prevent scan query returning all fields we just add 90 | // a column. 91 | if (columns.isEmpty) { 92 | columns = columns :+ dqb1.druidRelationInfo.druidColumns.head._1 93 | } 94 | 95 | var qrySpec: QuerySpec = 96 | new ScanQuerySpec(dqb1.druidRelationInfo.druidDataSource.name, 97 | columns, 98 | dqb1.filterSpec, 99 | intervals.map(_.toString), 100 | None, 101 | None, 102 | true, 103 | Some(QuerySpecContext(s"query-${System.nanoTime()}")) 104 | ) 105 | 106 | val queryHistorical = false 107 | val numSegsPerQuery = -1 108 | 109 | val druidQuery = DruidQuery(qrySpec, 110 | dqb1.druidRelationInfo.options.useSmile, 111 | false, -1, 112 | intervals, 113 | Some(druidSchema.druidAttributes)) 114 | 115 | def postDruidStep(plan: SparkPlan): SparkPlan = { 116 | // TODO: always return arg currently 117 | plan 118 | } 119 | 120 | def buildProjectList(dqb: DruidQueryBuilder, druidSchema: DruidSchema): Seq[NamedExpression] = { 121 | buildProjectionList(projectList, druidSchema) 122 | } 123 | 124 | buildPlan(dqb1, druidSchema, druidQuery, planner, postDruidStep _, buildProjectList _) 125 | } 126 | 127 | // private def selectPlan(dqb: DruidQueryBuilder, lp: LogicalPlan): SparkPlan = { 128 | // lp match { 129 | // // Just in the case that Project operator as current logical plan 130 | // // the query spec will be generated. 131 | // case Project(projectList, _) => selectPlan(dqb, projectList) 132 | // case _ => 133 | // } 134 | // } 135 | 136 | // private def selectPlan(dqb: DruidQueryBuilder, 137 | // projectList: Seq[NamedExpression]): SparkPlan = { 138 | // 139 | // 140 | // // Replace __time with timestamp because Druid returns 'tiemstamp' field 141 | // // by default to represent the timestamp string. 142 | // val referredDruidColumn = dqb.referencedDruidColumns.mapValues { dc => 143 | // if (dc.isTimeDimension) { 144 | // // The dataType of druidColumn of type DruidColumn in 145 | // // DruidRelationColumn has already been set as StringType. 146 | // dc.copy(druidColumn = 147 | // dc.druidColumn.map(d => 148 | // d.asInstanceOf[DruidTimeDimension].copy(name = DruidDataSource.TIMESTAMP_KEY_NAME))) 149 | // } else dc 150 | // } 151 | // 152 | // var dqb1 = dqb.copy(referencedDruidColumns = MMap(referredDruidColumn.toSeq: _*)) 153 | // 154 | // // Set outputAttrs with projectList 155 | // for (na <- projectList; 156 | // attr <- na.references; 157 | // dc <- dqb.druidColumn(attr.name)) { 158 | // dqb1 = dqb1.outputAttribute(attr.name, attr, attr.dataType, 159 | // DruidDataType.sparkDataType(dc.dataType), null) 160 | // } 161 | // 162 | // // Set outputAttrs with filters 163 | // for (e <- dqb.origFilter; 164 | // attr <- e.references; 165 | // dc <- dqb.druidColumn(attr.name)) { 166 | // dqb1 = dqb1.outputAttribute(attr.name, attr, attr.dataType, 167 | // DruidDataType.sparkDataType(dc.dataType), null) 168 | // } 169 | // 170 | // val druidSchema = new DruidSchema(dqb1) 171 | // 172 | // var (dims, metrics) = dqb1.referencedDruidColumns.values.partition(_.isDimension()) 173 | // 174 | // // Remove dimension with name 'timestamp' because Druid will return this field. 175 | // dims = dims.filterNot(_.druidColumn.filter(_ == DruidDataSource.TIMESTAMP_KEY_NAME).nonEmpty) 176 | // 177 | // /* 178 | // * If dimensions or metrics are empty, arbitrarily pick 1 dimension and metric. 179 | // * Otherwise Druid will return all dimensions/metrics. 180 | // */ 181 | // if (dims.isEmpty) { 182 | // dims = dqb1.druidRelationInfo.druidColumns.find(_._2.isDimension(true)).map(_._2) 183 | // } 184 | // if (metrics.isEmpty) { 185 | // metrics = dqb1.druidRelationInfo.druidColumns.find(_._2.isMetric).map(_._2) 186 | // } 187 | // 188 | // val intervals = dqb1.queryIntervals.get 189 | // 190 | // null 191 | // } 192 | 193 | private def aggregatePlan(dqb: DruidQueryBuilder): SparkPlan = { 194 | val druidSchema = new DruidSchema(dqb) 195 | val postAgg = new PostAggregate(druidSchema) 196 | 197 | val queryIntervals = dqb.queryIntervals.get 198 | 199 | val qrySpec: QuerySpec = new GroupByQuerySpec( 200 | dqb.druidRelationInfo.druidDataSource.name, 201 | dqb.dimensions, 202 | dqb.limitSpec, 203 | dqb.havingSpec, 204 | dqb.granularitySpec, 205 | dqb.filterSpec, 206 | dqb.aggregations, 207 | dqb.postAggregations, 208 | queryIntervals.map(_.toString), 209 | Some(QuerySpecContext(s"query-${System.nanoTime()}"))) 210 | 211 | // TODO: any necessary transformations. 212 | 213 | qrySpec.context.foreach { ctx => 214 | if (dqb.druidRelationInfo.options.useV2GroupByEngine) { 215 | ctx.groupByStrategy = Some("v2") 216 | } 217 | } 218 | 219 | // TODO: handle the cost of post aggregation at historical. 220 | val queryHistorical = false 221 | val numSegsPerQuery = -1 222 | 223 | val druidQuery = DruidQuery(qrySpec, 224 | dqb.druidRelationInfo.options.useSmile, 225 | queryHistorical, 226 | numSegsPerQuery, 227 | queryIntervals, 228 | Some(druidSchema.druidAttributes)) 229 | 230 | def postDruidStep(plan: SparkPlan): SparkPlan = { 231 | // TODO: always return arg currently 232 | plan 233 | } 234 | 235 | def buildProjectList(dqb: DruidQueryBuilder, druidSchema: DruidSchema): Seq[NamedExpression] = { 236 | buildProjectionList(dqb.aggregateOper.get.aggregateExpressions, druidSchema) 237 | } 238 | 239 | buildPlan(dqb, druidSchema, druidQuery, planner, postDruidStep _, buildProjectList _) 240 | 241 | } 242 | 243 | private def buildPlan(dqb: DruidQueryBuilder, 244 | druidSchema: DruidSchema, 245 | druidQuery: DruidQuery, 246 | planner: DruidPlanner, 247 | postDruidStep: SparkPlan => SparkPlan, 248 | buildProjectList: (DruidQueryBuilder, DruidSchema) => Seq[NamedExpression] 249 | ): SparkPlan = { 250 | 251 | val druidRelation = DruidRelation(dqb.druidRelationInfo, Some(druidQuery))(planner.sqlContext) 252 | 253 | val fullAttributes = druidSchema.schema 254 | val requiredColumnIndex = (0 until fullAttributes.size).toSeq 255 | 256 | val druidSparkPlan = postDruidStep( 257 | //DataSourceScanExec.create(druidSchema.schema, 258 | //druidRelation.buildInternalScan, druidRelation) 259 | //RowDataSourceScanExec(druidSchema.schema, 260 | // druidRelation.buildInternalScan, 261 | // druidRelation, 262 | // UnknownPartitioning(0), 263 | // Map(), 264 | // None) 265 | RowDataSourceScanExec(fullAttributes, 266 | requiredColumnIndex, 267 | Set(), 268 | Set(), 269 | druidRelation.buildInternalScan, 270 | druidRelation, 271 | None) 272 | ) 273 | 274 | if (druidSparkPlan != null) { 275 | val projections = buildProjectList(dqb, druidSchema) 276 | ProjectExec(projections, druidSparkPlan) 277 | } else null 278 | } 279 | 280 | private def buildProjectionList(origExpressions: Seq[NamedExpression], 281 | druidSchema: DruidSchema): Seq[NamedExpression] = { 282 | val druidPushDownExprMap: Map[Expression, DruidAttribute] = 283 | druidSchema.pushedDownExprToDruidAttr 284 | val avgExpressions = druidSchema.avgExpressions 285 | 286 | origExpressions.map { ne => ExprUtil.transformReplace(ne, { 287 | case e: Expression if avgExpressions.contains(e) => 288 | val (sumAlias, cntAlias) = avgExpressions(e) 289 | val sumAttr: DruidAttribute = druidSchema.druidAttrMap(sumAlias) 290 | val cntAttr: DruidAttribute = druidSchema.druidAttrMap(cntAlias) 291 | Cast(Divide( 292 | Cast(AttributeReference(sumAttr.name, sumAttr.dataType)(sumAttr.exprId), DoubleType), 293 | Cast(AttributeReference(cntAttr.name, cntAttr.dataType)(cntAttr.exprId), DoubleType) 294 | ), e.dataType) 295 | case ae: AttributeReference if druidPushDownExprMap.contains(ae) && 296 | druidPushDownExprMap(ae).dataType != ae.dataType => 297 | val da = druidPushDownExprMap(ae) 298 | Alias(Cast(AttributeReference(da.name, da.dataType)(da.exprId), ae.dataType), da.name)(da.exprId) 299 | case ae: AttributeReference if druidPushDownExprMap.contains(ae) && 300 | druidPushDownExprMap(ae).name != ae.name => 301 | val da = druidPushDownExprMap(ae) 302 | Alias(AttributeReference(da.name, da.dataType)(da.exprId), da.name)(da.exprId) 303 | case ae: AttributeReference if druidPushDownExprMap.contains(ae) => ae 304 | case e: Expression if druidPushDownExprMap.contains(e) && 305 | druidPushDownExprMap(e).dataType != e.dataType => 306 | val da = druidPushDownExprMap(e) 307 | Cast(AttributeReference(da.name, da.dataType)(da.exprId), e.dataType) 308 | case e: Expression if druidPushDownExprMap.contains(e) => 309 | val da = druidPushDownExprMap(e) 310 | AttributeReference(da.name, da.dataType)(da.exprId) 311 | }).asInstanceOf[NamedExpression]} 312 | } 313 | 314 | } 315 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/DruidTransforms.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | import org.apache.spark.sql.MyLogging 4 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 5 | import org.rzlabs.druid.DruidQueryBuilder 6 | 7 | class DruidTransforms extends MyLogging { 8 | self: DruidPlanner => // DruidTransforms can only be inherited by DruidPlanner 9 | 10 | type DruidTransform = Function[(Seq[DruidQueryBuilder], LogicalPlan), Seq[DruidQueryBuilder]] 11 | 12 | case class ORTransform(t1: DruidTransform, t2: DruidTransform) extends DruidTransform { 13 | 14 | def apply(p: (Seq[DruidQueryBuilder], LogicalPlan)): Seq[DruidQueryBuilder] = { 15 | 16 | val r = t1(p._1, p._2) 17 | if (r.size > 0) { 18 | r 19 | } else { 20 | t2(p._1, p._2) 21 | } 22 | } 23 | } 24 | 25 | case class DebugTransform(transformName: String, t: DruidTransform) extends DruidTransform { 26 | 27 | def apply(p: (Seq[DruidQueryBuilder], LogicalPlan)): Seq[DruidQueryBuilder] = { 28 | 29 | val dqb = p._1 30 | val lp = p._2 31 | val rdqb = t((dqb, lp)) 32 | if (self.druidOptions.debugTransformations) { 33 | logInfo(s"$transformName transform invoked:\n" + 34 | s"Input DruidQueryBuilders: $dqb\n" + 35 | s"Input LogicalPlan: $lp\n" + 36 | s"Output DruidQueryBuilders: $rdqb\n") 37 | } 38 | rdqb 39 | } 40 | } 41 | 42 | case class TransfomHolder(t: DruidTransform) { 43 | 44 | def or(t2: DruidTransform): DruidTransform = ORTransform(t, t2) 45 | 46 | def debug(name: String): DruidTransform = DebugTransform(name, t) 47 | } 48 | 49 | /** 50 | * Convert an object's type from DruidTransform to TransformHolder implicitly. 51 | * So we can call "transform1.or(tramsfomr2)" or "transform1.debug("transformName")". 52 | * @param t The input DruidTransform object. 53 | * @return The converted TransformHoder object. 54 | */ 55 | implicit def transformToHolder(t: DruidTransform) = TransfomHolder(t) 56 | 57 | def debugTransform(msg: => String): Unit = { 58 | if (self.druidOptions.debugTransformations) { 59 | logInfo(msg) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/PostAggregate.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | class PostAggregate(druidSchema: DruidSchema) { 4 | 5 | val dqb = druidSchema.dqb 6 | 7 | } 8 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sources/druid/ProjectFilterTransform.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.sources.druid 2 | 3 | import org.apache.spark.sql.catalyst.expressions._ 4 | import org.apache.spark.sql.catalyst.planning.PhysicalOperation 5 | import org.apache.spark.sql.execution.datasources.LogicalRelation 6 | import org.apache.spark.sql.types._ 7 | import org.apache.spark.sql.util.ExprUtil 8 | import org.rzlabs.druid._ 9 | import org.rzlabs.druid.jscodegen.JSCodeGenerator 10 | import org.rzlabs.druid.metadata.DruidRelationColumn 11 | 12 | trait ProjectFilterTransform { 13 | self: DruidPlanner => 14 | 15 | def addUnpushedAttributes(dqb: DruidQueryBuilder, e: Expression, 16 | isProjection: Boolean): Option[DruidQueryBuilder] = { 17 | if (isProjection) { 18 | Some(dqb.copy(hasUnpushedProjections = true)) 19 | } else { 20 | Some(dqb.copy(hasUnpushedFilters = true)) 21 | } 22 | } 23 | 24 | def projectExpression(dqb: DruidQueryBuilder, projectExpr: Expression, 25 | joinAttrs: Set[String] = Set(), ignoreProjectList: Boolean = false 26 | ): Option[DruidQueryBuilder] = projectExpr match { 27 | 28 | case _ if ignoreProjectList => Some(dqb) 29 | case AttributeReference(nm, _, _, _) if dqb.druidColumn(nm).isDefined => Some(dqb) 30 | case AttributeReference(nm, _, _, _) if joinAttrs.contains(nm) => Some(dqb) 31 | case Alias(ar @ AttributeReference(nm1, _, _, _), nm) => { 32 | for (dqbc <- projectExpression(dqb, ar, joinAttrs, ignoreProjectList)) yield 33 | dqbc.addAlias(nm, nm1) 34 | } 35 | case _ => addUnpushedAttributes(dqb, projectExpr, true) 36 | } 37 | 38 | def translateProjectFilter(dqb: Option[DruidQueryBuilder], projectList: Seq[NamedExpression], 39 | filters: Seq[Expression], ignoreProjectList: Boolean = false, 40 | joinAttrs: Set[String] = Set()): Seq[DruidQueryBuilder] = { 41 | val dqb1 = if (ignoreProjectList) dqb else { 42 | projectList.foldLeft(dqb) { 43 | (ldqb, e) => ldqb.flatMap(projectExpression(_, e, joinAttrs, false)) 44 | } 45 | } 46 | 47 | if (dqb1.isDefined) { // dqb will never be None 48 | // A predicate on the time dimension will be rewrites to Interval constraint. 49 | val ice = new SparkIntervalConditionExtractor(dqb1.get) 50 | // For each filter generates a new DruidQueryBuilder. 51 | var odqb = filters.foldLeft(dqb1) { (lodqb, filter) => 52 | lodqb.flatMap { ldqb => 53 | intervalFilterExpression(ldqb, ice, filter).orElse { 54 | dimensionFilterExpression(ldqb, filter).map { spec => 55 | ldqb.filterSpecification(spec) 56 | } 57 | } 58 | } 59 | } 60 | odqb = odqb.map { dqb2 => 61 | dqb2.copy(origProjectList = dqb2.origProjectList.map(_ ++ projectList).orElse(Some(projectList))). 62 | copy(origFilter = dqb2.origFilter.flatMap(f => 63 | ExprUtil.and(filters :+ f)).orElse(ExprUtil.and(filters))) 64 | } 65 | odqb.map(Seq(_)).getOrElse(Seq()) 66 | } else Seq() 67 | } 68 | 69 | def intervalFilterExpression(dqb: DruidQueryBuilder, ice: SparkIntervalConditionExtractor, 70 | filter: Expression): Option[DruidQueryBuilder] = filter match { 71 | case ice(ic) => dqb.queryInterval(ic) 72 | case _ => None 73 | } 74 | 75 | def dimensionFilterExpression(dqb: DruidQueryBuilder, filter: Expression): Option[FilterSpec] = { 76 | 77 | val timeExtractor = new SparkNativeTimeElementExtractor()(dqb) 78 | 79 | (dqb, filter) match { 80 | case ValidDruidNativeComparison(filterSpec) => Some(filterSpec) 81 | case (dqb, filter) => filter match { 82 | case Or(e1, e2) => 83 | Utils.sequence( 84 | List(dimensionFilterExpression(dqb, e1), 85 | dimensionFilterExpression(dqb, e2))).map { specs => 86 | LogicalExpressionFilterSpec("or", specs) 87 | } 88 | case And(e1, e2) => 89 | Utils.sequence( 90 | List(dimensionFilterExpression(dqb, e1), 91 | dimensionFilterExpression(dqb, e2))).map { specs => 92 | LogicalExpressionFilterSpec("and", specs) 93 | } 94 | case In(AttributeReference(nm, _, _, _), vl: Seq[Expression]) => 95 | for (dc <- dqb.druidColumn(nm) if dc.isDimension() && 96 | vl.forall(_.isInstanceOf[Literal])) yield 97 | new InFilterSpec(dc.name, vl.map(_.asInstanceOf[Literal].value.toString).toList) 98 | case InSet(AttributeReference(nm, _, _, _), vl: Set[Any]) => 99 | for (dc <- dqb.druidColumn(nm) if dc.isDimension()) yield 100 | new InFilterSpec(dc.name, vl.map(_.toString).toList) 101 | case IsNotNull(AttributeReference(nm, _, _, _)) => 102 | for (dc <- dqb.druidColumn(nm) if dc.isDimension()) yield 103 | new NotFilterSpec(new SelectorFilterSpec(dc.name, "")) 104 | case IsNull(AttributeReference(nm, _, _, _)) => 105 | for (dc <- dqb.druidColumn(nm) if dc.isDimension()) yield 106 | new SelectorFilterSpec(nm, "") 107 | case Not(e) => 108 | for (spec <- dimensionFilterExpression(dqb, e)) yield 109 | new NotFilterSpec(spec) 110 | // TODO: What is NULL SCAN ??? 111 | case Literal(null, _) => Some(new SelectorFilterSpec("__time", "")) 112 | case _ => { 113 | val jscodegen = JSCodeGenerator(dqb, filter, false, false, 114 | dqb.druidRelationInfo.options.timeZoneId, BooleanType) 115 | for (fn <- jscodegen.fnCode) yield { 116 | new JavascriptFilterSpec(jscodegen.fnParams.last, fn) 117 | } 118 | } 119 | } 120 | } 121 | } 122 | 123 | private def boundFilter(dqb: DruidQueryBuilder, e: Expression, dc: DruidRelationColumn, 124 | value: Any, sparkDt: DataType, op: String): FilterSpec = { 125 | 126 | val druidDs = dqb.druidRelationInfo.druidDataSource 127 | val ordering = sparkDt match { 128 | case ShortType | IntegerType | LongType | 129 | FloatType | DoubleType | DecimalType() => "numeric" 130 | case _ => "lexicographic" 131 | } 132 | if (druidDs.supportsBoundFilter) { // druid 0.9.0+ support bound filter. 133 | op match { 134 | case "<" => new BoundFilterSpec(dc.name, null, value.toString, false, true, ordering) 135 | case "<=" => new BoundFilterSpec(dc.name, null, value.toString, false, false, ordering) 136 | case ">" => new BoundFilterSpec(dc.name, value.toString, null, true, false, ordering) 137 | case ">=" => new BoundFilterSpec(dc.name, value.toString, null, false, false, ordering) 138 | case _ => null 139 | } 140 | } else { 141 | val jscodegen = new JSCodeGenerator(dqb, e, false, false, 142 | dqb.druidRelationInfo.options.timeZoneId) 143 | val v = if (ordering == "numeric") value.toString else s""""${value.toString}"""" 144 | val ospec: Option[FilterSpec] = for (fn <- jscodegen.fnElements; 145 | body <- Some(fn._1); returnVal <- Some(fn._2)) yield { 146 | val jsFn = 147 | s"""function(${dc.name}) { 148 | | ${body}; 149 | | if (($returnVal) $op $v) { 150 | | return true; 151 | | } else { 152 | | return false; 153 | | } 154 | |}""".stripMargin 155 | new JavascriptFilterSpec(dc.name, jsFn) 156 | } 157 | ospec.getOrElse(null) 158 | } 159 | } 160 | 161 | object ValidDruidNativeComparison { 162 | 163 | def unapply(t: (DruidQueryBuilder, Expression)): Option[FilterSpec] = { 164 | import SparkNativeTimeElementExtractor._ 165 | val dqb = t._1 166 | val filter = t._2 167 | val timeExtractor = new SparkNativeTimeElementExtractor()(dqb) 168 | filter match { 169 | case EqualTo(AttributeReference(nm, dt, _, _), Literal(value, _)) => 170 | for (dc <- dqb.druidColumn(nm) 171 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 172 | yield new SelectorFilterSpec(dc.name, value.toString) 173 | case EqualTo(Literal(value, _), AttributeReference(nm, dt, _, _)) => 174 | for (dc <- dqb.druidColumn(nm) 175 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 176 | yield new SelectorFilterSpec(dc.name, value.toString) 177 | case EqualTo(AttributeReference(nm1, _, _, _), AttributeReference(nm2, _, _, _)) => 178 | for (dc1 <- dqb.druidColumn(nm1) if dc1.isDimension(); 179 | dc2 <- dqb.druidColumn(nm2) if dc2.isDimension()) yield 180 | new ColumnComparisonFilterSpec(List(dc1.name, dc2.name)) 181 | case LessThan(ar @ AttributeReference(nm, dt, _, _), Literal(value, _)) => 182 | for (dc <- dqb.druidColumn(nm) 183 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 184 | yield boundFilter(dqb, ar, dc, value, dt, "<") 185 | case LessThan(Literal(value, _), ar @ AttributeReference(nm, dt, _, _)) => 186 | for (dc <- dqb.druidColumn(nm) 187 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 188 | yield boundFilter(dqb, ar, dc, value, dt, ">") 189 | case LessThanOrEqual(ar @ AttributeReference(nm, dt, _, _), Literal(value, _)) => 190 | for (dc <- dqb.druidColumn(nm) 191 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 192 | yield boundFilter(dqb, ar, dc, value, dt, "<=") 193 | case LessThanOrEqual(Literal(value, _), ar @ AttributeReference(nm, dt, _, _)) => 194 | for (dc <- dqb.druidColumn(nm) 195 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 196 | yield boundFilter(dqb, ar, dc, value, dt, ">=") 197 | case GreaterThan(ar @ AttributeReference(nm, dt, _, _), Literal(value, _)) => 198 | for (dc <- dqb.druidColumn(nm) 199 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 200 | yield boundFilter(dqb, ar, dc, value, dt, ">") 201 | case GreaterThan(Literal(value, _), ar @ AttributeReference(nm, dt, _, _)) => 202 | for (dc <- dqb.druidColumn(nm) 203 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 204 | yield boundFilter(dqb, ar, dc, value, dt, "<") 205 | case GreaterThanOrEqual(ar @ AttributeReference(nm, dt, _, _), Literal(value, _)) => 206 | for (dc <- dqb.druidColumn(nm) 207 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 208 | yield boundFilter(dqb, ar, dc, value, dt, ">=") 209 | case GreaterThanOrEqual(Literal(value, _), ar @ AttributeReference(nm, dt, _, _)) => 210 | for (dc <- dqb.druidColumn(nm) 211 | if dc.isDimension() && DruidDataType.sparkDataType(dc.dataType) == dt) 212 | yield boundFilter(dqb, ar, dc, value, dt, "<=") 213 | case _ => None 214 | } 215 | } 216 | } 217 | 218 | val druidRelationTransform: DruidTransform = { 219 | case (_, PhysicalOperation(projectList, filters, 220 | l @ LogicalRelation(d @ DruidRelation(info, _), _, _, _))) => 221 | // This is the initial DruidQueryBuilder which all transformations 222 | // are based on. 223 | val dqb: Option[DruidQueryBuilder] = Some(DruidQueryBuilder(info)) 224 | val (newFilters, dqb1) = ExprUtil.simplifyConjPred(dqb.get, filters) 225 | translateProjectFilter(Some(dqb1), projectList, newFilters) 226 | case _ => Nil 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/util/ExprUtil.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.util 2 | 3 | import org.apache.spark.sql.catalyst.expressions._ 4 | import org.apache.spark.sql.catalyst.trees.CurrentOrigin 5 | import org.apache.spark.sql.types._ 6 | import org.rzlabs.druid.DruidQueryBuilder 7 | 8 | object ExprUtil { 9 | 10 | /** 11 | * If any input col/ref is null then expression will evaluate to null 12 | * and if no input col/ref is null then expression won't evaluate to null. 13 | * 14 | * @param e Expression that neeeds to be checked 15 | * @return 16 | */ 17 | private[this] def nullPreserving(e: Expression): Boolean = e match { 18 | case Literal(v, _) if v == null => false 19 | case _ if e.isInstanceOf[LeafExpression] => true // LeafExpression except Literal(null) 20 | // TODO: Expand the case below 21 | case Cast(_, _, _) | BinaryArithmetic(_, _) | UnaryMinus(_) | UnaryPositive(_) | Abs(_) | 22 | Concat(_) => e.children.filter(_.isInstanceOf[Expression]).foldLeft(true) { 23 | (lb, ce) => if (nullPreserving(ce) && lb) true else false 24 | } 25 | case _ => false 26 | 27 | } 28 | 29 | private[this] def nullableAttributes(dqb: DruidQueryBuilder, 30 | references: AttributeSet): List[AttributeReference] = { 31 | references.foldLeft(List[AttributeReference]()) { 32 | (list, reference) => 33 | var arList = list 34 | val dc = dqb.druidColumn(reference.name) 35 | if (dc.nonEmpty) { 36 | dc.get match { 37 | case d if d.isDimension(excludeTime = true) && reference.isInstanceOf[AttributeReference] => 38 | arList = arList :+ reference.asInstanceOf[AttributeReference] 39 | case _ => None // metric or time column can not exist as filter pushing down to Druid 40 | } 41 | } 42 | arList 43 | } 44 | } 45 | 46 | def simplifyConjPred(dqb: DruidQueryBuilder, filters: Seq[Expression]): 47 | (Seq[Expression], DruidQueryBuilder) = { 48 | var newFilters = Seq[Expression]() 49 | filters.foreach { filter => 50 | for (nf <- simplifyPred(dqb, filter)) { 51 | newFilters = newFilters :+ nf 52 | } 53 | } 54 | (newFilters, dqb) 55 | } 56 | 57 | def simplifyPred(dqb: DruidQueryBuilder, filter: Expression): Option[Expression] = filter match { 58 | case And(le, re) => simplifyBinaryPred(dqb, le, re, true) 59 | case Or(le, re) => simplifyBinaryPred(dqb, le, re, false) 60 | case SimplifyCast(e) => simplifyPred(dqb, e) 61 | case e => e match { 62 | case SimplifyNotNullFilter(se) => 63 | /* 64 | * nullFilter may equals Concat(a, "abc") after simplify IsNotNull(Concat(a, "abc")) 65 | * This is also null preserving because either null child of Concat 66 | * will lead to null result. 67 | * The null preserving expression may includes: 68 | * 69 | * 1. Cast if its child is null preserving; 70 | * 2. BinaryArithmetic if its children is null preserving; 71 | * 3. UnaryMinus if its child is null preserving; 72 | * 4. UnaryPositive if its child is null preserving; 73 | * 5. Abs if its child is null preserving; 74 | * 6. Concat if its child is null preserving; 75 | * ... 76 | * 77 | */ 78 | if (se.nullable) { 79 | if (nullPreserving(se)) { 80 | // e.g., Concat(a, '123') will generate IsNotNull(a) here, and Concat(a, '123') 81 | // will translate to JavascriptExtractionFunctionSpec. 82 | // Concat(a, b) will generate And(IsNotNull(a), IsNotNull(b)) here, and 83 | // Concat(a, b) will not push down to Druid because there's no AggregateSpec 84 | // with more than 1 input dimension (Just select spec generated). 85 | val nullableAttrRefs = nullableAttributes(dqb, se.references) 86 | nullableAttrRefs.foldLeft(Option.empty[Expression]) { 87 | (le, ar) => if (le.isEmpty) { 88 | Some(IsNotNull(ar)) 89 | } else { 90 | Some(And(le.get, IsNotNull(ar))) 91 | } 92 | } 93 | } else Some(se) // no IsNotNull predicates generated. 94 | } else None // Literal(true) because it's not nullable. 95 | 96 | case fe @ IsNull(ce) => 97 | if (ce.nullable) { 98 | if (nullPreserving(ce)) { 99 | val nullableAttrRefs = nullableAttributes(dqb, ce.references) 100 | if (nullableAttrRefs.isEmpty) { 101 | Some(alwaysFalseExpr) 102 | } else Some(fe) 103 | } else Some(alwaysFalseExpr) // not null preserving expr means any input won't result null. 104 | } else Some(alwaysFalseExpr) // IsNull(not nullable expr) always false 105 | 106 | case _ => Some(e) 107 | } 108 | 109 | 110 | } 111 | 112 | def simplifyBinaryPred(dqb: DruidQueryBuilder, le: Expression, re: Expression, 113 | conj: Boolean): Option[Expression] = { 114 | val newLe = simplifyPred(dqb, le) 115 | val newRe = simplifyPred(dqb, re) 116 | val newFilter = if (newLe.nonEmpty) { 117 | if (newRe.nonEmpty) { 118 | if (conj) { 119 | Some(And(newLe.get, newRe.get)) 120 | } else { 121 | Some(Or(newLe.get, newRe.get)) 122 | } 123 | } else newLe 124 | } else { 125 | if (newRe.nonEmpty) { 126 | newRe 127 | } else None 128 | } 129 | 130 | newFilter 131 | } 132 | 133 | private[this] object SimplifyNotNullFilter { 134 | private[this] val trueFilter = Literal(true) 135 | 136 | def unapply(e: Expression): Option[Expression] = e match { 137 | case Not(IsNull(c)) if (c.nullable) => Some(IsNotNull(c)) 138 | case IsNotNull(c) if (c.nullable) => Some(c) // What if IsNotNull(Concat(a, "abc")) ??? => Concat(a, "abc") ??? 139 | case Not(IsNull(c)) if (!c.nullable) => Some(trueFilter) // e.g., Not(isNull(EqualTo(a, b))) always true 140 | case IsNotNull(c) if (!c.nullable) => Some(trueFilter) // e.g., IsNotNull(LessThan(a, b)) always true 141 | case _ => None 142 | } 143 | } 144 | 145 | private[this] object SimplifyCast { 146 | def unapply(e: Expression): Option[Expression] = e match { 147 | case Cast(Cast(_, _, _), dt, _) => 148 | val c = simplifyCast(e, dt) 149 | if (c == e) None else Some(c) 150 | case _ => None 151 | } 152 | } 153 | 154 | def escapeLikeRegex(v: String): String = { 155 | org.apache.spark.sql.catalyst.util.StringUtils.escapeLikeRegex(v) 156 | } 157 | 158 | /** 159 | * Simplify Cast expression by removing inner most cast if redundant. 160 | * @param oe 161 | * @param odt 162 | * @return 163 | */ 164 | def simplifyCast(oe: Expression, odt: DataType): Expression = oe match { 165 | case Cast(ie, idt, _) if odt.isInstanceOf[NumericType] && 166 | (idt.isInstanceOf[DoubleType] || idt.isInstanceOf[FloatType] || 167 | idt.isInstanceOf[DecimalType]) => Cast(ie, odt) 168 | case _ => oe 169 | } 170 | 171 | def and(exprs: Seq[Expression]): Option[Expression] = exprs.size match { 172 | case 0 => None 173 | case 1 => exprs.headOption 174 | case _ => Some(exprs.foldLeft[Expression](null) { (le, e) => 175 | if (le == null) e else And(le, e) 176 | }) 177 | } 178 | 179 | /** 180 | * This is different from transformDown because if rule transforms an Expression, 181 | * we don't try to apply any more transformations. 182 | * @param e 183 | * @param rule 184 | * @return 185 | */ 186 | def transformReplace(e: Expression, 187 | rule: PartialFunction[Expression, Expression]): Expression = { 188 | val afterRule = CurrentOrigin.withOrigin(e.origin) { 189 | rule.applyOrElse(e, identity[Expression]) 190 | } 191 | 192 | if (e.fastEquals(afterRule)) { 193 | e.transformDown(rule) 194 | } else { 195 | afterRule 196 | } 197 | } 198 | 199 | private val alwaysFalseExpr = EqualTo(Literal(1), Literal(2)) 200 | } 201 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/util/MyThreadUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.util 2 | 3 | import java.util.concurrent.ThreadPoolExecutor 4 | 5 | import org.apache.spark.util 6 | 7 | object MyThreadUtils { 8 | 9 | def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int, 10 | keepAliveSeconds: Int = 60): ThreadPoolExecutor = { 11 | ThreadUtils.newDaemonCachedThreadPool(prefix, maxThreadNumber, keepAliveSeconds) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/org/fasterxml/jackson/databind/ObjectMapper.scala: -------------------------------------------------------------------------------- 1 | package org.fasterxml.jackson.databind 2 | 3 | import com.fasterxml.jackson.databind._ 4 | import com.fasterxml.jackson.dataformat.smile.SmileFactory 5 | import com.fasterxml.jackson.module.scala._ 6 | import com.fasterxml.jackson.datatype.joda._ 7 | 8 | 9 | object ObjectMapper { 10 | 11 | val jsonMapper = { 12 | val om = new ObjectMapper() 13 | om.registerModule(DefaultScalaModule) 14 | om.registerModule(new JodaModule) 15 | om.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS) 16 | om.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) 17 | om 18 | } 19 | 20 | val smileMapper = { 21 | val om = new ObjectMapper(new SmileFactory()) 22 | om.registerModule(DefaultScalaModule) 23 | om.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS) 24 | om.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) 25 | om 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/DateTimeExtractor.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | import org.apache.spark.sql.catalyst.expressions._ 4 | import org.apache.spark.sql.catalyst.util.DateTimeUtils 5 | import org.apache.spark.sql.types._ 6 | import org.joda.time.{DateTime, DateTimeZone} 7 | import org.rzlabs.druid.metadata.DruidRelationColumn 8 | 9 | /** 10 | * @param outputName The output name of the column. 11 | * @param druidColumn The druid column. 12 | * @param formatToApply The output date format. 13 | * @param timeZone The output date time zone. 14 | * @param pushedExpression This controls the expression evaluation that happens 15 | * on return from Druid. So for expression like 16 | * {{{to_date(cast(dateCol as DateType))}}} is evaluated 17 | * on the resultset of Druid. This is required because 18 | * Dates are Ints and Timestamps are Longs in Spark, whereas 19 | * the value coming out of Druid is an ISO DateTime String. 20 | * @param inputFormat Format to use to parse input value. 21 | */ 22 | case class DateTimeGroupingElem(outputName: String, 23 | druidColumn: DruidRelationColumn, 24 | formatToApply: String, 25 | timeZone: Option[String], 26 | pushedExpression: Expression, 27 | inputFormat: Option[String] = None) 28 | 29 | object DruidColumnExtractor { 30 | 31 | def unapply(e: Expression)( 32 | implicit dqb: DruidQueryBuilder): Option[DruidRelationColumn] = e match { 33 | case AttributeReference(nm, _, _, _) => 34 | val druidColumn = dqb.druidColumn(nm) 35 | druidColumn.filter(_.isDimension()) 36 | case _ => None 37 | } 38 | } 39 | 40 | class SparkNativeTimeElementExtractor(implicit val dqb: DruidQueryBuilder) { 41 | 42 | self => 43 | 44 | import SparkNativeTimeElementExtractor._ 45 | 46 | def unapply(e: Expression): Option[DateTimeGroupingElem] = e match { 47 | case DruidColumnExtractor(dc) if e.dataType == DateType => 48 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, DATE_FORMAT, 49 | Some(dqb.druidRelationInfo.options.timeZoneId), e)) 50 | case Cast(c @ DruidColumnExtractor(dc), DateType, _) => 51 | // e.g., "cast(time as date)" 52 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, DATE_FORMAT, 53 | Some(dqb.druidRelationInfo.options.timeZoneId), c)) 54 | case Cast(self(dtGrp), DateType, _) => 55 | // e.g., "cast(from_utc_timestamp(time, 'GMT') as date)", include last case 56 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 57 | DATE_FORMAT, dtGrp.timeZone, dtGrp.pushedExpression)) 58 | case DruidColumnExtractor(dc) if e.dataType == StringType => 59 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, TIMESTAMP_FORMAT, 60 | Some(dqb.druidRelationInfo.options.timeZoneId), e)) 61 | case Cast(self(dtGrp), StringType, _) => 62 | // e.g., "cast(time as string)" 63 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 64 | dtGrp.formatToApply, dtGrp.timeZone, e)) 65 | case DruidColumnExtractor(dc) if e.dataType == TimestampType => 66 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, TIMESTAMP_FORMAT, 67 | Some(dqb.druidRelationInfo.options.timeZoneId), e)) 68 | case Cast(c @ DruidColumnExtractor(dc), TimestampType, _) => 69 | // e.g., "cast(time as timestamp)" 70 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, TIMESTAMP_FORMAT, 71 | Some(dqb.druidRelationInfo.options.timeZoneId), c)) 72 | case Cast(self(dtGrp), TimestampType, _) => 73 | // e.g., "cast(to_date(time) as timestamp)", include last case 74 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 75 | TIMESTAMP_FORMAT, dtGrp.timeZone, dtGrp.pushedExpression)) 76 | //case ToDate(self(dtGrp)) => 77 | // // e.g., "to_date(time)" 78 | // Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 79 | // DATE_FORMAT, dtGrp.timeZone,dtGrp.pushedExpression)) 80 | case ParseToDate(self(dtGrp), fmt, _) if fmt.isInstanceOf[Option[Literal]] => 81 | val fmtStr = if (fmt.nonEmpty) { 82 | fmt.map(_.asInstanceOf[Literal].value.toString).get 83 | } else DATE_FORMAT 84 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 85 | fmtStr, dtGrp.timeZone, dtGrp.pushedExpression)) 86 | case Year(self(dtGrp)) => 87 | // e.g., "year(cast(time as date))" 88 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 89 | YEAR_FORMAT, dtGrp.timeZone, e)) 90 | case DayOfMonth(self(dtGrp)) => 91 | // e.g., "dayofmonth(cast(time as date))" 92 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 93 | DAY_OF_MONTH_FORMAT, dtGrp.timeZone, e)) 94 | case DayOfYear(self(dtGrp)) => 95 | // e.g., "dayofyear(cast(time as date))" 96 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 97 | DAY_OF_YEAR_FORMAT, dtGrp.timeZone, e)) 98 | case Month(self(dtGrp)) => 99 | // e.g., "month(cast(time as date))" 100 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 101 | MONTH_FORMAT, dtGrp.timeZone, e)) 102 | case WeekOfYear(self(dtGrp)) => 103 | // e.g., "weekofyear(cast(time as date))" 104 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 105 | WEEKOFYEAR_FORMAT, dtGrp.timeZone, e)) 106 | case Hour(self(dtGrp), _) => 107 | // e.g., "hour(cast(time as date))" 108 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 109 | HOUR_FORMAT, dtGrp.timeZone, e)) 110 | case Minute(self(dtGrp), _) => 111 | // e.g., "minute(cast(time as date))" 112 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 113 | MINUTE_FORMAT, dtGrp.timeZone, e)) 114 | case Second(self(dtGrp), _) => 115 | // e.g., "second(cast(time as date))" 116 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 117 | SECOND_FORMAT, dtGrp.timeZone, e)) 118 | case UnixTimestamp(self(dtGrp), Literal(inFmt, StringType), _) => 119 | // e.g., "unix_timestamp(cast(time as date), 'YYYY-MM-dd HH:mm:ss')" 120 | 121 | // TODO: UnixTImestamp should parse with JSGenerator 122 | // This because TimeFormatExtractionFunctionSpec just return 123 | // string not bigint. 124 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 125 | TIMESTAMP_FORMAT, dtGrp.timeZone, dtGrp.pushedExpression, 126 | Some(inFmt.toString))) 127 | case UnixTimestamp(c @ DruidColumnExtractor(dc), Literal(inFmt, StringType), _) => 128 | // e.g., "unix_timestamp(time, 'YYYY-MM-dd HH:mm:ss')" 129 | 130 | // TODO: UnixTImestamp should parse with JSGenerator 131 | // This because TimeFormatExtractionFunctionSpec just return 132 | // string not bigint. 133 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, 134 | TIMESTAMP_FORMAT, None, c, 135 | Some(inFmt.toString))) 136 | case FromUnixTime(self(dtGrp), Literal(outFmt, StringType), _) => 137 | // TODO: Remove this case because the TimeFormatExtractionFunctionSpec 138 | // cannot represent the bigint input. We should use 139 | // JavascriptExtractionFunctionSpec out of here. 140 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 141 | outFmt.toString, dtGrp.timeZone, e)) 142 | case FromUnixTime(c @ DruidColumnExtractor(dc), Literal(outFmt, StringType), _) => 143 | // TODO: Remove this case because the TimeFormatExtractionFunctionSpec 144 | // cannot represent the bigint input. We should use 145 | // JavascriptExtractionFunctionSpec out of here. 146 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, outFmt.toString, 147 | Some(dqb.druidRelationInfo.options.timeZoneId), e)) 148 | case FromUTCTimestamp(self(dtGrp), Literal(tz, StringType)) => 149 | // e.g., "from_utc_timestamp(cast(time as timestamp), 'GMT')" 150 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 151 | TIMESTAMP_FORMAT, Some(tz.toString), dtGrp.pushedExpression)) 152 | case FromUTCTimestamp(c @ DruidColumnExtractor(dc), Literal(tz, StringType)) => 153 | // e.g., "from_utc_timestamp(time, 'GMT')" 154 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, 155 | TIMESTAMP_FORMAT, Some(tz.toString), e)) 156 | case ToUTCTimestamp(self(dtGrp), _) => 157 | // e.g., "to_utc_timestamp(cast(time as timestamp), 'GMT')" 158 | Some(DateTimeGroupingElem(dtGrp.outputName, dtGrp.druidColumn, 159 | TIMESTAMP_FORMAT, None, dtGrp.pushedExpression)) 160 | case ToUTCTimestamp(c @ DruidColumnExtractor(dc), _) => 161 | // e.g., "to_utc_timestamp(time, 'GMT')" 162 | Some(DateTimeGroupingElem(dqb.nextAlias, dc, 163 | TIMESTAMP_FORMAT, None, e)) 164 | case _ => None 165 | } 166 | } 167 | 168 | object SparkNativeTimeElementExtractor { 169 | 170 | val DATE_FORMAT = "YYYY-MM-dd" 171 | val TIMESTAMP_FORMAT = "YYYY-MM-dd HH:mm:ss" 172 | val TIMESTAMP_DATEZERO_FORMAT = "YYYY-MM-dd 00:00:00" 173 | 174 | val YEAR_FORMAT = "YYYY" 175 | val MONTH_FORMAT = "MM" 176 | val WEEKOFYEAR_FORMAT = "ww" 177 | val DAY_OF_MONTH_FORMAT = "dd" 178 | val DAY_OF_YEAR_FORMAT = "DD" 179 | 180 | val HOUR_FORMAT = "HH" 181 | val MINUTE_FORMAT = "mm" 182 | val SECOND_FORMAT = "ss" 183 | } 184 | 185 | object IntervalConditionType extends Enumeration { 186 | val GT = Value 187 | val GTE = Value 188 | val LT = Value 189 | val LTE = Value 190 | } 191 | 192 | case class IntervalCondition(`type`: IntervalConditionType.Value, dt: DateTime) 193 | 194 | class SparkIntervalConditionExtractor(dqb: DruidQueryBuilder) { 195 | 196 | import SparkNativeTimeElementExtractor._ 197 | 198 | val timeExtractor = new SparkNativeTimeElementExtractor()(dqb) 199 | 200 | private def literalToDateTime(value: Any, dataType: DataType): DateTime = dataType match { 201 | case TimestampType => 202 | // Timestamp Literal's value accurate to micro second 203 | new DateTime(value.toString.toLong / 1000, 204 | DateTimeZone.forID(dqb.druidRelationInfo.options.timeZoneId)) 205 | case DateType => 206 | new DateTime(DateTimeUtils.toJavaDate(value.toString.toInt), 207 | DateTimeZone.forID(dqb.druidRelationInfo.options.timeZoneId)) 208 | case StringType => new DateTime(value.toString, 209 | DateTimeZone.forID(dqb.druidRelationInfo.options.timeZoneId)) 210 | } 211 | 212 | private object DateTimeLiteralType { 213 | def unapply(dt: DataType): Option[DataType] = dt match { 214 | case StringType | DateType | TimestampType => Some(dt) 215 | case _ => None 216 | } 217 | } 218 | 219 | def unapply(e: Expression): Option[IntervalCondition] = e match { 220 | // TODO: Or(le, re) don't us javascript function 221 | case LessThan(timeExtractor(dtGrp), Literal(value, DateTimeLiteralType(dt))) 222 | if dtGrp.druidColumn.name == dqb.druidRelationInfo.timeDimensionCol && 223 | (dtGrp.formatToApply == TIMESTAMP_FORMAT || 224 | dtGrp.formatToApply == TIMESTAMP_DATEZERO_FORMAT) => 225 | Some(IntervalCondition(IntervalConditionType.LT, literalToDateTime(value, dt))) 226 | case LessThan(Literal(value, DateTimeLiteralType(dt)), timeExtractor(dtGrp)) 227 | if dtGrp.druidColumn.name == dqb.druidRelationInfo.timeDimensionCol && 228 | (dtGrp.formatToApply == TIMESTAMP_FORMAT || 229 | dtGrp.formatToApply == TIMESTAMP_DATEZERO_FORMAT) => 230 | Some(IntervalCondition(IntervalConditionType.GT, literalToDateTime(value, dt))) 231 | case LessThanOrEqual(timeExtractor(dtGrp), Literal(value, DateTimeLiteralType(dt))) 232 | if dtGrp.druidColumn.name == dqb.druidRelationInfo.timeDimensionCol && 233 | (dtGrp.formatToApply == TIMESTAMP_FORMAT || 234 | dtGrp.formatToApply == TIMESTAMP_DATEZERO_FORMAT) => 235 | Some(IntervalCondition(IntervalConditionType.LTE, literalToDateTime(value, dt))) 236 | case LessThanOrEqual(Literal(value, DateTimeLiteralType(dt)), timeExtractor(dtGrp)) 237 | if dtGrp.druidColumn.name == dqb.druidRelationInfo.timeDimensionCol && 238 | (dtGrp.formatToApply == TIMESTAMP_FORMAT || 239 | dtGrp.formatToApply == TIMESTAMP_DATEZERO_FORMAT) => 240 | Some(IntervalCondition(IntervalConditionType.GTE, literalToDateTime(value, dt))) 241 | case GreaterThan(timeExtractor(dtGrp), Literal(value, DateTimeLiteralType(dt))) 242 | if dtGrp.druidColumn.name == dqb.druidRelationInfo.timeDimensionCol && 243 | (dtGrp.formatToApply == TIMESTAMP_FORMAT || 244 | dtGrp.formatToApply == TIMESTAMP_DATEZERO_FORMAT) => 245 | Some(IntervalCondition(IntervalConditionType.GT, literalToDateTime(value, dt))) 246 | case GreaterThan(Literal(value, DateTimeLiteralType(dt)), timeExtractor(dtGrp)) 247 | if dtGrp.druidColumn.name == dqb.druidRelationInfo.timeDimensionCol && 248 | (dtGrp.formatToApply == TIMESTAMP_FORMAT || 249 | dtGrp.formatToApply == TIMESTAMP_DATEZERO_FORMAT) => 250 | Some(IntervalCondition(IntervalConditionType.LT, literalToDateTime(value, dt))) 251 | case GreaterThanOrEqual(timeExtractor(dtGrp), Literal(value, DateTimeLiteralType(dt))) 252 | if dtGrp.druidColumn.name == dqb.druidRelationInfo.timeDimensionCol && 253 | (dtGrp.formatToApply == TIMESTAMP_FORMAT || 254 | dtGrp.formatToApply == TIMESTAMP_DATEZERO_FORMAT) => 255 | Some(IntervalCondition(IntervalConditionType.GTE, literalToDateTime(value, dt))) 256 | case GreaterThanOrEqual(Literal(value, DateTimeLiteralType(dt)), timeExtractor(dtGrp)) 257 | if dtGrp.druidColumn.name == dqb.druidRelationInfo.timeDimensionCol && 258 | (dtGrp.formatToApply == TIMESTAMP_FORMAT || 259 | dtGrp.formatToApply == TIMESTAMP_DATEZERO_FORMAT) => 260 | Some(IntervalCondition(IntervalConditionType.LTE, literalToDateTime(value, dt))) 261 | case _ => None 262 | } 263 | } 264 | 265 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | import com.fasterxml.jackson.core.`type`.TypeReference 4 | import org.apache.spark.sql.rzlabs.DruidBaseModule 5 | import org.apache.spark.sql.{MyLogging, SQLContext} 6 | import org.apache.spark.sql.sources.{BaseRelation, RelationProvider} 7 | import org.fasterxml.jackson.databind.ObjectMapper._ 8 | import org.rzlabs.druid.metadata.{DruidMetadataCache, DruidOptions, DruidRelationColumnInfo, DruidRelationInfo} 9 | 10 | import org.apache.spark.sql._ 11 | 12 | class DefaultSource extends RelationProvider with MyLogging { 13 | 14 | import DefaultSource._ 15 | 16 | override def createRelation(sqlContext: SQLContext, 17 | parameters: Map[String, String]): BaseRelation = { 18 | 19 | val dsName: String = parameters.getOrElse(DRUID_DS_NAME, 20 | throw new DruidDataSourceException( 21 | s"'$DRUID_DS_NAME' must be specified for Druid datasource.") 22 | ) 23 | 24 | val timeDimensionCol: String = parameters.getOrElse(TIME_DIMENSION_COLUMN_NAME, 25 | null) 26 | 27 | val hyperUniqueColumnInfos: List[DruidRelationColumnInfo] = 28 | parameters.get(HYPER_UNIQUE_COLUMN_INFO) 29 | .map(jsonMapper.readValue(_, 30 | new TypeReference[List[DruidRelationColumnInfo]] {}). 31 | asInstanceOf[List[DruidRelationColumnInfo]]).getOrElse(List()) 32 | 33 | val sketchColumnInfos: List[DruidRelationColumnInfo] = 34 | parameters.get(SKETCH_COLUMN_INFO) 35 | .map(jsonMapper.readValue(_, 36 | new TypeReference[List[DruidRelationColumnInfo]] {}). 37 | asInstanceOf[List[DruidRelationColumnInfo]]).getOrElse(List()) 38 | 39 | val zkHost: String = parameters.getOrElse(ZK_HOST, DEFAULT_ZK_HOST) 40 | 41 | val zkDruidPath: String = parameters.getOrElse(ZK_DRUID_PATH, DEFAULT_ZK_DRUID_PATH) 42 | 43 | val zkSessionTimeout: Int = parameters.getOrElse(ZK_SESSION_TIMEOUT, 44 | DEFAULT_ZK_SESSION_TIMEOUT).toInt 45 | 46 | val zkEnableCompression: Boolean = parameters.getOrElse(ZK_ENABLE_COMPRESSION, 47 | DEFAULT_ZK_ENABLE_COMPRESSION).toBoolean 48 | 49 | val zKQualifyDiscoveryNames: Boolean = parameters.getOrElse(ZK_QUALIFY_DISCOVERY_NAMES, 50 | DEFAULT_ZK_QUALIFY_DISCOVERY_NAMES).toBoolean 51 | 52 | val poolMaxConnectionsPerRoute: Int = parameters.getOrElse(CONN_POOL_MAX_CONNECTIONS_PER_ROUTE, 53 | DEFAULT_CONN_POOL_MAX_CONNECTIONS_PER_ROUTE).toInt 54 | 55 | val poolMaxConnections: Int = parameters.getOrElse(CONN_POOL_MAX_CONNECTIONS, 56 | DEFAULT_CONN_POOL_MAX_CONNECTIONS).toInt 57 | 58 | val loadMetadataFromAllSegments: Boolean = parameters.getOrElse(LOAD_METADATA_FROM_ALL_SEGMENTS, 59 | DEFAULT_LOAD_METADATA_FROM_ALL_SEGMENTS).toBoolean 60 | 61 | val debugTransformations: Boolean = parameters.getOrElse(DEBUG_TRANSFORMATIONS, 62 | DEFAULT_DEBUG_TRANSFORMATIONS).toBoolean 63 | 64 | val timeZoneId: String = parameters.getOrElse(TIME_ZONE_ID, DEFAULT_TIME_ZONE_ID) 65 | 66 | val useV2GroupByEngine = parameters.getOrElse(USE_V2_GROUPBY_ENGINE, 67 | DEFAULT_USE_V2_GROUPBY_ENGINE).toBoolean 68 | 69 | val useSmile = parameters.getOrElse(USE_SMILE, DEFAULT_USE_SMILE).toBoolean 70 | 71 | val queryGranularity = DruidQueryGranularity( 72 | parameters.getOrElse(QUERY_GRANULARITY, DEFAULT_QUERY_GRANULARITY)) 73 | 74 | 75 | val druidOptions = DruidOptions( 76 | zkHost, 77 | zkSessionTimeout, 78 | zkEnableCompression, 79 | zKQualifyDiscoveryNames, 80 | zkDruidPath, 81 | poolMaxConnectionsPerRoute, 82 | poolMaxConnections, 83 | loadMetadataFromAllSegments, 84 | debugTransformations, 85 | timeZoneId, 86 | useV2GroupByEngine, 87 | useSmile, 88 | queryGranularity 89 | ) 90 | 91 | val druidRelationInfo: DruidRelationInfo = 92 | DruidMetadataCache.druidRelation(dsName, 93 | timeDimensionCol, 94 | hyperUniqueColumnInfos ++ sketchColumnInfos, 95 | druidOptions) 96 | 97 | val druidRelation = DruidRelation(druidRelationInfo, None)(sqlContext) 98 | 99 | addPhysicalRules(sqlContext, druidOptions) 100 | 101 | druidRelation 102 | } 103 | 104 | /** 105 | * There are 3 places to initialize a [[BaseRelation]] object by calling 106 | * the `resolveRelation` method of [[org.apache.spark.sql.execution.datasources.DataSource]]: 107 | * 108 | * 1. In the `run(sparkSession: SparkSession)` method in 109 | * [[org.apache.spark.sql.execution.command.CreateDataSourceTableCommand]] 110 | * when executing sql "create table using ..."; 111 | * 2. In the `load(paths: String*)` method in [[org.apache.spark.sql.DataFrameReader]] 112 | * when calling "spark.read.format(org.rzlabs.druid).load()"; 113 | * 3. In the `load` method of the LoadingCache object "cachedDataSourceTables" in 114 | * [[org.apache.spark.sql.hive.HiveMetastoreCatalog]] which called from the root 115 | * method of `apply` in [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]] 116 | * which belongs to "resolution" rule batch in the logical plan analyzing phase of the 117 | * execution of "select ...". 118 | * 119 | * None of the 3 cases generates [[org.apache.spark.sql.execution.SparkPlan]] in DataFrame's 120 | * `queryExecution` of [[org.apache.spark.sql.execution.QueryExecution]], so we can 121 | * add druid-related physical rules in the `resolveRelation` method in [[DefaultSource]]. 122 | * 123 | * @param sqlContext 124 | * @param druidOptions 125 | */ 126 | private def addPhysicalRules(sqlContext: SQLContext, druidOptions: DruidOptions) = { 127 | rulesLock.synchronized { 128 | if (!physicalRulesAdded) { 129 | sqlContext.sparkSession.experimental.extraStrategies ++= 130 | DruidBaseModule.physicalRules(sqlContext, druidOptions) 131 | physicalRulesAdded = true 132 | } 133 | } 134 | } 135 | 136 | } 137 | 138 | object DefaultSource { 139 | 140 | private val rulesLock = new Object 141 | 142 | private var physicalRulesAdded = false 143 | 144 | /** 145 | * Datasource name in Druid. 146 | */ 147 | val DRUID_DS_NAME = "druidDatasource" 148 | 149 | /** 150 | * Time dimension name in a Druid datasource. 151 | */ 152 | val TIME_DIMENSION_COLUMN_NAME = "timeDimensionColumn" 153 | 154 | val HYPER_UNIQUE_COLUMN_INFO = "hyperUniqueColumnInfos" 155 | 156 | val SKETCH_COLUMN_INFO ="sketchColumnInfos" 157 | 158 | /** 159 | * Zookeeper server host name with pattern "host:port". 160 | */ 161 | val ZK_HOST = "zkHost" 162 | val DEFAULT_ZK_HOST = "localhost" 163 | 164 | val ZK_SESSION_TIMEOUT = "zkSessionTimeoutMs" 165 | val DEFAULT_ZK_SESSION_TIMEOUT = "30000" 166 | 167 | val ZK_ENABLE_COMPRESSION ="zkEnableCompression" 168 | val DEFAULT_ZK_ENABLE_COMPRESSION = "true" 169 | 170 | /** 171 | * Druid cluster sync path on zk. 172 | */ 173 | val ZK_DRUID_PATH = "zkDruidPath" 174 | val DEFAULT_ZK_DRUID_PATH = "/druid" 175 | 176 | val ZK_QUALIFY_DISCOVERY_NAMES = "zkQualifyDiscoveryNames" 177 | val DEFAULT_ZK_QUALIFY_DISCOVERY_NAMES = "true" 178 | 179 | /** 180 | * The query granularity that should be told Druid. 181 | * The options include 'minute', 'hour', 'day' and etc. 182 | */ 183 | val QUERY_GRANULARITY = "queryGranularity" 184 | val DEFAULT_QUERY_GRANULARITY = "all" 185 | 186 | /** 187 | * The max simultaneous live connections per Druid server. 188 | */ 189 | val CONN_POOL_MAX_CONNECTIONS_PER_ROUTE = "maxConnectionsPerRoute" 190 | val DEFAULT_CONN_POOL_MAX_CONNECTIONS_PER_ROUTE = "20" 191 | 192 | /** 193 | * The max simultaneous live connections of the Druid cluster. 194 | */ 195 | val CONN_POOL_MAX_CONNECTIONS = "maxConnections" 196 | val DEFAULT_CONN_POOL_MAX_CONNECTIONS = "100" 197 | 198 | val LOAD_METADATA_FROM_ALL_SEGMENTS = "loadMetadataFromAllSegments" 199 | val DEFAULT_LOAD_METADATA_FROM_ALL_SEGMENTS = "true" 200 | 201 | val DEBUG_TRANSFORMATIONS = "debugTransformations" 202 | val DEFAULT_DEBUG_TRANSFORMATIONS = "false" 203 | 204 | val TIME_ZONE_ID = "timeZoneId" 205 | val DEFAULT_TIME_ZONE_ID= "UTC" 206 | 207 | val USE_V2_GROUPBY_ENGINE = "useV2GroupByEngine" 208 | val DEFAULT_USE_V2_GROUPBY_ENGINE = "false" 209 | 210 | val USE_SMILE = "useSmile" 211 | val DEFAULT_USE_SMILE = "true" 212 | 213 | } 214 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/DruidDataSource.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | import com.clearspring.analytics.stream.cardinality.ICardinality 4 | import org.apache.spark.sql.types._ 5 | import org.joda.time.Interval 6 | import org.rzlabs.druid.client.{Aggregator, ColumnDetail, MetadataResponse, TimestampSpec} 7 | 8 | /** 9 | * Driud data type enum. All the value name are from the 10 | * `type` field of the column JSON by calling the `segmentMetadata` API. 11 | */ 12 | object DruidDataType extends Enumeration { 13 | val String = Value("STRING") 14 | val Long = Value("LONG") 15 | val Float = Value("FLOAT") 16 | val HyperUnique = Value("hyperUnique") 17 | val ThetaSketch = Value("thetaSketch") 18 | 19 | def sparkDataType(t: String): DataType = sparkDataType(DruidDataType.withName(t)) 20 | 21 | def sparkDataType(t: DruidDataType.Value): DataType = t match { 22 | case String => StringType 23 | case Long => LongType 24 | case Float => FloatType 25 | case HyperUnique => BinaryType 26 | case ThetaSketch => BinaryType 27 | } 28 | } 29 | 30 | sealed trait DruidColumn { 31 | val name: String 32 | val dataType: DruidDataType.Value 33 | val size: Long // in bytes 34 | 35 | /** 36 | * Come from the segment metadata query. 37 | * Just for time and dimension fields. 38 | */ 39 | val cardinality: Long 40 | 41 | def isDimension(excludeTime: Boolean = false): Boolean 42 | } 43 | 44 | object DruidColumn { 45 | 46 | def apply(name: String, 47 | c: ColumnDetail, 48 | numRows: Long, 49 | timeTicks: Long): DruidColumn = { 50 | 51 | if (name == DruidDataSource.INNER_TIME_COLUMN_NAME) { 52 | DruidTimeDimension(name, DruidDataType.withName(c.`type`), c.size, Math.min(numRows, timeTicks)) 53 | } else if (c.cardinality.isDefined) { 54 | DruidDimension(name, DruidDataType.withName(c.`type`), c.size, c.cardinality.get) 55 | } else { 56 | // The metric's cardinality is considered the same as the datasource row number. 57 | DruidMetric(name, DruidDataType.withName(c.`type`), c.size, numRows) 58 | } 59 | } 60 | } 61 | 62 | case class DruidDimension(name: String, 63 | dataType: DruidDataType.Value, 64 | size: Long, 65 | cardinality: Long) extends DruidColumn { 66 | def isDimension(excludeTime: Boolean = false) = true 67 | } 68 | 69 | case class DruidTimeDimension(name: String, 70 | dataType: DruidDataType.Value, 71 | size: Long, 72 | cardinality: Long) extends DruidColumn { 73 | def isDimension(excludeTime: Boolean = false): Boolean = !excludeTime 74 | } 75 | 76 | case class DruidMetric(name: String, 77 | dataType: DruidDataType.Value, 78 | size: Long, 79 | cardinality: Long) extends DruidColumn { 80 | def isDimension(excludeTime: Boolean) = false 81 | } 82 | 83 | trait DruidDataSourceCapability { 84 | def druidVersion: String 85 | def supportsBoundFilter: Boolean = false 86 | } 87 | 88 | object DruidDataSourceCapability { 89 | private def versionCompare(druidVersion: String, oldVersion: String): Int = { 90 | def compare(v1: Int, v2: Int) = { 91 | (v1 - v2) match { 92 | case 0 => 0 93 | case v if v > 0 => 1 94 | case _ => -1 95 | } 96 | } 97 | druidVersion.split("\\.").zip(oldVersion.split("\\.")).foldLeft(Option.empty[Int]) { 98 | (l, r) => l match { 99 | case None | Some(0) => Some(compare(r._1.toInt, r._2.toInt)) 100 | case res @ (Some(-1) | Some(1)) => res 101 | } 102 | }.get 103 | } 104 | def supportsBoundFilter(druidVersion: String): Boolean = 105 | versionCompare(druidVersion, "0.9.0") >= 0 106 | def supportsQueryGranularityMetadata(druidVersion: String): Boolean = 107 | versionCompare(druidVersion, "0.9.1") >= 0 108 | def supportsTimestampSpecMetadata(druidVersion: String): Boolean = 109 | versionCompare(druidVersion, "0.9.2") >= 0 110 | } 111 | 112 | case class DruidDataSource(name: String, 113 | var intervals: List[Interval], 114 | columns: Map[String, DruidColumn], 115 | size: Long, 116 | numRows: Long, 117 | timeTicks: Long, 118 | aggregators: Option[Map[String, Aggregator]] = None, 119 | timestampSpec: Option[TimestampSpec] = None, 120 | druidVersion: String = null) extends DruidDataSourceCapability { 121 | 122 | import DruidDataSource._ 123 | 124 | lazy val timeDimension: Option[DruidColumn] = columns.values.find { 125 | case c if c.name == INNER_TIME_COLUMN_NAME => true 126 | case c if timestampSpec.isDefined && c.name == timestampSpec.get.column => true 127 | case _ => false 128 | } 129 | 130 | lazy val dimensions: IndexedSeq[DruidDimension] = columns.values.filter { 131 | case d: DruidDimension => true 132 | case _ => false 133 | }.map{_.asInstanceOf[DruidDimension]}.toIndexedSeq 134 | 135 | lazy val metrics: Map[String, DruidMetric] = columns.values.filter { 136 | case m: DruidMetric => true 137 | case _ => false 138 | }.map(m => m.name -> m.asInstanceOf[DruidMetric]).toMap 139 | 140 | override def supportsBoundFilter: Boolean = DruidDataSourceCapability.supportsBoundFilter(druidVersion) 141 | 142 | def numDimensions = dimensions.size 143 | 144 | def indexOfDimension(d: String): Int = { 145 | dimensions.indexWhere(_.name == d) 146 | } 147 | 148 | def metric(name: String): Option[DruidMetric] = metrics.get(name) 149 | 150 | def timeDimensionColName(timeDimCol: String) = { 151 | if (timestampSpec.nonEmpty) { 152 | timestampSpec.get.column 153 | } else if (timeDimCol != null) { 154 | timeDimCol 155 | } else { 156 | DruidDataSource.INNER_TIME_COLUMN_NAME 157 | } 158 | } 159 | } 160 | 161 | object DruidDataSource { 162 | 163 | val INNER_TIME_COLUMN_NAME = "__time" 164 | val TIMESTAMP_KEY_NAME = "timestamp" 165 | 166 | def apply(dataSource: String, mr: MetadataResponse, 167 | ins: List[Interval]): DruidDataSource = { 168 | 169 | val numRows = mr.getNumRows 170 | // TODO just 1 interval in 'ins' List and it's the 171 | // time boundary of the datasource. So the time ticks may not be correct. 172 | val timeTicks = mr.timeTicks(ins) 173 | 174 | val columns = mr.columns.map { 175 | case (name, columnDetail) => 176 | name -> DruidColumn(name, columnDetail, numRows, timeTicks) 177 | } 178 | new DruidDataSource(dataSource, ins, columns, mr.size, numRows, timeTicks, 179 | mr.aggregators, mr.timestampSpec) 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/DruidExceptions.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | class DruidDataSourceException(message: String, cause: Throwable) 4 | extends Exception(message, cause) { 5 | def this(message: String) = this(message, null) 6 | } 7 | 8 | class QueryGranularityException(message: String, cause: Throwable) 9 | extends Exception(message, cause) { 10 | def this(message: String) = this(message, null) 11 | } 12 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/DruidQueryBuilder.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | import java.util.concurrent.atomic.AtomicLong 4 | 5 | import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} 6 | import org.apache.spark.sql.catalyst.plans.logical.Aggregate 7 | import org.apache.spark.sql.types.DataType 8 | import org.rzlabs.druid.metadata.{DruidRelationColumn, DruidRelationInfo} 9 | 10 | import scala.collection.mutable.{Map => MMap} 11 | 12 | case class DruidQueryBuilder(druidRelationInfo: DruidRelationInfo, 13 | queryIntervals: QueryIntervals, 14 | referencedDruidColumns: MMap[String, DruidRelationColumn] = MMap(), 15 | dimensions: List[DimensionSpec] = Nil, 16 | limitSpec: Option[LimitSpec] = None, 17 | havingSpec: Option[HavingSpec] = None, 18 | granularitySpec: GranularitySpec = AllGranularitySpec("all"), 19 | filterSpec: Option[FilterSpec] = None, 20 | aggregations: List[AggregationSpec] = Nil, 21 | postAggregations: Option[List[PostAggregationSpec]] = None, 22 | projectionAliasMap: Map[String, String] = Map(), 23 | outputAttributeMap: Map[String, (Expression, DataType, DataType, String)] = Map(), 24 | // avg expressions to perform in the Project operator 25 | // on top of Druid PhysicalScan. 26 | avgExpressions: Map[Expression, (String, String)] = Map(), 27 | aggregateOper: Option[Aggregate] = None, 28 | curId: AtomicLong = new AtomicLong(-1), 29 | origProjectList: Option[Seq[NamedExpression]] = None, 30 | origFilter: Option[Expression] = None, 31 | hasUnpushedProjections: Boolean = false, 32 | hasUnpushedFilters: Boolean = false 33 | ) { 34 | 35 | override def toString = { 36 | s""" 37 | queryIntervals: 38 | ${queryIntervals.intervals.mkString("\n")} 39 | dimensions: 40 | ${dimensions.mkString("\n")} 41 | aggregations: 42 | ${aggregations.mkString("\n")} 43 | """.stripMargin 44 | } 45 | 46 | def aggregateOp(oper: Aggregate) = this.copy(aggregateOper = Some(oper)) 47 | 48 | def isNonTimeDimension(name: String) = { 49 | druidColumn(name).map(_.isDimension(true)).getOrElse(false) 50 | } 51 | 52 | def isNotIndexedDimension(name: String) = { 53 | druidColumn(name).map(_.isNotIndexedDimension).getOrElse(false) 54 | } 55 | 56 | def dimensionSpec(d: DimensionSpec) = { 57 | this.copy(dimensions = dimensions :+ d) 58 | 59 | } 60 | 61 | def aggregationSpec(a: AggregationSpec) = { 62 | this.copy(aggregations = aggregations :+ a) 63 | } 64 | 65 | def filterSpecification(f: FilterSpec) = (filterSpec, f) match { 66 | case (Some(fs), _) => 67 | this.copy(filterSpec = Some(new LogicalExpressionFilterSpec("and", List(f, fs)))) 68 | case (None, _) => 69 | this.copy(filterSpec = Some(f)) 70 | } 71 | 72 | /** 73 | * Get the [[DruidRelationColumn]] by column name. 74 | * The column name may be alias name, so we should 75 | * find the real column name in ''projectionAliasMap''. 76 | * @param name The key to find DruidRelationColumn. 77 | * @return The found DruidRelationColumn or None. 78 | */ 79 | def druidColumn(name: String): Option[DruidRelationColumn] = { 80 | druidRelationInfo.druidColumns.get(projectionAliasMap.getOrElse(name, name)).map { 81 | druidColumn => 82 | referencedDruidColumns(name) = druidColumn 83 | druidColumn 84 | } 85 | } 86 | 87 | def queryInterval(ic: IntervalCondition): Option[DruidQueryBuilder] = ic.`type` match { 88 | case IntervalConditionType.LT => 89 | queryIntervals.ltCond(ic.dt).map(qi => this.copy(queryIntervals = qi)) 90 | case IntervalConditionType.LTE => 91 | queryIntervals.lteCond(ic.dt).map(qi => this.copy(queryIntervals = qi)) 92 | case IntervalConditionType.GT => 93 | queryIntervals.gtCond(ic.dt).map(qi => this.copy(queryIntervals = qi)) 94 | case IntervalConditionType.GTE => 95 | queryIntervals.gteCond(ic.dt).map(qi => this.copy(queryIntervals = qi)) 96 | } 97 | 98 | /** 99 | * From "alias-1" to "alias-N". 100 | * @return 101 | */ 102 | def nextAlias: String = s"alias${curId.getAndDecrement()}" 103 | 104 | def outputAttribute(name: String, e: Expression, originalDT: DataType, 105 | druidDT: DataType, tfName: String = null) = { 106 | val tf = if (tfName == null) DruidValTransform.getTFName(druidDT) else tfName 107 | this.copy(outputAttributeMap = outputAttributeMap + (name -> (e, originalDT, druidDT, tf))) 108 | } 109 | 110 | def addAlias(alias: String, col: String) = { 111 | val colName = projectionAliasMap.getOrElse(col, col) 112 | this.copy(projectionAliasMap = projectionAliasMap + (alias -> colName)) 113 | } 114 | 115 | def avgExpression(e: Expression, sumAlias: String, countAlias: String) = { 116 | this.copy(avgExpressions = avgExpressions + (e -> (sumAlias, countAlias))) 117 | } 118 | 119 | } 120 | 121 | object DruidQueryBuilder { 122 | def apply(druidRelationInfo: DruidRelationInfo) = { 123 | new DruidQueryBuilder(druidRelationInfo, new QueryIntervals(druidRelationInfo)) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/DruidQueryGranularity.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | import org.joda.time.{DateTime, DateTimeZone, Interval, Period} 4 | import org.fasterxml.jackson.databind.ObjectMapper._ 5 | import com.fasterxml.jackson.annotation._ 6 | import com.fasterxml.jackson.databind.JsonNode 7 | import com.fasterxml.jackson.databind.node._ 8 | 9 | import scala.util.Try 10 | 11 | @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type") 12 | @JsonSubTypes(Array( 13 | new JsonSubTypes.Type(value = classOf[NoneGranularity], name = "none"), 14 | new JsonSubTypes.Type(value = classOf[AllGranularity], name = "all"), 15 | new JsonSubTypes.Type(value = classOf[DurationGranularity], name = "duration"), 16 | new JsonSubTypes.Type(value = classOf[PeriodGranularity], name = "period") 17 | )) 18 | sealed trait DruidQueryGranularity extends Serializable { 19 | 20 | /** 21 | * The cardinality of the time field amongst the intervals 22 | * according to the specific granularity. 23 | * @param ins The intervals specified. 24 | * @return The cardinality of the time field. 25 | */ 26 | def ndv(ins: List[Interval]): Long 27 | } 28 | 29 | object DruidQueryGranularity { 30 | 31 | def apply(s: String): DruidQueryGranularity = s match { 32 | case n if n.toLowerCase().equals("none") => NoneGranularity() 33 | case a if a.toLowerCase().equals("all") => AllGranularity() 34 | case s if s.toLowerCase().equals("second") => DurationGranularity(1000L) 35 | case m if m.toLowerCase().equals("minute") => DurationGranularity(60 * 1000L) 36 | case fm if fm.toLowerCase().equals("fifteen_minute") => DurationGranularity(15 * 60 * 1000L) 37 | case tm if tm.toLowerCase().equals("thirty_minute") => DurationGranularity(30 * 60 * 1000L) 38 | case h if h.toLowerCase().equals("hour") => DurationGranularity(3600 * 1000L) 39 | case d if d.toLowerCase().equals("day") => DurationGranularity(24 * 3600 * 1000L) 40 | case w if w.toLowerCase().equals("week") => DurationGranularity(7 * 24 * 3600 * 1000L) 41 | case q if q.toLowerCase().equals("quarter") => DurationGranularity(91 * 24 * 3600 * 1000L) 42 | case y if y.toLowerCase().equals("year") => DurationGranularity(365 * 24 * 3600 * 1000L) 43 | case _ => { 44 | // val jV = parse(s) 45 | // Try { 46 | // jV.extract[DurationGranularity] 47 | // } recover { 48 | // case _ => jV.extract[PeriodGranularity] 49 | // } get 50 | Try { 51 | jsonMapper.readValue(s, classOf[DurationGranularity]) 52 | } recover { 53 | case _ => jsonMapper.readValue(s, classOf[PeriodGranularity]) 54 | } get 55 | } 56 | } 57 | 58 | def substitute(n: JsonNode): JsonNode = n.findValuesAsText("queryGranularity") match { 59 | case vl: java.util.List[String] if vl.size > 0 && vl.get(0).nonEmpty => 60 | val on = jsonMapper.createObjectNode() 61 | vl.get(0) match { 62 | case n if n.toLowerCase().equals("none") => on.put("type", "none") 63 | case a if a.toLowerCase().equals("all") => on.put("type", "all") 64 | case s if s.toLowerCase().equals("second") => 65 | on.put("type", "duration").put("duration", 1000L) 66 | case m if m.toLowerCase().equals("minute") => 67 | on.put("type", "duration").put("duration", 60 * 1000L) 68 | case fm if fm.toLowerCase().equals("fifteen_minute") => 69 | on.put("type", "duration").put("duration", 15 * 60 * 1000L) 70 | case tm if tm.toLowerCase().equals("thirty_minute") => 71 | on.put("type", "duration").put("duration", 30 * 60 * 1000L) 72 | case h if h.toLowerCase().equals("hour") => 73 | on.put("type", "duration").put("duration", 3600 * 1000L) 74 | case d if d.toLowerCase().equals("day") => 75 | on.put("type", "duration").put("duration", 24 * 3600 * 1000L) 76 | case w if w.toLowerCase().equals("week") => 77 | on.put("type", "duration").put("duration", 7 * 24 * 3600 * 1000L) 78 | case q if q.toLowerCase().equals("quarter") => 79 | on.put("type", "duration").put("duration", 91 * 24 * 3600 * 1000L) 80 | case y if y.toLowerCase().equals("year") => 81 | on.put("type", "duration").put("duration", 365 * 24 * 3600 * 1000L) 82 | case other => throw new DruidDataSourceException(s"Invalid query granularity '$other'") 83 | } 84 | n.asInstanceOf[ObjectNode].replace("queryGranularity", on) 85 | n 86 | case _ => n 87 | } 88 | } 89 | 90 | case class AllGranularity() extends DruidQueryGranularity { 91 | 92 | def ndv(ins: List[Interval]) = 1L 93 | } 94 | 95 | case class NoneGranularity() extends DruidQueryGranularity { 96 | 97 | def ndv(ins: List[Interval]) = Utils.intervalsMillis(ins) 98 | } 99 | 100 | case class DurationGranularity(duration: Long, origin: DateTime = null) 101 | extends DruidQueryGranularity { 102 | 103 | lazy val originMillis = if (origin == null) 0L else origin.getMillis 104 | 105 | def ndv(ins: List[Interval]) = { 106 | val boundedIns = ins.flatMap { in => 107 | try { 108 | Some(in.withStartMillis(Math.max(originMillis, in.getStartMillis)) 109 | .withEndMillis(Math.max(originMillis, in.getEndMillis))) 110 | } catch { 111 | case e: IllegalArgumentException => None 112 | } 113 | } 114 | Utils.intervalsMillis(boundedIns) / duration 115 | } 116 | } 117 | 118 | case class PeriodGranularity(period: Period, 119 | origin: DateTime = null, 120 | timeZone: DateTimeZone = null) extends DruidQueryGranularity { 121 | 122 | val tz = if (timeZone == null) DateTimeZone.UTC else timeZone 123 | lazy val originMillis = if (origin == null) { 124 | new DateTime(0, DateTimeZone.UTC).withZoneRetainFields(tz).getMillis 125 | } else { 126 | origin.getMillis 127 | } 128 | lazy val periodMillis = period.getValues.zipWithIndex.foldLeft(0L) { 129 | case (r, p) => r + p._1 * 130 | (p._2 match { 131 | case 0 => 365 * 24 * 3600 * 1000L // year 132 | case 1 => 30 * 24 * 3600 * 1000L // month 133 | case 2 => 7 * 24 * 3600 * 1000L // week 134 | case 3 => 24 * 3600 * 1000L // day 135 | case 4 => 3600 * 1000L // hour 136 | case 5 => 60 * 1000L // minute 137 | case 6 => 1000L // second 138 | case 7 => 1L // millisecond 139 | }) 140 | } 141 | 142 | def ndv(ins: List[Interval]) = { 143 | val boundedIns = ins.flatMap { in => 144 | try { 145 | Some(in.withStartMillis(Math.max(originMillis, in.getStartMillis)) 146 | .withEndMillis(Math.max(originMillis, in.getEndMillis))) 147 | } catch { 148 | case e: IllegalArgumentException => None 149 | } 150 | } 151 | Utils.intervalsMillis(boundedIns) / periodMillis 152 | } 153 | } 154 | 155 | //class DruidQueryGranularitySerializer extends CustomSerializer[DruidQueryGranularity](format => { 156 | // implicit val fmt = format 157 | // ( 158 | // { 159 | // // PartialFunction used in deserialize method. 160 | // case jsonObj: JObject => 161 | // val fieldMap = jsonObj.values 162 | // fieldMap.get("type") match { 163 | // case Some(typ) if typ == "period" => 164 | // val period = new Period((jsonObj \ "period").extract[String]) 165 | // //val timeZone: DateTimeZone = Try(DateTimeZone.forID((jsonObj \ "timeZone").extract[String])) 166 | // // .recover { case _ => null } 167 | // val timeZone: DateTimeZone = if (fieldMap.contains("timeZone")) { 168 | // DateTimeZone.forID((jsonObj \ "timeZone").extract[String]) 169 | // } else null 170 | // val origin: DateTime = if (fieldMap.contains("origin")) { 171 | // new DateTime((jsonObj \ "origin").extract[String], timeZone) 172 | // } else null 173 | // PeriodGranularity(period, origin, timeZone) 174 | // case Some(typ) if typ == "duration" => 175 | // val duration = (jsonObj \ "duration").extract[Long] 176 | // val origin: DateTime = if (fieldMap.contains("origin")) { 177 | // new DateTime((jsonObj \ "origin").extract[String]) 178 | // } else null 179 | // DurationGranularity(duration, origin) 180 | // } 181 | // }, 182 | // { 183 | // // PartialFunction used in serialize method. 184 | // case x: DruidQueryGranularity => 185 | // throw new RuntimeException("DruidQueryGranularity serialization not supported.") 186 | // } 187 | // ) 188 | //}) -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/DruidRDD.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | 4 | import com.fasterxml.jackson.core.{Base64Variant, Base64Variants} 5 | import org.apache.http.concurrent.Cancellable 6 | import org.apache.spark.{InterruptibleIterator, Partition, TaskContext} 7 | import org.apache.spark.rdd.RDD 8 | import org.apache.spark.sql.{MyLogging, SQLContext} 9 | import org.apache.spark.sql.catalyst.InternalRow 10 | import org.apache.spark.sql.catalyst.expressions.GenericInternalRow 11 | import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp 12 | import org.apache.spark.sql.types._ 13 | import org.joda.time.{DateTime, DateTimeZone, Interval} 14 | import org.rzlabs.druid.client.{CancellableHolder, ConnectionManager, DruidQueryServerClient, ResultRow} 15 | import org.rzlabs.druid.metadata.{DruidMetadataCache, DruidRelationInfo} 16 | import org.apache.spark.sql.sources.druid._ 17 | import org.apache.spark.unsafe.types.UTF8String 18 | 19 | import scala.collection.concurrent.TrieMap 20 | 21 | abstract class DruidPartition extends Partition { 22 | 23 | def queryClient(useSmile: Boolean, httpMaxConnPerRoute: Int, 24 | httpMaxConnTotal: Int): DruidQueryServerClient 25 | 26 | def intervals: List[Interval] 27 | 28 | def setIntervalsOnQuerySpec(qrySpec: QuerySpec): QuerySpec = { 29 | qrySpec.setIntervals(intervals) 30 | } 31 | } 32 | 33 | case class BrokerPartition(idx: Int, 34 | broker: String, 35 | in: Interval) extends DruidPartition { 36 | 37 | override def index: Int = idx 38 | 39 | override def queryClient(useSmile: Boolean, httpMaxConnPerRoute: Int, 40 | httpMaxConnTotal: Int): DruidQueryServerClient = { 41 | ConnectionManager.init(httpMaxConnPerRoute, httpMaxConnTotal) 42 | new DruidQueryServerClient(broker, useSmile) 43 | } 44 | 45 | override def intervals: List[Interval] = List(in) 46 | } 47 | 48 | class DruidRDD(sqlContext: SQLContext, 49 | drInfo: DruidRelationInfo, 50 | val druidQuery: DruidQuery 51 | ) extends RDD[InternalRow](sqlContext.sparkContext, Nil) { 52 | 53 | val httpMaxConnPerRoute = drInfo.options.poolMaxConnectionsPerRoute 54 | val httpMaxConnTotal = drInfo.options.poolMaxConnections 55 | val useSmile = druidQuery.useSmile 56 | val schema: StructType = druidQuery.schema(drInfo) 57 | 58 | // TODO: add recording Druid query logic. 59 | 60 | override def getPartitions: Array[Partition] = { 61 | val broker = DruidMetadataCache.getDruidClusterInfo(drInfo.fullName, 62 | drInfo.options).curatorConnection.getBroker 63 | druidQuery.intervals.zipWithIndex.map(t => new BrokerPartition(t._2, broker, t._1)).toArray 64 | } 65 | 66 | override def compute(split: Partition, context: TaskContext 67 | ): Iterator[InternalRow] = { 68 | val partition = split.asInstanceOf[DruidPartition] 69 | val qrySpec = partition.setIntervalsOnQuerySpec(druidQuery.qrySpec) 70 | //log.info("Druid querySpec: " + Utils.toPrettyJson(Left(qrySpec))) 71 | 72 | 73 | var cancelCallback: TaskCancelHandler.TaskCancelHolder = null 74 | var resultIter: CloseableIterator[ResultRow] = null 75 | var client: DruidQueryServerClient = null 76 | val queryId = qrySpec.context.map(_.queryId).getOrElse(s"query-${System.nanoTime()}") 77 | var queryStartTime = System.currentTimeMillis() 78 | var queryStartDT = (new DateTime()).toString 79 | try { 80 | cancelCallback = TaskCancelHandler.registerQueryId(queryId, context) 81 | client = partition.queryClient(useSmile, httpMaxConnPerRoute, httpMaxConnTotal) 82 | client.setCancellableHolder(cancelCallback) 83 | queryStartTime = System.currentTimeMillis() 84 | queryStartDT = (new DateTime()).toString() 85 | resultIter = qrySpec.executeQuery(client) 86 | } catch { 87 | case _ if cancelCallback.wasCancelTriggered && client != null => 88 | resultIter = new DummyResultIterator() 89 | case e: Throwable => throw e 90 | } finally { 91 | TaskCancelHandler.clearQueryId(queryId) 92 | } 93 | 94 | val druidExecTime = System.currentTimeMillis() - queryStartTime 95 | var numRows: Int = 0 96 | 97 | context.addTaskCompletionListener { taskContext => 98 | // TODO: add Druid query metrics. 99 | resultIter.closeIfNeeded() 100 | } 101 | 102 | val rIter = new InterruptibleIterator[ResultRow](context, resultIter) 103 | val nameToTF: Map[String, String] = druidQuery.getValTFMap() 104 | 105 | rIter.map { r => 106 | numRows += 1 107 | val row = new GenericInternalRow(schema.fields.map { field => 108 | DruidValTransform.sparkValue(field, r.event(field.name), 109 | nameToTF.get(field.name), drInfo.options.timeZoneId) 110 | }) 111 | row 112 | } 113 | } 114 | 115 | } 116 | 117 | /** 118 | * The "TaskCancel thread" tracks the Spark tasks that are executing Druid queries. 119 | * Periodically (current every 5 secs) it checks if any of the Spark tasks have been 120 | * cancelled and relays this to the current [[Cancellable]] associated with the 121 | * [[org.apache.http.client.methods.HttpExecutionAware]] connection handing the 122 | * Druid query. 123 | */ 124 | object TaskCancelHandler extends MyLogging { 125 | 126 | private val taskMap = TrieMap[String, (Cancellable, TaskCancelHolder, TaskContext)]() 127 | 128 | class TaskCancelHolder(val queryId: String, val taskContext: TaskContext) extends CancellableHolder { 129 | 130 | override def setCancellable(c: Cancellable): Unit = { 131 | log.debug(s"set cancellable for query $queryId") 132 | taskMap(queryId) = (c, this, taskContext) 133 | } 134 | 135 | @volatile var wasCancelTriggered: Boolean = false 136 | } 137 | 138 | def registerQueryId(queryId: String, taskContext: TaskContext): TaskCancelHolder = { 139 | log.debug(s"register query $queryId") 140 | new TaskCancelHolder(queryId, taskContext) 141 | } 142 | 143 | def clearQueryId(queryId: String): Unit = taskMap.remove(queryId) 144 | 145 | val sec5: Long = 5 * 1000 146 | 147 | object cancelCheckThread extends Runnable with MyLogging { 148 | 149 | def run(): Unit = { 150 | while (true) { 151 | Thread.sleep(sec5) 152 | log.debug("cancelThread woke up") 153 | var canceledTasks: Seq[String] = Seq() // queryId list 154 | taskMap.foreach { 155 | case (queryId, (request, taskConcelHolder, taskContext)) => 156 | log.debug(s"checking task stageId = ${taskContext.stageId()}, " + 157 | s"partitionId = ${taskContext.partitionId()}, " + 158 | s"isInterrupted = ${taskContext.isInterrupted()}") 159 | if (taskContext.isInterrupted()) { 160 | try { 161 | taskConcelHolder.wasCancelTriggered = true 162 | request.cancel() 163 | log.info(s"aborted http request for query $queryId: $request") 164 | canceledTasks = canceledTasks :+ queryId 165 | } catch { 166 | case e: Throwable => log.warn(s"failed to abort http request: $request") 167 | } 168 | } 169 | } 170 | canceledTasks.foreach(clearQueryId) 171 | } 172 | } 173 | } 174 | 175 | val t = new Thread(cancelCheckThread) 176 | t.setName("DruidRDD-TaskCancelCheckThread") 177 | t.setDaemon(true) 178 | t.start() 179 | } 180 | 181 | object DruidValTransform { 182 | 183 | private[this] val toTSWithTZAdj = (druidVal: Any, tz: String) => { 184 | val dvLong = if (druidVal.isInstanceOf[Double]) { 185 | druidVal.asInstanceOf[Double].toLong 186 | } else if (druidVal.isInstanceOf[BigInt]) { 187 | druidVal.asInstanceOf[BigInt].toLong 188 | } else if (druidVal.isInstanceOf[String]) { 189 | druidVal.asInstanceOf[String].toLong 190 | } else if (druidVal.isInstanceOf[Integer]) { 191 | druidVal.asInstanceOf[Integer].toLong 192 | } else { 193 | druidVal 194 | } 195 | 196 | new DateTime(dvLong, DateTimeZone.forID(tz)).getMillis * 1000.asInstanceOf[SQLTimestamp] 197 | } 198 | 199 | private[this] val toTS = (druidVal: Any, tz: String) => { 200 | if (druidVal.isInstanceOf[Double]) { 201 | druidVal.asInstanceOf[Double].longValue() 202 | } else if (druidVal.isInstanceOf[BigInt]) { 203 | druidVal.asInstanceOf[BigInt].toLong 204 | } else if (druidVal.isInstanceOf[Integer]) { 205 | druidVal.asInstanceOf[Integer].toLong 206 | } else druidVal 207 | } 208 | 209 | private[this] val toString = (druidVal: Any, tz: String) => { 210 | UTF8String.fromString(druidVal.toString) 211 | } 212 | 213 | private[this] val toInt = (druidVal: Any, tz: String) => { 214 | if (druidVal.isInstanceOf[Double]) { 215 | druidVal.asInstanceOf[Double].toInt 216 | } else if (druidVal.isInstanceOf[BigInt]) { 217 | druidVal.asInstanceOf[BigInt].toInt 218 | } else if (druidVal.isInstanceOf[String]) { 219 | druidVal.asInstanceOf[String].toInt 220 | } else if (druidVal.isInstanceOf[Integer]) { 221 | druidVal.asInstanceOf[Integer].toInt 222 | } else druidVal 223 | } 224 | 225 | private[this] val toLong = (druidVal: Any, tz: String) => { 226 | if (druidVal.isInstanceOf[Double]) { 227 | druidVal.asInstanceOf[Double].toLong 228 | } else if (druidVal.isInstanceOf[BigInt]) { 229 | druidVal.asInstanceOf[BigInt].toLong 230 | } else if (druidVal.isInstanceOf[String]) { 231 | druidVal.asInstanceOf[String].toLong 232 | } else if (druidVal.isInstanceOf[Integer]) { 233 | druidVal.asInstanceOf[Integer].toLong 234 | } else druidVal 235 | } 236 | 237 | private[this] val toFloat = (druidVal: Any, tz: String) => { 238 | if (druidVal.isInstanceOf[Double]) { 239 | druidVal.asInstanceOf[Double].toFloat 240 | } else if (druidVal.isInstanceOf[BigInt]) { 241 | druidVal.asInstanceOf[BigInt].toFloat 242 | } else if (druidVal.isInstanceOf[String]) { 243 | druidVal.asInstanceOf[String].toFloat 244 | } else if (druidVal.isInstanceOf[Integer]) { 245 | druidVal.asInstanceOf[Integer].toFloat 246 | } else druidVal 247 | } 248 | 249 | private[this] val tfMap: Map[String, (Any, String) => Any] = { 250 | Map[String, (Any, String) => Any]( 251 | "toTSWithTZAdj" -> toTSWithTZAdj, 252 | "toTS" -> toTS, 253 | "toString" -> toString, 254 | "toInt" -> toInt, 255 | "toLong" -> toLong, 256 | "toFloat" -> toFloat 257 | ) 258 | } 259 | 260 | def defaultValueConversion(f: StructField, druidVal: Any): Any = f.dataType match { 261 | case TimestampType if druidVal.isInstanceOf[Double] => 262 | druidVal.asInstanceOf[Double].longValue() 263 | case StringType if druidVal != null => UTF8String.fromString(druidVal.toString) 264 | case LongType if druidVal.isInstanceOf[BigInt] => 265 | druidVal.asInstanceOf[BigInt].longValue() 266 | case LongType if druidVal.isInstanceOf[Integer] => 267 | druidVal.asInstanceOf[Integer].longValue() 268 | case BinaryType if druidVal.isInstanceOf[String] => 269 | Base64Variants.getDefaultVariant.decode(druidVal.asInstanceOf[String]) 270 | case _ => druidVal 271 | 272 | } 273 | 274 | def sparkValue(f: StructField, druidVal: Any, tfName: Option[String], tz: String): Any = { 275 | tfName match { 276 | case Some(tf) if (tfMap.contains(tf) && druidVal != null) => tfMap(tf)(druidVal, tz) 277 | case _ => defaultValueConversion(f, druidVal) 278 | } 279 | } 280 | 281 | def getTFName(sparkDT: DataType, adjForTZ: Boolean = false): String = sparkDT match { 282 | case TimestampType if adjForTZ => "toTSWithTZAdj" 283 | case TimestampType => "toTS" 284 | case StringType if !adjForTZ => "toString" 285 | case ShortType if !adjForTZ => "toInt" 286 | case LongType => "toLong" 287 | case FloatType => "toFloat" 288 | case _ => "" 289 | } 290 | 291 | 292 | } 293 | 294 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/DruidRelation.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | import org.apache.spark.rdd.RDD 4 | import org.apache.spark.sql.catalyst.InternalRow 5 | import org.apache.spark.sql.catalyst.expressions._ 6 | import org.apache.spark.sql.{Row, SQLContext} 7 | import org.apache.spark.sql.sources.{BaseRelation, TableScan} 8 | import org.apache.spark.sql.types._ 9 | import org.fasterxml.jackson.databind.ObjectMapper._ 10 | import org.joda.time.Interval 11 | import org.rzlabs.druid.metadata.{DruidRelationColumn, DruidRelationInfo} 12 | 13 | import scala.collection.mutable.ArrayBuffer 14 | 15 | case class DruidAttribute(exprId: ExprId, name: String, dataType: DataType, tf: String = null) 16 | 17 | /** 18 | * 19 | * @param qrySpec 20 | * @param useSmile 21 | * @param queryHistoricalServer currently this is always false. 22 | * @param numSegmentsPerQuery currently this is always -1. 23 | * @param intervals 24 | * @param outputAttrSpec Attributes to be output from the RawDataSourceScanExec. Each output attribute is 25 | * based on an Attribute in the original logical plan. The association is based on 26 | * the exprId of `NamedExpression`. 27 | */ 28 | case class DruidQuery(qrySpec: QuerySpec, 29 | useSmile: Boolean, 30 | queryHistoricalServer: Boolean, 31 | numSegmentsPerQuery: Int, 32 | intervals: List[Interval], 33 | outputAttrSpec: Option[List[DruidAttribute]] 34 | ) { 35 | 36 | def this(qrySpec: QuerySpec, useSmile: Boolean = true, 37 | queryHistoricalServer: Boolean = false, 38 | numSegmentPerQuery: Int = -1) = { 39 | this(qrySpec, useSmile, queryHistoricalServer, numSegmentPerQuery, 40 | qrySpec.intervalList.map(new Interval(_)), None) 41 | } 42 | 43 | private def schemaFromQuerySpec(drInfo: DruidRelationInfo): StructType = { 44 | qrySpec.schemaFromQuerySpec(drInfo) 45 | } 46 | 47 | private lazy val schemaFromOutputSpec: StructType = { 48 | StructType(outputAttrSpec.getOrElse(Nil).map { 49 | case DruidAttribute(_, name, dataType, _) => 50 | new StructField(name, dataType) 51 | }) 52 | } 53 | 54 | def schema(drInfo: DruidRelationInfo): StructType = { 55 | schemaFromOutputSpec.length match { 56 | case 0 => schemaFromQuerySpec(drInfo) 57 | case _ => schemaFromOutputSpec 58 | } 59 | } 60 | 61 | private def outputAttrsFromQuerySpec(drInfo: DruidRelationInfo): Seq[Attribute] = { 62 | schemaFromQuerySpec(drInfo).map { 63 | case StructField(name, dataType, _, _) => AttributeReference(name, dataType)() 64 | } 65 | } 66 | 67 | private lazy val outputAttrsFromOutputSpec: Seq[Attribute] = { 68 | outputAttrSpec.getOrElse(Nil).map { 69 | case DruidAttribute(exprId, name, dataType, _) => 70 | AttributeReference(name, dataType)(exprId) 71 | } 72 | } 73 | 74 | def outputAttrs(drInfo: DruidRelationInfo): Seq[Attribute] = { 75 | outputAttrsFromOutputSpec.size match { 76 | case 0 => outputAttrsFromQuerySpec(drInfo) 77 | case _ => outputAttrsFromOutputSpec 78 | } 79 | } 80 | 81 | def getValTFMap(): Map[String, String] = { 82 | outputAttrSpec.getOrElse(Nil).map { 83 | case DruidAttribute(_, name, _, tf) => 84 | name -> tf 85 | }.toMap 86 | } 87 | } 88 | 89 | case class DruidRelation(val info: DruidRelationInfo, val druidQuery: Option[DruidQuery])( 90 | @transient val sqlContext: SQLContext) 91 | extends BaseRelation with TableScan { 92 | 93 | override val needConversion: Boolean = false; 94 | 95 | override def schema: StructType = { 96 | var timeField: ArrayBuffer[StructField] = ArrayBuffer() 97 | val dimensionFields: ArrayBuffer[StructField] = ArrayBuffer() 98 | val metricFields: ArrayBuffer[StructField] = ArrayBuffer() 99 | info.druidColumns.map { 100 | case (columnName: String, relationColumn: DruidRelationColumn) => 101 | val (sparkType, isDimension) = relationColumn.druidColumn match { 102 | case Some(dc) => 103 | (DruidDataType.sparkDataType(dc.dataType), dc.isDimension()) 104 | case None => // Some hll metric's origin column may not be indexed 105 | if (relationColumn.hllMetric.isEmpty && 106 | relationColumn.sketchMetric.isEmpty) { 107 | throw new DruidDataSourceException(s"Illegal column $relationColumn") 108 | } 109 | (StringType, true) 110 | } 111 | if (columnName == info.timeDimensionCol) { 112 | // Here specifies time dimension's spark data type as StringType 113 | // instead TimestampType because the precision of timestamp type 114 | // is not enough (in Druid may the query granularity be millis) 115 | timeField += StructField(info.timeDimensionCol, StringType) 116 | } else if (isDimension) { 117 | dimensionFields += StructField(columnName, sparkType) 118 | } else { 119 | metricFields += StructField(columnName, sparkType) 120 | } 121 | } 122 | StructType(timeField ++ dimensionFields ++ metricFields) 123 | } 124 | 125 | def buildInternalScan: RDD[InternalRow] = { 126 | druidQuery.map(new DruidRDD(sqlContext, info, _)).getOrElse(null) 127 | } 128 | 129 | override def buildScan(): RDD[Row] = { 130 | buildInternalScan.asInstanceOf[RDD[Row]] 131 | } 132 | 133 | override def toString(): String = { 134 | druidQuery.map { dq => 135 | s"DruidQuery: ${Utils.toPrettyJson(scala.util.Left(dq))}" 136 | }.getOrElse { 137 | info.toString 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/QueryIntervals.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | import org.joda.time.{DateTime, Interval} 4 | import org.rzlabs.druid.metadata.{DruidMetadataCache, DruidRelationInfo} 5 | 6 | case class QueryIntervals(druidRelationInfo: DruidRelationInfo, 7 | intervals: List[Interval] = Nil) { 8 | 9 | // The whole interval of Druid data source indexed data. 10 | var indexIntervals = druidRelationInfo.druidDataSource.intervals 11 | 12 | private def indexInterval(dt: DateTime): Option[Interval] = { 13 | indexIntervals.find(_.contains(dt)) 14 | } 15 | 16 | def get: List[Interval] = if (intervals.isEmpty) indexIntervals else intervals 17 | 18 | /** 19 | * - if this is the first queryInterval, add it. 20 | * - if the new Interval overlaps with the current QueryInterval set interval to the overlap. 21 | * - otherwise, the interval is an empty Interval. 22 | * @param in 23 | * @return 24 | */ 25 | private def add(in: Interval): QueryIntervals = { 26 | if (intervals.isEmpty) { 27 | // The first query interval. e.g., time < '2017-01-01T00:00:00' 28 | QueryIntervals(druidRelationInfo, List(in)) 29 | } else { 30 | // new interval overlaps old. 31 | val oldIn = intervals.head 32 | if (oldIn.overlaps(in)) { 33 | // two interval overlaps. 34 | // e.g., time > '2017-01-01T00:00:00' and time < '2018-01-01T00:00:00' 35 | QueryIntervals(druidRelationInfo, List(oldIn.overlap(in))) 36 | } else { 37 | // e.g., time > '2017-01-01T00:00:00' and time < '2016-01-01T00:00:00' 38 | val idxIn = indexIntervals.head 39 | QueryIntervals(druidRelationInfo, List(idxIn.withEnd(idxIn.getStart))) 40 | } 41 | } 42 | } 43 | 44 | private def outsideIndexRange(dt: DateTime, ict: IntervalConditionType.Value): Option[QueryIntervals] = { 45 | if (indexIntervals.size == 1) { // This seems always be. 46 | ict match { 47 | case IntervalConditionType.LT | IntervalConditionType.LTE 48 | if indexIntervals(0).isBefore(dt) => 49 | Some(add(indexIntervals(0))) 50 | case IntervalConditionType.GT | IntervalConditionType.GTE 51 | if indexIntervals(0).isAfter(dt) => 52 | Some(add(indexIntervals(0))) 53 | case _ => None 54 | } 55 | } else None 56 | } 57 | 58 | private def updateIndexInterval: Boolean = { 59 | indexIntervals.synchronized { 60 | val newIn = DruidMetadataCache.brokerClient.timeBoundary(druidRelationInfo.druidDataSource.name) 61 | if (newIn != indexIntervals.head) { 62 | druidRelationInfo.druidDataSource.intervals = List(newIn) 63 | indexIntervals = druidRelationInfo.druidDataSource.intervals 64 | true 65 | } else false 66 | } 67 | } 68 | 69 | def ltCond(dt: DateTime): Option[QueryIntervals] = { 70 | indexInterval(dt).map { in => 71 | val newIn = in.withEnd(dt) 72 | add(newIn) 73 | } orElse { 74 | if (updateIndexInterval) { 75 | ltCond(dt) 76 | } else { 77 | outsideIndexRange(dt, IntervalConditionType.LT) 78 | } 79 | } 80 | } 81 | 82 | def lteCond(dt: DateTime): Option[QueryIntervals] = { 83 | indexInterval(dt).map { in => 84 | val newIn = in.withEnd(dt.plusMillis(1)) // This because interval's end is excluded. 85 | add(newIn) 86 | } orElse { 87 | if(updateIndexInterval) { 88 | lteCond(dt) 89 | } else { 90 | outsideIndexRange(dt, IntervalConditionType.LTE) 91 | } 92 | } 93 | } 94 | 95 | def gtCond(dt: DateTime): Option[QueryIntervals] = { 96 | indexInterval(dt).map { in => 97 | val newIn = in.withStart(dt.plusMillis(1)) // This because interval's end is included. 98 | add(newIn) 99 | } orElse { 100 | if (updateIndexInterval) { 101 | gtCond(dt) 102 | } else { 103 | outsideIndexRange(dt, IntervalConditionType.GT) 104 | } 105 | } 106 | } 107 | 108 | def gteCond(dt: DateTime): Option[QueryIntervals] = { 109 | indexInterval(dt).map { in => 110 | val newIn = in.withStart(dt) 111 | add(newIn) 112 | } orElse { 113 | if (updateIndexInterval) { 114 | gteCond(dt) 115 | } else { 116 | outsideIndexRange(dt, IntervalConditionType.GTE) 117 | } 118 | } 119 | } 120 | 121 | } 122 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/Utils.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid 2 | 3 | import com.fasterxml.jackson.databind.JsonNode 4 | import org.apache.spark.sql.MyLogging 5 | import org.joda.time.Interval 6 | import org.fasterxml.jackson.databind.ObjectMapper._ 7 | 8 | object Utils extends MyLogging { 9 | 10 | // implicit val jsonFormat = Serialization.formats( 11 | // ShortTypeHints( 12 | // List( 13 | // classOf[DruidRelationColumnInfo], 14 | // classOf[DurationGranularity], 15 | // classOf[PeriodGranularity] 16 | // ) 17 | // ) 18 | // ) + new DruidQueryGranularitySerializer 19 | 20 | def intervalsMillis(intervals: List[Interval]): Long = { 21 | intervals.foldLeft[Long](0L) { 22 | case (t, in) => t + (in.getEndMillis - in.getStartMillis) 23 | } 24 | } 25 | 26 | def updateInterval(interval: Interval, `with`: Interval, withType: String) = { 27 | interval 28 | .withStartMillis(Math.min(interval.getStartMillis, `with`.getStartMillis)) 29 | .withEndMillis(Math.max(interval.getEndMillis, `with`.getEndMillis)) 30 | } 31 | 32 | def filterSomes[A](a: List[Option[A]]): List[Option[A]] = { 33 | a.filter { case Some(x) => true; case None => false } 34 | } 35 | 36 | /** 37 | * transform List[Option] tp Option[List] 38 | * @param a 39 | * @tparam A 40 | * @return 41 | */ 42 | def sequence[A](a: List[Option[A]]): Option[List[A]] = a match { 43 | case Nil => Some(Nil) 44 | case head :: tail => head.flatMap (h => sequence(tail).map(h :: _)) 45 | } 46 | 47 | def toPrettyJson(obj: Either[AnyRef, JsonNode]) = { 48 | jsonMapper.writerWithDefaultPrettyPrinter().writeValueAsString( 49 | if (obj.isLeft) obj.left else obj.right) 50 | } 51 | 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/client/CuratorConnection.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.client 2 | 3 | import java.io.IOException 4 | import java.util.concurrent.ExecutorService 5 | 6 | import org.apache.curator.framework.api.CompressionProvider 7 | import org.apache.curator.framework.imps.GzipCompressionProvider 8 | import org.apache.curator.framework.recipes.cache.PathChildrenCache.StartMode 9 | import org.apache.curator.framework.recipes.cache.{ChildData, PathChildrenCache, PathChildrenCacheEvent, PathChildrenCacheListener} 10 | import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} 11 | import org.apache.curator.retry.BoundedExponentialBackoffRetry 12 | import org.apache.curator.utils.ZKPaths 13 | import org.apache.spark.sql.MyLogging 14 | import org.json4s._ 15 | import org.json4s.jackson.JsonMethods._ 16 | import org.rzlabs.druid.metadata.{DruidClusterInfo, DruidNode, DruidOptions} 17 | import org.rzlabs.druid.{DruidDataSourceException, Utils} 18 | import org.fasterxml.jackson.databind.ObjectMapper._ 19 | 20 | import scala.collection.mutable.{Map => MMap} 21 | 22 | class CuratorConnection(val zkHost: String, 23 | val options: DruidOptions, 24 | val cache: MMap[String, DruidClusterInfo], 25 | execSvc: ExecutorService, 26 | updateTimeBoundary: String => Unit 27 | ) extends MyLogging { 28 | 29 | // Cache the active historical servers. 30 | val serverQueueCacheMap: MMap[String, PathChildrenCache] = MMap() 31 | private val serverQueueCacheLock = new Object 32 | // serverName -> serverList 33 | val discoveryServers: MMap[String, Seq[String]] = MMap() 34 | private val discoveryCacheLock = new Object 35 | 36 | val announcementsPath = ZKPaths.makePath(options.zkDruidPath, "announcements") 37 | val serverSegmentsPath = ZKPaths.makePath(options.zkDruidPath, "segments") 38 | val discoveryPath = ZKPaths.makePath(options.zkDruidPath, "discovery") 39 | val loadQueuePath = ZKPaths.makePath(options.zkDruidPath, "loadQueue") 40 | val brokersPath = ZKPaths.makePath(discoveryPath, getServiceName("broker")) 41 | val coordinatorsPath = ZKPaths.makePath(discoveryPath, getServiceName("coordinator")) 42 | 43 | val framework: CuratorFramework = CuratorFrameworkFactory.builder 44 | .connectString(zkHost) 45 | .sessionTimeoutMs(options.zkSessionTimeoutMs) 46 | .retryPolicy(new BoundedExponentialBackoffRetry(1000, 45000, 30)) 47 | .compressionProvider(new PotentiallyGzippedCompressionProvider(options.zkEnableCompression)) 48 | .build() 49 | 50 | val announcementsCache: PathChildrenCache = new PathChildrenCache( 51 | framework, 52 | announcementsPath, 53 | true, 54 | true, 55 | execSvc 56 | ) 57 | 58 | val brokersCache: PathChildrenCache = new PathChildrenCache( 59 | framework, 60 | brokersPath, 61 | true, 62 | true, 63 | execSvc 64 | ) 65 | 66 | val coordinatorsCache: PathChildrenCache = new PathChildrenCache( 67 | framework, 68 | coordinatorsPath, 69 | true, 70 | true, 71 | execSvc 72 | ) 73 | 74 | /** 75 | * A [[PathChildrenCacheListener]] which is used to monitor 76 | * coordinators/brokers/overlords' in and out. 77 | * In the case of multiple brokers all them will be active if necessary. 78 | * In the case of multiple coordinators just one is active while others are standby. 79 | * 80 | */ 81 | val discoveryListener = new PathChildrenCacheListener { 82 | override def childEvent(client: CuratorFramework, event: PathChildrenCacheEvent): Unit = { 83 | event.getType match { 84 | case eventType @ (PathChildrenCacheEvent.Type.CHILD_ADDED | 85 | PathChildrenCacheEvent.Type.CHILD_REMOVED) => 86 | discoveryCacheLock.synchronized { 87 | //val data = getZkDataForNode(event.getData.getPath) 88 | val data = event.getData.getData 89 | val druidNode = jsonMapper.readValue(new String(data), classOf[DruidNode]) 90 | val host = s"${druidNode.address}:${druidNode.port}" 91 | val serviceName = druidNode.name 92 | val serverSeq = getServerSeq(serviceName) 93 | if (eventType == PathChildrenCacheEvent.Type.CHILD_ADDED) { 94 | logInfo(s"Server[$serviceName][$host] is added in the path ${event.getData.getPath}") 95 | if (serverSeq.contains(host)) { 96 | logWarning(s"New server[$serviceName][$host] but there was already one, ignoring new one.", host) 97 | } else { 98 | discoveryServers(serviceName) = serverSeq :+ host 99 | logDebug(s"New server[$host] is added to cache.") 100 | } 101 | } else { 102 | logInfo(s"Server[$serviceName][$host] is removed from the path ${event.getData.getPath}") 103 | if (serverSeq.contains(host)) { 104 | discoveryServers(serviceName) = serverSeq.filterNot(_ == host) 105 | logDebug(s"Server[$host] is offline, so remove it from the cache.") 106 | } else { 107 | logError(s"Server[$host] is not in the cache, how to remove it from cache?") 108 | } 109 | } 110 | } 111 | case _ => () 112 | } 113 | } 114 | } 115 | 116 | /** 117 | * A [[PathChildrenCacheListener]] which is used to monitor segments' in and out and 118 | * update time boundary for datasources. 119 | * The occurrence of a CHILD_ADDED event means there's a new segment of some datasource 120 | * is added to the `loadQueue`, we should do nothing because the segment may be queued 121 | * and it is not announced at that time. 122 | * The occurrence of a CHILD_REMOVED event means there's a segment is removed from `loadQueue` 123 | * and it will be announced then, we should update the datasource's time boundary. 124 | */ 125 | val segmentLoadQueueListener = new PathChildrenCacheListener { 126 | override def childEvent(client: CuratorFramework, event: PathChildrenCacheEvent): Unit = { 127 | event.getType match { 128 | case eventType @ PathChildrenCacheEvent.Type.CHILD_REMOVED => 129 | logDebug(s"event $eventType occurred.") 130 | // This will throw a KeeperException because the removed path is not existence. 131 | //val nodeData = getZkDataForNode(event.getData.getPath) 132 | val nodeData: Array[Byte] = event.getData.getData 133 | if (nodeData == null) { 134 | logWarning(s"Ignoring event: Type - ${event.getType}, " + 135 | s"Path - ${event.getData.getPath}, Version - ${event.getData.getStat.getVersion}") 136 | } else { 137 | updateTimeBoundary(new String(nodeData)) 138 | } 139 | case _ => () 140 | } 141 | } 142 | } 143 | 144 | /** 145 | * A [[PathChildrenCacheListener]] which is used to monitor historical servers' in and out 146 | * and manage the relationships of historical servers and their [[PathChildrenCache]]s. 147 | */ 148 | val announcementsListener = new PathChildrenCacheListener { 149 | override def childEvent(client: CuratorFramework, event: PathChildrenCacheEvent): Unit = { 150 | event.getType match { 151 | case PathChildrenCacheEvent.Type.CHILD_ADDED => 152 | // New historical server is added to the Druid cluster. 153 | serverQueueCacheLock.synchronized { 154 | // Get the historical server addr from path child data. 155 | val key = getServerKey(event) 156 | logInfo(s"Historical server[$key] is added to the path ${event.getData.getPath}") 157 | if (serverQueueCacheMap.contains(key)) { 158 | logWarning(s"New historical[$key] but there was already one, ignoring new one.") 159 | } else if (key != null) { 160 | val queuePath = ZKPaths.makePath(loadQueuePath, key) 161 | val queueCache = new PathChildrenCache( 162 | framework, 163 | queuePath, 164 | true, 165 | true, 166 | execSvc 167 | ) 168 | queueCache.getListenable.addListener(segmentLoadQueueListener) 169 | serverQueueCacheMap(key) = queueCache 170 | logDebug(s"Starting inventory cache for $key, inventoryPath $queuePath", Array(key, queuePath)) 171 | // Start cache and trigger the CHILD_ADDED event. 172 | //segmentsCache.start(StartMode.POST_INITIALIZED_EVENT) 173 | // Start cache and do not trigger the CHILD_ADDED by default. 174 | queueCache.start(StartMode.BUILD_INITIAL_CACHE) 175 | } 176 | } 177 | case PathChildrenCacheEvent.Type.CHILD_REMOVED => 178 | // A historical server is offline. 179 | serverQueueCacheLock.synchronized { 180 | val key = getServerKey(event) 181 | logInfo(s"Historical server[$key] is removed from the path ${event.getData.getPath}") 182 | val segmentsCache: Option[PathChildrenCache] = serverQueueCacheMap.remove(key) 183 | if (segmentsCache.isDefined) { 184 | logInfo(s"Closing inventory for $key. Also removing listeners.") 185 | segmentsCache.get.getListenable.clear() 186 | segmentsCache.get.close() 187 | } else logWarning(s"Cache[$key] removed that wasn't cache!?") 188 | } 189 | case _ => () 190 | } 191 | } 192 | } 193 | 194 | announcementsCache.getListenable.addListener(announcementsListener) 195 | brokersCache.getListenable.addListener(discoveryListener) 196 | coordinatorsCache.getListenable.addListener(discoveryListener) 197 | 198 | framework.start() 199 | announcementsCache.start(StartMode.POST_INITIALIZED_EVENT) 200 | brokersCache.start(StartMode.POST_INITIALIZED_EVENT) 201 | coordinatorsCache.start(StartMode.POST_INITIALIZED_EVENT) 202 | 203 | // def getService(name: String): String = { 204 | // getServices(name).head 205 | // } 206 | 207 | def getService(name: String): String = { 208 | val serviceName = getServiceName(name) 209 | discoveryCacheLock.synchronized { 210 | var serverSeq = getServerSeq(serviceName) 211 | if (serverSeq.isEmpty) { 212 | serverSeq = getServices(name) 213 | if (serverSeq.isEmpty) { 214 | return null 215 | } else { 216 | discoveryServers(serviceName) = serverSeq 217 | } 218 | } 219 | val server = serverSeq.head 220 | discoveryServers(serviceName) = serverSeq.tail :+ server 221 | server 222 | } 223 | } 224 | 225 | def getServices(name: String): Seq[String] = { 226 | 227 | val serviceName = getServiceName(name) 228 | val servicePath = ZKPaths.makePath(discoveryPath, serviceName) 229 | val childrenNodes: java.util.List[String] = framework.getChildren.forPath(servicePath) 230 | var services: Seq[String] = Nil 231 | try { 232 | val iter = childrenNodes.iterator() 233 | while (iter.hasNext) { 234 | val childNode = iter.next() 235 | val childPath = ZKPaths.makePath(servicePath, childNode) 236 | val data: Array[Byte] = getZkDataForNode(childPath) 237 | if (data != null) { 238 | val druidNode = jsonMapper.readValue(new String(data), classOf[DruidNode]) 239 | services = services :+ s"${druidNode.address}:${druidNode.port}" 240 | } 241 | } 242 | } catch { 243 | case e: Exception => 244 | throw new DruidDataSourceException(s"Failed to get '$name' for zkHost '$zkHost'", e) 245 | } 246 | if (services.isEmpty) { 247 | throw new DruidDataSourceException(s"There's no '$name' for zkHost '$zkHost' in path '$servicePath'") 248 | } 249 | services 250 | } 251 | 252 | def getBroker: String = { 253 | getService("broker") 254 | } 255 | 256 | def getCoordinator: String = { 257 | getService("coordinator") 258 | } 259 | 260 | private def getServerSeq(name: String): Seq[String] = { 261 | discoveryServers.get(name).getOrElse { 262 | val l: List[String] = List() 263 | discoveryServers(name) = l 264 | l 265 | } 266 | } 267 | 268 | private def getServiceName(name: String): String = { 269 | if (options.zkQualifyDiscoveryNames) { 270 | s"${options.zkDruidPath}:$name".tail 271 | } else name 272 | } 273 | 274 | private def getServerKey(event: PathChildrenCacheEvent): String = { 275 | val child: ChildData = event.getData 276 | //val data: Array[Byte] = getZkDataForNode(child.getPath) 277 | val data: Array[Byte] = child.getData 278 | if (data == null) { 279 | logWarning(s"Ignoring event: Type - ${event.getType}, " + 280 | s"Path - ${child.getPath}, Version - ${child.getStat.getVersion}") 281 | null 282 | } else { 283 | ZKPaths.getNodeFromPath(child.getPath) 284 | } 285 | } 286 | 287 | private def getZkDataForNode(path: String): Array[Byte] = { 288 | try { 289 | framework.getData.decompressed().forPath(path) 290 | } catch { 291 | case e: Exception => { 292 | logError(s"Exception occurs while getting data fro node $path", e) 293 | null 294 | } 295 | } 296 | } 297 | } 298 | 299 | /* 300 | * copied from druid code base. 301 | */ 302 | class PotentiallyGzippedCompressionProvider(val compressOutput: Boolean) 303 | extends CompressionProvider { 304 | 305 | private val base: GzipCompressionProvider = new GzipCompressionProvider 306 | 307 | 308 | @throws[Exception] 309 | def compress(path: String, data: Array[Byte]): Array[Byte] = { 310 | return if (compressOutput) base.compress(path, data) 311 | else data 312 | } 313 | 314 | @throws[Exception] 315 | def decompress(path: String, data: Array[Byte]): Array[Byte] = { 316 | try { 317 | return base.decompress(path, data) 318 | } 319 | catch { 320 | case e: IOException => { 321 | return data 322 | } 323 | } 324 | } 325 | } -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/client/DruidClient.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.client 2 | 3 | import com.fasterxml.jackson.core.`type`.TypeReference 4 | import com.fasterxml.jackson.databind.node.ObjectNode 5 | import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes 6 | import org.apache.commons.io.IOUtils 7 | import org.apache.commons.lang.exception.ExceptionUtils 8 | import org.apache.http.{HttpEntity, HttpHeaders} 9 | import org.apache.http.client.methods._ 10 | import org.apache.spark.sql.MyLogging 11 | import org.apache.http.concurrent._ 12 | import org.apache.http.entity.{ByteArrayEntity, ContentType, StringEntity} 13 | import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} 14 | import org.apache.http.impl.conn.PoolingHttpClientConnectionManager 15 | import org.apache.http.util.EntityUtils 16 | import org.apache.spark.sql.sources.druid.CloseableIterator 17 | import org.fasterxml.jackson.databind.ObjectMapper._ 18 | import org.joda.time.{DateTime, Interval} 19 | import org.rzlabs.druid.metadata.DruidOptions 20 | import org.rzlabs.druid._ 21 | 22 | import scala.util.Try 23 | 24 | object ConnectionManager { 25 | 26 | @volatile private var initialized: Boolean = false 27 | 28 | lazy val pool = { 29 | val p = new PoolingHttpClientConnectionManager() 30 | p.setMaxTotal(40) 31 | p.setDefaultMaxPerRoute(8) 32 | p 33 | } 34 | 35 | def init(druidOptions: DruidOptions): Unit = { 36 | if (!initialized) { 37 | init(druidOptions.poolMaxConnectionsPerRoute, 38 | druidOptions.poolMaxConnections) 39 | initialized = true 40 | } 41 | } 42 | 43 | def init(maxPerRoute: Int, maxTotal: Int): Unit = { 44 | if (!initialized) { 45 | pool.setMaxTotal(maxTotal) 46 | pool.setDefaultMaxPerRoute(maxPerRoute) 47 | initialized = true 48 | } 49 | } 50 | } 51 | 52 | /** 53 | * A mechanism to relay [[org.apache.http.concurrent.Cancellable]] resources 54 | * associated with the "http connection" of a "DruidClient". This is used by 55 | * the [[org.rzlabs.druid.TaskCancelHandler]] to capture the association 56 | * between "Spark Tasks" and "Cancellable" resources (connections). 57 | */ 58 | trait CancellableHolder { 59 | def setCancellable(c: Cancellable) 60 | } 61 | 62 | /** 63 | * A mixin trait that relays [[Cancellable]] resources to 64 | * a [[CancellableHolder]]. 65 | */ 66 | trait DruidClientHttpExecutionAware extends HttpExecutionAware { 67 | 68 | val ch: CancellableHolder 69 | 70 | abstract override def isAborted = super.isAborted 71 | 72 | abstract override def setCancellable(cancellable: Cancellable): Unit = { 73 | if (ch != null) { 74 | ch.setCancellable(cancellable) 75 | } 76 | super.setCancellable(cancellable) 77 | } 78 | } 79 | 80 | /** 81 | * Configure [[HttpPost]] to have the [[DruidClientHttpExecutionAware]] trait, 82 | * so that [[Cancellable]] resources are relayed to the registered [[CancellableHolder]]. 83 | * @param url The url the request posted to. 84 | * @param ch The registered CancellableHolder. 85 | */ 86 | class DruidHttpPost(url: String, val ch: CancellableHolder) 87 | extends HttpPost(url) with DruidClientHttpExecutionAware 88 | 89 | /** 90 | * Configure [[HttpGet]] to have the [[DruidClientHttpExecutionAware]] trait, 91 | * so that [[Cancellable]] resources are relayed to the registered [[CancellableHolder]]. 92 | * @param url The url the request take data from. 93 | * @param ch The registered CancellableHolder. 94 | */ 95 | class DruidHttpGet(url: String, val ch: CancellableHolder) 96 | extends HttpGet(url) with DruidClientHttpExecutionAware 97 | 98 | /** 99 | * `DruidClient` is not thread-safe because `cancellableHolder` state is used to relay 100 | * cancellable resources information. 101 | * @param host Server host. 102 | * @param port Server port. 103 | * @param useSmile Use smile binary JSON format or not. 104 | */ 105 | abstract class DruidClient(val host: String, 106 | val port: Int, 107 | val useSmile: Boolean = false) extends MyLogging { 108 | 109 | private var cancellableHolder: CancellableHolder = null 110 | 111 | def this(t: (String, Int)) = { 112 | this(t._1, t._2) 113 | } 114 | 115 | def this(s: String) = { 116 | this(DruidClient.hostPort(s)) 117 | } 118 | 119 | def setCancellableHolder(c: CancellableHolder): Unit = { 120 | cancellableHolder = c 121 | } 122 | 123 | /** 124 | * A [[CloseableHttpClient]] is a [[org.apache.http.client.HttpClient]] 125 | * with a `close` method in [[java.io.Closeable]]. 126 | * @return 127 | */ 128 | protected def httpClient: CloseableHttpClient = { 129 | val sTime = System.currentTimeMillis() 130 | val r = HttpClients.custom().setConnectionManager(ConnectionManager.pool).build() 131 | val eTime = System.currentTimeMillis() 132 | logDebug(s"Time to get httpClient: ${eTime - sTime}") 133 | logDebug("Pool Stats: {}", ConnectionManager.pool.getTotalStats) 134 | r 135 | } 136 | 137 | /** 138 | * Close the [[java.io.InputStream]] represented by the 139 | * `resp.getEntity.getContent()` to return a 140 | * [[org.apache.http.client.HttpClient]] to the 141 | * connection pool. 142 | * @param resp 143 | */ 144 | protected def release(resp: CloseableHttpResponse): Unit = { 145 | Try { 146 | if (resp != null) EntityUtils.consume(resp.getEntity) 147 | } recover { 148 | case e => logError("Error returning client to pool", 149 | ExceptionUtils.getStackTrace(e)) 150 | } 151 | } 152 | 153 | protected def getRequest(url: String) = new DruidHttpGet(url, cancellableHolder) 154 | protected def postRequest(url: String) = new DruidHttpPost(url, cancellableHolder) 155 | 156 | protected def addHeaders(req: HttpRequestBase, reqHeaders: Map[String, String]): Unit = { 157 | if (useSmile) { 158 | req.addHeader(HttpHeaders.CONTENT_TYPE, SmileMediaTypes.APPLICATION_JACKSON_SMILE) 159 | } 160 | if (reqHeaders != null) { 161 | reqHeaders.foreach(header => req.setHeader(header._1, header._2)) 162 | } 163 | } 164 | 165 | @throws[DruidDataSourceException] 166 | protected def perform(url: String, 167 | reqType: String => HttpRequestBase, 168 | payload: ObjectNode, 169 | reqHeaders: Map[String, String]): String = { 170 | var resp: CloseableHttpResponse = null 171 | 172 | val tis: Try[String] = for { 173 | r <- Try { 174 | val req: CloseableHttpClient = httpClient 175 | val request = reqType(url) 176 | // Just HttpPost extends HttpEntityEnclosingRequestBase. 177 | // HttpGet extends HttpRequestBase. 178 | if (payload != null && request.isInstanceOf[HttpEntityEnclosingRequestBase]) { 179 | val input: HttpEntity = if (!useSmile) { 180 | new StringEntity(jsonMapper.writeValueAsString(payload), ContentType.APPLICATION_JSON) 181 | } else { 182 | new ByteArrayEntity(smileMapper.writeValueAsBytes(payload), null) 183 | } 184 | request.asInstanceOf[HttpEntityEnclosingRequestBase].setEntity(input) 185 | } 186 | addHeaders(request, reqHeaders) 187 | resp = req.execute(request) 188 | resp 189 | } 190 | is <- Try { 191 | val status = r.getStatusLine.getStatusCode 192 | if (status >= 200 && status < 300) { 193 | if (r.getEntity != null) { 194 | IOUtils.toString(r.getEntity.getContent) 195 | } else { 196 | throw new DruidDataSourceException(s"Unexpected response status: ${r.getStatusLine}") 197 | } 198 | } else { 199 | throw new DruidDataSourceException(s"Unexpected response status: ${r.getStatusLine}") 200 | } 201 | } 202 | } yield is 203 | 204 | release(resp) 205 | tis.getOrElse(tis.failed.get match { 206 | case de: DruidDataSourceException => throw de 207 | case e => throw new DruidDataSourceException("Failed in communication with Druid", e) 208 | }) 209 | } 210 | 211 | @throws[DruidDataSourceException] 212 | protected def performQuery(url: String, 213 | reqType: String => HttpRequestBase, 214 | qrySpec: QuerySpec, 215 | payload: ObjectNode, 216 | reqHeaders: Map[String, String]): CloseableIterator[ResultRow] = { 217 | 218 | var resp: CloseableHttpResponse = null 219 | 220 | val enterTime = System.currentTimeMillis() 221 | var beforeExecTime = System.currentTimeMillis() 222 | var afterExecTime = System.currentTimeMillis() 223 | 224 | val iter: Try[CloseableIterator[ResultRow]] = for { 225 | r <- Try { 226 | val req: CloseableHttpClient = httpClient 227 | val request: HttpRequestBase = reqType(url) 228 | if (payload != null && request.isInstanceOf[HttpEntityEnclosingRequestBase]) { 229 | // HttpPost 230 | val input: HttpEntity = if (!useSmile) { 231 | new StringEntity(jsonMapper.writeValueAsString(payload), ContentType.APPLICATION_JSON) 232 | } else { 233 | new ByteArrayEntity(smileMapper.writeValueAsBytes(payload), null) 234 | } 235 | request.asInstanceOf[HttpEntityEnclosingRequestBase].setEntity(input) 236 | } 237 | addHeaders(request, reqHeaders) 238 | beforeExecTime = System.currentTimeMillis() 239 | resp = req.execute(request) 240 | afterExecTime = System.currentTimeMillis() 241 | resp 242 | } 243 | iter <- Try { 244 | val status = r.getStatusLine.getStatusCode 245 | if (status >= 200 && status < 300) { 246 | qrySpec(useSmile, r.getEntity.getContent, this, release(r)) 247 | } else { 248 | throw new DruidDataSourceException(s"Unexpected response status: ${r.getStatusLine} " + 249 | s"on $url for query: " + 250 | s"\n ${Utils.toPrettyJson(Right(payload))}") 251 | } 252 | } 253 | } yield iter 254 | 255 | val afterIterBuildTime = System.currentTimeMillis() 256 | log.debug(s"request $url: beforeExecTime = ${beforeExecTime - enterTime}, " + 257 | s"execTime = ${afterExecTime - beforeExecTime}, " + 258 | s"iterBuildTime = ${afterIterBuildTime - afterExecTime}") 259 | iter.getOrElse { 260 | release(resp) 261 | iter.failed.get match { 262 | case de: DruidDataSourceException => throw de 263 | case e => throw new DruidDataSourceException("Failed in communication with Druid: ", e) 264 | } 265 | } 266 | } 267 | 268 | protected def post(url: String, 269 | payload: ObjectNode, 270 | reqHeaders: Map[String, String] = null): String = { 271 | perform(url, postRequest _, payload, reqHeaders) 272 | } 273 | 274 | def postQuery(url: String, qrySpec: QuerySpec, 275 | payload: ObjectNode, 276 | reqHeaders: Map[String, String] = null): CloseableIterator[ResultRow] = { 277 | performQuery(url, postRequest _, qrySpec, payload, reqHeaders) 278 | } 279 | 280 | protected def get(url: String, 281 | payload: ObjectNode = null, 282 | reqHeaders: Map[String, String] = null): String = { 283 | perform(url, getRequest _, payload, reqHeaders) 284 | } 285 | 286 | @throws[DruidDataSourceException] 287 | def executeQuery(url: String, qrySpec: QuerySpec): List[ResultRow] = { 288 | // Payload to be posted is the QuerySpec. 289 | val payload: ObjectNode = jsonMapper.valueToTree(qrySpec) 290 | val r = post(url, payload) 291 | jsonMapper.readValue(r, new TypeReference[List[ResultRow]] {}) 292 | } 293 | 294 | @throws[DruidDataSourceException] 295 | def executeQueryAsStream(url: String, qrySpec: QuerySpec): CloseableIterator[ResultRow] = { 296 | val payload: ObjectNode = jsonMapper.valueToTree(qrySpec) 297 | postQuery(url, qrySpec, payload) 298 | } 299 | 300 | def timeBoundary(dataSource: String): Interval 301 | 302 | @throws[DruidDataSourceException] 303 | def metadata(url: String, 304 | dataSource: String, 305 | fullIndex: Boolean, 306 | druidVersion: String): DruidDataSource = { 307 | 308 | val in: Interval = timeBoundary(dataSource) 309 | // TODO: we do not fetch intervals of all segments for performence considerations. 310 | val ins: String = 311 | if (fullIndex) in.toString else in.withEnd(in.getStart.plusMillis(1)).toString 312 | 313 | val payload: ObjectNode = if (!DruidDataSourceCapability.supportsQueryGranularityMetadata(druidVersion)) { 314 | jsonMapper.createObjectNode() 315 | .put("queryType", "segmentMetadata") 316 | .put("dataSource", dataSource) 317 | .set("intervals", jsonMapper.createArrayNode() 318 | .add(ins)).asInstanceOf[ObjectNode] 319 | .set("analysisTypes", jsonMapper.createArrayNode() 320 | .add("cardinality") 321 | .add("interval") 322 | .add("aggregators")).asInstanceOf[ObjectNode] 323 | .put("merge", "true") 324 | } else if (!DruidDataSourceCapability.supportsTimestampSpecMetadata(druidVersion)) { 325 | jsonMapper.createObjectNode() 326 | .put("queryType", "segmentMetadata") 327 | .put("dataSource", dataSource) 328 | .set("intervals", jsonMapper.createArrayNode() 329 | .add(ins)).asInstanceOf[ObjectNode] 330 | .set("analysisTypes", jsonMapper.createArrayNode() 331 | .add("cardinality") 332 | .add("interval") 333 | .add("aggregators") 334 | .add("queryGranularity")).asInstanceOf[ObjectNode] 335 | .put("merge", "true") 336 | } else {jsonMapper.createObjectNode() 337 | .put("queryType", "segmentMetadata") 338 | .put("dataSource", dataSource) 339 | .set("intervals", jsonMapper.createArrayNode() 340 | .add(ins)).asInstanceOf[ObjectNode] 341 | .set("analysisTypes", jsonMapper.createArrayNode() 342 | .add("cardinality") 343 | .add("interval") 344 | .add("aggregators") 345 | .add("queryGranularity") 346 | .add("timestampSpec")).asInstanceOf[ObjectNode] 347 | .put("merge", "true") 348 | } 349 | 350 | val resp: String = post(url, payload) 351 | logWarning(s"The json response of 'segmentMetadata' query: \n$resp") 352 | 353 | // substitute `queryGranularity` field value if needed. 354 | // TODO: The truth is that multiple paths may exist because different columns 355 | // set will be occurs for different intervals. 356 | val resp1 = jsonMapper.writeValueAsString(DruidQueryGranularity.substitute( 357 | jsonMapper.readTree(resp).path(0))) 358 | logWarning(s"After substitution, the json: \n$resp1") 359 | 360 | val mr: MetadataResponse = 361 | jsonMapper.readValue(resp1, new TypeReference[MetadataResponse] {}) 362 | DruidDataSource(dataSource, mr, List(in)) 363 | } 364 | 365 | def serverStatus: ServerStatus = { 366 | val url = s"http://$host:$port/status" 367 | val is: String = get(url) 368 | jsonMapper.readValue(is, new TypeReference[ServerStatus] {}) 369 | } 370 | } 371 | 372 | 373 | object DruidClient { 374 | 375 | val HOST = """([^:]*):(\d*)""".r 376 | 377 | def hostPort(s: String) : (String, Int) = { 378 | val HOST(h, p) = s 379 | (h, p.toInt) 380 | } 381 | } 382 | 383 | class DruidQueryServerClient(host: String, port: Int, useSmile: Boolean = false) 384 | extends DruidClient(host, port, useSmile) { 385 | 386 | @transient val url = s"http://$host:$port/druid/v2/?pretty" 387 | 388 | def this(t: (String, Int), useSmile: Boolean) = { 389 | this(t._1, t._2, useSmile) 390 | } 391 | 392 | def this(s: String, useSmile: Boolean) = { 393 | this(DruidClient.hostPort(s), useSmile) 394 | } 395 | 396 | @throws[DruidDataSourceException] 397 | override def timeBoundary(dataSource: String): Interval = { 398 | val payload: ObjectNode = jsonMapper.createObjectNode() 399 | .put("queryType", "timeBoundary") 400 | .put("dataSource", dataSource) 401 | val resp: String = post(url, payload) 402 | val objectNode = jsonMapper.readTree(resp) 403 | val maxTime: java.util.List[String] = objectNode.findValuesAsText("maxTime") 404 | val minTime: java.util.List[String] = objectNode.findValuesAsText("minTime") 405 | if (!maxTime.isEmpty && !minTime.isEmpty) { 406 | new Interval( 407 | DateTime.parse(minTime.get(0)), 408 | DateTime.parse(maxTime.get(0)).plusMillis(1) 409 | ) 410 | } else { 411 | throw new DruidDataSourceException("Time boundary should include both the start time and the end time.") 412 | } 413 | } 414 | 415 | @throws[DruidDataSourceException] 416 | def metadata(dataSource: String, fullIndex: Boolean, druidVersion: String): DruidDataSource = { 417 | metadata(url, dataSource, fullIndex, druidVersion) 418 | } 419 | 420 | @throws[DruidDataSourceException] 421 | def executeQuery(qrySpec: QuerySpec): List[ResultRow] = { 422 | executeQuery(url, qrySpec) 423 | } 424 | 425 | @throws[DruidDataSourceException] 426 | def executeQueryAsStream(qrySpec: QuerySpec): CloseableIterator[ResultRow] = { 427 | executeQueryAsStream(url, qrySpec) 428 | } 429 | } 430 | 431 | class DruidCoordinatorClient(host: String, port: Int, useSmile: Boolean = false) 432 | extends DruidClient(host, port, useSmile) { 433 | 434 | @transient val urlPrefix = s"http://$host:$port/druid/coordinator/v1" 435 | 436 | def this(t: (String, Int)) = { 437 | this(t._1, t._2) 438 | } 439 | 440 | def this(s: String) = { 441 | this(DruidClient.hostPort(s)) 442 | } 443 | 444 | override def timeBoundary(dataSource: String): Interval = null 445 | } 446 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/client/DruidMessages.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.client 2 | 3 | import com.fasterxml.jackson.annotation._ 4 | import com.fasterxml.jackson.databind.annotation.JsonDeserialize 5 | import org.joda.time.Interval 6 | import org.rzlabs.druid.{DruidQueryGranularity, NoneGranularity} 7 | 8 | 9 | // All messages are coming from Druid API responses. 10 | 11 | /** 12 | * Constructed by the response of `segmentMetadata` query. 13 | * 14 | * @param `type` The column data type in Druid. 15 | * @param size Estimated byte size for the segment columns if they were stored in a flat format. 16 | * @param cardinality Time or dimension field's cardinality. 17 | * @param minValue Min value of string type column in the segment. 18 | * @param maxValue Max value of string type column in the segment. 19 | * @param errorMessage Error message of the column. 20 | */ 21 | @JsonIgnoreProperties(ignoreUnknown = true) 22 | case class ColumnDetail(`type`: String, size: Long, 23 | @JsonDeserialize(contentAs = classOf[java.lang.Long]) 24 | cardinality: Option[Long], 25 | minValue: Option[String], 26 | maxValue: Option[String], 27 | errorMessage: Option[String]) { 28 | 29 | /** 30 | * Metric column have no cardinality. 31 | */ 32 | def isDimension = cardinality.isDefined 33 | } 34 | 35 | @JsonIgnoreProperties(ignoreUnknown = true) 36 | case class Aggregator(`type`: String, 37 | name: String, 38 | fieldName: String, 39 | expression: Option[String]) 40 | 41 | @JsonIgnoreProperties(ignoreUnknown = true) 42 | case class TimestampSpec(column: String, 43 | format: String, 44 | missingValue: Option[String]) 45 | 46 | /** 47 | * Constructed by the response of `segmentMetadata` query. 48 | * 49 | * @param id 50 | * @param intervals Intervals of segments. 51 | * @param columns Column map which key is the column name in Druid. 52 | * @param size The Estimated byte size for the dataSource. 53 | * @param numRows Total row number of the dataSource. 54 | * @param queryGranularity query granularity specified in the ingestion spec. 55 | */ 56 | @JsonIgnoreProperties(ignoreUnknown = true) 57 | case class MetadataResponse(id: String, 58 | intervals: List[String], 59 | columns: Map[String, ColumnDetail], 60 | size: Long, 61 | @JsonDeserialize(contentAs = classOf[java.lang.Long]) 62 | numRows: Option[Long], 63 | aggregators: Option[Map[String, Aggregator]] = None, 64 | timestampSpec: Option[TimestampSpec] = None, 65 | queryGranularity: Option[DruidQueryGranularity] = None) { 66 | 67 | def getIntervals: List[Interval] = intervals.map(Interval.parse(_)) 68 | 69 | /** 70 | * All intervals' total time tick number. 71 | * According to different query granularities, 72 | * same intervals may have different time ticks. 73 | * 74 | * @param ins The input interval list. 75 | * @return The time tick number. 76 | */ 77 | def timeTicks(ins: List[Interval]): Long = 78 | queryGranularity.getOrElse(NoneGranularity()).ndv(ins) 79 | 80 | /** 81 | * Although all dimension columns have cardinalities, we 82 | * still call `getOrElse(1)` just in case. 83 | */ 84 | def getNumRows: Long = numRows.getOrElse { 85 | val p = columns.values.filter(c => c.isDimension) 86 | .map(_.cardinality.getOrElse(1L)).map(_.toDouble).product 87 | if (p > Long.MaxValue) Long.MaxValue else p.toLong 88 | } 89 | } 90 | 91 | @JsonIgnoreProperties(ignoreUnknown = true) 92 | case class ModuleInfo(name: String, 93 | artifact: String, 94 | version: String) 95 | 96 | @JsonIgnoreProperties(ignoreUnknown = true) 97 | case class ServerMemory(maxMemory: Long, 98 | totalMemory: Long, 99 | freeMemory: Long, 100 | usedMemory: Long) 101 | 102 | @JsonIgnoreProperties(ignoreUnknown = true) 103 | case class ServerStatus(version: String, 104 | modules: List[ModuleInfo], 105 | memory: ServerMemory) 106 | 107 | 108 | sealed trait ResultRow { 109 | def event: Map[String, Any] 110 | } 111 | 112 | case class QueryResultRow(version: String, timestamp: String, 113 | event: Map[String, Any]) extends ResultRow 114 | 115 | case class SelectResultRow(segmentId: String, offset: Int, 116 | event: Map[String, Any]) extends ResultRow 117 | 118 | case class TopNResultRow(event: Map[String, Any]) extends ResultRow 119 | 120 | case class ScanResultRow(event: Map[String, Any]) extends ResultRow 121 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/jscodegen/JSAggrGenerator.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.jscodegen 2 | 3 | import org.apache.spark.sql.MyLogging 4 | import org.apache.spark.sql.catalyst.expressions._ 5 | import org.apache.spark.sql.catalyst.expressions.aggregate._ 6 | import org.apache.spark.sql.types._ 7 | import org.rzlabs.druid.{DruidQueryBuilder, JavascriptAggregationSpec} 8 | 9 | case class JSAggrGenerator(dqb: DruidQueryBuilder, aggrFunc: AggregateFunction, 10 | timeZone: String) extends MyLogging { 11 | 12 | import JSAggrGenerator._ 13 | 14 | private[this] def aggrJsFuncSkeleton(argA: String, argB: String, code: String) = 15 | s""" 16 | |if (($argA == null || isNaN($argA)) && ($argB == null || isNaN($argB))) { 17 | | return null; 18 | |} else if ($argA == null || isNaN($argA)) { 19 | | return $argB; 20 | |} else if ($argB == null || isNaN($argB)) { 21 | | return $argA; 22 | |} else { 23 | | return $code; 24 | |} 25 | """.stripMargin 26 | 27 | private[this] def getAggr(arg: String): Option[String] = aggrFunc match { 28 | case Sum(e) => Some(aggrJsFuncSkeleton("current", arg, s"(current + ($arg))")) 29 | case Min(e) => Some(aggrJsFuncSkeleton("current", arg, s"Math.min(current, $arg)")) 30 | case Max(e) => Some(aggrJsFuncSkeleton("current", arg, s"Math.max(current, $arg)")) 31 | case Count(e) => Some(s"return (current + 1);") 32 | case _ => None 33 | } 34 | 35 | private[this] def getCombine(partialA: String, partialB: String): Option[String] = aggrFunc match { 36 | case Sum(e) => Some(aggrJsFuncSkeleton(partialA, partialB, s"($partialA + $partialB)")) 37 | case Min(e) => Some(aggrJsFuncSkeleton(partialA, partialB, s"Math.min($partialA, $partialB)")) 38 | case Max(e) => Some(aggrJsFuncSkeleton(partialA, partialB, s"Math.max($partialA, $partialB)")) 39 | case Count(e) => Some(s"return ($partialA + $partialB);") 40 | case _ => None 41 | } 42 | 43 | private[this] def getReset: Option[String] = aggrFunc match { 44 | case Sum(e) => Some("return 0;") 45 | case Min(e) => Some("return Number.POSITIVE_INFINITY;") 46 | case Max(e) => Some("return Number.NEGATIVE_INFINITY;") 47 | case Count(e) => Some("return 0;") 48 | case _ => None 49 | } 50 | 51 | /** 52 | * Druid aggregator type could only be LONG, FLOAT or DOUBLE. 53 | */ 54 | private[this] val jsAggrType: Option[DataType] = aggrFunc match { 55 | case Count(a :: Nil) => Some(LongType) 56 | case _ => 57 | aggrFunc.dataType match { 58 | case ShortType | IntegerType | LongType | FloatType | DoubleType => Some(DoubleType) 59 | case TimestampType => Some(LongType) // What AggregateFunction's datatype is TimestampType??? 60 | case _ => None 61 | } 62 | } 63 | 64 | val druidType: Option[DataType] = aggrFunc match { 65 | case Count(_) => Some(LongType) 66 | case _ => 67 | aggrFunc.dataType match { 68 | case ShortType | IntegerType | LongType | FloatType | DoubleType => Some(DoubleType) 69 | case TimestampType => Some(TimestampType) // What AggregateFunction's datatype is TimestampType??? 70 | case _ => None 71 | } 72 | } 73 | 74 | private[this] val jscodegen: Option[(JSCodeGenerator, String)] = 75 | for (c <- aggrFunc.children.headOption 76 | if (aggrFunc.children.size == 1 && !aggrFunc.isInstanceOf[Average]); 77 | (ce, tf) <- simplifyExpr(dqb, c, timeZone); 78 | retType <- jsAggrType) yield 79 | (JSCodeGenerator(dqb, ce, true, true, timeZone, retType), tf) 80 | 81 | val fnAggregate: Option[String] = 82 | for (codegen <- jscodegen; fne <- codegen._1.fnElements; ret <- getAggr(fne._2)) yield 83 | s"""function(${("current" :: codegen._1.fnParams).mkString(", ")}) { ${fne._1} $ret }""" 84 | 85 | val fnCombine: Option[String] = 86 | for (ret <- getCombine("partialA", "partialB")) yield 87 | s"function(partialA, partialB) { $ret }" 88 | 89 | val fnReset: Option[String] = 90 | for (ret <- getReset) yield s"function() { $ret }" 91 | 92 | val aggrFnName: Option[String] = aggrFunc match { 93 | case Sum(_) => Some("SUM") 94 | case Min(_) => Some("MIN") 95 | case Max(_) => Some("MAX") 96 | case Count(_) => Some("COUNT") 97 | case _ => None 98 | } 99 | 100 | val fnParams: Option[List[String]] = for (codegen <- jscodegen) yield codegen._1.fnParams 101 | 102 | val valTransFormFn = if (jscodegen.nonEmpty) jscodegen.get._2 else null 103 | } 104 | 105 | object JSAggrGenerator { 106 | 107 | // def jsAvgCandidate(dqb: DruidQueryBuilder, af: AggregateFunction) = { 108 | // af match { 109 | // case Average(_) if (af.children.size == 1 && 110 | // !af.children.head.isInstanceOf[LeafExpression]) => true 111 | // case _ => false 112 | // } 113 | // } 114 | 115 | def simplifyExpr(dqb: DruidQueryBuilder, e: Expression, timeZone: String): 116 | Option[(Expression, String)] = { 117 | e match { 118 | case Cast(a @ AttributeReference(nm, _, _, _), TimestampType, _) 119 | if (dqb.druidColumn(nm).get.isTimeDimension) => 120 | Some((Cast(a, LongType), "toTSWithTZAdj")) 121 | case _ => Some(e, null) 122 | } 123 | } 124 | 125 | def jsAggr(dqb: DruidQueryBuilder, aggrExpr: Expression, af: AggregateFunction, 126 | tz: String): Option[(DruidQueryBuilder, String)] = { 127 | val jsAggrGen = JSAggrGenerator(dqb, af, tz) 128 | for (fnAggr <- jsAggrGen.fnAggregate; fnCbn <- jsAggrGen.fnCombine; 129 | fnRst <- jsAggrGen.fnReset; fnName <- jsAggrGen.aggrFnName; 130 | fnAlias <- Some(dqb.nextAlias); fnParams <- jsAggrGen.fnParams; 131 | druidDataType <- jsAggrGen.druidType) yield { 132 | (dqb.aggregationSpec( 133 | new JavascriptAggregationSpec(fnAlias, fnParams, fnAggr, fnCbn, fnRst)). 134 | outputAttribute(fnAlias, aggrExpr, aggrExpr.dataType, druidDataType, 135 | jsAggrGen.valTransFormFn), fnAlias) 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/jscodegen/JSCast.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.jscodegen 2 | 3 | import org.apache.spark.sql.types._ 4 | 5 | case class JSCast(from: JSExpr, toDT: DataType, ctx: JSCodeGenerator, fmt: Option[String] = None) { 6 | 7 | import JSDateTimeCtx._ 8 | 9 | private[jscodegen] val castCode: Option[JSExpr] = toDT match { 10 | case _ if from.fnDT == toDT => Some(from) 11 | case BooleanType => castToBooleanType 12 | case ShortType => castToNumericCode(ShortType) 13 | case IntegerType => castToNumericCode(IntegerType) 14 | case LongType => castToNumericCode(LongType) 15 | case FloatType => castToNumericCode(FloatType) 16 | case DoubleType => castToNumericCode(DoubleType) 17 | case StringType => castToStringCode 18 | case DateType => castToDateCode 19 | case TimestampType => castToTimestampCode 20 | case _ => None 21 | } 22 | 23 | private[this] def castToBooleanType: Option[JSExpr] = from.fnDT match { 24 | case IntegerType | LongType | FloatType | DoubleType => 25 | Some(JSExpr(None, from.linesSoFar, s"Boolean(${from.getRef})", BooleanType)) 26 | case StringType | DateType => 27 | // Spark would return null when cast from date to boolean 28 | // Druid will return null value if the value is null or "". 29 | Some(JSExpr(None, from.linesSoFar, "null", BooleanType)) 30 | case TimestampType => 31 | // Boolean(TimestampType) should always returns true 32 | // which behaves the same as Spark. 33 | Some(JSExpr(None, from.linesSoFar, "true", BooleanType)) 34 | } 35 | 36 | private[this] def castToNumericCode(outDt: DataType): Option[JSExpr] = from.fnDT match { 37 | case BooleanType | StringType => 38 | Some(JSExpr(None, from.linesSoFar, s"Number(${from.getRef})", outDt)) 39 | case (FloatType | DoubleType) if ctx.isIntegralNumeric(outDt) => 40 | Some(JSExpr(None, from.linesSoFar, s"Math.floor(${from.getRef})", outDt)) 41 | case ShortType | IntegerType | LongType | FloatType | DoubleType => 42 | Some(JSExpr(None, from.linesSoFar, from.getRef, outDt)) 43 | case DateType => 44 | // Behave the same with Spark. (cast(cast('2018-01-01' as date) as int)) 45 | Some(JSExpr(None, from.linesSoFar, "null", outDt)) 46 | case TimestampType => 47 | // Behave the same with Spark. (cast(cast('2018-01-01' as timestamp) as long)) 48 | Some(JSExpr(None, from.linesSoFar, dtToLongCode(from.getRef), outDt)) 49 | case _ => None 50 | } 51 | 52 | private[this] def castToStringCode: Option[JSExpr] = from.fnDT match { 53 | case TimestampType if from.timeDim => 54 | // time dimension 55 | nullSafeCastToString(dtToStrCode(longToISODtCode(from.getRef, ctx.dateTimeCtx))) 56 | case BooleanType | ShortType | IntegerType | LongType | FloatType 57 | | DoubleType | DecimalType() => nullSafeCastToString(from.getRef) 58 | case DateType => nullSafeCastToString(dateToStrCode(from.getRef)) 59 | case TimestampType => nullSafeCastToString(dtToStrCode(from.getRef)) 60 | case _ => None 61 | } 62 | 63 | private[this] def castToDateCode: Option[JSExpr] = from.fnDT match { 64 | case StringType => 65 | if (fmt.nonEmpty) { 66 | Some(JSExpr(None, from.linesSoFar, stringToDateCode(from.getRef, ctx.dateTimeCtx, true, fmt), DateType)) 67 | } else { 68 | Some(JSExpr(None, from.linesSoFar, stringToDateCode(from.getRef, ctx.dateTimeCtx), DateType)) 69 | } 70 | case TimestampType => 71 | Some(JSExpr(None, from.linesSoFar, dtToDateCode(from.getRef), DateType)) 72 | case LongType if from.timeDim => 73 | Some(JSExpr(None, from.linesSoFar, longToDateCode(from.getRef, ctx.dateTimeCtx), DateType)) 74 | case _ => None 75 | } 76 | 77 | private[this] def castToTimestampCode: Option[JSExpr] = from.fnDT match { 78 | case StringType => 79 | if (fmt.nonEmpty) { 80 | Some(JSExpr(None, from.linesSoFar, stringToISODtCode(from.getRef, ctx.dateTimeCtx, true, fmt), TimestampType)) 81 | } else { 82 | Some(JSExpr(None, from.linesSoFar, stringToISODtCode(from.getRef, ctx.dateTimeCtx), TimestampType)) 83 | } 84 | case BooleanType => 85 | Some(JSExpr(None, from.linesSoFar, stringToISODtCode( 86 | s""" (${from.getRef}) == true ? "T00:00:01Z" : "T00:00:00Z"""", ctx.dateTimeCtx), TimestampType)) 87 | case FloatType | DoubleType | DecimalType() => 88 | for (lc <- castToNumericCode(LongType)) yield 89 | JSExpr(None, lc.linesSoFar, longToISODtCode(lc.getRef, ctx.dateTimeCtx), TimestampType) 90 | case ShortType | IntegerType | LongType => 91 | Some(JSExpr(None, from.linesSoFar, longToISODtCode(from.getRef, ctx.dateTimeCtx), TimestampType)) 92 | case DateType => 93 | Some(JSExpr(None, from.linesSoFar, localDateToDtCode(from.getRef, ctx.dateTimeCtx), TimestampType)) 94 | case _ => None 95 | } 96 | 97 | private[this] def nullSafeCastToString(valToCast: String): Option[JSExpr] = { 98 | if (from.fnVar.isEmpty) { 99 | val vn = ctx.makeUniqueVarName 100 | Some(JSExpr(None, from.linesSoFar + s"$vn = $valToCast;", 101 | s"""($vn != null && !isNaN($vn) ? $vn.toString() : "")""", StringType)) 102 | } else { 103 | Some(JSExpr(None, from.linesSoFar, 104 | s"""($valToCast != null && !isNaN($valToCast) ? $valToCast.toString() : "")""", 105 | StringType)) 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/jscodegen/JSDateTime.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.jscodegen 2 | 3 | private[jscodegen] case class JSDateTimeCtx(val timeZone: String, val ctx: JSCodeGenerator) { 4 | 5 | // v1 = org.joda.time.DateTimeZone.forID("${tz_id}") 6 | private[jscodegen] val tzVar = ctx.makeUniqueVarName 7 | // v2 = org.joda.time.format.ISODateTimeFormat.dateTimeParser() 8 | private[jscodegen] val isoFormatterVar = ctx.makeUniqueVarName 9 | // v3 = v2.withID(v1) 10 | private[jscodegen] val isoFormatterWIthTzVar = ctx.makeUniqueVarName 11 | 12 | private[jscodegen] var createJodaTz = false 13 | private[jscodegen] var createJodaISOFormatter = false 14 | private[jscodegen] var createJodaISOFormatterWithTz = false 15 | 16 | private[jscodegen] def dateTimeInitCode: String = { 17 | var dtInitCode = "" 18 | if (createJodaTz) { 19 | dtInitCode += 20 | s"""var $tzVar = org.joda.time.DateTimeZone.forID("$timeZone");""" 21 | } 22 | if (createJodaISOFormatter || createJodaISOFormatterWithTz) { 23 | dtInitCode += 24 | s"var $isoFormatterVar = org.joda.time.format.ISODateTimeFormat.dateTimeParser();" 25 | } 26 | if (createJodaISOFormatterWithTz) { 27 | dtInitCode += 28 | s"var $isoFormatterWIthTzVar = $isoFormatterVar.withZone($tzVar);" 29 | } 30 | 31 | dtInitCode 32 | } 33 | } 34 | 35 | private[jscodegen] object JSDateTimeCtx { 36 | private val dateFormat = "yyyy-MM-dd" 37 | private val timestampFormat = "yyyy-MM-dd HH:mm:ss" 38 | private val mSecsInDay = 86400000 39 | 40 | /** 41 | * The 'ts' param must be a [[org.joda.time.DateTime]] literal, 42 | * e.g., 'DateTime.parse("2018-01-01", DateTimeZone.forID("UTC"))'. 43 | * NOTE: The returned unit is second so divide 1000. 44 | * @param ts must be a [[org.joda.time.DateTime]] literal. 45 | * @return 46 | */ 47 | private[jscodegen] def dtToLongCode(ts: String) = s"Math.floor($ts.getMillis() / 1000)" 48 | 49 | /** 50 | * The 'dt' param may be a [[org.joda.time.LocalDate]] literal, 51 | * e.g., ''LocalDate.parse("2018-01-01", format.ISODateTimeFormat.dateTimeParser)''. 52 | * @param dt must be a [[org.joda.time.LocalDate]] literal. 53 | * @return 54 | */ 55 | private[jscodegen] def dateToStrCode(dt: String) = s"""$dt.toString("$dateFormat")""" 56 | 57 | /** 58 | * The ''ts'' param may be a [[org.joda.time.DateTime]] literal. 59 | * @param ts may be a [[org.joda.time.DateTime]] literal. 60 | * @param fmt The timestamp format specified, default is ''yyyy-MM-dd HH:mm:ss''. 61 | * @return 62 | */ 63 | private[jscodegen] def dtToStrCode(ts: String, 64 | fmt: String = timestampFormat, 65 | litFmt: Boolean = true) = { 66 | if (litFmt) { 67 | s"""$ts.toString("$fmt")""" 68 | } else { 69 | s"""$ts.toString($fmt)""" 70 | } 71 | } 72 | 73 | private[jscodegen] def longToISODtCode(l: String, ctx: JSDateTimeCtx) = { 74 | ctx.createJodaTz = true 75 | s"new org.joda.time.DateTime($l, ${ctx.tzVar})" 76 | } 77 | 78 | // private[jscodegen] def stringToDateCode(s: String, ctx: JSDateTimeCtx) = { 79 | // ctx.createJodaISOFormatter = true 80 | // s"""org.joda.time.LocalDate.parse($s, ${ctx.isoFormatterVar})""" 81 | // } 82 | 83 | private[jscodegen] def stringToDateCode(s: String, ctx: JSDateTimeCtx, 84 | withFmt: Boolean = false, 85 | fmt: Option[String] = None): String = { 86 | ctx.createJodaTz = true 87 | ctx.createJodaISOFormatterWithTz = true 88 | if (!withFmt) { 89 | s"""org.joda.time.LocalDate.parse($s, ${ctx.isoFormatterVar})""" 90 | } else { 91 | s"""org.joda.time.LocalDate.parse($s, 92 | |org.joda.time.format.DateTimeFormat.forPattern("${fmt.get}").withZone(${ctx.tzVar}))""".stripMargin 93 | } 94 | } 95 | 96 | /** 97 | * The ''ts'' param must be a [[org.joda.time.DateTime]] literal. 98 | * @param ts 99 | * @return 100 | */ 101 | private[jscodegen] def dtToDateCode(ts: String) = { 102 | s"${ts}.toLocalDate()" 103 | } 104 | 105 | private[jscodegen] def longToDateCode(ts: String, ctx: JSDateTimeCtx) = { 106 | ctx.createJodaTz = true 107 | s"org.joda.time.LocalDate($ts, ${ctx.tzVar})" 108 | } 109 | 110 | 111 | private[jscodegen] def stringToISODtCode(s: String, ctx: JSDateTimeCtx, 112 | withFmt: Boolean = false, 113 | fmt: Option[String] = None) = { 114 | ctx.createJodaTz = true 115 | ctx.createJodaISOFormatterWithTz = true 116 | if (!withFmt) { 117 | s"""org.joda.time.DateTime.parse(($s).replace(" ", "T"), ${ctx.isoFormatterWIthTzVar})""" 118 | } else { 119 | s"""org.joda.time.DateTime.parse($s, 120 | |org.joda.time.format.DateTimeFormat.forPattern("$fmt").withZone(${ctx.tzVar}))""".stripMargin 121 | } 122 | } 123 | 124 | /** 125 | * The ''dt'' param must be a [[org.joda.time.LocalDate]] literal. 126 | * @param dt 127 | * @param ctx 128 | * @return 129 | */ 130 | private[jscodegen] def localDateToDtCode(dt: String, ctx: JSDateTimeCtx) = { 131 | ctx.createJodaTz = true 132 | s"$dt.toDateTimeAsStartOfDay(${ctx.tzVar})" 133 | } 134 | 135 | private[jscodegen] def noDaysToDateCode(d: String) = { 136 | s"org.joda.time.LocalDate($d * $mSecsInDay)" 137 | } 138 | 139 | private[jscodegen] def dateComparisonCode(l: String, r: String, op: String) = { 140 | op match { 141 | case " < " => Some(s"$l.isBefore($r)") 142 | case " <= " => Some(s"$l.compareTo($r) <= 0") 143 | case " > " => Some(s"$l.isAfter($r)") 144 | case " >= " => Some(s"$l.compareTo($r) >= 0") 145 | case " = " => Some(s"$l.equals($r)") 146 | case _ => None 147 | } 148 | } 149 | 150 | private[jscodegen] def dateAdd(dt: String, nd: String) = s"$dt.plusDays($nd)" 151 | private[jscodegen] def dateSub(dt: String, nd: String) = s"$dt.minusDays($nd)" 152 | private[jscodegen] def dateDiff(ed: String, sd: String) = { 153 | s"org.joda.time.Days.daysBetween($sd, $ed).getDays()" 154 | } 155 | 156 | private[jscodegen] def year(dt: String) = s"$dt.getYear()" 157 | private[jscodegen] def quarter(dt: String) = s"(Math.floor(($dt.getMonthOfYear() - 1) / 3) + 1)" 158 | private[jscodegen] def month(dt: String) = s"$dt.getMonthOfYear()" 159 | private[jscodegen] def dayOfMonth(dt: String) = s"$dt.getDayOfMonth()" 160 | private[jscodegen] def dayOfYear(dt: String) = s"$dt.getDayOfYear()" 161 | private[jscodegen] def weekOfYear(dt: String) = s"$dt.getWeekOfYear()" 162 | private[jscodegen] def hourOfDay(dt: String) = s"$dt.getHourOfDay()" 163 | private[jscodegen] def minuteOfHour(dt: String) = s"$dt.getMinuteOfHour()" 164 | private[jscodegen] def secondOfMinute(dt: String) = s"$dt.getSecondOfMinute()" 165 | 166 | private[jscodegen] def truncate(dt: String, fmt: String)= fmt.toLowerCase() match { 167 | case "year" | "yyyy" | "yy" => s"$dt.withDayOfYear(1)" 168 | case "month" | "mon" | "mm" => s"$dt.withDayOfMonth(1)" 169 | case _ => "null" 170 | } 171 | } -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/jscodegen/JSExpr.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.jscodegen 2 | 3 | import org.apache.spark.sql.types.DataType 4 | 5 | /** 6 | * 7 | * @param fnVar 8 | * @param linesSoFar 9 | * @param curLine 10 | * @param fnDT 11 | * @param timeDim 12 | */ 13 | private[jscodegen] case class JSExpr(val fnVar: Option[String], val linesSoFar: String, 14 | val curLine: String, val fnDT: DataType, 15 | val timeDim: Boolean = false) { 16 | 17 | private[jscodegen] def this(curLine: String, fnDT: DataType, timeDim: Boolean) = { 18 | this(None, "", curLine, fnDT, timeDim) 19 | } 20 | 21 | private[jscodegen] def getRef: String = { 22 | if (fnVar.isDefined) fnVar.get else curLine 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/metadata/DruidInfo.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.metadata 2 | 3 | import org.apache.spark.sql.MyLogging 4 | import org.rzlabs.druid.{DruidDataSource, DruidQueryGranularity} 5 | 6 | case class DruidOptions(zkHost: String, 7 | zkSessionTimeoutMs: Int, 8 | zkEnableCompression: Boolean, 9 | zkQualifyDiscoveryNames: Boolean, 10 | zkDruidPath: String, 11 | poolMaxConnectionsPerRoute: Int, 12 | poolMaxConnections: Int, 13 | loadMetadataFromAllSegments: Boolean, 14 | debugTransformations: Boolean, 15 | timeZoneId: String, 16 | useV2GroupByEngine: Boolean, 17 | useSmile: Boolean, 18 | queryGranularity: DruidQueryGranularity) 19 | 20 | case class DruidRelationName(zkHost: String, druidDataSource: String) 21 | 22 | case class DruidRelationInfo(fullName: DruidRelationName, 23 | timeDimensionCol: String, 24 | druidDataSource: DruidDataSource, 25 | val druidColumns: Map[String, DruidRelationColumn], 26 | val options: DruidOptions) 27 | 28 | -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/metadata/DruidMetadataCache.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.metadata 2 | 3 | import org.apache.spark.sql.MyLogging 4 | import org.apache.spark.sql.types.TimestampType 5 | import org.apache.spark.util.MyThreadUtils 6 | import org.codehaus.jackson.annotate.JsonIgnoreProperties 7 | import org.rzlabs.druid._ 8 | import org.rzlabs.druid.client._ 9 | import org.fasterxml.jackson.databind.ObjectMapper._ 10 | import org.joda.time.Interval 11 | 12 | import scala.collection.mutable.{Map => MMap} 13 | import scala.util.Try 14 | 15 | @JsonIgnoreProperties(ignoreUnknown = true) 16 | case class DruidNode(name: String, 17 | id: String, 18 | address: String, 19 | port: Int) 20 | 21 | case class DruidClusterInfo(host: String, 22 | curatorConnection: CuratorConnection, 23 | serverStatus: ServerStatus, 24 | druidDataSources: MMap[String, DruidDataSource]) 25 | 26 | trait DruidMetadataCache { 27 | 28 | def getDruidClusterInfo(druidRelationName: DruidRelationName, 29 | options: DruidOptions): DruidClusterInfo 30 | 31 | def getDataSourceInfo(druidRelationName: DruidRelationName, 32 | options: DruidOptions): DruidDataSource 33 | } 34 | 35 | trait DruidRelationInfoCache { 36 | 37 | self: DruidMetadataCache => 38 | 39 | 40 | def buildColumnInfos(druidDataSource: DruidDataSource, 41 | userSpecifiedColumnInfos: List[DruidRelationColumnInfo] 42 | ): Map[String, DruidRelationColumn] = { 43 | val columns: Map[String, DruidColumn] = druidDataSource.columns 44 | 45 | def getDruidMetric(metricName: Option[String]): Option[DruidMetric] = { 46 | if (metricName.isDefined) { 47 | if (columns.contains(metricName.get) && 48 | columns(metricName.get).isInstanceOf[DruidMetric]) { 49 | Some(columns(metricName.get).asInstanceOf[DruidMetric]) 50 | } else None 51 | } else None 52 | } 53 | 54 | def isApproxMetric(druidColumn: DruidColumn) = { 55 | druidColumn.dataType == DruidDataType.HyperUnique || 56 | druidColumn.dataType == DruidDataType.ThetaSketch 57 | } 58 | 59 | //val normalColumns: Map[String, DruidRelationColumn] = columns.map { 60 | // case (columnName, druidColumn) if !isApproxMetric(druidColumn) => 61 | // val ci = userSpecifiedColumnInfos.find(_.column == columnName).getOrElse(null) 62 | // val druidRelationColumn = if (ci != null) { 63 | // val hllMetric = getDruidMetric(ci.hllMetric) 64 | // val sketchMetric = getDruidMetric(ci.sketchMetric) 65 | // DruidRelationColumn(columnName, Some(druidColumn), hllMetric, sketchMetric) 66 | // } else { 67 | // DruidRelationColumn(columnName, Some(druidColumn)) 68 | // } 69 | // val cardinality: Option[Long] = if (druidColumn.isInstanceOf[DruidTimeDimension]) { 70 | // Some(druidColumn.asInstanceOf[DruidTimeDimension].cardinality) 71 | // } else if (druidColumn.isInstanceOf[DruidDimension]) { 72 | // Some(druidColumn.asInstanceOf[DruidDimension].cardinality) 73 | // } else if (druidColumn.isInstanceOf[DruidMetric]) { 74 | // Some(druidColumn.asInstanceOf[DruidMetric].cardinality) 75 | // } else None 76 | // columnName -> druidRelationColumn.copy(cardinalityEstimate = cardinality) 77 | // // Approx metric information should be carried by related origin column. 78 | // case _ => (null, null) 79 | //}.filterNot(_._1 == null) 80 | 81 | val normalColumns: Map[String, DruidRelationColumn] = columns.map { 82 | case (columnName, druidColumn) => 83 | val ci = userSpecifiedColumnInfos.find(_.column == columnName).getOrElse(null) 84 | val druidRelationColumn = if (ci != null) { 85 | val hllMetric = getDruidMetric(ci.hllMetric) 86 | val sketchMetric = getDruidMetric(ci.sketchMetric) 87 | DruidRelationColumn(columnName, Some(druidColumn), hllMetric, sketchMetric) 88 | } else { 89 | DruidRelationColumn(columnName, Some(druidColumn)) 90 | } 91 | val cardinality: Option[Long] = if (druidColumn.isInstanceOf[DruidTimeDimension]) { 92 | Some(druidColumn.asInstanceOf[DruidTimeDimension].cardinality) 93 | } else if (druidColumn.isInstanceOf[DruidDimension]) { 94 | Some(druidColumn.asInstanceOf[DruidDimension].cardinality) 95 | } else if (druidColumn.isInstanceOf[DruidMetric]) { 96 | Some(druidColumn.asInstanceOf[DruidMetric].cardinality) 97 | } else None 98 | columnName -> druidRelationColumn.copy(cardinalityEstimate = cardinality) 99 | } 100 | 101 | // For the dimension user specified but not indexed in Druid datasource. 102 | val notIndexedColumns: Map[String, DruidRelationColumn] = userSpecifiedColumnInfos.collect { 103 | case ci if !columns.exists(_ == ci.column) => ci 104 | }.map { 105 | case ci: DruidRelationColumnInfo => 106 | val hllMetric = getDruidMetric(ci.hllMetric) 107 | val sketchMetric = getDruidMetric(ci.sketchMetric) 108 | val cardinality = columns.find { _._1 == 109 | hllMetric.getOrElse(sketchMetric.getOrElse(null)) 110 | } match { 111 | case Some((_, druidColumn)) => 112 | Some(druidColumn.asInstanceOf[DruidMetric].cardinality) 113 | case _ => None 114 | } 115 | ci.column -> DruidRelationColumn(ci.column, None, hllMetric, sketchMetric, cardinality) 116 | }.toMap 117 | 118 | normalColumns ++ notIndexedColumns 119 | } 120 | 121 | def druidRelation(dataSourceName: String, 122 | timeDimensionCol: String, 123 | userSpecifiedColumnInfos: List[DruidRelationColumnInfo], 124 | options: DruidOptions): DruidRelationInfo = { 125 | 126 | val name = DruidRelationName(options.zkHost, dataSourceName) 127 | val druidDS = getDataSourceInfo(name, options) 128 | val columnInfos = buildColumnInfos(druidDS, userSpecifiedColumnInfos) 129 | val timeDimCol = druidDS.timeDimensionColName(timeDimensionCol) 130 | // Change the time dimension name "__time" to real time. 131 | val colInfos = columnInfos.map { 132 | case (colName, drCol) if colName == DruidDataSource.INNER_TIME_COLUMN_NAME => 133 | (timeDimCol, drCol.copy(column = timeDimCol, 134 | druidColumn = Some( 135 | // The time dimension column's datatype should be StringType but not LongType 136 | drCol.druidColumn.get.asInstanceOf[DruidTimeDimension].copy(name = timeDimCol, 137 | dataType = DruidDataType.withName("STRING")) 138 | ))) 139 | case other => other 140 | } 141 | DruidRelationInfo(name, timeDimCol, druidDS, colInfos, options) 142 | } 143 | } 144 | 145 | object DruidMetadataCache extends DruidMetadataCache with MyLogging with DruidRelationInfoCache { 146 | 147 | private[metadata] val cache: MMap[String, DruidClusterInfo] = MMap() // zkHost -> DruidClusterInfo 148 | private val curatorConnections: MMap[String, CuratorConnection] = MMap() 149 | private[druid] var brokerClient: DruidQueryServerClient = null 150 | val threadPool = MyThreadUtils.newDaemonCachedThreadPool("druidZkEventExec", 10) 151 | 152 | /** 153 | * 154 | * @param json 155 | */ 156 | private def updateTimePeriod(json: String): Unit = { 157 | val root = jsonMapper.readTree(json) 158 | val action = Try(root.get("action").asText).recover({ case _ => null }).get // "load" or "drop" 159 | val dataSource = Try(root.get("dataSource").asText).recover({ case _ => null }).get 160 | val interval = Try(root.get("interval").asText).recover({ case _ => null }).get 161 | if (action == null || dataSource == null || interval == null) return 162 | 163 | // Find datasource in `DruidClusterInfo` for each zkHost. 164 | logInfo(s"${action.toUpperCase} a segment of dataSource $dataSource with interval $interval.") 165 | cache.foreach { 166 | case (_, druidClusterInfo) => { 167 | val dDS: Option[DruidDataSource] = druidClusterInfo.druidDataSources.get(dataSource) 168 | if (dDS.isDefined) { // find the dataSource the interval should be updated. 169 | dDS.synchronized { 170 | val oldInterval: Interval = dDS.get.intervals(0) 171 | action.toUpperCase match { 172 | case "LOAD" => 173 | // Don't call `timeBoundary` to update interval (cost to much). 174 | val newInterval = Utils.updateInterval(oldInterval, new Interval(interval), action) 175 | dDS.get.intervals = List(newInterval) 176 | case "DROP" => 177 | // Call `timeBoundary` to update interval. 178 | dDS.get.intervals = List(brokerClient.timeBoundary(dataSource)) 179 | case other => logWarning(s"Unkown segment action '$other'") 180 | } 181 | } 182 | logInfo(s"The new interval of dataSource $dataSource is ${dDS.get.intervals(0)}") 183 | } // else do nothing 184 | } 185 | } 186 | } 187 | 188 | private def curatorConnection(host: String, options: DruidOptions): CuratorConnection = { 189 | curatorConnections.getOrElse(host, { 190 | val cc = new CuratorConnection(host, options, cache, threadPool, updateTimePeriod _) 191 | curatorConnections(host) = cc 192 | cc 193 | }) 194 | } 195 | 196 | def getDruidClusterInfo(druidRelationName: DruidRelationName, 197 | options: DruidOptions): DruidClusterInfo = { 198 | cache.synchronized { 199 | if (cache.contains(druidRelationName.zkHost)) { 200 | cache(druidRelationName.zkHost) 201 | } else { 202 | val zkHost = druidRelationName.zkHost 203 | val cc = curatorConnection(zkHost, options) 204 | val coordClient = new DruidCoordinatorClient(cc.getCoordinator) 205 | val serverStatus = coordClient.serverStatus 206 | val druidClusterInfo = new DruidClusterInfo(zkHost, cc, serverStatus, 207 | MMap[String, DruidDataSource]()) 208 | cache(druidClusterInfo.host) = druidClusterInfo 209 | logInfo(s"Loading druid cluster info for $druidRelationName with zkHost $zkHost") 210 | druidClusterInfo 211 | } 212 | } 213 | } 214 | 215 | def getDataSourceInfo(druidRelationName: DruidRelationName, 216 | options: DruidOptions): DruidDataSource = { 217 | val druidClusterInfo = getDruidClusterInfo(druidRelationName, options) 218 | val dataSourceName: String = druidRelationName.druidDataSource 219 | druidClusterInfo.synchronized { 220 | if (druidClusterInfo.druidDataSources.contains(dataSourceName)) { 221 | druidClusterInfo.druidDataSources(dataSourceName) 222 | } else { 223 | val broker: String = druidClusterInfo.curatorConnection.getBroker 224 | brokerClient = new DruidQueryServerClient(broker, false) 225 | val fullIndex = options.loadMetadataFromAllSegments 226 | val druidDS = brokerClient.metadata(dataSourceName, fullIndex, 227 | druidClusterInfo.serverStatus.version) 228 | .copy(druidVersion = druidClusterInfo.serverStatus.version) 229 | druidClusterInfo.druidDataSources(dataSourceName) = druidDS 230 | logInfo(s"Druid datasource info for ${dataSourceName} is loaded.") 231 | druidDS 232 | } 233 | } 234 | } 235 | } -------------------------------------------------------------------------------- /src/main/scala/org/rzlabs/druid/metadata/DruidRelationColumn.scala: -------------------------------------------------------------------------------- 1 | package org.rzlabs.druid.metadata 2 | 3 | import org.rzlabs.druid._ 4 | 5 | 6 | case class DruidRelationColumnInfo(column: String, 7 | hllMetric: Option[String] = None, 8 | sketchMetric: Option[String] = None, 9 | cardinalityEstimate: Option[Long] = None) 10 | 11 | case class DruidRelationColumn(column: String, 12 | druidColumn: Option[DruidColumn], 13 | hllMetric: Option[DruidMetric] = None, 14 | sketchMetric: Option[DruidMetric] = None, 15 | cardinalityEstimate: Option[Long] = None) { 16 | 17 | private lazy val druidColumnToUse: DruidColumn = { 18 | Utils.filterSomes( 19 | Seq(druidColumn, hllMetric, sketchMetric).toList 20 | ).head.get 21 | } 22 | 23 | def hasDirectDruidColumn = druidColumn.isDefined 24 | 25 | def hasHllMetric = hllMetric.isDefined 26 | 27 | def hasSketchMetric = sketchMetric.isDefined 28 | 29 | // TODO: Not support spatial index yet. 30 | def hasSpatialIndex = false 31 | 32 | //def name = druidColumnToUse.name 33 | def name = column 34 | 35 | //def dataType = if (hasSpatialIndex) DruidDataType.Float else druidColumnToUse.dataType 36 | def dataType = { 37 | if (hasSpatialIndex) { 38 | DruidDataType.Float 39 | } else if (isNotIndexedDimension) { 40 | // Specify non-indexed dimension type as string 41 | DruidDataType.String 42 | } else { 43 | druidColumnToUse.dataType 44 | } 45 | } 46 | 47 | def size = druidColumnToUse.size 48 | 49 | val cardinality: Long = cardinalityEstimate.getOrElse(druidColumnToUse.cardinality) 50 | 51 | def isDimension(excludeTime: Boolean = false): Boolean = { 52 | // Approximate metric refered dimension that not indexed in Druid datasource. 53 | hasDirectDruidColumn && druidColumnToUse.isDimension(excludeTime) 54 | } 55 | 56 | def isNotIndexedDimension = !hasDirectDruidColumn 57 | 58 | def isTimeDimension: Boolean = { 59 | hasDirectDruidColumn && druidColumnToUse.isInstanceOf[DruidTimeDimension] 60 | } 61 | 62 | def isMetric: Boolean = { 63 | hasDirectDruidColumn && !isDimension() 64 | } 65 | 66 | def metric = druidColumnToUse.asInstanceOf[DruidMetric] 67 | } 68 | --------------------------------------------------------------------------------