├── .github └── workflows │ └── scala.yml ├── .gitignore ├── LICENSE ├── README.md ├── build.sbt ├── macros └── src │ ├── main │ └── scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── sql │ │ └── sqlmacros │ │ ├── Arithmetics.scala │ │ ├── CollectionUtils.scala │ │ ├── Collections.scala │ │ ├── Conditionals.scala │ │ ├── DateTime.scala │ │ ├── DateTimeUtils.scala │ │ ├── ExprBuilders.scala │ │ ├── ExprOptimize.scala │ │ ├── ExprTranslator.scala │ │ ├── MacrosEnv.scala │ │ ├── MacrosScalaReflection.scala │ │ ├── Options.scala │ │ ├── PredicateUtils.scala │ │ ├── RecursiveSparkApply.scala │ │ ├── SQLMacro.scala │ │ ├── SQLMacroExpressions.scala │ │ ├── SparkSQLMacroUtils.scala │ │ ├── StringUtils.scala │ │ ├── Strings.scala │ │ ├── Structs.scala │ │ ├── Tuples.scala │ │ ├── package.scala │ │ └── registered_macros.scala │ └── test │ └── scala │ └── macrotest │ └── ExampleStructs.scala ├── project ├── Assembly.scala ├── Dependencies.scala ├── Versions.scala ├── build.properties └── plugins.sbt ├── scalastyle-config.xml └── sql └── src ├── main └── scala │ └── org │ └── apache │ └── spark │ └── sql │ └── defineMacros.scala └── test └── scala └── org └── apache └── spark └── sql ├── AbstractTest.scala ├── CollectionMacrosTest.scala ├── MacrosTest.scala ├── StringMacrosTest.scala └── hive └── test └── sqlmacros └── TestSQLMacrosHive.scala /.github/workflows/scala.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up JDK 1.8 17 | uses: actions/setup-java@v1 18 | with: 19 | java-version: 1.8 20 | - name: Run tests 21 | run: sbt clean compile test 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | *.pyc 4 | sbt/*.jar 5 | 6 | # sbt specific 7 | .cache/ 8 | .history/ 9 | .lib/ 10 | dist/* 11 | target/ 12 | lib_managed/ 13 | src_managed/ 14 | project/boot/ 15 | project/plugins/project/ 16 | 17 | # idea 18 | .idea/ 19 | *.iml 20 | 21 | # Mac 22 | .DS_Store 23 | 24 | # emacs 25 | *.*~ 26 | 27 | # docs 28 | docs/*.html 29 | docs/*.tex 30 | docs/auto/ 31 | *-blx.bib 32 | *.bbl 33 | *.blg 34 | *.fdb_latexmk 35 | *.fls 36 | docs/*.xml 37 | *.tex 38 | *.tiff 39 | docs/benchmark/*.html 40 | **/auto/ 41 | docs/notes/*_ImplNotes.org 42 | docs/notes/mynotes 43 | spark-sql-macros.wiki 44 | jdbc-examples 45 | **/Experiments.* 46 | 47 | 48 | # docker 49 | docker/driver 50 | docker/worker 51 | docker/.env 52 | 53 | # leveldb 54 | **/metadata_cache/*.log 55 | **/metadata_cache/LOG* 56 | 57 | # data files 58 | **/_SUCCESS 59 | **/._SUCCESS.crc 60 | **/.part*.crc 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import Dependencies._ 2 | import sbt.Keys.test 3 | import sbt._ 4 | 5 | ThisBuild / scalaVersion := Versions.scalaVersion 6 | ThisBuild / crossScalaVersions := Seq(Versions.scalaVersion) 7 | 8 | ThisBuild / homepage := Some(url("https://orahub.oci.oraclecorp.com/harish_butani/spark-oracle")) 9 | ThisBuild / licenses := List("Apache 2" -> url("http://www.apache.org/licenses/LICENSE-2.0.txt")) 10 | ThisBuild / organization := "org.rhbutani" 11 | ThisBuild / version := Versions.sparksqlMacrosVersion 12 | 13 | // from https://www.scala-sbt.org/1.x/docs/Cached-Resolution.html 14 | // added to commonSettings 15 | // ThisBuild / updateOptions := updateOptions.value.withLatestSnapshots(false) 16 | // ThisBuild / updateOptions := updateOptions.value.withCachedResolution(true) 17 | 18 | Global / resolvers ++= Seq( 19 | DefaultMavenRepository, 20 | Resolver.sonatypeRepo("public"), 21 | "Apache snapshots repo" at "https://repository.apache.org/content/groups/snapshots/") 22 | 23 | lazy val commonSettings = Seq( 24 | updateOptions := updateOptions.value.withLatestSnapshots(false), 25 | updateOptions := updateOptions.value.withCachedResolution(true), 26 | javaOptions := Seq( 27 | "-Xms1g", 28 | "-Xmx3g", 29 | "-Duser.timezone=UTC", 30 | "-Dscalac.patmat.analysisBudget=512", 31 | "-XX:MaxPermSize=256M", 32 | "-Xrunjdwp:transport=dt_socket,address=5005,server=y,suspend=n"), 33 | scalacOptions ++= Seq("-target:jvm-1.8", "-feature", "-deprecation"), 34 | licenses := Seq("Apache License, Version 2.0" -> 35 | url("http://www.apache.org/licenses/LICENSE-2.0") 36 | ), 37 | homepage := Some(url("https://github.com/hbutani/spark-sql-macros")), 38 | test in assembly := {}, 39 | fork in Test := true, 40 | parallelExecution in Test := false, 41 | libraryDependencies ++= (scala.dependencies ++ 42 | spark.dependencies ++ 43 | utils.dependencies ++ 44 | test_infra.dependencies), 45 | excludeDependencies ++= Seq(ExclusionRule("org.apache.calcite.avatica")) 46 | ) 47 | 48 | lazy val macros = project 49 | .in(file("macros")) 50 | .disablePlugins(AssemblyPlugin) 51 | .settings(commonSettings: _*) 52 | .settings(libraryDependencies ++= scala.dependencies) 53 | 54 | lazy val sql = project 55 | .in(file("sql")) 56 | .aggregate(macros) 57 | .settings(commonSettings: _*) 58 | .settings(Assembly.assemblySettings: _*) 59 | .settings( 60 | name := "spark-sql-macros", 61 | assemblyJarName in assembly := s"${name.value}_${scalaVersion.value}_${version.value}.jar" 62 | ). 63 | dependsOn(macros % "compile->compile;test->test") 64 | 65 | 66 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/Arithmetics.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 21 | import org.apache.spark.sql.types.IntegralType 22 | 23 | trait Arithmetics { self : ExprTranslator => 24 | 25 | import macroUniverse._ 26 | 27 | object BasicArith { 28 | def unapply(t: mTree): Option[sparkexpr.Expression] = 29 | t match { 30 | /** Unfortunately cannot pattern match like this, see note on [[ExprBuilders]] 31 | case q"${l : sparkexpr.Expression} + ${r: sparkexpr.Expression }" => 32 | Some(sparkexpr.Add(l, r)) 33 | */ 34 | case q"$lT + $rT" => 35 | for ((l, r) <- binaryArgs(lT, rT)) 36 | yield sparkexpr.Add(l, r) 37 | case q"$lT - $rT" => 38 | for ((l, r) <- binaryArgs(lT, rT)) 39 | yield sparkexpr.Subtract(l, r) 40 | case q"$lT * $rT" => 41 | for ((l, r) <- binaryArgs(lT, rT)) 42 | yield sparkexpr.Multiply(l, r) 43 | case q"$lT / $rT" => 44 | for ((l, r) <- binaryArgs(lT, rT)) 45 | yield 46 | if ( l.dataType.isInstanceOf[IntegralType] ) { 47 | sparkexpr.IntegralDivide(l, r) 48 | } else { 49 | sparkexpr.Divide(l, r) 50 | } 51 | case q"$lT % $rT" => 52 | for ((l, r) <- binaryArgs(lT, rT)) 53 | yield sparkexpr.Remainder(l, r) 54 | case _ => None 55 | } 56 | } 57 | 58 | object JavaMathFuncs { 59 | val mathCompanion = macroUniverse.typeOf[java.lang.Math].companion 60 | val absFuncs = mathCompanion.decl(TermName("abs")) 61 | 62 | def unapply(t: mTree): Option[sparkexpr.Expression] = 63 | t match { 64 | case q"$id(..$args)" if args.size == 1 && absFuncs.alternatives.contains(id.symbol) => 65 | for ( 66 | c <- CatalystExpression.unapply(args(0).asInstanceOf[mTree]) 67 | ) yield sparkexpr.Abs(c) 68 | case _ => None 69 | } 70 | } 71 | 72 | } 73 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/CollectionUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import java.sql.{Date, Timestamp} 20 | 21 | import org.apache.spark.unsafe.types.CalendarInterval 22 | 23 | // scalastyle:off 24 | 25 | object CollectionUtils { 26 | 27 | def mapEntries[K,V](m : Map[K, V]) : Array[(K,V)] = ??? 28 | def mapFromEntries[K,V](arr : Array[(K,V)]) : Map[K, V] = ??? 29 | 30 | def sortArray[T](arr : Array[T], asc : Boolean) : Array[T] = ??? 31 | def shuffleArray[T](arr : Array[T]) : Array[T] = ??? 32 | def shuffleArray[T](arr : Array[T], randomSeed: Long) : Array[T] = ??? 33 | def overlapArrays[T](lArr : Array[T], rArr : Array[T]) : Array[T] = ??? 34 | def positionArray[T](lArr : Array[T], elem : T) : Long = ??? 35 | def sequence[T : Integral](start : T, stop : T, step : T) : Array[T] = ??? 36 | def date_sequence(start : Date, stop : Date, step : CalendarInterval) : Array[Date] = ??? 37 | def timestamp_sequence(start : Timestamp, stop : Timestamp, step : CalendarInterval) : Array[Timestamp] = ??? 38 | def removeArray[T](arr : Array[T], elem: T) : Array[T] = ??? 39 | def exceptArray[T](lArr : Array[T], rArr : Array[T]) : Array[T] = ??? 40 | def mapKeys[K, V](map : Map[K,V]) : Array[K] = ??? 41 | def mapValues[K, V](map : Map[K,V]) : Array[V] = ??? 42 | } 43 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/Collections.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 21 | import org.apache.spark.sql.types.{ArrayType, MapType} 22 | 23 | trait Collections { 24 | self: ExprBuilders with ExprTranslator => 25 | 26 | import macroUniverse._ 27 | 28 | object CollectionTrees { 29 | val arrObjTyp = typeOf[Array.type] 30 | val arrOpsTyp = typeOf[scala.collection.mutable.ArrayOps[_]] 31 | val mapObjTyp = typeOf[Map.type] 32 | val mapTyp = typeOf[Map[_, _]] 33 | 34 | def unapply(t: mTree): Option[sparkexpr.Expression] = 35 | Option(t match { 36 | case CollectionConstruct(e) => e 37 | case CollectionApply(e) => e 38 | case ArrayUnref(e) => e 39 | case ArrayUnwrap(e) => e 40 | case CollectionUtilsFunctions(e) => e 41 | case CollectionFunctions(e) => e 42 | case _ => null 43 | }) 44 | 45 | object ArrayUnwrap { 46 | val wrapRefArraySym = predefTyp.member(TermName("wrapRefArray")).alternatives 47 | val wrapIntArraySym = predefTyp.member(TermName("wrapIntArray")).alternatives 48 | val wrapDoubleArraySym = predefTyp.member(TermName("wrapDoubleArray")).alternatives 49 | val wrapLongArraySym = predefTyp.member(TermName("wrapLongArray")).alternatives 50 | val wrapFloatArraySym = predefTyp.member(TermName("wrapFloatArray")).alternatives 51 | val wrapCharArraySym = predefTyp.member(TermName("wrapCharArray")).alternatives 52 | val wrapByteArraySym = predefTyp.member(TermName("wrapByteArray")).alternatives 53 | val wrapShortArraySym = predefTyp.member(TermName("wrapShortArray")).alternatives 54 | val wrapBooleanArraySym = predefTyp.member(TermName("wrapBooleanArray")).alternatives 55 | 56 | val wrapArraySyms = (wrapRefArraySym ++ wrapIntArraySym ++ wrapDoubleArraySym ++ 57 | wrapLongArraySym ++ wrapFloatArraySym ++ wrapCharArraySym ++ 58 | wrapByteArraySym ++ wrapShortArraySym ++ wrapBooleanArraySym).toSet 59 | 60 | def unapply(t: mTree): Option[sparkexpr.Expression] = t match { 61 | case ModuleMethodCall(id, args, args2) if wrapArraySyms.contains(t.symbol) => 62 | CatalystExpression.unapply(args(0)) 63 | case _ => None 64 | } 65 | } 66 | 67 | object ArrayUnref { 68 | val refArrayOpsSym = predefTyp.member(TermName("refArrayOps")).alternatives 69 | val intArrayOpsSym = predefTyp.member(TermName("intArrayOps")).alternatives 70 | val doubleArrayOpsSym = predefTyp.member(TermName("doubleArrayOps")).alternatives 71 | val longArrayOpsSym = predefTyp.member(TermName("longArrayOps")).alternatives 72 | val floatArrayOpsSym = predefTyp.member(TermName("floatArrayOps")).alternatives 73 | val charArrayOpsSym = predefTyp.member(TermName("charArrayOps")).alternatives 74 | val byteArrayOpsSym = predefTyp.member(TermName("byteArrayOps")).alternatives 75 | val shortArrayOpsSym = predefTyp.member(TermName("shortArrayOps")).alternatives 76 | val booleanArrayOpsSym = predefTyp.member(TermName("booleanArrayOps")).alternatives 77 | 78 | val arrayOpsSyms = (refArrayOpsSym ++ intArrayOpsSym ++ doubleArrayOpsSym ++ 79 | longArrayOpsSym ++ floatArrayOpsSym ++ charArrayOpsSym ++ 80 | byteArrayOpsSym ++ shortArrayOpsSym ++ booleanArrayOpsSym).toSet 81 | 82 | def unapply(t: mTree): Option[sparkexpr.Expression] = t match { 83 | case ModuleMethodCall(id, args, args2) if arrayOpsSyms.contains(t.symbol) => 84 | CatalystExpression.unapply(args(0)) 85 | case _ => None 86 | } 87 | } 88 | 89 | object CollectionFunctions { 90 | val arrSizeSym = arrOpsTyp.member(TermName("size")).alternatives 91 | val arrZipSym = arrOpsTyp.member(TermName("zip")).alternatives 92 | val arrMkStringSym = arrOpsTyp.member(TermName("mkString")).alternatives 93 | val arrMinSym = arrOpsTyp.member(TermName("min")).alternatives 94 | val arrMaxSym = arrOpsTyp.member(TermName("max")).alternatives 95 | val arrReverseSym = arrOpsTyp.member(TermName("reverse")).alternatives 96 | val arrContainsSym = arrOpsTyp.member(TermName("contains")).alternatives 97 | val arrSliceSym = arrOpsTyp.member(TermName("slice")).alternatives 98 | val arrPlusPlusSym = arrOpsTyp.member(TermName("$plus$plus")).alternatives 99 | val arrFlattenSym = arrOpsTyp.member(TermName("flatten")).alternatives 100 | val arrIntersectSym = arrOpsTyp.member(TermName("intersect")).alternatives 101 | val arrDistinctSym = arrOpsTyp.member(TermName("distinct")).alternatives 102 | val arrFillSym = arrObjTyp.member(TermName("fill")).alternatives 103 | val mapPlusPlusSym = mapTyp.member(TermName("$plus$plus")).alternatives 104 | 105 | def unapply(t: mTree): Option[sparkexpr.Expression] = 106 | t match { 107 | case InstanceMethodCall(elem, args1, args2) => 108 | if (arrSizeSym.contains(t.symbol)) { 109 | for (arrE <- CatalystExpression.unapply(elem)) yield sparkexpr.Size(arrE) 110 | } else if (arrZipSym.contains(t.symbol)) { 111 | for (lArr <- CatalystExpression.unapply(elem); 112 | rArr <- CatalystExpression.unapply(args1(0)) 113 | ) yield sparkexpr.ArraysZip(Seq(lArr, rArr)) 114 | } else if (arrMkStringSym.contains(t.symbol) && args1.size == 0) { 115 | for (arrE <- CatalystExpression.unapply(elem)) yield 116 | sparkexpr.ArrayJoin(arrE, sparkexpr.Literal(" "), None) 117 | } else if (arrMkStringSym.contains(t.symbol) && args1.size == 1) { 118 | for (arrE <- CatalystExpression.unapply(elem); 119 | sepE <- CatalystExpression.unapply(args1(0))) yield 120 | sparkexpr.ArrayJoin(arrE, sepE, None) 121 | } else if (arrMinSym.contains(t.symbol)) { 122 | for (arrE <- CatalystExpression.unapply(elem)) yield sparkexpr.ArrayMin(arrE) 123 | } else if (arrMaxSym.contains(t.symbol)) { 124 | for (arrE <- CatalystExpression.unapply(elem)) yield sparkexpr.ArrayMax(arrE) 125 | } else if (arrReverseSym.contains(t.symbol)) { 126 | for (arrE <- CatalystExpression.unapply(elem)) yield sparkexpr.Reverse(arrE) 127 | } else if (arrContainsSym.contains(t.symbol)) { 128 | for (arrE <- CatalystExpression.unapply(elem); 129 | elemE <- CatalystExpression.unapply(args1(0))) yield 130 | sparkexpr.ArrayContains(arrE, elemE) 131 | } else if (arrSliceSym.contains(t.symbol)) { 132 | for (arrE <- CatalystExpression.unapply(elem); 133 | fromE <- CatalystExpression.unapply(args1(0)); 134 | untilE <- CatalystExpression.unapply(args1(0)) 135 | ) yield 136 | sparkexpr.Slice( 137 | arrE, 138 | fromE, 139 | sparkexpr.Add(sparkexpr.Subtract(untilE, fromE), sparkexpr.Literal(1)) 140 | ) 141 | } else if (arrPlusPlusSym.contains(t.symbol)) { 142 | for (lArr <- CatalystExpression.unapply(elem); 143 | rArr <- CatalystExpression.unapply(args1(0)) 144 | ) yield sparkexpr.Concat(Seq(lArr, rArr)) 145 | } else if (arrFlattenSym.contains(t.symbol)) { 146 | for (arrE <- CatalystExpression.unapply(elem)) yield sparkexpr.Flatten(arrE) 147 | } else if (arrIntersectSym.contains(t.symbol)) { 148 | for (lArr <- CatalystExpression.unapply(elem); 149 | rArr <- CatalystExpression.unapply(args1(0)) 150 | ) yield sparkexpr.ArrayIntersect(lArr, rArr) 151 | } else if (arrDistinctSym.contains(t.symbol)) { 152 | for (arrE <- CatalystExpression.unapply(elem)) yield sparkexpr.ArrayDistinct(arrE) 153 | } else if (mapPlusPlusSym.contains(t.symbol) && args1.size == 1) { 154 | for (mapE <- CatalystExpression.unapply(elem); 155 | oMapE <- CatalystExpression.unapply(args1(0))) yield 156 | sparkexpr.MapConcat(Seq(mapE, oMapE)) 157 | } else None 158 | case ModuleMethodCall(id, args, args2) => 159 | if (arrFillSym.contains(id)) { 160 | for (countE <- CatalystExpression.unapply(args(0)); 161 | elemE <- CatalystExpression.unapply(args2(0)) 162 | ) yield sparkexpr.ArrayRepeat(elemE, countE) 163 | } else None 164 | case _ => None 165 | } 166 | } 167 | 168 | object CollectionUtilsFunctions { 169 | val collUtilsTyp = typeOf[CollectionUtils.type] 170 | 171 | // TODO: MapKeys, MapValues <- CollectUtils; MapConcat 172 | val mapEntriesSym = collUtilsTyp.member(TermName("mapEntries")).alternatives 173 | val mapFromEntriesSym = collUtilsTyp.member(TermName("mapFromEntries")).alternatives 174 | 175 | val sortArraySym = collUtilsTyp.member(TermName("sortArray")).alternatives 176 | val shuffleSym = collUtilsTyp.member(TermName("shuffleArray")).alternatives 177 | val overlapSym = collUtilsTyp.member(TermName("overlapArrays")).alternatives 178 | val positionSym = collUtilsTyp.member(TermName("positionArray")).alternatives 179 | val sequenceSym = collUtilsTyp.member(TermName("sequence")).alternatives 180 | val date_sequenceSym = collUtilsTyp.member(TermName("date_sequence")).alternatives 181 | val timestamp_sequenceSym = collUtilsTyp.member(TermName("timestamp_sequence")).alternatives 182 | val removeSym = collUtilsTyp.member(TermName("removeArray")).alternatives 183 | val exceptSym = collUtilsTyp.member(TermName("exceptArray")).alternatives 184 | val mapKeysSym = collUtilsTyp.member(TermName("mapKeys")).alternatives 185 | val mapValuesSym = collUtilsTyp.member(TermName("mapValues")).alternatives 186 | 187 | def unapply(t: mTree): Option[sparkexpr.Expression] = 188 | t match { 189 | case ModuleMethodCall(id, args, args2) => 190 | if (mapEntriesSym.contains(id.symbol)) { 191 | for ( 192 | mE <- CatalystExpression.unapply(args(0)) 193 | ) yield sparkexpr.MapEntries(mE) 194 | } else if (mapFromEntriesSym.contains(id.symbol)) { 195 | for ( 196 | aE <- CatalystExpression.unapply(args(0)) 197 | ) yield sparkexpr.MapFromEntries(aE) 198 | } else if (sortArraySym.contains(id.symbol)) { 199 | for ( 200 | aE <- CatalystExpression.unapply(args(0)); 201 | sE <- CatalystExpression.unapply(args(1)) 202 | ) yield sparkexpr.SortArray(aE, sE) 203 | } else if (shuffleSym.contains(id.symbol)) { 204 | (for ( 205 | argEs <- CatalystExpressions.unapplySeq(args) 206 | ) yield if (args.size == 1) { 207 | Some(sparkexpr.Shuffle(argEs.head)) 208 | } else if (args.last.isInstanceOf[sparkexpr.Literal]) { 209 | Some(sparkexpr.Shuffle(argEs.head, 210 | Some(argEs.last.asInstanceOf[sparkexpr.Literal].value.asInstanceOf[Long]) 211 | )) 212 | } else None).flatten 213 | } else if (overlapSym.contains(id.symbol)) { 214 | for ( 215 | lArr <- CatalystExpression.unapply(args(0)); 216 | rArr <- CatalystExpression.unapply(args(1)) 217 | ) yield sparkexpr.ArraysOverlap(lArr, rArr) 218 | } else if (positionSym.contains(id.symbol)) { 219 | for ( 220 | arrE <- CatalystExpression.unapply(args(0)); 221 | elemE <- CatalystExpression.unapply(args(1)) 222 | ) yield sparkexpr.ArrayPosition(arrE, elemE) 223 | } else if (sequenceSym.contains(id.symbol)) { 224 | for ( 225 | startE <- CatalystExpression.unapply(args(0)); 226 | stopE <- CatalystExpression.unapply(args(1)); 227 | stepE <- CatalystExpression.unapply(args(2)) 228 | ) yield sparkexpr.Sequence(startE, stopE, Some(stepE), None) 229 | } else if (date_sequenceSym.contains(id.symbol)) { 230 | for ( 231 | startE <- CatalystExpression.unapply(args(0)); 232 | stopE <- CatalystExpression.unapply(args(1)); 233 | stepE <- CatalystExpression.unapply(args(2)) 234 | ) yield sparkexpr.Sequence(startE, stopE, Some(stepE), None) 235 | } else if (timestamp_sequenceSym.contains(id.symbol)) { 236 | for ( 237 | startE <- CatalystExpression.unapply(args(0)); 238 | stopE <- CatalystExpression.unapply(args(1)); 239 | stepE <- CatalystExpression.unapply(args(2)) 240 | ) yield sparkexpr.Sequence(startE, stopE, Some(stepE), None) 241 | } else if (removeSym.contains(id.symbol)) { 242 | for ( 243 | arrE <- CatalystExpression.unapply(args(0)); 244 | elemE <- CatalystExpression.unapply(args(1)) 245 | ) yield sparkexpr.ArrayRemove(arrE, elemE) 246 | } else if (exceptSym.contains(id.symbol)) { 247 | for ( 248 | lArr <- CatalystExpression.unapply(args(0)); 249 | rArr <- CatalystExpression.unapply(args(1)) 250 | ) yield sparkexpr.ArrayExcept(lArr, rArr) 251 | } else if (mapKeysSym.contains(id.symbol)) { 252 | for ( 253 | mapE <- CatalystExpression.unapply(args(0)) 254 | ) yield sparkexpr.MapKeys(mapE) 255 | } else if (mapValuesSym.contains(id.symbol)) { 256 | for ( 257 | mapE <- CatalystExpression.unapply(args(0)) 258 | ) yield sparkexpr.MapValues(mapE) 259 | } else None 260 | case _ => None 261 | } 262 | } 263 | 264 | object CollName { 265 | def unapply(t: mTree): 266 | Option[mTermName] = t match { 267 | case Select(Ident(collNm), TermName("apply")) if collNm.isTermName => 268 | Some(collNm.toTermName) 269 | case _ => None 270 | } 271 | } 272 | 273 | object GetEntryExpr { 274 | def unapply(vInfo: ValInfo, idxExpr: sparkexpr.Expression): 275 | Option[sparkexpr.Expression] = vInfo.typInfo.catalystType match { 276 | case a: ArrayType => Some(sparkexpr.GetArrayItem(vInfo.rhsExpr, idxExpr)) 277 | case m: MapType => Some(sparkexpr.GetMapValue(vInfo.rhsExpr, idxExpr)) 278 | case _ => None 279 | } 280 | } 281 | 282 | object CollectionApply { 283 | def unapply(t: mTree): Option[sparkexpr.Expression] = 284 | t match { 285 | case q"$id(..$args)" if args.size == 1 => 286 | for ( 287 | collNm <- CollName.unapply(id); 288 | vInfo <- scope.get(collNm); 289 | idxExpr <- CatalystExpression.unapply(args(0).asInstanceOf[mTree]); 290 | valExpr <- GetEntryExpr.unapply(vInfo, idxExpr) 291 | ) yield valExpr 292 | case _ => None 293 | } 294 | } 295 | 296 | object CollectionConstruct { 297 | val arrApplySym = arrObjTyp.decl(TermName("apply")) 298 | val mapApplySym = mapObjTyp.member(TermName("apply")) 299 | 300 | def unapply(t: mTree): Option[sparkexpr.Expression] = 301 | t match { 302 | case q"$id(..$args)" if arrApplySym.alternatives.contains(id.symbol) => 303 | for ( 304 | entries <- CatalystExpressions.unapplySeq(args) 305 | ) yield sparkexpr.CreateArray(entries) 306 | case q"$id(..$args)(..$implArgs)" if arrApplySym.alternatives.contains(id.symbol) => 307 | for ( 308 | entries <- CatalystExpressions.unapplySeq(args) 309 | ) yield sparkexpr.CreateArray(entries) 310 | case q"$id(..$args)" if mapApplySym.alternatives.contains(id.symbol) => 311 | for ( 312 | entries <- CatalystExpressions.unapplySeq(args) 313 | if entries.forall(_.isInstanceOf[sparkexpr.CreateNamedStruct]) 314 | ) yield { 315 | val mEntries = entries.flatMap(_.asInstanceOf[sparkexpr.CreateNamedStruct].valExprs) 316 | sparkexpr.CreateMap(mEntries) 317 | } 318 | case _ => None 319 | } 320 | } 321 | 322 | } 323 | 324 | } 325 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/Conditionals.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 20 | import org.apache.spark.sql.types.BooleanType 21 | 22 | /** 23 | * Translation for logical operators(AND, OR, NOT), for comparison operators 24 | * (`>, >=, <, <=, ==, !=`), string predicate functions(startsWith, endsWith, contains) 25 | * the `if` statement and the `case` statement. 26 | * 27 | * Support for case statements is limited: 28 | * - case pattern must be `cq"$pat => $expr2"`, so no `if` in case 29 | * - the pattern must be a literal for constructor pattern like `(a,b)`, `Point(1,2)` etc. 30 | * 31 | */ 32 | trait Conditionals { self: ExprTranslator => 33 | 34 | import macroUniverse._ 35 | 36 | val predUtilsTyp = typeOf[PredicateUtils.type ] 37 | val any2PredSym = predUtilsTyp.member(TermName("any2Preds")) 38 | 39 | val anyWithPredsTyp = typeOf[PredicateUtils.AnyWithPreds] 40 | 41 | val nullCheckMethods = Seq("is_null", "is_not_null") 42 | val null_safe_eq = "null_safe_eq" 43 | val inMethods = Seq("in", "not_in") 44 | 45 | private def isStrMethodCall(t : mTree) : Boolean = 46 | startsWithSyms.contains(t.symbol) || endsWithSym == t.symbol || containsSym == t.symbol 47 | 48 | private def boolExpr(expr : sparkexpr.Expression) : Option[sparkexpr.Expression] = 49 | expr.dataType match { 50 | case BooleanType => Some(expr) 51 | case dt if sparkexpr.Cast.canCast(dt, BooleanType) => Some(sparkexpr.Cast(expr, BooleanType)) 52 | case _ => None 53 | } 54 | 55 | private def boolExprs(lTree : mTree, 56 | rTree : mTree) : Option[(sparkexpr.Expression, sparkexpr.Expression)] = { 57 | for ( 58 | (lexpr, rexpr) <- binaryArgs(lTree, rTree); 59 | lboolExpr <- boolExpr(lexpr); 60 | rboolExpr <- boolExpr(rexpr) 61 | ) yield (lboolExpr, rboolExpr) 62 | } 63 | 64 | private def compareExprs(lTree : mTree, 65 | rTree : mTree) : Option[(sparkexpr.Expression, sparkexpr.Expression)] = { 66 | for ( 67 | (lexpr, rexpr) <- binaryArgs(lTree, rTree) 68 | if sparkexpr.RowOrdering.isOrderable(lexpr.dataType) 69 | ) yield (lexpr, rexpr) 70 | } 71 | 72 | object Predicates { 73 | 74 | private def nm(t : TermName) = t.decodedName.toString 75 | 76 | def unapply(t: mTree): Option[sparkexpr.Expression] = 77 | t match { 78 | case q"!$cond" => 79 | for ( 80 | expr <- CatalystExpression.unapply(cond); 81 | boolExpr <- boolExpr(expr) 82 | ) yield boolExpr 83 | case q"$lCond && $rCond" => 84 | for ( 85 | (l, r) <- boolExprs(lCond, rCond) 86 | ) yield sparkexpr.And(l, r) 87 | case q"$lCond || $rCond" => 88 | for ( 89 | (l, r) <- boolExprs(lCond, rCond) 90 | ) yield sparkexpr.Or(l, r) 91 | case q"$lT > $rT" => 92 | for ( 93 | (l, r) <- compareExprs(lT, rT) 94 | ) yield sparkexpr.GreaterThan(l, r) 95 | case q"$lT >= $rT" => 96 | for ( 97 | (l, r) <- compareExprs(lT, rT) 98 | ) yield sparkexpr.GreaterThanOrEqual(l, r) 99 | case q"$lT < $rT" => 100 | for ( 101 | (l, r) <- compareExprs(lT, rT) 102 | ) yield sparkexpr.LessThan(l, r) 103 | case q"$lT <= $rT" => 104 | for ( 105 | (l, r) <- compareExprs(lT, rT) 106 | ) yield sparkexpr.LessThanOrEqual(l, r) 107 | case q"$lT == $rT" => 108 | for ( 109 | (l, r) <- compareExprs(lT, rT) 110 | ) yield sparkexpr.EqualTo(l, r) 111 | case q"$lT != $rT" => 112 | for ( 113 | (l, r) <- compareExprs(lT, rT) 114 | ) yield sparkexpr.Not(sparkexpr.EqualTo(l, r)) 115 | case q"$id(..$args).$m" 116 | if id.symbol == any2PredSym && args.size == 1 && 117 | nullCheckMethods.contains(nm(m)) => 118 | for ( 119 | expr <- CatalystExpression.unapply(args(0)) 120 | ) yield { 121 | if (nm(m) == "is_null") { 122 | sparkexpr.IsNull(expr) 123 | } else { 124 | sparkexpr.IsNotNull(expr) 125 | } 126 | } 127 | case q"$id(..$args1).$m(..$args2)" 128 | if id.symbol == any2PredSym && args1.size == 1 && 129 | inMethods.contains(nm(m)) || nm(m) == null_safe_eq => 130 | for ( 131 | lexpr <- CatalystExpression.unapply(args1(0)); 132 | inExprs <- CatalystExpressions.unapplySeq(args2) 133 | ) yield { 134 | if (nm(m) == "in") { 135 | sparkexpr.In(lexpr, inExprs) 136 | } else if (nm(m) == "not_in") { 137 | sparkexpr.Not(sparkexpr.In(lexpr, inExprs)) 138 | } else { 139 | sparkexpr.EqualNullSafe(lexpr, inExprs.head) 140 | } 141 | } 142 | case q"$id(..$args)" if isStrMethodCall(id) => 143 | id match { 144 | case q"$l.$m" => 145 | for ( 146 | lexpr <- CatalystExpression.unapply(l); 147 | rexpr <- CatalystExpression.unapply(args(0)) 148 | ) yield { 149 | if ( nm(m) == "startsWith" ) { 150 | sparkexpr.StartsWith(lexpr, rexpr) 151 | } else if ( nm(m) == "endsWith" ) { 152 | sparkexpr.EndsWith(lexpr, rexpr) 153 | } else { 154 | sparkexpr.Contains(lexpr, rexpr) 155 | } 156 | } 157 | case _ => None 158 | } 159 | case _ => None 160 | } 161 | } 162 | 163 | object IFCase { 164 | 165 | private def caseEntries(caseTrees : Seq[mTree]) : Option[Seq[(mTree, mTree)]] = 166 | SparkSQLMacroUtils.sequence( 167 | caseTrees map { 168 | case cq"$pat => $expr2" => Some((pat, expr2)) 169 | case _ => None 170 | } 171 | ) 172 | 173 | private def caseExpr(caseEntry : (mTree, mTree)) : Option[Seq[sparkexpr.Expression]] 174 | = caseEntry match { 175 | case (pq"_", t) => CatalystExpressions.unapplySeq(Seq(t)) 176 | case (t1, t2) => CatalystExpressions.unapplySeq(Seq(t1, t2)) 177 | } 178 | 179 | def unapply(t: mTree): Option[sparkexpr.Expression] = t match { 180 | case q"if ($cond) $thenp else $elsep" => 181 | for ( 182 | condExpr <- CatalystExpression.unapply(cond); 183 | thenExpr <- CatalystExpression.unapply(thenp); 184 | elseExpr <- CatalystExpression.unapply(elsep) 185 | ) yield sparkexpr.If(condExpr, thenExpr, elseExpr) 186 | case q"$expr match { case ..$cases } " => 187 | for ( 188 | matchExpr <- CatalystExpression.unapply((expr)); 189 | caseEntries <- caseEntries(cases); 190 | exprs <- SparkSQLMacroUtils.sequence(caseEntries.map(caseExpr)) 191 | ) yield sparkexpr.CaseKeyWhen(matchExpr, exprs.flatten) 192 | 193 | case _ => None 194 | } 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/DateTime.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import java.time.ZoneId 20 | 21 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr, CatalystTypeConverters} 22 | import org.apache.spark.sql.types._ 23 | import org.apache.spark.unsafe.types.CalendarInterval 24 | 25 | 26 | /** 27 | * Handle translation of ''Date, Timestamp and Interval'' values. We support 28 | * translation of functions from [[MacroDateTimeUtils]] module. 29 | * 30 | * ''Construction support:'' 31 | * - `new java.sql.Date(epochVal)` is translated to 32 | * `Cast(MillisToTimestamp(spark_expr(epochVal)), DateType)` 33 | * - `new java.sql.Timestamp(epochVal` is translated to `MillisToTimestamp(spark_expr(epochVal))` 34 | * - `java.time.LocalDate.of(yr, month, dayOfMonth)` is translated to 35 | * `MakeDate(spark_expr(yr), spark_expr(month), spark_expr(dayOfMonth))` 36 | * - we don't support translation of construction of `java.time.Instant` 37 | * because there is no spark expression to construct a Date value from 38 | * `long epochSecond, int nanos`. 39 | * 40 | * ''Certain Arguments must be macro compile time static values:'' 41 | * - there is no spark expression to construct a [[CalendarInterval]], so these 42 | * must be static values. 43 | * - [[ZoneId]] must be a static value. 44 | */ 45 | trait DateTime { self: ExprTranslator => 46 | 47 | import macroUniverse._ 48 | 49 | val dateTyp = typeOf[java.sql.Date] 50 | val localDateTyp = typeOf[java.time.LocalDate] 51 | val timestampTyp = typeOf[java.sql.Timestamp] 52 | val instantTyp = typeOf[java.time.Instant] 53 | val calIntervalTyp = typeOf[CalendarInterval] 54 | val zoneIdTyp = typeOf[ZoneId] 55 | 56 | val dateConstructors = dateTyp.member(termNames.CONSTRUCTOR).alternatives 57 | val timeStampConstructors = timestampTyp.member(termNames.CONSTRUCTOR).alternatives 58 | val localDteOfMethod = localDateTyp.companion.member(TermName("of")). 59 | alternatives.filter(m => m.asMethod.paramLists(0)(1).typeSignature =:= typeOf[Int]) 60 | 61 | val dateTimeUtils = typeOf[DateTimeUtils.type] 62 | 63 | val fromJavaDateSym = dateTimeUtils.member(TermName("fromJavaDate")) 64 | val fromJavaTimestamp = dateTimeUtils.member(TermName("fromJavaTimestamp")) 65 | val instantToMicrosSym = dateTimeUtils.member(TermName("instantToMicros")) 66 | val localDateToDaysSym = dateTimeUtils.member(TermName("localDateToDays")) 67 | 68 | val microsToDaysSym = dateTimeUtils.member(TermName("microsToDays")) 69 | val daysToMicrosSym = dateTimeUtils.member(TermName("daysToMicros")) 70 | val millisToMicrosSym = dateTimeUtils.member(TermName("millisToMicros")) 71 | 72 | val stringToTimestampSym = dateTimeUtils.member(TermName("stringToTimestamp")) 73 | val stringToTimestampAnsiSym = dateTimeUtils.member(TermName("stringToTimestampAnsi")) 74 | val stringToDateSym = dateTimeUtils.member(TermName("stringToDate")) 75 | 76 | val getHoursSym = dateTimeUtils.member(TermName("getHours")) 77 | val getMinutesSym = dateTimeUtils.member(TermName("getMinutes")) 78 | val getSecondsSym = dateTimeUtils.member(TermName("getSeconds")) 79 | val getSecondsWithFractionSym = dateTimeUtils.member(TermName("getSecondsWithFraction")) 80 | // val getMicrosecondsSym = dateTimeUtils.member(TermName("getMicroseconds")) 81 | val getDayInYearSym = dateTimeUtils.member(TermName("getDayInYear")) 82 | val getYearSym = dateTimeUtils.member(TermName("getYear")) 83 | val getWeekBasedYearSym = dateTimeUtils.member(TermName("getWeekBasedYear")) 84 | val getQuarterSym = dateTimeUtils.member(TermName("getQuarter")) 85 | val getMonthSym = dateTimeUtils.member(TermName("getMonth")) 86 | val getDayOfMonthSym = dateTimeUtils.member(TermName("getDayOfMonth")) 87 | val getDayOfWeekSym = dateTimeUtils.member(TermName("getDayOfWeek")) 88 | val getWeekDaySym = dateTimeUtils.member(TermName("getWeekDay")) 89 | val getWeekOfYearSym = dateTimeUtils.member(TermName("getWeekOfYear")) 90 | 91 | val dateAddMonths = dateTimeUtils.member(TermName("dateAddMonths")) 92 | val timestampAddIntervalSym = dateTimeUtils.member(TermName("timestampAddInterval")) 93 | val dateAddIntervalSym = dateTimeUtils.member(TermName("dateAddInterval")) 94 | val monthsBetweenSym = dateTimeUtils.member(TermName("monthsBetween")) 95 | val getNextDateForDayOfWeekSym = dateTimeUtils.member(TermName("getNextDateForDayOfWeek")) 96 | val getLastDayOfMonth = dateTimeUtils.member(TermName("getLastDayOfMonth")) 97 | val truncDateSym = dateTimeUtils.member(TermName("truncDate")) 98 | val truncTimestampSym = dateTimeUtils.member(TermName("truncTimestamp")) 99 | val fromUTCTimeSym = dateTimeUtils.member(TermName("fromUTCTimeSym")) 100 | val toUTCTimeSym = dateTimeUtils.member(TermName("toUTCTime")) 101 | val currentTimestampSym = dateTimeUtils.member(TermName("currentTimestamp")) 102 | val currentDateSym = dateTimeUtils.member(TermName("currentDate")) 103 | 104 | private val dayOfWeekMap = { 105 | // from DateTimeUtils 106 | val SUNDAY = 3 107 | val MONDAY = 4 108 | val TUESDAY = 5 109 | val WEDNESDAY = 6 110 | val THURSDAY = 0 111 | val FRIDAY = 1 112 | val SATURDAY = 2 113 | sparkexpr.Literal( 114 | CatalystTypeConverters.convertToCatalyst( 115 | Map(SUNDAY -> "SU", MONDAY -> "MO", TUESDAY -> "TU", WEDNESDAY -> "WE", 116 | THURSDAY -> "TH", FRIDAY -> "FR", SATURDAY -> "SA") 117 | ), MapType(IntegerType, StringType) 118 | ) 119 | } 120 | 121 | private val truncLevelMap = { 122 | val TRUNC_TO_MICROSECOND = 0 123 | val TRUNC_TO_MILLISECOND = 1 124 | val TRUNC_TO_SECOND = 2 125 | val TRUNC_TO_MINUTE = 3 126 | val TRUNC_TO_HOUR = 4 127 | val TRUNC_TO_DAY = 5 128 | val TRUNC_TO_WEEK = 6 129 | val TRUNC_TO_MONTH = 7 130 | val TRUNC_TO_QUARTER = 8 131 | val TRUNC_TO_YEAR = 9 132 | sparkexpr.Literal( 133 | CatalystTypeConverters.convertToCatalyst( 134 | Map( 135 | TRUNC_TO_MICROSECOND -> "MICROSECOND", 136 | TRUNC_TO_MILLISECOND -> "MILLISECOND", 137 | TRUNC_TO_SECOND -> "SECOND", 138 | TRUNC_TO_MINUTE -> "MINUTE", 139 | TRUNC_TO_HOUR -> "HOUR", 140 | TRUNC_TO_DAY -> "DAY", 141 | TRUNC_TO_WEEK -> "WEEK", 142 | TRUNC_TO_MONTH -> "MON", 143 | TRUNC_TO_QUARTER -> "QUARTER", 144 | TRUNC_TO_YEAR -> "YEAR" 145 | ) 146 | ), MapType(IntegerType, StringType) 147 | ) 148 | } 149 | 150 | object DateTimePatterns { 151 | def unapply(t: mTree): Option[sparkexpr.Expression] = 152 | t match { 153 | case q"new $id(..$args)" if dateConstructors.contains(t.symbol) && args.size == 1 => 154 | for (entries <- CatalystExpressions.unapplySeq(args)) 155 | yield sparkexpr.Cast(sparkexpr.MillisToTimestamp(entries.head), DateType) 156 | case q"new $id(..$args)" if timeStampConstructors.contains(t.symbol) && args.size == 1 => 157 | for (entries <- CatalystExpressions.unapplySeq(args)) 158 | yield sparkexpr.MillisToTimestamp(entries.head) 159 | case q"$id(..$args)" => 160 | id match { 161 | case id if localDteOfMethod.contains(id.symbol) => 162 | for (entries <- CatalystExpressions.unapplySeq(args)) yield 163 | sparkexpr.MakeDate(entries(0), entries(1), entries(2)) 164 | case id if id.symbol == microsToDaysSym => 165 | for (expr <- CatalystExpression.unapply(args.head); 166 | zId <- zoneId(args.tail.head)) 167 | yield sparkexpr.Cast(expr, DateType, Some(zId.toString)) 168 | case id if id.symbol == daysToMicrosSym => 169 | for (expr <- CatalystExpression.unapply(args.head); 170 | zId <- zoneId(args.tail.head)) 171 | yield sparkexpr.Cast(expr, TimestampType, Some(zId.toString)) 172 | /* 173 | * Child is already a internal Date/Timestamp form, so no conversion needed 174 | */ 175 | case id if id.symbol == instantToMicrosSym => 176 | for (entries <- CatalystExpressions.unapplySeq(args)) yield entries.head 177 | case id if id.symbol == fromJavaDateSym => 178 | for (entries <- CatalystExpressions.unapplySeq(args)) yield entries.head 179 | case id if id.symbol == fromJavaTimestamp => 180 | for (entries <- CatalystExpressions.unapplySeq(args)) yield entries.head 181 | case id if id.symbol == localDateToDaysSym => 182 | for (entries <- CatalystExpressions.unapplySeq(args)) yield entries.head 183 | /* 184 | * End no conversion needed 185 | */ 186 | case id if id.symbol == millisToMicrosSym => 187 | for (entries <- CatalystExpressions.unapplySeq(args)) 188 | yield sparkexpr.MillisToTimestamp(entries.head) 189 | case id if id.symbol == stringToTimestampSym => 190 | for (strExpr <- CatalystExpression.unapply(args.head); 191 | zId <- zoneId(args.tail.head)) yield 192 | sparkexpr.objects.WrapOption( 193 | sparkexpr.Cast(strExpr, TimestampType, Some(zId.toString)), 194 | TimestampType 195 | ) 196 | case id if id.symbol == stringToTimestampAnsiSym => 197 | for (strExpr <- CatalystExpression.unapply(args.head); 198 | zId <- zoneId(args.tail.head)) 199 | yield sparkexpr.AnsiCast(strExpr, TimestampType, Some(zId.toString)) 200 | case id if id.symbol == stringToDateSym => 201 | for (strExpr <- CatalystExpression.unapply(args.head); 202 | zId <- zoneId(args.tail.head)) 203 | yield sparkexpr.objects.WrapOption( 204 | sparkexpr.Cast(strExpr, DateType, Some(zId.toString)), 205 | DateType) 206 | case id if id.symbol == getHoursSym => 207 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.Hour(expr) 208 | case id if id.symbol == getMinutesSym => 209 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.Minute(expr) 210 | case id if id.symbol == getSecondsSym => 211 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.Second(expr) 212 | case id if id.symbol == getSecondsWithFractionSym => 213 | for (expr <- CatalystExpression.unapply(args.head); 214 | zId <- zoneId(args.tail.head)) 215 | yield sparkexpr.SecondWithFraction(expr, Some(zId.toString)) 216 | case id if id.symbol == getDayInYearSym => 217 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.DayOfYear(expr) 218 | case id if id.symbol == getYearSym => 219 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.Year(expr) 220 | case id if id.symbol == getWeekBasedYearSym => 221 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.YearOfWeek(expr) 222 | case id if id.symbol == getQuarterSym => 223 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.Quarter(expr) 224 | case id if id.symbol == getMonthSym => 225 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.Month(expr) 226 | case id if id.symbol == getDayOfMonthSym => 227 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.DayOfMonth(expr) 228 | case id if id.symbol == getDayOfWeekSym => 229 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.DayOfWeek(expr) 230 | case id if id.symbol == getWeekDaySym => 231 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.WeekDay(expr) 232 | case id if id.symbol == getWeekOfYearSym => 233 | for (expr <- CatalystExpression.unapply(args.head)) yield sparkexpr.WeekOfYear(expr) 234 | case id if id.symbol == dateAddMonths => 235 | for (entries <- CatalystExpressions.unapplySeq(args)) 236 | yield sparkexpr.AddMonths(entries.head, entries.tail.head) 237 | case id if id.symbol == timestampAddIntervalSym => 238 | for (expr <- CatalystExpression.unapply(args(0)); 239 | calInt <- staticCalInterval(args(1), "interval"); 240 | zId <- zoneId(args(2))) 241 | yield 242 | sparkexpr.TimeAdd( 243 | expr, 244 | sparkexpr.Literal( 245 | calInt, 246 | CalendarIntervalType), 247 | Some(zId.toString)) 248 | case id if id.symbol == dateAddIntervalSym => 249 | for (expr <- CatalystExpression.unapply(args(0)); 250 | calInt <- staticCalInterval(args(1), "interval") 251 | ) yield sparkexpr.DateAddInterval(expr, 252 | sparkexpr.Literal(calInt, CalendarIntervalType)) 253 | case id if id.symbol == monthsBetweenSym => 254 | for (micros1 <- CatalystExpression.unapply(args(0)); 255 | micros2 <- CatalystExpression.unapply(args(1)); 256 | roundoff <- CatalystExpression.unapply(args(2)); 257 | zId <- zoneId(args(3)) 258 | ) yield sparkexpr.MonthsBetween(micros1, micros2, roundoff, Some(zId.toString)) 259 | case id if id.symbol == getNextDateForDayOfWeekSym => 260 | for (expr <- CatalystExpression.unapply(args(0)); 261 | dayOfWeek <- CatalystExpression.unapply(args(1)) 262 | ) yield sparkexpr.NextDay(expr, dayOfWeek) 263 | case id if id.symbol == getLastDayOfMonth => 264 | for (expr <- CatalystExpression.unapply(args(0)) 265 | ) yield sparkexpr.LastDay(expr) 266 | case id if id.symbol == truncDateSym => 267 | for (days <- CatalystExpression.unapply(args(0)); 268 | level <- CatalystExpression.unapply(args(1))) yield 269 | sparkexpr.TruncDate(days, level) 270 | case id if id.symbol == truncTimestampSym => 271 | for (micros <- CatalystExpression.unapply(args(0)); 272 | level <- CatalystExpression.unapply(args(1)); 273 | zId <- zoneId(args(2)) 274 | ) yield 275 | sparkexpr.TruncTimestamp(level, micros, Some(zId.toString)) 276 | case id if id.symbol == fromUTCTimeSym => 277 | for (expr <- CatalystExpression.unapply(args(0)); 278 | zId <- CatalystExpression.unapply(args(1))) yield 279 | sparkexpr.FromUTCTimestamp(expr, zId) 280 | case id if id.symbol == toUTCTimeSym => 281 | for (expr <- CatalystExpression.unapply(args(0)); 282 | zId <- CatalystExpression.unapply(args(1))) yield 283 | sparkexpr.ToUTCTimestamp(expr, zId) 284 | case id if id.symbol == currentTimestampSym => 285 | Some(sparkexpr.CurrentTimestamp()) 286 | case id if id.symbol == currentDateSym => 287 | for (zId <- zoneId(args(0))) yield 288 | sparkexpr.CurrentDate(Some(zId.toString)) 289 | case _ => None 290 | } 291 | case _ => None 292 | } 293 | } 294 | 295 | private def zoneId(t: mTree): Option[ZoneId] = 296 | staticValue[ZoneId](t, "evaluate ZoneId expression as a static value") 297 | 298 | private def staticIntValue(t: mTree, typStr: String): Option[Int] = 299 | staticValue[Int](t, s"evaluate ${typStr} expression as a static int value") 300 | 301 | private def staticLongValue(t: mTree, typStr: String): Option[Long] = 302 | staticValue[Long](t, s"evaluate ${typStr} expression as a static long value") 303 | 304 | private def staticCalInterval(t: mTree, typStr: String): Option[CalendarInterval] = 305 | staticValue[CalendarInterval](t, 306 | s"evaluate ${typStr} expression as a static cal_interval value" 307 | ) 308 | 309 | } 310 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/DateTimeUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import java.sql.{Date, Timestamp} 20 | import java.time.{DateTimeException, Instant, LocalDate, ZoneId} 21 | 22 | import org.apache.spark.sql.types.Decimal 23 | import org.apache.spark.unsafe.types.CalendarInterval 24 | 25 | // scalastyle:off 26 | /** 27 | * A marker interface for code inside Macros, that is a replacement for functions in 28 | * [[org.apache.spark.sql.catalyst.util.DateTimeUtils]]. Use the 29 | * functions with [[Date]] values instead of [[Int]], [[Timestamp]] 30 | * values instead of [[Long]] and [[String]] values instead of 31 | * [[org.apache.spark.unsafe.types.UTF8String]] values. Other interface changes: 32 | * - getNextDateForDayOfWeek take String values for the dayOfWeek param 33 | * - `truncDate`, `truncTimestamp` take String values for the level param 34 | */ 35 | object DateTimeUtils { 36 | 37 | /** 38 | * Converts days since 1970-01-01 at the given zone ID to microseconds since 1970-01-01 00:00:00Z. 39 | */ 40 | def daysToMicros(days: Date, zoneId: ZoneId): Timestamp = ??? 41 | 42 | /** 43 | * Trims and parses a given UTF8 timestamp string to the corresponding a corresponding [[Long]] 44 | * value. The return type is [[Option]] in order to distinguish between 0L and null. The following 45 | * formats are allowed: 46 | * 47 | * `yyyy` 48 | * `yyyy-[m]m` 49 | * `yyyy-[m]m-[d]d` 50 | * `yyyy-[m]m-[d]d ` 51 | * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` 52 | * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` 53 | * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` 54 | * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` 55 | * 56 | * where `zone_id` should have one of the forms: 57 | * - Z - Zulu time zone UTC+0 58 | * - +|-[h]h:[m]m 59 | * - A short id, see https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS 60 | * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, 61 | * and a suffix in the formats: 62 | * - +|-h[h] 63 | * - +|-hh[:]mm 64 | * - +|-hh:mm:ss 65 | * - +|-hhmmss 66 | * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris` 67 | */ 68 | def stringToTimestamp(s: String, timeZoneId: ZoneId): Option[Timestamp] = ??? 69 | 70 | def stringToTimestampAnsi(s: String, timeZoneId: ZoneId): Timestamp = ??? 71 | 72 | /** 73 | * Gets the number of microseconds since the epoch of 1970-01-01 00:00:00Z from the given 74 | * instance of `java.time.Instant`. The epoch microsecond count is a simple incrementing count of 75 | * microseconds where microsecond 0 is 1970-01-01 00:00:00Z. 76 | */ 77 | def instantToMicros(instant: Instant): Timestamp = ??? 78 | 79 | /** 80 | * Converts the local date to the number of days since 1970-01-01. 81 | */ 82 | def localDateToDays(localDate: LocalDate): Date = ??? 83 | 84 | /** 85 | * Trims and parses a given UTF8 date string to a corresponding [[Int]] value. 86 | * The return type is [[Option]] in order to distinguish between 0 and null. The following 87 | * formats are allowed: 88 | * 89 | * `yyyy` 90 | * `yyyy-[m]m` 91 | * `yyyy-[m]m-[d]d` 92 | * `yyyy-[m]m-[d]d ` 93 | * `yyyy-[m]m-[d]d *` 94 | * `yyyy-[m]m-[d]dT*` 95 | */ 96 | def stringToDate(s: String, zoneId: ZoneId): Option[Date] = ??? 97 | 98 | /** 99 | * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. 100 | */ 101 | def getHours(micros: Timestamp, zoneId: ZoneId): Int = ??? 102 | 103 | /** 104 | * Returns the minute value of a given timestamp value. The timestamp is expressed in 105 | * microseconds since the epoch. 106 | */ 107 | def getMinutes(micros: Timestamp, zoneId: ZoneId): Int = ??? 108 | 109 | /** 110 | * Returns the second value of a given timestamp value. The timestamp is expressed in 111 | * microseconds since the epoch. 112 | */ 113 | def getSeconds(micros: Timestamp, zoneId: ZoneId): Int = ??? 114 | 115 | /** 116 | * Returns the seconds part and its fractional part with microseconds. 117 | */ 118 | def getSecondsWithFraction(micros: Timestamp, zoneId: ZoneId): Decimal = ??? 119 | 120 | /** 121 | * Returns local seconds, including fractional parts, multiplied by 1000000. 122 | * 123 | * @param micros The number of microseconds since the epoch. 124 | * @param zoneId The time zone id which milliseconds should be obtained in. 125 | */ 126 | def getMicroseconds(micros: Timestamp, zoneId: ZoneId): Int = ??? 127 | 128 | /** 129 | * Returns the 'day in year' value for the given number of days since 1970-01-01. 130 | */ 131 | def getDayInYear(days: Date): Int = ??? 132 | 133 | /** 134 | * Returns the year value for the given number of days since 1970-01-01. 135 | */ 136 | def getYear(days: Date): Int = ??? 137 | 138 | /** 139 | * Returns the year which conforms to ISO 8601. Each ISO 8601 week-numbering 140 | * year begins with the Monday of the week containing the 4th of January. 141 | */ 142 | def getWeekBasedYear(days: Date): Int = ??? 143 | 144 | /** Returns the quarter for the given number of days since 1970-01-01. */ 145 | def getQuarter(days: Date): Int = ??? 146 | 147 | /** 148 | * Returns the month value for the given number of days since 1970-01-01. 149 | * January is month 1. 150 | */ 151 | def getMonth(days: Date): Int = ??? 152 | 153 | /** 154 | * Returns the 'day of month' value for the given number of days since 1970-01-01. 155 | */ 156 | def getDayOfMonth(days: Date): Int = ??? 157 | 158 | /** 159 | * Returns the day of the week for the given number of days since 1970-01-01 160 | * (1 = Sunday, 2 = Monday, ..., 7 = Saturday). 161 | */ 162 | def getDayOfWeek(days: Date): Int = ??? 163 | 164 | /** 165 | * Returns the day of the week for the given number of days since 1970-01-01 166 | * (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). 167 | */ 168 | def getWeekDay(days: Date): Int = ??? 169 | 170 | /** 171 | * Returns the week of the year of the given date expressed as the number of days from 1970-01-01. 172 | * A week is considered to start on a Monday and week 1 is the first week with > 3 days. 173 | */ 174 | def getWeekOfYear(days: Date): Int = ??? 175 | 176 | /** 177 | * Adds an year-month interval to a date represented as days since 1970-01-01. 178 | * @return a date value, expressed in days since 1970-01-01. 179 | */ 180 | def dateAddMonths(days: Date, months: Int): Date = ??? 181 | 182 | /** 183 | * Adds a full interval (months, days, microseconds) a timestamp represented as the number of 184 | * microseconds since 1970-01-01 00:00:00Z. 185 | * @return A timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z. 186 | */ 187 | def timestampAddInterval(start: Timestamp, int: CalendarInterval, zoneId: ZoneId): Timestamp = 188 | ??? 189 | 190 | /** 191 | * Adds the interval's months and days to a date expressed as days since the epoch. 192 | * @return A date value, expressed in days since 1970-01-01. 193 | * 194 | * @throws DateTimeException if the result exceeds the supported date range 195 | * @throws IllegalArgumentException if the interval has `microseconds` part 196 | */ 197 | def dateAddInterval(start: Date, interval: CalendarInterval): Date = ??? 198 | 199 | /** 200 | * Returns number of months between micros1 and micros2. micros1 and micros2 are expressed in 201 | * microseconds since 1970-01-01. If micros1 is later than micros2, the result is positive. 202 | * 203 | * If micros1 and micros2 are on the same day of month, or both are the last day of month, 204 | * returns, time of day will be ignored. 205 | * 206 | * Otherwise, the difference is calculated based on 31 days per month. 207 | * The result is rounded to 8 decimal places if `roundOff` is set to true. 208 | */ 209 | def monthsBetween( 210 | micros1: Timestamp, 211 | micros2: Timestamp, 212 | roundOff: Boolean, 213 | zoneId: ZoneId): Double = ??? 214 | 215 | /** 216 | * Returns the first date which is later than startDate and is of the given dayOfWeek. 217 | * dayOfWeek can be 218 | * "SU" | "SUN" | "SUNDAY", "MO" | "MON" | "MONDAY", 219 | * "TU" | "TUE" | "TUESDAY", "WE" | "WED" | "WEDNESDAY", 220 | * "TH" | "THU" | "THURSDAY" => THURSDAY, "FR" | "FRI" | "FRIDAY" => FRIDAY, 221 | * "SA" | "SAT" | "SATURDAY" 222 | */ 223 | def getNextDateForDayOfWeek(startDay: Date, dayOfWeek: String): Date = ??? 224 | 225 | /** Returns last day of the month for the given number of days since 1970-01-01. */ 226 | def getLastDayOfMonth(days: Date): Date = ??? 227 | 228 | /** 229 | * Returns the trunc date from original date and trunc level. 230 | * level can be: 231 | * "WEEK", "MON" | "MONTH" | "MM", "QUARTER", 232 | * "YEAR" | "YYYY" | "YY" 233 | */ 234 | def truncDate(days: Date, level: String): Date = ??? 235 | 236 | /** 237 | * Returns the trunc date time from original date time and trunc level. 238 | * level can be: 239 | * "MICROSECOND", "MILLISECOND", "SECOND", "MINUTE", "HOUR", 240 | * "DAY" | "DD", "WEEK", "MON" | "MONTH" | "MM", "QUARTER", 241 | * "YEAR" | "YYYY" | "YY" 242 | */ 243 | def truncTimestamp(micros: Timestamp, level: String, zoneId: ZoneId): Timestamp = ??? 244 | 245 | /** 246 | * Returns a timestamp of given timezone from UTC timestamp, with the same string 247 | * representation in their timezone. 248 | */ 249 | def fromUTCTime(micros: Timestamp, timeZone: String): Timestamp = ??? 250 | 251 | /** 252 | * Returns a utc timestamp from a given timestamp from a given timezone, with the same 253 | * string representation in their timezone. 254 | */ 255 | def toUTCTime(micros: Timestamp, timeZone: String): Timestamp = ??? 256 | 257 | /** 258 | * Obtains the current instant as microseconds since the epoch at the UTC time zone. 259 | */ 260 | def currentTimestamp(): Timestamp = ??? 261 | 262 | /** 263 | * Obtains the current date as days since the epoch in the specified time-zone. 264 | */ 265 | def currentDate(zoneId: ZoneId): Date = ??? 266 | 267 | /** 268 | * Converts notational shorthands that are converted to ordinary timestamps. 269 | * 270 | * @param input A trimmed string 271 | * @param zoneId Zone identifier used to get the current date. 272 | * @return Some of microseconds since the epoch if the conversion completed 273 | * successfully otherwise None. 274 | */ 275 | def convertSpecialTimestamp(input: String, zoneId: ZoneId): Option[Timestamp] = ??? 276 | 277 | /** 278 | * Converts notational shorthands that are converted to ordinary dates. 279 | * 280 | * @param input A trimmed string 281 | * @param zoneId Zone identifier used to get the current date. 282 | * @return Some of days since the epoch if the conversion completed successfully otherwise None. 283 | */ 284 | def convertSpecialDate(input: String, zoneId: ZoneId): Option[Date] = ??? 285 | 286 | /** 287 | * Subtracts two dates expressed as days since 1970-01-01. 288 | * 289 | * @param endDay The end date, exclusive 290 | * @param startDay The start date, inclusive 291 | * @return An interval between two dates. The interval can be negative 292 | * if the end date is before the start date. 293 | */ 294 | def subtractDates(endDay: Date, startDay: Date): CalendarInterval = ??? 295 | 296 | } 297 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/ExprBuilders.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr, InternalRow} 21 | import org.apache.spark.sql.types.StringType 22 | 23 | /** 24 | * Writing match Patterns: 25 | * - One cannot write a pattern like this 26 | * {{{q"${l : sparkexpr.Expression} + ${r: sparkexpr.Expression }"}}} 27 | * because the tree for {{{1 + 2}}} is 28 | * {{{Apply(Select(Literal(Constant(1)), TermName("$plus")), List(Literal(Constant(2))))}}} 29 | * and the tree for {{{Array(5}}} is 30 | * {{{Apply(Select(Ident(scala.Array), TermName("apply")), List(Literal(Constant(5))))}}} 31 | * Pattern matching for {{{Array(5}}} triggers a recursive unapply on 32 | * {{{CatalystExpression}}} for the subtree {{{Ident(scala.Array)}}}, which doesn't match 33 | * any valid ''Expression'' pattern; so we attempt to ''evaluate'' the tree, and 34 | * this fails. 35 | */ 36 | trait ExprBuilders 37 | extends Arithmetics with Collections with Structs with Tuples 38 | with DateTime with Options with RecursiveSparkApply with Conditionals 39 | with Strings { self : ExprTranslator => 40 | 41 | import macroUniverse._ 42 | 43 | val predefTyp = typeOf[scala.Predef.type] 44 | 45 | object Literals { 46 | def unapply(t: mTree): Option[sparkexpr.Expression] = 47 | Option(t match { 48 | case Literal(Constant(v)) => 49 | scala.util.Try(sparkexpr.Literal(v)).toOption.orNull 50 | case _ => null 51 | }) 52 | } 53 | 54 | object Reference { 55 | 56 | private def exprInScope(nm : TermName) : Option[sparkexpr.Expression] = 57 | scope.get(nm).map(_.rhsExpr) 58 | 59 | def unapply(t: mTree): Option[sparkexpr.Expression] = 60 | Option(t match { 61 | case Ident(tNm : TermName) => exprInScope(tNm).orNull 62 | case tNm : TermName => exprInScope(tNm).orNull 63 | case _ => null 64 | }) 65 | } 66 | 67 | object StaticValue { 68 | def unapply(t: mTree): Option[sparkexpr.Expression] = { 69 | if (t.tpe =:= typeOf[org.apache.spark.unsafe.types.UTF8String]) { 70 | doWithWarning[sparkexpr.Expression](t, 71 | "evaluate to a static value", 72 | { 73 | val v = eval_tree(t) 74 | new sparkexpr.Literal(v, StringType) 75 | }) 76 | } else { 77 | (for (typInfo <- TypeInfo.unapply(t)) yield { 78 | doWithWarning[sparkexpr.Expression](t, 79 | "evaluate to a static value", 80 | { 81 | val v = eval_tree(t) 82 | val iRow = InternalRow(v) 83 | val lVal = typInfo.exprEnc.objSerializer.eval(iRow) 84 | new sparkexpr.Literal(lVal, typInfo.catalystType) 85 | }) 86 | }).flatten 87 | } 88 | } 89 | } 90 | 91 | private[sqlmacros] def binaryArgs(lT : mTree, rT : mTree) : 92 | Option[(sparkexpr.Expression, sparkexpr.Expression)] = { 93 | for ( 94 | l <- CatalystExpression.unapply(lT); 95 | r <- CatalystExpression.unapply(rT) 96 | ) yield (l, r) 97 | } 98 | 99 | object CatalystExpression { 100 | def unapply(t: mTree): Option[sparkexpr.Expression] = 101 | Option(t match { 102 | case Literals(e) => e 103 | case Reference(e) => e 104 | case BasicArith(e) => e 105 | case StringPatterns(e) => e 106 | case JavaMathFuncs(e) => e 107 | case CollectionTrees(e) => e 108 | case TupleConstruct(e) => e 109 | case FieldAccess(e) => e 110 | case StructConstruct(e) => e 111 | case DateTimePatterns(e) => e 112 | case OptionPatterns(e) => e 113 | case Predicates(e) => e 114 | case IFCase(e) => e 115 | case FunctionBuilderApplication(e) => e 116 | case StaticValue(e) => e 117 | case _ => null 118 | }) 119 | } 120 | 121 | object CatalystExpressions { 122 | def unapplySeq(tS: Seq[mTree]): Option[Seq[sparkexpr.Expression]] = 123 | SparkSQLMacroUtils.sequence(tS.map(CatalystExpression.unapply(_))) 124 | } 125 | 126 | implicit val toExpr = new Unliftable[sparkexpr.Expression] { 127 | def unapply(t: c.Tree): Option[sparkexpr.Expression] = CatalystExpression.unapply(t) 128 | } 129 | 130 | } 131 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/ExprOptimize.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 20 | import org.apache.spark.sql.types.IntegralType 21 | 22 | /** 23 | * Collapse [[sparkexpr.GetMapValue]], [[sparkexpr.GetStructField]] and 24 | * [[sparkexpr.GetArrayItem]] expressions. Also simplify `Unwrap <- Wrap` expression 25 | * sub-trees. 26 | */ 27 | trait ExprOptimize { 28 | self : ExprTranslator => 29 | 30 | private def hasStaticKeys(m: sparkexpr.CreateMap): Boolean = 31 | m.keys.forall(_.isInstanceOf[sparkexpr.Literal]) 32 | 33 | private def getValueExpr( 34 | m: sparkexpr.CreateMap, 35 | key: sparkexpr.Literal): sparkexpr.Expression = { 36 | var valExpr: sparkexpr.Expression = null 37 | 38 | for (i <- 0 until m.keys.size) { 39 | if (m.keys(i).canonicalized == key) { 40 | valExpr = m.values(i) 41 | } 42 | } 43 | 44 | if (valExpr == null) { 45 | valExpr = new sparkexpr.Literal(null, m.dataType.valueType) 46 | } 47 | valExpr 48 | } 49 | 50 | private def isValidField(s: sparkexpr.CreateNamedStruct, fIdx: Int): Boolean = { 51 | fIdx >= 0 && fIdx < s.dataType.fields.size 52 | } 53 | 54 | private def getFieldExpr(s: sparkexpr.CreateNamedStruct, fIdx: Int): sparkexpr.Expression = { 55 | s.valExprs(fIdx) 56 | } 57 | 58 | private def validIndex(a: sparkexpr.CreateArray, l: sparkexpr.Literal): Option[Int] = { 59 | if (l.dataType.isInstanceOf[IntegralType]) { 60 | val typ: IntegralType = l.dataType.asInstanceOf[IntegralType] 61 | val idx = typ.numeric.toInt(l.value.asInstanceOf[typ.InternalType]) 62 | if (idx >= 0 && idx < a.children.size) { 63 | Some(idx) 64 | } else None 65 | } else None 66 | } 67 | 68 | private def geArrEntryExpr(a: sparkexpr.CreateArray, eIdx: Int): sparkexpr.Expression = { 69 | a.children(eIdx) 70 | } 71 | 72 | def optimizeExpr(expr: sparkexpr.Expression): sparkexpr.Expression = expr transformUp { 73 | case sparkexpr.objects.UnwrapOption(_, sparkexpr.objects.WrapOption(c, _)) => c 74 | case sparkexpr.GetMapValue(cm: sparkexpr.CreateMap, k: sparkexpr.Literal, false) 75 | if hasStaticKeys(cm) => 76 | getValueExpr(cm, k) 77 | case e @ sparkexpr.GetStructField(s: sparkexpr.CreateNamedStruct, fIdx, _) 78 | if isValidField(s, fIdx) => 79 | getFieldExpr(s, fIdx) 80 | case e @ sparkexpr.GetArrayItem(a: sparkexpr.CreateArray, ordExpr: sparkexpr.Literal, _) => 81 | validIndex(a, ordExpr).map(geArrEntryExpr(a, _)).getOrElse(e) 82 | case e @ sparkexpr.GetMapValue(_: sparkexpr.Literal, _: sparkexpr.Literal, false) => 83 | val value = e.eval(null) 84 | sparkexpr.Literal(value, e.dataType) 85 | case e @ sparkexpr.GetStructField(l: sparkexpr.Literal, fIdx, _) => 86 | val value = e.eval(null) 87 | sparkexpr.Literal(value, e.dataType) 88 | case e @ sparkexpr.GetArrayItem(_: sparkexpr.Literal, _: sparkexpr.Literal, false) => 89 | val value = e.eval(null) 90 | sparkexpr.Literal(value, e.dataType) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/ExprTranslator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import scala.collection.mutable.{Map => MMap} 21 | 22 | import org.apache.spark.internal.Logging 23 | import org.apache.spark.sql.SparkSession 24 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 25 | import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder 26 | import org.apache.spark.sql.types.DataType 27 | 28 | trait ExprTranslator extends MacrosEnv with ExprBuilders with ExprOptimize with Logging { 29 | import macroUniverse._ 30 | 31 | object MacroTransException extends Exception 32 | 33 | def sparkSession : Option[SparkSession] 34 | 35 | val scope = MMap[mTermName, ValInfo]() 36 | 37 | /** 38 | * - log the issue and also record it in the context as a warning. 39 | * - don't issue a `c.abort` or `c.error` because it doesn't let us recover and return a `None`. 40 | * - not doing a `c.error` or `c.abort` is safe because we immediately throw and go into 41 | * 'exit mode'(ie to do any cleanup and return a None) 42 | * 43 | * @param tree 44 | * @param reason 45 | * @param kind 46 | */ 47 | def logIssue(tree : mTree, 48 | reason : String, 49 | kind : String) : Unit = { 50 | val msg = s"$reason: '${showCode(tree)}}'" 51 | logDebug( 52 | s"""Spark SQL Macro translation ${kind} at pos ${tree.pos.toString}: 53 | |msg""".stripMargin) 54 | c.warning(tree.pos, msg) 55 | } 56 | 57 | def quit(tree : mTree, 58 | reason : String) : Nothing = { 59 | logIssue(tree, reason, "error") 60 | throw MacroTransException 61 | } 62 | 63 | def warn(tree : mTree, reason : String) : Unit = logIssue(tree, reason, "warning") 64 | 65 | def doWithWarning[T](tree : mTree, 66 | action : String, 67 | fn : => T) : Option[T] = { 68 | try { 69 | Some(fn) 70 | } catch { 71 | case e : Exception => 72 | warn(tree, s"Failed to do $action(exception = ${e.getMessage})") 73 | None 74 | case e : Error => 75 | warn(tree, s"Failed to do $action(error = ${e.getMessage})") 76 | None 77 | } 78 | } 79 | 80 | def staticValue[T : TypeTag](tree : mTree, action : String) : Option[T] = { 81 | if (tree.tpe <:< typeOf[T]) { 82 | doWithWarning[T](tree, "evaluate ZoneId expression a static value", { 83 | val v = eval_tree(tree) 84 | v.asInstanceOf[T] 85 | }) 86 | } else None 87 | } 88 | 89 | /** 90 | * There should be nothing to do here. 91 | * For functions the [[ExpressionEncoder]] for the argument's [[DataType]] 92 | * will convert an internal value into a catalyst value. In case of macros 93 | * the catalyst value is fed to a catalyst Expression that is equivalent to 94 | * the function body. 95 | * @param typInfo 96 | * @param pos 97 | * @return 98 | */ 99 | private def paramSparkExpr(typInfo : TypeInfo, 100 | pos : Int) : sparkexpr.Expression = { 101 | MacroArg(pos, typInfo.catalystType) 102 | } 103 | 104 | def translateExprTree(tree : mTree) : sparkexpr.Expression = tree match { 105 | case CatalystExpression(e) => e 106 | case _ => quit(tree, "Not able to translate to a Spark Expression") 107 | } 108 | 109 | def translateValDef(tree : ValDef, 110 | exprBldr : Function2[TypeInfo, mTree, sparkexpr.Expression] 111 | ) : Unit = (tree, exprBldr) match { 112 | case ValInfo(_) => () 113 | case _ => quit(tree, "Not able to translate Value definition") 114 | } 115 | 116 | def translateParam(pos : Int, tree : mTree) : Unit = tree match { 117 | case v : ValDef => translateValDef(v, (typInfo, tree) => paramSparkExpr(typInfo, pos)) 118 | case _ => quit(tree, "Not able to translate function param") 119 | } 120 | 121 | 122 | def translateStat(tree : mTree) : Unit = tree match { 123 | case v : ValDef => translateValDef(v, (typInfo, tree) => translateExprTree(tree)) 124 | case _ => warn(tree, s"Ignoring statement") 125 | } 126 | 127 | 128 | def expressionEncoder(typ : mType) : ExpressionEncoder[_] = { 129 | _expressionEncoder(convertType(typ)) 130 | } 131 | 132 | def extractFuncParamsStats(fTree : mTree) : (Seq[mTree], Seq[mTree]) = fTree match { 133 | case q"(..$params) => {..$stats}" => (params, stats) 134 | case q"{(..$params) => {..$stats}}" => (params, stats) 135 | case _ => quit(fTree, "Not able to recognize function structure") 136 | } 137 | 138 | case class TypeInfo(mTyp : mType, 139 | rTyp : ruType, 140 | catalystType : DataType, 141 | _exprEnc : () => ExpressionEncoder[_]) { 142 | lazy val exprEnc = _exprEnc() 143 | override def toString: String = s"""${catalystType.toString}""" 144 | } 145 | object TypeInfo { 146 | def unapply(typTree : mTree) : Option[TypeInfo] = unapply(typTree.tpe) 147 | 148 | def unapply(typ : mType) : Option[TypeInfo] = (scala.util.Try { 149 | val rTyp = convertType(typ) 150 | val cSchema = MacrosScalaReflection.schemaFor(rTyp) 151 | TypeInfo(typ, rTyp, cSchema.dataType, () => _expressionEncoder(rTyp)) 152 | }).toOption 153 | } 154 | 155 | case class ValInfo(vDef : macroUniverse.ValDef, 156 | name : String, 157 | typInfo : TypeInfo, 158 | rhsExpr : sparkexpr.Expression) { 159 | override def toString: String = { 160 | s"""ValDef: 161 | | vDef: ${macroUniverse.show(vDef)} 162 | | name: ${name} 163 | | type: ${typInfo.toString}""".stripMargin 164 | } 165 | } 166 | 167 | object ValInfo { 168 | def unapply(arg : (mTree, Function2[TypeInfo, mTree, sparkexpr.Expression])) : Option[ValInfo] 169 | = { 170 | val (t, exprBldr) = arg 171 | t match { 172 | case vDef@ValDef(mods, tNm, TypeInfo(tInfo), rhsTree) => 173 | val nm = tNm.decodedName.toString 174 | val vInfo = ValInfo(vDef, nm, tInfo, exprBldr(tInfo, rhsTree)) 175 | scope(tNm) = vInfo 176 | Some(vInfo) 177 | case _ => None 178 | } 179 | } 180 | } 181 | 182 | /** 183 | * return the elem being operated on and args1 and args2 184 | */ 185 | object InstanceMethodCall { 186 | def unapply(t: mTree): Option[(mTree, Seq[mTree], Seq[mTree])] = t match { 187 | case q"$ent.$_" => Some((ent, Seq.empty, Seq.empty)) 188 | case q"$ent.$_[..$_]" => Some((ent, Seq.empty, Seq.empty)) 189 | case q"$ent.$_(..$args1)" => Some((ent, args1, Seq.empty)) 190 | case q"$ent.$_[..$_](..$args1)" => Some((ent, args1, Seq.empty)) 191 | case q"$ent.$_(..$args1)(..$args2)" => Some((ent, args1, args2)) 192 | case q"$ent.$_[..$_](..$args1)(..$args2)" => Some((ent, args1, args2)) 193 | case _ => None 194 | } 195 | } 196 | 197 | object ModuleMethodCall { 198 | def unapply(t: mTree): Option[(mTree, Seq[mTree], Seq[mTree])] = t match { 199 | case q"$id(..$args1)" => Some((id, args1, Seq.empty)) 200 | case q"$id[..$_](..$args1)" => Some((id, args1, Seq.empty)) 201 | case q"$id(..$args1)(..$args2)" => Some((id, args1, args2)) 202 | case q"$id[..$_](..$args1)(..$args2)" => Some((id, args1, args2)) 203 | case _ => None 204 | } 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/MacrosEnv.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import scala.reflect.macros.blackbox 21 | import scala.tools.reflect.ToolBox 22 | 23 | import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder 24 | 25 | trait MacrosEnv { 26 | 27 | val c: blackbox.Context 28 | 29 | val macroUniverse: c.universe.type = c.universe 30 | val runtimeUniverse = MacrosScalaReflection.universe 31 | type mType = macroUniverse.Type 32 | type mTree = macroUniverse.Tree 33 | type ruTree = runtimeUniverse.Tree 34 | type ruType = runtimeUniverse.Type 35 | type mTermName = macroUniverse.TermName 36 | 37 | lazy val ruToolBox = MacrosScalaReflection.mirror.mkToolBox() 38 | 39 | lazy val ruImporter = { 40 | val importer0 = runtimeUniverse.internal.createImporter(macroUniverse) 41 | importer0.asInstanceOf[runtimeUniverse.Importer {val from: macroUniverse.type}] 42 | } 43 | 44 | private[sqlmacros] def _expressionEncoder(rTyp : ruType) : ExpressionEncoder[_] = { 45 | MacroUtils.expressionEncoder(rTyp) 46 | } 47 | 48 | def convertTree(tree : mTree) : ruTree = { 49 | val imported = ruImporter.importTree(tree) 50 | val treeR = ruToolBox.untypecheck(imported.duplicate) 51 | ruToolBox.typecheck(treeR, ruToolBox.TERMmode) 52 | } 53 | 54 | def convertType(typ : mType) : ruType = { 55 | ruImporter.importType(typ) 56 | } 57 | 58 | def eval_tree(tree : mTree) : Any = { 59 | val e = c.Expr(c.untypecheck(tree.duplicate)) 60 | c.eval(e) 61 | } 62 | } 63 | 64 | object MacroUtils { 65 | 66 | import MacrosScalaReflection._ 67 | 68 | val ru = universe 69 | 70 | def expressionEncoder[T : ru.TypeTag](): ExpressionEncoder[T] = { 71 | import scala.reflect.ClassTag 72 | 73 | val tpe = ru.typeTag[T].in(mirror).tpe 74 | 75 | val cls = mirror.runtimeClass(tpe) 76 | val serializer = serializerForType(tpe) 77 | val deserializer = deserializerForType(tpe) 78 | 79 | new ExpressionEncoder[T]( 80 | serializer, 81 | deserializer, 82 | ClassTag[T](cls)) 83 | } 84 | 85 | def expressionEncoder(tpe : ru.Type) : ExpressionEncoder[_] = { 86 | import scala.reflect.ClassTag 87 | 88 | val cls = mirror.runtimeClass(tpe) 89 | val serializer = serializerForType(tpe) 90 | val deserializer = deserializerForType(tpe) 91 | 92 | new ExpressionEncoder( 93 | serializer, 94 | deserializer, 95 | ClassTag(cls)) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/Options.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 20 | 21 | trait Options { self: ExprTranslator => 22 | 23 | import macroUniverse._ 24 | 25 | object OptionPatterns { 26 | 27 | val someApplySym = typeOf[Some.type].member(TermName("apply")) 28 | 29 | def unapply(t: mTree): Option[sparkexpr.Expression] = 30 | t match { 31 | case q"$op.get" => 32 | for ( 33 | e <- CatalystExpression.unapply(op) 34 | ) yield sparkexpr.objects.UnwrapOption(e.dataType, e) 35 | case q"$id[$_]($op)" if id.symbol == someApplySym => 36 | for ( 37 | e <- CatalystExpression.unapply(op) 38 | ) yield sparkexpr.objects.WrapOption(e, e.dataType) 39 | case _ => None 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/PredicateUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import scala.language.implicitConversions 20 | 21 | /** 22 | * Provide scala marker functions to perform null-checks, null_safe_= 23 | * and in/not_in check on Arrays. 24 | */ 25 | object PredicateUtils { 26 | 27 | implicit def any2Preds(v : Any) : AnyWithPreds = new AnyWithPreds(v) 28 | 29 | // scalastyle:off 30 | class AnyWithPreds(val v : Any) extends AnyVal { 31 | def is_null : Boolean = ??? 32 | def is_not_null : Boolean = ??? 33 | def null_safe_eq(o : Any) : Boolean = ??? 34 | def in(a : Any*) : Boolean = ??? 35 | def not_in(a : Any*) : Boolean = ??? 36 | } 37 | // scalastyle:on 38 | } 39 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/RecursiveSparkApply.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr, FunctionIdentifier} 20 | import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder 21 | 22 | trait RecursiveSparkApply { self: ExprTranslator => 23 | 24 | import macroUniverse._ 25 | 26 | val reg_macros = typeOf[registered_macros.type] 27 | val dyn_applySym = reg_macros.member(TermName("applyDynamic")) 28 | 29 | private def funcBldr(nmLst : Seq[mTree]) : Option[FunctionBuilder] = { 30 | val macroNm = nmLst match { 31 | case Literal(Constant(x : String)) :: Nil => Some(x) 32 | case _ => None 33 | } 34 | for( 35 | nm <- macroNm; 36 | ss <- sparkSession; 37 | fb <- ss.sessionState.functionRegistry.lookupFunctionBuilder(FunctionIdentifier(nm)) 38 | ) yield fb 39 | } 40 | 41 | object FunctionBuilderApplication { 42 | def unapply(t: mTree): Option[sparkexpr.Expression] = 43 | t match { 44 | case q"$id(..$nmLst)(..$args)" if args.size == 1 => 45 | for ( 46 | fb <- funcBldr(nmLst); 47 | exprs <- CatalystExpressions.unapplySeq(args) 48 | ) yield fb(exprs) 49 | case _ => None 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/SQLMacro.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import scala.language.experimental.macros 21 | import scala.reflect.macros.blackbox._ 22 | 23 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 24 | 25 | // scalastyle:off line.size.limit 26 | class SQLMacro(val c : Context) extends ExprTranslator { 27 | 28 | lazy val sparkSession = SparkSQLMacroUtils.currentSparkSessionOption 29 | 30 | def buildExpression(params : Seq[mTree], 31 | stats : Seq[mTree]) : Option[sparkexpr.Expression] = { 32 | 33 | try { 34 | 35 | for ((p, i) <- params.zipWithIndex) { 36 | translateParam(i, p) 37 | } 38 | 39 | for (s <- stats.init) { 40 | translateStat(s) 41 | } 42 | 43 | Some(translateExprTree(stats.last)).map(optimizeExpr) 44 | 45 | } catch { 46 | case MacroTransException => None 47 | case e : Throwable => throw e 48 | } 49 | 50 | } 51 | 52 | def translateFunc(fTree : mTree) : Option[sparkexpr.Expression] = { 53 | val (params, stats) = extractFuncParamsStats(fTree) 54 | buildExpression(params, stats) 55 | } 56 | 57 | def udm1_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag](f : c.Expr[Function1[A1, RT]]) 58 | : c.Expr[Either[Function1[A1, RT], SQLMacroExpressionBuilder]] = { 59 | 60 | import macroUniverse._ 61 | val expr = translateFunc(f.tree) 62 | 63 | expr match { 64 | case Some(e) => 65 | val eSer = SQLMacroExpressionBuilder.serialize(e) 66 | c.Expr[Either[Function1[A1, RT], SQLMacroExpressionBuilder]]( 67 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 68 | case None => 69 | c.Expr[Either[Function1[A1, RT], SQLMacroExpressionBuilder]]( 70 | q"scala.util.Left(${f.tree})") 71 | } 72 | } 73 | 74 | // GENERATED using [[GenMacroFuncs] 75 | 76 | def udm2_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag](f : c.Expr[Function2[A1, A2, RT]]) 77 | : c.Expr[Either[Function2[A1, A2, RT], SQLMacroExpressionBuilder]] = { 78 | 79 | import macroUniverse._ 80 | val expr = translateFunc(f.tree) 81 | 82 | expr match { 83 | case Some(e) => 84 | val eSer = SQLMacroExpressionBuilder.serialize(e) 85 | c.Expr[Either[Function2[A1, A2, RT], SQLMacroExpressionBuilder]]( 86 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 87 | case None => 88 | c.Expr[Either[Function2[A1, A2, RT], SQLMacroExpressionBuilder]]( 89 | q"scala.util.Left(${f.tree})") 90 | } 91 | } 92 | 93 | def udm3_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag, A3 : c.WeakTypeTag](f : c.Expr[Function3[A1, A2, A3, RT]]) 94 | : c.Expr[Either[Function3[A1, A2, A3, RT], SQLMacroExpressionBuilder]] = { 95 | 96 | import macroUniverse._ 97 | val expr = translateFunc(f.tree) 98 | 99 | expr match { 100 | case Some(e) => 101 | val eSer = SQLMacroExpressionBuilder.serialize(e) 102 | c.Expr[Either[Function3[A1, A2, A3, RT], SQLMacroExpressionBuilder]]( 103 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 104 | case None => 105 | c.Expr[Either[Function3[A1, A2, A3, RT], SQLMacroExpressionBuilder]]( 106 | q"scala.util.Left(${f.tree})") 107 | } 108 | } 109 | 110 | def udm4_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag, A3 : c.WeakTypeTag, A4 : c.WeakTypeTag](f : c.Expr[Function4[A1, A2, A3, A4, RT]]) 111 | : c.Expr[Either[Function4[A1, A2, A3, A4, RT], SQLMacroExpressionBuilder]] = { 112 | 113 | import macroUniverse._ 114 | val expr = translateFunc(f.tree) 115 | 116 | expr match { 117 | case Some(e) => 118 | val eSer = SQLMacroExpressionBuilder.serialize(e) 119 | c.Expr[Either[Function4[A1, A2, A3, A4, RT], SQLMacroExpressionBuilder]]( 120 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 121 | case None => 122 | c.Expr[Either[Function4[A1, A2, A3, A4, RT], SQLMacroExpressionBuilder]]( 123 | q"scala.util.Left(${f.tree})") 124 | } 125 | } 126 | 127 | def udm5_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag, A3 : c.WeakTypeTag, A4 : c.WeakTypeTag, A5 : c.WeakTypeTag](f : c.Expr[Function5[A1, A2, A3, A4, A5, RT]]) 128 | : c.Expr[Either[Function5[A1, A2, A3, A4, A5, RT], SQLMacroExpressionBuilder]] = { 129 | 130 | import macroUniverse._ 131 | val expr = translateFunc(f.tree) 132 | 133 | expr match { 134 | case Some(e) => 135 | val eSer = SQLMacroExpressionBuilder.serialize(e) 136 | c.Expr[Either[Function5[A1, A2, A3, A4, A5, RT], SQLMacroExpressionBuilder]]( 137 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 138 | case None => 139 | c.Expr[Either[Function5[A1, A2, A3, A4, A5, RT], SQLMacroExpressionBuilder]]( 140 | q"scala.util.Left(${f.tree})") 141 | } 142 | } 143 | 144 | def udm6_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag, A3 : c.WeakTypeTag, A4 : c.WeakTypeTag, A5 : c.WeakTypeTag, A6 : c.WeakTypeTag](f : c.Expr[Function6[A1, A2, A3, A4, A5, A6, RT]]) 145 | : c.Expr[Either[Function6[A1, A2, A3, A4, A5, A6, RT], SQLMacroExpressionBuilder]] = { 146 | 147 | import macroUniverse._ 148 | val expr = translateFunc(f.tree) 149 | 150 | expr match { 151 | case Some(e) => 152 | val eSer = SQLMacroExpressionBuilder.serialize(e) 153 | c.Expr[Either[Function6[A1, A2, A3, A4, A5, A6, RT], SQLMacroExpressionBuilder]]( 154 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 155 | case None => 156 | c.Expr[Either[Function6[A1, A2, A3, A4, A5, A6, RT], SQLMacroExpressionBuilder]]( 157 | q"scala.util.Left(${f.tree})") 158 | } 159 | } 160 | 161 | def udm7_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag, A3 : c.WeakTypeTag, A4 : c.WeakTypeTag, A5 : c.WeakTypeTag, A6 : c.WeakTypeTag, A7 : c.WeakTypeTag](f : c.Expr[Function7[A1, A2, A3, A4, A5, A6, A7, RT]]) 162 | : c.Expr[Either[Function7[A1, A2, A3, A4, A5, A6, A7, RT], SQLMacroExpressionBuilder]] = { 163 | 164 | import macroUniverse._ 165 | val expr = translateFunc(f.tree) 166 | 167 | expr match { 168 | case Some(e) => 169 | val eSer = SQLMacroExpressionBuilder.serialize(e) 170 | c.Expr[Either[Function7[A1, A2, A3, A4, A5, A6, A7, RT], SQLMacroExpressionBuilder]]( 171 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 172 | case None => 173 | c.Expr[Either[Function7[A1, A2, A3, A4, A5, A6, A7, RT], SQLMacroExpressionBuilder]]( 174 | q"scala.util.Left(${f.tree})") 175 | } 176 | } 177 | 178 | def udm8_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag, A3 : c.WeakTypeTag, A4 : c.WeakTypeTag, A5 : c.WeakTypeTag, A6 : c.WeakTypeTag, A7 : c.WeakTypeTag, A8 : c.WeakTypeTag](f : c.Expr[Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]]) 179 | : c.Expr[Either[Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], SQLMacroExpressionBuilder]] = { 180 | 181 | import macroUniverse._ 182 | val expr = translateFunc(f.tree) 183 | 184 | expr match { 185 | case Some(e) => 186 | val eSer = SQLMacroExpressionBuilder.serialize(e) 187 | c.Expr[Either[Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], SQLMacroExpressionBuilder]]( 188 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 189 | case None => 190 | c.Expr[Either[Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], SQLMacroExpressionBuilder]]( 191 | q"scala.util.Left(${f.tree})") 192 | } 193 | } 194 | 195 | def udm9_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag, A3 : c.WeakTypeTag, A4 : c.WeakTypeTag, A5 : c.WeakTypeTag, A6 : c.WeakTypeTag, A7 : c.WeakTypeTag, A8 : c.WeakTypeTag, A9 : c.WeakTypeTag](f : c.Expr[Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]]) 196 | : c.Expr[Either[Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], SQLMacroExpressionBuilder]] = { 197 | 198 | import macroUniverse._ 199 | val expr = translateFunc(f.tree) 200 | 201 | expr match { 202 | case Some(e) => 203 | val eSer = SQLMacroExpressionBuilder.serialize(e) 204 | c.Expr[Either[Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], SQLMacroExpressionBuilder]]( 205 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 206 | case None => 207 | c.Expr[Either[Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], SQLMacroExpressionBuilder]]( 208 | q"scala.util.Left(${f.tree})") 209 | } 210 | } 211 | 212 | def udm10_impl[RT : c.WeakTypeTag, A1 : c.WeakTypeTag, A2 : c.WeakTypeTag, A3 : c.WeakTypeTag, A4 : c.WeakTypeTag, A5 : c.WeakTypeTag, A6 : c.WeakTypeTag, A7 : c.WeakTypeTag, A8 : c.WeakTypeTag, A9 : c.WeakTypeTag, A10 : c.WeakTypeTag](f : c.Expr[Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]]) 213 | : c.Expr[Either[Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], SQLMacroExpressionBuilder]] = { 214 | 215 | import macroUniverse._ 216 | val expr = translateFunc(f.tree) 217 | 218 | expr match { 219 | case Some(e) => 220 | val eSer = SQLMacroExpressionBuilder.serialize(e) 221 | c.Expr[Either[Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], SQLMacroExpressionBuilder]]( 222 | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder(${eSer}))") 223 | case None => 224 | c.Expr[Either[Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], SQLMacroExpressionBuilder]]( 225 | q"scala.util.Left(${f.tree})") 226 | } 227 | } 228 | } 229 | 230 | // scalastyle:off println 231 | /* UNCOMMENT to generate macro signature functions 232 | object GenMacroFuncs extends App { 233 | 234 | case class MacroCode(i : Int) { 235 | val argTypes : Seq[String] = (1 until i + 1).map(j => s"A$j") 236 | 237 | val funcType = { 238 | (argTypes :+ "RT").mkString(s"Function${i}[", ", ", "]") 239 | } 240 | 241 | def impl_method : String = { 242 | val methodNm = s"udm${i}_impl" 243 | 244 | val methodTypeArg = { 245 | val argTypeCtx : Seq[String] = argTypes.map(at => s"${at} : c.WeakTypeTag") 246 | ("RT : c.WeakTypeTag" +: argTypeCtx).mkString("[", ", ", "]") 247 | } 248 | 249 | val methodSig = 250 | s"""def ${methodNm}${methodTypeArg}(f : c.Expr[$funcType]) 251 | | : c.Expr[Either[${funcType}, SQLMacroExpressionBuilder]]""".stripMargin 252 | 253 | s"""$methodSig = { 254 | | 255 | | import macroUniverse._ 256 | | val expr = translateFunc(f.tree) 257 | | 258 | | expr match { 259 | | case Some(e) => 260 | | val eSer = SQLMacroExpressionBuilder.serialize(e) 261 | | c.Expr[Either[${funcType}, SQLMacroExpressionBuilder]]( 262 | | q"scala.util.Right(org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder($${eSer}))") 263 | | case None => 264 | | c.Expr[Either[${funcType}, SQLMacroExpressionBuilder]]( 265 | | q"scala.util.Left($${f.tree})") 266 | | } 267 | | }""".stripMargin 268 | } 269 | 270 | def udm_method : String = { 271 | val implmethodNm = s"udm${i}_impl" 272 | 273 | val methodTypeArg = ("RT" +: argTypes).mkString("[", ", ", "]") 274 | val methodSig = s"""def udm${methodTypeArg}(f: $funcType) : 275 | | Either[$funcType, SQLMacroExpressionBuilder]""".stripMargin 276 | 277 | val regMacroTypeArg = { 278 | val argTypeCtx : Seq[String] = argTypes.map(at => s"${at} : TypeTag") 279 | ("RT : TypeTag" +: argTypeCtx).mkString("[", ", ", "]") 280 | } 281 | 282 | s"""$methodSig = macro SQLMacro.${implmethodNm}${methodTypeArg} 283 | | 284 | |def registerMacro${regMacroTypeArg}(nm : String, 285 | | udm : Either[${funcType}, SQLMacroExpressionBuilder] 286 | | ) : Unit = { 287 | | udm match { 288 | | case Left(fn) => 289 | | sparkSession.udf.register(nm, udf(fn)) 290 | | case Right(sqlMacroBldr) => 291 | | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 292 | | } 293 | |} 294 | |""".stripMargin 295 | } 296 | 297 | } 298 | 299 | lazy val udm_impl_methods : Seq[String] = for (i <- (1 until 11)) yield MacroCode(i).impl_method 300 | 301 | lazy val udm_methods : Seq[String] = for (i <- (1 until 11)) yield MacroCode(i).udm_method 302 | 303 | 304 | // println(udm_impl_methods.mkString("\n\n")) 305 | 306 | println(udm_methods.mkString("\n\n")) 307 | 308 | } 309 | */ -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/SQLMacroExpressions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import java.nio.ByteBuffer 21 | 22 | import org.apache.spark.SparkConf 23 | import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerInstance} 24 | import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, LeafExpression, Unevaluable} 25 | import org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder.deserialize 26 | import org.apache.spark.sql.types.DataType 27 | 28 | 29 | /** 30 | * A [[org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder]] 31 | * for ''Spark SQL macros'' 32 | * 33 | * @param macroExpr 34 | */ 35 | case class SQLMacroExpressionBuilder(macroExprSer : Array[Byte]) 36 | extends Function1[Seq[Expression], Expression] { 37 | 38 | @transient lazy val macroExpr = deserialize(macroExprSer) 39 | 40 | override def apply(args: Seq[Expression]): Expression = { 41 | macroExpr transformUp { 42 | case MacroArg(argPos, dt) if argPos < args.size && 43 | Cast.canCast(args(argPos).dataType, dt) => 44 | if (dt == args(argPos).dataType) { 45 | args(argPos) 46 | } else { 47 | Cast(args(argPos), dt) 48 | } 49 | } 50 | } 51 | } 52 | 53 | /** 54 | * Represent holes in ''Macro SQL'' that will be filled in with the macro invocation argument 55 | * expressions. 56 | * 57 | * @param argPos 58 | * @param dataType 59 | */ 60 | @SerialVersionUID(-4890323739479048322L) 61 | case class MacroArg(argPos : Int, 62 | dataType : DataType) 63 | extends LeafExpression with Unevaluable { 64 | 65 | override def nullable: Boolean = true 66 | 67 | override def sql: String = s"macroarg($argPos)" 68 | 69 | } 70 | 71 | object SQLMacroExpressionBuilder { 72 | 73 | 74 | /** 75 | * This is not ideal. On each macro invocation we are setting up a [[JavaSerializer]]. 76 | * This is needed because the macro invocation runs in an independent ClassLoader. 77 | * 78 | * The alternate to this is to implement [[Liftable]] for all Expression classes. 79 | * This is a lot of work. Deferring this for now. 80 | * 81 | * The `deserialize` call happens within an `SparkEnv`; but we don't know 82 | * which [[Serializer]] is configured within it; so we use our own 83 | * [[JavaSerializer]] to deserialize the ''macroExprSer''. 84 | * 85 | */ 86 | def serializerInstance : SerializerInstance = { 87 | val sC = new SparkConf(false) 88 | val factory = new JavaSerializer(sC) 89 | factory.newInstance() 90 | } 91 | 92 | def serialize(e : Expression) : Array[Byte] = { 93 | val bb = serializerInstance.serialize[Expression](e) 94 | if (bb.hasArray) { 95 | bb.array() 96 | } else { 97 | val arr = new Array[Byte](bb.remaining()) 98 | bb.get(arr) 99 | arr 100 | } 101 | } 102 | 103 | def deserialize(arr : Array[Byte]) : Expression = { 104 | val bb = ByteBuffer.wrap(arr) 105 | serializerInstance.deserialize[Expression](bb) 106 | } 107 | } -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/SparkSQLMacroUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import java.util.Locale 21 | 22 | import scala.util.Random 23 | 24 | import org.apache.spark.{SparkConf, SparkContext, SparkEnv} 25 | import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession, SQLContext} 26 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 27 | import org.apache.spark.sql.internal.SQLConf 28 | import org.apache.spark.sql.types.{DataType, FractionalType, NumericType, StructType} 29 | import org.apache.spark.util.Utils 30 | 31 | object SparkSQLMacroUtils { 32 | 33 | def getLocalDir(conf: SparkConf): String = { 34 | Utils.getLocalDir(conf) 35 | } 36 | 37 | def isNumeric(dt: DataType): Boolean = NumericType.acceptsType(dt) 38 | def isApproximateNumeric(dt: DataType): Boolean = dt.isInstanceOf[FractionalType] 39 | 40 | def setLogLevel(logLevel: String): Unit = { 41 | val upperCased = logLevel.toUpperCase(Locale.ENGLISH) 42 | org.apache.spark.util.Utils.setLogLevel(org.apache.log4j.Level.toLevel(logLevel)) 43 | } 44 | 45 | def currentSparkSessionOption : Option[SparkSession] = { 46 | var spkSessionO = SparkSession.getActiveSession 47 | if (!spkSessionO.isDefined) { 48 | spkSessionO = SparkSession.getDefaultSession 49 | } 50 | spkSessionO 51 | } 52 | 53 | def currentSparkSession: SparkSession = { 54 | currentSparkSessionOption.getOrElse(???) 55 | } 56 | 57 | def getSparkClassLoader: ClassLoader = Utils.getSparkClassLoader 58 | 59 | def currentSparkContext: SparkContext = currentSparkSession.sparkContext 60 | 61 | def currentSQLConf: SQLConf = { 62 | var spkSessionO = SparkSession.getActiveSession 63 | if (!spkSessionO.isDefined) { 64 | spkSessionO = SparkSession.getDefaultSession 65 | } 66 | 67 | spkSessionO.map(_.sqlContext.conf).getOrElse { 68 | val sprkConf = SparkEnv.get.conf 69 | val sqlConf = new SQLConf 70 | sprkConf.getAll.foreach { 71 | case (k, v) => 72 | sqlConf.setConfString(k, v) 73 | } 74 | sqlConf 75 | } 76 | } 77 | 78 | def dataFrame(lP: LogicalPlan)(implicit sqlContext: SQLContext): DataFrame = { 79 | Dataset.ofRows(sqlContext.sparkSession, lP) 80 | } 81 | 82 | def defaultParallelism(sparkSession: SparkSession) : Int = 83 | sparkSession.sparkContext.schedulerBackend.defaultParallelism() 84 | 85 | def throwAnalysisException[T](msg: => String): T = { 86 | throw new AnalysisException(msg) 87 | } 88 | 89 | private val r = new Random() 90 | 91 | def nextRandomInt(n: Int): Int = r.nextInt(n) 92 | 93 | /** 94 | * from fpinscala book 95 | * 96 | * @param a 97 | * @tparam A 98 | * @return 99 | */ 100 | def sequence[A](a: Seq[Option[A]]): Option[Seq[A]] = 101 | a match { 102 | case Nil => Some(Nil) 103 | case s => s.head flatMap (hh => sequence(s.tail) map (hh +: _)) 104 | } 105 | } -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/StringUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import scala.language.implicitConversions 20 | 21 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 22 | 23 | // scalastyle:off 24 | /** 25 | * A marker interface for code inside Macros, that provides a way 26 | * for users to write scala code on 27 | * [[sparkexpr.Expression ] catalyst spark expressions]. 28 | */ 29 | object StringUtils { 30 | 31 | sealed trait ConcatWSArgType 32 | case class StringConcatWARg(s : String) extends ConcatWSArgType 33 | case class ArrStringConcatWARg(arr : Array[String]) extends ConcatWSArgType 34 | implicit def toConcatWSStr(s : String) : ConcatWSArgType = StringConcatWARg(s) 35 | implicit def toConcatWSArrStr(arr : Array[String]) : ConcatWSArgType = ArrStringConcatWARg(arr) 36 | 37 | def concatWs(sep : String, inputs : ConcatWSArgType*) : String = ??? 38 | 39 | def elt(n : Int, inputs : String*) : String = ??? 40 | def elt(n : Int, inputs : Array[Byte]*) : Array[Byte] = ??? 41 | 42 | def overlay(input : String, replace : String, pos : Int) : String = ??? 43 | def overlay(input : String, replace : String, pos : Int, len : Int) : String = ??? 44 | def overlay(input : Array[Byte], replace : Array[Byte], pos : Int) : String = ??? 45 | def overlay(input : Array[Byte], replace : Array[Byte], pos : Int, len : Int) : String = ??? 46 | 47 | def translate(input : String, from : String, to : String) : String = ??? 48 | } 49 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/Strings.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 21 | 22 | trait Strings { self : ExprTranslator => 23 | 24 | import macroUniverse._ 25 | 26 | val strTyp = typeOf[String] 27 | val strOpsTyp = typeOf[scala.collection.immutable.StringOps] 28 | val strUtilsTyp = typeOf[StringUtils.type] 29 | 30 | val toUpperSym = strTyp.member(TermName("toUpperCase")).alternatives 31 | val toLowerSym = strTyp.member(TermName("toLowerCase")).alternatives 32 | val replaceSym = strTyp.member(TermName("replace")).alternatives 33 | val trimSym = strTyp.member(TermName("trim")).alternatives 34 | val indexOfSym = strTyp.member(TermName("indexOf")).alternatives 35 | val substringSym = strTyp.member(TermName("substring")).alternatives 36 | val lengthSym = strTyp.member(TermName("length")).alternatives 37 | val startsWithSyms = strTyp.member(TermName("startsWith")).alternatives 38 | val endsWithSym = strTyp.member(TermName("endsWith")) 39 | val containsSym = strTyp.member(TermName("contains")) 40 | 41 | val strOpsTimesSym = strOpsTyp.member(TermName("$times")).alternatives 42 | 43 | val toConcatWSStrSym = strUtilsTyp.member(TermName("toConcatWSStr")).alternatives 44 | val toConcatWSArrStrSym = strUtilsTyp.member(TermName("toConcatWSArrStr")).alternatives 45 | val concatWsSym = strUtilsTyp.member(TermName("concatWs")).alternatives 46 | val eltSym = strUtilsTyp.member(TermName("elt")).alternatives 47 | val overlaySym = strUtilsTyp.member(TermName("overlay")).alternatives 48 | val translateSym = strUtilsTyp.member(TermName("translate")).alternatives 49 | 50 | /* 51 | * Notes: 52 | * 1. String.replace/replaceAll cannot be translated because 53 | * StringReplace replaces all occurrences of `search` with `replace`." 54 | */ 55 | 56 | /* 57 | functions TODO: 58 | 59 | ConcatWS(sep, Array[String] | String,...) <- Array[String].mkString(sep) 60 | Elt(idx, child1, child2, ...) 61 | 62 | Upper(child) <- str.toUpper 63 | Lower <- str.toLowerCase 64 | Contains <- str.contains 65 | StartsWith <- str.startsWith 66 | EndsWith <- str.endsWith 67 | 68 | StringReplace <- str.replace 69 | Overlay <- by a StringUtils func 70 | StringTranslate <- by a StringUtils func 71 | FindInSet <- by a StringUtils func 72 | 73 | StringTrim <- str.trim 74 | StringTrimLeft <- by a StringUtils func 75 | StringTrimRight <- by a StringUtils func 76 | 77 | StringInstr <- str.indexOf 78 | 79 | SubstringIndex <- by a StringUtils func 80 | StringLocate <- by a StringUtils func 81 | 82 | StringLPad <- by a StringUtils func 83 | StringRPad <- by a StringUtils func 84 | 85 | ParseUrl <- by a URLUtils 86 | 87 | FormatString <- no translate 88 | 89 | InitCap <- by a StringUtils func 90 | 91 | StringRepeat <- str * int 92 | StringSpace <- by a StringUtils func 93 | 94 | Substring <- str.substring 95 | Right <- by a StringUtils func 96 | Left <- by a StringUtils func 97 | 98 | Length <- str.length 99 | BitLength <- by a StringUtils func 100 | OctetLength <- by a StringUtils func 101 | Levenshtein <- by a StringUtils func 102 | SoundEx <- by a StringUtils func 103 | Ascii <- by a StringUtils func 104 | 105 | Chr <- by a StringUtils func 106 | Base64 <- by a StringUtils func 107 | UnBase64 <- by a StringUtils func 108 | Decode <- by a StringUtils func 109 | Encode <- by a StringUtils func 110 | FormatNumber <- by a StringUtils func 111 | Sentences <- by a StringUtils func 112 | 113 | Concat(in collectionOps.scala) <- "a" + "b" 114 | */ 115 | 116 | object StringPatterns { 117 | def unapply(t: mTree): Option[sparkexpr.Expression] = 118 | t match { 119 | case InstanceMethodCall(elem, args1, args2) => 120 | if (toUpperSym.contains(t.symbol)) { 121 | for (strE <- CatalystExpression.unapply(elem)) yield sparkexpr.Upper(strE) 122 | } else if (toLowerSym.contains(t.symbol) ) { 123 | for (strE <- CatalystExpression.unapply(elem)) yield sparkexpr.Lower(strE) 124 | } else if (trimSym.contains(t.symbol)) { 125 | for (strE <- CatalystExpression.unapply(elem)) yield sparkexpr.StringTrim(strE) 126 | } else if (indexOfSym.contains(t.symbol)) { 127 | for ( 128 | strE <- CatalystExpression.unapply(elem); 129 | argE <- CatalystExpression.unapply(args1(0)) if args1(0).tpe <:< strTyp 130 | ) yield sparkexpr.StringInstr(strE, argE) 131 | } else if (substringSym.contains(t.symbol) && args1.size == 2) { 132 | for ( 133 | strE <- CatalystExpression.unapply(elem); 134 | posE <- CatalystExpression.unapply(args1(0)) if args1(0).tpe <:< typeOf[Int]; 135 | lenE <- CatalystExpression.unapply(args1(1)) if args1(1).tpe <:< typeOf[Int] 136 | ) yield sparkexpr.Substring(strE, posE, lenE) 137 | } else if (substringSym.contains(t.symbol)) { 138 | for ( 139 | strE <- CatalystExpression.unapply(elem); 140 | posE <- CatalystExpression.unapply(args1(0)) if args1(0).tpe <:< typeOf[Int] 141 | ) yield sparkexpr.Substring(strE, posE, sparkexpr.Literal(Int.MaxValue)) 142 | } else if (lengthSym.contains(t.symbol)) { 143 | for (strE <- CatalystExpression.unapply(elem)) yield sparkexpr.Length(strE) 144 | } else None 145 | case ModuleMethodCall(id, args1, args2) => 146 | if (eltSym.contains(id.symbol) && args1.size > 1) { 147 | for ( 148 | nE <- CatalystExpression.unapply(args1(0)); 149 | inputsE <- CatalystExpressions.unapplySeq(args1.tail) 150 | ) yield sparkexpr.Elt(nE +: inputsE) 151 | } else if (concatWsSym.contains(id.symbol) && args1.size > 1) { 152 | for ( 153 | sepE <- CatalystExpression.unapply(args1(0)); 154 | inputsE <- ConcatWSArgs.unapplySeq(args1.tail) 155 | ) yield sparkexpr.ConcatWs(sepE +: inputsE) 156 | } else None 157 | case q"$l(..$str).$m(..$args)" if strOpsTimesSym.contains(t.symbol) => 158 | for ( 159 | timesE <- CatalystExpression.unapply(args(0)) if args(0).tpe <:< typeOf[Int]; 160 | strE <- CatalystExpression.unapply(str(0)) if str(0).tpe <:< typeOf[String] 161 | ) yield sparkexpr.StringRepeat(strE, timesE) 162 | case _ => None 163 | case _ => None 164 | } 165 | } 166 | 167 | object ConcatWSArgs { 168 | 169 | def unapplySeq(tS: Seq[mTree]): Option[Seq[sparkexpr.Expression]] = 170 | SparkSQLMacroUtils.sequence(tS.map(ConcatWSArg.unapply(_))) 171 | 172 | object ConcatWSArg { 173 | def unapply(t: mTree): Option[sparkexpr.Expression] = t match { 174 | case q"$id(..$args)" if toConcatWSStrSym.contains(id.symbol) && args.size == 1 => 175 | CatalystExpression.unapply(args(0)) 176 | case q"$id(..$args)" if toConcatWSArrStrSym.contains(id.symbol) && args.size == 1 => 177 | CatalystExpression.unapply(args(0)) 178 | case _ => None 179 | } 180 | } 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/Structs.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 20 | import org.apache.spark.sql.types.{DataType, StructType} 21 | 22 | trait Structs { self: ExprTranslator => 23 | 24 | import macroUniverse._ 25 | 26 | private object VarAndFieldName { 27 | def unapply(t : mTree) : 28 | Option[(mTermName, String)] = t match { 29 | case Select(Ident(varNm), TermName(fNm)) if varNm.isTermName => 30 | Some((varNm.toTermName, fNm)) 31 | case _ => None 32 | } 33 | } 34 | 35 | private def fieldIndex(dt : DataType, 36 | fNm : String) : Option[Int] = { 37 | dt match { 38 | case sT : StructType => sT.getFieldIndex(fNm) 39 | case _ => None 40 | } 41 | } 42 | 43 | object FieldAccess { 44 | def unapply(t: mTree): Option[sparkexpr.Expression] = 45 | t match { 46 | case VarAndFieldName(varNm, fNm) => 47 | for ( 48 | vInfo <- scope.get(varNm); 49 | fIdx <- fieldIndex(vInfo.typInfo.catalystType, fNm) 50 | ) yield sparkexpr.GetStructField(vInfo.rhsExpr, fIdx, Some(fNm)) 51 | case q"$l.$r" => 52 | for ( 53 | lExpr <- CatalystExpression.unapply(l); 54 | fNm = r.decodedName.toString; 55 | fIdx <- fieldIndex(lExpr.dataType, fNm) 56 | ) yield sparkexpr.GetStructField(lExpr, fIdx, Some(fNm)) 57 | case _ => None 58 | } 59 | } 60 | 61 | object StructConstruct { 62 | 63 | private def isADTConstruction(applyTree : mTree) : Boolean = { 64 | applyTree.tpe != null && 65 | applyTree.symbol == applyTree.tpe.resultType.companion.member(TermName("apply")) 66 | } 67 | 68 | private def isCandidateType(dt : DataType, 69 | numArgs : Int) : Boolean = { 70 | 71 | dt match { 72 | case sT: StructType if sT.fields.size == numArgs => true 73 | case _ => false 74 | } 75 | } 76 | 77 | private def isCandidateParms(params : Seq[sparkexpr.Expression], 78 | sT : StructType) : Boolean = { 79 | params.zip(sT.fields.map(_.dataType)).forall { 80 | case (e, dt) => sparkexpr.Cast.canCast(e.dataType, dt) 81 | } 82 | } 83 | 84 | def unapply(t: mTree): Option[sparkexpr.Expression] = 85 | t match { 86 | case q"$id(..$args)" if isADTConstruction(id) => 87 | for ( 88 | typInfo <- TypeInfo.unapply(t) if isCandidateType(typInfo.catalystType, args.size); 89 | sT = typInfo.catalystType.asInstanceOf[StructType]; 90 | exprs <- CatalystExpressions.unapplySeq(args) if isCandidateParms(exprs, sT) 91 | ) yield { 92 | val params = exprs.zip(sT.fieldNames).flatMap { 93 | case (e, fN) => Seq(sparkexpr.Literal(fN), e) 94 | } 95 | sparkexpr.CreateNamedStruct(params) 96 | } 97 | case _ => None 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/Tuples.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.sqlmacros 19 | 20 | import org.apache.spark.sql.catalyst.{expressions => sparkexpr} 21 | 22 | trait Tuples { self: ExprTranslator => 23 | 24 | import macroUniverse._ 25 | 26 | object TupleConstruct { 27 | /** 28 | Generated by: 29 | {{{ 30 | println((for(i <- 2 until 23) yield s"""val tup${i}ApplySym = typeOf[scala.Tuple${i}.type].member(TermName("apply"))""").mkString("\n")) 31 | }}} 32 | in scala shell 33 | */ 34 | 35 | val tup2ApplySym = typeOf[scala.Tuple2.type].member(TermName("apply")) 36 | val tup3ApplySym = typeOf[scala.Tuple3.type].member(TermName("apply")) 37 | val tup4ApplySym = typeOf[scala.Tuple4.type].member(TermName("apply")) 38 | val tup5ApplySym = typeOf[scala.Tuple5.type].member(TermName("apply")) 39 | val tup6ApplySym = typeOf[scala.Tuple6.type].member(TermName("apply")) 40 | val tup7ApplySym = typeOf[scala.Tuple7.type].member(TermName("apply")) 41 | val tup8ApplySym = typeOf[scala.Tuple8.type].member(TermName("apply")) 42 | val tup9ApplySym = typeOf[scala.Tuple9.type].member(TermName("apply")) 43 | val tup10ApplySym = typeOf[scala.Tuple10.type].member(TermName("apply")) 44 | val tup11ApplySym = typeOf[scala.Tuple11.type].member(TermName("apply")) 45 | val tup12ApplySym = typeOf[scala.Tuple12.type].member(TermName("apply")) 46 | val tup13ApplySym = typeOf[scala.Tuple13.type].member(TermName("apply")) 47 | val tup14ApplySym = typeOf[scala.Tuple14.type].member(TermName("apply")) 48 | val tup15ApplySym = typeOf[scala.Tuple15.type].member(TermName("apply")) 49 | val tup16ApplySym = typeOf[scala.Tuple16.type].member(TermName("apply")) 50 | val tup17ApplySym = typeOf[scala.Tuple17.type].member(TermName("apply")) 51 | val tup18ApplySym = typeOf[scala.Tuple18.type].member(TermName("apply")) 52 | val tup19ApplySym = typeOf[scala.Tuple19.type].member(TermName("apply")) 53 | val tup20ApplySym = typeOf[scala.Tuple20.type].member(TermName("apply")) 54 | val tup21ApplySym = typeOf[scala.Tuple21.type].member(TermName("apply")) 55 | val tup22ApplySym = typeOf[scala.Tuple22.type].member(TermName("apply")) 56 | 57 | val arrAssocSym = typeOf[scala.Predef.ArrowAssoc[_]].member(TermName("$minus$greater")) 58 | 59 | /** 60 | cases generated by 61 | {{{ 62 | for (i <- 3 until 23) yield { 63 | val patStr = (for (j <- 1 until i + 1) yield s"$$a_$j").mkString("(", ", ", ")") 64 | val treeSeq = (for (j <- 1 until i + 1) yield s"a_$j").mkString("Seq(", ", ", ")") 65 | s"""case q"${patStr}" if t.symbol == tup${i}ApplySym => 66 | CatalystExpressions.unapplySeq(${treeSeq}). 67 | map(es => sparkexpr.CreateStruct(es))""".stripMargin 68 | } 69 | }}} 70 | */ 71 | 72 | // scalastyle:off line.size.limit 73 | 74 | def unapply(t: mTree): Option[sparkexpr.Expression] = 75 | t match { 76 | case q"($l, $r)" if t.symbol == tup2ApplySym => 77 | CatalystExpressions.unapplySeq(Seq(l, r)). 78 | map(es => sparkexpr.CreateStruct(es)) 79 | case q"Predef.ArrowAssoc[$_]($l).->[$_]($r)" if t.symbol == arrAssocSym => 80 | CatalystExpressions.unapplySeq(Seq(l, r)). 81 | map(es => sparkexpr.CreateStruct(es)) 82 | case q"($a_1, $a_2, $a_3)" if t.symbol == tup3ApplySym => 83 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3)). 84 | map(es => sparkexpr.CreateStruct(es)) 85 | case q"($a_1, $a_2, $a_3, $a_4)" if t.symbol == tup4ApplySym => 86 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4)). 87 | map(es => sparkexpr.CreateStruct(es)) 88 | case q"($a_1, $a_2, $a_3, $a_4, $a_5)" if t.symbol == tup5ApplySym => 89 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5)). 90 | map(es => sparkexpr.CreateStruct(es)) 91 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6)" if t.symbol == tup6ApplySym => 92 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6)). 93 | map(es => sparkexpr.CreateStruct(es)) 94 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7)" if t.symbol == tup7ApplySym => 95 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7)). 96 | map(es => sparkexpr.CreateStruct(es)) 97 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8)" if t.symbol == tup8ApplySym => 98 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8)). 99 | map(es => sparkexpr.CreateStruct(es)) 100 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9)" if t.symbol == tup9ApplySym => 101 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9)). 102 | map(es => sparkexpr.CreateStruct(es)) 103 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10)" if t.symbol == tup10ApplySym => 104 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10)). 105 | map(es => sparkexpr.CreateStruct(es)) 106 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11)" if t.symbol == tup11ApplySym => 107 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11)). 108 | map(es => sparkexpr.CreateStruct(es)) 109 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12)" if t.symbol == tup12ApplySym => 110 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12)). 111 | map(es => sparkexpr.CreateStruct(es)) 112 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13)" if t.symbol == tup13ApplySym => 113 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13)). 114 | map(es => sparkexpr.CreateStruct(es)) 115 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14)" if t.symbol == tup14ApplySym => 116 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14)). 117 | map(es => sparkexpr.CreateStruct(es)) 118 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14, $a_15)" if t.symbol == tup15ApplySym => 119 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14, a_15)). 120 | map(es => sparkexpr.CreateStruct(es)) 121 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14, $a_15, $a_16)" if t.symbol == tup16ApplySym => 122 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14, a_15, a_16)). 123 | map(es => sparkexpr.CreateStruct(es)) 124 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14, $a_15, $a_16, $a_17)" if t.symbol == tup17ApplySym => 125 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14, a_15, a_16, a_17)). 126 | map(es => sparkexpr.CreateStruct(es)) 127 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14, $a_15, $a_16, $a_17, $a_18)" if t.symbol == tup18ApplySym => 128 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14, a_15, a_16, a_17, a_18)). 129 | map(es => sparkexpr.CreateStruct(es)) 130 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14, $a_15, $a_16, $a_17, $a_18, $a_19)" if t.symbol == tup19ApplySym => 131 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14, a_15, a_16, a_17, a_18, a_19)). 132 | map(es => sparkexpr.CreateStruct(es)) 133 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14, $a_15, $a_16, $a_17, $a_18, $a_19, $a_20)" if t.symbol == tup20ApplySym => 134 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14, a_15, a_16, a_17, a_18, a_19, a_20)). 135 | map(es => sparkexpr.CreateStruct(es)) 136 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14, $a_15, $a_16, $a_17, $a_18, $a_19, $a_20, $a_21)" if t.symbol == tup21ApplySym => 137 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14, a_15, a_16, a_17, a_18, a_19, a_20, a_21)). 138 | map(es => sparkexpr.CreateStruct(es)) 139 | case q"($a_1, $a_2, $a_3, $a_4, $a_5, $a_6, $a_7, $a_8, $a_9, $a_10, $a_11, $a_12, $a_13, $a_14, $a_15, $a_16, $a_17, $a_18, $a_19, $a_20, $a_21, $a_22)" if t.symbol == tup22ApplySym => 140 | CatalystExpressions.unapplySeq(Seq(a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_10, a_11, a_12, a_13, a_14, a_15, a_16, a_17, a_18, a_19, a_20, a_21, a_22)). 141 | map(es => sparkexpr.CreateStruct(es)) 142 | case _ => None 143 | } 144 | } 145 | } -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql 18 | 19 | /** 20 | * Spark SQL Macros provides a capability to register custom functions into a [[SparkSession]]. 21 | * This is similar to [[UDFRegistration]]. The difference being SQL Macro attempts to generate 22 | * an equivalent [[Expression]] for the function body. 23 | * 24 | * Given a function registration: 25 | * {{{ 26 | * spark.udf.register("intUDF", (i: Int) => { 27 | *val j = 2 28 | *i + j 29 | *}) 30 | * }}} 31 | * The following query(assuming `sparktest.unit_test` has a column `c_int : Int`): 32 | * {{{ 33 | * select intUDF(c_int) 34 | * from unit_test 35 | * where intUDF(c_int) > 1 36 | * }}} 37 | * generates the following physical plan: 38 | * {{{ 39 | +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 40 | |plan 41 | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 42 | |== Physical Plan == 43 | *(1) Project [if (isnull(c_int#9)) null else intUDF(knownnotnull(c_int#9)) AS intUDF(c_int)#10] 44 | +- *(1) Filter (if (isnull(c_int#9)) null else intUDF(knownnotnull(c_int#9)) > 1) 45 | +- *(1) ColumnarToRow 46 | +- FileScan parquet default.unit_test[c_int#9] Batched: true, DataFilters: [(if (isnull(c_int#9)) null else intUDF(knownnotnull(c_int#9)) > 1)], Format: Parquet, Location: InMemoryFileIndex[file:/private/var/folders/qy/qtpc2h2n3sn74gfxpjr6nqdc0000gn/T/warehouse-8b1e79b..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct 47 | 48 | | 49 | +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 50 | * 51 | * }}} 52 | * The `intUDF` is invoked in the `Filter operator` for evaluating the `intUDF(c_int) > 1` predicate; 53 | * and in the `Project operator` to evaluate the projection `intUDF(c_int)` 54 | * 55 | * But the `intUDF` is a trivial function that just adds `2` to its argument. 56 | * With Spark SQL Macros you can register the function as a macro like this: 57 | * {{{ 58 | * 59 | * import org.apache.spark.sql.defineMacros._ 60 | * 61 | * spark.registerMacro("intUDM", spark.udm((i: Int) => { 62 | * val j = 2 63 | * i + j 64 | * })) 65 | * }}} 66 | * The query: 67 | * {{{ 68 | * select intUDM(c_int) 69 | * from sparktest.unit_test 70 | * where intUDM(c_int) < 0 71 | * }}} 72 | * generates the following physical plan: 73 | * {{{ 74 | +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 75 | |plan 76 | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 77 | |== Physical Plan == 78 | *(1) Project [(c_int#9 + 1) AS (c_int + 1)#27] 79 | +- *(1) Filter (isnotnull(c_int#9) AND ((c_int#9 + 1) > 1)) 80 | +- *(1) ColumnarToRow 81 | +- FileScan parquet default.unit_test[c_int#9] Batched: true, DataFilters: [isnotnull(c_int#9), ((c_int#9 + 1) > 1)], Format: Parquet, Location: InMemoryFileIndex[file:/private/var/folders/qy/qtpc2h2n3sn74gfxpjr6nqdc0000gn/T/warehouse-8b1e79b..., PartitionFilters: [], PushedFilters: [IsNotNull(c_int)], ReadSchema: struct 82 | 83 | | 84 | +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 85 | * }}} 86 | * The predicate `intUDM(c_int) < 0` becomes `("C_INT" + 1) < 0` 87 | * and the projection `intUDM(c_int)` becomes `"C_INT" + 2`. 88 | * 89 | * '''DESIGN NOTES''' 90 | * 91 | * '''Injection of Static Values:''' 92 | * We allow macro call-site static values to be used in the macro code. 93 | * These values need to be translated to catalyst expression trees. 94 | * Spark's [[org.apache.spark.sql.catalyst.ScalaReflection]] and already 95 | * provides a mechanism for inferring and converting to catalyst expressions 96 | * (via [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s) 97 | * values of supported types. We leverage 98 | * this mechanism. But in order to leverage it we need to stand-up 99 | * a runtime Universe inside the macro invocation. This is fine because 100 | * [[SQLMacro]] is invoked in an env. that has all the Spark classes in the 101 | * classpath. The only issue it that we cannot use the Thread Classloader 102 | * of the Macro invocation. For this reason [[MacrosScalaReflection]] 103 | * is a copy of [[org.apache.spark.sql.catalyst.ScalaReflection]] with its 104 | * `mirror` setup on `org.apache.spark.util.Utils.getSparkClassLoader` 105 | * 106 | * '''Transferring Catalyst Expression Tree by Serialization:''' 107 | * Instead of developing a new builder capability to construct 108 | * macro universe Trees of catalyst Expressions, we directly construct 109 | * catalyst Expressions. To Lift these catalyst Expressions back to 110 | * the runtime world we use the serialization mechanism of catalyst 111 | * Expressions. So the [[SQLMacroExpressionBuilder]] is constructed 112 | * with the serialized form of the catalyst Expression that represents 113 | * the original macro code. In the runtime world this serialized form 114 | * is deserialized and on macro invocation [[MacroArg]] positions 115 | * are replaced with the Catalyst expressions at the invocation site. 116 | * 117 | */ 118 | package object sqlmacros {} 119 | -------------------------------------------------------------------------------- /macros/src/main/scala/org/apache/spark/sql/sqlmacros/registered_macros.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql.sqlmacros 18 | 19 | import scala.language.dynamics 20 | 21 | // scalastyle:off 22 | object registered_macros extends Dynamic { 23 | 24 | def applyDynamic(name: String)(args: Any*) = ??? 25 | } 26 | // scalastyle:on -------------------------------------------------------------------------------- /macros/src/test/scala/macrotest/ExampleStructs.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package macrotest 19 | 20 | object ExampleStructs { 21 | 22 | case class Point(x: Int, y: Int) 23 | 24 | } 25 | -------------------------------------------------------------------------------- /project/Assembly.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import sbt._ 19 | import sbt.Keys.fullClasspath 20 | import sbtassembly.AssemblyPlugin.autoImport.{ 21 | MergeStrategy, 22 | PathList, 23 | assemblyExcludedJars, 24 | assemblyMergeStrategy, 25 | assemblyOption 26 | } 27 | import sbtassembly.AssemblyKeys.assembly 28 | 29 | 30 | object Assembly { 31 | 32 | def assemblyPredicate(d: Attributed[File]): Boolean = { 33 | true 34 | } 35 | 36 | lazy val assemblySettings = 37 | Seq( 38 | assemblyOption in assembly := 39 | (assemblyOption in assembly).value.copy(includeScala = false), 40 | assemblyExcludedJars in assembly := { 41 | val cp = (fullClasspath in assembly).value 42 | cp filter assemblyPredicate 43 | }, 44 | assemblyMergeStrategy in assembly := { 45 | case PathList("META-INF", "MANIFEST.MF") => MergeStrategy.discard 46 | case PathList("META-INF", "maven", ps @ _*) => MergeStrategy.first 47 | case PathList("META-INF", "services", ps @ _*) => MergeStrategy.first 48 | case PathList("com", "fasterxml", "jackson", "annotation", _*) => MergeStrategy.first 49 | case PathList(ps @ _*) if ps.last == "pom.properties" => MergeStrategy.first 50 | case x => 51 | val oldStrategy = (assemblyMergeStrategy in assembly).value 52 | oldStrategy(x) 53 | } 54 | ) 55 | 56 | } 57 | -------------------------------------------------------------------------------- /project/Dependencies.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | import sbt.{ModuleID, _} 19 | 20 | object Dependencies { 21 | import Versions._ 22 | 23 | object scala { 24 | val dependencies = Seq( 25 | "org.scala-lang.modules" %% "scala-xml" % scalaXMLVersion % "provided", 26 | "org.scala-lang" % "scala-compiler" % scalaVersion % "provided", 27 | "org.scala-lang" % "scala-reflect" % scalaVersion % "provided", 28 | "org.scala-lang.modules" %% "scala-parser-combinators" % scalaParseCombVersion % "provided") 29 | } 30 | 31 | object spark { 32 | val dependencies = Seq( 33 | "org.apache.spark" %% "spark-core" % sparkVersion % "provided", 34 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", 35 | "org.apache.spark" %% "spark-hive" % sparkVersion % "provided", 36 | "org.apache.spark" %% "spark-hive" % sparkVersion % "test" classifier "tests", 37 | "org.apache.spark" %% "spark-hive-thriftserver" % sparkVersion % "provided", 38 | "org.apache.spark" %% "spark-repl" % sparkVersion % "provided", 39 | "org.apache.spark" %% "spark-unsafe" % sparkVersion % "provided") 40 | } 41 | 42 | object utils { 43 | val dependencies = Seq( 44 | "org.json4s" %% "json4s-jackson" % json4sVersion % "provided", 45 | "org.slf4j" % "slf4j-api" % slf4jVersion % "provided", 46 | "org.slf4j" % "slf4j-log4j12" % slf4jVersion % "provided", 47 | "org.slf4j" % "jul-to-slf4j" % slf4jVersion % "provided", 48 | "org.slf4j" % "jcl-over-slf4j" % slf4jVersion % "provided", 49 | "log4j" % "log4j" % log4jVersion % "provided") 50 | } 51 | 52 | object test_infra { 53 | val dependencies = Seq( 54 | "org.scalatest" %% "scalatest" % scalatestVersion % "test", 55 | "org.apache.hadoop" % "hadoop-client" % hadoopVersion % "test", 56 | "org.apache.derby" % "derby" % derbyVersion % "test", 57 | "org.scalacheck" %% "scalacheck" % "1.14.1" % "test") 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /project/Versions.scala: -------------------------------------------------------------------------------- 1 | object Versions { 2 | val javaVersion = "1.8" 3 | val scalaVersion = "2.12.10" 4 | val oraVersion = "19.6.0.0" 5 | val sparkVersion = "3.1.0" 6 | val log4jVersion = "1.2.17" 7 | val slf4jVersion = "1.7.30" 8 | val json4sVersion = "3.6.6" 9 | val scalaXMLVersion = "1.2.0" 10 | val scalaParseCombVersion = "1.1.2" 11 | 12 | val scalatestVersion = "3.0.8" 13 | val hadoopVersion = "3.2.0" 14 | val hiveVersion = "2.3.7" 15 | val derbyVersion = "10.12.1.1" 16 | 17 | val sparksqlMacrosVersion = "0.1.0-SNAPSHOT" 18 | } 19 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.2.8 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") 19 | 20 | // sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. 21 | libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.25" 22 | 23 | // checkstyle uses guava 23.0. 24 | libraryDependencies += "com.google.guava" % "guava" % "23.0" 25 | 26 | addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2") 27 | 28 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") 29 | 30 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.10") 31 | 32 | addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.7.5") 33 | 34 | addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "1.0.0") -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 17 | 39 | 40 | 41 | Scalastyle standard configuration 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | true 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW 126 | 127 | 128 | 129 | 130 | 131 | ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 144 | 145 | 146 | 147 | ^println$ 148 | 152 | 153 | 154 | 155 | spark(.sqlContext)?.sparkContext.hadoopConfiguration 156 | 165 | 166 | 167 | 168 | @VisibleForTesting 169 | 172 | 173 | 174 | 175 | Runtime\.getRuntime\.addShutdownHook 176 | 184 | 185 | 186 | 187 | mutable\.SynchronizedBuffer 188 | 196 | 197 | 198 | 199 | Class\.forName 200 | 207 | 208 | 209 | 210 | Await\.result 211 | 218 | 219 | 220 | 221 | Await\.ready 222 | 229 | 230 | 231 | 232 | (\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\))) 233 | 242 | 243 | 244 | 245 | throw new \w+Error\( 246 | 253 | 254 | 255 | 256 | 257 | JavaConversions 258 | Instead of importing implicits in scala.collection.JavaConversions._, import 259 | scala.collection.JavaConverters._ and use .asScala / .asJava methods 260 | 261 | 262 | 263 | org\.apache\.commons\.lang\. 264 | Use Commons Lang 3 classes (package org.apache.commons.lang3.*) instead 265 | of Commons Lang 2 (package org.apache.commons.lang.*) 266 | 267 | 268 | 269 | extractOpt 270 | Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter 271 | is slower. 272 | 273 | 274 | 275 | 276 | java,scala,3rdParty,spark 277 | javax?\..* 278 | scala\..* 279 | (?!org\.apache\.spark\.).* 280 | org\.apache\.spark\..* 281 | 282 | 283 | 284 | 285 | 286 | COMMA 287 | 288 | 289 | 290 | 291 | 292 | \)\{ 293 | 296 | 297 | 298 | 299 | (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] 300 | Use Javadoc style indentation for multiline comments 301 | 302 | 303 | 304 | case[^\n>]*=>\s*\{ 305 | Omit braces in case clauses. 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 800> 359 | 360 | 361 | 362 | 363 | 30 364 | 365 | 366 | 367 | 368 | 10 369 | 370 | 371 | 372 | 373 | 50 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | -1,0,1,2,3 385 | 386 | 387 | 388 | -------------------------------------------------------------------------------- /sql/src/main/scala/org/apache/spark/sql/defineMacros.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql 19 | 20 | import scala.language.experimental.macros 21 | import scala.language.implicitConversions 22 | import scala.reflect.runtime.universe.TypeTag 23 | 24 | import org.apache.spark.sql.functions._ 25 | import org.apache.spark.sql.sqlmacros.{SQLMacro, _} 26 | 27 | // scalastyle:off 28 | object defineMacros { 29 | // scalastyle:on 30 | 31 | // scalastyle:off line.size.limit 32 | 33 | class SparkSessionMacroExt(val sparkSession: SparkSession) extends AnyVal { 34 | 35 | def udm[RT, A1](f: Function1[A1, RT]) : 36 | Either[Function1[A1, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm1_impl[RT, A1] 37 | 38 | def registerMacro[RT : TypeTag, A1 : TypeTag](nm : String, 39 | udm : Either[Function1[A1, RT], SQLMacroExpressionBuilder] 40 | ) : Unit = { 41 | udm match { 42 | case Left(fn) => 43 | sparkSession.udf.register(nm, udf(fn)) 44 | case Right(sqlMacroBldr) => 45 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 46 | } 47 | } 48 | 49 | def udm[RT, A1, A2](f: Function2[A1, A2, RT]) : 50 | Either[Function2[A1, A2, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm2_impl[RT, A1, A2] 51 | 52 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag](nm : String, 53 | udm : Either[Function2[A1, A2, RT], SQLMacroExpressionBuilder] 54 | ) : Unit = { 55 | udm match { 56 | case Left(fn) => 57 | sparkSession.udf.register(nm, udf(fn)) 58 | case Right(sqlMacroBldr) => 59 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 60 | } 61 | } 62 | 63 | // GENERATED using [[GenMacroFuncs] 64 | 65 | def udm[RT, A1, A2, A3](f: Function3[A1, A2, A3, RT]) : 66 | Either[Function3[A1, A2, A3, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm3_impl[RT, A1, A2, A3] 67 | 68 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag, A3 : TypeTag](nm : String, 69 | udm : Either[Function3[A1, A2, A3, RT], SQLMacroExpressionBuilder] 70 | ) : Unit = { 71 | udm match { 72 | case Left(fn) => 73 | sparkSession.udf.register(nm, udf(fn)) 74 | case Right(sqlMacroBldr) => 75 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 76 | } 77 | } 78 | 79 | 80 | def udm[RT, A1, A2, A3, A4](f: Function4[A1, A2, A3, A4, RT]) : 81 | Either[Function4[A1, A2, A3, A4, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm4_impl[RT, A1, A2, A3, A4] 82 | 83 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag, A3 : TypeTag, A4 : TypeTag](nm : String, 84 | udm : Either[Function4[A1, A2, A3, A4, RT], SQLMacroExpressionBuilder] 85 | ) : Unit = { 86 | udm match { 87 | case Left(fn) => 88 | sparkSession.udf.register(nm, udf(fn)) 89 | case Right(sqlMacroBldr) => 90 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 91 | } 92 | } 93 | 94 | 95 | def udm[RT, A1, A2, A3, A4, A5](f: Function5[A1, A2, A3, A4, A5, RT]) : 96 | Either[Function5[A1, A2, A3, A4, A5, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm5_impl[RT, A1, A2, A3, A4, A5] 97 | 98 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag, A3 : TypeTag, A4 : TypeTag, A5 : TypeTag](nm : String, 99 | udm : Either[Function5[A1, A2, A3, A4, A5, RT], SQLMacroExpressionBuilder] 100 | ) : Unit = { 101 | udm match { 102 | case Left(fn) => 103 | sparkSession.udf.register(nm, udf(fn)) 104 | case Right(sqlMacroBldr) => 105 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 106 | } 107 | } 108 | 109 | 110 | def udm[RT, A1, A2, A3, A4, A5, A6](f: Function6[A1, A2, A3, A4, A5, A6, RT]) : 111 | Either[Function6[A1, A2, A3, A4, A5, A6, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm6_impl[RT, A1, A2, A3, A4, A5, A6] 112 | 113 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag, A3 : TypeTag, A4 : TypeTag, A5 : TypeTag, A6 : TypeTag](nm : String, 114 | udm : Either[Function6[A1, A2, A3, A4, A5, A6, RT], SQLMacroExpressionBuilder] 115 | ) : Unit = { 116 | udm match { 117 | case Left(fn) => 118 | sparkSession.udf.register(nm, udf(fn)) 119 | case Right(sqlMacroBldr) => 120 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 121 | } 122 | } 123 | 124 | 125 | def udm[RT, A1, A2, A3, A4, A5, A6, A7](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]) : 126 | Either[Function7[A1, A2, A3, A4, A5, A6, A7, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm7_impl[RT, A1, A2, A3, A4, A5, A6, A7] 127 | 128 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag, A3 : TypeTag, A4 : TypeTag, A5 : TypeTag, A6 : TypeTag, A7 : TypeTag](nm : String, 129 | udm : Either[Function7[A1, A2, A3, A4, A5, A6, A7, RT], SQLMacroExpressionBuilder] 130 | ) : Unit = { 131 | udm match { 132 | case Left(fn) => 133 | sparkSession.udf.register(nm, udf(fn)) 134 | case Right(sqlMacroBldr) => 135 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 136 | } 137 | } 138 | 139 | 140 | def udm[RT, A1, A2, A3, A4, A5, A6, A7, A8](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]) : 141 | Either[Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm8_impl[RT, A1, A2, A3, A4, A5, A6, A7, A8] 142 | 143 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag, A3 : TypeTag, A4 : TypeTag, A5 : TypeTag, A6 : TypeTag, A7 : TypeTag, A8 : TypeTag](nm : String, 144 | udm : Either[Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], SQLMacroExpressionBuilder] 145 | ) : Unit = { 146 | udm match { 147 | case Left(fn) => 148 | sparkSession.udf.register(nm, udf(fn)) 149 | case Right(sqlMacroBldr) => 150 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 151 | } 152 | } 153 | 154 | 155 | def udm[RT, A1, A2, A3, A4, A5, A6, A7, A8, A9](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]) : 156 | Either[Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm9_impl[RT, A1, A2, A3, A4, A5, A6, A7, A8, A9] 157 | 158 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag, A3 : TypeTag, A4 : TypeTag, A5 : TypeTag, A6 : TypeTag, A7 : TypeTag, A8 : TypeTag, A9 : TypeTag](nm : String, 159 | udm : Either[Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], SQLMacroExpressionBuilder] 160 | ) : Unit = { 161 | udm match { 162 | case Left(fn) => 163 | sparkSession.udf.register(nm, udf(fn)) 164 | case Right(sqlMacroBldr) => 165 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 166 | } 167 | } 168 | 169 | 170 | def udm[RT, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]) : 171 | Either[Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], SQLMacroExpressionBuilder] = macro SQLMacro.udm10_impl[RT, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10] 172 | 173 | def registerMacro[RT : TypeTag, A1 : TypeTag, A2 : TypeTag, A3 : TypeTag, A4 : TypeTag, A5 : TypeTag, A6 : TypeTag, A7 : TypeTag, A8 : TypeTag, A9 : TypeTag, A10 : TypeTag](nm : String, 174 | udm : Either[Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], SQLMacroExpressionBuilder] 175 | ) : Unit = { 176 | udm match { 177 | case Left(fn) => 178 | sparkSession.udf.register(nm, udf(fn)) 179 | case Right(sqlMacroBldr) => 180 | sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(nm, sqlMacroBldr) 181 | } 182 | } 183 | } 184 | 185 | implicit def ssWithMacros(ss : SparkSession) : SparkSessionMacroExt = new SparkSessionMacroExt(ss) 186 | 187 | } 188 | -------------------------------------------------------------------------------- /sql/src/test/scala/org/apache/spark/sql/AbstractTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql 19 | 20 | import org.scalatest.{fixture, BeforeAndAfterAll} 21 | 22 | import org.apache.spark.internal.Logging 23 | import org.apache.spark.sql.hive.test.sqlmacros.TestSQLMacrosHive 24 | import org.apache.spark.sql.sqlmacros.SQLMacroExpressionBuilder 25 | 26 | abstract class AbstractTest 27 | extends fixture.FunSuite 28 | with fixture.TestDataFixture 29 | with BeforeAndAfterAll 30 | with Logging { 31 | 32 | import scala.tools.reflect.ToolBox 33 | import org.apache.spark.sql.catalyst.ScalaReflection._ 34 | import universe._ 35 | 36 | protected val tb = mirror.mkToolBox() 37 | 38 | protected def eval(fnTree : Tree) : Either[_, SQLMacroExpressionBuilder] = { 39 | tb.eval( 40 | q"""{ 41 | new org.apache.spark.sql.defineMacros.SparkSessionMacroExt( 42 | org.apache.spark.sql.hive.test.sqlmacros.TestSQLMacrosHive.sparkSession 43 | ).udm(${fnTree}) 44 | } 45 | """).asInstanceOf[Either[_, SQLMacroExpressionBuilder]] 46 | } 47 | 48 | protected def register(nm : String, fnTree : Tree): Unit = { 49 | tb.eval( 50 | q"""{ 51 | import org.apache.spark.sql.defineMacros._ 52 | val ss = org.apache.spark.sql.hive.test.sqlmacros.TestSQLMacrosHive.sparkSession 53 | ss.registerMacro($nm,ss.udm(${fnTree})) 54 | }""" 55 | ) 56 | } 57 | 58 | // scalastyle:off println 59 | protected def handleMacroOutput(r: Either[Any, SQLMacroExpressionBuilder]) = { 60 | r match { 61 | case Left(fn) => println(s"Failed to create expression for ${fn}") 62 | case Right(fb) => 63 | val s = fb.macroExpr.treeString(false).split("\n"). 64 | map(s => if (s.length > 100) s.substring(0, 97) + "..." else s).mkString("\n") 65 | println( 66 | s"""Spark SQL expression is 67 | |${fb.macroExpr.sql}""".stripMargin) 68 | } 69 | } 70 | 71 | def printOut(s : => String) : Unit = { 72 | println(s) 73 | } 74 | 75 | // scalastyle:on 76 | 77 | 78 | override def beforeAll(): Unit = { 79 | TestSQLMacrosHive.sql( 80 | """create table if not exists unit_test( 81 | | c_varchar2_40 string, 82 | | c_number decimal(38,18), 83 | | c_int int 84 | |) 85 | |using parquet""".stripMargin 86 | ) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /sql/src/test/scala/org/apache/spark/sql/CollectionMacrosTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql 18 | 19 | import org.apache.spark.unsafe.types.CalendarInterval 20 | 21 | class CollectionMacrosTest extends AbstractTest { 22 | 23 | import org.apache.spark.sql.catalyst.ScalaReflection._ 24 | import universe._ 25 | 26 | test("arrFuncs") { td => 27 | 28 | handleMacroOutput(eval(reify {(s : String) => 29 | Array(s).size 30 | }.tree) 31 | ) 32 | 33 | handleMacroOutput(eval(reify {(s : String) => 34 | Array(s).zip[String, String, Array[(String, String)]](Array(s)) 35 | }.tree) 36 | ) 37 | 38 | handleMacroOutput(eval(reify {(s : String) => 39 | (Array(s, s).mkString, Array(s, s).mkString(", ")) 40 | }.tree) 41 | ) 42 | 43 | handleMacroOutput(eval(reify {(s : String) => 44 | (Array(s, s).min[String], Array(s, s).max[String]) 45 | }.tree) 46 | ) 47 | 48 | handleMacroOutput(eval(reify {(s : String) => 49 | Array(s, s).slice(0, 2).reverse ++[String, Array[String]] Array(s, s) 50 | }.tree) 51 | ) 52 | 53 | handleMacroOutput(eval(reify {(s : String) => 54 | (Array(Array(s, s), Array(s, s)).flatten[String] intersect[String] Array(s, s)).distinct 55 | }.tree) 56 | ) 57 | 58 | handleMacroOutput(eval(reify {(s : String) => 59 | Array.fill[String](5)(s) 60 | }.tree) 61 | ) 62 | } 63 | 64 | test("collectionUtilFuncs") { td => 65 | import org.apache.spark.sql.sqlmacros.CollectionUtils._ 66 | 67 | handleMacroOutput(eval(reify {(s : String) => 68 | mapFromEntries(mapEntries(Map(s -> s))) 69 | }.tree) 70 | ) 71 | 72 | handleMacroOutput(eval(reify {(s : String) => 73 | overlapArrays( 74 | sortArray(shuffleArray(Array(s, s)), true), 75 | Array(s, s) 76 | ) 77 | }.tree) 78 | ) 79 | 80 | handleMacroOutput(eval(reify {(s : String) => 81 | positionArray(Array(s, s), s) 82 | }.tree) 83 | ) 84 | 85 | handleMacroOutput(eval(reify {(start : Int, stop : Int) => 86 | sequence(start, stop, 1) 87 | }.tree) 88 | ) 89 | 90 | handleMacroOutput(eval(reify {(start : java.sql.Date, stop : java.sql.Date) => 91 | date_sequence(start, stop, new CalendarInterval(0,0, 1000L)) 92 | }.tree) 93 | ) 94 | 95 | handleMacroOutput(eval(reify {(start : java.sql.Timestamp, stop : java.sql.Timestamp) => 96 | timestamp_sequence(start, stop, new CalendarInterval(0,0, 1000L)) 97 | }.tree) 98 | ) 99 | 100 | handleMacroOutput(eval(reify {(s : String) => 101 | exceptArray(removeArray(Array(s, s), s), Array(s)) 102 | }.tree) 103 | ) 104 | 105 | handleMacroOutput(eval(reify {(s : String) => 106 | mapKeys(Map(s -> s) ++ Map(s -> s)) ++[String, Array[String]] mapValues(Map(s -> s)) 107 | }.tree) 108 | ) 109 | } 110 | 111 | } 112 | -------------------------------------------------------------------------------- /sql/src/test/scala/org/apache/spark/sql/MacrosTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql 19 | 20 | import org.apache.spark.sql.hive.test.sqlmacros.TestSQLMacrosHive 21 | 22 | class MacrosTest extends AbstractTest { 23 | 24 | import org.apache.spark.sql.defineMacros._ 25 | import org.apache.spark.sql.catalyst.ScalaReflection._ 26 | import universe._ 27 | 28 | test("compileTime") { td => 29 | handleMacroOutput(TestSQLMacrosHive.sparkSession.udm((i: Int) => i)) 30 | 31 | handleMacroOutput(TestSQLMacrosHive.sparkSession.udm((i: Int) => i + 1)) 32 | 33 | handleMacroOutput(TestSQLMacrosHive.sparkSession.udm((i: Int) => { 34 | val j = 5 35 | j 36 | })) 37 | 38 | handleMacroOutput(TestSQLMacrosHive.sparkSession.udm((i: Int) => { 39 | val b = Array(5) 40 | val j = 5 41 | j 42 | })) 43 | } 44 | 45 | test("basics") {td => 46 | 47 | handleMacroOutput(eval(q"(i : Int) => i")) 48 | 49 | handleMacroOutput(eval(q"(i : java.lang.Integer) => i")) 50 | 51 | // val a = Array(5) 52 | // handleMacroOutput(eval[Int, Int](q"(i : Int) => a(0)")) 53 | 54 | handleMacroOutput(eval(q"(i : Int) => i + 5")) 55 | 56 | handleMacroOutput(eval( 57 | q"""{(i : Int) => 58 | val b = Array(5) 59 | val j = 5 60 | j 61 | }""")) 62 | 63 | handleMacroOutput(eval(reify { 64 | (i : Int) => org.apache.spark.SPARK_BRANCH.length + i 65 | }.tree)) 66 | 67 | handleMacroOutput(eval(reify {(i : Int) => 68 | val b = Array(5, 6) 69 | val j = b(0) 70 | i + j + Math.abs(j)}.tree)) 71 | 72 | handleMacroOutput(eval( 73 | q"""{(i : Int) => 74 | val b = Array(5) 75 | val j = 5 76 | j 77 | }""")) 78 | } 79 | 80 | test("udts") {td => 81 | import macrotest.ExampleStructs.Point 82 | handleMacroOutput(eval( 83 | reify {(p : Point) => 84 | Point(1, 2) 85 | }.tree 86 | ) 87 | ) 88 | 89 | handleMacroOutput(eval( 90 | reify {(p : Point) => 91 | p.x + p.y 92 | }.tree 93 | ) 94 | ) 95 | 96 | handleMacroOutput(eval( 97 | reify {(p : Point) => 98 | Point(p.x + p.y, p.y) 99 | }.tree 100 | ) 101 | ) 102 | } 103 | 104 | test("optimizeExpr") { td => 105 | import macrotest.ExampleStructs.Point 106 | handleMacroOutput(eval( 107 | reify {(p : Point) => 108 | val p1 = Point(p.x, p.y) 109 | val a = Array(1) 110 | val m = Map(1 -> 2) 111 | p1.x + p1.y + a(0) + m(1) 112 | }.tree 113 | ) 114 | ) 115 | } 116 | 117 | test("tuples") {td => 118 | handleMacroOutput(eval( 119 | reify {(t : Tuple2[Int, Int]) => 120 | (t._2, t._1) 121 | }.tree 122 | ) 123 | ) 124 | 125 | handleMacroOutput(eval( 126 | reify {(t : Tuple2[Int, Int]) => 127 | t._2 -> t._1 128 | }.tree 129 | ) 130 | ) 131 | 132 | handleMacroOutput(eval( 133 | reify {(t : Tuple4[Float, Double, Int, Int]) => 134 | (t._4 + t._3, t._4) 135 | }.tree 136 | ) 137 | ) 138 | } 139 | 140 | test("arrays") {td => 141 | handleMacroOutput(eval( 142 | reify {(i : Int) => 143 | val b = Array(5, i) 144 | val j = b(0) 145 | j + b(1) 146 | }.tree)) 147 | } 148 | 149 | test("maps") {td => 150 | handleMacroOutput(eval( 151 | reify {(i : Int) => 152 | val b = Map(0 -> i, 1 -> (i + 1)) 153 | val j = b(0) 154 | j + b(1) 155 | }.tree)) 156 | } 157 | 158 | test("datetimes") {td => 159 | import java.sql.Date 160 | import java.sql.Timestamp 161 | import java.time.ZoneId 162 | import java.time.Instant 163 | import org.apache.spark.unsafe.types.CalendarInterval 164 | import org.apache.spark.sql.sqlmacros.DateTimeUtils._ 165 | 166 | handleMacroOutput(eval( 167 | reify {(dt : Date) => 168 | val dtVal = dt 169 | val dtVal2 = new Date(System.currentTimeMillis()) 170 | val tVal = new Timestamp(System.currentTimeMillis()) 171 | val dVal3 = localDateToDays(java.time.LocalDate.of(2000, 1, 1)) 172 | val t2 = instantToMicros(Instant.now()) 173 | val t3 = stringToTimestamp("2000-01-01", ZoneId.systemDefault()).get 174 | val t4 = daysToMicros(dtVal, ZoneId.systemDefault()) 175 | getDayInYear(dtVal) + getDayOfMonth(dtVal) + getDayOfWeek(dtVal2) + 176 | getHours(tVal, ZoneId.systemDefault) + getSeconds(t2, ZoneId.systemDefault) + 177 | getMinutes(t3, ZoneId.systemDefault()) + 178 | getDayInYear(dateAddMonths(dtVal, getMonth(dtVal2))) + 179 | getDayInYear(dVal3) + 180 | getHours( 181 | timestampAddInterval(t4, new CalendarInterval(1, 1, 1), ZoneId.systemDefault()), 182 | ZoneId.systemDefault) + 183 | getDayInYear(dateAddInterval(dtVal, new CalendarInterval(1, 1, 1L))) + 184 | monthsBetween(t2, t3, true, ZoneId.systemDefault()) + 185 | getDayOfMonth(getNextDateForDayOfWeek(dtVal2, "MO")) + 186 | getDayInYear(getLastDayOfMonth(dtVal2)) + getDayOfWeek(truncDate(dtVal, "week")) + 187 | getHours(toUTCTime(t3, ZoneId.systemDefault().toString), ZoneId.systemDefault()) 188 | }.tree)) 189 | } 190 | 191 | test("taxAndDiscount") { td => 192 | import org.apache.spark.sql.sqlmacros.DateTimeUtils._ 193 | import java.sql.Date 194 | import java.time.ZoneId 195 | 196 | handleMacroOutput(eval( 197 | reify { (prodCat : String, amt : Double) => 198 | val taxRate = prodCat match { 199 | case "grocery" => 0.0 200 | case "alcohol" => 10.5 201 | case _ => 9.5 202 | } 203 | val currDate = currentDate(ZoneId.systemDefault()) 204 | val discount = if (getDayOfWeek(currDate) == 1 && prodCat == "alcohol") 0.05 else 0.0 205 | 206 | amt * ( 1.0 - discount) * (1.0 + taxRate) 207 | 208 | }.tree)) 209 | } 210 | 211 | test("taxAndDiscountMultiMacro") { td => 212 | import org.apache.spark.sql.sqlmacros.DateTimeUtils._ 213 | import java.sql.Date 214 | import java.time.ZoneId 215 | 216 | import org.apache.spark.sql.sqlmacros.registered_macros 217 | 218 | register("taxRate", reify {(prodCat : String) => 219 | prodCat match { 220 | case "grocery" => 0.0 221 | case "alcohol" => 10.5 222 | case _ => 9.5 223 | } 224 | }.tree) 225 | 226 | register("discount", reify {(prodCat : String) => 227 | val currDate = currentDate(ZoneId.systemDefault()) 228 | if (getDayOfWeek(currDate) == 1 && prodCat == "alcohol") 0.05 else 0.0 229 | }.tree) 230 | 231 | register("taxAndDiscount", reify {(prodCat : String, amt : Double) => 232 | val taxRate : Double = registered_macros.taxRate(prodCat) 233 | val discount : Double = registered_macros.discount(prodCat) 234 | amt * ( 1.0 - discount) * (1.0 + taxRate) 235 | }.tree) 236 | 237 | val dfM = 238 | TestSQLMacrosHive.sql( 239 | "select taxAndDiscount(c_varchar2_40, c_number) from unit_test" 240 | ) 241 | printOut( 242 | s"""Macro based Plan: 243 | |${dfM.queryExecution.analyzed}""".stripMargin 244 | ) 245 | 246 | } 247 | 248 | test("conditionals") { td => 249 | import org.apache.spark.sql.sqlmacros.PredicateUtils._ 250 | import macrotest.ExampleStructs.Point 251 | 252 | handleMacroOutput(eval( 253 | reify { (i: Int) => 254 | val j = if (i > 7 && i < 20 && i.is_not_null) { 255 | i 256 | } else if (i == 6 || i.in(4, 5) ) { 257 | i + 1 258 | } else i + 2 259 | val k = i match { 260 | case 1 => i + 2 261 | case _ => i + 3 262 | } 263 | val l = (j, k) match { 264 | case (1, 2) => 1 265 | case (3, 4) => 2 266 | case _ => 3 267 | } 268 | val p = Point(k, l) 269 | val m = p match { 270 | case Point(1, 2) => 1 271 | case _ => 2 272 | } 273 | j + k + l + m 274 | }.tree)) 275 | 276 | handleMacroOutput(eval( 277 | reify { (s: String) => 278 | val i = if (s.endsWith("abc")) 1 else 0 279 | val j = if (s.contains("abc")) 1 else 0 280 | val k = if (s.is_not_null && s.not_in("abc")) 1 else 0 281 | i + j + k 282 | }.tree)) 283 | } 284 | 285 | test("macroVsFuncPlan") { td => 286 | 287 | TestSQLMacrosHive.sparkSession.registerMacro("intUDM", { 288 | TestSQLMacrosHive.sparkSession.udm((i: Int) => i + 1) 289 | }) 290 | 291 | TestSQLMacrosHive.udf.register("intUDF", (i: Int) => i + 1) 292 | 293 | printOut("Function based Plan:") 294 | TestSQLMacrosHive.sql("explain select intUDF(c_int) from unit_test where intUDF(c_int) > 1"). 295 | show(1000, false) 296 | 297 | printOut("Macro based Plan:") 298 | TestSQLMacrosHive.sql("explain select intUDM(c_int) from unit_test where intUDM(c_int) > 1"). 299 | show(1000, false) 300 | 301 | } 302 | 303 | test("macroPlan") { td => 304 | 305 | import TestSQLMacrosHive.sparkSession.implicits._ 306 | 307 | TestSQLMacrosHive.sparkSession.registerMacro("m1", 308 | TestSQLMacrosHive.sparkSession.udm( 309 | {(i : Int) => 310 | val b = Array(5, 6) 311 | val j = b(0) 312 | val k = new java.sql.Date(System.currentTimeMillis()).getTime 313 | i + j + k + Math.abs(j) 314 | } 315 | ) 316 | ) 317 | 318 | val dfM = TestSQLMacrosHive.sql("select m1(c_int) from unit_test") 319 | printOut( 320 | s"""Macro based Plan: 321 | |${dfM.queryExecution.analyzed}""".stripMargin 322 | ) 323 | 324 | } 325 | 326 | test("macroWithinMacro") { td => 327 | 328 | import org.apache.spark.sql.sqlmacros.registered_macros 329 | 330 | register("m2", reify {(i : Int) => 331 | val b = Array(5, 6) 332 | val j = b(0) 333 | val k = new java.sql.Date(System.currentTimeMillis()).getTime 334 | i + j + k + Math.abs(j) 335 | }.tree) 336 | 337 | register("m3", reify {(i : Int) => 338 | val l : Int = registered_macros.m2(i) 339 | i + l 340 | }.tree) 341 | 342 | val dfM = TestSQLMacrosHive.sql("select m3(c_int) from unit_test") 343 | printOut( 344 | s"""Macro based Plan: 345 | |${dfM.queryExecution.analyzed}""".stripMargin 346 | ) 347 | } 348 | 349 | } 350 | -------------------------------------------------------------------------------- /sql/src/test/scala/org/apache/spark/sql/StringMacrosTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.sql 18 | 19 | import org.apache.spark.sql.hive.test.sqlmacros.TestSQLMacrosHive 20 | 21 | class StringMacrosTest extends AbstractTest { 22 | 23 | import org.apache.spark.sql.defineMacros._ 24 | import org.apache.spark.sql.catalyst.ScalaReflection._ 25 | import universe._ 26 | 27 | test("basics") { td => 28 | handleMacroOutput(eval(reify {(s : String) => 29 | (s * 3).toLowerCase(). 30 | toUpperCase(). 31 | trim. 32 | substring(5). 33 | substring(0, 5). 34 | indexOf("a")}.tree) 35 | ) 36 | } 37 | 38 | test("stringOps") { td => 39 | import org.apache.spark.sql.sqlmacros.StringUtils._ 40 | 41 | handleMacroOutput(eval(reify {(s : String) => 42 | elt(0, concatWs(" ", s, s, Array(s, s)), s * 3) 43 | }.tree) 44 | ) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /sql/src/test/scala/org/apache/spark/sql/hive/test/sqlmacros/TestSQLMacrosHive.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.hive.test.sqlmacros 19 | 20 | import org.apache.spark.{SparkConf, SparkContext} 21 | import org.apache.spark.internal.config 22 | import org.apache.spark.internal.config.UI.UI_ENABLED 23 | import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation 24 | import org.apache.spark.sql.hive.HiveUtils 25 | import org.apache.spark.sql.hive.test.TestHiveContext 26 | import org.apache.spark.sql.internal.SQLConf 27 | import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH 28 | 29 | object SQLMacrosTestConf { 30 | 31 | lazy val localConf = new SparkConf() 32 | .set("spark.sql.test", "") 33 | .set(SQLConf.CODEGEN_FALLBACK.key, "false") 34 | .set( 35 | HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, 36 | "org.apache.spark.sql.hive.execution.PairSerDe") 37 | .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) 38 | // SPARK-8910 39 | .set(UI_ENABLED, false) 40 | .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) 41 | // Hive changed the default of hive.metastore.disallow.incompatible.col.type.changes 42 | // from false to true. For details, see the JIRA HIVE-12320 and HIVE-17764. 43 | .set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", "false") 44 | // Disable ConvertToLocalRelation for better test coverage. Test cases built on 45 | // LocalRelation will exercise the optimization rules better by disabling it as 46 | // this rule may potentially block testing of other optimization rules such as 47 | // ConstantPropagation etc. 48 | .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) 49 | /* 50 | Uncomment to see Plan rewrites 51 | .set("spark.sql.planChangeLog.level", "ERROR") 52 | .set( 53 | "spark.sql.planChangeLog.rules", 54 | "org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions," + 55 | "org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown," + 56 | "org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery," + 57 | "org.apache.spark.sql.catalyst.optimizer.PullupCorrelatedPredicates") 58 | */ 59 | /* Use these settings to turn off some of the code generation 60 | .set("spark.sql.codegen.factoryMode", "NO_CODEGEN") 61 | .set("spark.sql.codegen.maxFields", "0") 62 | .set("spark.sql.codegen.wholeStage", "false") 63 | */ 64 | 65 | def testMaster: String = "local[*]" 66 | 67 | } 68 | 69 | object TestSQLMacrosHive 70 | extends TestHiveContext( 71 | new SparkContext( 72 | System.getProperty("spark.sql.test.master", SQLMacrosTestConf.testMaster), 73 | "TestSQLContext", 74 | SQLMacrosTestConf.localConf), 75 | false) 76 | --------------------------------------------------------------------------------