├── .github └── workflows │ ├── ci.yaml │ └── docs.yaml ├── .gitignore ├── .publish └── publish.scala ├── LICENSE ├── README.md ├── USAGE-DEV.md ├── project.scala ├── scala-cli └── src ├── main ├── Aliasing.scala ├── ClassDataFrame.scala ├── CollectColumns.scala ├── Column.scala ├── ColumnOp.scala ├── DataFrame.scala ├── DataFrameBuilders.scala ├── FrameSchema.scala ├── Grouping.scala ├── Join.scala ├── JoinOnCondition.scala ├── MacroHelpers.scala ├── Name.scala ├── Repeated.scala ├── SchemaView.scala ├── SchemaViewProvider.scala ├── Select.scala ├── StructDataFrame.scala ├── UntypedOps.scala ├── When.scala ├── Where.scala ├── WithColumns.scala ├── api │ └── api.scala ├── functions │ ├── aggregates.scala │ ├── lit.scala │ └── when.scala ├── types │ ├── Coerce.scala │ ├── DataType.scala │ └── Encoder.scala └── untyped.scala └── test ├── AggregatorsTest.scala ├── CoerceTest.scala ├── ColumnsTest.scala ├── CompilationTest.scala ├── JoinTest.scala ├── OperatorsTest.scala ├── SparkUnitTest.scala ├── WhenTest.scala ├── WhereTest.scala ├── WithColumnsTest.scala └── example ├── Books.scala ├── Countries.scala └── Workers.scala /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Run CI 2 | 3 | on: 4 | push: 5 | workflow_dispatch: 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout 12 | uses: actions/checkout@v3 13 | with: 14 | fetch-depth: 0 15 | - name: Setup coursier cache 16 | uses: coursier/cache-action@v6.3 17 | - name: Setup scala-cli 18 | uses: VirtusLab/scala-cli-setup@v1.5.1 19 | - name: Run tests 20 | run: scala-cli test . --jvm temurin:11 21 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy documentation to Pages 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | pages: write 11 | id-token: write 12 | 13 | concurrency: 14 | group: "pages" 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | deploy: 19 | environment: 20 | name: github-pages 21 | url: ${{ steps.deployment.outputs.page_url }} 22 | runs-on: ubuntu-latest 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v3 26 | with: 27 | fetch-depth: 0 28 | - name: Setup coursier cache 29 | uses: coursier/cache-action@v6.3 30 | - name: Setup scala-cli 31 | uses: VirtusLab/scala-cli-setup@v1.4.0 32 | - name: Generate documentation 33 | run: scala-cli doc . 34 | - name: Setup Pages 35 | uses: actions/configure-pages@v2 36 | - name: Upload artifact 37 | uses: actions/upload-pages-artifact@v1 38 | with: 39 | path: 'scala-doc' 40 | - name: Deploy to GitHub Pages 41 | id: deployment 42 | uses: actions/deploy-pages@v1 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .bloop/ 2 | .bsp/ 3 | .scala/ 4 | .scala-build/ 5 | .vscode/ 6 | .metals/ 7 | target/ 8 | metals.sbt 9 | scala-doc 10 | /.tmp 11 | -------------------------------------------------------------------------------- /.publish/publish.scala: -------------------------------------------------------------------------------- 1 | //> using publish.organization "org.virtuslab" 2 | //> using publish.name "iskra" 3 | //> using publish.computeVersion "git:tag" 4 | //> using publish.url "https://github.com/VirtusLab/iskra" 5 | //> using publish.versionControl "github:VirtusLab/iskra" 6 | //> using publish.license "Apache-2.0" 7 | //> using publish.repository "central" 8 | //> using publish.developer "prolativ|Michał Pałka|https://github.com/prolativ" 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Iskra 2 | 3 | Iskra is a Scala 3 wrapper library around Apache Spark API which allows writing typesafe and boilerplate-free but still efficient Spark code. 4 | 5 | ## How is it possible to write Spark applications in Scala 3? 6 | 7 | Starting from the release of 3.2.0, Spark is cross-compiled also for Scala 2.13, which opens a way to using Spark from Scala 3 code, as Scala 3 projects can depend on Scala 2.13 artifacts. 8 | 9 | However, one might run into problems when trying to call a method requiring an implicit instance of Spark's `Encoder` type. Derivation of instances of `Encoder` relies on presence of a `TypeTag` for a given type. However `TypeTag`s are not generated by Scala 3 compiler anymore (and there are no plans to support this) so instances of `Encoder` cannot be automatically synthesized in most cases. 10 | 11 | Iskra tries to work around this problem by using its own encoders (unrelated to Spark's `Encoder` type) generated using Scala 3's new metaprogramming API. 12 | 13 | ## How does Iskra make things typesafe and efficient at the same time? 14 | 15 | Iskra provides thin (but strongly typed) wrappers around `DataFrame`s, which track types and names of columns at compile time but let Catalyst perform all of its optimizations at runtime. 16 | 17 | Iskra uses structural types rather than case classes as data models, which gives us a lot of flexibility (no need to explicitly define a new case class when a column is added/removed/renamed!) but we still get compilation errors when we try to refer to a column which doesn't exist or can't be used in a given context. 18 | 19 | ## Usage 20 | 21 | :warning: This library is in its early stage of development - the syntax and type hierarchy might still change, 22 | the coverage of Spark's API is far from being complete and more tests are needed. 23 | 24 | 1) Add Iskra as a dependency to your project, e.g. 25 | 26 | * in a file compiled with Scala CLI: 27 | ```scala 28 | //> using lib "org.virtuslab::iskra:0.0.3" 29 | ``` 30 | 31 | * when starting Scala CLI REPL: 32 | ```shell 33 | scala-cli repl --dep org.virtuslab::iskra:0.0.3 34 | ``` 35 | 36 | * in `build.sbt` in an sbt project: 37 | ```scala 38 | libraryDependencies += "org.virtuslab" %% "iskra" % "0.0.3" 39 | ``` 40 | 41 | Iskra is built with Scala 3.1.3 so it's compatible with Scala 3.1.x and newer minor releases (starting from 3.2.0 you'll get code completions for names of columns in REPL and Metals!). 42 | Iskra transitively depends on Spark 3.2.0. 43 | 44 | 2) Import the basic definitions from the API 45 | ```scala 46 | import org.virtuslab.iskra.api.* 47 | ``` 48 | 49 | 3) Get a Spark session, e.g. 50 | ```scala 51 | given spark: SparkSession = 52 | SparkSession 53 | .builder() 54 | .master("local") 55 | .appName("my-spark-app") 56 | .getOrCreate() 57 | ``` 58 | 59 | 4) Create a typed data frame in either of the two ways: 60 | * by using `toTypedDF` extension method on a `Seq` of case classes, e.g. 61 | ```scala 62 | Seq(Foo(1, "abc"), Foo(2, "xyz")).toTypedDF 63 | ``` 64 | * by taking a good old (untyped) data frame and calling `typed` extension method on it with a type parameter representing a case class, e.g. 65 | ```scala 66 | df.typed[Foo] 67 | ``` 68 | 69 | In case you needed to get back to the unsafe world of untyped data frames for some reason, just call `.untyped` on a typed data frame. 70 | 71 | 5) Follow your intuition of a Spark developer :wink: 72 | 73 | This library intends to maximally resemble the original API of Spark (e.g. by using the same names of methods, etc.) where possible, although trying to make the code feel more like regular Scala without unnecessary boilerplate and adding some other syntactic improvements. 74 | 75 | Most important differences: 76 | * Refer to columns (also with prefixes specifying the alias for a dataframe in case of ambiguities) simply with `$.foo.bar` instead of `$"foo.bar"` or `col("foo.bar")`. Use backticks when necessary, e.g. ``$.`column with spaces` ``. 77 | * From inside of `.select(...)` or `.select{...}` you should return something that is a named column or a tuple of named columns. Because of how Scala syntax works you can write simply `.select($.x, $.y)` instead of `select(($.x, $.y))`. With braces you can compute intermediate values like 78 | ```scala 79 | .select { 80 | val sum = ($.x + $.y).as("sum") 81 | ($.x, $.y, sum) 82 | } 83 | ``` 84 | * Syntax for joins looks slightly more like SQL, but with dots and parentheses as for usual method calls, e.g. 85 | ```scala 86 | foos.innerJoin(bars).on($.foos.barId === $.bars.id).select(...) 87 | ``` 88 | * As you might have noticed above, the aliases for `foos` and `bars` were automatically inferred 89 | 90 | 6) For reference look at the [examples](src/test/example/) and the [API docs](https://virtuslab.github.io/iskra/) 91 | 92 | ## Local development 93 | 94 | This project is built using [scala-cli](https://scala-cli.virtuslab.org/) so just use the traditional commands with `.` as root like `scala-cli compile .` or `scala-cli test .`. 95 | 96 | For a more recent version of `Usage` section look [here](./USAGE-DEV.md) -------------------------------------------------------------------------------- /USAGE-DEV.md: -------------------------------------------------------------------------------- 1 | ## Usage (for devs) 2 | 3 | :warning: This library is in its early stage of development - the syntax and type hierarchy might still change, 4 | the coverage of Spark's API is far from being complete and more tests are needed. 5 | 6 | ### First steps 7 | 8 | 1) Add Iskra as a dependency to your project, e.g. 9 | 10 | * in a file compiled with Scala CLI: 11 | ```scala 12 | //> using lib "org.virtuslab::iskra:0.0.4-SNAPSHOT" 13 | ``` 14 | 15 | * when starting Scala CLI REPL: 16 | ```shell 17 | scala-cli repl --dep org.virtuslab::iskra:0.0.4-SNAPSHOT 18 | ``` 19 | 20 | * in `build.sbt` in an sbt project: 21 | ```scala 22 | libraryDependencies += "org.virtuslab" %% "iskra" % "0.0.4-SNAPSHOT" 23 | ``` 24 | 25 | Iskra is built with Scala 3.3.0 so it's compatible with Scala 3.3.x (LTS) and newer minor releases. 26 | Iskra transitively depends on Spark 3.2.0. 27 | 28 | 2) Import the basic definitions from the API 29 | ```scala 30 | import org.virtuslab.iskra.api.* 31 | ``` 32 | 33 | 3) Get a Spark session, e.g. 34 | ```scala 35 | given spark: SparkSession = 36 | SparkSession 37 | .builder() 38 | .master("local") 39 | .appName("my-spark-app") 40 | .getOrCreate() 41 | ``` 42 | 43 | 4) Create a typed data frame in either of the two ways: 44 | * by using `.toDF` extension method on a `Seq` of case classes, e.g. 45 | ```scala 46 | Seq(Foo(1, "abc"), Foo(2, "xyz")).toDF 47 | ``` 48 | Note that this variant of `.toDF` comes from `org.virtuslab.iskra.api` rather than from `spark.implicits`. 49 | 50 | * by taking a good old (untyped) data frame and calling `typed` extension method on it with a type parameter representing a case class, e.g. 51 | ```scala 52 | df.typed[Foo] 53 | ``` 54 | In case you needed to get back to the unsafe world of untyped data frames for some reason, just call `.untyped` on a typed data frame. 55 | 56 | 5) Follow your intuition of a Spark developer :wink: This library is intended to maximally resemble the original API of Spark (e.g. by using the same names of methods, etc.) where possible, although trying to make the code feel more like regular Scala without unnecessary boilerplate but with some other syntactic improvements. 57 | 58 | 6) Look at the [examples](src/test/example/). 59 | 60 | ### Key concepts and differences to untyped Spark 61 | 62 | Data frames created with `.toDF` or `.typed` like above are called `ClassDataFrame`s as their compile-time schema is represented by a case class. As an alternative to case classes, schemas of data frames can be expressed structurally. If this is a case, such a data frame is called a `StructDataFrame`. Some data frame operations might be available only for either `ClassDataFrame`s or `StructDataFrame` while some operations are available for both. This depends on the semantics of each operation and on implementational restrictions (which might get lifted in the future). To turn a `ClassDataFrame` into a `StructDataFrame` or vice versa use `.asStruct` or `.asClass[A]` method respectively. 63 | 64 | 65 | When operating on a data frame, `$` represents the schema of this frame, from which columns can be selected like ordinary class memebers. So to refer to a column called `foo` instead of writing `col("foo")` or `$"foo"` write `$.foo`. If the name of a column is not a valid Scala identifier, you can use backticks, e.g. ``$.`column with spaces` ``. Similarly the syntax `$.foo.bar` can be used to refer to a column originating from a specific data frame to avoid ambiguities. This corresponds to `col("foo.bar")` or `$"foo.bar"` in vanilla Spark. 66 | 67 | 68 | Some operations like `.select(...)` or `.agg(...)` accept potentially multiple columns as arguments. You can pass individual columns separately, like `.select($.foo, $.bar)` or you can aggregate them usings `Columns(...)`, i.e. `select(Columns($.foo, $.bar))`. `Columns` will eventually get flattened so these who syntaxes are semantically equivalent. However, `Columns(...)` syntax might come in handy e.g. if you needed to embed a block of code as an argument to `.select { ... }`, e.g. 69 | ```scala 70 | .select { 71 | val sum = ($.x + $.y).as("sum") 72 | Columns($.x, $.y, sum) 73 | } 74 | ``` 75 | 76 | The syntax for joins looks slightly more like SQL, but with dots and parentheses as for usual method calls, e.g. 77 | ```scala 78 | foos.innerJoin(bars).on($.foos.barId === $.bars.id).select(...) 79 | ``` 80 | 81 | As you might have noticed above, the aliases for `foos` and `bars` were automatically inferred so you don't have to write 82 | ```scala 83 | foos.as("foos").innerJoin(bars.as("bars")) 84 | ``` 85 | -------------------------------------------------------------------------------- /project.scala: -------------------------------------------------------------------------------- 1 | //> using scala "3.3.4" 2 | //> using dep "org.apache.spark:spark-core_2.13:3.2.0" 3 | //> using dep "org.apache.spark:spark-sql_2.13:3.2.0" 4 | //> using test.dep "org.scalatest::scalatest::3.2.19" 5 | -------------------------------------------------------------------------------- /scala-cli: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | scala-cli --cli-version nightly "$@" 3 | -------------------------------------------------------------------------------- /src/main/Aliasing.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | 5 | object Aliasing: 6 | given dataFrameAliasingOps: {} with 7 | extension [S](df: StructDataFrame[S]) 8 | transparent inline def as(inline frameName: Name): StructDataFrame[?] = ${ Aliasing.aliasStructDataFrameImpl('df, 'frameName) } 9 | transparent inline def alias(inline frameName: Name): StructDataFrame[?] = ${ Aliasing.aliasStructDataFrameImpl('df, 'frameName) } 10 | 11 | def autoAliasImpl[DF <: StructDataFrame[?] : Type](df: Expr[DF])(using Quotes): Expr[StructDataFrame[?]] = 12 | import quotes.reflect.* 13 | 14 | val identName = df.asTerm match 15 | case Inlined(_, _, Ident(name)) => 16 | Some(name) 17 | case Inlined(_, _, Select(This(_), name)) => 18 | Some(name) 19 | case _ => None 20 | 21 | identName match 22 | case None => df 23 | case Some(name) => 24 | Type.of[DF] match 25 | case '[StructDataFrame.WithAlias[alias]] => 26 | df 27 | case '[StructDataFrame[schema]] => 28 | ConstantType(StringConstant(name)).asType match 29 | case '[Name.Subtype[n]] => FrameSchema.reownType[n](Type.of[schema]) match 30 | case '[t] => 31 | '{ StructDataFrame[t](${df}.untyped.as(${Expr(name)})) } 32 | 33 | 34 | def aliasStructDataFrameImpl[S : Type](df: Expr[StructDataFrame[S]], frameName: Expr[String])(using Quotes) = 35 | import quotes.reflect.* 36 | ConstantType(StringConstant(frameName.valueOrAbort)).asType match 37 | case '[Name.Subtype[n]] => 38 | FrameSchema.reownType[n](Type.of[S]) match 39 | case '[reowned] => '{ 40 | new StructDataFrame[reowned](${ df }.untyped.alias(${ frameName })): 41 | type Alias = n 42 | } 43 | -------------------------------------------------------------------------------- /src/main/ClassDataFrame.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import scala.reflect.ClassTag 5 | 6 | import types.{DataType, Encoder, StructEncoder} 7 | 8 | 9 | class ClassDataFrame[A](val untyped: UntypedDataFrame) extends DataFrame 10 | 11 | object ClassDataFrame: 12 | extension [A](df: ClassDataFrame[A]) 13 | transparent inline def asStruct: StructDataFrame[?] = ${ asStructImpl('df) } 14 | 15 | inline def collect(): Array[A] = ${ collectImpl('df) } 16 | 17 | private def asStructImpl[A : Type](df: Expr[ClassDataFrame[A]])(using Quotes): Expr[StructDataFrame[?]] = 18 | import quotes.reflect.report 19 | 20 | Expr.summon[Encoder[A]] match 21 | case Some(encoder) => encoder match 22 | case '{ $enc: StructEncoder[A] { type StructSchema = structSchema } } => 23 | '{ StructDataFrame[structSchema](${ df }.untyped) } 24 | case '{ $enc: Encoder[A] { type ColumnType = colType } } => 25 | Type.of[colType] match 26 | case '[DataType.Subtype[t]] => 27 | '{ StructDataFrame[("value" := t)/* *: EmptyTuple */](${ df }.untyped) } // TODO: Get rid of non-tuple schemas? 28 | case None => report.errorAndAbort(s"Could not summon encoder for ${Type.show[A]}") 29 | 30 | private def collectImpl[A : Type](df: Expr[ClassDataFrame[A]])(using Quotes): Expr[Array[A]] = 31 | import quotes.reflect.report 32 | 33 | Expr.summon[Encoder[A]] match 34 | case Some(encoder) => 35 | val classTag = Expr.summon[ClassTag[A]].getOrElse(report.errorAndAbort(s"Could not summon ClassTag for ${Type.show[A]}")) 36 | encoder match 37 | case '{ $enc: StructEncoder[A] { type StructSchema = structSchema } } => 38 | '{ ${ df }.untyped.collect.map(row => ${ enc }.decode(row).asInstanceOf[A])(${ classTag }) } 39 | case '{ $enc: Encoder[A] { type ColumnType = colType } } => 40 | '{ ${ df }.untyped.collect.map(row => ${ enc }.decode(row(0)).asInstanceOf[A])(${ classTag }) } 41 | case None => report.errorAndAbort(s"Could not summon encoder for ${Type.show[A]}") 42 | -------------------------------------------------------------------------------- /src/main/CollectColumns.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.compiletime.error 4 | 5 | import org.virtuslab.iskra.types.DataType 6 | 7 | // TODO should it be covariant or not? 8 | trait CollectColumns[-C]: 9 | type CollectedColumns <: Tuple 10 | def underlyingColumns(c: C): Seq[UntypedColumn] 11 | 12 | // Using `given ... with { ... }` syntax might sometimes break pattern match on `CollectColumns[...] { type CollectedColumns = cc }` 13 | 14 | object CollectColumns: 15 | given collectNamedColumn[N <: Name, T <: DataType]: CollectColumns[NamedColumn[N, T]] with 16 | type CollectedColumns = (N := T) *: EmptyTuple 17 | def underlyingColumns(c: NamedColumn[N, T]) = Seq(c.untyped) 18 | 19 | given collectColumnsWithSchema[S <: Tuple]: CollectColumns[ColumnsWithSchema[S]] with 20 | type CollectedColumns = S 21 | def underlyingColumns(c: ColumnsWithSchema[S]) = c.underlyingColumns 22 | 23 | given collectEmptyTuple[S]: CollectColumns[EmptyTuple] with 24 | type CollectedColumns = EmptyTuple 25 | def underlyingColumns(c: EmptyTuple) = Seq.empty 26 | 27 | given collectCons[H, T <: Tuple](using collectHead: CollectColumns[H], collectTail: CollectColumns[T]): (CollectColumns[H *: T] { type CollectedColumns = Tuple.Concat[collectHead.CollectedColumns, collectTail.CollectedColumns] }) = 28 | new CollectColumns[H *: T]: 29 | type CollectedColumns = Tuple.Concat[collectHead.CollectedColumns, collectTail.CollectedColumns] 30 | def underlyingColumns(c: H *: T) = collectHead.underlyingColumns(c.head) ++ collectTail.underlyingColumns(c.tail) 31 | 32 | 33 | // TODO Customize error message for different operations with an explanation 34 | class CannotCollectColumns(typeName: String) 35 | extends Exception(s"Could not find an instance of CollectColumns for ${typeName}") 36 | -------------------------------------------------------------------------------- /src/main/Column.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.language.implicitConversions 4 | 5 | import scala.quoted.* 6 | 7 | import org.apache.spark.sql.{Column => UntypedColumn} 8 | import types.DataType 9 | import MacroHelpers.TupleSubtype 10 | 11 | class Column(val untyped: UntypedColumn): 12 | inline def name(using v: ValueOf[Name]): Name = v.value 13 | 14 | object Column: 15 | implicit transparent inline def columnToNamedColumn(inline col: Col[?]): NamedColumn[?, ?] = 16 | ${ columnToNamedColumnImpl('col) } 17 | 18 | private def columnToNamedColumnImpl(col: Expr[Col[?]])(using Quotes): Expr[NamedColumn[?, ?]] = 19 | import quotes.reflect.* 20 | col match 21 | case '{ ($v: StructuralSchemaView).selectDynamic($nm: Name).$asInstanceOf$[Col[tp]] } => 22 | nm.asTerm.tpe.asType match 23 | case '[Name.Subtype[n]] => 24 | '{ NamedColumn[n, tp](${ col }.untyped.as(${ nm })) } 25 | case '{ $c: Col[tp] } => 26 | col.asTerm match 27 | case Inlined(_, _, Ident(name)) => 28 | ConstantType(StringConstant(name)).asType match 29 | case '[Name.Subtype[n]] => 30 | val alias = Literal(StringConstant(name)).asExprOf[Name] 31 | '{ NamedColumn[n, tp](${ col }.untyped.as(${ alias })) } 32 | 33 | extension [T <: DataType](col: Col[T]) 34 | inline def as[N <: Name](name: N): NamedColumn[N, T] = 35 | NamedColumn[N, T](col.untyped.as(name)) 36 | inline def alias[N <: Name](name: N): NamedColumn[N, T] = 37 | NamedColumn[N, T](col.untyped.as(name)) 38 | 39 | extension [T1 <: DataType](col1: Col[T1]) 40 | inline def +[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Plus[T1, T2]): Col[op.Out] = op(col1, col2) 41 | inline def -[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Minus[T1, T2]): Col[op.Out] = op(col1, col2) 42 | inline def *[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Mult[T1, T2]): Col[op.Out] = op(col1, col2) 43 | inline def /[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Div[T1, T2]): Col[op.Out] = op(col1, col2) 44 | inline def ++[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.PlusPlus[T1, T2]): Col[op.Out] = op(col1, col2) 45 | inline def <[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Lt[T1, T2]): Col[op.Out] = op(col1, col2) 46 | inline def <=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Le[T1, T2]): Col[op.Out] = op(col1, col2) 47 | inline def >[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Gt[T1, T2]): Col[op.Out] = op(col1, col2) 48 | inline def >=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Ge[T1, T2]): Col[op.Out] = op(col1, col2) 49 | inline def ===[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Eq[T1, T2]): Col[op.Out] = op(col1, col2) 50 | inline def =!=[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Ne[T1, T2]): Col[op.Out] = op(col1, col2) 51 | inline def &&[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.And[T1, T2]): Col[op.Out] = op(col1, col2) 52 | inline def ||[T2 <: DataType](col2: Col[T2])(using op: ColumnOp.Or[T1, T2]): Col[op.Out] = op(col1, col2) 53 | 54 | 55 | class Col[+T <: DataType](untyped: UntypedColumn) extends Column(untyped) 56 | 57 | 58 | object Columns: 59 | transparent inline def apply[C <: NamedColumns](columns: C): ColumnsWithSchema[?] = ${ applyImpl('columns) } 60 | 61 | private def applyImpl[C : Type](columns: Expr[C])(using Quotes): Expr[ColumnsWithSchema[?]] = 62 | import quotes.reflect.* 63 | 64 | Expr.summon[CollectColumns[C]] match 65 | case Some(collectColumns) => 66 | collectColumns match 67 | case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => 68 | Type.of[collectedColumns] match 69 | case '[TupleSubtype[collectedCols]] => 70 | '{ 71 | val cols = ${ cc }.underlyingColumns(${ columns }) 72 | ColumnsWithSchema[collectedCols](cols) 73 | } 74 | case None => 75 | throw CollectColumns.CannotCollectColumns(Type.show[C]) 76 | 77 | 78 | trait NamedColumnOrColumnsLike 79 | 80 | type NamedColumns = Repeated[NamedColumnOrColumnsLike] 81 | 82 | class NamedColumn[N <: Name, T <: DataType](val untyped: UntypedColumn) 83 | extends NamedColumnOrColumnsLike 84 | 85 | class ColumnsWithSchema[Schema <: Tuple](val underlyingColumns: Seq[UntypedColumn]) extends NamedColumnOrColumnsLike 86 | 87 | 88 | @annotation.showAsInfix 89 | trait :=[L <: ColumnLabel, T <: DataType] 90 | 91 | @annotation.showAsInfix 92 | trait /[+Prefix <: Name, +Suffix <: Name] 93 | 94 | type ColumnLabel = Name | (Name / Name) 95 | -------------------------------------------------------------------------------- /src/main/ColumnOp.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import org.apache.spark.sql 5 | import org.apache.spark.sql.functions.concat 6 | import org.virtuslab.iskra.Col 7 | import org.virtuslab.iskra.UntypedOps.typed 8 | import org.virtuslab.iskra.types.* 9 | import DataType.* 10 | 11 | trait ColumnOp: 12 | type Out <: DataType 13 | 14 | object ColumnOp: 15 | trait ResultType[T <: DataType] extends ColumnOp: 16 | override type Out = T 17 | 18 | abstract class BinaryColumnOp[T1 <: DataType, T2 <: DataType](untypedOp: (UntypedColumn, UntypedColumn) => UntypedColumn) extends ColumnOp: 19 | def apply(col1: Col[T1], col2: Col[T2]): Col[Out] = untypedOp(col1.untyped, col2.untyped).typed[Out] 20 | 21 | class Plus[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ + _) 22 | object Plus: 23 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Plus[T1, T2] { type Out = CommonNumericType[T1, T2] }) = 24 | new Plus[T1, T2] with ResultType[CommonNumericType[T1, T2]] 25 | 26 | class Minus[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ - _) 27 | object Minus: 28 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Minus[T1, T2] { type Out = CommonNumericType[T1, T2] }) = 29 | new Minus[T1, T2] with ResultType[CommonNumericType[T1, T2]] 30 | 31 | class Mult[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ * _) 32 | object Mult: 33 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Mult[T1, T2] { type Out = CommonNumericType[T1, T2] }) = 34 | new Mult[T1, T2] with ResultType[CommonNumericType[T1, T2]] 35 | 36 | class Div[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ / _) 37 | object Div: 38 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Div[T1, T2] { type Out = DoubleOfCommonNullability[T1, T2] }) = 39 | new Div[T1, T2] with ResultType[DoubleOfCommonNullability[T1, T2]] 40 | 41 | class PlusPlus[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](concat(_, _)) 42 | object PlusPlus: 43 | given strings[T1 <: StringOptLike, T2 <: StringOptLike]: (PlusPlus[T1, T2] { type Out = StringOfCommonNullability[T1, T2] }) = 44 | new PlusPlus[T1, T2] with ResultType[StringOfCommonNullability[T1, T2]] 45 | 46 | class Eq[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ === _) 47 | object Eq: 48 | given booleans[T1 <: BooleanOptLike, T2 <: BooleanOptLike]: (Eq[T1, T2] { type Out = CommonBooleanType[T1, T2] }) = 49 | new Eq[T1, T2] with ResultType[CommonBooleanType[T1, T2]] 50 | 51 | given strings[T1 <: StringOptLike, T2 <: StringOptLike]: (Eq[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 52 | new Eq[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 53 | 54 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Eq[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 55 | new Eq[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 56 | 57 | given structs[S1 <: Tuple, S2 <: Tuple, T1 <: StructOptLike[S1], T2 <: StructOptLike[S2]]: (Eq[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 58 | new Eq[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 59 | 60 | class Ne[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ =!= _) 61 | object Ne: 62 | given booleans[T1 <: BooleanOptLike, T2 <: BooleanOptLike]: (Ne[T1, T2] { type Out = CommonBooleanType[T1, T2] }) = 63 | new Ne[T1, T2] with ResultType[CommonBooleanType[T1, T2]] 64 | 65 | given strings[T1 <: StringOptLike, T2 <: StringOptLike]: (Ne[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 66 | new Ne[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 67 | 68 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Ne[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 69 | new Ne[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 70 | 71 | given structs[S1 <: Tuple, S2 <: Tuple, T1 <: StructOptLike[S1], T2 <: StructOptLike[S2]]: (Ne[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 72 | new Ne[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 73 | 74 | class Lt[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ < _) 75 | object Lt: 76 | given booleans[T1 <: BooleanOptLike, T2 <: BooleanOptLike]: (Lt[T1, T2] { type Out = CommonBooleanType[T1, T2] }) = 77 | new Lt[T1, T2] with ResultType[CommonBooleanType[T1, T2]] 78 | 79 | given strings[T1 <: StringOptLike, T2 <: StringOptLike]: (Lt[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 80 | new Lt[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 81 | 82 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Lt[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 83 | new Lt[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 84 | 85 | given structs[S1 <: Tuple, S2 <: Tuple, T1 <: StructOptLike[S1], T2 <: StructOptLike[S2]]: (Lt[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 86 | new Lt[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 87 | 88 | class Le[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ <= _) 89 | object Le: 90 | given booleans[T1 <: BooleanOptLike, T2 <: BooleanOptLike]: (Le[T1, T2] { type Out = CommonBooleanType[T1, T2] }) = 91 | new Le[T1, T2] with ResultType[CommonBooleanType[T1, T2]] 92 | 93 | given strings[T1 <: StringOptLike, T2 <: StringOptLike]: (Le[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 94 | new Le[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 95 | 96 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Le[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 97 | new Le[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 98 | 99 | given structs[S1 <: Tuple, S2 <: Tuple, T1 <: StructOptLike[S1], T2 <: StructOptLike[S2]]: (Le[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 100 | new Le[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 101 | 102 | class Gt[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ > _) 103 | object Gt: 104 | given booleans[T1 <: BooleanOptLike, T2 <: BooleanOptLike]: (Gt[T1, T2] { type Out = CommonBooleanType[T1, T2] }) = 105 | new Gt[T1, T2] with ResultType[CommonBooleanType[T1, T2]] 106 | 107 | given strings[T1 <: StringOptLike, T2 <: StringOptLike]: (Gt[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 108 | new Gt[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 109 | 110 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Gt[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 111 | new Gt[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 112 | 113 | given structs[S1 <: Tuple, S2 <: Tuple, T1 <: StructOptLike[S1], T2 <: StructOptLike[S2]]: (Gt[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 114 | new Gt[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 115 | 116 | class Ge[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ >= _) 117 | object Ge: 118 | given booleans[T1 <: BooleanOptLike, T2 <: BooleanOptLike]: (Ge[T1, T2] { type Out = CommonBooleanType[T1, T2] }) = 119 | new Ge[T1, T2] with ResultType[CommonBooleanType[T1, T2]] 120 | 121 | given strings[T1 <: StringOptLike, T2 <: StringOptLike]: (Ge[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 122 | new Ge[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 123 | 124 | given numerics[T1 <: DoubleOptLike, T2 <: DoubleOptLike]: (Ge[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 125 | new Ge[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 126 | 127 | given structs[S1 <: Tuple, S2 <: Tuple, T1 <: StructOptLike[S1], T2 <: StructOptLike[S2]]: (Ge[T1, T2] { type Out = BooleanOfCommonNullability[T1, T2] }) = 128 | new Ge[T1, T2] with ResultType[BooleanOfCommonNullability[T1, T2]] 129 | 130 | class And[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ && _) 131 | object And: 132 | given booleans[T1 <: BooleanOptLike, T2 <: BooleanOptLike]: (And[T1, T2] { type Out = CommonBooleanType[T1, T2] }) = 133 | new And[T1, T2] with ResultType[CommonBooleanType[T1, T2]] 134 | 135 | class Or[T1 <: DataType, T2 <: DataType] extends BinaryColumnOp[T1, T2](_ || _) 136 | object Or: 137 | given booleans[T1 <: BooleanOptLike, T2 <: BooleanOptLike]: (Or[T1, T2] { type Out = CommonBooleanType[T1, T2] }) = 138 | new Or[T1, T2] with ResultType[CommonBooleanType[T1, T2]] 139 | -------------------------------------------------------------------------------- /src/main/DataFrame.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import org.apache.spark.sql 4 | import org.apache.spark.sql.SparkSession 5 | 6 | 7 | trait DataFrame: 8 | type Alias 9 | def untyped: UntypedDataFrame 10 | 11 | object DataFrame: 12 | export Aliasing.dataFrameAliasingOps 13 | export Select.dataFrameSelectOps 14 | export Join.dataFrameJoinOps 15 | export GroupBy.dataFrameGroupByOps 16 | export Where.dataFrameWhereOps 17 | export WithColumns.dataFrameWithColumnsOps 18 | 19 | given dataFrameOps: {} with 20 | extension (df: DataFrame) 21 | inline def show(): Unit = df.untyped.show() 22 | -------------------------------------------------------------------------------- /src/main/DataFrameBuilders.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted._ 4 | import org.apache.spark.sql 5 | import org.apache.spark.sql.SparkSession 6 | import org.virtuslab.iskra.DataFrame 7 | import org.virtuslab.iskra.types.{DataType, Encoder, StructEncoder, PrimitiveEncoder} 8 | 9 | object DataFrameBuilders: 10 | extension [A](seq: Seq[A])(using encoder: Encoder[A]) 11 | inline def toDF(using spark: SparkSession): ClassDataFrame[A] = ${ toTypedDFImpl('seq, 'encoder, 'spark) } 12 | 13 | private def toTypedDFImpl[A : Type](seq: Expr[Seq[A]], encoder: Expr[Encoder[A]], spark: Expr[SparkSession])(using Quotes) = 14 | val (schema, encodeFun) = encoder match 15 | case '{ $e: StructEncoder.Aux[A, t] } => 16 | val schema = '{ ${ e }.catalystType } 17 | val encodeFun: Expr[A => sql.Row] = '{ ${ e }.encode } 18 | (schema, encodeFun) 19 | case '{ $e: Encoder.Aux[tpe, t] } => 20 | val schema = '{ 21 | sql.types.StructType(Seq( 22 | sql.types.StructField("value", ${ encoder }.catalystType, ${ encoder }.isNullable ) 23 | )) 24 | } 25 | val encodeFun: Expr[A => sql.Row] = '{ (value: A) => sql.Row(${ encoder }.encode(value)) } 26 | (schema, encodeFun) 27 | 28 | '{ 29 | val rowRDD = ${ spark }.sparkContext.parallelize(${ seq }.map(${ encodeFun })) 30 | ClassDataFrame[A](${ spark }.createDataFrame(rowRDD, ${ schema })) 31 | } 32 | -------------------------------------------------------------------------------- /src/main/FrameSchema.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import scala.deriving.Mirror 5 | import types.{DataType, Encoder, StructEncoder} 6 | import MacroHelpers.TupleSubtype 7 | 8 | object FrameSchema: 9 | type AsTuple[A] = A match 10 | case Tuple => A 11 | case Any => A *: EmptyTuple 12 | 13 | type FromTuple[T] = T match 14 | case h *: EmptyTuple => h 15 | case Tuple => T 16 | 17 | type Merge[S1, S2] = (S1, S2) match 18 | case (Tuple, Tuple) => 19 | Tuple.Concat[S1, S2] 20 | case (Any, Tuple) => 21 | S1 *: S2 22 | case (Tuple, Any) => 23 | Tuple.Append[S1, S2] 24 | case (Any, Any) => 25 | (S1, S2) 26 | 27 | type NullableLabeledDataType[T] = T match 28 | case label := tpe => label := DataType.AsNullable[tpe] 29 | 30 | type NullableSchema[T] = T match 31 | case Tuple => Tuple.Map[T, NullableLabeledDataType] 32 | case Any => NullableLabeledDataType[T] 33 | 34 | def reownType[Owner <: Name : Type](schema: Type[?])(using Quotes): Type[?] = 35 | schema match 36 | case '[EmptyTuple] => Type.of[EmptyTuple] 37 | case '[head *: tail] => 38 | reownType[Owner](Type.of[head]) match 39 | case '[reownedHead] => 40 | reownType[Owner](Type.of[tail]) match 41 | case '[TupleSubtype[reownedTail]] => 42 | Type.of[reownedHead *: reownedTail] 43 | case '[namePrefix / nameSuffix := dataType] => 44 | Type.of[Owner / nameSuffix := dataType] 45 | case '[Name.Subtype[name] := dataType] => 46 | Type.of[Owner / name := dataType] 47 | 48 | def isValidType(tpe: Type[?])(using Quotes): Boolean = tpe match 49 | case '[EmptyTuple] => true 50 | case '[(label := column) *: tail] => isValidType(Type.of[tail]) 51 | case '[label := column] => true 52 | case _ => false 53 | 54 | def schemaTypeFromColumnsTypes(colTypes: Seq[Type[?]])(using Quotes): Type[? <: Tuple] = 55 | colTypes match 56 | case Nil => Type.of[EmptyTuple] 57 | case '[TupleSubtype[headTpes]] :: tail => 58 | schemaTypeFromColumnsTypes(tail) match 59 | case '[TupleSubtype[tailTpes]] => Type.of[Tuple.Concat[headTpes, tailTpes]] 60 | -------------------------------------------------------------------------------- /src/main/Grouping.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import org.virtuslab.iskra.types.DataType 5 | import MacroHelpers.TupleSubtype 6 | 7 | class GroupBy[View <: SchemaView](val view: View, val underlying: UntypedDataFrame) 8 | 9 | object GroupBy: 10 | given dataFrameGroupByOps: {} with 11 | extension [S] (df: StructDataFrame[S]) 12 | transparent inline def groupBy: GroupBy[?] = ${ GroupBy.groupByImpl[S]('{df}) } 13 | 14 | given groupByOps: {} with 15 | extension [View <: SchemaView](groupBy: GroupBy[View]) 16 | transparent inline def apply[C <: NamedColumns](groupingColumns: View ?=> C) = ${ applyImpl[View, C]('groupBy, 'groupingColumns) } 17 | 18 | private def groupByImpl[S : Type](df: Expr[StructDataFrame[S]])(using Quotes): Expr[GroupBy[?]] = 19 | import quotes.reflect.asTerm 20 | val viewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[S]] 21 | viewExpr.asTerm.tpe.asType match 22 | case '[SchemaView.Subtype[v]] => 23 | '{ GroupBy[v](${ viewExpr }.asInstanceOf[v], ${ df }.untyped) } 24 | 25 | private def applyImpl[View <: SchemaView : Type, C : Type](groupBy: Expr[GroupBy[View]], groupingColumns: Expr[View ?=> C])(using Quotes): Expr[GroupedDataFrame[View]] = 26 | import quotes.reflect.* 27 | 28 | Expr.summon[CollectColumns[C]] match 29 | case Some(collectColumns) => 30 | collectColumns match 31 | case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => 32 | Type.of[collectedColumns] match 33 | case '[TupleSubtype[collectedCols]] => 34 | val groupedViewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[collectedCols]] 35 | groupedViewExpr.asTerm.tpe.asType match 36 | case '[SchemaView.Subtype[groupedView]] => 37 | '{ 38 | val groupingCols = ${ cc }.underlyingColumns(${ groupingColumns }(using ${ groupBy }.view)) 39 | new GroupedDataFrame[View]: 40 | type GroupingKeys = collectedCols 41 | type GroupedView = groupedView 42 | def underlying = ${ groupBy }.underlying.groupBy(groupingCols*) 43 | def fullView = ${ groupBy }.view 44 | def groupedView = ${ groupedViewExpr }.asInstanceOf[GroupedView] 45 | } 46 | case None => 47 | throw CollectColumns.CannotCollectColumns(Type.show[C]) 48 | 49 | // TODO: Rename to RelationalGroupedDataset and handle other aggregations: cube, rollup (and pivot?) 50 | trait GroupedDataFrame[FullView <: SchemaView]: 51 | import GroupedDataFrame.* 52 | 53 | type GroupingKeys <: Tuple 54 | type GroupedView <: SchemaView 55 | def underlying: UntypedRelationalGroupedDataset 56 | def fullView: FullView 57 | def groupedView: GroupedView 58 | 59 | object GroupedDataFrame: 60 | given groupedDataFrameOps: {} with 61 | extension [FullView <: SchemaView, GroupKeys <: Tuple, GroupView <: SchemaView](gdf: GroupedDataFrame[FullView]{ type GroupedView = GroupView; type GroupingKeys = GroupKeys }) 62 | transparent inline def agg[C <: NamedColumns](columns: (Agg { type View = FullView }, GroupView) ?=> C): StructDataFrame[?] = 63 | ${ aggImpl[FullView, GroupKeys, GroupView, C]('gdf, 'columns) } 64 | 65 | private def aggImpl[FullView <: SchemaView : Type, GroupingKeys <: Tuple : Type, GroupView <: SchemaView : Type, C : Type]( 66 | gdf: Expr[GroupedDataFrame[FullView] { type GroupedView = GroupView }], 67 | columns: Expr[(Agg { type View = FullView }, GroupView) ?=> C] 68 | )(using Quotes): Expr[StructDataFrame[?]] = 69 | import quotes.reflect.* 70 | 71 | val aggWrapper = '{ 72 | new Agg: 73 | type View = FullView 74 | val view = ${ gdf }.fullView 75 | } 76 | 77 | Expr.summon[CollectColumns[C]] match 78 | case Some(collectColumns) => 79 | collectColumns match 80 | case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => 81 | '{ 82 | // TODO assert cols is not empty 83 | val cols = ${ cc }.underlyingColumns(${ columns }(using ${ aggWrapper }, ${ gdf }.groupedView)) 84 | StructDataFrame[FrameSchema.Merge[GroupingKeys, collectedColumns]]( 85 | ${ gdf }.underlying.agg(cols.head, cols.tail*) 86 | ) 87 | } 88 | case None => 89 | throw CollectColumns.CannotCollectColumns(Type.show[C]) 90 | 91 | trait Agg: 92 | type View <: SchemaView 93 | def view: View 94 | -------------------------------------------------------------------------------- /src/main/Join.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import scala.compiletime.erasedValue 5 | 6 | enum JoinType: 7 | case Inner 8 | case Left 9 | case Right 10 | case Full 11 | case Semi 12 | case Anti 13 | 14 | object JoinType: 15 | inline def typeName[T <: JoinType] = inline erasedValue[T] match 16 | case _: Inner.type => "inner" 17 | case _: Left.type => "left" 18 | case _: Right.type => "right" 19 | case _: Full.type => "full" 20 | case _: Semi.type => "semi" 21 | case _: Anti.type => "anti" 22 | 23 | trait Join[T <: JoinType](val left: UntypedDataFrame, val right: UntypedDataFrame): 24 | type Left <: StructDataFrame[?] 25 | type Right <: StructDataFrame[?] 26 | 27 | object Join: 28 | given dataFrameJoinOps: {} with 29 | extension [LeftDF <: StructDataFrame[?]](inline leftDF: LeftDF) 30 | transparent inline def join[RightDF <: StructDataFrame[?]](inline rightDF: RightDF): Join[JoinType.Inner.type] = 31 | ${ joinImpl[JoinType.Inner.type, LeftDF, RightDF]('leftDF, 'rightDF) } 32 | transparent inline def innerJoin[RightDF <: StructDataFrame[?]](inline rightDF: RightDF): Join[JoinType.Inner.type] = 33 | ${ joinImpl[JoinType.Inner.type, LeftDF, RightDF]('leftDF, 'rightDF) } 34 | transparent inline def leftJoin[RightDF <: StructDataFrame[?]](inline rightDF: RightDF): Join[JoinType.Left.type] = 35 | ${ joinImpl[JoinType.Left.type, LeftDF, RightDF]('leftDF, 'rightDF) } 36 | transparent inline def rightJoin[RightDF <: StructDataFrame[?]](inline rightDF: RightDF): Join[JoinType.Right.type] = 37 | ${ joinImpl[JoinType.Right.type, LeftDF, RightDF]('leftDF, 'rightDF) } 38 | transparent inline def fullJoin[RightDF <: StructDataFrame[?]](inline rightDF: RightDF): Join[JoinType.Full.type] = 39 | ${ joinImpl[JoinType.Full.type, LeftDF, RightDF]('leftDF, 'rightDF) } 40 | transparent inline def semiJoin[RightDF <: StructDataFrame[?]](inline rightDF: RightDF): Join[JoinType.Semi.type] = 41 | ${ joinImpl[JoinType.Semi.type, LeftDF, RightDF]('leftDF, 'rightDF) } 42 | transparent inline def antiJoin[RightDF <: StructDataFrame[?]](inline rightDF: RightDF): Join[JoinType.Anti.type] = 43 | ${ joinImpl[JoinType.Anti.type, LeftDF, RightDF]('leftDF, 'rightDF) } 44 | 45 | transparent inline def crossJoin[RightDF <: StructDataFrame[?]](inline rightDF: RightDF): StructDataFrame[?] = 46 | ${ crossJoinImpl[LeftDF, RightDF]('leftDF, 'rightDF) } 47 | end dataFrameJoinOps 48 | 49 | def joinImpl[T <: JoinType : Type, LeftDF <: StructDataFrame[?] : Type, RightDF <: StructDataFrame[?] : Type]( 50 | leftDF: Expr[LeftDF], rightDF: Expr[RightDF] 51 | )(using Quotes) = 52 | Aliasing.autoAliasImpl(leftDF) match 53 | case '{ $left: l } => 54 | Aliasing.autoAliasImpl(rightDF) match 55 | case '{ $right: r } => 56 | '{ 57 | (new Join[T](${left}.untyped, ${right}.untyped) { type Left = l; type Right = r }) 58 | : Join[T]{ type Left = l; type Right = r } 59 | } 60 | 61 | def crossJoinImpl[LeftDF <: StructDataFrame[?] : Type, RightDF <: StructDataFrame[?] : Type]( 62 | leftDF: Expr[LeftDF], rightDF: Expr[RightDF] 63 | )(using Quotes): Expr[StructDataFrame[?]] = 64 | Aliasing.autoAliasImpl(leftDF) match 65 | case '{ $left: StructDataFrame[l] } => 66 | Aliasing.autoAliasImpl(rightDF) match 67 | case '{ $right: StructDataFrame[r] } => 68 | '{ 69 | val joined = ${left}.untyped.crossJoin(${right}.untyped) 70 | StructDataFrame[FrameSchema.Merge[l, r]](joined) 71 | } 72 | 73 | export JoinOnCondition.joinOnConditionOps 74 | -------------------------------------------------------------------------------- /src/main/JoinOnCondition.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.language.implicitConversions 4 | 5 | import scala.quoted.* 6 | import org.virtuslab.iskra.types.BooleanOptLike 7 | 8 | trait OnConditionJoiner[Join <: JoinType, Left, Right] 9 | 10 | trait JoinOnCondition[Join <: JoinType, Left <: StructDataFrame[?], Right <: StructDataFrame[?]]: 11 | type JoiningView <: SchemaView 12 | type JoinedSchema 13 | def joiningView: JoiningView 14 | 15 | object JoinOnCondition: 16 | private type LeftWithRight[Left, Right] = FrameSchema.Merge[Left, Right] 17 | private type LeftWithOptionalRight[Left, Right] = FrameSchema.Merge[Left, FrameSchema.NullableSchema[Right]] 18 | private type OptionalLeftWithRight[Left, Right] = FrameSchema.Merge[FrameSchema.NullableSchema[Left], Right] 19 | private type OptionalLeftWithOptionalRight[Left, Right] = FrameSchema.Merge[FrameSchema.NullableSchema[Left], FrameSchema.NullableSchema[Right]] 20 | private type OnlyLeft[Left, Right] = Left 21 | 22 | transparent inline given inner[Left <: StructDataFrame[?], Right <: StructDataFrame[?]]: JoinOnCondition[JoinType.Inner.type, Left, Right] = 23 | ${ joinOnConditionImpl[JoinType.Inner.type, LeftWithRight, Left, Right] } 24 | 25 | transparent inline given left[Left <: StructDataFrame[?], Right <: StructDataFrame[?]]: JoinOnCondition[JoinType.Left.type, Left, Right] = 26 | ${ joinOnConditionImpl[JoinType.Left.type, LeftWithOptionalRight, Left, Right] } 27 | 28 | transparent inline given right[Left <: StructDataFrame[?], Right <: StructDataFrame[?]]: JoinOnCondition[JoinType.Right.type, Left, Right] = 29 | ${ joinOnConditionImpl[JoinType.Right.type, OptionalLeftWithRight, Left, Right] } 30 | 31 | transparent inline given full[Left <: StructDataFrame[?], Right <: StructDataFrame[?]]: JoinOnCondition[JoinType.Full.type, Left, Right] = 32 | ${ joinOnConditionImpl[JoinType.Full.type, [S1, S2] =>> FrameSchema.Merge[FrameSchema.NullableSchema[S1], FrameSchema.NullableSchema[S2]], Left, Right] } 33 | 34 | transparent inline given semi[Left <: StructDataFrame[?], Right <: StructDataFrame[?]]: JoinOnCondition[JoinType.Semi.type, Left, Right] = 35 | ${ joinOnConditionImpl[JoinType.Semi.type, OnlyLeft, Left, Right] } 36 | 37 | transparent inline given anti[Left <: StructDataFrame[?], Right <: StructDataFrame[?]]: JoinOnCondition[JoinType.Anti.type, Left, Right] = 38 | ${ joinOnConditionImpl[JoinType.Anti.type, OnlyLeft, Left, Right] } 39 | 40 | 41 | def joinOnConditionImpl[Join <: JoinType : Type, MergeSchemas[S1, S2] : Type, Left <: StructDataFrame[?] : Type, Right <: StructDataFrame[?] : Type](using Quotes) = 42 | import quotes.reflect.* 43 | 44 | Type.of[Left] match 45 | case '[StructDataFrame[s1]] => 46 | Type.of[Right] match 47 | case '[StructDataFrame[s2]] => 48 | Type.of[FrameSchema.Merge[s1, s2]] match 49 | case '[viewSchema] if FrameSchema.isValidType(Type.of[viewSchema]) => 50 | Type.of[MergeSchemas[s1, s2]] match 51 | case '[joinedSchema] if FrameSchema.isValidType(Type.of[joinedSchema]) => 52 | val viewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[viewSchema]] 53 | viewExpr.asTerm.tpe.asType match 54 | case '[SchemaView.Subtype[v]] => 55 | '{ 56 | new JoinOnCondition[Join, Left, Right] { 57 | override type JoinedSchema = joinedSchema 58 | override type JoiningView = v 59 | val joiningView: JoiningView = (${ viewExpr }).asInstanceOf[v] 60 | } 61 | } 62 | 63 | implicit def joinOnConditionOps[T <: JoinType](join: Join[T])(using joc: JoinOnCondition[T, join.Left, join.Right]): JoinOnConditionOps[T, joc.JoiningView, joc.JoinedSchema] = 64 | new JoinOnConditionOps[T, joc.JoiningView, joc.JoinedSchema](join, joc.joiningView) 65 | 66 | class JoinOnConditionOps[T <: JoinType, JoiningView <: SchemaView, JoinedSchema](join: Join[T], joiningView: JoiningView): 67 | inline def on[Condition](condition: JoiningView ?=> Condition): StructDataFrame[JoinedSchema] = 68 | ${ joinOnImpl[T, JoiningView, JoinedSchema, Condition]('join, 'joiningView, 'condition) } 69 | 70 | def joinOnImpl[T <: JoinType : Type, JoiningView <: SchemaView : Type, JoinedSchema : Type, Condition : Type]( 71 | join: Expr[Join[?]], joiningView: Expr[JoiningView], condition: Expr[JoiningView ?=> Condition] 72 | )(using Quotes) = 73 | import quotes.reflect.* 74 | 75 | '{ ${ condition }(using ${ joiningView }) } match 76 | case '{ $cond: Col[BooleanOptLike] } => 77 | '{ 78 | val joined = ${ join }.left.join(${ join }.right, ${ cond }.untyped, JoinType.typeName[T]) 79 | StructDataFrame[JoinedSchema](joined) 80 | } 81 | case '{ $cond: condType } => 82 | val errorMsg = s"""The join condition of `on` has to be a (potentially nullable) boolean column but it has type: ${Type.show[condType]}""" 83 | // TODO: improve error position 84 | report.errorAndAbort(errorMsg) 85 | -------------------------------------------------------------------------------- /src/main/MacroHelpers.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | 5 | private[iskra] object MacroHelpers: 6 | def callPosition(ownerExpr: Expr[?])(using Quotes): quotes.reflect.Position = 7 | import quotes.reflect.* 8 | val file = Position.ofMacroExpansion.sourceFile 9 | val start = ownerExpr.asTerm.pos.end 10 | val end = Position.ofMacroExpansion.end 11 | Position(file, start, end) 12 | 13 | type TupleSubtype[T <: Tuple] = T 14 | -------------------------------------------------------------------------------- /src/main/Name.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | type Name = String & Singleton 4 | object Name: 5 | type Subtype[T <: Name] = T 6 | 7 | //TODO: Reverse Scala's mangling of operators, e.g. currectly the name for `a!` is `a$bang` 8 | def escape(name: String) = s"`${name}`" -------------------------------------------------------------------------------- /src/main/Repeated.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | type Repeated[A] = 4 | A 5 | | (A, A) 6 | | (A, A, A) 7 | | (A, A, A, A) 8 | | (A, A, A, A, A) 9 | | (A, A, A, A, A, A) 10 | | (A, A, A, A, A, A, A) 11 | | (A, A, A, A, A, A, A, A) 12 | | (A, A, A, A, A, A, A, A, A) 13 | | (A, A, A, A, A, A, A, A, A, A) 14 | | (A, A, A, A, A, A, A, A, A, A, A) 15 | | (A, A, A, A, A, A, A, A, A, A, A, A) 16 | | (A, A, A, A, A, A, A, A, A, A, A, A, A) 17 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A) 18 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) 19 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) 20 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) 21 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) 22 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) 23 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) 24 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) 25 | | (A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A, A) // 22 is maximal arity 26 | -------------------------------------------------------------------------------- /src/main/SchemaView.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import org.apache.spark.sql.functions.col 5 | import types.DataType 6 | import MacroHelpers.TupleSubtype 7 | 8 | inline def $(using view: SchemaView): view.type = view 9 | 10 | trait SchemaView 11 | 12 | object SchemaView: 13 | type Subtype[T <: SchemaView] = T 14 | 15 | private[iskra] def exprForDataFrame[DF <: DataFrame : Type](using quotes: Quotes): Option[Expr[SchemaView]] = 16 | Type.of[DF] match 17 | case '[ClassDataFrame[a]] => 18 | Expr.summon[SchemaViewProvider[a]].map { 19 | case '{ $provider } => '{ ${ provider }.view } 20 | } 21 | case '[StructDataFrame.Subtype[df]] => 22 | Some(StructSchemaView.schemaViewExpr[df]) 23 | 24 | trait StructuralSchemaView extends SchemaView, Selectable: 25 | def selectDynamic(name: String): AliasedSchemaView | Column 26 | 27 | trait StructSchemaView extends StructuralSchemaView: 28 | def frameAliases: Seq[String] // TODO: get rid of this at runtime 29 | 30 | // TODO: What should be the semantics of `*`? How to handle ambiguous columns? 31 | // type AllColumns <: Tuple 32 | // def * : AllColumns 33 | 34 | override def selectDynamic(name: String): AliasedSchemaView | Column = 35 | if frameAliases.contains(name) 36 | then AliasedSchemaView(name) 37 | else Col[DataType](col(Name.escape(name))) 38 | 39 | 40 | object StructSchemaView: 41 | type Subtype[T <: StructSchemaView] = T 42 | 43 | private def refineType(using Quotes)(base: quotes.reflect.TypeRepr, refinements: List[(String, quotes.reflect.TypeRepr)]): quotes.reflect.TypeRepr = 44 | import quotes.reflect.* 45 | refinements match 46 | case Nil => base 47 | case (name, info) :: refinementsTail => 48 | val newBase = Refinement(base, name, info) 49 | refineType(newBase, refinementsTail) 50 | 51 | private def schemaViewType(using Quotes)(base: quotes.reflect.TypeRepr, schemaType: Type[?]): quotes.reflect.TypeRepr = 52 | import quotes.reflect.* 53 | schemaType match 54 | case '[EmptyTuple] => base 55 | case '[(headLabelPrefix / headLabelName := headType) *: tail] => // TODO: get rid of duplicates 56 | val nameType = Type.of[headLabelName] match 57 | case '[Name.Subtype[name]] => 58 | Type.of[name] 59 | case '[(Name.Subtype[framePrefix], Name.Subtype[name])] => 60 | Type.of[name] 61 | val name = nameType match 62 | case '[n] => Type.valueOfConstant[n].get.toString 63 | val newBase = Refinement(base, name, TypeRepr.of[Col[headType]]) 64 | schemaViewType(newBase, Type.of[tail]) 65 | 66 | // private def reifyColumns[T <: Tuple : Type](using Quotes): Expr[Tuple] = reifyCols(Type.of[T]) 67 | 68 | // private def reifyCols(using Quotes)(schemaType: Type[?]): Expr[Tuple] = 69 | // import quotes.reflect.* 70 | // schemaType match 71 | // case '[EmptyTuple] => '{ EmptyTuple } 72 | // case '[(headLabel1 := headType) *: tail] => 73 | // headLabel1 match 74 | // case '[Name.Subtype[name]] => // TODO: handle frame prefixes 75 | // val label = Expr(Type.valueOfConstant[name].get.toString) 76 | // '{ Col[Nothing](col(Name.escape(${ label }))) *: ${ reifyCols(Type.of[tail]) } } 77 | 78 | def schemaViewExpr[DF <: StructDataFrame[?] : Type](using Quotes): Expr[StructSchemaView] = 79 | import quotes.reflect.* 80 | Type.of[DF] match 81 | case '[StructDataFrame[schema]] => 82 | val schemaType = Type.of[FrameSchema.AsTuple[schema]] 83 | val aliasViewsByName = frameAliasViewsByName(schemaType) 84 | val columns = unambiguousColumns(schemaType) 85 | val frameAliasNames = Expr(aliasViewsByName.map(_._1)) 86 | val baseType = TypeRepr.of[StructSchemaView] 87 | val viewType = refineType(refineType(baseType, columns), aliasViewsByName) // TODO: conflicting name of frame alias and column? 88 | 89 | viewType.asType match 90 | case '[StructSchemaView.Subtype[t]] => 91 | '{ 92 | new StructSchemaView { 93 | override def frameAliases: Seq[String] = ${ frameAliasNames } 94 | // TODO: Reintroduce `*` selector 95 | // type AllColumns = schema 96 | // override def * : AllColumns = ${ reifyColumns[schema] }.asInstanceOf[AllColumns] 97 | }.asInstanceOf[t] 98 | } 99 | 100 | def allPrefixedColumns(using Quotes)(schemaType: Type[?]): List[(String, (String, quotes.reflect.TypeRepr))] = 101 | import quotes.reflect.* 102 | 103 | schemaType match 104 | case '[EmptyTuple] => List.empty 105 | case '[(Name.Subtype[name] := dataType) *: tail] => 106 | allPrefixedColumns(Type.of[tail]) 107 | case '[(framePrefix / name := dataType) *: tail] => 108 | val prefix = Type.valueOfConstant[framePrefix].get.toString 109 | val colName = Type.valueOfConstant[name].get.toString 110 | (prefix -> (colName -> TypeRepr.of[Col[dataType]])) :: allPrefixedColumns(Type.of[tail]) 111 | 112 | // TODO Show this case to users as propagated error 113 | case _ => 114 | List.empty 115 | 116 | def frameAliasViewsByName(using Quotes)(schemaType: Type[?]): List[(String, quotes.reflect.TypeRepr)] = 117 | import quotes.reflect.* 118 | allPrefixedColumns(schemaType).groupBy(_._1).map { (frameName, values) => 119 | val columnsTypes = values.map(_._2) 120 | frameName -> refineType(TypeRepr.of[AliasedSchemaView], columnsTypes) 121 | }.toList 122 | 123 | def unambiguousColumns(using Quotes)(schemaType: Type[?]): List[(String, quotes.reflect.TypeRepr)] = 124 | allColumns(schemaType).groupBy(_._1).collect { 125 | case (name, List((_, col))) => name -> col 126 | }.toList 127 | 128 | def allColumns(using Quotes)(schemaType: Type[?]): List[(String, quotes.reflect.TypeRepr)] = 129 | import quotes.reflect.* 130 | schemaType match 131 | case '[EmptyTuple] => List.empty 132 | case '[(Name.Subtype[name] := dataType) *: tail] => 133 | val colName = Type.valueOfConstant[name].get.toString 134 | val namedColumn = colName -> TypeRepr.of[Col[dataType]] 135 | namedColumn :: allColumns(Type.of[tail]) 136 | case '[((Name.Subtype[framePrefix] / Name.Subtype[name]) := dataType) *: tail] => 137 | val colName = Type.valueOfConstant[name].get.toString 138 | val namedColumn = colName -> TypeRepr.of[Col[dataType]] 139 | namedColumn :: allColumns(Type.of[tail]) 140 | 141 | // TODO Show this case to users as propagated error 142 | case _ => 143 | List.empty 144 | 145 | class AliasedSchemaView(frameAliasName: String) extends StructuralSchemaView: 146 | override def selectDynamic(name: String): Column = 147 | val columnName = s"${Name.escape(frameAliasName)}.${Name.escape(name)}" 148 | Col[DataType](col(columnName)) -------------------------------------------------------------------------------- /src/main/SchemaViewProvider.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import types.StructEncoder 5 | 6 | trait SchemaViewProvider[A]: 7 | type View <: SchemaView 8 | def view: View 9 | 10 | object SchemaViewProvider: 11 | transparent inline given derived[A]: SchemaViewProvider[A] = ${ derivedImpl[A]} 12 | 13 | private def derivedImpl[A : Type](using Quotes): Expr[SchemaViewProvider[A]] = 14 | import quotes.reflect.* 15 | 16 | Expr.summon[StructEncoder[A]] match 17 | case Some(encoder) => encoder match 18 | case '{ $enc: StructEncoder[A] { type StructSchema = structSchema } } => 19 | val schemaView = StructSchemaView.schemaViewExpr[StructDataFrame[structSchema]] 20 | schemaView.asTerm.tpe.asType match 21 | case '[SchemaView.Subtype[v]] => 22 | '{ 23 | new SchemaViewProvider[A] { 24 | override type View = v 25 | override def view = ${ schemaView }.asInstanceOf[v] 26 | } 27 | } 28 | case None => 29 | report.errorAndAbort(s"SchemaViewProvider cannot be derived for type ${Type.show[A]} because a given instance of Encoder is missing") 30 | -------------------------------------------------------------------------------- /src/main/Select.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | 5 | class Select[View <: SchemaView](val view: View, val underlying: UntypedDataFrame) 6 | 7 | object Select: 8 | given dataFrameSelectOps: {} with 9 | extension [DF <: DataFrame](df: DF) 10 | transparent inline def select: Select[?] = ${ selectImpl[DF]('{df}) } 11 | 12 | private def selectImpl[DF <: DataFrame : Type](df: Expr[DF])(using Quotes): Expr[Select[?]] = 13 | import quotes.reflect.{asTerm, report} 14 | 15 | val schemaView = SchemaView.exprForDataFrame[DF].getOrElse( 16 | report.errorAndAbort(s"A given instance of SchemaViewProvider for the model type of ${Type.show[DF]} is required to make `.select` possible") 17 | ) 18 | 19 | schemaView.asTerm.tpe.asType match 20 | case '[SchemaView.Subtype[v]] => 21 | '{ 22 | new Select[v]( 23 | view = ${ schemaView }.asInstanceOf[v], 24 | underlying = ${ df }.untyped 25 | ) 26 | } 27 | 28 | given selectOps: {} with 29 | extension [View <: SchemaView](select: Select[View]) 30 | transparent inline def apply[C <: NamedColumns](columns: View ?=> C): StructDataFrame[?] = 31 | ${ applyImpl[View, C]('select, 'columns) } 32 | 33 | private def applyImpl[View <: SchemaView : Type, C : Type](using Quotes)(select: Expr[Select[View]], columns: Expr[View ?=> C]) = 34 | import quotes.reflect.* 35 | 36 | Expr.summon[CollectColumns[C]] match 37 | case Some(collectColumns) => 38 | collectColumns match 39 | case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => 40 | Type.of[FrameSchema.FromTuple[collectedColumns]] match 41 | case '[s] => 42 | '{ 43 | val cols = ${ cc }.underlyingColumns(${ columns }(using ${ select }.view)) 44 | StructDataFrame[s](${ select }.underlying.select(cols*)) 45 | } 46 | case None => 47 | throw CollectColumns.CannotCollectColumns(Type.show[C]) 48 | -------------------------------------------------------------------------------- /src/main/StructDataFrame.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | 5 | import types.{DataType, Encoder, StructEncoder} 6 | import MacroHelpers.TupleSubtype 7 | 8 | 9 | class StructDataFrame[Schema](val untyped: UntypedDataFrame) extends DataFrame 10 | 11 | object StructDataFrame: 12 | type Subtype[T <: StructDataFrame[?]] = T 13 | type WithAlias[T <: String & Singleton] = StructDataFrame[?] { type Alias = T } 14 | 15 | extension [Schema](df: StructDataFrame[Schema]) 16 | inline def asClass[A]: ClassDataFrame[A] = ${ asClassImpl[Schema, A]('df) } 17 | 18 | private def asClassImpl[FrameSchema : Type, A : Type](df: Expr[StructDataFrame[FrameSchema]])(using Quotes): Expr[ClassDataFrame[A]] = 19 | import quotes.reflect.report 20 | 21 | Expr.summon[Encoder[A]] match 22 | case Some(encoder) => encoder match 23 | case '{ $enc: StructEncoder[A] { type StructSchema = structSchema } } => 24 | val frameSchemaTuple = Type.of[FrameSchema] match 25 | case '[TupleSubtype[t]] => 26 | Type.of[t] 27 | case '[t] => 28 | Type.of[t *: EmptyTuple] 29 | 30 | frameSchemaTuple match 31 | case '[`structSchema`] => 32 | '{ ClassDataFrame[A](${ df }.untyped) } 33 | case _ => 34 | val frameColumns = allColumns(Type.of[FrameSchema]) 35 | val structColumns = allColumns(Type.of[structSchema]) 36 | val errorMsg = s"A structural data frame with columns:\n${showColumns(frameColumns)}\nis not equivalent to a class data frame of ${Type.show[A]}, which would be encoded as a row with columns:\n${showColumns(structColumns)}" 37 | quotes.reflect.report.errorAndAbort(errorMsg) 38 | case '{ $enc: Encoder[A] { type ColumnType = colType } } => 39 | def fromDataType[T : Type] = 40 | Type.of[T] match 41 | case '[`colType`] => 42 | '{ ClassDataFrame[A](${ df }.untyped) } 43 | case '[t] => 44 | val frameColumns = allColumns(Type.of[FrameSchema]) 45 | val errorMsg = s"A structural data frame with columns:\n${showColumns(frameColumns)}\nis not equivalent to a class data frame of ${Type.show[A]}" 46 | quotes.reflect.report.errorAndAbort(errorMsg) 47 | Type.of[FrameSchema] match 48 | case '[label := dataType] => 49 | fromDataType[dataType] 50 | case '[(label := dataType) *: EmptyTuple] => 51 | fromDataType[dataType] 52 | case '[t] => 53 | val frameColumns = allColumns(Type.of[FrameSchema]) 54 | val errorMsg = s"A structural data frame with columns:\n${showColumns(frameColumns)}\nis not equivalent to a class data frame of ${Type.show[A]}" 55 | quotes.reflect.report.errorAndAbort(errorMsg) 56 | case None => report.errorAndAbort(s"Could not summon encoder for ${Type.show[A]}") 57 | 58 | private def allColumns(schemaType: Type[?])(using Quotes): Seq[Type[?]] = 59 | schemaType match 60 | case '[prefix / suffix := dataType] => Seq(Type.of[suffix := dataType]) 61 | case '[Name.Subtype[suffix] := dataType] => Seq(Type.of[suffix := dataType]) 62 | case '[EmptyTuple] => Seq.empty 63 | case '[head *: tail] => allColumns(Type.of[head]) ++ allColumns(Type.of[tail]) 64 | 65 | private def showColumns(columnsTypes: Seq[Type[?]])(using Quotes): String = 66 | val columns = columnsTypes.map { 67 | case '[label := dataType] => 68 | val shortDataType = Type.show[dataType].split("\\.").last 69 | s"${Type.show[label]} := ${shortDataType}" 70 | } 71 | columns.mkString(", ") 72 | -------------------------------------------------------------------------------- /src/main/UntypedOps.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import types.{DataType, Encoder, struct, StructEncoder, StructNotNull} 5 | 6 | object UntypedOps: 7 | extension (untyped: UntypedColumn) 8 | def typed[A <: DataType] = Col[A](untyped) 9 | 10 | extension (df: UntypedDataFrame) 11 | transparent inline def typed[A](using encoder: StructEncoder[A]): ClassDataFrame[?] = ${ typedDataFrameImpl('df, 'encoder) } // TODO: Check schema at runtime? Check if names of columns match? 12 | 13 | private def typedDataFrameImpl[A : Type](df: Expr[UntypedDataFrame], encoder: Expr[StructEncoder[A]])(using Quotes) = 14 | encoder match 15 | case '{ ${e}: Encoder.Aux[tpe, StructNotNull[t]] } => 16 | '{ ClassDataFrame[A](${ df }) } 17 | -------------------------------------------------------------------------------- /src/main/When.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import org.apache.spark.sql.{functions => f, Column => UntypedColumn} 4 | import org.virtuslab.iskra.types.{Coerce, DataType, BooleanOptLike} 5 | 6 | object When: 7 | class WhenColumn[T <: DataType](untyped: UntypedColumn) extends Col[DataType.AsNullable[T]](untyped): 8 | def when[U <: DataType](condition: Col[BooleanOptLike], value: Col[U])(using coerce: Coerce[T, U]): WhenColumn[coerce.Coerced] = 9 | WhenColumn(this.untyped.when(condition.untyped, value.untyped)) 10 | def otherwise[U <: DataType](value: Col[U])(using coerce: Coerce[T, U]): Col[coerce.Coerced] = 11 | Col(this.untyped.otherwise(value.untyped)) 12 | 13 | def when[T <: DataType](condition: Col[BooleanOptLike], value: Col[T]): WhenColumn[T] = 14 | WhenColumn(f.when(condition.untyped, value.untyped)) 15 | -------------------------------------------------------------------------------- /src/main/Where.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import org.virtuslab.iskra.types.BooleanOptLike 5 | 6 | trait Where[Schema, View <: SchemaView]: 7 | val view: View 8 | def underlying: UntypedDataFrame 9 | 10 | object Where: 11 | given dataFrameWhereOps: {} with 12 | extension [Schema](df: StructDataFrame[Schema]) 13 | transparent inline def where: Where[Schema, ?] = ${ Where.whereImpl[Schema]('{df}) } 14 | 15 | def whereImpl[Schema : Type](df: Expr[StructDataFrame[Schema]])(using Quotes): Expr[Where[Schema, ?]] = 16 | import quotes.reflect.asTerm 17 | val viewExpr = StructSchemaView.schemaViewExpr[StructDataFrame[Schema]] 18 | viewExpr.asTerm.tpe.asType match 19 | case '[SchemaView.Subtype[v]] => 20 | '{ 21 | new Where[Schema, v] { 22 | val view = ${ viewExpr }.asInstanceOf[v] 23 | def underlying = ${ df }.untyped 24 | } 25 | } 26 | 27 | given whereApply: {} with 28 | extension [Schema, View <: SchemaView](where: Where[Schema, View]) 29 | inline def apply[Condition](condition: View ?=> Condition): StructDataFrame[Schema] = 30 | ${ Where.applyImpl[Schema, View, Condition]('where, 'condition) } 31 | 32 | def applyImpl[Schema : Type, View <: SchemaView : Type, Condition : Type]( 33 | where: Expr[Where[Schema, View]], 34 | condition: Expr[View ?=> Condition] 35 | )(using Quotes): Expr[StructDataFrame[Schema]] = 36 | import quotes.reflect.* 37 | 38 | '{ ${ condition }(using ${ where }.view) } match 39 | case '{ $cond: Col[BooleanOptLike] } => 40 | '{ 41 | val filtered = ${ where }.underlying.where(${ cond }.untyped) 42 | StructDataFrame[Schema](filtered) 43 | } 44 | case '{ $cond: condType } => 45 | val errorMsg = s"""The filtering condition of `where` has to be a (potentially nullable) boolean column but it has type: ${Type.show[condType]}""" 46 | report.errorAndAbort(errorMsg, MacroHelpers.callPosition(where)) 47 | -------------------------------------------------------------------------------- /src/main/WithColumns.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | import scala.quoted.* 4 | import MacroHelpers.TupleSubtype 5 | 6 | class WithColumns[Schema, View <: SchemaView](val view: View, val underlying: UntypedDataFrame) 7 | 8 | object WithColumns: 9 | given dataFrameWithColumnsOps: {} with 10 | extension [Schema, DF <: StructDataFrame[Schema]](df: DF) 11 | transparent inline def withColumns: WithColumns[Schema, ?] = ${ withColumnsImpl[Schema, DF]('{df}) } 12 | 13 | def withColumnsImpl[Schema : Type, DF <: StructDataFrame[Schema] : Type](df: Expr[DF])(using Quotes): Expr[WithColumns[Schema, ?]] = 14 | import quotes.reflect.asTerm 15 | val viewExpr = StructSchemaView.schemaViewExpr[DF] 16 | viewExpr.asTerm.tpe.asType match 17 | case '[SchemaView.Subtype[v]] => 18 | '{ 19 | new WithColumns[Schema, v]( 20 | view = ${ viewExpr }.asInstanceOf[v], 21 | underlying = ${ df }.untyped 22 | ) 23 | } 24 | 25 | given withColumnsApply: {} with 26 | extension [Schema <: Tuple, View <: SchemaView](withColumns: WithColumns[Schema, View]) 27 | transparent inline def apply[C <: NamedColumns](columns: View ?=> C): StructDataFrame[?] = 28 | ${ applyImpl[Schema, View, C]('withColumns, 'columns) } 29 | 30 | private def applyImpl[Schema <: Tuple : Type, View <: SchemaView : Type, C : Type]( 31 | withColumns: Expr[WithColumns[Schema, View]], 32 | columns: Expr[View ?=> C] 33 | )(using Quotes): Expr[StructDataFrame[?]] = 34 | import quotes.reflect.* 35 | 36 | Expr.summon[CollectColumns[C]] match 37 | case Some(collectColumns) => 38 | collectColumns match 39 | case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } => 40 | Type.of[collectedColumns] match 41 | case '[TupleSubtype[collectedCols]] => 42 | '{ 43 | val cols = 44 | org.apache.spark.sql.functions.col("*") +: ${ cc }.underlyingColumns(${ columns }(using ${ withColumns }.view)) 45 | val withColumnsAppended = 46 | ${ withColumns }.underlying.select(cols*) 47 | StructDataFrame[Tuple.Concat[Schema, collectedCols]](withColumnsAppended) 48 | } 49 | case None => 50 | throw CollectColumns.CannotCollectColumns(Type.show[C]) 51 | -------------------------------------------------------------------------------- /src/main/api/api.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | package api 3 | 4 | export DataFrameBuilders.toDF 5 | export types.{ 6 | boolean, 7 | boolean_?, 8 | BooleanNotNull, 9 | BooleanOrNull, 10 | string, 11 | string_?, 12 | StringNotNull, 13 | StringOrNull, 14 | byte, 15 | byte_?, 16 | ByteNotNull, 17 | ByteOrNull, 18 | short, 19 | short_?, 20 | ShortNotNull, 21 | ShortOrNull, 22 | int, 23 | int_?, 24 | IntNotNull, 25 | IntOrNull, 26 | long, 27 | long_?, 28 | LongNotNull, 29 | LongOrNull, 30 | float, 31 | float_?, 32 | FloatNotNull, 33 | FloatOrNull, 34 | double, 35 | double_?, 36 | DoubleNotNull, 37 | DoubleOrNull, 38 | struct, 39 | struct_?, 40 | StructNotNull, 41 | StructOrNull 42 | } 43 | export UntypedOps.typed 44 | export org.virtuslab.iskra.$ 45 | export org.virtuslab.iskra.{Column, Columns, Col, DataFrame, ClassDataFrame, NamedColumns, StructDataFrame, UntypedColumn, UntypedDataFrame, :=, /} 46 | 47 | object functions: 48 | export org.virtuslab.iskra.functions.{lit, when} 49 | export org.virtuslab.iskra.functions.Aggregates.* 50 | 51 | export org.apache.spark.sql.SparkSession 52 | -------------------------------------------------------------------------------- /src/main/functions/aggregates.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.functions 2 | 3 | import org.apache.spark.sql 4 | import org.virtuslab.iskra.Agg 5 | import org.virtuslab.iskra.Col 6 | import org.virtuslab.iskra.UntypedOps.typed 7 | import org.virtuslab.iskra.types.* 8 | import org.virtuslab.iskra.types.DataType.AsNullable 9 | 10 | class Sum[A <: Agg](val agg: A): 11 | def apply[T <: DoubleOptLike](column: agg.View ?=> Col[T]): Col[AsNullable[T]] = 12 | sql.functions.sum(column(using agg.view).untyped).typed 13 | 14 | class Max[A <: Agg](val agg: A): 15 | def apply[T <: DoubleOptLike](column: agg.View ?=> Col[T]): Col[AsNullable[T]] = 16 | sql.functions.max(column(using agg.view).untyped).typed 17 | 18 | class Min[A <: Agg](val agg: A): 19 | def apply[T <: DoubleOptLike](column: agg.View ?=> Col[T]): Col[AsNullable[T]] = 20 | sql.functions.min(column(using agg.view).untyped).typed 21 | 22 | class Avg[A <: Agg](val agg: A): 23 | def apply(column: agg.View ?=> Col[DoubleOptLike]): Col[DoubleOrNull] = 24 | sql.functions.avg(column(using agg.view).untyped).typed 25 | 26 | object Aggregates: 27 | def sum(using agg: Agg): Sum[agg.type] = new Sum(agg) 28 | def max(using agg: Agg): Max[agg.type] = new Max(agg) 29 | def min(using agg: Agg): Min[agg.type] = new Min(agg) 30 | def avg(using agg: Agg): Avg[agg.type] = new Avg(agg) 31 | -------------------------------------------------------------------------------- /src/main/functions/lit.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.functions 2 | 3 | import org.apache.spark.sql 4 | import org.virtuslab.iskra.Col 5 | import org.virtuslab.iskra.types.PrimitiveEncoder 6 | 7 | def lit[A](value: A)(using encoder: PrimitiveEncoder[A]): Col[encoder.ColumnType] = Col(sql.functions.lit(encoder.encode(value))) 8 | -------------------------------------------------------------------------------- /src/main/functions/when.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | package functions 3 | 4 | export When.when 5 | -------------------------------------------------------------------------------- /src/main/types/Coerce.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | package types 3 | 4 | trait Coerce[A <: DataType, B <: DataType]: 5 | type Coerced <: DataType 6 | 7 | object Coerce extends CoerceLowPrio: 8 | given sameType[A <: FinalDataType]: Coerce[A, A] with 9 | override type Coerced = A 10 | 11 | given nullableFirst[A <: FinalDataType & Nullable, B <: FinalDataType & NonNullable](using A <:< NullableOf[B]): Coerce[A, B] with 12 | override type Coerced = A 13 | 14 | given nullableSecond[A <: FinalDataType & NonNullable, B <: FinalDataType & Nullable](using A <:< NonNullableOf[B]): Coerce[A, B] with 15 | override type Coerced = B 16 | 17 | trait CoerceLowPrio: 18 | given numeric[A <: FinalDataType & DoubleOptLike, B <: FinalDataType & DoubleOptLike]: (Coerce[A, B] { type Coerced = CommonNumericType[A, B] }) = 19 | new Coerce[A, B]: 20 | override type Coerced = CommonNumericType[A, B] 21 | -------------------------------------------------------------------------------- /src/main/types/DataType.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | package types 3 | 4 | sealed trait Nullability 5 | 6 | sealed trait Nullable extends Nullability 7 | sealed trait NonNullable extends Nullability 8 | 9 | trait NullableOf[T <: DataType & NonNullable] extends Nullable 10 | trait NonNullableOf[T <: DataType & Nullable] extends NonNullable 11 | 12 | 13 | trait DataType 14 | 15 | abstract class FinalDataType extends DataType { 16 | self: Nullability => 17 | } 18 | 19 | object DataType: 20 | type Subtype[T <: DataType] = T 21 | 22 | type AsNullable[T <: DataType] <: DataType = T match 23 | case NonNullableOf[t] => t 24 | case Nullable => T 25 | 26 | 27 | trait BooleanOptLike extends DataType 28 | trait BooleanLike extends BooleanOptLike 29 | final class boolean_? extends FinalDataType, NullableOf[boolean], BooleanOptLike 30 | final class boolean extends FinalDataType, NonNullableOf[boolean_?], BooleanLike 31 | type BooleanOrNull = boolean_? 32 | type BooleanNotNull = boolean 33 | 34 | trait StringOptLike extends DataType 35 | trait StringLike extends StringOptLike 36 | final class string_? extends FinalDataType, NullableOf[string], StringOptLike 37 | final class string extends FinalDataType, NonNullableOf[string_?], StringLike 38 | type StringOrNull = string_? 39 | type StringNotNull = string 40 | 41 | trait DoubleOptLike extends DataType 42 | trait DoubleLike extends DoubleOptLike 43 | final class double_? extends FinalDataType, NullableOf[double], DoubleOptLike 44 | final class double extends FinalDataType, NonNullableOf[double_?], DoubleLike 45 | type DoubleOrNull = double_? 46 | type DoubleNotNull = double 47 | 48 | trait FloatOptLike extends DoubleOptLike 49 | trait FloatLike extends FloatOptLike, DoubleLike 50 | final class float_? extends FinalDataType, NullableOf[float], FloatOptLike 51 | final class float extends FinalDataType, NonNullableOf[float_?], FloatLike 52 | type FloatOrNull = float_? 53 | type FloatNotNull = float 54 | 55 | trait LongOptLike extends FloatOptLike 56 | trait LongLike extends LongOptLike, FloatLike 57 | final class long_? extends FinalDataType, NullableOf[long], LongOptLike 58 | final class long extends FinalDataType, NonNullableOf[long_?], LongLike 59 | type LongOrNull = long_? 60 | type LongNotNull = long 61 | 62 | trait IntOptLike extends LongOptLike 63 | trait IntLike extends IntOptLike, LongLike 64 | final class int_? extends FinalDataType, NullableOf[int], IntOptLike 65 | final class int extends FinalDataType, NonNullableOf[int_?], IntLike 66 | type IntOrNull = int_? 67 | type IntNotNull = int 68 | 69 | trait ShortOptLike extends IntOptLike 70 | trait ShortLike extends ShortOptLike, IntLike 71 | final class short_? extends FinalDataType, NullableOf[short], ShortOptLike 72 | final class short extends FinalDataType, NonNullableOf[short_?], ShortLike 73 | type ShortOrNull = short_? 74 | type ShortNotNull = short 75 | 76 | trait ByteOptLike extends ShortOptLike 77 | trait ByteLike extends ByteOptLike, ShortLike 78 | final class byte_? extends FinalDataType, NullableOf[byte], ByteOptLike 79 | final class byte extends FinalDataType, NonNullableOf[byte_?], ByteLike 80 | type ByteOrNull = byte_? 81 | type ByteNotNull = byte 82 | 83 | trait StructOptLike[Schema <: Tuple] extends DataType 84 | trait StructLike[Schema <: Tuple] extends StructOptLike[Schema] 85 | final class struct_?[Schema <: Tuple] extends FinalDataType, NullableOf[struct[Schema]], StructOptLike[Schema] 86 | final class struct[Schema <: Tuple] extends FinalDataType, NonNullableOf[struct_?[Schema]], StructLike[Schema] 87 | type StructOrNull[Schema <: Tuple] = struct_?[Schema] 88 | type StructNotNull[Schema <: Tuple] = struct[Schema] 89 | 90 | type CommonNumericType[T1 <: DataType, T2 <: DataType] <: DataType = (T1, T2) match 91 | case (ByteLike, ByteLike) => byte 92 | case (ByteOptLike, ByteOptLike) => byte_? 93 | case (ShortLike, ShortLike) => short 94 | case (ShortOptLike, ShortOptLike) => short_? 95 | case (IntLike, IntLike) => int 96 | case (IntOptLike, IntOptLike) => int_? 97 | case (LongLike, LongLike) => long 98 | case (LongOptLike, LongOptLike) => long_? 99 | case (FloatLike, FloatLike) => float 100 | case (FloatOptLike, FloatOptLike) => float_? 101 | case (DoubleLike, DoubleLike) => double 102 | case (DoubleOptLike, DoubleOptLike) => double_? 103 | 104 | type CommonBooleanType[T1 <: DataType, T2 <: DataType] <: DataType = (T1, T2) match 105 | case (BooleanLike, BooleanLike) => BooleanNotNull 106 | case (BooleanOptLike, BooleanOptLike) => BooleanOrNull 107 | 108 | type CommonNullability[T1 <: Nullability, T2 <: Nullability] <: Nullability = (T1, T2) match 109 | case (NonNullable, NonNullable) => NonNullable 110 | case _ => Nullable 111 | 112 | type BooleanOfNullability[N <: Nullability] <: DataType = N match 113 | case NonNullable => BooleanNotNull 114 | case Nullable => BooleanOrNull 115 | 116 | type BooleanOfCommonNullability[T1, T2] <: DataType = (T1, T2) match 117 | case (NonNullable, NonNullable) => BooleanNotNull 118 | case (Nullability, Nullability) => BooleanOrNull 119 | 120 | type DoubleOfCommonNullability[T1 <: DoubleOptLike, T2 <: DoubleOptLike] <: DataType = (T1, T2) match 121 | case (DoubleLike, DoubleLike) => DoubleNotNull 122 | case (DoubleOptLike, DoubleOptLike) => DoubleOrNull 123 | 124 | type StringOfCommonNullability[T1 <: StringOptLike, T2 <: StringOptLike] <: DataType = (T1, T2) match 125 | case (StringLike, StringLike) => StringNotNull 126 | case (StringOptLike, StringOptLike) => StringOrNull 127 | -------------------------------------------------------------------------------- /src/main/types/Encoder.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | package types 3 | 4 | import scala.quoted._ 5 | import scala.deriving.Mirror 6 | import org.apache.spark.sql 7 | import MacroHelpers.TupleSubtype 8 | 9 | 10 | trait Encoder[-A]: 11 | type ColumnType <: DataType 12 | def encode(value: A): Any 13 | def decode(value: Any): Any 14 | def catalystType: sql.types.DataType 15 | def isNullable: Boolean 16 | 17 | trait PrimitiveEncoder[-A] extends Encoder[A] 18 | 19 | trait PrimitiveNullableEncoder[-A] extends PrimitiveEncoder[Option[A]]: 20 | def encode(value: Option[A]) = value.orNull 21 | def decode(value: Any) = Option(value) 22 | def isNullable = true 23 | 24 | trait PrimitiveNonNullableEncoder[-A] extends PrimitiveEncoder[A]: 25 | def encode(value: A) = value 26 | def decode(value: Any) = value 27 | def isNullable = false 28 | 29 | 30 | object Encoder: 31 | type Aux[-A, E <: DataType] = Encoder[A] { type ColumnType = E } 32 | 33 | inline given booleanEncoder: PrimitiveNonNullableEncoder[Boolean] with 34 | type ColumnType = BooleanNotNull 35 | def catalystType = sql.types.BooleanType 36 | inline given booleanOptEncoder: PrimitiveNullableEncoder[Boolean] with 37 | type ColumnType = BooleanOrNull 38 | def catalystType = sql.types.BooleanType 39 | 40 | inline given stringEncoder: PrimitiveNonNullableEncoder[String] with 41 | type ColumnType = StringNotNull 42 | def catalystType = sql.types.StringType 43 | inline given stringOptEncoder: PrimitiveNullableEncoder[String] with 44 | type ColumnType = StringOrNull 45 | def catalystType = sql.types.StringType 46 | 47 | inline given byteEncoder: PrimitiveNonNullableEncoder[Byte] with 48 | type ColumnType = ByteNotNull 49 | def catalystType = sql.types.ByteType 50 | inline given byteOptEncoder: PrimitiveNullableEncoder[Byte] with 51 | type ColumnType = ByteOrNull 52 | def catalystType = sql.types.ByteType 53 | 54 | inline given shortEncoder: PrimitiveNonNullableEncoder[Short] with 55 | type ColumnType = ShortNotNull 56 | def catalystType = sql.types.ShortType 57 | inline given shortOptEncoder: PrimitiveNullableEncoder[Short] with 58 | type ColumnType = ShortOrNull 59 | def catalystType = sql.types.ShortType 60 | 61 | inline given intEncoder: PrimitiveNonNullableEncoder[Int] with 62 | type ColumnType = IntNotNull 63 | def catalystType = sql.types.IntegerType 64 | inline given intOptEncoder: PrimitiveNullableEncoder[Int] with 65 | type ColumnType = IntOrNull 66 | def catalystType = sql.types.IntegerType 67 | 68 | inline given longEncoder: PrimitiveNonNullableEncoder[Long] with 69 | type ColumnType = LongNotNull 70 | def catalystType = sql.types.LongType 71 | inline given longOptEncoder: PrimitiveNullableEncoder[Long] with 72 | type ColumnType = LongOrNull 73 | def catalystType = sql.types.LongType 74 | 75 | inline given floatEncoder: PrimitiveNonNullableEncoder[Float] with 76 | type ColumnType = FloatNotNull 77 | def catalystType = sql.types.FloatType 78 | inline given floatOptEncoder: PrimitiveNullableEncoder[Float] with 79 | type ColumnType = FloatOrNull 80 | def catalystType = sql.types.FloatType 81 | 82 | inline given doubleEncoder: PrimitiveNonNullableEncoder[Double] with 83 | type ColumnType = DoubleNotNull 84 | def catalystType = sql.types.DoubleType 85 | inline given doubleOptEncoder: PrimitiveNullableEncoder[Double] with 86 | type ColumnType = DoubleOrNull 87 | def catalystType = sql.types.DoubleType 88 | 89 | export StructEncoder.{fromMirror, optFromMirror} 90 | 91 | trait StructEncoder[-A] extends Encoder[A]: 92 | type StructSchema <: Tuple 93 | type ColumnType = StructNotNull[StructSchema] 94 | override def catalystType: sql.types.StructType 95 | override def encode(a: A): sql.Row 96 | 97 | object StructEncoder: 98 | type Aux[-A, E <: Tuple] = StructEncoder[A] { type StructSchema = E } 99 | 100 | private case class ColumnInfo( 101 | labelType: Type[?], 102 | labelValue: String, 103 | decodedType: Type[?], 104 | encoder: Expr[Encoder[?]] 105 | ) 106 | 107 | private def getColumnSchemaType(using quotes: Quotes)(subcolumnInfos: List[ColumnInfo]): Type[?] = 108 | subcolumnInfos match 109 | case Nil => Type.of[EmptyTuple] 110 | case info :: tail => 111 | info.labelType match 112 | case '[Name.Subtype[lt]] => 113 | info.encoder match 114 | case '{ ${encoder}: Encoder.Aux[tpe, DataType.Subtype[e]] } => 115 | getColumnSchemaType(tail) match 116 | case '[TupleSubtype[tailType]] => 117 | Type.of[(lt := e) *: tailType] 118 | 119 | private def getSubcolumnInfos(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[ColumnInfo] = 120 | import quotes.reflect.Select 121 | elemLabels match 122 | case '[EmptyTuple] => Nil 123 | case '[label *: labels] => 124 | val labelValue = Type.valueOfConstant[label].get.toString 125 | elemTypes match 126 | case '[tpe *: types] => 127 | Expr.summon[Encoder[tpe]] match 128 | case Some(encoderExpr) => 129 | ColumnInfo(Type.of[label], labelValue, Type.of[tpe], encoderExpr) :: getSubcolumnInfos(Type.of[labels], Type.of[types]) 130 | case _ => quotes.reflect.report.errorAndAbort(s"Could not summon encoder for ${Type.show[tpe]}") 131 | 132 | transparent inline given fromMirror[A]: StructEncoder[A] = ${ fromMirrorImpl[A] } 133 | 134 | def fromMirrorImpl[A : Type](using q: Quotes): Expr[StructEncoder[A]] = 135 | Expr.summon[Mirror.Of[A]].getOrElse(throw new Exception(s"Could not find Mirror when generating encoder for ${Type.show[A]}")) match 136 | case '{ ${mirror}: Mirror.ProductOf[A] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes } } => 137 | val subcolumnInfos = getSubcolumnInfos(Type.of[elementLabels], Type.of[elementTypes]) 138 | val columnSchemaType = getColumnSchemaType(subcolumnInfos) 139 | 140 | val structFieldExprs = subcolumnInfos.map { info => 141 | '{ sql.types.StructField(${Expr(info.labelValue)}, ${info.encoder}.catalystType, ${info.encoder}.isNullable) } 142 | } 143 | val structFields = Expr.ofSeq(structFieldExprs) 144 | 145 | def rowElements(value: Expr[A]) = 146 | subcolumnInfos.map { info => 147 | import quotes.reflect.* 148 | info.decodedType match 149 | case '[t] => 150 | '{ ${info.encoder}.asInstanceOf[Encoder[t]].encode(${ Select.unique(value.asTerm, info.labelValue).asExprOf[t] }) } 151 | } 152 | 153 | def rowElementsTuple(row: Expr[sql.Row]): Expr[Tuple] = 154 | val elements = subcolumnInfos.zipWithIndex.map { (info, idx) => 155 | given Quotes = q 156 | '{ ${info.encoder}.decode(${row}.get(${Expr(idx)})) } 157 | } 158 | Expr.ofTupleFromSeq(elements) 159 | 160 | columnSchemaType match 161 | case '[TupleSubtype[t]] => 162 | '{ 163 | (new StructEncoder[A] { 164 | override type StructSchema = t 165 | override def catalystType = sql.types.StructType(${ structFields }) 166 | override def isNullable = false 167 | override def encode(a: A) = 168 | sql.Row.fromSeq(${ Expr.ofSeq(rowElements('a)) }) 169 | override def decode(a: Any) = 170 | ${mirror}.fromProduct(${ rowElementsTuple('{a.asInstanceOf[sql.Row]}) }) 171 | }): StructEncoder[A] { type StructSchema = t } 172 | } 173 | end fromMirrorImpl 174 | 175 | given optFromMirror[A](using encoder: StructEncoder[A]): (Encoder[Option[A]] { type ColumnType = StructOrNull[encoder.StructSchema] }) = 176 | new Encoder[Option[A]]: 177 | override type ColumnType = StructOrNull[encoder.StructSchema] 178 | override def encode(value: Option[A]): Any = value.map(encoder.encode).orNull 179 | override def decode(value: Any): Any = Option(encoder.decode) 180 | override def catalystType = encoder.catalystType 181 | override def isNullable = true 182 | -------------------------------------------------------------------------------- /src/main/untyped.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | 3 | type UntypedDataFrame = org.apache.spark.sql.DataFrame 4 | type UntypedColumn = org.apache.spark.sql.Column 5 | type UntypedRelationalGroupedDataset = org.apache.spark.sql.RelationalGroupedDataset -------------------------------------------------------------------------------- /src/test/AggregatorsTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | class AggregatorsTest extends SparkUnitTest: 4 | import org.virtuslab.iskra.api.* 5 | import functions.* 6 | 7 | case class Foo(string: String, int: Int, intOpt: Option[Int], float: Float, floatOpt: Option[Float]) 8 | 9 | val foos = Seq( 10 | Foo("a", 1, Some(1), 1.0f, Some(1.0f)), 11 | Foo("a", 3, None, 3.0f, None) 12 | ).toDF.asStruct 13 | 14 | test("sum") { 15 | val result = foos.groupBy($.string).agg( 16 | sum($.int).as("_1"), 17 | sum($.intOpt).as("_2"), 18 | sum($.float).as("_3"), 19 | sum($.floatOpt).as("_4"), 20 | ) 21 | .select($._1, $._2, $._3, $._4) 22 | .asClass[(Option[Int], Option[Int], Option[Float], Option[Float])].collect().toList 23 | 24 | result shouldEqual List((Some(4), Some(1), Some(4.0f), Some(1.0f))) 25 | } 26 | 27 | test("max") { 28 | val result = foos.groupBy($.string).agg( 29 | max($.int).as("_1"), 30 | max($.intOpt).as("_2"), 31 | max($.float).as("_3"), 32 | max($.floatOpt).as("_4"), 33 | ) 34 | .select($._1, $._2, $._3, $._4) 35 | .asClass[(Option[Int], Option[Int], Option[Float], Option[Float])].collect().toList 36 | 37 | result shouldEqual List((Some(3), Some(1), Some(3.0f), Some(1.0f))) 38 | } 39 | 40 | test("min") { 41 | val result = foos.groupBy($.string).agg( 42 | min($.int).as("_1"), 43 | min($.intOpt).as("_2"), 44 | min($.float).as("_3"), 45 | min($.floatOpt).as("_4"), 46 | ) 47 | .select($._1, $._2, $._3, $._4) 48 | .asClass[(Option[Int], Option[Int], Option[Float], Option[Float])].collect().toList 49 | 50 | result shouldEqual List((Some(1), Some(1), Some(1.0f), Some(1.0f))) 51 | } 52 | 53 | test("avg") { 54 | val result = foos.groupBy($.string).agg( 55 | avg($.int).as("_1"), 56 | avg($.intOpt).as("_2"), 57 | avg($.float).as("_3"), 58 | avg($.floatOpt).as("_4"), 59 | ) 60 | .select($._1, $._2, $._3, $._4) 61 | .asClass[(Option[Double], Option[Double], Option[Double], Option[Double])].collect().toList 62 | 63 | result shouldEqual List((Some(2.0), Some(1.0), Some(2.0), Some(1.0))) 64 | } 65 | -------------------------------------------------------------------------------- /src/test/CoerceTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra 2 | package test 3 | 4 | import org.scalatest.funsuite.AnyFunSuite 5 | import types.* 6 | 7 | class CoerceTest extends AnyFunSuite: 8 | test("coerce-int-double") { 9 | val c = summon[Coerce[IntNotNull, DoubleNotNull]] 10 | summon[c.Coerced =:= DoubleNotNull] 11 | } 12 | 13 | test("coerce-short-short-opt") { 14 | val c = summon[Coerce[ShortNotNull, ShortOrNull]] 15 | summon[c.Coerced =:= ShortOrNull] 16 | } 17 | 18 | test("coerce-long-byte-opt") { 19 | val c = summon[Coerce[LongNotNull, ByteOrNull]] 20 | summon[c.Coerced =:= LongOrNull] 21 | } 22 | 23 | test("coerce-string-string-opt") { 24 | val c = summon[Coerce[StringNotNull, StringOrNull]] 25 | summon[c.Coerced =:= StringOrNull] 26 | } 27 | 28 | test("coerce-string-opt-string") { 29 | val c = summon[Coerce[StringOrNull, StringNotNull]] 30 | summon[c.Coerced =:= StringOrNull] 31 | } 32 | 33 | test("coerce-string-opt-string-opt") { 34 | val c = summon[Coerce[StringOrNull, StringOrNull]] 35 | summon[c.Coerced =:= StringOrNull] 36 | } 37 | -------------------------------------------------------------------------------- /src/test/ColumnsTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | import org.scalatest.funsuite.AnyFunSuite 4 | import org.scalatest.BeforeAndAfterAll 5 | import org.scalatest.matchers.should.Matchers.shouldEqual 6 | 7 | class ColumnsTest extends SparkUnitTest: 8 | import org.virtuslab.iskra.api.* 9 | 10 | case class Foo(x1: Int, x2: Int, x3: Int, x4: Int) 11 | 12 | val foos = Seq( 13 | Foo(1, 2, 3, 4) 14 | ).toDF.asStruct 15 | 16 | test("plus") { 17 | val result = foos.select { 18 | val cols1 = Columns($.x1) 19 | val cols2 = Columns($.x2, $.x3) 20 | (cols1, cols2, $.x4) 21 | }.asClass[Foo].collect().toList 22 | 23 | result shouldEqual List(Foo(1, 2, 3, 4)) 24 | } 25 | -------------------------------------------------------------------------------- /src/test/CompilationTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | import org.scalatest.funsuite.AnyFunSuite 4 | 5 | class CompilationTest extends AnyFunSuite: 6 | test("Select nonexistent column") { 7 | assertCompiles(""" 8 | |import org.virtuslab.iskra.api.* 9 | |case class Foo(string: String) 10 | |given spark: SparkSession = ??? 11 | |val elements = Seq(Foo("abc")).toDF.asStruct.select($.string) 12 | |""".stripMargin) 13 | 14 | assertDoesNotCompile(""" 15 | |import org.virtuslab.iskra.api.* 16 | |case class Foo(string: String) 17 | |given spark: SparkSession = ??? 18 | |val elements = Seq(Foo("abc")).toDF.asStruct.select($.strin) 19 | |""".stripMargin) 20 | } 21 | -------------------------------------------------------------------------------- /src/test/JoinTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | class JoinTest extends SparkUnitTest: 4 | import org.virtuslab.iskra.api.* 5 | import functions.lit 6 | import Column.=== // by default shadowed by === from scalatest 7 | 8 | 9 | case class Foo(int: Int, long: Long) 10 | case class Bar(int: Int, string: String) 11 | 12 | val foos = Seq( 13 | Foo(1, 10), 14 | Foo(2, 20) 15 | ).toDF.asStruct 16 | 17 | val bars = Seq( 18 | Bar(2, "b"), 19 | Bar(3, "c") 20 | ).toDF.asStruct 21 | 22 | test("join-inner-on") { 23 | val joined = foos.join(bars).on($.foos.int === $.bars.int) 24 | 25 | val typedJoined: StructDataFrame[( 26 | "foos" / "int" := IntNotNull, 27 | "foos" / "long" := LongNotNull, 28 | "bars" / "int" := IntNotNull, 29 | "bars" / "string" := StringNotNull 30 | )] = joined 31 | 32 | val result = joined.select( 33 | $.foos.int.as("_1"), 34 | $.foos.long.as("_2"), 35 | $.bars.int.as("_3"), 36 | $.bars.string.as("_4") 37 | ).asClass[(Int, Long, Int, String)].collect().toList 38 | 39 | result shouldEqual List( 40 | (2, 20, 2, "b") 41 | ) 42 | } 43 | 44 | test("join-left-on") { 45 | val joined = foos.leftJoin(bars).on($.foos.int === $.bars.int) 46 | 47 | val typedJoined: StructDataFrame[( 48 | "foos" / "int" := IntNotNull, 49 | "foos" / "long" := LongNotNull, 50 | "bars" / "int" := IntOrNull, 51 | "bars" / "string" := StringOrNull 52 | )] = joined 53 | 54 | val result = joined.select( 55 | $.foos.int.as("_1"), 56 | $.foos.long.as("_2"), 57 | $.bars.int.as("_3"), 58 | $.bars.string.as("_4") 59 | ).asClass[(Int, Long, Option[Int], Option[String])].collect().toList 60 | 61 | result shouldEqual List( 62 | (1, 10, None, None), 63 | (2, 20, Some(2), Some("b")) 64 | ) 65 | } 66 | 67 | test("join-right-on") { 68 | val joined = foos.rightJoin(bars).on($.foos.int === $.bars.int) 69 | 70 | val typedJoined: StructDataFrame[( 71 | "foos" / "int" := IntOrNull, 72 | "foos" / "long" := LongOrNull, 73 | "bars" / "int" := IntNotNull, 74 | "bars" / "string" := StringNotNull 75 | )] = joined 76 | 77 | val result = joined.select( 78 | $.foos.int.as("_1"), 79 | $.foos.long.as("_2"), 80 | $.bars.int.as("_3"), 81 | $.bars.string.as("_4") 82 | ).asClass[(Option[Int], Option[Long], Int, String)].collect().toList 83 | 84 | result shouldEqual List( 85 | (Some(2), Some(20), 2, "b"), 86 | (None, None, 3, "c") 87 | ) 88 | } 89 | 90 | test("join-full-on") { 91 | val joined = foos.fullJoin(bars).on($.foos.int === $.bars.int) 92 | 93 | val typedJoined: StructDataFrame[( 94 | "foos" / "int" := IntOrNull, 95 | "foos" / "long" := LongOrNull, 96 | "bars" / "int" := IntOrNull, 97 | "bars" / "string" := StringOrNull 98 | )] = joined 99 | 100 | val result = joined.select( 101 | $.foos.int.as("_1"), 102 | $.foos.long.as("_2"), 103 | $.bars.int.as("_3"), 104 | $.bars.string.as("_4") 105 | ).asClass[(Option[Int], Option[Long], Option[Int], Option[String])].collect().toList 106 | 107 | result shouldEqual List( 108 | (Some(1), Some(10), None, None), 109 | (Some(2), Some(20), Some(2), Some("b")), 110 | (None, None, Some(3), Some("c")) 111 | ) 112 | } 113 | 114 | test("join-semi-on") { 115 | val joined = foos.semiJoin(bars).on($.foos.int === $.bars.int) 116 | 117 | val typedJoined: StructDataFrame[( 118 | "foos" / "int" := IntNotNull, 119 | "foos" / "long" := LongNotNull 120 | )] = joined 121 | 122 | val result = joined.select( 123 | $.foos.int.as("_1"), 124 | $.foos.long.as("_2"), 125 | ).asClass[(Int, Long)].collect().toList 126 | 127 | result shouldEqual List( 128 | (2, 20) 129 | ) 130 | } 131 | 132 | test("join-anti-on") { 133 | val joined = foos.antiJoin(bars).on($.foos.int === $.bars.int) 134 | 135 | val typedJoined: StructDataFrame[( 136 | "foos" / "int" := IntNotNull, 137 | "foos" / "long" := LongNotNull 138 | )] = joined 139 | 140 | val result = joined.select( 141 | $.foos.int.as("_1"), 142 | $.foos.long.as("_2"), 143 | ).asClass[(Int, Long)].collect().toList 144 | 145 | result shouldEqual List( 146 | (1, 10) 147 | ) 148 | } 149 | 150 | test("join-cross") { 151 | val joined = foos.crossJoin(bars) 152 | 153 | val typedJoined: StructDataFrame[( 154 | "foos" / "int" := IntNotNull, 155 | "foos" / "long" := LongNotNull, 156 | "bars" / "int" := IntNotNull, 157 | "bars" / "string" := StringNotNull 158 | )] = joined 159 | 160 | val result = joined.select( 161 | $.foos.int.as("_1"), 162 | $.foos.long.as("_2"), 163 | $.bars.int.as("_3"), 164 | $.bars.string.as("_4") 165 | ).asClass[(Int, Long, Int, String)].collect().toList 166 | 167 | result shouldEqual List( 168 | (1, 10, 2, "b"), 169 | (1, 10, 3, "c"), 170 | (2, 20, 2, "b"), 171 | (2, 20, 3, "c"), 172 | ) 173 | } 174 | 175 | test("join-preserve-aliases") { 176 | val joined1 = foos.as("fooz").join(bars).on($.fooz.int === $.bars.int) 177 | val barz = bars.as("barz") 178 | val joined2 = foos.join(barz).on($.foos.int === $.barz.int) 179 | 180 | val result1 = joined1.select( 181 | $.fooz.int.as("_1"), 182 | $.fooz.long.as("_2"), 183 | $.bars.int.as("_3"), 184 | $.bars.string.as("_4") 185 | ).asClass[(Int, Long, Int, String)].collect().toList 186 | 187 | val result2 = joined2.select( 188 | $.foos.int.as("_1"), 189 | $.foos.long.as("_2"), 190 | $.barz.int.as("_3"), 191 | $.barz.string.as("_4") 192 | ).asClass[(Int, Long, Int, String)].collect().toList 193 | 194 | result1 shouldEqual List( 195 | (2, 20, 2, "b") 196 | ) 197 | 198 | result2 shouldEqual List( 199 | (2, 20, 2, "b") 200 | ) 201 | } 202 | -------------------------------------------------------------------------------- /src/test/OperatorsTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | import org.scalatest.funsuite.AnyFunSuite 4 | import org.scalatest.BeforeAndAfterAll 5 | import org.scalatest.matchers.should.Matchers.shouldEqual 6 | 7 | class OperatorsTest extends SparkUnitTest: 8 | import org.virtuslab.iskra.api.* 9 | 10 | case class Foo(boolean: Boolean, string: String, byte: Byte, short: Short, int: Int, long: Long, float: Float, double: Double) 11 | case class Bar(int: Int, intSome: Option[Int], intNone: Option[Int]) 12 | 13 | val foos = Seq( 14 | Foo(true, "abc", 1, 2, 3, 4, 5.0, 6.0) 15 | ).toDF.asStruct 16 | 17 | val bars = Seq( 18 | Bar(1, Some(10), None), 19 | ).toDF.asStruct 20 | 21 | test("plus") { 22 | val result = foos.select( 23 | ($.byte + $.byte).as("_1"), 24 | ($.short + $.short).as("_2"), 25 | ($.int + $.int).as("_3"), 26 | ($.long + $.long).as("_4"), 27 | ($.float + $.float).as("_5"), 28 | ($.double + $.double).as("_6"), 29 | ($.short + $.float).as("_7"), 30 | ).asClass[(Byte, Short, Int, Long, Float, Double, Float)].collect().toList 31 | 32 | result shouldEqual List((2, 4, 6, 8, 10.0, 12.0, 7.0)) 33 | } 34 | 35 | test("minus") { 36 | val result = foos.select( 37 | ($.byte - $.byte).as("_1"), 38 | ($.short - $.short).as("_2"), 39 | ($.int - $.int).as("_3"), 40 | ($.long - $.long).as("_4"), 41 | ($.float - $.float).as("_5"), 42 | ($.double - $.double).as("_6"), 43 | ($.double - $.byte).as("_7") 44 | ).asClass[(Byte, Short, Int, Long, Float, Double, Double)].collect().toList 45 | 46 | result shouldEqual List((0, 0, 0, 0, 0.0, 0.0, 5.0)) 47 | } 48 | 49 | test("mult") { 50 | val result = foos.select( 51 | ($.byte * $.byte).as("_1"), 52 | ($.short * $.short).as("_2"), 53 | ($.int * $.int).as("_3"), 54 | ($.long * $.long).as("_4"), 55 | ($.float * $.float).as("_5"), 56 | ($.double * $.double).as("_6"), 57 | ($.int * $.float).as("_7"), 58 | ).asClass[(Byte, Short, Int, Long, Float, Double, Float)].collect().toList 59 | 60 | result shouldEqual List((1, 4, 9, 16, 25.0, 36.0, 15.0)) 61 | } 62 | 63 | test("div") { 64 | val result = foos.select( 65 | ($.byte / $.byte).as("_1"), 66 | ($.short / $.short).as("_2"), 67 | ($.int / $.int).as("_3"), 68 | ($.long / $.long).as("_4"), 69 | ($.float / $.float).as("_5"), 70 | ($.double / $.double).as("_6"), 71 | ($.long / $.short).as("_7"), 72 | ).asClass[(Double, Double, Double, Double, Double, Double, Double)].collect().toList 73 | 74 | result shouldEqual List((1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0)) 75 | } 76 | 77 | test("plusplus") { 78 | val result = foos.select( 79 | ($.string ++ $.string).as("_1"), 80 | ).asClass[String].collect().toList 81 | 82 | result shouldEqual List("abcabc") 83 | } 84 | 85 | test("eq") { 86 | import Column.=== // by default shadowed by === from scalatest 87 | 88 | val result = foos.select( 89 | ($.boolean === $.boolean).as("_1"), 90 | ($.string === $.string).as("_2"), 91 | ($.byte === $.byte).as("_3"), 92 | ($.short === $.short).as("_4"), 93 | ($.int === $.int).as("_5"), 94 | ($.long === $.long).as("_6"), 95 | ($.float === $.float).as("_7"), 96 | ($.double === $.double).as("_8"), 97 | ($.byte === $.int).as("_9"), 98 | ).asClass[(Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean)].collect().toList 99 | 100 | result shouldEqual List((true, true, true, true, true, true, true, true, false)) 101 | } 102 | 103 | test("ne") { 104 | val result = foos.select( 105 | ($.boolean =!= $.boolean).as("_1"), 106 | ($.string =!= $.string).as("_2"), 107 | ($.byte =!= $.byte).as("_3"), 108 | ($.short =!= $.short).as("_4"), 109 | ($.int =!= $.int).as("_5"), 110 | ($.long =!= $.long).as("_6"), 111 | ($.float =!= $.float).as("_7"), 112 | ($.double =!= $.double).as("_8"), 113 | ($.short =!= $.float).as("_9"), 114 | ).asClass[(Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean)].collect().toList 115 | 116 | result shouldEqual List((false, false, false, false, false, false, false, false, true)) 117 | } 118 | 119 | test("lt") { 120 | val result = foos.select( 121 | ($.boolean < $.boolean).as("_1"), 122 | ($.string < $.string).as("_2"), 123 | ($.byte < $.byte).as("_3"), 124 | ($.short < $.short).as("_4"), 125 | ($.int < $.int).as("_5"), 126 | ($.long < $.long).as("_6"), 127 | ($.float < $.float).as("_7"), 128 | ($.double < $.double).as("_8"), 129 | ($.byte < $.double).as("_9"), 130 | ).asClass[(Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean)].collect().toList 131 | 132 | result shouldEqual List((false, false, false, false, false, false, false, false, true)) 133 | } 134 | 135 | test("le") { 136 | val result = foos.select( 137 | ($.boolean <= $.boolean).as("_1"), 138 | ($.string <= $.string).as("_2"), 139 | ($.byte <= $.byte).as("_3"), 140 | ($.short <= $.short).as("_4"), 141 | ($.int <= $.int).as("_5"), 142 | ($.long <= $.long).as("_6"), 143 | ($.float <= $.float).as("_7"), 144 | ($.double <= $.double).as("_8"), 145 | ($.long <= $.byte).as("_9"), 146 | ).asClass[(Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean)].collect().toList 147 | 148 | result shouldEqual List((true, true, true, true, true, true, true, true, false)) 149 | } 150 | 151 | test("gt") { 152 | val result = foos.select( 153 | ($.boolean > $.boolean).as("_1"), 154 | ($.string > $.string).as("_2"), 155 | ($.byte > $.byte).as("_3"), 156 | ($.short > $.short).as("_4"), 157 | ($.int > $.int).as("_5"), 158 | ($.long > $.long).as("_6"), 159 | ($.float > $.float).as("_7"), 160 | ($.double > $.double).as("_8"), 161 | ($.float > $.long).as("_9"), 162 | ).asClass[(Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean)].collect().toList 163 | 164 | result shouldEqual List((false, false, false, false, false, false, false, false, true)) 165 | } 166 | 167 | test("ge") { 168 | val result = foos.select( 169 | ($.boolean >= $.boolean).as("_1"), 170 | ($.string >= $.string).as("_2"), 171 | ($.byte >= $.byte).as("_3"), 172 | ($.short >= $.short).as("_4"), 173 | ($.int >= $.int).as("_5"), 174 | ($.long >= $.long).as("_6"), 175 | ($.float >= $.float).as("_7"), 176 | ($.double >= $.double).as("_8"), 177 | ($.short >= $.int).as("_9"), 178 | ).asClass[(Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean, Boolean)].collect().toList 179 | 180 | result shouldEqual List((true, true, true, true, true, true, true, true, false)) 181 | } 182 | 183 | test("and") { 184 | val result = foos.select( 185 | ($.boolean && $.boolean).as("_1"), 186 | ).asClass[Boolean].collect().toList 187 | 188 | result shouldEqual List(true) 189 | } 190 | 191 | test("or") { 192 | val result = foos.select( 193 | ($.boolean || $.boolean).as("_1"), 194 | ).asClass[Boolean].collect().toList 195 | 196 | result shouldEqual List(true) 197 | } 198 | 199 | test("plus nullable") { 200 | val result = bars.select( 201 | ($.int + $.intSome).as("_1"), 202 | ($.int + $.intNone).as("_2"), 203 | ($.intSome + $.int).as("_3"), 204 | ($.intNone + $.int).as("_4"), 205 | ($.intSome + $.intSome).as("_5"), 206 | ($.intSome + $.intNone).as("_6"), 207 | ($.intNone + $.intSome).as("_7"), 208 | ($.intNone + $.intNone).as("_8"), 209 | ).asClass[(Option[Int], Option[Int], Option[Int], Option[Int], Option[Int], Option[Int], Option[Int], Option[Int])].collect().toList 210 | 211 | result shouldEqual List((Some(11), None, Some(11), None, Some(20), None, None, None)) 212 | } 213 | -------------------------------------------------------------------------------- /src/test/SparkUnitTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | import org.scalatest.funsuite.AnyFunSuite 4 | import org.scalatest.BeforeAndAfterAll 5 | import org.virtuslab.iskra.api.* 6 | 7 | abstract class SparkUnitTest extends AnyFunSuite, BeforeAndAfterAll: 8 | def appName: String = getClass.getSimpleName 9 | 10 | given spark: SparkSession = 11 | SparkSession 12 | .builder() 13 | .master("local") 14 | .appName(suiteName) 15 | .getOrCreate() 16 | 17 | override def afterAll() = 18 | spark.stop() 19 | 20 | export org.scalatest.matchers.should.Matchers.shouldEqual -------------------------------------------------------------------------------- /src/test/WhenTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | class WhenTest extends SparkUnitTest: 4 | import org.virtuslab.iskra.api.* 5 | import functions.{lit, when} 6 | import Column.=== // by default shadowed by === from scalatest 7 | 8 | case class Foo(int: Int) 9 | 10 | val foos = Seq( 11 | Foo(1), 12 | Foo(2), 13 | Foo(3) 14 | ).toDF.asStruct 15 | 16 | test("when-without-fallback") { 17 | val result = foos 18 | .select(when($.int === lit(1), lit("a")).as("strOpt")) 19 | .asClass[Option[String]].collect().toList 20 | 21 | result shouldEqual Seq(Some("a"), None, None) 22 | } 23 | 24 | test("when-with-fallback") { 25 | val result = foos 26 | .select{ 27 | when($.int === lit(1), lit(10)) 28 | .otherwise(lit(100d)) 29 | .as("double") 30 | } 31 | .asClass[Double].collect().toList 32 | 33 | result shouldEqual Seq(10d, 100d, 100d) 34 | } 35 | 36 | test("when-else-when-without-fallback") { 37 | val result = foos 38 | .select{ 39 | when($.int === lit(1), lit(10)) 40 | .when($.int === lit(2), lit(100L)) 41 | .as("longOpt") 42 | } 43 | .asClass[Option[Long]].collect().toList 44 | 45 | result shouldEqual Seq(Some(10L), Some(100L), None) 46 | } 47 | 48 | test("when-else-when-with-fallback") { 49 | val result = foos 50 | .select{ 51 | when($.int === lit(1), lit(10)) 52 | .when($.int === lit(2), lit(100L)) 53 | .otherwise(lit(1000d)) 54 | .as("double") 55 | } 56 | .asClass[Double].collect().toList 57 | 58 | result shouldEqual Seq(10d, 100d, 1000d) 59 | } 60 | -------------------------------------------------------------------------------- /src/test/WhereTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | class WhereTest extends SparkUnitTest: 4 | import org.virtuslab.iskra.api.* 5 | import functions.lit 6 | 7 | case class Foo(int: Int, intOpt: Option[Int]) 8 | 9 | val foos = Seq( 10 | Foo(1, Some(1)), 11 | Foo(2, None), 12 | Foo(3, Some(3)) 13 | ).toDF.asStruct 14 | 15 | test("where-nonnullable") { 16 | val result = foos 17 | .where($.int >= lit(2)) 18 | .select($.intOpt) 19 | .asClass[Option[Int]].collect().toList 20 | 21 | result shouldEqual Seq(None, Some(3)) 22 | } 23 | 24 | test("where-nullable") { 25 | val result = foos 26 | .where($.intOpt >= lit(2)) 27 | .select($.int) 28 | .asClass[Int].collect().toList 29 | 30 | result shouldEqual Seq(3) 31 | } 32 | -------------------------------------------------------------------------------- /src/test/WithColumnsTest.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.test 2 | 3 | class WithColumnsTest extends SparkUnitTest: 4 | import org.virtuslab.iskra.api.* 5 | 6 | case class Foo(a: Int, b: Int) 7 | case class Bar(a: Int, b: Int, c: Int) 8 | case class Baz(a: Int, b: Int, c: Int, d: Int) 9 | 10 | val foos = Seq( 11 | Foo(1, 2) 12 | ).toDF.asStruct 13 | 14 | test("withColumns-single") { 15 | val result = foos 16 | .withColumns( 17 | ($.a + $.b).as("c") 18 | ) 19 | .asClass[Bar].collect().toList 20 | 21 | result shouldEqual List(Bar(1, 2, 3)) 22 | } 23 | 24 | test("withColumns-single-autoAliased") { 25 | val result = foos 26 | .withColumns { 27 | val c = ($.a + $.b) 28 | c 29 | } 30 | .asClass[Bar].collect().toList 31 | 32 | result shouldEqual List(Bar(1, 2, 3)) 33 | } 34 | 35 | test("withColumns-many") { 36 | val result = foos 37 | .withColumns( 38 | ($.a + $.b).as("c"), 39 | ($.a - $.b).as("d"), 40 | ) 41 | .asClass[Baz].collect().toList 42 | 43 | result shouldEqual List(Baz(1, 2, 3, -1)) 44 | } 45 | 46 | test("withColumns-many-autoAliased") { 47 | val result = foos 48 | .withColumns{ 49 | val c = ($.a + $.b) 50 | val d = ($.a - $.b) 51 | (c, d) 52 | } 53 | .asClass[Baz].collect().toList 54 | 55 | result shouldEqual List(Baz(1, 2, 3, -1)) 56 | } -------------------------------------------------------------------------------- /src/test/example/Books.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.example.books 2 | 3 | import org.virtuslab.iskra.api.* 4 | 5 | @main def runExample(dataFilePath: String): Unit = 6 | given spark: SparkSession = { 7 | SparkSession 8 | .builder() 9 | .master("local") 10 | .appName("books") 11 | .getOrCreate() 12 | } 13 | 14 | case class Book(title: String, author: String, publicationYear: Int) 15 | 16 | val untypedBooks: UntypedDataFrame = spark.read.options(Map("header"->"true")).csv(dataFilePath) // UntypedDataFrame = sql.DataFrame 17 | untypedBooks.show() 18 | val books = untypedBooks 19 | .typed[Book] // Unsafe: make sure `untypedBooks` has the right schema 20 | .asStruct 21 | 22 | import org.apache.spark.sql.functions.lower 23 | 24 | val authorlessBooks = books.select( 25 | lower($.title.untyped).typed[StringNotNull].as("title"), 26 | $.publicationYear 27 | ) 28 | authorlessBooks.show() 29 | 30 | 31 | 32 | import org.scalatest.funsuite.AnyFunSuite 33 | import org.scalatest.matchers.should.Matchers.shouldEqual 34 | import java.io.ByteArrayOutputStream 35 | import java.io.File 36 | import java.nio.file.Files 37 | 38 | class ExampleTest extends AnyFunSuite: 39 | test("Books example") { 40 | val fileContent = """title,author,publicationYear 41 | |"The Call of Cthulhu","H. P. Lovecraft",1928 42 | |"The Hobbit, or There and Back Again","J. R. R. Tolkien",1937 43 | |"Alice's Adventures in Wonderland","Lewis Carroll",1865 44 | |"Murder on the Orient Express","Agatha Christie",1934 45 | """.stripMargin 46 | 47 | val file = File.createTempFile("books", ".csv") 48 | file.deleteOnExit() 49 | 50 | Files.write(file.getAbsoluteFile.toPath, fileContent.getBytes) 51 | 52 | val outCapture = new ByteArrayOutputStream 53 | Console.withOut(outCapture) { runExample(file.getAbsolutePath) } 54 | val result = new String(outCapture.toByteArray) 55 | 56 | val expected = """+--------------------+----------------+---------------+ 57 | || title| author|publicationYear| 58 | |+--------------------+----------------+---------------+ 59 | || The Call of Cthulhu| H. P. Lovecraft| 1928| 60 | ||The Hobbit, or Th...|J. R. R. Tolkien| 1937| 61 | ||Alice's Adventure...| Lewis Carroll| 1865| 62 | ||Murder on the Ori...| Agatha Christie| 1934| 63 | |+--------------------+----------------+---------------+ 64 | | 65 | |+--------------------+---------------+ 66 | || title|publicationYear| 67 | |+--------------------+---------------+ 68 | || the call of cthulhu| 1928| 69 | ||the hobbit, or th...| 1937| 70 | ||alice's adventure...| 1865| 71 | ||murder on the ori...| 1934| 72 | |+--------------------+---------------+ 73 | | 74 | |""".stripMargin 75 | 76 | result shouldEqual expected 77 | } 78 | -------------------------------------------------------------------------------- /src/test/example/Countries.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.example.countries 2 | 3 | import org.virtuslab.iskra.api.* 4 | import functions.avg 5 | 6 | @main def runExample(): Unit = 7 | given spark: SparkSession = { 8 | SparkSession 9 | .builder() 10 | .master("local") 11 | .appName("countries") 12 | .getOrCreate() 13 | } 14 | 15 | case class City(name: String, population: Int) 16 | case class Country(name: String, continent: String, capital: String, population: Int, gdp: Int) 17 | 18 | val cities = Seq( 19 | City("Warsaw", 1794532), 20 | City("Krakow", 769595), 21 | City("Paris", 11142303), 22 | City("Washington", 718355), 23 | City("London", 9540576), 24 | City("Ottawa", 1422635) 25 | ).toDF.asStruct 26 | 27 | val countries = Seq( 28 | Country("United Kingdom", "Europe", "London", 67886011, 39532), 29 | Country("France", "Europe", "Paris", 65273511, 39827), 30 | Country("USA", "North America", "Washington", 331002651, 59939), 31 | Country("Poland", "Europe", "Warsaw", 37846611, 13871), 32 | Country("Canada", "North America", "Ottawa", 37742154, 44841) 33 | ).toDF.asStruct 34 | 35 | countries.join(cities) // shorthand for: countries.as("countries").join(cities.as("cities")) 36 | .on($.countries.capital === $.cities.name) 37 | .select( 38 | $.countries.name.as("country"), 39 | $.continent, 40 | $.cities.population.as("capital population") 41 | ).show() 42 | 43 | countries.groupBy($.continent).agg(avg($.gdp).as("avg gdp")).show() 44 | 45 | spark.stop() 46 | 47 | 48 | 49 | import org.scalatest.funsuite.AnyFunSuite 50 | import org.scalatest.matchers.should.Matchers.shouldEqual 51 | import java.io.ByteArrayOutputStream 52 | 53 | class ExampleTest extends AnyFunSuite: 54 | test("Countries example") { 55 | val outCapture = new ByteArrayOutputStream 56 | Console.withOut(outCapture) { runExample() } 57 | val result = new String(outCapture.toByteArray) 58 | 59 | val expected = """|+--------------+-------------+------------------+ 60 | || country| continent|capital population| 61 | |+--------------+-------------+------------------+ 62 | ||United Kingdom| Europe| 9540576| 63 | || Canada|North America| 1422635| 64 | || France| Europe| 11142303| 65 | || Poland| Europe| 1794532| 66 | || USA|North America| 718355| 67 | |+--------------+-------------+------------------+ 68 | | 69 | |+-------------+------------------+ 70 | || continent| avg gdp| 71 | |+-------------+------------------+ 72 | || Europe|31076.666666666668| 73 | ||North America| 52390.0| 74 | |+-------------+------------------+ 75 | | 76 | |""".stripMargin 77 | 78 | result shouldEqual expected 79 | } 80 | -------------------------------------------------------------------------------- /src/test/example/Workers.scala: -------------------------------------------------------------------------------- 1 | package org.virtuslab.iskra.example.workers 2 | 3 | import org.virtuslab.iskra.api.* 4 | import functions.lit 5 | 6 | @main def runExample(): Unit = 7 | given spark: SparkSession = { 8 | SparkSession 9 | .builder() 10 | .master("local") 11 | .appName("workers") 12 | .getOrCreate() 13 | } 14 | 15 | case class Worker(id: Long, firstName: String, lastName: String, yearsInCompany: Int) 16 | case class Supervision(subordinateId: Long, supervisorId: Long) 17 | case class Room(number: Int, name: String, desksCount: Int) 18 | 19 | val workers = Seq( 20 | Worker(3, "Bob", "Smith", 8), 21 | Worker(13, "Alice", "Potter", 4), 22 | Worker(38, "John", "Parker", 1), 23 | Worker(21, "Julia", "Taylor", 3), 24 | Worker(11, "Emma", "Brown", 6), 25 | Worker(8, "Michael", "Johnson", 7), 26 | Worker(18, "Natalie", "Evans", 4), 27 | Worker(22, "Paul", "Wilson", 3), 28 | Worker(44, "Daniel", "Jones", 1) 29 | ).toDF.asStruct 30 | 31 | val supervisions = Seq( 32 | 44 -> 21, 22 -> 21, 38 -> 13, 11 -> 3, 21 -> 18, 13 -> 8, 3 -> 8, 18 -> 8 33 | ).map{ case (id1, id2) => Supervision(id1, id2) }.toDF.asStruct 34 | 35 | workers.as("subordinates") 36 | .leftJoin(supervisions).on($.subordinates.id === $.subordinateId) 37 | .leftJoin(workers.as("supervisors")).on($.supervisorId === $.supervisors.id) 38 | .select { 39 | val salary = lit(4732) + $.subordinates.yearsInCompany * lit(214) 40 | val supervisor = $.supervisors.firstName ++ lit(" ") ++ $.supervisors.lastName 41 | ($.subordinates.firstName, $.subordinates.lastName, supervisor, salary) 42 | } 43 | .where($.salary > lit(5000)) 44 | .show() 45 | 46 | spark.stop() 47 | 48 | 49 | 50 | import org.scalatest.funsuite.AnyFunSuite 51 | import org.scalatest.matchers.should.Matchers.shouldEqual 52 | import java.io.ByteArrayOutputStream 53 | 54 | class ExampleTest extends AnyFunSuite: 55 | test("Workers example") { 56 | val outCapture = new ByteArrayOutputStream 57 | Console.withOut(outCapture) { runExample() } 58 | val result = new String(outCapture.toByteArray) 59 | 60 | val expected = """|+---------+--------+---------------+------+ 61 | ||firstName|lastName| supervisor|salary| 62 | |+---------+--------+---------------+------+ 63 | || Michael| Johnson| null| 6230| 64 | || Emma| Brown| Bob Smith| 6016| 65 | || Bob| Smith|Michael Johnson| 6444| 66 | || Alice| Potter|Michael Johnson| 5588| 67 | || Natalie| Evans|Michael Johnson| 5588| 68 | || Julia| Taylor| Natalie Evans| 5374| 69 | || Paul| Wilson| Julia Taylor| 5374| 70 | |+---------+--------+---------------+------+ 71 | | 72 | |""".stripMargin 73 | 74 | result shouldEqual expected 75 | } 76 | --------------------------------------------------------------------------------