├── .github
├── CODEOWNERS
└── workflows
│ ├── ci-quality.yml
│ ├── ci.yml
│ ├── docs.yml
│ └── release.yml
├── .gitignore
├── .scalafmt.conf
├── LICENSE
├── README.md
├── benchmarks
└── src
│ └── main
│ └── scala
│ └── com
│ └── github
│ └── mrpowers
│ └── spark
│ └── fast
│ └── tests
│ ├── ColumnComparerBenchmark.scala
│ └── DataFrameComparerBenchmark.scala
├── build.sbt
├── core
└── src
│ ├── main
│ └── scala
│ │ └── com
│ │ └── github
│ │ └── mrpowers
│ │ └── spark
│ │ └── fast
│ │ └── tests
│ │ ├── ArrayUtil.scala
│ │ ├── ColumnComparer.scala
│ │ ├── DataFrameComparer.scala
│ │ ├── DataFramePrettyPrint.scala
│ │ ├── DatasetComparer.scala
│ │ ├── ProductUtil.scala
│ │ ├── RDDComparer.scala
│ │ ├── RddHelpers.scala
│ │ ├── RowComparer.scala
│ │ ├── SchemaComparer.scala
│ │ ├── SchemaDiffOutputFormat.scala
│ │ ├── SeqLikesExtensions.scala
│ │ └── ufansi
│ │ ├── Fansi.scala
│ │ └── FansiExtensions.scala
│ └── test
│ ├── resources
│ └── log4j.properties
│ └── scala
│ └── com
│ └── github
│ └── mrpowers
│ └── spark
│ └── fast
│ └── tests
│ ├── ArrayUtilTest.scala
│ ├── ColumnComparerTest.scala
│ ├── DataFrameComparerTest.scala
│ ├── DataFramePrettyPrintTest.scala
│ ├── DatasetComparerTest.scala
│ ├── ExamplesTest.scala
│ ├── RDDComparerTest.scala
│ ├── RowComparerTest.scala
│ ├── SchemaComparerTest.scala
│ ├── SeqLikesExtensionsTest.scala
│ ├── SparkSessionExt.scala
│ ├── SparkSessionTestWrapper.scala
│ └── TestUtilsExt.scala
├── docs
└── about
│ └── README.md
├── images
├── assertColumnEquality_error_message.png
├── assertSchemaEquality_tree_message.png
├── assertSmallDataFrameEquality_DatasetContentMissmatch_message.png
├── assertSmallDataFrameEquality_DatasetSchemaMisMatch_message.png
├── assertSmallDataFrameEquality_error_message.png
└── assertSmallDatasetEquality_error_message.png
├── project
├── build.properties
└── plugins.sbt
└── scripts
└── multi_spark_releases.sh
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @alfonsorr
2 | * @SemyonSinchenko
3 |
--------------------------------------------------------------------------------
/.github/workflows/ci-quality.yml:
--------------------------------------------------------------------------------
1 | name: CI-quality
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 |
8 | jobs:
9 | build:
10 | strategy:
11 | fail-fast: false
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - uses: coursier/cache-action@v6
16 | with:
17 | extraKey: quality-check
18 | - uses: coursier/setup-action@v1
19 | with:
20 | jvm: zulu:8
21 | - name: Setup sbt launcher
22 | uses: sbt/setup-sbt@v1
23 | with:
24 | sbt-version: 1.10.1
25 | - name: ScalaFmt
26 | id: fmt
27 | run: sbt scalafmtCheckAll
28 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 | on:
3 | pull_request:
4 |
5 | jobs:
6 | build:
7 | strategy:
8 | fail-fast: false
9 | matrix:
10 | spark: ["3.1.3","3.2.4", "3.3.4", "3.4.3", "3.5.3"]
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v4
14 | - uses: coursier/cache-action@v6
15 | with:
16 | extraKey: ${{ matrix.spark }}
17 | - uses: coursier/setup-action@v1
18 | with:
19 | jvm: zulu:8
20 | - name: Setup sbt launcher
21 | uses: sbt/setup-sbt@v1
22 | with :
23 | sbt-version: 1.10.1
24 | - name: Test
25 | run: sbt -Dspark.version=${{ matrix.spark }} +test
26 | - name: Benchmark
27 | run: sbt -Dspark.version=${{ matrix.spark }} +benchmarks/Jmh/run
28 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Docs
2 | on:
3 | workflow_dispatch:
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v1
10 | - uses: olafurpg/setup-scala@v10
11 | - name: Setup sbt launcher
12 | uses: sbt/setup-sbt@v1
13 | with:
14 | sbt-version: 1.10.1
15 | - name: Build docs
16 | run: sbt "project docs" laikaSite
17 | - name: Deploy to GH Pages
18 | uses: peaceiris/actions-gh-pages@v4
19 | with:
20 | github_token: ${{ secrets.GITHUB_TOKEN }}
21 | publish_dir: ./target/docs/site
22 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 | on:
3 | push:
4 | tags: [ "*" ]
5 | branches:
6 | - main
7 | jobs:
8 | publish:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - uses: coursier/setup-action@v1
13 | with:
14 | jvm: zulu:11
15 | apps: sbt scala scalafmt
16 | - run: sbt -Dsun.net.client.defaultReadTimeout=60000 -Dsun.net.client.defaultConnectTimeout=60000 -Dspark.version=3.5.3 -v ci-release
17 | env:
18 | PGP_PASSPHRASE: ${{secrets.PGP_PASSPHRASE}}
19 | PGP_SECRET: ${{secrets.PGP_LONG_ID}}
20 | SONATYPE_PASSWORD: ${{secrets.SONATYPE_PASSWORD}}
21 | SONATYPE_USERNAME: ${{secrets.SONATYPE_USERNAME}}
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | lib_managed/
3 | src_managed/
4 | project/boot/
5 |
6 | .DS_Store
7 |
8 | .idea
9 | *.iml
10 |
11 | *.swp
12 |
13 | *.class
14 | *.log
15 |
16 | # sbt specific
17 | .cache
18 | .history
19 | .lib/
20 | dist/*
21 | target/
22 | lib_managed/
23 | src_managed/
24 | project/boot/
25 | project/plugins/project/
26 | **/.bloop
27 | .bsp/
28 |
29 | # Scala-IDE specific
30 | .scala_dependencies
31 | .worksheet
32 |
33 | .idea
34 |
35 | # ENSIME specific
36 | .ensime_cache/
37 | .ensime
38 |
39 | .metals/
40 | .ammonite/
41 | metals.sbt
42 | metals/project/
43 |
44 | .vscode/
45 |
46 | local.*
47 |
48 | .DS_Store
49 |
50 | node_modules
51 |
52 | lib/core/metadata.js
53 | lib/core/MetadataBlog.js
54 |
55 | website/translated_docs
56 | website/build/
57 | website/yarn.lock
58 | website/node_modules
59 | website/i18n/*
60 | !website/i18n/en.json
61 | website/.docusaurus
62 | website/.cache-loader
63 |
64 | project/metals.sbt
65 | coursier
--------------------------------------------------------------------------------
/.scalafmt.conf:
--------------------------------------------------------------------------------
1 | version = 3.8.2
2 |
3 | align.preset = more
4 | runner.dialect = scala212
5 | maxColumn = 150
6 | docstrings.style = Asterisk
7 |
8 | //style = defaultWithAlign
9 | //align = true
10 | //danglingParentheses = false
11 | //indentOperator = spray
12 | //maxColumn = 300
13 | //rewrite.rules = [RedundantParens, AvoidInfix, SortImports, PreferCurlyFors]
14 | //unindentTopLevelOperators = true
15 | //verticalMultilineAtDefinitionSite = true
16 | //assumeStandardLibraryStripMargin = true
17 | //includeCurlyBraceInSelectChains = false
18 | //
19 | //align.openParenCallSite = false
20 | //align.openParenDefnSite = false
21 | //
22 | //binPack.parentConstructors = true
23 | //binPack.literalArgumentLists = true
24 | //binPack.unsafeCallSite = false
25 | //binPack.unsafeDefnSite = false
26 | //
27 | //continuationIndent.callSite = 2
28 | //continuationIndent.defnSite = 2
29 | //
30 | //newlines.alwaysBeforeTopLevelStatements = true
31 | //newlines.penalizeSingleSelectMultiArgList = false
32 | //
33 | //optIn.breakChainOnFirstMethodDot = true
34 | //
35 | //runner.optimizer.forceConfigStyleOnOffset = 1
36 | //runner.optimizer.forceConfigStyleMinArgCount = 2
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 MrPowers
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spark Fast Tests
2 |
3 | [](https://github.com/MrPowers/spark-fast-tests/actions/workflows/ci.yml)
4 |
5 | A fast Apache Spark testing helper library with beautifully formatted error messages! Works
6 | with [scalatest](https://github.com/scalatest/scalatest), [uTest](https://github.com/lihaoyi/utest),
7 | and [munit](https://github.com/scalameta/munit).
8 |
9 | Use [chispa](https://github.com/MrPowers/chispa) for PySpark applications.
10 |
11 | Read [Testing Spark Applications](https://leanpub.com/testing-spark) for a full explanation on the best way to test
12 | Spark code! Good test suites yield higher quality codebases that are easy to refactor.
13 |
14 | ## Table of Contents
15 | - [Install](#install)
16 | - [Examples](#simple-examples)
17 | - [Why is this library fast?](#why-is-this-library-fast)
18 | - [Usage](#usage)
19 | - [Local Testing SparkSession](#local-sparksession-for-test)
20 | - [DataFrameComparer](#datasetcomparer--dataframecomparer)
21 | - [Unordered DataFrames comparison](#unordered-dataframe-equality-comparisons)
22 | - [Approximate DataFrames comparison](#approximate-dataframe-equality)
23 | - [Ignore Nullable DataFrames comparison](#equality-comparisons-ignoring-the-nullable-flag)
24 | - [ColumnComparer](#column-equality)
25 | - [SchemaComparer](#schema-equality)
26 | - [Testing tips](#testing-tips)
27 |
28 |
29 | ## Install
30 |
31 | Fetch the JAR file from Maven.
32 |
33 | ```scala
34 | // for Spark 3
35 | libraryDependencies += "com.github.mrpowers" %% "spark-fast-tests" % "1.3.0" % "test"
36 | ```
37 |
38 | **Important: Future versions of spark-fast-test will no longer support Spark 2.x. We recommend upgrading to Spark 3.x to
39 | ensure compatibility with upcoming releases.**
40 |
41 | Here's a link to the releases for different Scala versions:
42 |
43 | * [Scala 2.11 JAR files](https://repo1.maven.org/maven2/com/github/mrpowers/spark-fast-tests_2.11/)
44 | * [Scala 2.12 JAR files](https://repo1.maven.org/maven2/com/github/mrpowers/spark-fast-tests_2.12/)
45 | * [Scala 2.13 JAR files](https://repo1.maven.org/maven2/com/github/mrpowers/spark-fast-tests_2.13/)
46 | * [Legacy JAR files in Maven](https://mvnrepository.com/artifact/MrPowers/spark-fast-tests?repo=spark-packages).
47 |
48 | You should use Scala 2.11 with Spark 2 and Scala 2.12 / 2.13 with Spark 3.
49 |
50 | ## Simple examples
51 |
52 | The `assertSmallDataFrameEquality` method can be used to compare two DataFrames.
53 |
54 | ```scala
55 | val sourceDF = Seq(
56 | (1),
57 | (5)
58 | ).toDF("number")
59 |
60 | val expectedDF = Seq(
61 | (1),
62 | (3)
63 | ).toDF("number")
64 |
65 | assertSmallDataFrameEquality(sourceDF, expectedDF)
66 | ```
67 |
68 |
69 |
70 |
71 |
72 | The `assertSmallDatasetEquality` method can be used to compare two Datasets or DataFrames(Dataset[Row]).
73 | Nicely formatted error messages are displayed when the Datasets are not equal. Here is an example of content mismatch:
74 |
75 | ```scala
76 | val sourceDS = Seq(
77 | Person("juan", 5),
78 | Person("bob", 1),
79 | Person("li", 49),
80 | Person("alice", 5)
81 | ).toDS
82 |
83 | val expectedDS = Seq(
84 | Person("juan", 6),
85 | Person("frank", 10),
86 | Person("li", 49),
87 | Person("lucy", 5)
88 | ).toDS
89 | ```
90 |
91 |
92 |
93 |
94 |
95 | The colors in the error message make it easy to identify the rows that aren't equal. These method also supports
96 | comparing DataFrames with different schemas.
97 |
98 | ```scala
99 | val sourceDF = spark.createDF(
100 | List(
101 | (1, 2.0),
102 | (5, 3.0)
103 | ),
104 | List(
105 | ("number", IntegerType, true),
106 | ("float", DoubleType, true)
107 | )
108 | )
109 |
110 | val expectedDF = spark.createDF(
111 | List(
112 | (1, "word", 1L),
113 | (5, "word", 2L)
114 | ),
115 | List(
116 | ("number", IntegerType, true),
117 | ("word", StringType, true),
118 | ("long", LongType, true)
119 | )
120 | )
121 |
122 | assertSmallDataFrameEquality(sourceDF, expectedDF)
123 | ```
124 |
125 |
126 |
127 |
128 |
129 | The `DatasetComparer` has `assertSmallDatasetEquality` and `assertLargeDatasetEquality` methods to compare either
130 | Datasets or DataFrames.
131 |
132 | If you only need to compare DataFrames, you can use `DataFrameComparer` with the associated
133 | `assertSmallDataFrameEquality` and `assertLargeDataFrameEquality` methods. Under the hood, `DataFrameComparer` uses the
134 | `assertSmallDatasetEquality` and `assertLargeDatasetEquality`.
135 |
136 | *Note : comparing Datasets can be tricky since some column names might be given by Spark when applying transformations.
137 | Use the `ignoreColumnNames` boolean to skip name verification.*
138 |
139 | ## Why is this library fast?
140 |
141 | This library provides three main methods to test your code.
142 |
143 | Suppose you'd like to test this function:
144 |
145 | ```scala
146 | def myLowerClean(col: Column): Column = {
147 | lower(regexp_replace(col, "\\s+", ""))
148 | }
149 | ```
150 |
151 | Here's how long the tests take to execute:
152 |
153 | | test method | runtime |
154 | |--------------------------------|------------------|
155 | | `assertLargeDataFrameEquality` | 709 milliseconds |
156 | | `assertSmallDataFrameEquality` | 166 milliseconds |
157 | | `assertColumnEquality` | 108 milliseconds |
158 | | `evalString` | 26 milliseconds |
159 |
160 | `evalString` isn't as robust, but is the fastest. `assertColumnEquality` is robust and saves a lot of time.
161 |
162 | Other testing libraries don't have methods like `assertSmallDataFrameEquality` or `assertColumnEquality` so they run
163 | slower.
164 |
165 | ## Usage
166 |
167 | ### Local SparkSession for test
168 | The spark-fast-tests project doesn't provide a SparkSession object in your test suite, so you'll need to make one
169 | yourself.
170 |
171 | ```scala
172 | import org.apache.spark.sql.SparkSession
173 |
174 | trait SparkSessionTestWrapper {
175 |
176 | lazy val spark: SparkSession = {
177 | SparkSession
178 | .builder()
179 | .master("local")
180 | .appName("spark session")
181 | .config("spark.sql.shuffle.partitions", "1")
182 | .getOrCreate()
183 | }
184 |
185 | }
186 | ```
187 |
188 | It's best set the number of shuffle partitions to a small number like one or four in your test suite. This configuration
189 | can make your tests run up to 70% faster. You can remove this configuration option or adjust it if you're working with
190 | big DataFrames in your test suite.
191 |
192 | Make sure to only use the `SparkSessionTestWrapper` trait in your test suite. You don't want to use test specific
193 | configuration (like one shuffle partition) when running production code.
194 |
195 | ### DatasetComparer / DataFrameComparer
196 | The `DatasetComparer` trait defines the `assertSmallDatasetEquality` method. Extend your spec file with the
197 | `SparkSessionTestWrapper` trait to create DataFrames and the `DatasetComparer` trait to make DataFrame comparisons.
198 |
199 | ```scala
200 | import org.apache.spark.sql.types._
201 | import org.apache.spark.sql.Row
202 | import org.apache.spark.sql.functions._
203 | import com.github.mrpowers.spark.fast.tests.DatasetComparer
204 |
205 | class DatasetSpec extends FunSpec with SparkSessionTestWrapper with DatasetComparer {
206 |
207 | import spark.implicits._
208 |
209 | it("aliases a DataFrame") {
210 |
211 | val sourceDF = Seq(
212 | ("jose"),
213 | ("li"),
214 | ("luisa")
215 | ).toDF("name")
216 |
217 | val actualDF = sourceDF.select(col("name").alias("student"))
218 |
219 | val expectedDF = Seq(
220 | ("jose"),
221 | ("li"),
222 | ("luisa")
223 | ).toDF("student")
224 |
225 | assertSmallDatasetEquality(actualDF, expectedDF)
226 |
227 | }
228 | }
229 | ```
230 |
231 | To compare large DataFrames that are partitioned across different nodes in a cluster, use the
232 | `assertLargeDatasetEquality` method.
233 |
234 | ```scala
235 | assertLargeDatasetEquality(actualDF, expectedDF)
236 | ```
237 |
238 | `assertSmallDatasetEquality` is faster for test suites that run on your local machine. `assertLargeDatasetEquality`
239 | should only be used for DataFrames that are split across nodes in a cluster.
240 |
241 | #### Unordered DataFrame equality comparisons
242 |
243 | Suppose you have the following `actualDF`:
244 |
245 | ```
246 | +------+
247 | |number|
248 | +------+
249 | | 1|
250 | | 5|
251 | +------+
252 | ```
253 |
254 | And suppose you have the following `expectedDF`:
255 |
256 | ```
257 | +------+
258 | |number|
259 | +------+
260 | | 5|
261 | | 1|
262 | +------+
263 | ```
264 |
265 | The DataFrames have the same columns and rows, but the order is different.
266 |
267 | `assertSmallDataFrameEquality(sourceDF, expectedDF)` will throw a `DatasetContentMismatch` error.
268 |
269 | We can set the `orderedComparison` boolean flag to `false` and spark-fast-tests will sort the DataFrames before
270 | performing the comparison.
271 |
272 | `assertSmallDataFrameEquality(sourceDF, expectedDF, orderedComparison = false)` will not throw an error.
273 |
274 | #### Equality comparisons ignoring the nullable flag
275 |
276 | You might also want to make equality comparisons that ignore the nullable flags for the DataFrame columns.
277 |
278 | Here is how to use the `ignoreNullable` flag to compare DataFrames without considering the nullable property of each
279 | column.
280 |
281 | ```scala
282 | val sourceDF = spark.createDF(
283 | List(
284 | (1),
285 | (5)
286 | ), List(
287 | ("number", IntegerType, false)
288 | )
289 | )
290 |
291 | val expectedDF = spark.createDF(
292 | List(
293 | (1),
294 | (5)
295 | ), List(
296 | ("number", IntegerType, true)
297 | )
298 | )
299 |
300 | assertSmallDatasetEquality(sourceDF, expectedDF, ignoreNullable = true)
301 | ```
302 |
303 | #### Approximate DataFrame Equality
304 |
305 | The `assertApproximateDataFrameEquality` function is useful for DataFrames that contain `DoubleType` columns. The
306 | precision threshold must be set when using the `assertApproximateDataFrameEquality` function.
307 |
308 | ```scala
309 | val sourceDF = spark.createDF(
310 | List(
311 | (1.2),
312 | (5.1),
313 | (null)
314 | ), List(
315 | ("number", DoubleType, true)
316 | )
317 | )
318 |
319 | val expectedDF = spark.createDF(
320 | List(
321 | (1.2),
322 | (5.1),
323 | (null)
324 | ), List(
325 | ("number", DoubleType, true)
326 | )
327 | )
328 |
329 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01)
330 | ```
331 |
332 | ### Column Equality
333 |
334 | The `assertColumnEquality` method can be used to assess the equality of two columns in a DataFrame.
335 |
336 | Suppose you have the following DataFrame with two columns that are not equal.
337 |
338 | ```
339 | +-------+-------------+
340 | | name|expected_name|
341 | +-------+-------------+
342 | | phil| phil|
343 | | rashid| rashid|
344 | |matthew| mateo|
345 | | sami| sami|
346 | | li| feng|
347 | | null| null|
348 | +-------+-------------+
349 | ```
350 |
351 | The following code will throw a `ColumnMismatch` error message:
352 |
353 | ```scala
354 | assertColumnEquality(df, "name", "expected_name")
355 | ```
356 |
357 |
358 |
359 |
360 |
361 | Mix in the `ColumnComparer` trait to your test class to access the `assertColumnEquality` method:
362 |
363 | ```scala
364 | import com.github.mrpowers.spark.fast.tests.ColumnComparer
365 |
366 | object MySpecialClassTest
367 | extends TestSuite
368 | with ColumnComparer
369 | with SparkSessionTestWrapper {
370 |
371 | // your tests
372 | }
373 | ```
374 |
375 | ### Schema Equality
376 |
377 | The SchemaComparer provide `assertSchemaEqual` API which is useful for comparing schema of dataframe schema
378 |
379 | Consider the following two schemas:
380 |
381 | ```scala
382 | val s1 = StructType(
383 | Seq(
384 | StructField("array", ArrayType(StringType, containsNull = true), true),
385 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
386 | StructField("something", StringType, true),
387 | StructField(
388 | "struct",
389 | StructType(
390 | StructType(
391 | Seq(
392 | StructField("mood", ArrayType(StringType, containsNull = false), true),
393 | StructField("something", StringType, false),
394 | StructField(
395 | "something2",
396 | StructType(
397 | Seq(
398 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
399 | StructField("something2", StringType, false)
400 | )
401 | ),
402 | false
403 | )
404 | )
405 | )
406 | ),
407 | true
408 | )
409 | )
410 | )
411 | val s2 = StructType(
412 | Seq(
413 | StructField("array", ArrayType(StringType, containsNull = true), true),
414 | StructField("something", StringType, true),
415 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
416 | StructField(
417 | "struct",
418 | StructType(
419 | StructType(
420 | Seq(
421 | StructField("something", StringType, false),
422 | StructField("mood", ArrayType(StringType, containsNull = false), true),
423 | StructField(
424 | "something3",
425 | StructType(
426 | Seq(
427 | StructField("mood3", ArrayType(StringType, containsNull = false), true)
428 | )
429 | ),
430 | false
431 | )
432 | )
433 | )
434 | ),
435 | true
436 | ),
437 | StructField("norma2", StringType, false)
438 | )
439 | )
440 |
441 | ```
442 |
443 | The `assertSchemaEqual` support two output format `SchemaDiffOutputFormat.Tree` and `SchemaDiffOutputFormat.Table`. Tree
444 | output
445 | format is useful when the schema is large and contains multi level nested fields.
446 |
447 | ```scala
448 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree)
449 | ```
450 |
451 |
452 |
453 |
454 |
455 | By default `SchemaDiffOutputFormat.Table` is used internally by all dataframe/dataset comparison APIs.
456 |
457 | ## Testing Tips
458 |
459 | * Use column functions instead of UDFs as described
460 | in [this blog post](https://medium.com/@mrpowers/spark-user-defined-functions-udfs-6c849e39443b)
461 | * Try to organize your code
462 | as [custom transformations](https://medium.com/@mrpowers/chaining-custom-dataframe-transformations-in-spark-a39e315f903c)
463 | so it's easy to test the logic elegantly
464 | * Don't write tests that read from files or write files. Dependency injection is a great way to avoid file I/O in you
465 | test suite.
466 |
467 | ## Alternatives
468 |
469 | The [spark-testing-base](https://github.com/holdenk/spark-testing-base) project has more features (e.g. streaming
470 | support) and is compiled to support a variety of Scala and Spark versions.
471 |
472 | You might want to use spark-fast-tests instead of spark-testing-base in these cases:
473 |
474 | * You want to use uTest or a testing framework other than scalatest
475 | * You want to run tests in parallel (you need to set `parallelExecution in Test := false` with spark-testing-base)
476 | * You don't want to include hive as a project dependency
477 | * You don't want to restart the SparkSession after each test file executes so the suite runs faster
478 |
479 | ## Publishing
480 |
481 | GPG & Sonatype need to be setup properly before running these commands. See
482 | the [spark-daria](https://github.com/MrPowers/spark-daria) README for more information.
483 |
484 | It's a good idea to always run `clean` before running any publishing commands. It's also important to run `clean` before
485 | different publishing commands as well.
486 |
487 | There is a two step process for publishing.
488 |
489 | Generate Scala 2.11 JAR files:
490 |
491 | * Run `sbt -Dspark.version=2.4.8`
492 | * Run `> ; + publishSigned; sonatypeBundleRelease` to create the JAR files and release them to Maven.
493 |
494 | Generate Scala 2.12 & Scala 2.13 JAR files:
495 |
496 | * Run `sbt`
497 | * Run `> ; + publishSigned; sonatypeBundleRelease`
498 |
499 | The `publishSigned` and `sonatypeBundleRelease` commands are made available by
500 | the [sbt-sonatype](https://github.com/xerial/sbt-sonatype) plugin.
501 |
502 | When the release command is run, you'll be prompted to enter your GPG passphrase.
503 |
504 | The Sonatype credentials should be stored in the `~/.sbt/sonatype_credentials` file in this format:
505 |
506 | ```
507 | realm=Sonatype Nexus Repository Manager
508 | host=oss.sonatype.org
509 | user=$USERNAME
510 | password=$PASSWORD
511 | ```
512 |
513 | ## Additional Goals
514 |
515 | * Use memory efficiently so Spark test runs don't crash
516 | * Provide readable error messages
517 | * Easy to use in conjunction with other test suites
518 | * Give the user control of the SparkSession
519 |
520 | ## Contributing
521 |
522 | Open an issue or send a pull request to contribute. Anyone that makes good contributions to the project will be promoted
523 | to project maintainer status.
524 |
525 | ## uTest settings to display color output
526 |
527 | Create a `CustomFramework` class with overrides that turn off the default uTest color settings.
528 |
529 | ```scala
530 | package com.github.mrpowers.spark.fast.tests
531 |
532 | class CustomFramework extends utest.runner.Framework {
533 | override def formatWrapWidth: Int = 300
534 |
535 | // turn off the default exception message color, so spark-fast-tests
536 | // can send messages with custom colors
537 | override def exceptionMsgColor = toggledColor(utest.ufansi.Attrs.Empty)
538 |
539 | override def exceptionPrefixColor = toggledColor(utest.ufansi.Attrs.Empty)
540 |
541 | override def exceptionMethodColor = toggledColor(utest.ufansi.Attrs.Empty)
542 |
543 | override def exceptionPunctuationColor = toggledColor(utest.ufansi.Attrs.Empty)
544 |
545 | override def exceptionLineNumberColor = toggledColor(utest.ufansi.Attrs.Empty)
546 | }
547 | ```
548 |
549 | Update the `build.sbt` file to use the `CustomFramework` class:
550 |
551 | ```scala
552 | testFrameworks += new TestFramework("com.github.mrpowers.spark.fast.tests.CustomFramework")
553 | ```
554 |
555 |
556 |
--------------------------------------------------------------------------------
/benchmarks/src/main/scala/com/github/mrpowers/spark/fast/tests/ColumnComparerBenchmark.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.SparkSession
4 | import org.openjdk.jmh.annotations._
5 | import org.openjdk.jmh.infra.Blackhole
6 |
7 | import java.util.concurrent.TimeUnit
8 | import scala.util.Try
9 |
10 | private class ColumnComparerBenchmark extends ColumnComparer {
11 | @Benchmark
12 | @BenchmarkMode(Array(Mode.SingleShotTime))
13 | @Fork(value = 2)
14 | @Warmup(iterations = 10)
15 | @Measurement(iterations = 10)
16 | @OutputTimeUnit(TimeUnit.NANOSECONDS)
17 | def assertColumnEqualityBenchmarks(blackHole: Blackhole): Boolean = {
18 | val spark = SparkSession
19 | .builder()
20 | .master("local")
21 | .appName("spark session")
22 | .config("spark.sql.shuffle.partitions", "1")
23 | .getOrCreate()
24 | spark.sparkContext.setLogLevel("ERROR")
25 |
26 | import spark.implicits._
27 | val ds1 = (Seq.fill(100)(("1", "2")) ++ Seq.fill(100)(("3", "4"))).toDF("col_B", "col_A")
28 |
29 | val result = Try(assertColumnEquality(ds1, "col_B", "col_A"))
30 |
31 | blackHole.consume(result)
32 | result.isSuccess
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/benchmarks/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerBenchmark.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.SparkSession
4 | import org.openjdk.jmh.annotations._
5 | import org.openjdk.jmh.infra.Blackhole
6 |
7 | import java.util.concurrent.TimeUnit
8 | import scala.util.Try
9 |
10 | private class DataFrameComparerBenchmark extends DataFrameComparer {
11 | @Benchmark
12 | @BenchmarkMode(Array(Mode.SingleShotTime))
13 | @Fork(value = 2)
14 | @Warmup(iterations = 10)
15 | @Measurement(iterations = 10)
16 | @OutputTimeUnit(TimeUnit.NANOSECONDS)
17 | def assertApproximateDataFrameEqualityWithPrecision(blackHole: Blackhole): Boolean = {
18 | val spark = SparkSession
19 | .builder()
20 | .master("local")
21 | .appName("spark session")
22 | .config("spark.sql.shuffle.partitions", "1")
23 | .getOrCreate()
24 | spark.sparkContext.setLogLevel("ERROR")
25 |
26 | import spark.implicits._
27 | val sameData = Seq.fill(50)("1", "10/01/2019", 26.762499999999996)
28 | val ds1Data = sameData ++ Seq.fill(50)("1", "11/01/2019", 26.76249999999991)
29 | val ds2Data = sameData ++ Seq.fill(50)("3", "12/01/2019", 26.76249999999991)
30 |
31 | val ds1 = ds1Data.toDF("col_B", "col_C", "col_A")
32 | val ds2 = ds2Data.toDF("col_B", "col_C", "col_A")
33 |
34 | val result = Try(assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false))
35 |
36 | blackHole.consume(result)
37 | result.isSuccess
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | val versionRegex = """^(.*)\.(.*)\.(.*)$""".r
2 | val scala2_13 = "2.13.14"
3 | val scala2_12 = "2.12.20"
4 | val sparkVersion = System.getProperty("spark.version", "3.5.3")
5 | val noPublish = Seq(
6 | (publish / skip) := true,
7 | publishArtifact := false
8 | )
9 |
10 | inThisBuild(
11 | List(
12 | organization := "com.github.mrpowers",
13 | homepage := Some(url("https://github.com/mrpowers-io/spark-fast-tests")),
14 | licenses := Seq("MIT" -> url("http://opensource.org/licenses/MIT")),
15 | developers ++= List(
16 | Developer("MrPowers", "Matthew Powers", "@MrPowers", url("https://github.com/MrPowers"))
17 | ),
18 | Compile / scalafmtOnCompile := true,
19 | Test / fork := true,
20 | crossScalaVersions := {
21 | sparkVersion match {
22 | case versionRegex("3", m, _) if m.toInt >= 2 => Seq(scala2_12, scala2_13)
23 | case versionRegex("3", _, _) => Seq(scala2_12)
24 | }
25 | },
26 | scalaVersion := crossScalaVersions.value.head
27 | )
28 | )
29 |
30 | enablePlugins(GitVersioning)
31 |
32 | lazy val commonSettings = Seq(
33 | javaOptions ++= {
34 | Seq("-Xms512M", "-Xmx2048M", "-Duser.timezone=GMT") ++ (if (System.getProperty("java.version").startsWith("1.8.0"))
35 | Seq("-XX:+CMSClassUnloadingEnabled")
36 | else Seq.empty)
37 | },
38 | libraryDependencies ++= Seq(
39 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
40 | "org.scalatest" %% "scalatest" % "3.2.18" % "test"
41 | )
42 | )
43 |
44 | lazy val core = (project in file("core"))
45 | .settings(
46 | commonSettings,
47 | moduleName := "spark-fast-tests",
48 | name := moduleName.value,
49 | Compile / packageSrc / publishArtifact := true,
50 | Compile / packageDoc / publishArtifact := true
51 | )
52 |
53 | lazy val benchmarks = (project in file("benchmarks"))
54 | .dependsOn(core)
55 | .settings(commonSettings)
56 | .settings(
57 | libraryDependencies ++= Seq(
58 | "org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" // required for jmh IDEA plugin. Make sure this version matches sbt-jmh version!
59 | ),
60 | name := "benchmarks"
61 | )
62 | .settings(noPublish)
63 | .enablePlugins(JmhPlugin)
64 |
65 | lazy val docs = (project in file("docs"))
66 | .dependsOn(core)
67 | .enablePlugins(LaikaPlugin)
68 | .settings(
69 | name := "docs",
70 | laikaTheme := {
71 | import laika.ast.Path.Root
72 | import laika.helium.Helium
73 | import laika.helium.config.*
74 |
75 | Helium.defaults.site
76 | .landingPage(
77 | title = Some("Spark Fast Tests"),
78 | subtitle = Some("Unit testing your Apache Spark application"),
79 | latestReleases = Seq(
80 | ReleaseInfo("Latest Stable Release", "1.0.0")
81 | ),
82 | license = Some("MIT"),
83 | titleLinks = Seq(
84 | VersionMenu.create(unversionedLabel = "Getting Started"),
85 | LinkGroup.create(
86 | IconLink.external("https://github.com/mrpowers-io/spark-fast-tests", HeliumIcon.github)
87 | )
88 | ),
89 | linkPanel = Some(
90 | LinkPanel(
91 | "Documentation",
92 | TextLink.internal(Root / "about" / "README.md", "Spark Fast Tests")
93 | )
94 | ),
95 | projectLinks = Seq(
96 | LinkGroup.create(
97 | TextLink.internal(Root / "api" / "com" / "github" / "mrpowers" / "spark" / "fast" / "tests" / "index.html", "API (Scaladoc)")
98 | )
99 | ),
100 | teasers = Seq(
101 | Teaser("Fast", "Handle small dataframes effectively and provide column assertions"),
102 | Teaser("Flexible", "Works fine with scalatest, uTest, munit")
103 | )
104 | )
105 | .build
106 | },
107 | laikaIncludeAPI := true,
108 | laikaExtensions ++= {
109 | import laika.config.SyntaxHighlighting
110 | import laika.format.Markdown
111 | Seq(Markdown.GitHubFlavor, SyntaxHighlighting)
112 | },
113 | Laika / sourceDirectories := Seq((ThisBuild / baseDirectory).value / "docs")
114 | )
115 | .settings(noPublish)
116 |
117 | lazy val root = (project in file("."))
118 | .settings(
119 | name := "spark-fast-tests-root",
120 | commonSettings,
121 | noPublish
122 | )
123 | .aggregate(core, benchmarks, docs)
124 |
125 | scmInfo := Some(ScmInfo(url("https://github.com/mrpowers-io/spark-fast-tests"), "git@github.com:MrPowers/spark-fast-tests.git"))
126 |
127 | updateOptions := updateOptions.value.withLatestSnapshots(false)
128 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ArrayUtil.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import java.sql.Date
4 | import com.github.mrpowers.spark.fast.tests.ufansi.EscapeAttr
5 | import java.time.format.DateTimeFormatter
6 | import org.apache.commons.lang3.StringUtils
7 |
8 | object ArrayUtil {
9 |
10 | def weirdTypesToStrings(arr: Array[(Any, Any)], truncate: Int = 20): Array[List[String]] = {
11 | arr.map { row =>
12 | row.productIterator.toList.map { cell =>
13 | val str = cell match {
14 | case null => "null"
15 | case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]")
16 | case array: Array[_] => array.mkString("[", ", ", "]")
17 | case seq: Seq[_] => seq.mkString("[", ", ", "]")
18 | case d: Date => d.toLocalDate.format(DateTimeFormatter.ISO_DATE)
19 | case _ => cell.toString
20 | }
21 | if (truncate > 0 && str.length > truncate) {
22 | // do not show ellipses for strings shorter than 4 characters.
23 | if (truncate < 4)
24 | str.substring(0, truncate)
25 | else
26 | str.substring(0, truncate - 3) + "..."
27 | } else {
28 | str
29 | }
30 | }
31 | }
32 | }
33 |
34 | def showTwoColumnString(arr: Array[(Any, Any)], truncate: Int = 20): String = {
35 | showTwoColumnStringColorCustomizable(arr, truncate = truncate)
36 | }
37 |
38 | def showTwoColumnStringColorCustomizable(
39 | arr: Array[(Any, Any)],
40 | rowEqual: Option[Array[Boolean]] = None,
41 | truncate: Int = 20,
42 | equalColor: EscapeAttr = ufansi.Color.Blue,
43 | unequalColorLeft: EscapeAttr = ufansi.Color.Red,
44 | unequalColorRight: EscapeAttr = ufansi.Color.Green
45 | ): String = {
46 | val sb = new StringBuilder
47 | val numCols = 2
48 | val rows = weirdTypesToStrings(arr, truncate)
49 |
50 | // Initialise the width of each column to a minimum value of '3'
51 | val colWidths = Array.fill(numCols)(3)
52 |
53 | // Compute the width of each column
54 | for (row <- rows) {
55 | for ((cell, i) <- row.zipWithIndex) {
56 | colWidths(i) = math.max(colWidths(i), cell.length)
57 | }
58 | }
59 |
60 | // Create SeparateLine
61 | val sep: String =
62 | colWidths
63 | .map("-" * _)
64 | .addString(sb, "+", "+", "+\n")
65 | .toString()
66 |
67 | // column names
68 | rows.head.zipWithIndex
69 | .map { case (cell, i) =>
70 | if (truncate > 0) {
71 | StringUtils.leftPad(cell, colWidths(i))
72 | } else {
73 | StringUtils.rightPad(cell, colWidths(i))
74 | }
75 | }
76 | .addString(sb, "|", "|", "|\n")
77 |
78 | sb.append(sep)
79 |
80 | // data
81 | rows.tail.zipWithIndex.map { case (row, j) =>
82 | row.zipWithIndex
83 | .map { case (cell, i) =>
84 | val r = if (truncate > 0) {
85 | StringUtils.leftPad(cell, colWidths(i))
86 | } else {
87 | StringUtils.rightPad(cell, colWidths(i))
88 | }
89 | if (rowEqual.fold(row.head == row(1))(_(j))) {
90 | equalColor(r)
91 | } else if (i == 0) {
92 | unequalColorLeft(r)
93 | } else {
94 | unequalColorRight(r)
95 | }
96 | }
97 | .addString(sb, "|", "|", "|\n")
98 | }
99 |
100 | sb.append(sep)
101 |
102 | sb.toString()
103 | }
104 |
105 | // The following code is taken from: https://github.com/scala/scala/blob/86e75db7f36bcafdd75302f2c2cca0c68413214d/src/partest/scala/tools/partest/Util.scala
106 | def prettyArray(a: Array[_]): collection.IndexedSeq[Any] = new collection.AbstractSeq[Any] with collection.IndexedSeq[Any] {
107 | def length: Int = a.length
108 |
109 | def apply(idx: Int): Any = a(idx) match {
110 | case x: AnyRef if x.getClass.isArray => prettyArray(x.asInstanceOf[Array[_]])
111 | case x => x
112 | }
113 |
114 | // Ignore deprecation warning in 2.13 - this is the correct def for 2.12 compatibility
115 | override def stringPrefix: String = "Array"
116 | }
117 |
118 | // The following code is taken from: https://github.com/scala/scala/blob/86e75db7f36bcafdd75302f2c2cca0c68413214d/src/partest/scala/tools/partest/Util.scala
119 | implicit class ArrayDeep(val a: Array[_]) extends AnyVal {
120 | def deep: collection.IndexedSeq[Any] = prettyArray(a)
121 | }
122 | }
123 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ColumnComparer.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.DataFrame
4 | import org.apache.spark.sql.Row
5 | import org.apache.spark.sql.functions._
6 | import org.apache.spark.sql.types.StructField
7 |
8 | case class ColumnMismatch(smth: String) extends Exception(smth)
9 |
10 | trait ColumnComparer {
11 |
12 | def assertColumnEquality(df: DataFrame, colName1: String, colName2: String): Unit = {
13 | val elements = df
14 | .select(colName1, colName2)
15 | .collect()
16 | val colName1Elements = elements.map(_(0))
17 | val colName2Elements = elements.map(_(1))
18 | if (!colName1Elements.sameElements(colName2Elements)) {
19 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs"
20 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnString(
21 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements)
22 | )
23 | throw ColumnMismatch(mismatchMessage)
24 | }
25 | }
26 |
27 | // ace stands for 'assertColumnEquality'
28 | def ace(df: DataFrame, colName1: String, colName2: String): Unit = {
29 | assertColumnEquality(df, colName1, colName2)
30 | }
31 |
32 | // possibly rename this to assertDeepColumnEquality...
33 | // would deep equality comparison help when comparing other types of columns?
34 | def assertBinaryTypeColumnEquality(df: DataFrame, colName1: String, colName2: String): Unit = {
35 | import ArrayUtil._
36 |
37 | val elements = df
38 | .select(colName1, colName2)
39 | .collect()
40 | val colName1Elements = elements.map(_(0))
41 | val colName2Elements = elements.map(_(1))
42 | if (!colName1Elements.deep.sameElements(colName2Elements.deep)) {
43 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs"
44 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnString(
45 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements)
46 | )
47 | throw ColumnMismatch(mismatchMessage)
48 | }
49 | }
50 |
51 | // private def approximatelyEqualDouble(x: Double, y: Double, precision: Double): Boolean = {
52 | // (x - y).abs < precision
53 | // }
54 | //
55 | // private def areDoubleArraysEqual(x: Array[Double], y: Array[Double], precision: Double): Boolean = {
56 | // x.zip(y).forall { t =>
57 | // approximatelyEqualDouble(t._1, t._2, precision)
58 | // }
59 | // }
60 |
61 | def assertDoubleTypeColumnEquality(df: DataFrame, colName1: String, colName2: String, precision: Double = 0.01): Unit = {
62 | val elements: Array[Row] = df
63 | .select(colName1, colName2)
64 | .collect()
65 | val rowsEqual = scala.collection.mutable.ArrayBuffer[Boolean]()
66 | var allRowsAreEqual = true
67 | for (i <- 0 until elements.length) {
68 | var t = elements(i)
69 | if (t(0) == null && t(1) == null) {
70 | rowsEqual += true
71 | } else if (t(0) != null && t(1) == null) {
72 | rowsEqual += false
73 | allRowsAreEqual = false
74 | } else if (t(0) == null && t(1) != null) {
75 | rowsEqual += false
76 | allRowsAreEqual = false
77 | } else {
78 | if (!((t(0).toString.toDouble - t(1).toString.toDouble).abs < precision)) {
79 | rowsEqual += false
80 | allRowsAreEqual = false
81 | } else {
82 | rowsEqual += true
83 | }
84 | }
85 | }
86 | if (!allRowsAreEqual) {
87 | val colName1Elements = elements.map(_(0))
88 | val colName2Elements = elements.map(_(1))
89 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs"
90 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnStringColorCustomizable(
91 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements),
92 | Some(rowsEqual.toArray)
93 | )
94 | throw ColumnMismatch(mismatchMessage)
95 | }
96 | }
97 |
98 | // private def approximatelyEqualFloat(x: Float, y: Float, precision: Float): Boolean = {
99 | // (x - y).abs < precision
100 | // }
101 | //
102 | // private def areFloatArraysEqual(x: Array[Float], y: Array[Float], precision: Float): Boolean = {
103 | // x.zip(y).forall { t =>
104 | // approximatelyEqualFloat(t._1, t._2, precision)
105 | // }
106 | // }
107 |
108 | def assertFloatTypeColumnEquality(df: DataFrame, colName1: String, colName2: String, precision: Float): Unit = {
109 | val elements: Array[Row] = df
110 | .select(colName1, colName2)
111 | .collect()
112 | val rowsEqual = scala.collection.mutable.ArrayBuffer[Boolean]()
113 | var allRowsAreEqual = true
114 | for (i <- 0 until elements.length) {
115 | var t = elements(i)
116 | if (t(0) == null && t(1) == null) {
117 | rowsEqual += true
118 | } else if (t(0) != null && t(1) == null) {
119 | rowsEqual += false
120 | allRowsAreEqual = false
121 | } else if (t(0) == null && t(1) != null) {
122 | rowsEqual += false
123 | allRowsAreEqual = false
124 | } else {
125 | if (!((t(0).toString.toFloat - t(1).toString.toFloat).abs < precision)) {
126 | rowsEqual += false
127 | allRowsAreEqual = false
128 | } else {
129 | rowsEqual += true
130 | }
131 | }
132 | }
133 | if (!allRowsAreEqual) {
134 | val colName1Elements = elements.map(_(0))
135 | val colName2Elements = elements.map(_(1))
136 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs"
137 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnStringColorCustomizable(
138 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements),
139 | Some(rowsEqual.toArray)
140 | )
141 | throw ColumnMismatch(mismatchMessage)
142 | }
143 | }
144 |
145 | def assertColEquality(df: DataFrame, colName1: String, colName2: String): Unit = {
146 | val cn = s"are_${colName1}_and_${colName2}_equal"
147 |
148 | val schema = df.schema
149 | val sf1: StructField = schema.find((sf: StructField) => sf.name == colName1).get
150 | val sf2 = schema.find((sf: StructField) => sf.name == colName2).get
151 |
152 | if (sf1.dataType != sf2.dataType) {
153 | throw ColumnMismatch(
154 | s"The column dataTypes are different. The `${colName1}` column has a `${sf1.dataType}` dataType and the `${colName2}` column has a `${sf2.dataType}` dataType."
155 | )
156 | }
157 |
158 | val r = df
159 | .withColumn(
160 | cn,
161 | col(colName1) <=> col(colName2)
162 | )
163 | .select(cn)
164 | .collect()
165 |
166 | if (r.contains(Row(false))) {
167 | val elements = df
168 | .select(colName1, colName2)
169 | .collect()
170 | val colName1Elements = elements.map(_(0))
171 | val colName2Elements = elements.map(_(1))
172 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs"
173 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnString(
174 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements)
175 | )
176 | throw ColumnMismatch(mismatchMessage)
177 | }
178 | }
179 |
180 | }
181 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparer.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.{DataFrame, Row}
4 | trait DataFrameComparer extends DatasetComparer {
5 |
6 | /**
7 | * Raises an error unless `actualDF` and `expectedDF` are equal
8 | */
9 | def assertSmallDataFrameEquality(
10 | actualDF: DataFrame,
11 | expectedDF: DataFrame,
12 | ignoreNullable: Boolean = false,
13 | ignoreColumnNames: Boolean = false,
14 | orderedComparison: Boolean = true,
15 | ignoreColumnOrder: Boolean = false,
16 | ignoreMetadata: Boolean = true,
17 | truncate: Int = 500
18 | ): Unit = {
19 | assertSmallDatasetEquality(
20 | actualDF,
21 | expectedDF,
22 | ignoreNullable,
23 | ignoreColumnNames,
24 | orderedComparison,
25 | ignoreColumnOrder,
26 | ignoreMetadata,
27 | truncate
28 | )
29 | }
30 |
31 | /**
32 | * Raises an error unless `actualDF` and `expectedDF` are equal
33 | */
34 | def assertLargeDataFrameEquality(
35 | actualDF: DataFrame,
36 | expectedDF: DataFrame,
37 | ignoreNullable: Boolean = false,
38 | ignoreColumnNames: Boolean = false,
39 | orderedComparison: Boolean = true,
40 | ignoreColumnOrder: Boolean = false,
41 | ignoreMetadata: Boolean = true
42 | ): Unit = {
43 | assertLargeDatasetEquality(
44 | actualDF,
45 | expectedDF,
46 | ignoreNullable = ignoreNullable,
47 | ignoreColumnNames = ignoreColumnNames,
48 | orderedComparison = orderedComparison,
49 | ignoreColumnOrder = ignoreColumnOrder,
50 | ignoreMetadata = ignoreMetadata
51 | )
52 | }
53 |
54 | /**
55 | * Raises an error unless `actualDF` and `expectedDF` are equal
56 | */
57 | def assertApproximateSmallDataFrameEquality(
58 | actualDF: DataFrame,
59 | expectedDF: DataFrame,
60 | precision: Double,
61 | ignoreNullable: Boolean = false,
62 | ignoreColumnNames: Boolean = false,
63 | orderedComparison: Boolean = true,
64 | ignoreColumnOrder: Boolean = false,
65 | ignoreMetadata: Boolean = true
66 | ): Unit = {
67 | assertSmallDatasetEquality[Row](
68 | actualDF,
69 | expectedDF,
70 | ignoreNullable,
71 | ignoreColumnNames,
72 | orderedComparison,
73 | ignoreColumnOrder,
74 | ignoreMetadata,
75 | equals = RowComparer.areRowsEqual(_, _, precision)
76 | )
77 | }
78 |
79 | /**
80 | * Raises an error unless `actualDF` and `expectedDF` are equal
81 | */
82 | def assertApproximateLargeDataFrameEquality(
83 | actualDF: DataFrame,
84 | expectedDF: DataFrame,
85 | precision: Double,
86 | ignoreNullable: Boolean = false,
87 | ignoreColumnNames: Boolean = false,
88 | orderedComparison: Boolean = true,
89 | ignoreColumnOrder: Boolean = false,
90 | ignoreMetadata: Boolean = true
91 | ): Unit = {
92 | assertLargeDatasetEquality[Row](
93 | actualDF,
94 | expectedDF,
95 | equals = RowComparer.areRowsEqual(_, _, precision),
96 | ignoreNullable,
97 | ignoreColumnNames,
98 | orderedComparison,
99 | ignoreColumnOrder,
100 | ignoreMetadata
101 | )
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFramePrettyPrint.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import java.sql.Date
4 | import java.time.format.DateTimeFormatter
5 | import org.apache.commons.lang3.StringUtils
6 |
7 | import org.apache.spark.sql.{DataFrame, Row}
8 |
9 | object DataFramePrettyPrint {
10 |
11 | def showString(df: DataFrame, _numRows: Int, truncate: Int = 20): String = {
12 | val numRows = _numRows.max(0)
13 | val takeResult = df.take(numRows + 1)
14 | val hasMoreData = takeResult.length > numRows
15 | val data = takeResult.take(numRows)
16 |
17 | // For array values, replace Seq and Array with square brackets
18 | // For cells that are beyond `truncate` characters, replace it with the
19 | // first `truncate-3` and "..."
20 | val rows: Seq[Seq[String]] = df.schema.fieldNames.toSeq +: data.map { row =>
21 | row.toSeq.map { cell =>
22 | val str = cell match {
23 | case null => "null"
24 | case binary: Array[Byte] =>
25 | binary
26 | .map("%02X".format(_))
27 | .mkString(
28 | "[",
29 | " ",
30 | "]"
31 | )
32 | case array: Array[_] =>
33 | array.mkString(
34 | "[",
35 | ", ",
36 | "]"
37 | )
38 | case seq: Seq[_] =>
39 | seq.mkString(
40 | "[",
41 | ", ",
42 | "]"
43 | )
44 | case d: Date =>
45 | d.toLocalDate.format(DateTimeFormatter.ISO_DATE)
46 | case r: Row =>
47 | r.schema.fieldNames
48 | .zip(r.toSeq)
49 | .map { case (k, v) =>
50 | s"$k -> $v"
51 | }
52 | .mkString(
53 | "{",
54 | ", ",
55 | "}"
56 | )
57 | case _ => cell.toString
58 | }
59 | if (truncate > 0 && str.length > truncate) {
60 | // do not show ellipses for strings shorter than 4 characters.
61 | if (truncate < 4)
62 | str.substring(
63 | 0,
64 | truncate
65 | )
66 | else
67 | str.substring(
68 | 0,
69 | truncate - 3
70 | ) + "..."
71 | } else {
72 | str
73 | }
74 | }: Seq[String]
75 | }
76 |
77 | val sb = new StringBuilder
78 | val numCols = df.schema.fieldNames.length
79 |
80 | // Initialise the width of each column to a minimum value of '3'
81 | val colWidths = Array.fill(numCols)(3)
82 |
83 | // Compute the width of each column
84 | for (row <- rows) {
85 | for ((cell, i) <- row.zipWithIndex) {
86 | colWidths(i) = math.max(
87 | colWidths(i),
88 | cell.length
89 | )
90 | }
91 | }
92 |
93 | // Create SeparateLine
94 | val sep: String =
95 | colWidths
96 | .map("-" * _)
97 | .addString(
98 | sb,
99 | "+",
100 | "+",
101 | "+\n"
102 | )
103 | .toString()
104 |
105 | // column names
106 | val h: Seq[(String, Int)] = rows.head.zipWithIndex
107 | h.map { case (cell, i) =>
108 | if (truncate > 0) {
109 | StringUtils.leftPad(
110 | cell,
111 | colWidths(i)
112 | )
113 | } else {
114 | StringUtils.rightPad(
115 | cell,
116 | colWidths(i)
117 | )
118 | }
119 | }.addString(
120 | sb,
121 | "|",
122 | "|",
123 | "|\n"
124 | )
125 |
126 | sb.append(sep)
127 |
128 | // data
129 | rows.tail.map {
130 | _.zipWithIndex
131 | .map { case (cell, i) =>
132 | if (truncate > 0) {
133 | StringUtils.leftPad(
134 | cell.toString,
135 | colWidths(i)
136 | )
137 | } else {
138 | StringUtils.rightPad(
139 | cell.toString,
140 | colWidths(i)
141 | )
142 | }
143 | }
144 | .addString(
145 | sb,
146 | "|",
147 | "|",
148 | "|\n"
149 | )
150 | }
151 |
152 | sb.append(sep)
153 |
154 | // For Data that has more than "numRows" records
155 | if (hasMoreData) {
156 | val rowsString = if (numRows == 1) "row" else "rows"
157 | sb.append(s"only showing top $numRows $rowsString\n")
158 | }
159 |
160 | sb.toString()
161 | }
162 |
163 | }
164 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow
4 | import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions
5 | import org.apache.spark.rdd.RDD
6 | import org.apache.spark.sql.functions._
7 | import org.apache.spark.sql.{DataFrame, Dataset, Row}
8 |
9 | import scala.reflect.ClassTag
10 |
11 | case class DatasetContentMismatch(smth: String) extends Exception(smth)
12 | case class DatasetCountMismatch(smth: String) extends Exception(smth)
13 |
14 | trait DatasetComparer {
15 | private def countMismatchMessage(actualCount: Long, expectedCount: Long): String = {
16 | s"""
17 | Actual DataFrame Row Count: '$actualCount'
18 | Expected DataFrame Row Count: '$expectedCount'
19 | """
20 | }
21 |
22 | private def unequalRDDMessage[T](unequalRDD: RDD[(Long, (T, T))], length: Int): String = {
23 | "\nRow Index | Actual Row | Expected Row\n" + unequalRDD
24 | .take(length)
25 | .map { case (idx, (left, right)) =>
26 | ufansi.Color.Red(s"$idx | $left | $right")
27 | }
28 | .mkString("\n")
29 | }
30 |
31 | /**
32 | * order ds1 column according to ds2 column order
33 | */
34 | def orderColumns[T](ds1: Dataset[T], ds2: Dataset[T]): Dataset[T] = {
35 | ds1.select(ds2.columns.map(col).toIndexedSeq: _*).as[T](ds2.encoder)
36 | }
37 |
38 | /**
39 | * Raises an error unless `actualDS` and `expectedDS` are equal
40 | */
41 | def assertSmallDatasetEquality[T: ClassTag](
42 | actualDS: Dataset[T],
43 | expectedDS: Dataset[T],
44 | ignoreNullable: Boolean = false,
45 | ignoreColumnNames: Boolean = false,
46 | orderedComparison: Boolean = true,
47 | ignoreColumnOrder: Boolean = false,
48 | ignoreMetadata: Boolean = true,
49 | truncate: Int = 500,
50 | equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2)
51 | ): Unit = {
52 | SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
53 | val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
54 | assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals)
55 | }
56 |
57 | def assertSmallDatasetContentEquality[T: ClassTag](
58 | actualDS: Dataset[T],
59 | expectedDS: Dataset[T],
60 | orderedComparison: Boolean,
61 | truncate: Int,
62 | equals: (T, T) => Boolean
63 | ): Unit = {
64 | if (orderedComparison)
65 | assertSmallDatasetContentEquality(actualDS, expectedDS, truncate, equals)
66 | else
67 | assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals)
68 | }
69 |
70 | def assertSmallDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = {
71 | val a = actualDS.collect().toSeq
72 | val e = expectedDS.collect().toSeq
73 | if (!a.approximateSameElements(e, equals)) {
74 | val arr = ("Actual Content", "Expected Content")
75 | val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate)
76 | throw DatasetContentMismatch(msg)
77 | }
78 | }
79 |
80 | def defaultSortDataset[T](ds: Dataset[T]): Dataset[T] = ds.sort(ds.columns.map(col).toIndexedSeq: _*)
81 |
82 | def sortPreciseColumns[T](ds: Dataset[T]): Dataset[T] = {
83 | val colNames = ds.dtypes
84 | .withFilter { dtype =>
85 | !Seq("DoubleType", "DecimalType", "FloatType").contains(dtype._2)
86 | }
87 | .map(_._1)
88 | val cols = colNames.map(col)
89 | ds.sort(cols: _*)
90 | }
91 |
92 | /**
93 | * Raises an error unless `actualDS` and `expectedDS` are equal
94 | */
95 | def assertLargeDatasetEquality[T: ClassTag](
96 | actualDS: Dataset[T],
97 | expectedDS: Dataset[T],
98 | equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2),
99 | ignoreNullable: Boolean = false,
100 | ignoreColumnNames: Boolean = false,
101 | orderedComparison: Boolean = true,
102 | ignoreColumnOrder: Boolean = false,
103 | ignoreMetadata: Boolean = true
104 | ): Unit = {
105 | // first check if the schemas are equal
106 | SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
107 | val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
108 | assertLargeDatasetContentEquality(actual, expectedDS, equals, orderedComparison)
109 | }
110 |
111 | def assertLargeDatasetContentEquality[T: ClassTag](
112 | actualDS: Dataset[T],
113 | expectedDS: Dataset[T],
114 | equals: (T, T) => Boolean,
115 | orderedComparison: Boolean
116 | ): Unit = {
117 | if (orderedComparison) {
118 | assertLargeDatasetContentEquality(actualDS, expectedDS, equals)
119 | } else {
120 | assertLargeDatasetContentEquality(sortPreciseColumns(actualDS), sortPreciseColumns(expectedDS), equals)
121 | }
122 | }
123 |
124 | def assertLargeDatasetContentEquality[T: ClassTag](ds1: Dataset[T], ds2: Dataset[T], equals: (T, T) => Boolean): Unit = {
125 | try {
126 | val ds1RDD = ds1.rdd.cache()
127 | val ds2RDD = ds2.rdd.cache()
128 |
129 | val actualCount = ds1RDD.count
130 | val expectedCount = ds2RDD.count
131 |
132 | if (actualCount != expectedCount) {
133 | throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount))
134 | }
135 | val expectedIndexValue = RddHelpers.zipWithIndex(ds1RDD)
136 | val resultIndexValue = RddHelpers.zipWithIndex(ds2RDD)
137 | val unequalRDD = expectedIndexValue
138 | .join(resultIndexValue)
139 | .filter { case (_, (o1, o2)) =>
140 | !equals(o1, o2)
141 | }
142 |
143 | if (!unequalRDD.isEmpty()) {
144 | throw DatasetContentMismatch(
145 | unequalRDDMessage(unequalRDD, maxUnequalRowsToShow)
146 | )
147 | }
148 |
149 | } finally {
150 | ds1.rdd.unpersist()
151 | ds2.rdd.unpersist()
152 | }
153 | }
154 |
155 | def assertApproximateDataFrameEquality(
156 | actualDF: DataFrame,
157 | expectedDF: DataFrame,
158 | precision: Double,
159 | ignoreNullable: Boolean = false,
160 | ignoreColumnNames: Boolean = false,
161 | orderedComparison: Boolean = true,
162 | ignoreColumnOrder: Boolean = false,
163 | ignoreMetadata: Boolean = true
164 | ): Unit = {
165 | val e = (r1: Row, r2: Row) => {
166 | r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, precision)
167 | }
168 | assertLargeDatasetEquality[Row](
169 | actualDF,
170 | expectedDF,
171 | equals = e,
172 | ignoreNullable,
173 | ignoreColumnNames,
174 | orderedComparison,
175 | ignoreColumnOrder,
176 | ignoreMetadata
177 | )
178 | }
179 | }
180 |
181 | object DatasetComparer {
182 | val maxUnequalRowsToShow = 10
183 | }
184 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red}
4 | import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
5 | import org.apache.commons.lang3.StringUtils
6 | import org.apache.spark.sql.Row
7 |
8 | import scala.reflect.ClassTag
9 |
10 | object ProductUtil {
11 | private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = {
12 | product match {
13 | case null => Seq.empty
14 | case a: Array[_] => a
15 | case i: Iterable[_] => i.toSeq
16 | case r: Row => r.toSeq
17 | case p: Product => p.productIterator.toSeq
18 | case s => Seq(s)
19 | }
20 | }
21 |
22 | private def rowFieldToString(fieldValue: Any): String = s"$fieldValue"
23 |
24 | private[mrpowers] def showProductDiff[T: ClassTag](
25 | header: (String, String),
26 | actual: Seq[T],
27 | expected: Seq[T],
28 | truncate: Int = 20,
29 | minColWidth: Int = 3
30 | ): String = {
31 |
32 | val runTimeClass = implicitly[ClassTag[T]].runtimeClass
33 | val (className, lBracket, rBracket) = if (runTimeClass == classOf[Row]) ("", "[", "]") else (runTimeClass.getSimpleName, "(", ")")
34 | val prodToString: Seq[Any] => String = s => s.mkString(s"$className$lBracket", ",", rBracket)
35 | val emptyProd = "MISSING"
36 |
37 | val sb = new StringBuilder
38 |
39 | val fullJoin = actual.zipAll(expected, null, null)
40 |
41 | val diff = fullJoin.map { case (actualRow, expectedRow) =>
42 | if (actualRow == expectedRow) {
43 | List(DarkGray(actualRow.toString), DarkGray(expectedRow.toString))
44 | } else {
45 | val actualSeq = productOrRowToSeq(actualRow)
46 | val expectedSeq = productOrRowToSeq(expectedRow)
47 | if (actualSeq.isEmpty)
48 | List(Red(emptyProd), Green(prodToString(expectedSeq)))
49 | else if (expectedSeq.isEmpty)
50 | List(Red(prodToString(actualSeq)), Green(emptyProd))
51 | else {
52 | val withEquals = actualSeq
53 | .zipAll(expectedSeq, "MISSING", "MISSING")
54 | .map { case (actualRowField, expectedRowField) =>
55 | (actualRowField, expectedRowField, actualRowField == expectedRowField)
56 | }
57 | val allFieldsAreNotEqual = !withEquals.exists(_._3)
58 | if (allFieldsAreNotEqual) {
59 | List(Red(prodToString(actualSeq)), Green(prodToString(expectedSeq)))
60 | } else {
61 | val coloredDiff = withEquals.map {
62 | case (actualRowField, expectedRowField, true) =>
63 | (DarkGray(rowFieldToString(actualRowField)), DarkGray(rowFieldToString(expectedRowField)))
64 | case (actualRowField, expectedRowField, false) =>
65 | (Red(rowFieldToString(actualRowField)), Green(rowFieldToString(expectedRowField)))
66 | }
67 | val start = DarkGray(s"$className$lBracket")
68 | val sep = DarkGray(",")
69 | val end = DarkGray(rBracket)
70 | List(
71 | coloredDiff.map(_._1).mkStr(start, sep, end),
72 | coloredDiff.map(_._2).mkStr(start, sep, end)
73 | )
74 | }
75 | }
76 | }
77 | }
78 | val headerSeq = List(header._1, header._2)
79 | val numCols = 2
80 |
81 | // Initialise the width of each column to a minimum value
82 | val colWidths = Array.fill(numCols)(minColWidth)
83 |
84 | // Compute the width of each column
85 | headerSeq.zipWithIndex.foreach({ case (cell, i) =>
86 | colWidths(i) = math.max(colWidths(i), cell.length)
87 | })
88 |
89 | diff.foreach { row =>
90 | row.zipWithIndex.foreach { case (cell, i) =>
91 | colWidths(i) = math.max(colWidths(i), cell.length)
92 | }
93 | }
94 |
95 | // Create SeparateLine
96 | val sep: String =
97 | colWidths
98 | .map("-" * _)
99 | .addString(sb, "+", "+", "+\n")
100 | .toString
101 |
102 | // column names
103 | headerSeq.zipWithIndex
104 | .map { case (cell, i) =>
105 | if (truncate > 0) {
106 | StringUtils.leftPad(cell, colWidths(i))
107 | } else {
108 | StringUtils.rightPad(cell, colWidths(i))
109 | }
110 | }
111 | .addString(sb, "|", "|", "|\n")
112 |
113 | sb.append(sep)
114 |
115 | diff.map { row =>
116 | row.zipWithIndex
117 | .map { case (cell, i) =>
118 | val padsLen = colWidths(i) - cell.length
119 | val pads = if (padsLen > 0) " " * padsLen else ""
120 | if (truncate > 0) {
121 | pads + cell.toString
122 | } else {
123 | cell.toString + pads
124 | }
125 |
126 | }
127 | .addString(sb, "|", "|", "|\n")
128 | }
129 |
130 | sb.append(sep)
131 |
132 | sb.toString
133 | }
134 | }
135 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/RDDComparer.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.rdd.RDD
4 |
5 | import scala.reflect.ClassTag
6 |
7 | case class RDDContentMismatch(smth: String) extends Exception(smth)
8 |
9 | trait RDDComparer {
10 |
11 | def contentMismatchMessage[T: ClassTag](actualRDD: RDD[T], expectedRDD: RDD[T]): String = {
12 | s"""
13 | Actual RDD Content:
14 | ${actualRDD.take(5).mkString("\n")}
15 | Expected RDD Content:
16 | ${expectedRDD.take(5).mkString("\n")}
17 | """
18 | }
19 |
20 | def assertSmallRDDEquality[T: ClassTag](actualRDD: RDD[T], expectedRDD: RDD[T]): Unit = {
21 | if (!actualRDD.collect().sameElements(expectedRDD.collect())) {
22 | throw RDDContentMismatch(
23 | contentMismatchMessage(actualRDD, expectedRDD)
24 | )
25 | }
26 | }
27 |
28 | }
29 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/RddHelpers.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.rdd.RDD
4 |
5 | object RddHelpers {
6 |
7 | /**
8 | * Zip RDD's with precise indexes. This is used so we can join two DataFrame's Rows together regardless of if the source is different but still
9 | * compare based on the order.
10 | */
11 | def zipWithIndex[T](rdd: RDD[T]): RDD[(Long, T)] = {
12 | rdd.zipWithIndex().map { case (row, idx) =>
13 | (idx, row)
14 | }
15 | }
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/RowComparer.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.commons.math3.util.Precision
4 | import org.apache.spark.sql.Row
5 |
6 | import scala.math.abs
7 |
8 | object RowComparer {
9 |
10 | /** Approximate equality, based on equals from [[Row]] */
11 | def areRowsEqual(r1: Row, r2: Row, tol: Double = 0): Boolean = {
12 | if (tol == 0) {
13 | return r1 == r2
14 | }
15 | if (r1.length != r2.length) {
16 | return false
17 | }
18 | for (i <- 0 until r1.length) {
19 | if (r1.isNullAt(i) != r2.isNullAt(i)) {
20 | return false
21 | }
22 | if (!r1.isNullAt(i)) {
23 | val o1 = r1.get(i)
24 | val o2 = r2.get(i)
25 | val valid = o1 match {
26 | case b1: Array[Byte] =>
27 | o2.isInstanceOf[Array[Byte]] && java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])
28 | case f1: Float if o2.isInstanceOf[Float] => Precision.equalsIncludingNaN(f1, o2.asInstanceOf[Float], tol)
29 | case d1: Double if o2.isInstanceOf[Double] => Precision.equalsIncludingNaN(d1, o2.asInstanceOf[Double], tol)
30 | case bd1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] =>
31 | val bigDecimalCompare = bd1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs().compareTo(new java.math.BigDecimal(tol))
32 | bigDecimalCompare == -1 || bigDecimalCompare == 0
33 | case f1: Number if o2.isInstanceOf[Number] =>
34 | val bd1 = new java.math.BigDecimal(f1.toString)
35 | val bd2 = new java.math.BigDecimal(o2.toString)
36 | bd1.subtract(bd2).abs().compareTo(new java.math.BigDecimal(tol)) == -1
37 | case t1: java.sql.Timestamp => abs(t1.getTime - o2.asInstanceOf[java.sql.Timestamp].getTime) <= tol
38 | case t1: java.time.Instant => abs(t1.toEpochMilli - o2.asInstanceOf[java.time.Instant].toEpochMilli) <= tol
39 | case rr1: Row if o2.isInstanceOf[Row] => areRowsEqual(rr1, o2.asInstanceOf[Row], tol)
40 | case _ => o1 == o2
41 | }
42 | if (!valid) {
43 | return false
44 | }
45 | }
46 | }
47 | true
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff
4 | import com.github.mrpowers.spark.fast.tests.SchemaDiffOutputFormat.SchemaDiffOutputFormat
5 | import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red}
6 | import org.apache.spark.sql.Dataset
7 | import org.apache.spark.sql.types._
8 |
9 | import scala.util.Try
10 |
11 | object SchemaComparer {
12 | private val INDENT_GAP = 5
13 | private val DESCRIPTION_GAP = 21
14 | private val TREE_GAP = 6
15 | private val CONTAIN_NULLS =
16 | "sparkFastTestContainsNull" // Distinguishable metadata key for spark fast tests to avoid potential ambiguity with other metadata keys
17 | case class DatasetSchemaMismatch(smth: String) extends Exception(smth)
18 | private def betterSchemaMismatchMessage(actualSchema: StructType, expectedSchema: StructType): String = {
19 | showProductDiff(
20 | ("Actual Schema", "Expected Schema"),
21 | actualSchema.fields,
22 | expectedSchema.fields,
23 | truncate = 200
24 | )
25 | }
26 |
27 | private def treeSchemaMismatchMessage[T](actualSchema: StructType, expectedSchema: StructType): String = {
28 | def flattenStrucType(s: StructType, indent: Int, additionalGap: Int = 0): (Seq[(Int, StructField)], Int) = s
29 | .foldLeft((Seq.empty[(Int, StructField)], Int.MinValue)) { case ((fieldPair, maxWidth), f) =>
30 | val gap = indent * INDENT_GAP + DESCRIPTION_GAP + f.name.length + f.dataType.typeName.length + f.nullable.toString.length + additionalGap
31 | val pair = fieldPair :+ (indent, f)
32 | val newMaxWidth = scala.math.max(maxWidth, gap)
33 | f.dataType match {
34 | case st: StructType =>
35 | val (flattenPair, width) = flattenStrucType(st, indent + 1)
36 | (pair ++ flattenPair, scala.math.max(newMaxWidth, width))
37 | case ArrayType(elementType, containsNull) =>
38 | val arrStruct = StructType(
39 | Seq(
40 | StructField(
41 | "element",
42 | elementType,
43 | metadata = new MetadataBuilder()
44 | .putBoolean(CONTAIN_NULLS, value = containsNull)
45 | .build()
46 | )
47 | )
48 | )
49 | val (flattenPair, width) = flattenStrucType(arrStruct, indent + 1, 5)
50 | (pair ++ flattenPair, scala.math.max(newMaxWidth, width))
51 | case _ => (pair, newMaxWidth)
52 | }
53 | }
54 |
55 | val (treeFieldPair1, tree1MaxWidth) = flattenStrucType(actualSchema, 0)
56 | val (treeFieldPair2, _) = flattenStrucType(expectedSchema, 0)
57 | val (treePair, maxWidth) = treeFieldPair1
58 | .zipAll(treeFieldPair2, (0, null), (0, null))
59 | .foldLeft((Seq.empty[(String, String)], 0)) { case ((acc, maxWidth), ((indent1, field1), (indent2, field2))) =>
60 | val (prefix1, prefix2) = getIndentPair(indent1, indent2)
61 | val (name1, name2) = getNamePair(field1, field2)
62 | val (dtype1, dtype2) = getDataTypePair(field1, field2)
63 | val (nullable1, nullable2) = getNullablePair(field1, field2)
64 | val (containNulls1, containNulls2) = getContainNullsPair(field1, field2)
65 | val structString1 = formatField(field1, prefix1, name1, dtype1, nullable1, containNulls1)
66 | val structString2 = formatField(field2, prefix2, name2, dtype2, nullable2, containNulls2)
67 |
68 | (acc :+ (structString1, structString2), math.max(maxWidth, structString1.length))
69 | }
70 |
71 | val schemaGap = maxWidth + TREE_GAP
72 | val headerGap = tree1MaxWidth + TREE_GAP
73 | treePair
74 | .foldLeft(new StringBuilder("\nActual Schema".padTo(headerGap, ' ') + "Expected Schema\n")) { case (sb, (s1, s2)) =>
75 | val gap = if (s1.isEmpty) headerGap else schemaGap
76 | val s = if (s2.isEmpty) s1 else s1.padTo(gap, ' ')
77 | sb.append(s + s2 + "\n")
78 | }
79 | .toString()
80 | }
81 |
82 | private def getDescriptionPair(nullable: String, containNulls: String): String =
83 | if (containNulls.isEmpty) s"(nullable = $nullable)" else s"(containsNull = $containNulls)"
84 |
85 | private def formatField(field: StructField, prefix: String, name: String, dtype: String, nullable: String, containNulls: String): String = {
86 | if (field == null) {
87 | ""
88 | } else {
89 | val description = getDescriptionPair(nullable, containNulls)
90 | s"$prefix $name : $dtype $description"
91 | }
92 | }
93 |
94 | private def getContainNullsPair(field1: StructField, field2: StructField): (String, String) = {
95 | (field1, field2) match {
96 | case (f, null) =>
97 | val containNulls = Try(f.metadata.getBoolean(CONTAIN_NULLS).toString).getOrElse("")
98 | Red(containNulls).toString -> ""
99 | case (null, f) =>
100 | val containNulls = Try(f.metadata.getBoolean(CONTAIN_NULLS).toString).getOrElse("")
101 | "" -> Green(containNulls).toString
102 | case (StructField(_, _, _, m1), StructField(_, _, _, m2)) =>
103 | val isArrayElement1 = m1.contains(CONTAIN_NULLS)
104 | val isArrayElement2 = m2.contains(CONTAIN_NULLS)
105 | if (isArrayElement1 && isArrayElement2) {
106 | val containNulls1 = m1.getBoolean(CONTAIN_NULLS)
107 | val containNulls2 = m2.getBoolean(CONTAIN_NULLS)
108 | val (cn1, cn2) = if (containNulls1 != containNulls2) {
109 | (Red(containNulls1.toString), Green(containNulls2.toString))
110 | } else {
111 | (DarkGray(containNulls1.toString), DarkGray(containNulls2.toString))
112 | }
113 | (cn1.toString, cn2.toString)
114 | } else if (isArrayElement1) {
115 | (DarkGray(m1.getBoolean(CONTAIN_NULLS).toString).toString, "")
116 | } else if (isArrayElement2) {
117 | ("", DarkGray(m2.getBoolean(CONTAIN_NULLS).toString).toString)
118 | } else {
119 | ("", "")
120 | }
121 | }
122 | }
123 |
124 | private def getIndentPair(indent1: Int, indent2: Int): (String, String) = {
125 | def depthToIndentStr(depth: Int): String = Range(0, depth).map(_ => "| ").mkString + "|--"
126 | val prefix1 = depthToIndentStr(indent1)
127 | val prefix2 = depthToIndentStr(indent2)
128 | val (p1, p2) = if (indent1 != indent2) {
129 | (Red(prefix1), Green(prefix2))
130 | } else {
131 | (DarkGray(prefix1), DarkGray(prefix2))
132 | }
133 | (p1.toString, p2.toString)
134 | }
135 |
136 | private def getNamePair(field1: StructField, field2: StructField): (String, String) = (field1, field2) match {
137 | case (_: StructField, null) => (Red(field1.name).toString, "")
138 | case (null, _: StructField) => ("", Green(field2.name).toString)
139 | case (f1, f2) if f1.name == f2.name => (DarkGray(field1.name).toString, DarkGray(field2.name).toString)
140 | case (f1, f2) if f1.name != f2.name => (Red(field1.name).toString, Green(field2.name).toString)
141 | }
142 |
143 | private def getDataTypePair(field1: StructField, field2: StructField): (String, String) = {
144 | (field1, field2) match {
145 | case (f: StructField, null) => (Red(f.dataType.typeName).toString, "")
146 | case (null, f: StructField) => ("", Green(f.dataType.typeName).toString)
147 | case (f1, f2) if f1.dataType == f2.dataType => (DarkGray(f1.dataType.typeName).toString, DarkGray(f2.dataType.typeName).toString)
148 | case (f1, f2) if f1.dataType != f2.dataType => (Red(f1.dataType.typeName).toString, Green(f2.dataType.typeName).toString)
149 | }
150 | }
151 |
152 | private def getNullablePair(field1: StructField, field2: StructField): (String, String) = {
153 | (field1, field2) match {
154 | case (f: StructField, null) => (Red(f.nullable.toString).toString, "")
155 | case (null, f: StructField) => ("", Green(f.nullable.toString).toString)
156 | case (f1, f2) if f1.nullable == f2.nullable => (DarkGray(f1.nullable.toString).toString, DarkGray(f2.nullable.toString).toString)
157 | case (f1, f2) if f1.nullable != f2.nullable => (Red(f1.nullable.toString).toString, Green(f2.nullable.toString).toString)
158 | }
159 | }
160 |
161 | def assertDatasetSchemaEqual[T](
162 | actualDS: Dataset[T],
163 | expectedDS: Dataset[T],
164 | ignoreNullable: Boolean = false,
165 | ignoreColumnNames: Boolean = false,
166 | ignoreColumnOrder: Boolean = true,
167 | ignoreMetadata: Boolean = true,
168 | outputFormat: SchemaDiffOutputFormat = SchemaDiffOutputFormat.Table
169 | ): Unit = {
170 | assertSchemaEqual(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata, outputFormat)
171 | }
172 |
173 | def assertSchemaEqual(
174 | actualSchema: StructType,
175 | expectedSchema: StructType,
176 | ignoreNullable: Boolean = false,
177 | ignoreColumnNames: Boolean = false,
178 | ignoreColumnOrder: Boolean = true,
179 | ignoreMetadata: Boolean = true,
180 | outputFormat: SchemaDiffOutputFormat = SchemaDiffOutputFormat.Table
181 | ): Unit = {
182 | require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.")
183 | if (!SchemaComparer.equals(actualSchema, expectedSchema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)) {
184 | val diffString = outputFormat match {
185 | case SchemaDiffOutputFormat.Tree => treeSchemaMismatchMessage(actualSchema, expectedSchema)
186 | case SchemaDiffOutputFormat.Table => betterSchemaMismatchMessage(actualSchema, expectedSchema)
187 | }
188 |
189 | throw DatasetSchemaMismatch(s"Diffs\n$diffString")
190 | }
191 | }
192 |
193 | def equals(
194 | s1: StructType,
195 | s2: StructType,
196 | ignoreNullable: Boolean = false,
197 | ignoreColumnNames: Boolean = false,
198 | ignoreColumnOrder: Boolean = true,
199 | ignoreMetadata: Boolean = true
200 | ): Boolean = {
201 | if (s1.length != s2.length) {
202 | false
203 | } else {
204 | if (s1.length != s2.length) {
205 | false
206 | } else {
207 | val zipStruct = if (ignoreColumnOrder) s1.sortBy(_.name) zip s2.sortBy(_.name) else s1 zip s2
208 | zipStruct.forall { case (f1, f2) =>
209 | (f1.nullable == f2.nullable || ignoreNullable) &&
210 | (f1.name == f2.name || ignoreColumnNames) &&
211 | (f1.metadata == f2.metadata || ignoreMetadata) &&
212 | equals(f1.dataType, f2.dataType, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
213 | }
214 | }
215 | }
216 | }
217 |
218 | def equals(
219 | dt1: DataType,
220 | dt2: DataType,
221 | ignoreNullable: Boolean,
222 | ignoreColumnNames: Boolean,
223 | ignoreColumnOrder: Boolean,
224 | ignoreMetadata: Boolean
225 | ): Boolean = {
226 | (dt1, dt2) match {
227 | case (st1: StructType, st2: StructType) =>
228 | equals(st1, st2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)
229 | case (ArrayType(vdt1, _), ArrayType(vdt2, _)) =>
230 | equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
231 | case (MapType(kdt1, vdt1, _), MapType(kdt2, vdt2, _)) =>
232 | equals(kdt1, kdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) &&
233 | equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
234 | case _ => dt1 == dt2
235 | }
236 | }
237 | }
238 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaDiffOutputFormat.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | object SchemaDiffOutputFormat extends Enumeration {
4 | type SchemaDiffOutputFormat = Value
5 |
6 | val Tree, Table = Value
7 | }
8 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SeqLikesExtensions.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.Row
4 |
5 | import scala.util.Try
6 |
7 | object SeqLikesExtensions {
8 | implicit class SeqExtensions[T](val seq1: Seq[T]) extends AnyVal {
9 | def approximateSameElements(seq2: Seq[T], equals: (T, T) => Boolean): Boolean = (seq1, seq2) match {
10 | case (i1: IndexedSeq[_], i2: IndexedSeq[_]) =>
11 | val length = i1.length
12 | var equal = length == i2.length
13 | if (equal) {
14 | var index = 0
15 | val maxApplyCompare = {
16 | val preferredLength =
17 | Try(System.getProperty("scala.collection.immutable.IndexedSeq.defaultApplyPreferredMaxLength", "64").toInt).getOrElse(64)
18 | if (length > (preferredLength.toLong << 1)) preferredLength else length
19 | }
20 | while (index < maxApplyCompare && equal) {
21 | equal = equals(i1(index), i2(index))
22 | index += 1
23 | }
24 | if ((index < length) && equal) {
25 | val thisIt = i1.iterator.drop(index)
26 | val thatIt = i2.iterator.drop(index)
27 | while (equal && thisIt.hasNext) {
28 | equal = equals(thisIt.next(), thatIt.next())
29 | }
30 | }
31 | }
32 | equal
33 | case _ =>
34 | val thisKnownSize = getKnownSize(seq1)
35 | val knownSizeDifference = thisKnownSize != -1 && {
36 | val thatKnownSize = getKnownSize(seq2)
37 | thatKnownSize != -1 && thisKnownSize != thatKnownSize
38 | }
39 | if (knownSizeDifference) {
40 | return false
41 | }
42 | val these = seq1.iterator
43 | val those = seq2.iterator
44 | while (these.hasNext && those.hasNext)
45 | if (!equals(these.next(), those.next()))
46 | return false
47 | these.hasNext == those.hasNext
48 | }
49 |
50 | // scala2.13 optimization: check number of element if it can be cheaply computed
51 | private def getKnownSize(s: Seq[T]): Int = Try(s.getClass.getMethod("knownSize").invoke(s).asInstanceOf[Int]).getOrElse(s.length)
52 |
53 | private[mrpowers] def asRows: Seq[Row] = seq1.map {
54 | case x: Row => x
55 | case y: Product => Row(y.productIterator.toSeq: _*)
56 | case a => Row(a)
57 | }
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/FansiExtensions.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests.ufansi
2 | object FansiExtensions {
3 | private[mrpowers] implicit class StrOps(c: Seq[Str]) {
4 | def mkStr(start: Str, sep: Str, end: Str): Str =
5 | start ++ c.reduce(_ ++ sep ++ _) ++ end
6 | }
7 | }
8 |
--------------------------------------------------------------------------------
/core/src/test/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | # Set everything to be logged to the console
2 | log4j.rootCategory=ERROR, console
3 | log4j.appender.console=org.apache.log4j.ConsoleAppender
4 | log4j.appender.console.target=System.err
5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
7 |
8 | # Settings to quiet third party logs that are too verbose
9 | log4j.logger.org.eclipse.jetty=WARN
10 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=WARN
12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=WARN
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/ArrayUtilTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import java.sql.Date
4 | import java.time.LocalDate
5 | import org.scalatest.freespec.AnyFreeSpec
6 |
7 | class ArrayUtilTest extends AnyFreeSpec {
8 |
9 | "blah" in {
10 | val arr: Array[(Any, Any)] = Array(("hi", "there"), ("fun", "train"))
11 | val res = ArrayUtil.weirdTypesToStrings(arr, 10)
12 | assert(res sameElements Array(List("hi", "there"), List("fun", "train")))
13 | }
14 |
15 | "showTwoColumnString" in {
16 | val arr: Array[(Any, Any)] = Array(("word1", "word2"), ("hi", "there"), ("fun", "train"))
17 | println(ArrayUtil.showTwoColumnString(arr, 10))
18 | }
19 |
20 | "showTwoColumnDate" in {
21 | val now = LocalDate.now()
22 | val arr: Array[(Any, Any)] =
23 | Array(("word1", "word2"), (Date.valueOf(now), Date.valueOf(now)), (Date.valueOf(now.plusDays(-1)), Date.valueOf(now)))
24 | println(ArrayUtil.showTwoColumnString(arr, 10))
25 | }
26 |
27 | "dumbshowTwoColumnString" in {
28 | val arr: Array[(Any, Any)] = Array(("word1", "word2"), ("hi", "there"), ("fun", "train"))
29 | val rowEqual = Array(true, false)
30 | println(ArrayUtil.showTwoColumnStringColorCustomizable(arr, Some(rowEqual)))
31 | }
32 |
33 | }
34 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/ColumnComparerTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.Row
4 | import org.apache.spark.sql.functions._
5 | import org.apache.spark.sql.types._
6 | import java.sql.Date
7 | import java.sql.Timestamp
8 |
9 | import org.scalatest.freespec.AnyFreeSpec
10 |
11 | class ColumnComparerTest extends AnyFreeSpec with ColumnComparer with SparkSessionTestWrapper {
12 |
13 | "assertColumnEquality" - {
14 |
15 | "throws an easily readable error message" in {
16 | val sourceData = Seq(
17 | Row("phil", "phil"),
18 | Row("rashid", "rashid"),
19 | Row("matthew", "mateo"),
20 | Row("sami", "sami"),
21 | Row("this is something that is super crazy long", "sami"),
22 | Row("li", "feng"),
23 | Row(null, null)
24 | )
25 | val sourceSchema = List(
26 | StructField("name", StringType, true),
27 | StructField("expected_name", StringType, true)
28 | )
29 | val sourceDF = spark.createDataFrame(
30 | spark.sparkContext.parallelize(sourceData),
31 | StructType(sourceSchema)
32 | )
33 | val e = intercept[ColumnMismatch] {
34 | assertColumnEquality(sourceDF, "name", "expected_name")
35 | }
36 | val e2 = intercept[ColumnMismatch] {
37 | assertColEquality(sourceDF, "name", "expected_name")
38 | }
39 | }
40 |
41 | "doesn't thrown an error when the columns are equal" in {
42 | val sourceData = Seq(
43 | Row(1, 1),
44 | Row(5, 5),
45 | Row(null, null)
46 | )
47 | val sourceSchema = List(
48 | StructField("num", IntegerType, true),
49 | StructField("expected_num", IntegerType, true)
50 | )
51 | val sourceDF = spark.createDataFrame(
52 | spark.sparkContext.parallelize(sourceData),
53 | StructType(sourceSchema)
54 | )
55 | assertColumnEquality(sourceDF, "num", "expected_num")
56 | assertColEquality(sourceDF, "num", "expected_num")
57 | }
58 |
59 | "throws an error if the columns are not equal" in {
60 | val sourceData = Seq(
61 | Row(1, 3),
62 | Row(5, 5),
63 | Row(null, null)
64 | )
65 | val sourceSchema = List(
66 | StructField("num", IntegerType, true),
67 | StructField("expected_num", IntegerType, true)
68 | )
69 | val sourceDF = spark.createDataFrame(
70 | spark.sparkContext.parallelize(sourceData),
71 | StructType(sourceSchema)
72 | )
73 | val e = intercept[ColumnMismatch] {
74 | assertColumnEquality(sourceDF, "num", "expected_num")
75 | }
76 | val e2 = intercept[ColumnMismatch] {
77 | assertColEquality(sourceDF, "num", "expected_num")
78 | }
79 | }
80 |
81 | "throws an error if the columns are different types" in {
82 | val sourceData = Seq(
83 | Row(1, "hi"),
84 | Row(5, "bye"),
85 | Row(null, null)
86 | )
87 | val sourceSchema = List(
88 | StructField("num", IntegerType, true),
89 | StructField("word", StringType, true)
90 | )
91 | val sourceDF = spark.createDataFrame(
92 | spark.sparkContext.parallelize(sourceData),
93 | StructType(sourceSchema)
94 | )
95 | val e = intercept[ColumnMismatch] {
96 | assertColumnEquality(sourceDF, "num", "word")
97 | }
98 | val e2 = intercept[ColumnMismatch] {
99 | assertColEquality(sourceDF, "num", "word")
100 | }
101 | }
102 |
103 | "works properly, even when null is compared with a value" in {
104 | val sourceData = Seq(
105 | Row(1, 1),
106 | Row(null, 5),
107 | Row(null, null)
108 | )
109 | val sourceSchema = List(
110 | StructField("num", IntegerType, true),
111 | StructField("expected_num", IntegerType, true)
112 | )
113 | val sourceDF = spark.createDataFrame(
114 | spark.sparkContext.parallelize(sourceData),
115 | StructType(sourceSchema)
116 | )
117 | val e = intercept[ColumnMismatch] {
118 | assertColumnEquality(sourceDF, "num", "expected_num")
119 | }
120 | val e2 = intercept[ColumnMismatch] {
121 | assertColEquality(sourceDF, "num", "expected_num")
122 | }
123 | }
124 |
125 | "works for ArrayType columns" in {
126 | val sourceData = Seq(
127 | Row(Array("a"), Array("a")),
128 | Row(Array("a", "b"), Array("a", "b")),
129 | Row(Array(), Array()),
130 | Row(null, null)
131 | )
132 | val sourceSchema = List(
133 | StructField("l1", ArrayType(StringType, true), true),
134 | StructField("l2", ArrayType(StringType, true), true)
135 | )
136 | val sourceDF = spark.createDataFrame(
137 | spark.sparkContext.parallelize(sourceData),
138 | StructType(sourceSchema)
139 | )
140 | assertColumnEquality(sourceDF, "l1", "l2")
141 | assertColEquality(sourceDF, "l1", "l2")
142 | }
143 |
144 | "works for nested arrays" in {
145 | val sourceData = Seq(
146 | Row(Array(Array("a"), Array("a")), Array(Array("a"), Array("a"))),
147 | Row(Array(Array("a", "b"), Array("a", "b")), Array(Array("a", "b"), Array("a", "b"))),
148 | Row(Array(Array(), Array()), Array(Array(), Array())),
149 | Row(null, null)
150 | )
151 | val sourceSchema = List(
152 | StructField("l1", ArrayType(ArrayType(StringType, true)), true),
153 | StructField("l2", ArrayType(ArrayType(StringType, true)), true)
154 | )
155 | val sourceDF = spark.createDataFrame(
156 | spark.sparkContext.parallelize(sourceData),
157 | StructType(sourceSchema)
158 | )
159 | assertColumnEquality(sourceDF, "l1", "l2")
160 | assertColEquality(sourceDF, "l1", "l2")
161 | }
162 |
163 | "works for computed ArrayType columns" in {
164 | val sourceData = Seq(
165 | Row("i like blue and red", Array("blue", "red")),
166 | Row("you pink and blue", Array("blue", "pink")),
167 | Row("i like fun", Array(""))
168 | )
169 | val sourceSchema = List(
170 | StructField("words", StringType, true),
171 | StructField("expected_colors", ArrayType(StringType, true), true)
172 | )
173 | val sourceDF = spark.createDataFrame(
174 | spark.sparkContext.parallelize(sourceData),
175 | StructType(sourceSchema)
176 | )
177 | val actualDF = sourceDF.withColumn(
178 | "colors",
179 | coalesce(
180 | split(
181 | concat_ws(
182 | ",",
183 | when(col("words").contains("blue"), "blue"),
184 | when(col("words").contains("red"), "red"),
185 | when(col("words").contains("pink"), "pink"),
186 | when(col("words").contains("cyan"), "cyan")
187 | ),
188 | ","
189 | ),
190 | typedLit(Array())
191 | )
192 | )
193 | assertColumnEquality(actualDF, "colors", "expected_colors")
194 | assertColEquality(actualDF, "colors", "expected_colors")
195 | }
196 |
197 | "works for MapType columns" in {
198 | val data = Seq(
199 | Row(Map("good_song" -> "santeria", "bad_song" -> "doesn't exist"), Map("good_song" -> "santeria", "bad_song" -> "doesn't exist"))
200 | )
201 | val schema = List(
202 | StructField("m1", MapType(StringType, StringType, true), true),
203 | StructField("m2", MapType(StringType, StringType, true), true)
204 | )
205 | val df = spark.createDataFrame(
206 | spark.sparkContext.parallelize(data),
207 | StructType(schema)
208 | )
209 | assertColumnEquality(df, "m1", "m2")
210 | }
211 |
212 | "throws error when MapType columns aren't equal" in {
213 | val data = Seq(
214 | Row(Map("good_song" -> "santeria", "bad_song" -> "doesn't exist"), Map("good_song" -> "what i got", "bad_song" -> "doesn't exist"))
215 | )
216 | val schema = List(
217 | StructField("m1", MapType(StringType, StringType, true), true),
218 | StructField("m2", MapType(StringType, StringType, true), true)
219 | )
220 | val df = spark.createDataFrame(
221 | spark.sparkContext.parallelize(data),
222 | StructType(schema)
223 | )
224 | val e = intercept[ColumnMismatch] {
225 | assertColumnEquality(df, "m1", "m2")
226 | }
227 | }
228 |
229 | "works for MapType columns with deep comparisons" in {
230 | val data = Seq(
231 | Row(Map("good_song" -> Array(1, 2, 3, 4)), Map("good_song" -> Array(1, 2, 3, 4)))
232 | )
233 | val schema = List(
234 | StructField("m1", MapType(StringType, ArrayType(IntegerType, true), true), true),
235 | StructField("m2", MapType(StringType, ArrayType(IntegerType, true), true), true)
236 | )
237 | val df = spark.createDataFrame(
238 | spark.sparkContext.parallelize(data),
239 | StructType(schema)
240 | )
241 | assertColumnEquality(df, "m1", "m2")
242 | }
243 |
244 | "throws an errors for MapType columns with deep comparisons that aren't equal" in {
245 | val data = Seq(
246 | Row(Map("good_song" -> Array(1, 2, 3, 4)), Map("good_song" -> Array(1, 2, 3, 8)))
247 | )
248 | val schema = List(
249 | StructField("m1", MapType(StringType, ArrayType(IntegerType, true), true), true),
250 | StructField("m2", MapType(StringType, ArrayType(IntegerType, true), true), true)
251 | )
252 | val df = spark.createDataFrame(
253 | spark.sparkContext.parallelize(data),
254 | StructType(schema)
255 | )
256 | val e = intercept[ColumnMismatch] {
257 | assertColumnEquality(df, "m1", "m2")
258 | }
259 | }
260 |
261 | "works when DateType columns are equal" in {
262 | val sourceData = Seq(
263 | Row(Date.valueOf("2016-08-09"), Date.valueOf("2016-08-09")),
264 | Row(Date.valueOf("2019-01-01"), Date.valueOf("2019-01-01")),
265 | Row(null, null)
266 | )
267 | val sourceSchema = List(
268 | StructField("d1", DateType, true),
269 | StructField("d2", DateType, true)
270 | )
271 | val sourceDF = spark.createDataFrame(
272 | spark.sparkContext.parallelize(sourceData),
273 | StructType(sourceSchema)
274 | )
275 | assertColumnEquality(sourceDF, "d1", "d2")
276 | }
277 |
278 | "throws an error when DateType columns are not equal" in {
279 | val sourceData = Seq(
280 | Row(Date.valueOf("2010-07-07"), Date.valueOf("2016-08-09")),
281 | Row(Date.valueOf("2019-01-01"), Date.valueOf("2019-01-01")),
282 | Row(null, null)
283 | )
284 | val sourceSchema = List(
285 | StructField("d1", DateType, true),
286 | StructField("d2", DateType, true)
287 | )
288 | val sourceDF = spark.createDataFrame(
289 | spark.sparkContext.parallelize(sourceData),
290 | StructType(sourceSchema)
291 | )
292 | val e = intercept[ColumnMismatch] {
293 | assertColumnEquality(sourceDF, "d1", "d2")
294 | }
295 | }
296 |
297 | "works when TimestampType columns are equal" in {
298 | val sourceData = Seq(
299 | Row(Timestamp.valueOf("2016-08-09 09:57:00"), Timestamp.valueOf("2016-08-09 09:57:00")),
300 | Row(Timestamp.valueOf("2016-04-10 09:57:00"), Timestamp.valueOf("2016-04-10 09:57:00")),
301 | Row(null, null)
302 | )
303 | val sourceSchema = List(
304 | StructField("t1", TimestampType, true),
305 | StructField("t2", TimestampType, true)
306 | )
307 | val sourceDF = spark.createDataFrame(
308 | spark.sparkContext.parallelize(sourceData),
309 | StructType(sourceSchema)
310 | )
311 | assertColumnEquality(sourceDF, "t1", "t2")
312 | }
313 |
314 | "throws an error when TimestampType columns are not equal" in {
315 | val sourceData = Seq(
316 | Row(Timestamp.valueOf("2010-08-09 09:57:00"), Timestamp.valueOf("2016-08-09 09:57:00")),
317 | Row(Timestamp.valueOf("2016-04-10 10:01:00"), Timestamp.valueOf("2016-04-10 09:57:00")),
318 | Row(null, null)
319 | )
320 | val sourceSchema = List(
321 | StructField("t1", TimestampType, true),
322 | StructField("t2", TimestampType, true)
323 | )
324 | val sourceDF = spark.createDataFrame(
325 | spark.sparkContext.parallelize(sourceData),
326 | StructType(sourceSchema)
327 | )
328 | val e = intercept[ColumnMismatch] {
329 | assertColumnEquality(sourceDF, "t1", "t2")
330 | }
331 | }
332 |
333 | "works when ByteType columns are equal" in {
334 | val sourceData = Seq(
335 | Row(10.toByte, 10.toByte),
336 | Row(33.toByte, 33.toByte),
337 | Row(null, null)
338 | )
339 | val sourceSchema = List(
340 | StructField("b1", ByteType, true),
341 | StructField("b2", ByteType, true)
342 | )
343 | val sourceDF = spark.createDataFrame(
344 | spark.sparkContext.parallelize(sourceData),
345 | StructType(sourceSchema)
346 | )
347 | assertColumnEquality(sourceDF, "b1", "b2")
348 | }
349 |
350 | "throws an error when ByteType columns are not equal" in {
351 | val sourceData = Seq(
352 | Row(8.toByte, 10.toByte),
353 | Row(33.toByte, 33.toByte),
354 | Row(null, null)
355 | )
356 | val sourceSchema = List(
357 | StructField("b1", ByteType, true),
358 | StructField("b2", ByteType, true)
359 | )
360 | val sourceDF = spark.createDataFrame(
361 | spark.sparkContext.parallelize(sourceData),
362 | StructType(sourceSchema)
363 | )
364 | val e = intercept[ColumnMismatch] {
365 | assertColumnEquality(sourceDF, "b1", "b2")
366 | }
367 | }
368 |
369 | }
370 |
371 | "assertBinaryTypeColumnEquality" - {
372 |
373 | "works when BinaryType columns are equal" in {
374 | val sourceData = Seq(
375 | Row(Array(10.toByte, 15.toByte), Array(10.toByte, 15.toByte)),
376 | Row(Array(4.toByte, 33.toByte), Array(4.toByte, 33.toByte)),
377 | Row(null, null)
378 | )
379 | val sourceSchema = List(
380 | StructField("b1", BinaryType, true),
381 | StructField("b2", BinaryType, true)
382 | )
383 | val sourceDF = spark.createDataFrame(
384 | spark.sparkContext.parallelize(sourceData),
385 | StructType(sourceSchema)
386 | )
387 | assertBinaryTypeColumnEquality(sourceDF, "b1", "b2")
388 | }
389 |
390 | "throws an error when BinaryType columns are not equal" in {
391 | val sourceData = Seq(
392 | Row(Array(10.toByte, 15.toByte), Array(10.toByte, 15.toByte)),
393 | Row(Array(4.toByte, 33.toByte), Array(4.toByte, 33.toByte)),
394 | Row(null, null),
395 | Row(Array(7.toByte, 33.toByte), Array(4.toByte, 33.toByte)),
396 | Row(Array(4.toByte, 33.toByte), null),
397 | Row(null, Array(4.toByte, 33.toByte))
398 | )
399 | val sourceSchema = List(
400 | StructField("b1", BinaryType, true),
401 | StructField("b2", BinaryType, true)
402 | )
403 | val sourceDF = spark.createDataFrame(
404 | spark.sparkContext.parallelize(sourceData),
405 | StructType(sourceSchema)
406 | )
407 | val e = intercept[ColumnMismatch] {
408 | assertBinaryTypeColumnEquality(sourceDF, "b1", "b2")
409 | }
410 | }
411 |
412 | }
413 |
414 | "assertDoubleTypeColumnEquality" - {
415 |
416 | "doesn't throw an error when two DoubleType columns are equal" in {
417 | val sourceData = Seq(
418 | Row(1.3, 1.3),
419 | Row(5.01, 5.0101),
420 | Row(null, null)
421 | )
422 | val sourceSchema = List(
423 | StructField("d1", DoubleType, true),
424 | StructField("d2", DoubleType, true)
425 | )
426 | val df = spark.createDataFrame(
427 | spark.sparkContext.parallelize(sourceData),
428 | StructType(sourceSchema)
429 | )
430 | assertDoubleTypeColumnEquality(df, "d1", "d2", 0.01)
431 | }
432 |
433 | "throws an error when two DoubleType columns are not equal" in {
434 | val sourceData = Seq(
435 | Row(1.3, 1.8),
436 | Row(5.01, 5.0101),
437 | Row(null, 10.0),
438 | Row(3.4, null),
439 | Row(-1.1, -1.1),
440 | Row(null, null)
441 | )
442 | val sourceSchema = List(
443 | StructField("d1", DoubleType, true),
444 | StructField("d2", DoubleType, true)
445 | )
446 | val df = spark.createDataFrame(
447 | spark.sparkContext.parallelize(sourceData),
448 | StructType(sourceSchema)
449 | )
450 | val e = intercept[ColumnMismatch] {
451 | assertDoubleTypeColumnEquality(df, "d1", "d2", 0.01)
452 | }
453 | }
454 |
455 | }
456 |
457 | "assertFloatTypeColumnEquality" - {
458 |
459 | "doesn't throw an error when two FloatType columns are equal" in {
460 | val sourceData = Seq(
461 | Row(1.3f, 1.3f),
462 | Row(5.01f, 5.0101f),
463 | Row(null, null)
464 | )
465 | val sourceSchema = List(
466 | StructField("num1", FloatType, true),
467 | StructField("num2", FloatType, true)
468 | )
469 | val df = spark.createDataFrame(
470 | spark.sparkContext.parallelize(sourceData),
471 | StructType(sourceSchema)
472 | )
473 | assertFloatTypeColumnEquality(df, "num1", "num2", 0.01f)
474 | }
475 |
476 | "throws an error when two FloatType columns are not equal" in {
477 | val sourceData = Seq(
478 | Row(1.3f, 1.8f),
479 | Row(5.01f, 5.0101f),
480 | Row(null, 10.0f),
481 | Row(3.4f, null),
482 | Row(null, null)
483 | )
484 | val sourceSchema = List(
485 | StructField("d1", FloatType, true),
486 | StructField("d2", FloatType, true)
487 | )
488 | val df = spark.createDataFrame(
489 | spark.sparkContext.parallelize(sourceData),
490 | StructType(sourceSchema)
491 | )
492 | val e = intercept[ColumnMismatch] {
493 | assertFloatTypeColumnEquality(df, "d1", "d2", 0.01f)
494 | }
495 | }
496 |
497 | }
498 |
499 | "assertColEquality" - {
500 |
501 | "throws an easily readable error message" in {
502 | val sourceData = Seq(
503 | Row("phil", "phil"),
504 | Row("rashid", "rashid"),
505 | Row("matthew", "mateo"),
506 | Row("sami", "sami"),
507 | Row("this is something that is super crazy long", "sami"),
508 | Row("li", "feng"),
509 | Row(null, null)
510 | )
511 | val sourceSchema = List(
512 | StructField("name", StringType, true),
513 | StructField("expected_name", StringType, true)
514 | )
515 | val sourceDF = spark.createDataFrame(
516 | spark.sparkContext.parallelize(sourceData),
517 | StructType(sourceSchema)
518 | )
519 | val e = intercept[ColumnMismatch] {
520 | assertColEquality(sourceDF, "name", "expected_name")
521 | }
522 | }
523 |
524 | "doesn't thrown an error when the columns are equal" in {
525 | val sourceData = Seq(
526 | Row(1, 1),
527 | Row(5, 5),
528 | Row(null, null)
529 | )
530 | val sourceSchema = List(
531 | StructField("num", IntegerType, true),
532 | StructField("expected_num", IntegerType, true)
533 | )
534 | val sourceDF = spark.createDataFrame(
535 | spark.sparkContext.parallelize(sourceData),
536 | StructType(sourceSchema)
537 | )
538 | assertColEquality(sourceDF, "num", "expected_num")
539 | }
540 |
541 | "throws an error if the columns are not equal" in {
542 | val sourceData = Seq(
543 | Row(1, 3),
544 | Row(5, 5),
545 | Row(null, null)
546 | )
547 | val sourceSchema = List(
548 | StructField("num", IntegerType, true),
549 | StructField("expected_num", IntegerType, true)
550 | )
551 | val sourceDF = spark.createDataFrame(
552 | spark.sparkContext.parallelize(sourceData),
553 | StructType(sourceSchema)
554 | )
555 | val e = intercept[ColumnMismatch] {
556 | assertColEquality(sourceDF, "num", "expected_num")
557 | }
558 | }
559 |
560 | "throws an error if the columns are different types" in {
561 | val sourceData = Seq(
562 | Row(1, "hi"),
563 | Row(5, "bye"),
564 | Row(null, null)
565 | )
566 | val sourceSchema = List(
567 | StructField("num", IntegerType, true),
568 | StructField("word", StringType, true)
569 | )
570 | val sourceDF = spark.createDataFrame(
571 | spark.sparkContext.parallelize(sourceData),
572 | StructType(sourceSchema)
573 | )
574 | val e = intercept[ColumnMismatch] {
575 | assertColEquality(sourceDF, "num", "word")
576 | }
577 | }
578 |
579 | "works properly, even when null is compared with a value" in {
580 | val sourceData = Seq(
581 | Row(1, 1),
582 | Row(null, 5),
583 | Row(null, null)
584 | )
585 | val sourceSchema = List(
586 | StructField("num", IntegerType, true),
587 | StructField("expected_num", IntegerType, true)
588 | )
589 | val sourceDF = spark.createDataFrame(
590 | spark.sparkContext.parallelize(sourceData),
591 | StructType(sourceSchema)
592 | )
593 | val e = intercept[ColumnMismatch] {
594 | assertColumnEquality(sourceDF, "num", "expected_num")
595 | }
596 | }
597 |
598 | "works for ArrayType columns" in {
599 | val sourceData = Seq(
600 | Row(Array("a"), Array("a")),
601 | Row(Array("a", "b"), Array("a", "b")),
602 | Row(Array(), Array()),
603 | Row(null, null)
604 | )
605 | val sourceSchema = List(
606 | StructField("l1", ArrayType(StringType, true), true),
607 | StructField("l2", ArrayType(StringType, true), true)
608 | )
609 | val sourceDF = spark.createDataFrame(
610 | spark.sparkContext.parallelize(sourceData),
611 | StructType(sourceSchema)
612 | )
613 | assertColumnEquality(sourceDF, "l1", "l2")
614 | }
615 |
616 | "throws an error for unequal nested StructType columns with same schema" in {
617 | val sourceData = Seq(
618 | Row(Row("John", 30), Row("John", 31)),
619 | Row(Row("Jane", 25), Row("Jane", 25)),
620 | Row(Row("Jake", 40), Row("Jake", 40)),
621 | Row(null, null)
622 | )
623 | val nestedSchema = StructType(
624 | List(
625 | StructField("name", StringType, true),
626 | StructField("age", IntegerType, true)
627 | )
628 | )
629 | val sourceSchema = List(
630 | StructField("struct1", nestedSchema, true),
631 | StructField("struct2", nestedSchema, true)
632 | )
633 | val sourceDF = spark.createDataFrame(
634 | spark.sparkContext.parallelize(sourceData),
635 | StructType(sourceSchema)
636 | )
637 | intercept[ColumnMismatch] {
638 | assertColumnEquality(sourceDF, "struct1", "struct2")
639 | }
640 | }
641 |
642 | "throws an error for unequal nested StructType columns with different schema" in {
643 | val sourceData = Seq(
644 | Row(Row("John", 30), Row("John")),
645 | Row(Row("Jane", 25), Row("Jane")),
646 | Row(Row("Jake", 40), Row("Jake")),
647 | Row(null, null)
648 | )
649 | val nestedSchema1 = StructType(
650 | List(
651 | StructField("name", StringType, true),
652 | StructField("age", IntegerType, true)
653 | )
654 | )
655 |
656 | val nestedSchema2 = StructType(
657 | List(
658 | StructField("name", StringType, true)
659 | )
660 | )
661 | val sourceSchema = List(
662 | StructField("struct1", nestedSchema1, true),
663 | StructField("struct2", nestedSchema2, true)
664 | )
665 | val sourceDF = spark.createDataFrame(
666 | spark.sparkContext.parallelize(sourceData),
667 | StructType(sourceSchema)
668 | )
669 | intercept[ColumnMismatch] {
670 | assertColumnEquality(sourceDF, "struct1", "struct2")
671 | }
672 | }
673 |
674 | "work with StructType columns" in {
675 | val sourceData = Seq(
676 | Row(Row("John", 30), Row("John", 30)),
677 | Row(Row("Jane", 25), Row("Jane", 25)),
678 | Row(Row("Jake", 40), Row("Jake", 40)),
679 | Row(null, null)
680 | )
681 | val nestedSchema = StructType(
682 | List(
683 | StructField("name", StringType, true),
684 | StructField("age", IntegerType, true)
685 | )
686 | )
687 | val sourceSchema = List(
688 | StructField("struct1", nestedSchema, true),
689 | StructField("struct2", nestedSchema, true)
690 | )
691 | val sourceDF = spark.createDataFrame(
692 | spark.sparkContext.parallelize(sourceData),
693 | StructType(sourceSchema)
694 | )
695 |
696 | assertColumnEquality(sourceDF, "struct1", "struct2")
697 | }
698 | }
699 |
700 | }
701 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType}
4 | import SparkSessionExt._
5 | import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch
6 | import org.apache.spark.sql.functions.col
7 | import com.github.mrpowers.spark.fast.tests.TestUtilsExt.ExceptionOps
8 | import org.scalatest.freespec.AnyFreeSpec
9 |
10 | import java.time.Instant
11 |
12 | class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with SparkSessionTestWrapper {
13 |
14 | "prints a descriptive error message if it bugs out" in {
15 | val sourceDF = spark.createDF(
16 | List(
17 | ("bob", 1, "uk"),
18 | ("camila", 5, "peru")
19 | ),
20 | List(
21 | ("name", StringType, true),
22 | ("age", IntegerType, true),
23 | ("country", StringType, true)
24 | )
25 | )
26 |
27 | val expectedDF = spark.createDF(
28 | List(
29 | ("bob", 1, "france"),
30 | ("camila", 5, "peru")
31 | ),
32 | List(
33 | ("name", StringType, true),
34 | ("age", IntegerType, true),
35 | ("country", StringType, true)
36 | )
37 | )
38 |
39 | val e = intercept[DatasetContentMismatch] {
40 | assertSmallDataFrameEquality(sourceDF, expectedDF)
41 | }
42 | assert(e.getMessage.indexOf("bob") >= 0)
43 | assert(e.getMessage.indexOf("camila") >= 0)
44 | }
45 |
46 | "Correctly mark unequal elements" in {
47 | val sourceDF = spark.createDF(
48 | List(
49 | ("bob", 1, "uk"),
50 | ("camila", 5, "peru"),
51 | ("steve", 10, "aus")
52 | ),
53 | List(
54 | ("name", StringType, true),
55 | ("age", IntegerType, true),
56 | ("country", StringType, true)
57 | )
58 | )
59 |
60 | val expectedDF = spark.createDF(
61 | List(
62 | ("bob", 1, "france"),
63 | ("camila", 5, "peru"),
64 | ("mark", 11, "usa")
65 | ),
66 | List(
67 | ("name", StringType, true),
68 | ("age", IntegerType, true),
69 | ("country", StringType, true)
70 | )
71 | )
72 |
73 | val e = intercept[DatasetContentMismatch] {
74 | assertSmallDataFrameEquality(expectedDF, sourceDF)
75 | }
76 |
77 | e.assertColorDiff(Seq("france", "[mark,11,usa]"), Seq("uk", "[steve,10,aus]"))
78 | }
79 |
80 | "Can handle unequal Dataframe containing null" in {
81 | val sourceDF = spark.createDF(
82 | List(
83 | ("bob", 1, "uk"),
84 | (null, 5, "peru"),
85 | ("steve", 10, "aus")
86 | ),
87 | List(
88 | ("name", StringType, true),
89 | ("age", IntegerType, true),
90 | ("country", StringType, true)
91 | )
92 | )
93 |
94 | val expectedDF = spark.createDF(
95 | List(
96 | ("bob", 1, "uk"),
97 | (null, 5, "peru"),
98 | (null, 10, "aus")
99 | ),
100 | List(
101 | ("name", StringType, true),
102 | ("age", IntegerType, true),
103 | ("country", StringType, true)
104 | )
105 | )
106 |
107 | val e = intercept[DatasetContentMismatch] {
108 | assertSmallDataFrameEquality(expectedDF, sourceDF)
109 | }
110 |
111 | e.assertColorDiff(Seq("null"), Seq("steve"))
112 | }
113 |
114 | "works well for wide DataFrames" in {
115 | val sourceDF = spark.createDF(
116 | List(
117 | ("bobisanicepersonandwelikehimOK", 1, "uk"),
118 | ("camila", 5, "peru")
119 | ),
120 | List(
121 | ("name", StringType, true),
122 | ("age", IntegerType, true),
123 | ("country", StringType, true)
124 | )
125 | )
126 |
127 | val expectedDF = spark.createDF(
128 | List(
129 | ("bobisanicepersonandwelikehimNOT", 1, "france"),
130 | ("camila", 5, "peru")
131 | ),
132 | List(
133 | ("name", StringType, true),
134 | ("age", IntegerType, true),
135 | ("country", StringType, true)
136 | )
137 | )
138 |
139 | intercept[DatasetContentMismatch] {
140 | assertSmallDataFrameEquality(sourceDF, expectedDF)
141 | }
142 | }
143 |
144 | "also print a descriptive error message if the right side is missing" in {
145 | val sourceDF = spark.createDF(
146 | List(
147 | ("bob", 1, "uk"),
148 | ("camila", 5, "peru"),
149 | ("jean-jacques", 4, "france")
150 | ),
151 | List(
152 | ("name", StringType, true),
153 | ("age", IntegerType, true),
154 | ("country", StringType, true)
155 | )
156 | )
157 |
158 | val expectedDF = spark.createDF(
159 | List(
160 | ("bob", 1, "france"),
161 | ("camila", 5, "peru")
162 | ),
163 | List(
164 | ("name", StringType, true),
165 | ("age", IntegerType, true),
166 | ("country", StringType, true)
167 | )
168 | )
169 |
170 | val e = intercept[DatasetContentMismatch] {
171 | assertSmallDataFrameEquality(sourceDF, expectedDF)
172 | }
173 |
174 | assert(e.getMessage.indexOf("jean-jacques") >= 0)
175 | }
176 |
177 | "also print a descriptive error message if the left side is missing" in {
178 | val sourceDF = spark.createDF(
179 | List(
180 | ("bob", 1, "uk"),
181 | ("camila", 5, "peru")
182 | ),
183 | List(
184 | ("name", StringType, true),
185 | ("age", IntegerType, true),
186 | ("country", StringType, true)
187 | )
188 | )
189 |
190 | val expectedDF = spark.createDF(
191 | List(
192 | ("bob", 1, "france"),
193 | ("camila", 5, "peru"),
194 | ("jean-claude", 4, "france")
195 | ),
196 | List(
197 | ("name", StringType, true),
198 | ("age", IntegerType, true),
199 | ("country", StringType, true)
200 | )
201 | )
202 |
203 | val e = intercept[DatasetContentMismatch] {
204 | assertSmallDataFrameEquality(sourceDF, expectedDF)
205 | }
206 |
207 | assert(e.getMessage.indexOf("jean-claude") >= 0)
208 | }
209 |
210 | "assertSmallDataFrameEquality" - {
211 |
212 | "does nothing if the DataFrames have the same schemas and content" in {
213 | val sourceDF = spark.createDF(
214 | List(
215 | (1),
216 | (5)
217 | ),
218 | List(("number", IntegerType, true))
219 | )
220 |
221 | val expectedDF = spark.createDF(
222 | List(
223 | (1),
224 | (5)
225 | ),
226 | List(("number", IntegerType, true))
227 | )
228 | assertLargeDataFrameEquality(sourceDF, expectedDF)
229 | }
230 |
231 | "throws an error if the DataFrames have different schemas" in {
232 | val sourceDF = spark.createDF(
233 | List(
234 | (1),
235 | (5)
236 | ),
237 | List(("number", IntegerType, true))
238 | )
239 |
240 | val expectedDF = spark.createDF(
241 | List(
242 | (1, "word"),
243 | (5, "word")
244 | ),
245 | List(
246 | ("number", IntegerType, true),
247 | ("word", StringType, true)
248 | )
249 | )
250 |
251 | intercept[DatasetSchemaMismatch] {
252 | assertLargeDataFrameEquality(sourceDF, expectedDF)
253 | }
254 | intercept[DatasetSchemaMismatch] {
255 | assertSmallDataFrameEquality(sourceDF, expectedDF)
256 | }
257 | }
258 |
259 | "throws an error if the DataFrames content is different" in {
260 | val sourceDF = spark.createDF(
261 | List(
262 | (1),
263 | (5)
264 | ),
265 | List(("number", IntegerType, true))
266 | )
267 |
268 | val expectedDF = spark.createDF(
269 | List(
270 | (10),
271 | (5)
272 | ),
273 | List(("number", IntegerType, true))
274 | )
275 |
276 | intercept[DatasetContentMismatch] {
277 | assertLargeDataFrameEquality(sourceDF, expectedDF)
278 | }
279 | intercept[DatasetContentMismatch] {
280 | assertSmallDataFrameEquality(sourceDF, expectedDF)
281 | }
282 | }
283 |
284 | "can performed unordered DataFrame comparisons" in {
285 | val sourceDF = spark.createDF(
286 | List(
287 | (1),
288 | (5)
289 | ),
290 | List(("number", IntegerType, true))
291 | )
292 | val expectedDF = spark.createDF(
293 | List(
294 | (5),
295 | (1)
296 | ),
297 | List(("number", IntegerType, true))
298 | )
299 | assertSmallDataFrameEquality(sourceDF, expectedDF, orderedComparison = false)
300 | }
301 |
302 | "throws an error for unordered DataFrame comparisons that don't match" in {
303 | val sourceDF = spark.createDF(
304 | List(
305 | (1),
306 | (5)
307 | ),
308 | List(("number", IntegerType, true))
309 | )
310 | val expectedDF = spark.createDF(
311 | List(
312 | (5),
313 | (1),
314 | (10)
315 | ),
316 | List(("number", IntegerType, true))
317 | )
318 | intercept[DatasetContentMismatch] {
319 | assertSmallDataFrameEquality(sourceDF, expectedDF, orderedComparison = false)
320 | }
321 | }
322 |
323 | "can performed DataFrame comparisons with unordered column" in {
324 | val sourceDF = spark.createDF(
325 | List(
326 | (1, "word"),
327 | (5, "word")
328 | ),
329 | List(
330 | ("number", IntegerType, true),
331 | ("word", StringType, true)
332 | )
333 | )
334 | val expectedDF = spark.createDF(
335 | List(
336 | ("word", 1),
337 | ("word", 5)
338 | ),
339 | List(
340 | ("word", StringType, true),
341 | ("number", IntegerType, true)
342 | )
343 | )
344 | assertLargeDataFrameEquality(sourceDF, expectedDF, ignoreColumnOrder = true)
345 | }
346 |
347 | "should not ignore nullable if ignoreNullable is false" in {
348 | val sourceDF = spark.createDF(
349 | List(
350 | 1.2,
351 | 5.1
352 | ),
353 | List(("number", DoubleType, false))
354 | )
355 | val expectedDF = spark.createDF(
356 | List(
357 | 1.2,
358 | 5.1
359 | ),
360 | List(("number", DoubleType, true))
361 | )
362 |
363 | intercept[DatasetSchemaMismatch] {
364 | assertLargeDataFrameEquality(sourceDF, expectedDF)
365 | }
366 | }
367 |
368 | "correctly mark unequal schema field" in {
369 | val sourceDF = spark.createDF(
370 | List(
371 | (1, 2.0),
372 | (5, 3.0)
373 | ),
374 | List(
375 | ("number", IntegerType, true),
376 | ("float", DoubleType, true)
377 | )
378 | )
379 |
380 | val expectedDF = spark.createDF(
381 | List(
382 | (1, "word", 1L),
383 | (5, "word", 2L)
384 | ),
385 | List(
386 | ("number", IntegerType, true),
387 | ("word", StringType, true),
388 | ("long", LongType, true)
389 | )
390 | )
391 |
392 | val e = intercept[DatasetSchemaMismatch] {
393 | assertSmallDataFrameEquality(sourceDF, expectedDF)
394 | }
395 |
396 | e.assertColorDiff(
397 | Seq("float", "DoubleType", "MISSING"),
398 | Seq("word", "StringType", "StructField(long,LongType,true,{})")
399 | )
400 | }
401 |
402 | "can performed Dataset comparisons and ignore metadata" in {
403 | val sourceDF = spark
404 | .createDF(
405 | List(
406 | 1,
407 | 5
408 | ),
409 | List(("number", IntegerType, true))
410 | )
411 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build()))
412 |
413 | val expectedDF = spark
414 | .createDF(
415 | List(
416 | 1,
417 | 5
418 | ),
419 | List(("number", IntegerType, true))
420 | )
421 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build()))
422 |
423 | assertLargeDataFrameEquality(sourceDF, expectedDF)
424 | }
425 |
426 | "can performed Dataset comparisons and compare metadata" in {
427 | val sourceDF = spark
428 | .createDF(
429 | List(
430 | 1,
431 | 5
432 | ),
433 | List(("number", IntegerType, true))
434 | )
435 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build()))
436 |
437 | val expectedDF = spark
438 | .createDF(
439 | List(
440 | 1,
441 | 5
442 | ),
443 | List(("number", IntegerType, true))
444 | )
445 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build()))
446 |
447 | intercept[DatasetSchemaMismatch] {
448 | assertLargeDataFrameEquality(sourceDF, expectedDF, ignoreMetadata = false)
449 | }
450 | }
451 | }
452 |
453 | "assertApproximateDataFrameEquality" - {
454 |
455 | "does nothing if the DataFrames have the same schemas and content" in {
456 | val sourceDF = spark.createDF(
457 | List(
458 | 1.2,
459 | 5.1,
460 | null
461 | ),
462 | List(("number", DoubleType, true))
463 | )
464 | val expectedDF = spark.createDF(
465 | List(
466 | 1.2,
467 | 5.1,
468 | null
469 | ),
470 | List(("number", DoubleType, true))
471 | )
472 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01)
473 | }
474 |
475 | "throws an error if the rows are different" in {
476 | val sourceDF = spark.createDF(
477 | List(
478 | 100.9,
479 | 5.1
480 | ),
481 | List(("number", DoubleType, true))
482 | )
483 | val expectedDF = spark.createDF(
484 | List(
485 | 1.2,
486 | 5.1
487 | ),
488 | List(("number", DoubleType, true))
489 | )
490 | val e = intercept[DatasetContentMismatch] {
491 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01)
492 | }
493 | }
494 |
495 | "throws an error DataFrames have a different number of rows" in {
496 | val sourceDF = spark.createDF(
497 | List(
498 | 1.2,
499 | 5.1,
500 | 8.8
501 | ),
502 | List(("number", DoubleType, true))
503 | )
504 | val expectedDF = spark.createDF(
505 | List(
506 | 1.2,
507 | 5.1
508 | ),
509 | List(("number", DoubleType, true))
510 | )
511 | val e = intercept[DatasetCountMismatch] {
512 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01)
513 | }
514 | }
515 |
516 | "should not ignore nullable if ignoreNullable is false" in {
517 | val sourceDF = spark.createDF(
518 | List(
519 | 1.2,
520 | 5.1
521 | ),
522 | List(("number", DoubleType, false))
523 | )
524 | val expectedDF = spark.createDF(
525 | List(
526 | 1.2,
527 | 5.1
528 | ),
529 | List(("number", DoubleType, true))
530 | )
531 |
532 | intercept[DatasetSchemaMismatch] {
533 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01)
534 | }
535 | }
536 |
537 | "can ignore the nullable property" in {
538 | val sourceDF = spark.createDF(
539 | List(
540 | 1.2,
541 | 5.1
542 | ),
543 | List(("number", DoubleType, false))
544 | )
545 | val expectedDF = spark.createDF(
546 | List(
547 | 1.2,
548 | 5.1
549 | ),
550 | List(("number", DoubleType, true))
551 | )
552 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreNullable = true)
553 | }
554 |
555 | "can ignore the column names" in {
556 | val sourceDF = spark.createDF(
557 | List(
558 | 1.2,
559 | 5.1,
560 | null
561 | ),
562 | List(("BLAHBLBH", DoubleType, true))
563 | )
564 | val expectedDF = spark.createDF(
565 | List(
566 | 1.2,
567 | 5.1,
568 | null
569 | ),
570 | List(("number", DoubleType, true))
571 | )
572 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreColumnNames = true)
573 | }
574 |
575 | "can work with precision and unordered comparison" in {
576 | import spark.implicits._
577 | val ds1 = Seq(
578 | ("1", "10/01/2019", 26.762499999999996),
579 | ("1", "11/01/2019", 26.762499999999996)
580 | ).toDF("col_B", "col_C", "col_A")
581 |
582 | val ds2 = Seq(
583 | ("1", "10/01/2019", 26.762499999999946),
584 | ("1", "11/01/2019", 26.76249999999991)
585 | ).toDF("col_B", "col_C", "col_A")
586 |
587 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
588 | }
589 |
590 | "can work with precision and unordered comparison 2" in {
591 | import spark.implicits._
592 | val ds1 = Seq(
593 | ("1", "10/01/2019", 26.762499999999996, "A"),
594 | ("1", "10/01/2019", 26.762499999999996, "B")
595 | ).toDF("col_B", "col_C", "col_A", "col_D")
596 |
597 | val ds2 = Seq(
598 | ("1", "10/01/2019", 26.762499999999946, "A"),
599 | ("1", "10/01/2019", 26.76249999999991, "B")
600 | ).toDF("col_B", "col_C", "col_A", "col_D")
601 |
602 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
603 | }
604 |
605 | "throw error when exceed precision" in {
606 | import spark.implicits._
607 | val ds1 = Seq(
608 | ("1", "10/01/2019", 26.762499999999996),
609 | ("1", "11/01/2019", 26.762499999999996)
610 | ).toDF("col_B", "col_C", "col_A")
611 |
612 | val ds2 = Seq(
613 | ("1", "10/01/2019", 26.762499999999946),
614 | ("1", "11/01/2019", 28.76249999999991)
615 | ).toDF("col_B", "col_C", "col_A")
616 |
617 | intercept[DatasetContentMismatch] {
618 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
619 | }
620 | }
621 |
622 | "throw error when exceed precision for TimestampType" in {
623 | import spark.implicits._
624 | val ds1 = Seq(
625 | ("1", Instant.parse("2019-10-01T00:00:00Z")),
626 | ("2", Instant.parse("2019-11-01T00:00:00Z"))
627 | ).toDF("col_B", "col_A")
628 |
629 | val ds2 = Seq(
630 | ("1", Instant.parse("2019-10-01T00:00:00Z")),
631 | ("2", Instant.parse("2019-12-01T00:00:00Z"))
632 | ).toDF("col_B", "col_A")
633 |
634 | intercept[DatasetContentMismatch] {
635 | assertApproximateDataFrameEquality(ds1, ds2, precision = 100, orderedComparison = false)
636 | }
637 | }
638 |
639 | "throw error when exceed precision for BigDecimal" in {
640 | import spark.implicits._
641 | val ds1 = Seq(
642 | ("1", BigDecimal(101)),
643 | ("2", BigDecimal(200))
644 | ).toDF("col_B", "col_A")
645 |
646 | val ds2 = Seq(
647 | ("1", BigDecimal(101)),
648 | ("2", BigDecimal(203))
649 | ).toDF("col_B", "col_A")
650 |
651 | intercept[DatasetContentMismatch] {
652 | assertApproximateDataFrameEquality(ds1, ds2, precision = 2, orderedComparison = false)
653 | }
654 | }
655 |
656 | "can work with precision and unordered comparison on nested column" in {
657 | import spark.implicits._
658 | val ds1 = Seq(
659 | ("1", "10/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)),
660 | ("1", "11/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996))
661 | ).toDF("col_B", "col_C", "col_A", "col_D")
662 |
663 | val ds2 = Seq(
664 | ("1", "11/01/2019", 26.7624999999999961, Seq(26.7624999999999961, 26.7624999999999961)),
665 | ("1", "10/01/2019", 26.762499999999997, Seq(26.762499999999997, 26.762499999999997))
666 | ).toDF("col_B", "col_C", "col_A", "col_D")
667 |
668 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
669 | }
670 |
671 | "can performed Dataset comparisons and ignore metadata" in {
672 | val sourceDF = spark
673 | .createDF(
674 | List(
675 | 1,
676 | 5
677 | ),
678 | List(("number", IntegerType, true))
679 | )
680 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build()))
681 |
682 | val expectedDF = spark
683 | .createDF(
684 | List(
685 | 1,
686 | 5
687 | ),
688 | List(("number", IntegerType, true))
689 | )
690 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build()))
691 |
692 | assertApproximateDataFrameEquality(sourceDF, expectedDF, precision = 0.0000001)
693 | }
694 |
695 | "can performed Dataset comparisons and compare metadata" in {
696 | val sourceDF = spark
697 | .createDF(
698 | List(
699 | 1,
700 | 5
701 | ),
702 | List(("number", IntegerType, true))
703 | )
704 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build()))
705 |
706 | val expectedDF = spark
707 | .createDF(
708 | List(
709 | 1,
710 | 5
711 | ),
712 | List(("number", IntegerType, true))
713 | )
714 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build()))
715 |
716 | intercept[DatasetSchemaMismatch] {
717 | assertApproximateDataFrameEquality(sourceDF, expectedDF, precision = 0.0000001, ignoreMetadata = false)
718 | }
719 | }
720 | }
721 |
722 | "assertApproximateSmallDataFrameEquality" - {
723 |
724 | "does nothing if the DataFrames have the same schemas and content" in {
725 | val sourceDF = spark.createDF(
726 | List(
727 | 1.2,
728 | 5.1,
729 | null
730 | ),
731 | List(("number", DoubleType, true))
732 | )
733 | val expectedDF = spark.createDF(
734 | List(
735 | 1.2,
736 | 5.1,
737 | null
738 | ),
739 | List(("number", DoubleType, true))
740 | )
741 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01)
742 | }
743 |
744 | "throws an error if the rows are different" in {
745 | val sourceDF = spark.createDF(
746 | List(
747 | 100.9,
748 | 5.1
749 | ),
750 | List(("number", DoubleType, true))
751 | )
752 | val expectedDF = spark.createDF(
753 | List(
754 | 1.2,
755 | 5.1
756 | ),
757 | List(("number", DoubleType, true))
758 | )
759 | val e = intercept[DatasetContentMismatch] {
760 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01)
761 | }
762 | }
763 |
764 | "throws an error DataFrames have a different number of rows" in {
765 | val sourceDF = spark.createDF(
766 | List(
767 | 1.2,
768 | 5.1,
769 | 8.8
770 | ),
771 | List(("number", DoubleType, true))
772 | )
773 | val expectedDF = spark.createDF(
774 | List(
775 | 1.2,
776 | 5.1
777 | ),
778 | List(("number", DoubleType, true))
779 | )
780 | val e = intercept[DatasetContentMismatch] {
781 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01)
782 | }
783 | }
784 |
785 | "can ignore the nullable property" in {
786 | val sourceDF = spark.createDF(
787 | List(
788 | 1.2,
789 | 5.1
790 | ),
791 | List(("number", DoubleType, false))
792 | )
793 | val expectedDF = spark.createDF(
794 | List(
795 | 1.2,
796 | 5.1
797 | ),
798 | List(("number", DoubleType, true))
799 | )
800 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreNullable = true)
801 | }
802 |
803 | "should not ignore nullable if ignoreNullable is false" in {
804 | val sourceDF = spark.createDF(
805 | List(
806 | 1.2,
807 | 5.1
808 | ),
809 | List(("number", DoubleType, false))
810 | )
811 | val expectedDF = spark.createDF(
812 | List(
813 | 1.2,
814 | 5.1
815 | ),
816 | List(("number", DoubleType, true))
817 | )
818 |
819 | intercept[DatasetSchemaMismatch] {
820 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01)
821 | }
822 | }
823 |
824 | "can ignore the column names" in {
825 | val sourceDF = spark.createDF(
826 | List(
827 | 1.2,
828 | 5.1,
829 | null
830 | ),
831 | List(("BLAHBLBH", DoubleType, true))
832 | )
833 | val expectedDF = spark.createDF(
834 | List(
835 | 1.2,
836 | 5.1,
837 | null
838 | ),
839 | List(("number", DoubleType, true))
840 | )
841 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreColumnNames = true)
842 | }
843 |
844 | "can work with precision and unordered comparison" in {
845 | import spark.implicits._
846 | val ds1 = Seq(
847 | ("1", "10/01/2019", 26.762499999999996),
848 | ("1", "11/01/2019", 26.762499999999996)
849 | ).toDF("col_B", "col_C", "col_A")
850 |
851 | val ds2 = Seq(
852 | ("1", "10/01/2019", 26.762499999999946),
853 | ("1", "11/01/2019", 26.76249999999991)
854 | ).toDF("col_B", "col_C", "col_A")
855 |
856 | assertApproximateSmallDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
857 | }
858 |
859 | "can work with precision and unordered comparison 2" in {
860 | import spark.implicits._
861 | val ds1 = Seq(
862 | ("1", "10/01/2019", "A", 26.762499999999996),
863 | ("1", "10/01/2019", "B", 26.762499999999996)
864 | ).toDF("col_B", "col_C", "col_A", "col_D")
865 |
866 | val ds2 = Seq(
867 | ("1", "10/01/2019", "A", 26.762499999999946),
868 | ("1", "10/01/2019", "B", 26.76249999999991)
869 | ).toDF("col_B", "col_C", "col_A", "col_D")
870 |
871 | assertApproximateSmallDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
872 | }
873 |
874 | "can work with precision and unordered comparison on nested column" in {
875 | import spark.implicits._
876 | val ds1 = Seq(
877 | ("1", "10/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)),
878 | ("2", "11/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996))
879 | ).toDF("col_B", "col_C", "col_A", "col_D")
880 |
881 | val ds2 = Seq(
882 | ("2", "11/01/2019", 26.7624999999999961, Seq(26.7624999999999961, 26.7624999999999961)),
883 | ("1", "10/01/2019", 26.762499999999997, Seq(26.762499999999997, 26.762499999999997))
884 | ).toDF("col_B", "col_C", "col_A", "col_D")
885 |
886 | assertApproximateSmallDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
887 | }
888 |
889 | "can performed Dataset comparisons and ignore metadata" in {
890 | val sourceDF = spark
891 | .createDF(
892 | List(
893 | 1,
894 | 5
895 | ),
896 | List(("number", IntegerType, true))
897 | )
898 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build()))
899 |
900 | val expectedDF = spark
901 | .createDF(
902 | List(
903 | 1,
904 | 5
905 | ),
906 | List(("number", IntegerType, true))
907 | )
908 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build()))
909 |
910 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, precision = 0.0000001)
911 | }
912 |
913 | "can performed Dataset comparisons and compare metadata" in {
914 | val sourceDF = spark
915 | .createDF(
916 | List(
917 | 1,
918 | 5
919 | ),
920 | List(("number", IntegerType, true))
921 | )
922 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build()))
923 |
924 | val expectedDF = spark
925 | .createDF(
926 | List(
927 | 1,
928 | 5
929 | ),
930 | List(("number", IntegerType, true))
931 | )
932 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build()))
933 |
934 | intercept[DatasetSchemaMismatch] {
935 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, precision = 0.0000001, ignoreMetadata = false)
936 | }
937 | }
938 | }
939 | }
940 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFramePrettyPrintTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.scalatest.freespec.AnyFreeSpec
4 |
5 | class DataFramePrettyPrintTest extends AnyFreeSpec with SparkSessionTestWrapper {
6 | "prints named_struct with keys" in {
7 | val inputDataframe = spark.sql("""select 1 as id, named_struct("k1", 2, "k2", 3) as to_show_with_k""")
8 | assert(
9 | DataFramePrettyPrint.showString(inputDataframe, 10) ==
10 | """+---+------------------+
11 | || id| to_show_with_k|
12 | |+---+------------------+
13 | || 1|{k1 -> 2, k2 -> 3}|
14 | |+---+------------------+
15 | |""".stripMargin
16 | )
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.types._
4 | import SparkSessionExt._
5 | import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch
6 | import com.github.mrpowers.spark.fast.tests.TestUtilsExt.ExceptionOps
7 | import org.apache.spark.sql.functions.col
8 | import org.scalatest.freespec.AnyFreeSpec
9 |
10 | object Person {
11 |
12 | def caseInsensitivePersonEquals(some: Person, other: Person): Boolean = {
13 | some.name.equalsIgnoreCase(other.name) && some.age == other.age
14 | }
15 | }
16 | case class Person(name: String, age: Int)
17 | case class PrecisePerson(name: String, age: Double)
18 |
19 | class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSessionTestWrapper {
20 |
21 | "checkDatasetEquality" - {
22 | import spark.implicits._
23 |
24 | "provides a good README example" in {
25 | val sourceDS = Seq(
26 | Person("juan", 5),
27 | Person("bob", 1),
28 | Person("li", 49),
29 | Person("alice", 5)
30 | ).toDS
31 |
32 | val expectedDS = Seq(
33 | Person("juan", 5),
34 | Person("frank", 10),
35 | Person("li", 49),
36 | Person("lucy", 5)
37 | ).toDS
38 |
39 | val e = intercept[DatasetContentMismatch] {
40 | assertSmallDatasetEquality(sourceDS, expectedDS)
41 | }
42 | }
43 |
44 | "can compare unequal Dataset containing null in column" in {
45 | val sourceDS = Seq(Person(null, 5), Person(null, 1)).toDS
46 | val expectedDS = Seq(Person("juan", 5), Person(null, 1)).toDS
47 |
48 | val e = intercept[DatasetContentMismatch] {
49 | assertSmallDatasetEquality(sourceDS, expectedDS, ignoreNullable = true, orderedComparison = false)
50 | }
51 |
52 | e.assertColorDiff(Seq("null"), Seq("juan"))
53 | }
54 |
55 | "Correctly mark unequal elements" in {
56 | val sourceDS = Seq(
57 | Person("juan", 5),
58 | Person("bob", 1),
59 | Person("li", 49),
60 | Person("alice", 5)
61 | ).toDS
62 |
63 | val expectedDS = Seq(
64 | Person("juan", 5),
65 | Person("frank", 10),
66 | Person("li", 49),
67 | Person("lucy", 5)
68 | ).toDS
69 |
70 | val e = intercept[DatasetContentMismatch] {
71 | assertSmallDatasetEquality(sourceDS, expectedDS)
72 | }
73 |
74 | e.assertColorDiff(Seq("Person(bob,1)", "alice"), Seq("Person(frank,10)", "lucy"))
75 | }
76 |
77 | "correctly mark unequal element for Dataset[String]" in {
78 | import spark.implicits._
79 | val sourceDS = Seq("word", "StringType", "StructField(long,LongType,true,{})").toDS
80 |
81 | val expectedDS = List("word", "StringType", "StructField(long,LongType2,true,{})").toDS
82 |
83 | val e = intercept[DatasetContentMismatch] {
84 | assertSmallDatasetEquality(sourceDS, expectedDS)
85 | }
86 |
87 | e.assertColorDiff(Seq("String(StructField(long,LongType,true,{}))"), Seq("String(StructField(long,LongType2,true,{}))"))
88 | }
89 |
90 | "correctly mark unequal element for Dataset[Seq[String]]" in {
91 | import spark.implicits._
92 |
93 | val sourceDS = Seq(
94 | Seq("apple", "banana", "cherry"),
95 | Seq("dog", "cat"),
96 | Seq("red", "green", "blue")
97 | ).toDS
98 |
99 | val expectedDS = Seq(
100 | Seq("apple", "banana2"),
101 | Seq("dog", "cat"),
102 | Seq("red", "green", "blue")
103 | ).toDS
104 |
105 | val e = intercept[DatasetContentMismatch] {
106 | assertSmallDatasetEquality(sourceDS, expectedDS)
107 | }
108 |
109 | e.assertColorDiff(Seq("banana", "cherry"), Seq("banana2", "MISSING"))
110 | }
111 |
112 | "correctly mark unequal element for Dataset[Array[String]]" in {
113 | import spark.implicits._
114 |
115 | val sourceDS = Seq(
116 | Array("apple", "banana", "cherry"),
117 | Array("dog", "cat"),
118 | Array("red", "green", "blue")
119 | ).toDS
120 |
121 | val expectedDS = Seq(
122 | Array("apple", "banana2"),
123 | Array("dog", "cat"),
124 | Array("red", "green", "blue")
125 | ).toDS
126 |
127 | val e = intercept[DatasetContentMismatch] {
128 | assertSmallDatasetEquality(sourceDS, expectedDS)
129 | }
130 |
131 | e.assertColorDiff(Seq("banana", "cherry"), Seq("banana2", "MISSING"))
132 | }
133 |
134 | "correctly mark unequal element for Dataset[Map[String, String]]" in {
135 | import spark.implicits._
136 |
137 | val sourceDS = Seq(
138 | Map("apple" -> "banana", "apple1" -> "banana1"),
139 | Map("apple" -> "banana", "apple1" -> "banana1")
140 | ).toDS
141 |
142 | val expectedDS = Seq(
143 | Map("apple" -> "banana1", "apple1" -> "banana1"),
144 | Map("apple" -> "banana", "apple1" -> "banana1")
145 | ).toDS
146 |
147 | val e = intercept[DatasetContentMismatch] {
148 | assertSmallDatasetEquality(sourceDS, expectedDS)
149 | }
150 |
151 | e.assertColorDiff(Seq("(apple,banana)"), Seq("(apple,banana1)"))
152 | }
153 |
154 | "works with really long columns" in {
155 | val sourceDS = Seq(
156 | Person("juanisareallygoodguythatilikealotOK", 5),
157 | Person("bob", 1),
158 | Person("li", 49),
159 | Person("alice", 5)
160 | ).toDS
161 |
162 | val expectedDS = Seq(
163 | Person("juanisareallygoodguythatilikealotNOT", 5),
164 | Person("frank", 10),
165 | Person("li", 49),
166 | Person("lucy", 5)
167 | ).toDS
168 |
169 | val e = intercept[DatasetContentMismatch] {
170 | assertSmallDatasetEquality(sourceDS, expectedDS)
171 | }
172 | }
173 |
174 | "does nothing if the DataFrames have the same schemas and content" in {
175 | val sourceDF = spark.createDF(
176 | List(
177 | (1),
178 | (5)
179 | ),
180 | List(("number", IntegerType, true))
181 | )
182 |
183 | val expectedDF = spark.createDF(
184 | List(
185 | (1),
186 | (5)
187 | ),
188 | List(("number", IntegerType, true))
189 | )
190 |
191 | assertSmallDatasetEquality(sourceDF, expectedDF)
192 | assertLargeDatasetEquality(sourceDF, expectedDF)
193 | }
194 |
195 | "does nothing if the Datasets have the same schemas and content" in {
196 | val sourceDS = spark.createDataset[Person](
197 | Seq(
198 | Person("Alice", 12),
199 | Person("Bob", 17)
200 | )
201 | )
202 |
203 | val expectedDS = spark.createDataset[Person](
204 | Seq(
205 | Person("Alice", 12),
206 | Person("Bob", 17)
207 | )
208 | )
209 |
210 | assertSmallDatasetEquality(sourceDS, expectedDS)
211 | assertLargeDatasetEquality(sourceDS, expectedDS)
212 | }
213 |
214 | "works with DataFrames that have ArrayType columns" in {
215 | val sourceDF = spark.createDF(
216 | List(
217 | (1, Array("word1", "blah")),
218 | (5, Array("hi", "there"))
219 | ),
220 | List(
221 | ("number", IntegerType, true),
222 | ("words", ArrayType(StringType, true), true)
223 | )
224 | )
225 |
226 | val expectedDF = spark.createDF(
227 | List(
228 | (1, Array("word1", "blah")),
229 | (5, Array("hi", "there"))
230 | ),
231 | List(
232 | ("number", IntegerType, true),
233 | ("words", ArrayType(StringType, true), true)
234 | )
235 | )
236 |
237 | assertLargeDatasetEquality(sourceDF, expectedDF)
238 | assertSmallDatasetEquality(sourceDF, expectedDF)
239 | }
240 |
241 | "throws an error if the DataFrames have different schemas" in {
242 | val nestedSchema = StructType(
243 | Seq(
244 | StructField(
245 | "attributes",
246 | StructType(
247 | Seq(
248 | StructField("PostCode", IntegerType, nullable = true)
249 | )
250 | ),
251 | nullable = true
252 | )
253 | )
254 | )
255 |
256 | val nestedSchema2 = StructType(
257 | Seq(
258 | StructField(
259 | "attributes",
260 | StructType(
261 | Seq(
262 | StructField("PostCode", StringType, nullable = true)
263 | )
264 | ),
265 | nullable = true
266 | )
267 | )
268 | )
269 |
270 | val sourceDF = spark.createDF(
271 | List(
272 | (1, 2.0, null),
273 | (5, 3.0, null)
274 | ),
275 | List(
276 | ("number", IntegerType, true),
277 | ("float", DoubleType, true),
278 | ("nestedField", nestedSchema, true)
279 | )
280 | )
281 |
282 | val expectedDF = spark.createDF(
283 | List(
284 | (1, "word", null, 1L),
285 | (5, "word", null, 2L)
286 | ),
287 | List(
288 | ("number", IntegerType, true),
289 | ("word", StringType, true),
290 | ("nestedField", nestedSchema2, true),
291 | ("long", LongType, true)
292 | )
293 | )
294 |
295 | intercept[DatasetSchemaMismatch] {
296 | assertLargeDatasetEquality(sourceDF, expectedDF)
297 | }
298 |
299 | intercept[DatasetSchemaMismatch] {
300 | assertSmallDatasetEquality(sourceDF, expectedDF)
301 | }
302 | }
303 |
304 | "throws an error if the DataFrames content is different" in {
305 | val sourceDF = Seq(
306 | (1), (5), (7), (1), (1)
307 | ).toDF("number")
308 |
309 | val expectedDF = Seq(
310 | (10), (5), (3), (7), (1)
311 | ).toDF("number")
312 |
313 | val e = intercept[DatasetContentMismatch] {
314 | assertLargeDatasetEquality(sourceDF, expectedDF)
315 | }
316 | val e2 = intercept[DatasetContentMismatch] {
317 | assertSmallDatasetEquality(sourceDF, expectedDF)
318 | }
319 | }
320 |
321 | "throws an error if the Dataset content is different" in {
322 | val sourceDS = spark.createDataset[Person](
323 | Seq(
324 | Person("Alice", 12),
325 | Person("Bob", 17)
326 | )
327 | )
328 |
329 | val expectedDS = spark.createDataset[Person](
330 | Seq(
331 | Person("Frank", 10),
332 | Person("Lucy", 5)
333 | )
334 | )
335 |
336 | val e = intercept[DatasetContentMismatch] {
337 | assertLargeDatasetEquality(sourceDS, expectedDS)
338 | }
339 | val e2 = intercept[DatasetContentMismatch] {
340 | assertLargeDatasetEquality(sourceDS, expectedDS)
341 | }
342 | }
343 |
344 | "succeeds if custom comparator returns true" in {
345 | val sourceDS = spark.createDataset[Person](
346 | Seq(
347 | Person("bob", 1),
348 | Person("alice", 5)
349 | )
350 | )
351 | val expectedDS = spark.createDataset[Person](
352 | Seq(
353 | Person("Bob", 1),
354 | Person("Alice", 5)
355 | )
356 | )
357 | assertLargeDatasetEquality(sourceDS, expectedDS, Person.caseInsensitivePersonEquals)
358 | }
359 |
360 | "fails if custom comparator for returns false" in {
361 | val sourceDS = spark.createDataset[Person](
362 | Seq(
363 | Person("bob", 10),
364 | Person("alice", 5)
365 | )
366 | )
367 | val expectedDS = spark.createDataset[Person](
368 | Seq(
369 | Person("Bob", 1),
370 | Person("Alice", 5)
371 | )
372 | )
373 | val e = intercept[DatasetContentMismatch] {
374 | assertLargeDatasetEquality(sourceDS, expectedDS, Person.caseInsensitivePersonEquals)
375 | }
376 | }
377 |
378 | }
379 |
380 | "assertLargeDatasetEquality" - {
381 | import spark.implicits._
382 |
383 | "ignores the nullable flag when making DataFrame comparisons" in {
384 | val sourceDF = spark.createDF(
385 | List(
386 | (1),
387 | (5)
388 | ),
389 | List(("number", IntegerType, false))
390 | )
391 |
392 | val expectedDF = spark.createDF(
393 | List(
394 | (1),
395 | (5)
396 | ),
397 | List(("number", IntegerType, true))
398 | )
399 |
400 | assertLargeDatasetEquality(sourceDF, expectedDF, ignoreNullable = true)
401 | }
402 |
403 | "should not ignore nullable if ignoreNullable is false" in {
404 |
405 | val sourceDF = spark.createDF(
406 | List(
407 | (1),
408 | (5)
409 | ),
410 | List(("number", IntegerType, false))
411 | )
412 |
413 | val expectedDF = spark.createDF(
414 | List(
415 | (1),
416 | (5)
417 | ),
418 | List(("number", IntegerType, true))
419 | )
420 |
421 | intercept[DatasetSchemaMismatch] {
422 | assertLargeDatasetEquality(sourceDF, expectedDF)
423 | }
424 | }
425 |
426 | "can performed unordered DataFrame comparisons" in {
427 | val sourceDF = spark.createDF(
428 | List(
429 | (1),
430 | (5)
431 | ),
432 | List(("number", IntegerType, true))
433 | )
434 |
435 | val expectedDF = spark.createDF(
436 | List(
437 | (5),
438 | (1)
439 | ),
440 | List(("number", IntegerType, true))
441 | )
442 |
443 | assertLargeDatasetEquality(sourceDF, expectedDF, orderedComparison = false)
444 | }
445 |
446 | "throws an error for unordered Dataset comparisons that don't match" in {
447 | val sourceDS = spark.createDataset[Person](
448 | Seq(
449 | Person("bob", 1),
450 | Person("frank", 5)
451 | )
452 | )
453 |
454 | val expectedDS = spark.createDataset[Person](
455 | Seq(
456 | Person("frank", 5),
457 | Person("bob", 1),
458 | Person("sadie", 2)
459 | )
460 | )
461 |
462 | val e = intercept[DatasetCountMismatch] {
463 | assertLargeDatasetEquality(sourceDS, expectedDS, orderedComparison = false)
464 | }
465 | }
466 |
467 | "throws an error for unordered DataFrame comparisons that don't match" in {
468 | val sourceDF = spark.createDF(
469 | List(
470 | (1),
471 | (5)
472 | ),
473 | List(("number", IntegerType, true))
474 | )
475 | val expectedDF = spark.createDF(
476 | List(
477 | (5),
478 | (1),
479 | (10)
480 | ),
481 | List(("number", IntegerType, true))
482 | )
483 |
484 | val e = intercept[DatasetCountMismatch] {
485 | assertLargeDatasetEquality(sourceDF, expectedDF, orderedComparison = false)
486 | }
487 | }
488 |
489 | "throws an error DataFrames have a different number of rows" in {
490 | val sourceDF = spark.createDF(
491 | List(
492 | (1),
493 | (5)
494 | ),
495 | List(("number", IntegerType, true))
496 | )
497 | val expectedDF = spark.createDF(
498 | List(
499 | (1),
500 | (5),
501 | (10)
502 | ),
503 | List(("number", IntegerType, true))
504 | )
505 |
506 | val e = intercept[DatasetCountMismatch] {
507 | assertLargeDatasetEquality(sourceDF, expectedDF)
508 | }
509 | }
510 |
511 | "can performed DataFrame comparisons with unordered column" in {
512 | val sourceDF = spark.createDF(
513 | List(
514 | (1, "word"),
515 | (5, "word")
516 | ),
517 | List(
518 | ("number", IntegerType, true),
519 | ("word", StringType, true)
520 | )
521 | )
522 | val expectedDF = spark.createDF(
523 | List(
524 | ("word", 1),
525 | ("word", 5)
526 | ),
527 | List(
528 | ("word", StringType, true),
529 | ("number", IntegerType, true)
530 | )
531 | )
532 | assertLargeDatasetEquality(sourceDF, expectedDF, ignoreColumnOrder = true)
533 | }
534 |
535 | "can performed Dataset comparisons with unordered column" in {
536 | val ds1 = Seq(
537 | Person("juan", 5),
538 | Person("bob", 1),
539 | Person("li", 49),
540 | Person("alice", 5)
541 | ).toDS
542 |
543 | val ds2 = Seq(
544 | Person("juan", 5),
545 | Person("bob", 1),
546 | Person("li", 49),
547 | Person("alice", 5)
548 | ).toDS.select("age", "name").as(ds1.encoder)
549 |
550 | assertLargeDatasetEquality(ds1, ds2, ignoreColumnOrder = true)
551 | assertLargeDatasetEquality(ds2, ds1, ignoreColumnOrder = true)
552 | }
553 |
554 | "correctly mark unequal schema field" in {
555 | val sourceDF = spark.createDF(
556 | List(
557 | (1, 2.0),
558 | (5, 3.0)
559 | ),
560 | List(
561 | ("number", IntegerType, true),
562 | ("float", DoubleType, true)
563 | )
564 | )
565 |
566 | val expectedDF = spark.createDF(
567 | List(
568 | (1, "word", 1L),
569 | (5, "word", 2L)
570 | ),
571 | List(
572 | ("number", IntegerType, true),
573 | ("word", StringType, true),
574 | ("long", LongType, true)
575 | )
576 | )
577 |
578 | val e = intercept[DatasetSchemaMismatch] {
579 | assertLargeDatasetEquality(sourceDF, expectedDF)
580 | }
581 |
582 | e.assertColorDiff(Seq("float", "DoubleType", "MISSING"), Seq("word", "StringType", "StructField(long,LongType,true,{})"))
583 | }
584 |
585 | "can performed Dataset comparisons and ignore metadata" in {
586 | val ds1 = Seq(
587 | Person("juan", 5),
588 | Person("bob", 1),
589 | Person("li", 49),
590 | Person("alice", 5)
591 | ).toDS
592 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build()))
593 | .as[Person]
594 |
595 | val ds2 = Seq(
596 | Person("juan", 5),
597 | Person("bob", 1),
598 | Person("li", 49),
599 | Person("alice", 5)
600 | ).toDS
601 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build()))
602 | .as[Person]
603 |
604 | assertLargeDatasetEquality(ds2, ds1)
605 | }
606 |
607 | "can performed Dataset comparisons and compare metadata" in {
608 | val ds1 = Seq(
609 | Person("juan", 5),
610 | Person("bob", 1),
611 | Person("li", 49),
612 | Person("alice", 5)
613 | ).toDS
614 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build()))
615 | .as[Person]
616 |
617 | val ds2 = Seq(
618 | Person("juan", 5),
619 | Person("bob", 1),
620 | Person("li", 49),
621 | Person("alice", 5)
622 | ).toDS
623 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build()))
624 | .as[Person]
625 |
626 | intercept[DatasetSchemaMismatch] {
627 | assertLargeDatasetEquality(ds2, ds1, ignoreMetadata = false)
628 | }
629 | }
630 | }
631 |
632 | "assertSmallDatasetEquality" - {
633 | import spark.implicits._
634 |
635 | "ignores the nullable flag when making DataFrame comparisons" in {
636 | val sourceDF = spark.createDF(
637 | List(
638 | (1),
639 | (5)
640 | ),
641 | List(("number", IntegerType, false))
642 | )
643 |
644 | val expectedDF = spark.createDF(
645 | List(
646 | (1),
647 | (5)
648 | ),
649 | List(("number", IntegerType, true))
650 | )
651 |
652 | assertSmallDatasetEquality(sourceDF, expectedDF, ignoreNullable = true)
653 | }
654 |
655 | "should not ignore nullable if ignoreNullable is false" in {
656 | val sourceDF = spark.createDF(
657 | List(
658 | (1),
659 | (5)
660 | ),
661 | List(("number", IntegerType, false))
662 | )
663 |
664 | val expectedDF = spark.createDF(
665 | List(
666 | (1),
667 | (5)
668 | ),
669 | List(("number", IntegerType, true))
670 | )
671 |
672 | intercept[DatasetSchemaMismatch] {
673 | assertSmallDatasetEquality(sourceDF, expectedDF)
674 | }
675 | }
676 |
677 | "can performed unordered DataFrame comparisons" in {
678 | val sourceDF = spark.createDF(
679 | List(
680 | (1),
681 | (5)
682 | ),
683 | List(("number", IntegerType, true))
684 | )
685 | val expectedDF = spark.createDF(
686 | List(
687 | (5),
688 | (1)
689 | ),
690 | List(("number", IntegerType, true))
691 | )
692 | assertSmallDatasetEquality(sourceDF, expectedDF, orderedComparison = false)
693 | }
694 |
695 | "can performed unordered Dataset comparisons" in {
696 | val sourceDS = spark.createDataset[Person](
697 | Seq(
698 | Person("bob", 1),
699 | Person("alice", 5)
700 | )
701 | )
702 | val expectedDS = spark.createDataset[Person](
703 | Seq(
704 | Person("alice", 5),
705 | Person("bob", 1)
706 | )
707 | )
708 | assertSmallDatasetEquality(sourceDS, expectedDS, orderedComparison = false)
709 | }
710 |
711 | "throws an error for unordered Dataset comparisons that don't match" in {
712 | val sourceDS = spark.createDataset[Person](
713 | Seq(
714 | Person("bob", 1),
715 | Person("frank", 5)
716 | )
717 | )
718 | val expectedDS = spark.createDataset[Person](
719 | Seq(
720 | Person("frank", 5),
721 | Person("bob", 1),
722 | Person("sadie", 2)
723 | )
724 | )
725 | val e = intercept[DatasetContentMismatch] {
726 | assertSmallDatasetEquality(sourceDS, expectedDS, orderedComparison = false)
727 | }
728 | }
729 |
730 | "throws an error for unordered DataFrame comparisons that don't match" in {
731 | val sourceDF = spark.createDF(
732 | List(
733 | (1),
734 | (5)
735 | ),
736 | List(("number", IntegerType, true))
737 | )
738 | val expectedDF = spark.createDF(
739 | List(
740 | (5),
741 | (1),
742 | (10)
743 | ),
744 | List(("number", IntegerType, true))
745 | )
746 | val e = intercept[DatasetContentMismatch] {
747 | assertSmallDatasetEquality(sourceDF, expectedDF, orderedComparison = false)
748 | }
749 | }
750 |
751 | "throws an error DataFrames have a different number of rows" in {
752 | val sourceDF = spark.createDF(
753 | List(
754 | (1),
755 | (5)
756 | ),
757 | List(("number", IntegerType, true))
758 | )
759 | val expectedDF = spark.createDF(
760 | List(
761 | (1),
762 | (5),
763 | (10)
764 | ),
765 | List(("number", IntegerType, true))
766 | )
767 | val e = intercept[DatasetContentMismatch] {
768 | assertSmallDatasetEquality(sourceDF, expectedDF)
769 | }
770 | }
771 |
772 | "can performed DataFrame comparisons with unordered column" in {
773 | val sourceDF = spark.createDF(
774 | List(
775 | (1, "word"),
776 | (5, "word")
777 | ),
778 | List(
779 | ("number", IntegerType, true),
780 | ("word", StringType, true)
781 | )
782 | )
783 | val expectedDF = spark.createDF(
784 | List(
785 | ("word", 1),
786 | ("word", 5)
787 | ),
788 | List(
789 | ("word", StringType, true),
790 | ("number", IntegerType, true)
791 | )
792 | )
793 | assertSmallDatasetEquality(sourceDF, expectedDF, ignoreColumnOrder = true)
794 | }
795 |
796 | "can performed Dataset comparisons with unordered column" in {
797 | val ds1 = Seq(
798 | Person("juan", 5),
799 | Person("bob", 1),
800 | Person("li", 49),
801 | Person("alice", 5)
802 | ).toDS
803 |
804 | val ds2 = Seq(
805 | Person("juan", 5),
806 | Person("bob", 1),
807 | Person("li", 49),
808 | Person("alice", 5)
809 | ).toDS.select("age", "name").as(ds1.encoder)
810 |
811 | assertSmallDatasetEquality(ds2, ds1, ignoreColumnOrder = true)
812 | }
813 |
814 | "correctly mark unequal schema field" in {
815 | val sourceDF = spark.createDF(
816 | List(
817 | (1, 2.0),
818 | (5, 3.0)
819 | ),
820 | List(
821 | ("number", IntegerType, true),
822 | ("float", DoubleType, true)
823 | )
824 | )
825 |
826 | val expectedDF = spark.createDF(
827 | List(
828 | (1, "word", 1L),
829 | (5, "word", 2L)
830 | ),
831 | List(
832 | ("number", IntegerType, true),
833 | ("word", StringType, true),
834 | ("long", LongType, true)
835 | )
836 | )
837 |
838 | val e = intercept[DatasetSchemaMismatch] {
839 | assertSmallDatasetEquality(sourceDF, expectedDF)
840 | }
841 |
842 | e.assertColorDiff(Seq("float", "DoubleType", "MISSING"), Seq("word", "StringType", "StructField(long,LongType,true,{})"))
843 | }
844 |
845 | "correctly mark schema with unequal metadata" in {
846 | val sourceDF = spark.createDF(
847 | List(
848 | (1, 2.0),
849 | (5, 3.0)
850 | ),
851 | List(
852 | ("number", IntegerType, true),
853 | ("float", DoubleType, true)
854 | )
855 | )
856 |
857 | val expectedDF = spark
858 | .createDF(
859 | List(
860 | (1, 2.0),
861 | (5, 3.0)
862 | ),
863 | List(
864 | ("number", IntegerType, true),
865 | ("float", DoubleType, true)
866 | )
867 | )
868 | .withColumn("float", col("float").as("float", new MetadataBuilder().putString("description", "a float").build()))
869 |
870 | val e = intercept[DatasetSchemaMismatch] {
871 | assertSmallDatasetEquality(sourceDF, expectedDF, ignoreMetadata = false)
872 | }
873 |
874 | e.assertColorDiff(
875 | Seq("{}"),
876 | Seq("{\"description\":\"a float\"}")
877 | )
878 | }
879 |
880 | "can performed Dataset comparisons and ignore metadata" in {
881 | val ds1 = Seq(
882 | Person("juan", 5),
883 | Person("bob", 1),
884 | Person("li", 49),
885 | Person("alice", 5)
886 | ).toDS
887 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build()))
888 | .as[Person]
889 |
890 | val ds2 = Seq(
891 | Person("juan", 5),
892 | Person("bob", 1),
893 | Person("li", 49),
894 | Person("alice", 5)
895 | ).toDS
896 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build()))
897 | .as[Person]
898 |
899 | assertSmallDatasetEquality(ds2, ds1)
900 | }
901 |
902 | "can performed Dataset comparisons and compare metadata" in {
903 | val ds1 = Seq(
904 | Person("juan", 5),
905 | Person("bob", 1),
906 | Person("li", 49),
907 | Person("alice", 5)
908 | ).toDS
909 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build()))
910 | .as[Person]
911 |
912 | val ds2 = Seq(
913 | Person("juan", 5),
914 | Person("bob", 1),
915 | Person("li", 49),
916 | Person("alice", 5)
917 | ).toDS
918 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build()))
919 | .as[Person]
920 |
921 | intercept[DatasetSchemaMismatch] {
922 | assertSmallDatasetEquality(ds2, ds1, ignoreMetadata = false)
923 | }
924 | }
925 | }
926 |
927 | "defaultSortDataset" - {
928 |
929 | "sorts a DataFrame by the column names in alphabetical order" in {
930 | val sourceDF = spark.createDF(
931 | List(
932 | (5, "bob"),
933 | (1, "phil"),
934 | (5, "anne")
935 | ),
936 | List(
937 | ("fun_level", IntegerType, true),
938 | ("name", StringType, true)
939 | )
940 | )
941 | val actualDF = defaultSortDataset(sourceDF)
942 | val expectedDF = spark.createDF(
943 | List(
944 | (1, "phil"),
945 | (5, "anne"),
946 | (5, "bob")
947 | ),
948 | List(
949 | ("fun_level", IntegerType, true),
950 | ("name", StringType, true)
951 | )
952 | )
953 | assertSmallDatasetEquality(actualDF, expectedDF)
954 | }
955 |
956 | }
957 |
958 | "assertApproximateDataFrameEquality" - {
959 |
960 | "does nothing if the DataFrames have the same schemas and content" in {
961 | val sourceDF = spark.createDF(
962 | List(
963 | (1.2),
964 | (5.1),
965 | (null)
966 | ),
967 | List(("number", DoubleType, true))
968 | )
969 | val expectedDF = spark.createDF(
970 | List(
971 | (1.2),
972 | (5.1),
973 | (null)
974 | ),
975 | List(("number", DoubleType, true))
976 | )
977 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01)
978 | }
979 |
980 | "throws an error if the rows are different" in {
981 | val sourceDF = spark.createDF(
982 | List(
983 | (100.9),
984 | (5.1)
985 | ),
986 | List(("number", DoubleType, true))
987 | )
988 | val expectedDF = spark.createDF(
989 | List(
990 | (1.2),
991 | (5.1)
992 | ),
993 | List(("number", DoubleType, true))
994 | )
995 | val e = intercept[DatasetContentMismatch] {
996 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01)
997 | }
998 | }
999 |
1000 | "throws an error DataFrames have a different number of rows" in {
1001 | val sourceDF = spark.createDF(
1002 | List(
1003 | (1.2),
1004 | (5.1),
1005 | (8.8)
1006 | ),
1007 | List(("number", DoubleType, true))
1008 | )
1009 | val expectedDF = spark.createDF(
1010 | List(
1011 | (1.2),
1012 | (5.1)
1013 | ),
1014 | List(("number", DoubleType, true))
1015 | )
1016 | val e = intercept[DatasetCountMismatch] {
1017 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01)
1018 | }
1019 | }
1020 |
1021 | "can ignore the nullable property" in {
1022 | val sourceDF = spark.createDF(
1023 | List(
1024 | (1.2),
1025 | (5.1)
1026 | ),
1027 | List(("number", DoubleType, false))
1028 | )
1029 | val expectedDF = spark.createDF(
1030 | List(
1031 | (1.2),
1032 | (5.1)
1033 | ),
1034 | List(("number", DoubleType, true))
1035 | )
1036 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreNullable = true)
1037 | }
1038 |
1039 | "can ignore the column names" in {
1040 | val sourceDF = spark.createDF(
1041 | List(
1042 | (1.2),
1043 | (5.1),
1044 | (null)
1045 | ),
1046 | List(("BLAHBLBH", DoubleType, true))
1047 | )
1048 | val expectedDF = spark.createDF(
1049 | List(
1050 | (1.2),
1051 | (5.1),
1052 | (null)
1053 | ),
1054 | List(("number", DoubleType, true))
1055 | )
1056 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreColumnNames = true)
1057 | }
1058 |
1059 | "can work with precision and unordered comparison" in {
1060 | import spark.implicits._
1061 | val ds1 = Seq(
1062 | ("1", "10/01/2019", 26.762499999999996),
1063 | ("1", "11/01/2019", 26.762499999999996)
1064 | ).toDF("col_B", "col_C", "col_A")
1065 |
1066 | val ds2 = Seq(
1067 | ("1", "10/01/2019", 26.762499999999946),
1068 | ("1", "11/01/2019", 26.76249999999991)
1069 | ).toDF("col_B", "col_C", "col_A")
1070 |
1071 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
1072 | }
1073 |
1074 | "can work with precision and unordered comparison 2" in {
1075 | import spark.implicits._
1076 | val ds1 = Seq(
1077 | ("1", "10/01/2019", 26.762499999999996, "A"),
1078 | ("1", "10/01/2019", 26.762499999999996, "B")
1079 | ).toDF("col_B", "col_C", "col_A", "col_D")
1080 |
1081 | val ds2 = Seq(
1082 | ("1", "10/01/2019", 26.762499999999946, "A"),
1083 | ("1", "10/01/2019", 26.76249999999991, "B")
1084 | ).toDF("col_B", "col_C", "col_A", "col_D")
1085 |
1086 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
1087 | }
1088 |
1089 | "can work with precision and unordered comparison on nested column" in {
1090 | import spark.implicits._
1091 | val ds1 = Seq(
1092 | ("1", "10/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)),
1093 | ("1", "11/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996))
1094 | ).toDF("col_B", "col_C", "col_A", "col_D")
1095 |
1096 | val ds2 = Seq(
1097 | ("1", "11/01/2019", 26.7624999999999961, Seq(26.7624999999999961, 26.7624999999999961)),
1098 | ("1", "10/01/2019", 26.762499999999997, Seq(26.762499999999997, 26.762499999999997))
1099 | ).toDF("col_B", "col_C", "col_A", "col_D")
1100 |
1101 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)
1102 | }
1103 | }
1104 |
1105 | // "works with FloatType columns" - {
1106 | // val sourceDF = spark.createDF(
1107 | // List(
1108 | // (1.2),
1109 | // (5.1),
1110 | // (null)
1111 | // ),
1112 | // List(
1113 | // ("number", FloatType, true)
1114 | // )
1115 | // )
1116 | //
1117 | // val expectedDF = spark.createDF(
1118 | // List(
1119 | // (1.2),
1120 | // (5.1),
1121 | // (null)
1122 | // ),
1123 | // List(
1124 | // ("number", FloatType, true)
1125 | // )
1126 | // )
1127 | //
1128 | // assertApproximateDataFrameEquality(
1129 | // sourceDF,
1130 | // expectedDF,
1131 | // 0.01
1132 | // )
1133 | // }
1134 |
1135 | }
1136 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/ExamplesTest.scala:
--------------------------------------------------------------------------------
1 | //package com.github.mrpowers.spark.fast.tests
2 | //
3 | //import org.apache.spark.sql.types._
4 | //import SparkSessionExt._
5 | //
6 | //import org.scalatest.FreeSpec
7 | //
8 | //class ExamplesTest extends FreeSpec with SparkSessionTestWrapper with DataFrameComparer with ColumnComparer {
9 | //
10 | // "assertSmallDatasetEquality" - {
11 | //
12 | // "error when row counts don't match" in {
13 | //
14 | // val sourceDF = spark.createDF(
15 | // List(
16 | // (1),
17 | // (5)
18 | // ),
19 | // List(("number", IntegerType, true))
20 | // )
21 | //
22 | // val expectedDF = spark.createDF(
23 | // List(
24 | // (1),
25 | // (5),
26 | // (10)
27 | // ),
28 | // List(("number", IntegerType, true))
29 | // )
30 | //
31 | // assertSmallDatasetEquality(
32 | // sourceDF,
33 | // expectedDF
34 | // )
35 | //
36 | // }
37 | //
38 | // "error when schemas don't match" in {
39 | //
40 | // val sourceDF = spark.createDF(
41 | // List(
42 | // (1, "a"),
43 | // (5, "b")
44 | // ),
45 | // List(
46 | // ("number", IntegerType, true),
47 | // ("letter", StringType, true)
48 | // )
49 | // )
50 | //
51 | // val expectedDF = spark.createDF(
52 | // List(
53 | // (1, "a"),
54 | // (5, "b")
55 | // ),
56 | // List(
57 | // ("num", IntegerType, true),
58 | // ("letter", StringType, true)
59 | // )
60 | // )
61 | //
62 | // assertSmallDatasetEquality(
63 | // sourceDF,
64 | // expectedDF
65 | // )
66 | //
67 | // }
68 | //
69 | // "error when content doesn't match" in {
70 | //
71 | // val sourceDF = spark.createDF(
72 | // List(
73 | // (1, "z"),
74 | // (5, "b")
75 | // ),
76 | // List(
77 | // ("number", IntegerType, true),
78 | // ("letter", StringType, true)
79 | // )
80 | // )
81 | //
82 | // val expectedDF = spark.createDF(
83 | // List(
84 | // (1111, "a"),
85 | // (5, "b")
86 | // ),
87 | // List(
88 | // ("number", IntegerType, true),
89 | // ("letter", StringType, true)
90 | // )
91 | // )
92 | //
93 | // assertSmallDataFrameEquality(
94 | // sourceDF,
95 | // expectedDF
96 | // )
97 | //
98 | // }
99 | //
100 | // }
101 | //
102 | // "pretty column mismatch message" in {
103 | //
104 | // val df = spark.createDF(
105 | // List(
106 | // ("a", "z"),
107 | // ("b", "b")
108 | // ),
109 | // List(
110 | // ("letter1", StringType, true),
111 | // ("letter", StringType, true)
112 | // )
113 | // )
114 | //
115 | // assertColumnEquality(df, "letter1", "letter")
116 | //
117 | // }
118 | //
119 | //}
120 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/RDDComparerTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.scalatest.freespec.AnyFreeSpec
4 |
5 | class RDDComparerTest extends AnyFreeSpec with RDDComparer with SparkSessionTestWrapper {
6 |
7 | "contentMismatchMessage" - {
8 |
9 | "returns a string string to compare the expected and actual RDDs" in {
10 | val sourceData = List(
11 | ("cat"),
12 | ("dog"),
13 | ("frog")
14 | )
15 |
16 | val actualRDD = spark.sparkContext.parallelize(sourceData)
17 |
18 | val expectedData = List(
19 | ("man"),
20 | ("can"),
21 | ("pan")
22 | )
23 |
24 | val expectedRDD = spark.sparkContext.parallelize(expectedData)
25 |
26 | val expected = """
27 | Actual RDD Content:
28 | cat
29 | dog
30 | frog
31 | Expected RDD Content:
32 | man
33 | can
34 | pan
35 | """
36 |
37 | assert(
38 | contentMismatchMessage(actualRDD, expectedRDD) == expected
39 | )
40 | }
41 |
42 | }
43 |
44 | "assertSmallRDDEquality" - {
45 |
46 | "does nothing if the RDDs have the same content" in {
47 | val sourceData = List(
48 | ("cat"),
49 | ("dog"),
50 | ("frog")
51 | )
52 |
53 | val sourceRDD = spark.sparkContext.parallelize(sourceData)
54 |
55 | val expectedData = List(
56 | ("cat"),
57 | ("dog"),
58 | ("frog")
59 | )
60 |
61 | val expectedRDD = spark.sparkContext.parallelize(expectedData)
62 |
63 | assertSmallRDDEquality(sourceRDD, expectedRDD)
64 | }
65 |
66 | "throws an error if the RDDs have different content" in {
67 | val sourceData = List(
68 | ("cat"),
69 | ("dog"),
70 | ("frog")
71 | )
72 | val sourceRDD = spark.sparkContext.parallelize(sourceData)
73 |
74 | val expectedData = List(
75 | ("mouse"),
76 | ("pig"),
77 | ("frog")
78 | )
79 |
80 | val expectedRDD = spark.sparkContext.parallelize(expectedData)
81 |
82 | val e = intercept[RDDContentMismatch] {
83 | assertSmallRDDEquality(sourceRDD, expectedRDD)
84 | }
85 | }
86 |
87 | }
88 |
89 | }
90 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/RowComparerTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.scalatest.freespec.AnyFreeSpec
4 |
5 | import org.apache.spark.sql.Row
6 |
7 | class RowComparerTest extends AnyFreeSpec {
8 |
9 | "areRowsEqual" - {
10 |
11 | "returns true for rows that contain the same elements" in {
12 | val r1 = Row("a", "b")
13 | val r2 = Row("a", "b")
14 | assert(
15 | RowComparer.areRowsEqual(r1, r2, 0.0)
16 | )
17 | }
18 |
19 | "returns false for rows that don't contain the same elements" - {
20 | val r1 = Row("a", 3)
21 | val r2 = Row("a", 4)
22 | assert(
23 | !RowComparer.areRowsEqual(r1, r2, 0.0)
24 | )
25 | }
26 |
27 | }
28 |
29 | }
30 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/SchemaComparerTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch
4 | import org.apache.spark.sql.types._
5 | import org.scalatest.freespec.AnyFreeSpec
6 |
7 | class SchemaComparerTest extends AnyFreeSpec {
8 |
9 | "equals" - {
10 |
11 | "returns true if the schemas are equal" in {
12 | val s1 = StructType(
13 | Seq(
14 | StructField("something", StringType, true),
15 | StructField("mood", StringType, true)
16 | )
17 | )
18 | val s2 = StructType(
19 | Seq(
20 | StructField("something", StringType, true),
21 | StructField("mood", StringType, true)
22 | )
23 | )
24 | assert(SchemaComparer.equals(s1, s2))
25 | }
26 |
27 | "works for single column schemas" in {
28 | val s1 = StructType(
29 | Seq(
30 | StructField("something", StringType, true)
31 | )
32 | )
33 | val s2 = StructType(
34 | Seq(
35 | StructField("something", StringType, false)
36 | )
37 | )
38 | assert(SchemaComparer.equals(s1, s2, true))
39 | }
40 |
41 | "returns false if the schemas aren't equal" in {
42 | val s1 = StructType(
43 | Seq(
44 | StructField("something", StringType, true)
45 | )
46 | )
47 | val s2 = StructType(
48 | Seq(
49 | StructField("something", StringType, true),
50 | StructField("mood", StringType, true)
51 | )
52 | )
53 | assert(!SchemaComparer.equals(s1, s2))
54 | }
55 |
56 | "can ignore the nullable flag when determining equality" in {
57 | val s1 = StructType(
58 | Seq(
59 | StructField("something", StringType, true),
60 | StructField("mood", StringType, true)
61 | )
62 | )
63 | val s2 = StructType(
64 | Seq(
65 | StructField("something", StringType, false),
66 | StructField("mood", StringType, true)
67 | )
68 | )
69 | assert(SchemaComparer.equals(s1, s2, ignoreNullable = true))
70 | }
71 |
72 | "do not ignore nullable when determining equality if ignoreNullable is true" in {
73 | val s1 = StructType(
74 | Seq(
75 | StructField("something", StringType, true),
76 | StructField("mood", StringType, true)
77 | )
78 | )
79 | val s2 = StructType(
80 | Seq(
81 | StructField("something", StringType, false),
82 | StructField("mood", StringType, true)
83 | )
84 | )
85 | assert(!SchemaComparer.equals(s1, s2))
86 | }
87 |
88 | "can ignore the nullable flag when determining equality on complex data types" in {
89 | val s1 = StructType(
90 | Seq(
91 | StructField("something", StringType, true),
92 | StructField("array", ArrayType(StringType, containsNull = true), true),
93 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
94 | StructField(
95 | "struct",
96 | StructType(
97 | StructType(
98 | Seq(
99 | StructField("something", StringType, false),
100 | StructField("mood", ArrayType(StringType, containsNull = false), true)
101 | )
102 | )
103 | ),
104 | true
105 | )
106 | )
107 | )
108 | val s2 = StructType(
109 | Seq(
110 | StructField("something", StringType, false),
111 | StructField("array", ArrayType(StringType, containsNull = false), true),
112 | StructField("map", MapType(StringType, StringType, valueContainsNull = true), true),
113 | StructField(
114 | "struct",
115 | StructType(
116 | StructType(
117 | Seq(
118 | StructField("something", StringType, false),
119 | StructField("mood", ArrayType(StringType, containsNull = true), true)
120 | )
121 | )
122 | ),
123 | false
124 | )
125 | )
126 | )
127 | assert(SchemaComparer.equals(s1, s2, ignoreNullable = true))
128 | }
129 |
130 | "do not ignore nullable when determining equality on complex data types if ignoreNullable is true" in {
131 | val s1 = StructType(
132 | Seq(
133 | StructField("something", StringType, true),
134 | StructField("array", ArrayType(StringType, containsNull = true), true),
135 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
136 | StructField(
137 | "struct",
138 | StructType(
139 | StructType(
140 | Seq(
141 | StructField("something", StringType, false),
142 | StructField("mood", ArrayType(StringType, containsNull = false), true)
143 | )
144 | )
145 | ),
146 | true
147 | )
148 | )
149 | )
150 | val s2 = StructType(
151 | Seq(
152 | StructField("something", StringType, false),
153 | StructField("array", ArrayType(StringType, containsNull = false), true),
154 | StructField("map", MapType(StringType, StringType, valueContainsNull = true), true),
155 | StructField(
156 | "struct",
157 | StructType(
158 | StructType(
159 | Seq(
160 | StructField("something", StringType, false),
161 | StructField("mood", ArrayType(StringType, containsNull = true), true)
162 | )
163 | )
164 | ),
165 | false
166 | )
167 | )
168 | )
169 | assert(!SchemaComparer.equals(s1, s2))
170 | }
171 |
172 | "can ignore the column names flag when determining equality" in {
173 | val s1 = StructType(
174 | Seq(
175 | StructField("these", StringType, true),
176 | StructField("are", StringType, true)
177 | )
178 | )
179 | val s2 = StructType(
180 | Seq(
181 | StructField("very", StringType, true),
182 | StructField("different", StringType, true)
183 | )
184 | )
185 | assert(SchemaComparer.equals(s1, s2, ignoreColumnNames = true))
186 | }
187 |
188 | "can ignore the column order when determining equality" in {
189 | val s1 = StructType(
190 | Seq(
191 | StructField("these", StringType, true),
192 | StructField("are", StringType, true)
193 | )
194 | )
195 | val s2 = StructType(
196 | Seq(
197 | StructField("are", StringType, true),
198 | StructField("these", StringType, true)
199 | )
200 | )
201 | assert(SchemaComparer.equals(s1, s2))
202 | }
203 |
204 | "can ignore the column order when determining equality of complex type" in {
205 | val s1 = StructType(
206 | Seq(
207 | StructField("array", ArrayType(StringType, containsNull = true), true),
208 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
209 | StructField("something", StringType, true),
210 | StructField(
211 | "struct",
212 | StructType(
213 | StructType(
214 | Seq(
215 | StructField("mood", ArrayType(StringType, containsNull = false), true),
216 | StructField("something", StringType, false)
217 | )
218 | )
219 | ),
220 | true
221 | )
222 | )
223 | )
224 | val s2 = StructType(
225 | Seq(
226 | StructField("something", StringType, true),
227 | StructField("array", ArrayType(StringType, containsNull = true), true),
228 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
229 | StructField(
230 | "struct",
231 | StructType(
232 | StructType(
233 | Seq(
234 | StructField("something", StringType, false),
235 | StructField("mood", ArrayType(StringType, containsNull = false), true)
236 | )
237 | )
238 | ),
239 | true
240 | )
241 | )
242 | )
243 | assert(SchemaComparer.equals(s1, s2))
244 | }
245 |
246 | "display schema diff as tree with different depth" in {
247 | val s1 = StructType(
248 | Seq(
249 | StructField("array", ArrayType(StringType, containsNull = true), true),
250 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
251 | StructField("something", StringType, true),
252 | StructField(
253 | "struct",
254 | StructType(
255 | StructType(
256 | Seq(
257 | StructField("mood", ArrayType(StringType, containsNull = false), true),
258 | StructField("something", StringType, false),
259 | StructField(
260 | "something2",
261 | StructType(
262 | Seq(
263 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
264 | StructField("something2", StringType, false)
265 | )
266 | ),
267 | false
268 | )
269 | )
270 | )
271 | ),
272 | true
273 | )
274 | )
275 | )
276 | val s2 = StructType(
277 | Seq(
278 | StructField("array", ArrayType(StringType, containsNull = true), true),
279 | StructField("something", StringType, true),
280 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
281 | StructField(
282 | "struct",
283 | StructType(
284 | StructType(
285 | Seq(
286 | StructField("something", StringType, false),
287 | StructField("mood", ArrayType(StringType, containsNull = false), true),
288 | StructField(
289 | "something3",
290 | StructType(
291 | Seq(
292 | StructField("mood3", ArrayType(StringType, containsNull = false), true)
293 | )
294 | ),
295 | false
296 | )
297 | )
298 | )
299 | ),
300 | true
301 | ),
302 | StructField("norma2", StringType, false)
303 | )
304 | )
305 |
306 | val e = intercept[DatasetSchemaMismatch] {
307 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree)
308 | }
309 | val expectedMessage = """Diffs
310 | |
311 | |Actual Schema Expected Schema
312 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
313 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m)
314 | |\u001b[90m|--\u001b[39m \u001b[31mmap\u001b[39m : \u001b[31mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
315 | |\u001b[90m|--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32mmap\u001b[39m : \u001b[32mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
316 | |\u001b[90m|--\u001b[39m \u001b[90mstruct\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90mstruct\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
317 | |\u001b[90m| |--\u001b[39m \u001b[31mmood\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[31mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[32mfalse\u001b[39m)
318 | |\u001b[31m| | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[32m| |--\u001b[39m \u001b[32mmood\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
319 | |\u001b[31m| |--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) \u001b[32m| | |--\u001b[39m \u001b[32melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m)
320 | |\u001b[90m| |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething3\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m)
321 | |\u001b[90m| | |--\u001b[39m \u001b[31mmood2\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[32mmood3\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
322 | |\u001b[90m| | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m)
323 | |\u001b[31m| | |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[32m|--\u001b[39m \u001b[32mnorma2\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m)
324 | |""".stripMargin
325 |
326 | assert(e.getMessage == expectedMessage)
327 | }
328 |
329 | "display schema diff for tree with array of struct" in {
330 | val s1 = StructType(
331 | Seq(
332 | StructField("array", ArrayType(StructType(Seq(StructField("arrayChild1", StringType))), containsNull = true), true)
333 | )
334 | )
335 | val s2 = StructType(
336 | Seq(
337 | StructField("array", ArrayType(StructType(Seq(StructField("arrayChild2", IntegerType))), containsNull = false), true)
338 | )
339 | )
340 | s1.printTreeString()
341 |
342 | val e = intercept[DatasetSchemaMismatch] {
343 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree)
344 | }
345 |
346 | val expectedMessage = """Diffs
347 | |
348 | |Actual Schema Expected Schema
349 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
350 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31mstruct\u001b[39m (containsNull = \u001b[31mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32mstruct\u001b[39m (containsNull = \u001b[32mfalse\u001b[39m)
351 | |\u001b[90m| | |--\u001b[39m \u001b[31marrayChild1\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[32marrayChild2\u001b[39m : \u001b[32minteger\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
352 | |""".stripMargin
353 |
354 | assert(e.getMessage == expectedMessage)
355 | }
356 |
357 | "display schema diff for tree with array of array of struct" in {
358 | val s1 = StructType(
359 | Seq(
360 | StructField("array", ArrayType(ArrayType(StructType(Seq(StructField("arrayChild1", StringType))), containsNull = true)))
361 | )
362 | )
363 | val s2 = StructType(
364 | Seq(
365 | StructField("array", ArrayType(ArrayType(StructType(Seq(StructField("arrayChild2", IntegerType))), containsNull = false)))
366 | )
367 | )
368 | s1.printTreeString()
369 |
370 | val e = intercept[DatasetSchemaMismatch] {
371 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree)
372 | }
373 |
374 | val expectedMessage = """Diffs
375 | |
376 | |Actual Schema Expected Schema
377 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
378 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31marray\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32marray\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m)
379 | |\u001b[90m| | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31mstruct\u001b[39m (containsNull = \u001b[31mtrue\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32mstruct\u001b[39m (containsNull = \u001b[32mfalse\u001b[39m)
380 | |\u001b[90m| | | |--\u001b[39m \u001b[31marrayChild1\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[32marrayChild2\u001b[39m : \u001b[32minteger\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
381 | |""".stripMargin
382 |
383 | assert(e.getMessage == expectedMessage)
384 | }
385 |
386 | "display schema diff for tree with array of simple type" in {
387 | val s1 = StructType(
388 | Seq(
389 | StructField("array", ArrayType(StringType, containsNull = true), true)
390 | )
391 | )
392 | val s2 = StructType(
393 | Seq(
394 | StructField("array", ArrayType(IntegerType, containsNull = true), true)
395 | )
396 | )
397 | s1.printTreeString()
398 |
399 | val e = intercept[DatasetSchemaMismatch] {
400 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree)
401 | }
402 |
403 | val expectedMessage = """Diffs
404 | |
405 | |Actual Schema Expected Schema
406 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
407 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32minteger\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m)
408 | |""".stripMargin
409 |
410 | assert(e.getMessage == expectedMessage)
411 | }
412 |
413 | "display schema diff for wide tree" in {
414 | val s1 = StructType(
415 | Seq(
416 | StructField("array", ArrayType(StringType, containsNull = true), true),
417 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
418 | StructField("something", StringType, true),
419 | StructField(
420 | "struct",
421 | StructType(
422 | StructType(
423 | Seq(
424 | StructField("mood", ArrayType(StringType, containsNull = false), true),
425 | StructField("something", StringType, false),
426 | StructField(
427 | "something2",
428 | StructType(
429 | Seq(
430 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
431 | StructField(
432 | "something2",
433 | StructType(
434 | Seq(
435 | StructField("mood", ArrayType(StringType, containsNull = false), true),
436 | StructField("something", StringType, false),
437 | StructField(
438 | "something2",
439 | StructType(
440 | Seq(
441 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
442 | StructField("something2", StringType, false)
443 | )
444 | ),
445 | false
446 | )
447 | )
448 | ),
449 | false
450 | )
451 | )
452 | ),
453 | false
454 | )
455 | )
456 | )
457 | ),
458 | true
459 | )
460 | )
461 | )
462 | val s2 = StructType(
463 | Seq(
464 | StructField("array", ArrayType(StringType, containsNull = true), true),
465 | StructField("something", StringType, true),
466 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
467 | StructField(
468 | "struct",
469 | StructType(
470 | StructType(
471 | Seq(
472 | StructField("something", StringType, false),
473 | StructField("mood", ArrayType(StringType, containsNull = false), true),
474 | StructField(
475 | "something3",
476 | StructType(
477 | Seq(
478 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
479 | StructField(
480 | "something2",
481 | StructType(
482 | Seq(
483 | StructField("mood", ArrayType(StringType, containsNull = false), true),
484 | StructField("something", StringType, false),
485 | StructField(
486 | "something2",
487 | StructType(
488 | Seq(
489 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
490 | StructField("something2", StringType, false)
491 | )
492 | ),
493 | false
494 | )
495 | )
496 | ),
497 | false
498 | )
499 | )
500 | ),
501 | false
502 | )
503 | )
504 | )
505 | ),
506 | true
507 | ),
508 | StructField("norma2", StringType, false)
509 | )
510 | )
511 |
512 | val e = intercept[DatasetSchemaMismatch] {
513 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree)
514 | }
515 | val expectedMessage = """Diffs
516 | |
517 | |Actual Schema Expected Schema
518 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
519 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m)
520 | |\u001b[90m|--\u001b[39m \u001b[31mmap\u001b[39m : \u001b[31mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
521 | |\u001b[90m|--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32mmap\u001b[39m : \u001b[32mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
522 | |\u001b[90m|--\u001b[39m \u001b[90mstruct\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90mstruct\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
523 | |\u001b[90m| |--\u001b[39m \u001b[31mmood\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[31mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[32mfalse\u001b[39m)
524 | |\u001b[31m| | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[32m| |--\u001b[39m \u001b[32mmood\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
525 | |\u001b[31m| |--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) \u001b[32m| | |--\u001b[39m \u001b[32melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m)
526 | |\u001b[90m| |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething3\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m)
527 | |\u001b[90m| | |--\u001b[39m \u001b[90mmood2\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[90mmood2\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
528 | |\u001b[90m| | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m)
529 | |\u001b[90m| | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m)
530 | |\u001b[90m| | | |--\u001b[39m \u001b[90mmood\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90mmood\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
531 | |\u001b[90m| | | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m)
532 | |\u001b[90m| | | |--\u001b[39m \u001b[90msomething\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90msomething\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m)
533 | |\u001b[90m| | | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m)
534 | |\u001b[90m| | | | |--\u001b[39m \u001b[90mmood2\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | | | |--\u001b[39m \u001b[90mmood2\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
535 | |\u001b[90m| | | | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | | | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m)
536 | |\u001b[90m| | | | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| | | | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m)
537 | | \u001b[90m|--\u001b[39m \u001b[32mnorma2\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[32mfalse\u001b[39m)
538 | |""".stripMargin
539 |
540 | assert(e.getMessage == expectedMessage)
541 | }
542 |
543 | "display schema diff as tree with more actual Column 2" in {
544 | val s1 = StructType(
545 | Seq(
546 | StructField("array", ArrayType(StringType, containsNull = true), true),
547 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true),
548 | StructField("something", StringType, true),
549 | StructField(
550 | "struct",
551 | StructType(
552 | StructType(
553 | Seq(
554 | StructField("mood", ArrayType(StringType, containsNull = false), true),
555 | StructField("something", StringType, false),
556 | StructField(
557 | "something2",
558 | StructType(
559 | Seq(
560 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
561 | StructField(
562 | "something2",
563 | StructType(
564 | Seq(
565 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
566 | StructField(
567 | "something2",
568 | StructType(
569 | Seq(
570 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true),
571 | StructField("something2", StringType, false)
572 | )
573 | ),
574 | false
575 | )
576 | )
577 | ),
578 | false
579 | )
580 | )
581 | ),
582 | false
583 | )
584 | )
585 | )
586 | ),
587 | true
588 | )
589 | )
590 | )
591 | val s2 = StructType(
592 | Seq(
593 | StructField("array", ArrayType(StringType, containsNull = true), true),
594 | StructField("something", StringType, true),
595 | StructField(
596 | "struct",
597 | StructType(
598 | StructType(
599 | Seq(
600 | StructField("something", StringType, false),
601 | StructField("mood", ArrayType(StringType, containsNull = false), true),
602 | StructField(
603 | "something3",
604 | StructType(
605 | Seq(
606 | StructField("mood3", ArrayType(StringType, containsNull = false), true)
607 | )
608 | ),
609 | false
610 | )
611 | )
612 | )
613 | ),
614 | true
615 | )
616 | )
617 | )
618 |
619 | val e = intercept[DatasetSchemaMismatch] {
620 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree)
621 | }
622 |
623 | val expectedMessage = """Diffs
624 | |
625 | |Actual Schema Expected Schema
626 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
627 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m)
628 | |\u001b[90m|--\u001b[39m \u001b[31mmap\u001b[39m : \u001b[31mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
629 | |\u001b[90m|--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32mstruct\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
630 | |\u001b[31m|--\u001b[39m \u001b[31mstruct\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[31mtrue\u001b[39m) \u001b[32m| |--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[32mfalse\u001b[39m)
631 | |\u001b[90m| |--\u001b[39m \u001b[90mmood\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90mmood\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m)
632 | |\u001b[90m| | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m)
633 | |\u001b[90m| |--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething3\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m)
634 | |\u001b[31m| |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) \u001b[32m| | |--\u001b[39m \u001b[32mmood3\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[32mtrue\u001b[39m)
635 | |\u001b[31m| | |--\u001b[39m \u001b[31mmood2\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[32m| | | |--\u001b[39m \u001b[32melement\u001b[39m : \u001b[32mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m)
636 | |\u001b[31m| | | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mdouble\u001b[39m (containsNull = \u001b[31mfalse\u001b[39m)
637 | |\u001b[31m| | |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[31mfalse\u001b[39m)
638 | |\u001b[31m| | | |--\u001b[39m \u001b[31mmood2\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[31mtrue\u001b[39m)
639 | |\u001b[31m| | | | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mdouble\u001b[39m (containsNull = \u001b[31mfalse\u001b[39m)
640 | |\u001b[31m| | | |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[31mfalse\u001b[39m)
641 | |\u001b[31m| | | | |--\u001b[39m \u001b[31mmood2\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[31mtrue\u001b[39m)
642 | |\u001b[31m| | | | | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mdouble\u001b[39m (containsNull = \u001b[31mfalse\u001b[39m)
643 | |\u001b[31m| | | | |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[31mfalse\u001b[39m)
644 | |""".stripMargin
645 |
646 | assert(e.getMessage == expectedMessage)
647 | }
648 | }
649 | }
650 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/SeqLikesExtensionsTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.scalatest.freespec.AnyFreeSpec
4 | import SeqLikesExtensions._
5 |
6 | class SeqLikesExtensionsTest extends AnyFreeSpec with SparkSessionTestWrapper {
7 |
8 | "check equality" - {
9 | import spark.implicits._
10 |
11 | "check equal Seq" in {
12 | val source = Seq(
13 | ("juan", 5),
14 | ("bob", 1),
15 | ("li", 49),
16 | ("alice", 5)
17 | )
18 |
19 | val expected = Seq(
20 | ("juan", 5),
21 | ("bob", 1),
22 | ("li", 49),
23 | ("alice", 5)
24 | )
25 |
26 | assert(source.approximateSameElements(expected, (s1, s2) => s1 == s2))
27 | }
28 |
29 | "check equal Seq[Row]" in {
30 |
31 | val source = Seq(
32 | ("juan", 5),
33 | ("bob", 1),
34 | ("li", 49),
35 | ("alice", 5)
36 | ).toDF.collect().toSeq
37 |
38 | val expected = Seq(
39 | ("juan", 5),
40 | ("bob", 1),
41 | ("li", 49),
42 | ("alice", 5)
43 | ).toDF.collect()
44 |
45 | assert(source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _)))
46 | }
47 |
48 | "check unequal Seq[Row]" in {
49 |
50 | val source = Seq(
51 | ("juan", 5),
52 | ("bob", 1),
53 | ("li", 49),
54 | ("alice", 5)
55 | ).toDF.collect().toSeq
56 |
57 | val expected = Seq(
58 | ("juan", 5),
59 | ("bob", 1),
60 | ("li", 40),
61 | ("alice", 5)
62 | ).toDF.collect()
63 |
64 | assert(!source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _)))
65 | }
66 |
67 | "check equal Seq[Row] with tolerance" in {
68 | val source = Seq(
69 | ("juan", 12.00000000001),
70 | ("bob", 1.00000000001),
71 | ("li", 49.00000000001),
72 | ("alice", 5.00000000001)
73 | ).toDF.collect().toSeq
74 |
75 | val expected = Seq(
76 | ("juan", 12.0),
77 | ("bob", 1.0),
78 | ("li", 49.0),
79 | ("alice", 5.0)
80 | ).toDF.collect()
81 | assert(source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _, 0.0000000002)))
82 | }
83 |
84 | "check indexedSeq[Row] with tolerance" in {
85 | val source = Seq(
86 | ("juan", 12.00000000001),
87 | ("bob", 1.00000000001),
88 | ("li", 49.00000000001),
89 | ("alice", 5.00000000001)
90 | ).toDF.collect().toIndexedSeq
91 |
92 | val expected = Seq(
93 | ("juan", 12.0),
94 | ("bob", 1.0),
95 | ("li", 49.0),
96 | ("alice", 5.0)
97 | ).toDF.collect().toIndexedSeq
98 |
99 | assert(source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _, 0.0000000002)))
100 | }
101 |
102 | "check non equal Seq[Row] with tolerance" in {
103 | val source = Seq(
104 | ("juan", 12.00000000002),
105 | ("bob", 1.00000000002),
106 | ("li", 49.00000000002),
107 | ("alice", 5.00000000002)
108 | ).toDF.collect().toSeq
109 |
110 | val expected = Seq(
111 | ("juan", 12),
112 | ("bob", 1),
113 | ("li", 49),
114 | ("alice", 5)
115 | ).toDF.collect()
116 |
117 | assert(!source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _, .00000000001)))
118 | }
119 |
120 | "check non equal indexedSeq[Row] with tolerance" in {
121 | val source = Seq(
122 | ("juan", 12.00000000002),
123 | ("bob", 1.00000000002),
124 | ("li", 49.00000000002),
125 | ("alice", 5.00000000002)
126 | ).toDF.collect().toIndexedSeq
127 |
128 | val expected = Seq(
129 | ("juan", 12),
130 | ("bob", 1),
131 | ("li", 49),
132 | ("alice", 5)
133 | ).toDF.collect().toIndexedSeq
134 |
135 | assert(!source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _, .00000000001)))
136 | }
137 | }
138 | }
139 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/SparkSessionExt.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions
4 | import org.apache.spark.sql.{DataFrame, Row, SparkSession}
5 | import org.apache.spark.sql.types.{DataType, StructField, StructType}
6 |
7 | object SparkSessionExt {
8 |
9 | implicit class SparkSessionMethods(spark: SparkSession) {
10 |
11 | private def asSchema[U](fields: List[U]): List[StructField] = {
12 | fields.map {
13 | case x: StructField => x.asInstanceOf[StructField]
14 | case (name: String, dataType: DataType, nullable: Boolean) =>
15 | StructField(name, dataType, nullable)
16 | }
17 | }
18 |
19 | /**
20 | * Creates a DataFrame, similar to createDataFrame, but with better syntax spark-daria defined a createDF method that allows for the terse syntax
21 | * of `toDF` and the control of `createDataFrame`.
22 | *
23 | * spark.createDF( List( ("bob", 45), ("liz", 25), ("freeman", 32) ), List( ("name", StringType, true), ("age", IntegerType, false) ) )
24 | *
25 | * The `createDF` method can also be used with lists of `Row` and `StructField` objects.
26 | *
27 | * spark.createDF( List( Row("bob", 45), Row("liz", 25), Row("freeman", 32) ), List( StructField("name", StringType, true), StructField("age",
28 | * IntegerType, false) ) )
29 | */
30 | def createDF[U, T](rowData: List[U], fields: List[T]): DataFrame = {
31 | spark.createDataFrame(
32 | spark.sparkContext.parallelize(rowData.asRows),
33 | StructType(asSchema(fields))
34 | )
35 | }
36 |
37 | }
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/SparkSessionTestWrapper.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import org.apache.spark.sql.SparkSession
4 |
5 | trait SparkSessionTestWrapper {
6 |
7 | lazy val spark: SparkSession = {
8 | val session = SparkSession
9 | .builder()
10 | .master("local")
11 | .appName("spark session")
12 | .config("spark.sql.shuffle.partitions", "1")
13 | .getOrCreate()
14 | session.sparkContext.setLogLevel("ERROR")
15 | session
16 | }
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/github/mrpowers/spark/fast/tests/TestUtilsExt.scala:
--------------------------------------------------------------------------------
1 | package com.github.mrpowers.spark.fast.tests
2 |
3 | import scala.util.matching.Regex
4 |
5 | object TestUtilsExt {
6 | val coloredStringPattern: Regex = "(\u001B\\[\\d{1,2}m)([\\s\\S]*?)(?=\u001B\\[\\d{1,2}m)".r
7 | implicit class StringOps(s: String) {
8 | def extractColorGroup: Map[String, Seq[String]] = coloredStringPattern
9 | .findAllMatchIn(s)
10 | .map(m => (m.group(1), m.group(2)))
11 | .toSeq
12 | .groupBy(_._1)
13 | .mapValues(_.map(_._2))
14 | .view
15 | .toMap
16 |
17 | def assertColorDiff(actual: Seq[String], expected: Seq[String]): Unit = {
18 | val colourGroup = extractColorGroup
19 | val expectedColourGroup = colourGroup.get(Console.GREEN)
20 | val actualColourGroup = colourGroup.get(Console.RED)
21 | assert(expectedColourGroup.contains(expected))
22 | assert(actualColourGroup.contains(actual))
23 | }
24 | }
25 |
26 | implicit class ExceptionOps(e: Exception) {
27 | def assertColorDiff(actual: Seq[String], expected: Seq[String]): Unit = e.getMessage.assertColorDiff(actual, expected)
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/docs/about/README.md:
--------------------------------------------------------------------------------
1 | ../../README.md
--------------------------------------------------------------------------------
/images/assertColumnEquality_error_message.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertColumnEquality_error_message.png
--------------------------------------------------------------------------------
/images/assertSchemaEquality_tree_message.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSchemaEquality_tree_message.png
--------------------------------------------------------------------------------
/images/assertSmallDataFrameEquality_DatasetContentMissmatch_message.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSmallDataFrameEquality_DatasetContentMissmatch_message.png
--------------------------------------------------------------------------------
/images/assertSmallDataFrameEquality_DatasetSchemaMisMatch_message.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSmallDataFrameEquality_DatasetSchemaMisMatch_message.png
--------------------------------------------------------------------------------
/images/assertSmallDataFrameEquality_error_message.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSmallDataFrameEquality_error_message.png
--------------------------------------------------------------------------------
/images/assertSmallDatasetEquality_error_message.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSmallDatasetEquality_error_message.png
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.10.1
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | logLevel := Level.Warn
2 |
3 | resolvers += Resolver.bintrayIvyRepo("s22s", "sbt-plugins")
4 |
5 | resolvers += Resolver.typesafeRepo("releases")
6 |
7 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2")
8 |
9 | addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.9.0")
10 |
11 | addSbtPlugin("org.typelevel" % "laika-sbt" % "1.2.0")
12 |
13 | addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.3")
--------------------------------------------------------------------------------
/scripts/multi_spark_releases.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # link to the spark-daria GitHub release script: https://github.com/MrPowers/spark-daria/blob/master/scripts/github_release.sh
4 | # need to clone the spark-daria repo and feed the release script as an argument to this script
5 |
6 | SPARK_DARIA_GITHUB_RELEASE=$1
7 | if [ "$SPARK_DARIA_GITHUB_RELEASE" = "" ]
8 | then
9 | echo "spark-daria github_release script path must be set"
10 | exit 1
11 | fi
12 |
13 | for sparkVersion in 2.2.0 2.2.1 2.3.0; do
14 | echo $sparkVersion
15 | sed -i '' "s/^val sparkVersion.*/val sparkVersion = \"$sparkVersion\"/" build.sbt
16 | $SPARK_DARIA_GITHUB_RELEASE package
17 | done
--------------------------------------------------------------------------------