├── .github ├── CODEOWNERS └── workflows │ ├── ci-quality.yml │ ├── ci.yml │ ├── docs.yml │ └── release.yml ├── .gitignore ├── .scalafmt.conf ├── LICENSE ├── README.md ├── benchmarks └── src │ └── main │ └── scala │ └── com │ └── github │ └── mrpowers │ └── spark │ └── fast │ └── tests │ ├── ColumnComparerBenchmark.scala │ └── DataFrameComparerBenchmark.scala ├── build.sbt ├── core └── src │ ├── main │ └── scala │ │ └── com │ │ └── github │ │ └── mrpowers │ │ └── spark │ │ └── fast │ │ └── tests │ │ ├── ArrayUtil.scala │ │ ├── ColumnComparer.scala │ │ ├── DataFrameComparer.scala │ │ ├── DataFramePrettyPrint.scala │ │ ├── DatasetComparer.scala │ │ ├── ProductUtil.scala │ │ ├── RDDComparer.scala │ │ ├── RddHelpers.scala │ │ ├── RowComparer.scala │ │ ├── SchemaComparer.scala │ │ ├── SchemaDiffOutputFormat.scala │ │ ├── SeqLikesExtensions.scala │ │ └── ufansi │ │ ├── Fansi.scala │ │ └── FansiExtensions.scala │ └── test │ ├── resources │ └── log4j.properties │ └── scala │ └── com │ └── github │ └── mrpowers │ └── spark │ └── fast │ └── tests │ ├── ArrayUtilTest.scala │ ├── ColumnComparerTest.scala │ ├── DataFrameComparerTest.scala │ ├── DataFramePrettyPrintTest.scala │ ├── DatasetComparerTest.scala │ ├── ExamplesTest.scala │ ├── RDDComparerTest.scala │ ├── RowComparerTest.scala │ ├── SchemaComparerTest.scala │ ├── SeqLikesExtensionsTest.scala │ ├── SparkSessionExt.scala │ ├── SparkSessionTestWrapper.scala │ └── TestUtilsExt.scala ├── docs └── about │ └── README.md ├── images ├── assertColumnEquality_error_message.png ├── assertSchemaEquality_tree_message.png ├── assertSmallDataFrameEquality_DatasetContentMissmatch_message.png ├── assertSmallDataFrameEquality_DatasetSchemaMisMatch_message.png ├── assertSmallDataFrameEquality_error_message.png └── assertSmallDatasetEquality_error_message.png ├── project ├── build.properties └── plugins.sbt └── scripts └── multi_spark_releases.sh /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @alfonsorr 2 | * @SemyonSinchenko 3 | -------------------------------------------------------------------------------- /.github/workflows/ci-quality.yml: -------------------------------------------------------------------------------- 1 | name: CI-quality 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | 8 | jobs: 9 | build: 10 | strategy: 11 | fail-fast: false 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: coursier/cache-action@v6 16 | with: 17 | extraKey: quality-check 18 | - uses: coursier/setup-action@v1 19 | with: 20 | jvm: zulu:8 21 | - name: Setup sbt launcher 22 | uses: sbt/setup-sbt@v1 23 | with: 24 | sbt-version: 1.10.1 25 | - name: ScalaFmt 26 | id: fmt 27 | run: sbt scalafmtCheckAll 28 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | pull_request: 4 | 5 | jobs: 6 | build: 7 | strategy: 8 | fail-fast: false 9 | matrix: 10 | spark: ["3.1.3","3.2.4", "3.3.4", "3.4.3", "3.5.3"] 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: coursier/cache-action@v6 15 | with: 16 | extraKey: ${{ matrix.spark }} 17 | - uses: coursier/setup-action@v1 18 | with: 19 | jvm: zulu:8 20 | - name: Setup sbt launcher 21 | uses: sbt/setup-sbt@v1 22 | with : 23 | sbt-version: 1.10.1 24 | - name: Test 25 | run: sbt -Dspark.version=${{ matrix.spark }} +test 26 | - name: Benchmark 27 | run: sbt -Dspark.version=${{ matrix.spark }} +benchmarks/Jmh/run 28 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | on: 3 | workflow_dispatch: 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v1 10 | - uses: olafurpg/setup-scala@v10 11 | - name: Setup sbt launcher 12 | uses: sbt/setup-sbt@v1 13 | with: 14 | sbt-version: 1.10.1 15 | - name: Build docs 16 | run: sbt "project docs" laikaSite 17 | - name: Deploy to GH Pages 18 | uses: peaceiris/actions-gh-pages@v4 19 | with: 20 | github_token: ${{ secrets.GITHUB_TOKEN }} 21 | publish_dir: ./target/docs/site 22 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | tags: [ "*" ] 5 | branches: 6 | - main 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: coursier/setup-action@v1 13 | with: 14 | jvm: zulu:11 15 | apps: sbt scala scalafmt 16 | - run: sbt -Dsun.net.client.defaultReadTimeout=60000 -Dsun.net.client.defaultConnectTimeout=60000 -Dspark.version=3.5.3 -v ci-release 17 | env: 18 | PGP_PASSPHRASE: ${{secrets.PGP_PASSPHRASE}} 19 | PGP_SECRET: ${{secrets.PGP_LONG_ID}} 20 | SONATYPE_PASSWORD: ${{secrets.SONATYPE_PASSWORD}} 21 | SONATYPE_USERNAME: ${{secrets.SONATYPE_USERNAME}} 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | lib_managed/ 3 | src_managed/ 4 | project/boot/ 5 | 6 | .DS_Store 7 | 8 | .idea 9 | *.iml 10 | 11 | *.swp 12 | 13 | *.class 14 | *.log 15 | 16 | # sbt specific 17 | .cache 18 | .history 19 | .lib/ 20 | dist/* 21 | target/ 22 | lib_managed/ 23 | src_managed/ 24 | project/boot/ 25 | project/plugins/project/ 26 | **/.bloop 27 | .bsp/ 28 | 29 | # Scala-IDE specific 30 | .scala_dependencies 31 | .worksheet 32 | 33 | .idea 34 | 35 | # ENSIME specific 36 | .ensime_cache/ 37 | .ensime 38 | 39 | .metals/ 40 | .ammonite/ 41 | metals.sbt 42 | metals/project/ 43 | 44 | .vscode/ 45 | 46 | local.* 47 | 48 | .DS_Store 49 | 50 | node_modules 51 | 52 | lib/core/metadata.js 53 | lib/core/MetadataBlog.js 54 | 55 | website/translated_docs 56 | website/build/ 57 | website/yarn.lock 58 | website/node_modules 59 | website/i18n/* 60 | !website/i18n/en.json 61 | website/.docusaurus 62 | website/.cache-loader 63 | 64 | project/metals.sbt 65 | coursier -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version = 3.8.2 2 | 3 | align.preset = more 4 | runner.dialect = scala212 5 | maxColumn = 150 6 | docstrings.style = Asterisk 7 | 8 | //style = defaultWithAlign 9 | //align = true 10 | //danglingParentheses = false 11 | //indentOperator = spray 12 | //maxColumn = 300 13 | //rewrite.rules = [RedundantParens, AvoidInfix, SortImports, PreferCurlyFors] 14 | //unindentTopLevelOperators = true 15 | //verticalMultilineAtDefinitionSite = true 16 | //assumeStandardLibraryStripMargin = true 17 | //includeCurlyBraceInSelectChains = false 18 | // 19 | //align.openParenCallSite = false 20 | //align.openParenDefnSite = false 21 | // 22 | //binPack.parentConstructors = true 23 | //binPack.literalArgumentLists = true 24 | //binPack.unsafeCallSite = false 25 | //binPack.unsafeDefnSite = false 26 | // 27 | //continuationIndent.callSite = 2 28 | //continuationIndent.defnSite = 2 29 | // 30 | //newlines.alwaysBeforeTopLevelStatements = true 31 | //newlines.penalizeSingleSelectMultiArgList = false 32 | // 33 | //optIn.breakChainOnFirstMethodDot = true 34 | // 35 | //runner.optimizer.forceConfigStyleOnOffset = 1 36 | //runner.optimizer.forceConfigStyleMinArgCount = 2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 MrPowers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spark Fast Tests 2 | 3 | [![CI](https://github.com/MrPowers/spark-fast-tests/actions/workflows/ci.yml/badge.svg)](https://github.com/MrPowers/spark-fast-tests/actions/workflows/ci.yml) 4 | 5 | A fast Apache Spark testing helper library with beautifully formatted error messages! Works 6 | with [scalatest](https://github.com/scalatest/scalatest), [uTest](https://github.com/lihaoyi/utest), 7 | and [munit](https://github.com/scalameta/munit). 8 | 9 | Use [chispa](https://github.com/MrPowers/chispa) for PySpark applications. 10 | 11 | Read [Testing Spark Applications](https://leanpub.com/testing-spark) for a full explanation on the best way to test 12 | Spark code! Good test suites yield higher quality codebases that are easy to refactor. 13 | 14 | ## Table of Contents 15 | - [Install](#install) 16 | - [Examples](#simple-examples) 17 | - [Why is this library fast?](#why-is-this-library-fast) 18 | - [Usage](#usage) 19 | - [Local Testing SparkSession](#local-sparksession-for-test) 20 | - [DataFrameComparer](#datasetcomparer--dataframecomparer) 21 | - [Unordered DataFrames comparison](#unordered-dataframe-equality-comparisons) 22 | - [Approximate DataFrames comparison](#approximate-dataframe-equality) 23 | - [Ignore Nullable DataFrames comparison](#equality-comparisons-ignoring-the-nullable-flag) 24 | - [ColumnComparer](#column-equality) 25 | - [SchemaComparer](#schema-equality) 26 | - [Testing tips](#testing-tips) 27 | 28 | 29 | ## Install 30 | 31 | Fetch the JAR file from Maven. 32 | 33 | ```scala 34 | // for Spark 3 35 | libraryDependencies += "com.github.mrpowers" %% "spark-fast-tests" % "1.3.0" % "test" 36 | ``` 37 | 38 | **Important: Future versions of spark-fast-test will no longer support Spark 2.x. We recommend upgrading to Spark 3.x to 39 | ensure compatibility with upcoming releases.** 40 | 41 | Here's a link to the releases for different Scala versions: 42 | 43 | * [Scala 2.11 JAR files](https://repo1.maven.org/maven2/com/github/mrpowers/spark-fast-tests_2.11/) 44 | * [Scala 2.12 JAR files](https://repo1.maven.org/maven2/com/github/mrpowers/spark-fast-tests_2.12/) 45 | * [Scala 2.13 JAR files](https://repo1.maven.org/maven2/com/github/mrpowers/spark-fast-tests_2.13/) 46 | * [Legacy JAR files in Maven](https://mvnrepository.com/artifact/MrPowers/spark-fast-tests?repo=spark-packages). 47 | 48 | You should use Scala 2.11 with Spark 2 and Scala 2.12 / 2.13 with Spark 3. 49 | 50 | ## Simple examples 51 | 52 | The `assertSmallDataFrameEquality` method can be used to compare two DataFrames. 53 | 54 | ```scala 55 | val sourceDF = Seq( 56 | (1), 57 | (5) 58 | ).toDF("number") 59 | 60 | val expectedDF = Seq( 61 | (1), 62 | (3) 63 | ).toDF("number") 64 | 65 | assertSmallDataFrameEquality(sourceDF, expectedDF) 66 | ``` 67 | 68 |

69 | assertSmallDataFrameEquality_DatasetContentMissmatch_message 70 |

71 | 72 | The `assertSmallDatasetEquality` method can be used to compare two Datasets or DataFrames(Dataset[Row]). 73 | Nicely formatted error messages are displayed when the Datasets are not equal. Here is an example of content mismatch: 74 | 75 | ```scala 76 | val sourceDS = Seq( 77 | Person("juan", 5), 78 | Person("bob", 1), 79 | Person("li", 49), 80 | Person("alice", 5) 81 | ).toDS 82 | 83 | val expectedDS = Seq( 84 | Person("juan", 6), 85 | Person("frank", 10), 86 | Person("li", 49), 87 | Person("lucy", 5) 88 | ).toDS 89 | ``` 90 | 91 |

92 | assertSmallDatasetEquality_error_message 93 |

94 | 95 | The colors in the error message make it easy to identify the rows that aren't equal. These method also supports 96 | comparing DataFrames with different schemas. 97 | 98 | ```scala 99 | val sourceDF = spark.createDF( 100 | List( 101 | (1, 2.0), 102 | (5, 3.0) 103 | ), 104 | List( 105 | ("number", IntegerType, true), 106 | ("float", DoubleType, true) 107 | ) 108 | ) 109 | 110 | val expectedDF = spark.createDF( 111 | List( 112 | (1, "word", 1L), 113 | (5, "word", 2L) 114 | ), 115 | List( 116 | ("number", IntegerType, true), 117 | ("word", StringType, true), 118 | ("long", LongType, true) 119 | ) 120 | ) 121 | 122 | assertSmallDataFrameEquality(sourceDF, expectedDF) 123 | ``` 124 | 125 |

126 | assertSmallDataFrameEquality_DatasetSchemaMisMatch_message 127 |

128 | 129 | The `DatasetComparer` has `assertSmallDatasetEquality` and `assertLargeDatasetEquality` methods to compare either 130 | Datasets or DataFrames. 131 | 132 | If you only need to compare DataFrames, you can use `DataFrameComparer` with the associated 133 | `assertSmallDataFrameEquality` and `assertLargeDataFrameEquality` methods. Under the hood, `DataFrameComparer` uses the 134 | `assertSmallDatasetEquality` and `assertLargeDatasetEquality`. 135 | 136 | *Note : comparing Datasets can be tricky since some column names might be given by Spark when applying transformations. 137 | Use the `ignoreColumnNames` boolean to skip name verification.* 138 | 139 | ## Why is this library fast? 140 | 141 | This library provides three main methods to test your code. 142 | 143 | Suppose you'd like to test this function: 144 | 145 | ```scala 146 | def myLowerClean(col: Column): Column = { 147 | lower(regexp_replace(col, "\\s+", "")) 148 | } 149 | ``` 150 | 151 | Here's how long the tests take to execute: 152 | 153 | | test method | runtime | 154 | |--------------------------------|------------------| 155 | | `assertLargeDataFrameEquality` | 709 milliseconds | 156 | | `assertSmallDataFrameEquality` | 166 milliseconds | 157 | | `assertColumnEquality` | 108 milliseconds | 158 | | `evalString` | 26 milliseconds | 159 | 160 | `evalString` isn't as robust, but is the fastest. `assertColumnEquality` is robust and saves a lot of time. 161 | 162 | Other testing libraries don't have methods like `assertSmallDataFrameEquality` or `assertColumnEquality` so they run 163 | slower. 164 | 165 | ## Usage 166 | 167 | ### Local SparkSession for test 168 | The spark-fast-tests project doesn't provide a SparkSession object in your test suite, so you'll need to make one 169 | yourself. 170 | 171 | ```scala 172 | import org.apache.spark.sql.SparkSession 173 | 174 | trait SparkSessionTestWrapper { 175 | 176 | lazy val spark: SparkSession = { 177 | SparkSession 178 | .builder() 179 | .master("local") 180 | .appName("spark session") 181 | .config("spark.sql.shuffle.partitions", "1") 182 | .getOrCreate() 183 | } 184 | 185 | } 186 | ``` 187 | 188 | It's best set the number of shuffle partitions to a small number like one or four in your test suite. This configuration 189 | can make your tests run up to 70% faster. You can remove this configuration option or adjust it if you're working with 190 | big DataFrames in your test suite. 191 | 192 | Make sure to only use the `SparkSessionTestWrapper` trait in your test suite. You don't want to use test specific 193 | configuration (like one shuffle partition) when running production code. 194 | 195 | ### DatasetComparer / DataFrameComparer 196 | The `DatasetComparer` trait defines the `assertSmallDatasetEquality` method. Extend your spec file with the 197 | `SparkSessionTestWrapper` trait to create DataFrames and the `DatasetComparer` trait to make DataFrame comparisons. 198 | 199 | ```scala 200 | import org.apache.spark.sql.types._ 201 | import org.apache.spark.sql.Row 202 | import org.apache.spark.sql.functions._ 203 | import com.github.mrpowers.spark.fast.tests.DatasetComparer 204 | 205 | class DatasetSpec extends FunSpec with SparkSessionTestWrapper with DatasetComparer { 206 | 207 | import spark.implicits._ 208 | 209 | it("aliases a DataFrame") { 210 | 211 | val sourceDF = Seq( 212 | ("jose"), 213 | ("li"), 214 | ("luisa") 215 | ).toDF("name") 216 | 217 | val actualDF = sourceDF.select(col("name").alias("student")) 218 | 219 | val expectedDF = Seq( 220 | ("jose"), 221 | ("li"), 222 | ("luisa") 223 | ).toDF("student") 224 | 225 | assertSmallDatasetEquality(actualDF, expectedDF) 226 | 227 | } 228 | } 229 | ``` 230 | 231 | To compare large DataFrames that are partitioned across different nodes in a cluster, use the 232 | `assertLargeDatasetEquality` method. 233 | 234 | ```scala 235 | assertLargeDatasetEquality(actualDF, expectedDF) 236 | ``` 237 | 238 | `assertSmallDatasetEquality` is faster for test suites that run on your local machine. `assertLargeDatasetEquality` 239 | should only be used for DataFrames that are split across nodes in a cluster. 240 | 241 | #### Unordered DataFrame equality comparisons 242 | 243 | Suppose you have the following `actualDF`: 244 | 245 | ``` 246 | +------+ 247 | |number| 248 | +------+ 249 | | 1| 250 | | 5| 251 | +------+ 252 | ``` 253 | 254 | And suppose you have the following `expectedDF`: 255 | 256 | ``` 257 | +------+ 258 | |number| 259 | +------+ 260 | | 5| 261 | | 1| 262 | +------+ 263 | ``` 264 | 265 | The DataFrames have the same columns and rows, but the order is different. 266 | 267 | `assertSmallDataFrameEquality(sourceDF, expectedDF)` will throw a `DatasetContentMismatch` error. 268 | 269 | We can set the `orderedComparison` boolean flag to `false` and spark-fast-tests will sort the DataFrames before 270 | performing the comparison. 271 | 272 | `assertSmallDataFrameEquality(sourceDF, expectedDF, orderedComparison = false)` will not throw an error. 273 | 274 | #### Equality comparisons ignoring the nullable flag 275 | 276 | You might also want to make equality comparisons that ignore the nullable flags for the DataFrame columns. 277 | 278 | Here is how to use the `ignoreNullable` flag to compare DataFrames without considering the nullable property of each 279 | column. 280 | 281 | ```scala 282 | val sourceDF = spark.createDF( 283 | List( 284 | (1), 285 | (5) 286 | ), List( 287 | ("number", IntegerType, false) 288 | ) 289 | ) 290 | 291 | val expectedDF = spark.createDF( 292 | List( 293 | (1), 294 | (5) 295 | ), List( 296 | ("number", IntegerType, true) 297 | ) 298 | ) 299 | 300 | assertSmallDatasetEquality(sourceDF, expectedDF, ignoreNullable = true) 301 | ``` 302 | 303 | #### Approximate DataFrame Equality 304 | 305 | The `assertApproximateDataFrameEquality` function is useful for DataFrames that contain `DoubleType` columns. The 306 | precision threshold must be set when using the `assertApproximateDataFrameEquality` function. 307 | 308 | ```scala 309 | val sourceDF = spark.createDF( 310 | List( 311 | (1.2), 312 | (5.1), 313 | (null) 314 | ), List( 315 | ("number", DoubleType, true) 316 | ) 317 | ) 318 | 319 | val expectedDF = spark.createDF( 320 | List( 321 | (1.2), 322 | (5.1), 323 | (null) 324 | ), List( 325 | ("number", DoubleType, true) 326 | ) 327 | ) 328 | 329 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01) 330 | ``` 331 | 332 | ### Column Equality 333 | 334 | The `assertColumnEquality` method can be used to assess the equality of two columns in a DataFrame. 335 | 336 | Suppose you have the following DataFrame with two columns that are not equal. 337 | 338 | ``` 339 | +-------+-------------+ 340 | | name|expected_name| 341 | +-------+-------------+ 342 | | phil| phil| 343 | | rashid| rashid| 344 | |matthew| mateo| 345 | | sami| sami| 346 | | li| feng| 347 | | null| null| 348 | +-------+-------------+ 349 | ``` 350 | 351 | The following code will throw a `ColumnMismatch` error message: 352 | 353 | ```scala 354 | assertColumnEquality(df, "name", "expected_name") 355 | ``` 356 | 357 |

358 | Description 359 |

360 | 361 | Mix in the `ColumnComparer` trait to your test class to access the `assertColumnEquality` method: 362 | 363 | ```scala 364 | import com.github.mrpowers.spark.fast.tests.ColumnComparer 365 | 366 | object MySpecialClassTest 367 | extends TestSuite 368 | with ColumnComparer 369 | with SparkSessionTestWrapper { 370 | 371 | // your tests 372 | } 373 | ``` 374 | 375 | ### Schema Equality 376 | 377 | The SchemaComparer provide `assertSchemaEqual` API which is useful for comparing schema of dataframe schema 378 | 379 | Consider the following two schemas: 380 | 381 | ```scala 382 | val s1 = StructType( 383 | Seq( 384 | StructField("array", ArrayType(StringType, containsNull = true), true), 385 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 386 | StructField("something", StringType, true), 387 | StructField( 388 | "struct", 389 | StructType( 390 | StructType( 391 | Seq( 392 | StructField("mood", ArrayType(StringType, containsNull = false), true), 393 | StructField("something", StringType, false), 394 | StructField( 395 | "something2", 396 | StructType( 397 | Seq( 398 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 399 | StructField("something2", StringType, false) 400 | ) 401 | ), 402 | false 403 | ) 404 | ) 405 | ) 406 | ), 407 | true 408 | ) 409 | ) 410 | ) 411 | val s2 = StructType( 412 | Seq( 413 | StructField("array", ArrayType(StringType, containsNull = true), true), 414 | StructField("something", StringType, true), 415 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 416 | StructField( 417 | "struct", 418 | StructType( 419 | StructType( 420 | Seq( 421 | StructField("something", StringType, false), 422 | StructField("mood", ArrayType(StringType, containsNull = false), true), 423 | StructField( 424 | "something3", 425 | StructType( 426 | Seq( 427 | StructField("mood3", ArrayType(StringType, containsNull = false), true) 428 | ) 429 | ), 430 | false 431 | ) 432 | ) 433 | ) 434 | ), 435 | true 436 | ), 437 | StructField("norma2", StringType, false) 438 | ) 439 | ) 440 | 441 | ``` 442 | 443 | The `assertSchemaEqual` support two output format `SchemaDiffOutputFormat.Tree` and `SchemaDiffOutputFormat.Table`. Tree 444 | output 445 | format is useful when the schema is large and contains multi level nested fields. 446 | 447 | ```scala 448 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree) 449 | ``` 450 | 451 |

452 | assert_column_equality_error_message 453 |

454 | 455 | By default `SchemaDiffOutputFormat.Table` is used internally by all dataframe/dataset comparison APIs. 456 | 457 | ## Testing Tips 458 | 459 | * Use column functions instead of UDFs as described 460 | in [this blog post](https://medium.com/@mrpowers/spark-user-defined-functions-udfs-6c849e39443b) 461 | * Try to organize your code 462 | as [custom transformations](https://medium.com/@mrpowers/chaining-custom-dataframe-transformations-in-spark-a39e315f903c) 463 | so it's easy to test the logic elegantly 464 | * Don't write tests that read from files or write files. Dependency injection is a great way to avoid file I/O in you 465 | test suite. 466 | 467 | ## Alternatives 468 | 469 | The [spark-testing-base](https://github.com/holdenk/spark-testing-base) project has more features (e.g. streaming 470 | support) and is compiled to support a variety of Scala and Spark versions. 471 | 472 | You might want to use spark-fast-tests instead of spark-testing-base in these cases: 473 | 474 | * You want to use uTest or a testing framework other than scalatest 475 | * You want to run tests in parallel (you need to set `parallelExecution in Test := false` with spark-testing-base) 476 | * You don't want to include hive as a project dependency 477 | * You don't want to restart the SparkSession after each test file executes so the suite runs faster 478 | 479 | ## Publishing 480 | 481 | GPG & Sonatype need to be setup properly before running these commands. See 482 | the [spark-daria](https://github.com/MrPowers/spark-daria) README for more information. 483 | 484 | It's a good idea to always run `clean` before running any publishing commands. It's also important to run `clean` before 485 | different publishing commands as well. 486 | 487 | There is a two step process for publishing. 488 | 489 | Generate Scala 2.11 JAR files: 490 | 491 | * Run `sbt -Dspark.version=2.4.8` 492 | * Run `> ; + publishSigned; sonatypeBundleRelease` to create the JAR files and release them to Maven. 493 | 494 | Generate Scala 2.12 & Scala 2.13 JAR files: 495 | 496 | * Run `sbt` 497 | * Run `> ; + publishSigned; sonatypeBundleRelease` 498 | 499 | The `publishSigned` and `sonatypeBundleRelease` commands are made available by 500 | the [sbt-sonatype](https://github.com/xerial/sbt-sonatype) plugin. 501 | 502 | When the release command is run, you'll be prompted to enter your GPG passphrase. 503 | 504 | The Sonatype credentials should be stored in the `~/.sbt/sonatype_credentials` file in this format: 505 | 506 | ``` 507 | realm=Sonatype Nexus Repository Manager 508 | host=oss.sonatype.org 509 | user=$USERNAME 510 | password=$PASSWORD 511 | ``` 512 | 513 | ## Additional Goals 514 | 515 | * Use memory efficiently so Spark test runs don't crash 516 | * Provide readable error messages 517 | * Easy to use in conjunction with other test suites 518 | * Give the user control of the SparkSession 519 | 520 | ## Contributing 521 | 522 | Open an issue or send a pull request to contribute. Anyone that makes good contributions to the project will be promoted 523 | to project maintainer status. 524 | 525 | ## uTest settings to display color output 526 | 527 | Create a `CustomFramework` class with overrides that turn off the default uTest color settings. 528 | 529 | ```scala 530 | package com.github.mrpowers.spark.fast.tests 531 | 532 | class CustomFramework extends utest.runner.Framework { 533 | override def formatWrapWidth: Int = 300 534 | 535 | // turn off the default exception message color, so spark-fast-tests 536 | // can send messages with custom colors 537 | override def exceptionMsgColor = toggledColor(utest.ufansi.Attrs.Empty) 538 | 539 | override def exceptionPrefixColor = toggledColor(utest.ufansi.Attrs.Empty) 540 | 541 | override def exceptionMethodColor = toggledColor(utest.ufansi.Attrs.Empty) 542 | 543 | override def exceptionPunctuationColor = toggledColor(utest.ufansi.Attrs.Empty) 544 | 545 | override def exceptionLineNumberColor = toggledColor(utest.ufansi.Attrs.Empty) 546 | } 547 | ``` 548 | 549 | Update the `build.sbt` file to use the `CustomFramework` class: 550 | 551 | ```scala 552 | testFrameworks += new TestFramework("com.github.mrpowers.spark.fast.tests.CustomFramework") 553 | ``` 554 | 555 | 556 | -------------------------------------------------------------------------------- /benchmarks/src/main/scala/com/github/mrpowers/spark/fast/tests/ColumnComparerBenchmark.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.SparkSession 4 | import org.openjdk.jmh.annotations._ 5 | import org.openjdk.jmh.infra.Blackhole 6 | 7 | import java.util.concurrent.TimeUnit 8 | import scala.util.Try 9 | 10 | private class ColumnComparerBenchmark extends ColumnComparer { 11 | @Benchmark 12 | @BenchmarkMode(Array(Mode.SingleShotTime)) 13 | @Fork(value = 2) 14 | @Warmup(iterations = 10) 15 | @Measurement(iterations = 10) 16 | @OutputTimeUnit(TimeUnit.NANOSECONDS) 17 | def assertColumnEqualityBenchmarks(blackHole: Blackhole): Boolean = { 18 | val spark = SparkSession 19 | .builder() 20 | .master("local") 21 | .appName("spark session") 22 | .config("spark.sql.shuffle.partitions", "1") 23 | .getOrCreate() 24 | spark.sparkContext.setLogLevel("ERROR") 25 | 26 | import spark.implicits._ 27 | val ds1 = (Seq.fill(100)(("1", "2")) ++ Seq.fill(100)(("3", "4"))).toDF("col_B", "col_A") 28 | 29 | val result = Try(assertColumnEquality(ds1, "col_B", "col_A")) 30 | 31 | blackHole.consume(result) 32 | result.isSuccess 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /benchmarks/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerBenchmark.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.SparkSession 4 | import org.openjdk.jmh.annotations._ 5 | import org.openjdk.jmh.infra.Blackhole 6 | 7 | import java.util.concurrent.TimeUnit 8 | import scala.util.Try 9 | 10 | private class DataFrameComparerBenchmark extends DataFrameComparer { 11 | @Benchmark 12 | @BenchmarkMode(Array(Mode.SingleShotTime)) 13 | @Fork(value = 2) 14 | @Warmup(iterations = 10) 15 | @Measurement(iterations = 10) 16 | @OutputTimeUnit(TimeUnit.NANOSECONDS) 17 | def assertApproximateDataFrameEqualityWithPrecision(blackHole: Blackhole): Boolean = { 18 | val spark = SparkSession 19 | .builder() 20 | .master("local") 21 | .appName("spark session") 22 | .config("spark.sql.shuffle.partitions", "1") 23 | .getOrCreate() 24 | spark.sparkContext.setLogLevel("ERROR") 25 | 26 | import spark.implicits._ 27 | val sameData = Seq.fill(50)("1", "10/01/2019", 26.762499999999996) 28 | val ds1Data = sameData ++ Seq.fill(50)("1", "11/01/2019", 26.76249999999991) 29 | val ds2Data = sameData ++ Seq.fill(50)("3", "12/01/2019", 26.76249999999991) 30 | 31 | val ds1 = ds1Data.toDF("col_B", "col_C", "col_A") 32 | val ds2 = ds2Data.toDF("col_B", "col_C", "col_A") 33 | 34 | val result = Try(assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false)) 35 | 36 | blackHole.consume(result) 37 | result.isSuccess 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | val versionRegex = """^(.*)\.(.*)\.(.*)$""".r 2 | val scala2_13 = "2.13.14" 3 | val scala2_12 = "2.12.20" 4 | val sparkVersion = System.getProperty("spark.version", "3.5.3") 5 | val noPublish = Seq( 6 | (publish / skip) := true, 7 | publishArtifact := false 8 | ) 9 | 10 | inThisBuild( 11 | List( 12 | organization := "com.github.mrpowers", 13 | homepage := Some(url("https://github.com/mrpowers-io/spark-fast-tests")), 14 | licenses := Seq("MIT" -> url("http://opensource.org/licenses/MIT")), 15 | developers ++= List( 16 | Developer("MrPowers", "Matthew Powers", "@MrPowers", url("https://github.com/MrPowers")) 17 | ), 18 | Compile / scalafmtOnCompile := true, 19 | Test / fork := true, 20 | crossScalaVersions := { 21 | sparkVersion match { 22 | case versionRegex("3", m, _) if m.toInt >= 2 => Seq(scala2_12, scala2_13) 23 | case versionRegex("3", _, _) => Seq(scala2_12) 24 | } 25 | }, 26 | scalaVersion := crossScalaVersions.value.head 27 | ) 28 | ) 29 | 30 | enablePlugins(GitVersioning) 31 | 32 | lazy val commonSettings = Seq( 33 | javaOptions ++= { 34 | Seq("-Xms512M", "-Xmx2048M", "-Duser.timezone=GMT") ++ (if (System.getProperty("java.version").startsWith("1.8.0")) 35 | Seq("-XX:+CMSClassUnloadingEnabled") 36 | else Seq.empty) 37 | }, 38 | libraryDependencies ++= Seq( 39 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", 40 | "org.scalatest" %% "scalatest" % "3.2.18" % "test" 41 | ) 42 | ) 43 | 44 | lazy val core = (project in file("core")) 45 | .settings( 46 | commonSettings, 47 | moduleName := "spark-fast-tests", 48 | name := moduleName.value, 49 | Compile / packageSrc / publishArtifact := true, 50 | Compile / packageDoc / publishArtifact := true 51 | ) 52 | 53 | lazy val benchmarks = (project in file("benchmarks")) 54 | .dependsOn(core) 55 | .settings(commonSettings) 56 | .settings( 57 | libraryDependencies ++= Seq( 58 | "org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" // required for jmh IDEA plugin. Make sure this version matches sbt-jmh version! 59 | ), 60 | name := "benchmarks" 61 | ) 62 | .settings(noPublish) 63 | .enablePlugins(JmhPlugin) 64 | 65 | lazy val docs = (project in file("docs")) 66 | .dependsOn(core) 67 | .enablePlugins(LaikaPlugin) 68 | .settings( 69 | name := "docs", 70 | laikaTheme := { 71 | import laika.ast.Path.Root 72 | import laika.helium.Helium 73 | import laika.helium.config.* 74 | 75 | Helium.defaults.site 76 | .landingPage( 77 | title = Some("Spark Fast Tests"), 78 | subtitle = Some("Unit testing your Apache Spark application"), 79 | latestReleases = Seq( 80 | ReleaseInfo("Latest Stable Release", "1.0.0") 81 | ), 82 | license = Some("MIT"), 83 | titleLinks = Seq( 84 | VersionMenu.create(unversionedLabel = "Getting Started"), 85 | LinkGroup.create( 86 | IconLink.external("https://github.com/mrpowers-io/spark-fast-tests", HeliumIcon.github) 87 | ) 88 | ), 89 | linkPanel = Some( 90 | LinkPanel( 91 | "Documentation", 92 | TextLink.internal(Root / "about" / "README.md", "Spark Fast Tests") 93 | ) 94 | ), 95 | projectLinks = Seq( 96 | LinkGroup.create( 97 | TextLink.internal(Root / "api" / "com" / "github" / "mrpowers" / "spark" / "fast" / "tests" / "index.html", "API (Scaladoc)") 98 | ) 99 | ), 100 | teasers = Seq( 101 | Teaser("Fast", "Handle small dataframes effectively and provide column assertions"), 102 | Teaser("Flexible", "Works fine with scalatest, uTest, munit") 103 | ) 104 | ) 105 | .build 106 | }, 107 | laikaIncludeAPI := true, 108 | laikaExtensions ++= { 109 | import laika.config.SyntaxHighlighting 110 | import laika.format.Markdown 111 | Seq(Markdown.GitHubFlavor, SyntaxHighlighting) 112 | }, 113 | Laika / sourceDirectories := Seq((ThisBuild / baseDirectory).value / "docs") 114 | ) 115 | .settings(noPublish) 116 | 117 | lazy val root = (project in file(".")) 118 | .settings( 119 | name := "spark-fast-tests-root", 120 | commonSettings, 121 | noPublish 122 | ) 123 | .aggregate(core, benchmarks, docs) 124 | 125 | scmInfo := Some(ScmInfo(url("https://github.com/mrpowers-io/spark-fast-tests"), "git@github.com:MrPowers/spark-fast-tests.git")) 126 | 127 | updateOptions := updateOptions.value.withLatestSnapshots(false) 128 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/ArrayUtil.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import java.sql.Date 4 | import com.github.mrpowers.spark.fast.tests.ufansi.EscapeAttr 5 | import java.time.format.DateTimeFormatter 6 | import org.apache.commons.lang3.StringUtils 7 | 8 | object ArrayUtil { 9 | 10 | def weirdTypesToStrings(arr: Array[(Any, Any)], truncate: Int = 20): Array[List[String]] = { 11 | arr.map { row => 12 | row.productIterator.toList.map { cell => 13 | val str = cell match { 14 | case null => "null" 15 | case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") 16 | case array: Array[_] => array.mkString("[", ", ", "]") 17 | case seq: Seq[_] => seq.mkString("[", ", ", "]") 18 | case d: Date => d.toLocalDate.format(DateTimeFormatter.ISO_DATE) 19 | case _ => cell.toString 20 | } 21 | if (truncate > 0 && str.length > truncate) { 22 | // do not show ellipses for strings shorter than 4 characters. 23 | if (truncate < 4) 24 | str.substring(0, truncate) 25 | else 26 | str.substring(0, truncate - 3) + "..." 27 | } else { 28 | str 29 | } 30 | } 31 | } 32 | } 33 | 34 | def showTwoColumnString(arr: Array[(Any, Any)], truncate: Int = 20): String = { 35 | showTwoColumnStringColorCustomizable(arr, truncate = truncate) 36 | } 37 | 38 | def showTwoColumnStringColorCustomizable( 39 | arr: Array[(Any, Any)], 40 | rowEqual: Option[Array[Boolean]] = None, 41 | truncate: Int = 20, 42 | equalColor: EscapeAttr = ufansi.Color.Blue, 43 | unequalColorLeft: EscapeAttr = ufansi.Color.Red, 44 | unequalColorRight: EscapeAttr = ufansi.Color.Green 45 | ): String = { 46 | val sb = new StringBuilder 47 | val numCols = 2 48 | val rows = weirdTypesToStrings(arr, truncate) 49 | 50 | // Initialise the width of each column to a minimum value of '3' 51 | val colWidths = Array.fill(numCols)(3) 52 | 53 | // Compute the width of each column 54 | for (row <- rows) { 55 | for ((cell, i) <- row.zipWithIndex) { 56 | colWidths(i) = math.max(colWidths(i), cell.length) 57 | } 58 | } 59 | 60 | // Create SeparateLine 61 | val sep: String = 62 | colWidths 63 | .map("-" * _) 64 | .addString(sb, "+", "+", "+\n") 65 | .toString() 66 | 67 | // column names 68 | rows.head.zipWithIndex 69 | .map { case (cell, i) => 70 | if (truncate > 0) { 71 | StringUtils.leftPad(cell, colWidths(i)) 72 | } else { 73 | StringUtils.rightPad(cell, colWidths(i)) 74 | } 75 | } 76 | .addString(sb, "|", "|", "|\n") 77 | 78 | sb.append(sep) 79 | 80 | // data 81 | rows.tail.zipWithIndex.map { case (row, j) => 82 | row.zipWithIndex 83 | .map { case (cell, i) => 84 | val r = if (truncate > 0) { 85 | StringUtils.leftPad(cell, colWidths(i)) 86 | } else { 87 | StringUtils.rightPad(cell, colWidths(i)) 88 | } 89 | if (rowEqual.fold(row.head == row(1))(_(j))) { 90 | equalColor(r) 91 | } else if (i == 0) { 92 | unequalColorLeft(r) 93 | } else { 94 | unequalColorRight(r) 95 | } 96 | } 97 | .addString(sb, "|", "|", "|\n") 98 | } 99 | 100 | sb.append(sep) 101 | 102 | sb.toString() 103 | } 104 | 105 | // The following code is taken from: https://github.com/scala/scala/blob/86e75db7f36bcafdd75302f2c2cca0c68413214d/src/partest/scala/tools/partest/Util.scala 106 | def prettyArray(a: Array[_]): collection.IndexedSeq[Any] = new collection.AbstractSeq[Any] with collection.IndexedSeq[Any] { 107 | def length: Int = a.length 108 | 109 | def apply(idx: Int): Any = a(idx) match { 110 | case x: AnyRef if x.getClass.isArray => prettyArray(x.asInstanceOf[Array[_]]) 111 | case x => x 112 | } 113 | 114 | // Ignore deprecation warning in 2.13 - this is the correct def for 2.12 compatibility 115 | override def stringPrefix: String = "Array" 116 | } 117 | 118 | // The following code is taken from: https://github.com/scala/scala/blob/86e75db7f36bcafdd75302f2c2cca0c68413214d/src/partest/scala/tools/partest/Util.scala 119 | implicit class ArrayDeep(val a: Array[_]) extends AnyVal { 120 | def deep: collection.IndexedSeq[Any] = prettyArray(a) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/ColumnComparer.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.DataFrame 4 | import org.apache.spark.sql.Row 5 | import org.apache.spark.sql.functions._ 6 | import org.apache.spark.sql.types.StructField 7 | 8 | case class ColumnMismatch(smth: String) extends Exception(smth) 9 | 10 | trait ColumnComparer { 11 | 12 | def assertColumnEquality(df: DataFrame, colName1: String, colName2: String): Unit = { 13 | val elements = df 14 | .select(colName1, colName2) 15 | .collect() 16 | val colName1Elements = elements.map(_(0)) 17 | val colName2Elements = elements.map(_(1)) 18 | if (!colName1Elements.sameElements(colName2Elements)) { 19 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs" 20 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnString( 21 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements) 22 | ) 23 | throw ColumnMismatch(mismatchMessage) 24 | } 25 | } 26 | 27 | // ace stands for 'assertColumnEquality' 28 | def ace(df: DataFrame, colName1: String, colName2: String): Unit = { 29 | assertColumnEquality(df, colName1, colName2) 30 | } 31 | 32 | // possibly rename this to assertDeepColumnEquality... 33 | // would deep equality comparison help when comparing other types of columns? 34 | def assertBinaryTypeColumnEquality(df: DataFrame, colName1: String, colName2: String): Unit = { 35 | import ArrayUtil._ 36 | 37 | val elements = df 38 | .select(colName1, colName2) 39 | .collect() 40 | val colName1Elements = elements.map(_(0)) 41 | val colName2Elements = elements.map(_(1)) 42 | if (!colName1Elements.deep.sameElements(colName2Elements.deep)) { 43 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs" 44 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnString( 45 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements) 46 | ) 47 | throw ColumnMismatch(mismatchMessage) 48 | } 49 | } 50 | 51 | // private def approximatelyEqualDouble(x: Double, y: Double, precision: Double): Boolean = { 52 | // (x - y).abs < precision 53 | // } 54 | // 55 | // private def areDoubleArraysEqual(x: Array[Double], y: Array[Double], precision: Double): Boolean = { 56 | // x.zip(y).forall { t => 57 | // approximatelyEqualDouble(t._1, t._2, precision) 58 | // } 59 | // } 60 | 61 | def assertDoubleTypeColumnEquality(df: DataFrame, colName1: String, colName2: String, precision: Double = 0.01): Unit = { 62 | val elements: Array[Row] = df 63 | .select(colName1, colName2) 64 | .collect() 65 | val rowsEqual = scala.collection.mutable.ArrayBuffer[Boolean]() 66 | var allRowsAreEqual = true 67 | for (i <- 0 until elements.length) { 68 | var t = elements(i) 69 | if (t(0) == null && t(1) == null) { 70 | rowsEqual += true 71 | } else if (t(0) != null && t(1) == null) { 72 | rowsEqual += false 73 | allRowsAreEqual = false 74 | } else if (t(0) == null && t(1) != null) { 75 | rowsEqual += false 76 | allRowsAreEqual = false 77 | } else { 78 | if (!((t(0).toString.toDouble - t(1).toString.toDouble).abs < precision)) { 79 | rowsEqual += false 80 | allRowsAreEqual = false 81 | } else { 82 | rowsEqual += true 83 | } 84 | } 85 | } 86 | if (!allRowsAreEqual) { 87 | val colName1Elements = elements.map(_(0)) 88 | val colName2Elements = elements.map(_(1)) 89 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs" 90 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnStringColorCustomizable( 91 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements), 92 | Some(rowsEqual.toArray) 93 | ) 94 | throw ColumnMismatch(mismatchMessage) 95 | } 96 | } 97 | 98 | // private def approximatelyEqualFloat(x: Float, y: Float, precision: Float): Boolean = { 99 | // (x - y).abs < precision 100 | // } 101 | // 102 | // private def areFloatArraysEqual(x: Array[Float], y: Array[Float], precision: Float): Boolean = { 103 | // x.zip(y).forall { t => 104 | // approximatelyEqualFloat(t._1, t._2, precision) 105 | // } 106 | // } 107 | 108 | def assertFloatTypeColumnEquality(df: DataFrame, colName1: String, colName2: String, precision: Float): Unit = { 109 | val elements: Array[Row] = df 110 | .select(colName1, colName2) 111 | .collect() 112 | val rowsEqual = scala.collection.mutable.ArrayBuffer[Boolean]() 113 | var allRowsAreEqual = true 114 | for (i <- 0 until elements.length) { 115 | var t = elements(i) 116 | if (t(0) == null && t(1) == null) { 117 | rowsEqual += true 118 | } else if (t(0) != null && t(1) == null) { 119 | rowsEqual += false 120 | allRowsAreEqual = false 121 | } else if (t(0) == null && t(1) != null) { 122 | rowsEqual += false 123 | allRowsAreEqual = false 124 | } else { 125 | if (!((t(0).toString.toFloat - t(1).toString.toFloat).abs < precision)) { 126 | rowsEqual += false 127 | allRowsAreEqual = false 128 | } else { 129 | rowsEqual += true 130 | } 131 | } 132 | } 133 | if (!allRowsAreEqual) { 134 | val colName1Elements = elements.map(_(0)) 135 | val colName2Elements = elements.map(_(1)) 136 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs" 137 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnStringColorCustomizable( 138 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements), 139 | Some(rowsEqual.toArray) 140 | ) 141 | throw ColumnMismatch(mismatchMessage) 142 | } 143 | } 144 | 145 | def assertColEquality(df: DataFrame, colName1: String, colName2: String): Unit = { 146 | val cn = s"are_${colName1}_and_${colName2}_equal" 147 | 148 | val schema = df.schema 149 | val sf1: StructField = schema.find((sf: StructField) => sf.name == colName1).get 150 | val sf2 = schema.find((sf: StructField) => sf.name == colName2).get 151 | 152 | if (sf1.dataType != sf2.dataType) { 153 | throw ColumnMismatch( 154 | s"The column dataTypes are different. The `${colName1}` column has a `${sf1.dataType}` dataType and the `${colName2}` column has a `${sf2.dataType}` dataType." 155 | ) 156 | } 157 | 158 | val r = df 159 | .withColumn( 160 | cn, 161 | col(colName1) <=> col(colName2) 162 | ) 163 | .select(cn) 164 | .collect() 165 | 166 | if (r.contains(Row(false))) { 167 | val elements = df 168 | .select(colName1, colName2) 169 | .collect() 170 | val colName1Elements = elements.map(_(0)) 171 | val colName2Elements = elements.map(_(1)) 172 | // Diffs\n is a hack, but a newline isn't added in ScalaTest unless we add "Diffs" 173 | val mismatchMessage = "Diffs\n" + ArrayUtil.showTwoColumnString( 174 | Array((colName1, colName2)) ++ colName1Elements.zip(colName2Elements) 175 | ) 176 | throw ColumnMismatch(mismatchMessage) 177 | } 178 | } 179 | 180 | } 181 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparer.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.{DataFrame, Row} 4 | trait DataFrameComparer extends DatasetComparer { 5 | 6 | /** 7 | * Raises an error unless `actualDF` and `expectedDF` are equal 8 | */ 9 | def assertSmallDataFrameEquality( 10 | actualDF: DataFrame, 11 | expectedDF: DataFrame, 12 | ignoreNullable: Boolean = false, 13 | ignoreColumnNames: Boolean = false, 14 | orderedComparison: Boolean = true, 15 | ignoreColumnOrder: Boolean = false, 16 | ignoreMetadata: Boolean = true, 17 | truncate: Int = 500 18 | ): Unit = { 19 | assertSmallDatasetEquality( 20 | actualDF, 21 | expectedDF, 22 | ignoreNullable, 23 | ignoreColumnNames, 24 | orderedComparison, 25 | ignoreColumnOrder, 26 | ignoreMetadata, 27 | truncate 28 | ) 29 | } 30 | 31 | /** 32 | * Raises an error unless `actualDF` and `expectedDF` are equal 33 | */ 34 | def assertLargeDataFrameEquality( 35 | actualDF: DataFrame, 36 | expectedDF: DataFrame, 37 | ignoreNullable: Boolean = false, 38 | ignoreColumnNames: Boolean = false, 39 | orderedComparison: Boolean = true, 40 | ignoreColumnOrder: Boolean = false, 41 | ignoreMetadata: Boolean = true 42 | ): Unit = { 43 | assertLargeDatasetEquality( 44 | actualDF, 45 | expectedDF, 46 | ignoreNullable = ignoreNullable, 47 | ignoreColumnNames = ignoreColumnNames, 48 | orderedComparison = orderedComparison, 49 | ignoreColumnOrder = ignoreColumnOrder, 50 | ignoreMetadata = ignoreMetadata 51 | ) 52 | } 53 | 54 | /** 55 | * Raises an error unless `actualDF` and `expectedDF` are equal 56 | */ 57 | def assertApproximateSmallDataFrameEquality( 58 | actualDF: DataFrame, 59 | expectedDF: DataFrame, 60 | precision: Double, 61 | ignoreNullable: Boolean = false, 62 | ignoreColumnNames: Boolean = false, 63 | orderedComparison: Boolean = true, 64 | ignoreColumnOrder: Boolean = false, 65 | ignoreMetadata: Boolean = true 66 | ): Unit = { 67 | assertSmallDatasetEquality[Row]( 68 | actualDF, 69 | expectedDF, 70 | ignoreNullable, 71 | ignoreColumnNames, 72 | orderedComparison, 73 | ignoreColumnOrder, 74 | ignoreMetadata, 75 | equals = RowComparer.areRowsEqual(_, _, precision) 76 | ) 77 | } 78 | 79 | /** 80 | * Raises an error unless `actualDF` and `expectedDF` are equal 81 | */ 82 | def assertApproximateLargeDataFrameEquality( 83 | actualDF: DataFrame, 84 | expectedDF: DataFrame, 85 | precision: Double, 86 | ignoreNullable: Boolean = false, 87 | ignoreColumnNames: Boolean = false, 88 | orderedComparison: Boolean = true, 89 | ignoreColumnOrder: Boolean = false, 90 | ignoreMetadata: Boolean = true 91 | ): Unit = { 92 | assertLargeDatasetEquality[Row]( 93 | actualDF, 94 | expectedDF, 95 | equals = RowComparer.areRowsEqual(_, _, precision), 96 | ignoreNullable, 97 | ignoreColumnNames, 98 | orderedComparison, 99 | ignoreColumnOrder, 100 | ignoreMetadata 101 | ) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFramePrettyPrint.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import java.sql.Date 4 | import java.time.format.DateTimeFormatter 5 | import org.apache.commons.lang3.StringUtils 6 | 7 | import org.apache.spark.sql.{DataFrame, Row} 8 | 9 | object DataFramePrettyPrint { 10 | 11 | def showString(df: DataFrame, _numRows: Int, truncate: Int = 20): String = { 12 | val numRows = _numRows.max(0) 13 | val takeResult = df.take(numRows + 1) 14 | val hasMoreData = takeResult.length > numRows 15 | val data = takeResult.take(numRows) 16 | 17 | // For array values, replace Seq and Array with square brackets 18 | // For cells that are beyond `truncate` characters, replace it with the 19 | // first `truncate-3` and "..." 20 | val rows: Seq[Seq[String]] = df.schema.fieldNames.toSeq +: data.map { row => 21 | row.toSeq.map { cell => 22 | val str = cell match { 23 | case null => "null" 24 | case binary: Array[Byte] => 25 | binary 26 | .map("%02X".format(_)) 27 | .mkString( 28 | "[", 29 | " ", 30 | "]" 31 | ) 32 | case array: Array[_] => 33 | array.mkString( 34 | "[", 35 | ", ", 36 | "]" 37 | ) 38 | case seq: Seq[_] => 39 | seq.mkString( 40 | "[", 41 | ", ", 42 | "]" 43 | ) 44 | case d: Date => 45 | d.toLocalDate.format(DateTimeFormatter.ISO_DATE) 46 | case r: Row => 47 | r.schema.fieldNames 48 | .zip(r.toSeq) 49 | .map { case (k, v) => 50 | s"$k -> $v" 51 | } 52 | .mkString( 53 | "{", 54 | ", ", 55 | "}" 56 | ) 57 | case _ => cell.toString 58 | } 59 | if (truncate > 0 && str.length > truncate) { 60 | // do not show ellipses for strings shorter than 4 characters. 61 | if (truncate < 4) 62 | str.substring( 63 | 0, 64 | truncate 65 | ) 66 | else 67 | str.substring( 68 | 0, 69 | truncate - 3 70 | ) + "..." 71 | } else { 72 | str 73 | } 74 | }: Seq[String] 75 | } 76 | 77 | val sb = new StringBuilder 78 | val numCols = df.schema.fieldNames.length 79 | 80 | // Initialise the width of each column to a minimum value of '3' 81 | val colWidths = Array.fill(numCols)(3) 82 | 83 | // Compute the width of each column 84 | for (row <- rows) { 85 | for ((cell, i) <- row.zipWithIndex) { 86 | colWidths(i) = math.max( 87 | colWidths(i), 88 | cell.length 89 | ) 90 | } 91 | } 92 | 93 | // Create SeparateLine 94 | val sep: String = 95 | colWidths 96 | .map("-" * _) 97 | .addString( 98 | sb, 99 | "+", 100 | "+", 101 | "+\n" 102 | ) 103 | .toString() 104 | 105 | // column names 106 | val h: Seq[(String, Int)] = rows.head.zipWithIndex 107 | h.map { case (cell, i) => 108 | if (truncate > 0) { 109 | StringUtils.leftPad( 110 | cell, 111 | colWidths(i) 112 | ) 113 | } else { 114 | StringUtils.rightPad( 115 | cell, 116 | colWidths(i) 117 | ) 118 | } 119 | }.addString( 120 | sb, 121 | "|", 122 | "|", 123 | "|\n" 124 | ) 125 | 126 | sb.append(sep) 127 | 128 | // data 129 | rows.tail.map { 130 | _.zipWithIndex 131 | .map { case (cell, i) => 132 | if (truncate > 0) { 133 | StringUtils.leftPad( 134 | cell.toString, 135 | colWidths(i) 136 | ) 137 | } else { 138 | StringUtils.rightPad( 139 | cell.toString, 140 | colWidths(i) 141 | ) 142 | } 143 | } 144 | .addString( 145 | sb, 146 | "|", 147 | "|", 148 | "|\n" 149 | ) 150 | } 151 | 152 | sb.append(sep) 153 | 154 | // For Data that has more than "numRows" records 155 | if (hasMoreData) { 156 | val rowsString = if (numRows == 1) "row" else "rows" 157 | sb.append(s"only showing top $numRows $rowsString\n") 158 | } 159 | 160 | sb.toString() 161 | } 162 | 163 | } 164 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow 4 | import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions 5 | import org.apache.spark.rdd.RDD 6 | import org.apache.spark.sql.functions._ 7 | import org.apache.spark.sql.{DataFrame, Dataset, Row} 8 | 9 | import scala.reflect.ClassTag 10 | 11 | case class DatasetContentMismatch(smth: String) extends Exception(smth) 12 | case class DatasetCountMismatch(smth: String) extends Exception(smth) 13 | 14 | trait DatasetComparer { 15 | private def countMismatchMessage(actualCount: Long, expectedCount: Long): String = { 16 | s""" 17 | Actual DataFrame Row Count: '$actualCount' 18 | Expected DataFrame Row Count: '$expectedCount' 19 | """ 20 | } 21 | 22 | private def unequalRDDMessage[T](unequalRDD: RDD[(Long, (T, T))], length: Int): String = { 23 | "\nRow Index | Actual Row | Expected Row\n" + unequalRDD 24 | .take(length) 25 | .map { case (idx, (left, right)) => 26 | ufansi.Color.Red(s"$idx | $left | $right") 27 | } 28 | .mkString("\n") 29 | } 30 | 31 | /** 32 | * order ds1 column according to ds2 column order 33 | */ 34 | def orderColumns[T](ds1: Dataset[T], ds2: Dataset[T]): Dataset[T] = { 35 | ds1.select(ds2.columns.map(col).toIndexedSeq: _*).as[T](ds2.encoder) 36 | } 37 | 38 | /** 39 | * Raises an error unless `actualDS` and `expectedDS` are equal 40 | */ 41 | def assertSmallDatasetEquality[T: ClassTag]( 42 | actualDS: Dataset[T], 43 | expectedDS: Dataset[T], 44 | ignoreNullable: Boolean = false, 45 | ignoreColumnNames: Boolean = false, 46 | orderedComparison: Boolean = true, 47 | ignoreColumnOrder: Boolean = false, 48 | ignoreMetadata: Boolean = true, 49 | truncate: Int = 500, 50 | equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2) 51 | ): Unit = { 52 | SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) 53 | val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS 54 | assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals) 55 | } 56 | 57 | def assertSmallDatasetContentEquality[T: ClassTag]( 58 | actualDS: Dataset[T], 59 | expectedDS: Dataset[T], 60 | orderedComparison: Boolean, 61 | truncate: Int, 62 | equals: (T, T) => Boolean 63 | ): Unit = { 64 | if (orderedComparison) 65 | assertSmallDatasetContentEquality(actualDS, expectedDS, truncate, equals) 66 | else 67 | assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals) 68 | } 69 | 70 | def assertSmallDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = { 71 | val a = actualDS.collect().toSeq 72 | val e = expectedDS.collect().toSeq 73 | if (!a.approximateSameElements(e, equals)) { 74 | val arr = ("Actual Content", "Expected Content") 75 | val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate) 76 | throw DatasetContentMismatch(msg) 77 | } 78 | } 79 | 80 | def defaultSortDataset[T](ds: Dataset[T]): Dataset[T] = ds.sort(ds.columns.map(col).toIndexedSeq: _*) 81 | 82 | def sortPreciseColumns[T](ds: Dataset[T]): Dataset[T] = { 83 | val colNames = ds.dtypes 84 | .withFilter { dtype => 85 | !Seq("DoubleType", "DecimalType", "FloatType").contains(dtype._2) 86 | } 87 | .map(_._1) 88 | val cols = colNames.map(col) 89 | ds.sort(cols: _*) 90 | } 91 | 92 | /** 93 | * Raises an error unless `actualDS` and `expectedDS` are equal 94 | */ 95 | def assertLargeDatasetEquality[T: ClassTag]( 96 | actualDS: Dataset[T], 97 | expectedDS: Dataset[T], 98 | equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2), 99 | ignoreNullable: Boolean = false, 100 | ignoreColumnNames: Boolean = false, 101 | orderedComparison: Boolean = true, 102 | ignoreColumnOrder: Boolean = false, 103 | ignoreMetadata: Boolean = true 104 | ): Unit = { 105 | // first check if the schemas are equal 106 | SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) 107 | val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS 108 | assertLargeDatasetContentEquality(actual, expectedDS, equals, orderedComparison) 109 | } 110 | 111 | def assertLargeDatasetContentEquality[T: ClassTag]( 112 | actualDS: Dataset[T], 113 | expectedDS: Dataset[T], 114 | equals: (T, T) => Boolean, 115 | orderedComparison: Boolean 116 | ): Unit = { 117 | if (orderedComparison) { 118 | assertLargeDatasetContentEquality(actualDS, expectedDS, equals) 119 | } else { 120 | assertLargeDatasetContentEquality(sortPreciseColumns(actualDS), sortPreciseColumns(expectedDS), equals) 121 | } 122 | } 123 | 124 | def assertLargeDatasetContentEquality[T: ClassTag](ds1: Dataset[T], ds2: Dataset[T], equals: (T, T) => Boolean): Unit = { 125 | try { 126 | val ds1RDD = ds1.rdd.cache() 127 | val ds2RDD = ds2.rdd.cache() 128 | 129 | val actualCount = ds1RDD.count 130 | val expectedCount = ds2RDD.count 131 | 132 | if (actualCount != expectedCount) { 133 | throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount)) 134 | } 135 | val expectedIndexValue = RddHelpers.zipWithIndex(ds1RDD) 136 | val resultIndexValue = RddHelpers.zipWithIndex(ds2RDD) 137 | val unequalRDD = expectedIndexValue 138 | .join(resultIndexValue) 139 | .filter { case (_, (o1, o2)) => 140 | !equals(o1, o2) 141 | } 142 | 143 | if (!unequalRDD.isEmpty()) { 144 | throw DatasetContentMismatch( 145 | unequalRDDMessage(unequalRDD, maxUnequalRowsToShow) 146 | ) 147 | } 148 | 149 | } finally { 150 | ds1.rdd.unpersist() 151 | ds2.rdd.unpersist() 152 | } 153 | } 154 | 155 | def assertApproximateDataFrameEquality( 156 | actualDF: DataFrame, 157 | expectedDF: DataFrame, 158 | precision: Double, 159 | ignoreNullable: Boolean = false, 160 | ignoreColumnNames: Boolean = false, 161 | orderedComparison: Boolean = true, 162 | ignoreColumnOrder: Boolean = false, 163 | ignoreMetadata: Boolean = true 164 | ): Unit = { 165 | val e = (r1: Row, r2: Row) => { 166 | r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, precision) 167 | } 168 | assertLargeDatasetEquality[Row]( 169 | actualDF, 170 | expectedDF, 171 | equals = e, 172 | ignoreNullable, 173 | ignoreColumnNames, 174 | orderedComparison, 175 | ignoreColumnOrder, 176 | ignoreMetadata 177 | ) 178 | } 179 | } 180 | 181 | object DatasetComparer { 182 | val maxUnequalRowsToShow = 10 183 | } 184 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red} 4 | import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps 5 | import org.apache.commons.lang3.StringUtils 6 | import org.apache.spark.sql.Row 7 | 8 | import scala.reflect.ClassTag 9 | 10 | object ProductUtil { 11 | private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = { 12 | product match { 13 | case null => Seq.empty 14 | case a: Array[_] => a 15 | case i: Iterable[_] => i.toSeq 16 | case r: Row => r.toSeq 17 | case p: Product => p.productIterator.toSeq 18 | case s => Seq(s) 19 | } 20 | } 21 | 22 | private def rowFieldToString(fieldValue: Any): String = s"$fieldValue" 23 | 24 | private[mrpowers] def showProductDiff[T: ClassTag]( 25 | header: (String, String), 26 | actual: Seq[T], 27 | expected: Seq[T], 28 | truncate: Int = 20, 29 | minColWidth: Int = 3 30 | ): String = { 31 | 32 | val runTimeClass = implicitly[ClassTag[T]].runtimeClass 33 | val (className, lBracket, rBracket) = if (runTimeClass == classOf[Row]) ("", "[", "]") else (runTimeClass.getSimpleName, "(", ")") 34 | val prodToString: Seq[Any] => String = s => s.mkString(s"$className$lBracket", ",", rBracket) 35 | val emptyProd = "MISSING" 36 | 37 | val sb = new StringBuilder 38 | 39 | val fullJoin = actual.zipAll(expected, null, null) 40 | 41 | val diff = fullJoin.map { case (actualRow, expectedRow) => 42 | if (actualRow == expectedRow) { 43 | List(DarkGray(actualRow.toString), DarkGray(expectedRow.toString)) 44 | } else { 45 | val actualSeq = productOrRowToSeq(actualRow) 46 | val expectedSeq = productOrRowToSeq(expectedRow) 47 | if (actualSeq.isEmpty) 48 | List(Red(emptyProd), Green(prodToString(expectedSeq))) 49 | else if (expectedSeq.isEmpty) 50 | List(Red(prodToString(actualSeq)), Green(emptyProd)) 51 | else { 52 | val withEquals = actualSeq 53 | .zipAll(expectedSeq, "MISSING", "MISSING") 54 | .map { case (actualRowField, expectedRowField) => 55 | (actualRowField, expectedRowField, actualRowField == expectedRowField) 56 | } 57 | val allFieldsAreNotEqual = !withEquals.exists(_._3) 58 | if (allFieldsAreNotEqual) { 59 | List(Red(prodToString(actualSeq)), Green(prodToString(expectedSeq))) 60 | } else { 61 | val coloredDiff = withEquals.map { 62 | case (actualRowField, expectedRowField, true) => 63 | (DarkGray(rowFieldToString(actualRowField)), DarkGray(rowFieldToString(expectedRowField))) 64 | case (actualRowField, expectedRowField, false) => 65 | (Red(rowFieldToString(actualRowField)), Green(rowFieldToString(expectedRowField))) 66 | } 67 | val start = DarkGray(s"$className$lBracket") 68 | val sep = DarkGray(",") 69 | val end = DarkGray(rBracket) 70 | List( 71 | coloredDiff.map(_._1).mkStr(start, sep, end), 72 | coloredDiff.map(_._2).mkStr(start, sep, end) 73 | ) 74 | } 75 | } 76 | } 77 | } 78 | val headerSeq = List(header._1, header._2) 79 | val numCols = 2 80 | 81 | // Initialise the width of each column to a minimum value 82 | val colWidths = Array.fill(numCols)(minColWidth) 83 | 84 | // Compute the width of each column 85 | headerSeq.zipWithIndex.foreach({ case (cell, i) => 86 | colWidths(i) = math.max(colWidths(i), cell.length) 87 | }) 88 | 89 | diff.foreach { row => 90 | row.zipWithIndex.foreach { case (cell, i) => 91 | colWidths(i) = math.max(colWidths(i), cell.length) 92 | } 93 | } 94 | 95 | // Create SeparateLine 96 | val sep: String = 97 | colWidths 98 | .map("-" * _) 99 | .addString(sb, "+", "+", "+\n") 100 | .toString 101 | 102 | // column names 103 | headerSeq.zipWithIndex 104 | .map { case (cell, i) => 105 | if (truncate > 0) { 106 | StringUtils.leftPad(cell, colWidths(i)) 107 | } else { 108 | StringUtils.rightPad(cell, colWidths(i)) 109 | } 110 | } 111 | .addString(sb, "|", "|", "|\n") 112 | 113 | sb.append(sep) 114 | 115 | diff.map { row => 116 | row.zipWithIndex 117 | .map { case (cell, i) => 118 | val padsLen = colWidths(i) - cell.length 119 | val pads = if (padsLen > 0) " " * padsLen else "" 120 | if (truncate > 0) { 121 | pads + cell.toString 122 | } else { 123 | cell.toString + pads 124 | } 125 | 126 | } 127 | .addString(sb, "|", "|", "|\n") 128 | } 129 | 130 | sb.append(sep) 131 | 132 | sb.toString 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/RDDComparer.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.rdd.RDD 4 | 5 | import scala.reflect.ClassTag 6 | 7 | case class RDDContentMismatch(smth: String) extends Exception(smth) 8 | 9 | trait RDDComparer { 10 | 11 | def contentMismatchMessage[T: ClassTag](actualRDD: RDD[T], expectedRDD: RDD[T]): String = { 12 | s""" 13 | Actual RDD Content: 14 | ${actualRDD.take(5).mkString("\n")} 15 | Expected RDD Content: 16 | ${expectedRDD.take(5).mkString("\n")} 17 | """ 18 | } 19 | 20 | def assertSmallRDDEquality[T: ClassTag](actualRDD: RDD[T], expectedRDD: RDD[T]): Unit = { 21 | if (!actualRDD.collect().sameElements(expectedRDD.collect())) { 22 | throw RDDContentMismatch( 23 | contentMismatchMessage(actualRDD, expectedRDD) 24 | ) 25 | } 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/RddHelpers.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.rdd.RDD 4 | 5 | object RddHelpers { 6 | 7 | /** 8 | * Zip RDD's with precise indexes. This is used so we can join two DataFrame's Rows together regardless of if the source is different but still 9 | * compare based on the order. 10 | */ 11 | def zipWithIndex[T](rdd: RDD[T]): RDD[(Long, T)] = { 12 | rdd.zipWithIndex().map { case (row, idx) => 13 | (idx, row) 14 | } 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/RowComparer.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.commons.math3.util.Precision 4 | import org.apache.spark.sql.Row 5 | 6 | import scala.math.abs 7 | 8 | object RowComparer { 9 | 10 | /** Approximate equality, based on equals from [[Row]] */ 11 | def areRowsEqual(r1: Row, r2: Row, tol: Double = 0): Boolean = { 12 | if (tol == 0) { 13 | return r1 == r2 14 | } 15 | if (r1.length != r2.length) { 16 | return false 17 | } 18 | for (i <- 0 until r1.length) { 19 | if (r1.isNullAt(i) != r2.isNullAt(i)) { 20 | return false 21 | } 22 | if (!r1.isNullAt(i)) { 23 | val o1 = r1.get(i) 24 | val o2 = r2.get(i) 25 | val valid = o1 match { 26 | case b1: Array[Byte] => 27 | o2.isInstanceOf[Array[Byte]] && java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]]) 28 | case f1: Float if o2.isInstanceOf[Float] => Precision.equalsIncludingNaN(f1, o2.asInstanceOf[Float], tol) 29 | case d1: Double if o2.isInstanceOf[Double] => Precision.equalsIncludingNaN(d1, o2.asInstanceOf[Double], tol) 30 | case bd1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] => 31 | val bigDecimalCompare = bd1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs().compareTo(new java.math.BigDecimal(tol)) 32 | bigDecimalCompare == -1 || bigDecimalCompare == 0 33 | case f1: Number if o2.isInstanceOf[Number] => 34 | val bd1 = new java.math.BigDecimal(f1.toString) 35 | val bd2 = new java.math.BigDecimal(o2.toString) 36 | bd1.subtract(bd2).abs().compareTo(new java.math.BigDecimal(tol)) == -1 37 | case t1: java.sql.Timestamp => abs(t1.getTime - o2.asInstanceOf[java.sql.Timestamp].getTime) <= tol 38 | case t1: java.time.Instant => abs(t1.toEpochMilli - o2.asInstanceOf[java.time.Instant].toEpochMilli) <= tol 39 | case rr1: Row if o2.isInstanceOf[Row] => areRowsEqual(rr1, o2.asInstanceOf[Row], tol) 40 | case _ => o1 == o2 41 | } 42 | if (!valid) { 43 | return false 44 | } 45 | } 46 | } 47 | true 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff 4 | import com.github.mrpowers.spark.fast.tests.SchemaDiffOutputFormat.SchemaDiffOutputFormat 5 | import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red} 6 | import org.apache.spark.sql.Dataset 7 | import org.apache.spark.sql.types._ 8 | 9 | import scala.util.Try 10 | 11 | object SchemaComparer { 12 | private val INDENT_GAP = 5 13 | private val DESCRIPTION_GAP = 21 14 | private val TREE_GAP = 6 15 | private val CONTAIN_NULLS = 16 | "sparkFastTestContainsNull" // Distinguishable metadata key for spark fast tests to avoid potential ambiguity with other metadata keys 17 | case class DatasetSchemaMismatch(smth: String) extends Exception(smth) 18 | private def betterSchemaMismatchMessage(actualSchema: StructType, expectedSchema: StructType): String = { 19 | showProductDiff( 20 | ("Actual Schema", "Expected Schema"), 21 | actualSchema.fields, 22 | expectedSchema.fields, 23 | truncate = 200 24 | ) 25 | } 26 | 27 | private def treeSchemaMismatchMessage[T](actualSchema: StructType, expectedSchema: StructType): String = { 28 | def flattenStrucType(s: StructType, indent: Int, additionalGap: Int = 0): (Seq[(Int, StructField)], Int) = s 29 | .foldLeft((Seq.empty[(Int, StructField)], Int.MinValue)) { case ((fieldPair, maxWidth), f) => 30 | val gap = indent * INDENT_GAP + DESCRIPTION_GAP + f.name.length + f.dataType.typeName.length + f.nullable.toString.length + additionalGap 31 | val pair = fieldPair :+ (indent, f) 32 | val newMaxWidth = scala.math.max(maxWidth, gap) 33 | f.dataType match { 34 | case st: StructType => 35 | val (flattenPair, width) = flattenStrucType(st, indent + 1) 36 | (pair ++ flattenPair, scala.math.max(newMaxWidth, width)) 37 | case ArrayType(elementType, containsNull) => 38 | val arrStruct = StructType( 39 | Seq( 40 | StructField( 41 | "element", 42 | elementType, 43 | metadata = new MetadataBuilder() 44 | .putBoolean(CONTAIN_NULLS, value = containsNull) 45 | .build() 46 | ) 47 | ) 48 | ) 49 | val (flattenPair, width) = flattenStrucType(arrStruct, indent + 1, 5) 50 | (pair ++ flattenPair, scala.math.max(newMaxWidth, width)) 51 | case _ => (pair, newMaxWidth) 52 | } 53 | } 54 | 55 | val (treeFieldPair1, tree1MaxWidth) = flattenStrucType(actualSchema, 0) 56 | val (treeFieldPair2, _) = flattenStrucType(expectedSchema, 0) 57 | val (treePair, maxWidth) = treeFieldPair1 58 | .zipAll(treeFieldPair2, (0, null), (0, null)) 59 | .foldLeft((Seq.empty[(String, String)], 0)) { case ((acc, maxWidth), ((indent1, field1), (indent2, field2))) => 60 | val (prefix1, prefix2) = getIndentPair(indent1, indent2) 61 | val (name1, name2) = getNamePair(field1, field2) 62 | val (dtype1, dtype2) = getDataTypePair(field1, field2) 63 | val (nullable1, nullable2) = getNullablePair(field1, field2) 64 | val (containNulls1, containNulls2) = getContainNullsPair(field1, field2) 65 | val structString1 = formatField(field1, prefix1, name1, dtype1, nullable1, containNulls1) 66 | val structString2 = formatField(field2, prefix2, name2, dtype2, nullable2, containNulls2) 67 | 68 | (acc :+ (structString1, structString2), math.max(maxWidth, structString1.length)) 69 | } 70 | 71 | val schemaGap = maxWidth + TREE_GAP 72 | val headerGap = tree1MaxWidth + TREE_GAP 73 | treePair 74 | .foldLeft(new StringBuilder("\nActual Schema".padTo(headerGap, ' ') + "Expected Schema\n")) { case (sb, (s1, s2)) => 75 | val gap = if (s1.isEmpty) headerGap else schemaGap 76 | val s = if (s2.isEmpty) s1 else s1.padTo(gap, ' ') 77 | sb.append(s + s2 + "\n") 78 | } 79 | .toString() 80 | } 81 | 82 | private def getDescriptionPair(nullable: String, containNulls: String): String = 83 | if (containNulls.isEmpty) s"(nullable = $nullable)" else s"(containsNull = $containNulls)" 84 | 85 | private def formatField(field: StructField, prefix: String, name: String, dtype: String, nullable: String, containNulls: String): String = { 86 | if (field == null) { 87 | "" 88 | } else { 89 | val description = getDescriptionPair(nullable, containNulls) 90 | s"$prefix $name : $dtype $description" 91 | } 92 | } 93 | 94 | private def getContainNullsPair(field1: StructField, field2: StructField): (String, String) = { 95 | (field1, field2) match { 96 | case (f, null) => 97 | val containNulls = Try(f.metadata.getBoolean(CONTAIN_NULLS).toString).getOrElse("") 98 | Red(containNulls).toString -> "" 99 | case (null, f) => 100 | val containNulls = Try(f.metadata.getBoolean(CONTAIN_NULLS).toString).getOrElse("") 101 | "" -> Green(containNulls).toString 102 | case (StructField(_, _, _, m1), StructField(_, _, _, m2)) => 103 | val isArrayElement1 = m1.contains(CONTAIN_NULLS) 104 | val isArrayElement2 = m2.contains(CONTAIN_NULLS) 105 | if (isArrayElement1 && isArrayElement2) { 106 | val containNulls1 = m1.getBoolean(CONTAIN_NULLS) 107 | val containNulls2 = m2.getBoolean(CONTAIN_NULLS) 108 | val (cn1, cn2) = if (containNulls1 != containNulls2) { 109 | (Red(containNulls1.toString), Green(containNulls2.toString)) 110 | } else { 111 | (DarkGray(containNulls1.toString), DarkGray(containNulls2.toString)) 112 | } 113 | (cn1.toString, cn2.toString) 114 | } else if (isArrayElement1) { 115 | (DarkGray(m1.getBoolean(CONTAIN_NULLS).toString).toString, "") 116 | } else if (isArrayElement2) { 117 | ("", DarkGray(m2.getBoolean(CONTAIN_NULLS).toString).toString) 118 | } else { 119 | ("", "") 120 | } 121 | } 122 | } 123 | 124 | private def getIndentPair(indent1: Int, indent2: Int): (String, String) = { 125 | def depthToIndentStr(depth: Int): String = Range(0, depth).map(_ => "| ").mkString + "|--" 126 | val prefix1 = depthToIndentStr(indent1) 127 | val prefix2 = depthToIndentStr(indent2) 128 | val (p1, p2) = if (indent1 != indent2) { 129 | (Red(prefix1), Green(prefix2)) 130 | } else { 131 | (DarkGray(prefix1), DarkGray(prefix2)) 132 | } 133 | (p1.toString, p2.toString) 134 | } 135 | 136 | private def getNamePair(field1: StructField, field2: StructField): (String, String) = (field1, field2) match { 137 | case (_: StructField, null) => (Red(field1.name).toString, "") 138 | case (null, _: StructField) => ("", Green(field2.name).toString) 139 | case (f1, f2) if f1.name == f2.name => (DarkGray(field1.name).toString, DarkGray(field2.name).toString) 140 | case (f1, f2) if f1.name != f2.name => (Red(field1.name).toString, Green(field2.name).toString) 141 | } 142 | 143 | private def getDataTypePair(field1: StructField, field2: StructField): (String, String) = { 144 | (field1, field2) match { 145 | case (f: StructField, null) => (Red(f.dataType.typeName).toString, "") 146 | case (null, f: StructField) => ("", Green(f.dataType.typeName).toString) 147 | case (f1, f2) if f1.dataType == f2.dataType => (DarkGray(f1.dataType.typeName).toString, DarkGray(f2.dataType.typeName).toString) 148 | case (f1, f2) if f1.dataType != f2.dataType => (Red(f1.dataType.typeName).toString, Green(f2.dataType.typeName).toString) 149 | } 150 | } 151 | 152 | private def getNullablePair(field1: StructField, field2: StructField): (String, String) = { 153 | (field1, field2) match { 154 | case (f: StructField, null) => (Red(f.nullable.toString).toString, "") 155 | case (null, f: StructField) => ("", Green(f.nullable.toString).toString) 156 | case (f1, f2) if f1.nullable == f2.nullable => (DarkGray(f1.nullable.toString).toString, DarkGray(f2.nullable.toString).toString) 157 | case (f1, f2) if f1.nullable != f2.nullable => (Red(f1.nullable.toString).toString, Green(f2.nullable.toString).toString) 158 | } 159 | } 160 | 161 | def assertDatasetSchemaEqual[T]( 162 | actualDS: Dataset[T], 163 | expectedDS: Dataset[T], 164 | ignoreNullable: Boolean = false, 165 | ignoreColumnNames: Boolean = false, 166 | ignoreColumnOrder: Boolean = true, 167 | ignoreMetadata: Boolean = true, 168 | outputFormat: SchemaDiffOutputFormat = SchemaDiffOutputFormat.Table 169 | ): Unit = { 170 | assertSchemaEqual(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata, outputFormat) 171 | } 172 | 173 | def assertSchemaEqual( 174 | actualSchema: StructType, 175 | expectedSchema: StructType, 176 | ignoreNullable: Boolean = false, 177 | ignoreColumnNames: Boolean = false, 178 | ignoreColumnOrder: Boolean = true, 179 | ignoreMetadata: Boolean = true, 180 | outputFormat: SchemaDiffOutputFormat = SchemaDiffOutputFormat.Table 181 | ): Unit = { 182 | require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.") 183 | if (!SchemaComparer.equals(actualSchema, expectedSchema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)) { 184 | val diffString = outputFormat match { 185 | case SchemaDiffOutputFormat.Tree => treeSchemaMismatchMessage(actualSchema, expectedSchema) 186 | case SchemaDiffOutputFormat.Table => betterSchemaMismatchMessage(actualSchema, expectedSchema) 187 | } 188 | 189 | throw DatasetSchemaMismatch(s"Diffs\n$diffString") 190 | } 191 | } 192 | 193 | def equals( 194 | s1: StructType, 195 | s2: StructType, 196 | ignoreNullable: Boolean = false, 197 | ignoreColumnNames: Boolean = false, 198 | ignoreColumnOrder: Boolean = true, 199 | ignoreMetadata: Boolean = true 200 | ): Boolean = { 201 | if (s1.length != s2.length) { 202 | false 203 | } else { 204 | if (s1.length != s2.length) { 205 | false 206 | } else { 207 | val zipStruct = if (ignoreColumnOrder) s1.sortBy(_.name) zip s2.sortBy(_.name) else s1 zip s2 208 | zipStruct.forall { case (f1, f2) => 209 | (f1.nullable == f2.nullable || ignoreNullable) && 210 | (f1.name == f2.name || ignoreColumnNames) && 211 | (f1.metadata == f2.metadata || ignoreMetadata) && 212 | equals(f1.dataType, f2.dataType, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) 213 | } 214 | } 215 | } 216 | } 217 | 218 | def equals( 219 | dt1: DataType, 220 | dt2: DataType, 221 | ignoreNullable: Boolean, 222 | ignoreColumnNames: Boolean, 223 | ignoreColumnOrder: Boolean, 224 | ignoreMetadata: Boolean 225 | ): Boolean = { 226 | (dt1, dt2) match { 227 | case (st1: StructType, st2: StructType) => 228 | equals(st1, st2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder) 229 | case (ArrayType(vdt1, _), ArrayType(vdt2, _)) => 230 | equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) 231 | case (MapType(kdt1, vdt1, _), MapType(kdt2, vdt2, _)) => 232 | equals(kdt1, kdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) && 233 | equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) 234 | case _ => dt1 == dt2 235 | } 236 | } 237 | } 238 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaDiffOutputFormat.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | object SchemaDiffOutputFormat extends Enumeration { 4 | type SchemaDiffOutputFormat = Value 5 | 6 | val Tree, Table = Value 7 | } 8 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/SeqLikesExtensions.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.Row 4 | 5 | import scala.util.Try 6 | 7 | object SeqLikesExtensions { 8 | implicit class SeqExtensions[T](val seq1: Seq[T]) extends AnyVal { 9 | def approximateSameElements(seq2: Seq[T], equals: (T, T) => Boolean): Boolean = (seq1, seq2) match { 10 | case (i1: IndexedSeq[_], i2: IndexedSeq[_]) => 11 | val length = i1.length 12 | var equal = length == i2.length 13 | if (equal) { 14 | var index = 0 15 | val maxApplyCompare = { 16 | val preferredLength = 17 | Try(System.getProperty("scala.collection.immutable.IndexedSeq.defaultApplyPreferredMaxLength", "64").toInt).getOrElse(64) 18 | if (length > (preferredLength.toLong << 1)) preferredLength else length 19 | } 20 | while (index < maxApplyCompare && equal) { 21 | equal = equals(i1(index), i2(index)) 22 | index += 1 23 | } 24 | if ((index < length) && equal) { 25 | val thisIt = i1.iterator.drop(index) 26 | val thatIt = i2.iterator.drop(index) 27 | while (equal && thisIt.hasNext) { 28 | equal = equals(thisIt.next(), thatIt.next()) 29 | } 30 | } 31 | } 32 | equal 33 | case _ => 34 | val thisKnownSize = getKnownSize(seq1) 35 | val knownSizeDifference = thisKnownSize != -1 && { 36 | val thatKnownSize = getKnownSize(seq2) 37 | thatKnownSize != -1 && thisKnownSize != thatKnownSize 38 | } 39 | if (knownSizeDifference) { 40 | return false 41 | } 42 | val these = seq1.iterator 43 | val those = seq2.iterator 44 | while (these.hasNext && those.hasNext) 45 | if (!equals(these.next(), those.next())) 46 | return false 47 | these.hasNext == those.hasNext 48 | } 49 | 50 | // scala2.13 optimization: check number of element if it can be cheaply computed 51 | private def getKnownSize(s: Seq[T]): Int = Try(s.getClass.getMethod("knownSize").invoke(s).asInstanceOf[Int]).getOrElse(s.length) 52 | 53 | private[mrpowers] def asRows: Seq[Row] = seq1.map { 54 | case x: Row => x 55 | case y: Product => Row(y.productIterator.toSeq: _*) 56 | case a => Row(a) 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /core/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/FansiExtensions.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests.ufansi 2 | object FansiExtensions { 3 | private[mrpowers] implicit class StrOps(c: Seq[Str]) { 4 | def mkStr(start: Str, sep: Str, end: Str): Str = 5 | start ++ c.reduce(_ ++ sep ++ _) ++ end 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /core/src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=ERROR, console 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.target=System.err 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Settings to quiet third party logs that are too verbose 9 | log4j.logger.org.eclipse.jetty=WARN 10 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR 11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=WARN 12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=WARN -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/ArrayUtilTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import java.sql.Date 4 | import java.time.LocalDate 5 | import org.scalatest.freespec.AnyFreeSpec 6 | 7 | class ArrayUtilTest extends AnyFreeSpec { 8 | 9 | "blah" in { 10 | val arr: Array[(Any, Any)] = Array(("hi", "there"), ("fun", "train")) 11 | val res = ArrayUtil.weirdTypesToStrings(arr, 10) 12 | assert(res sameElements Array(List("hi", "there"), List("fun", "train"))) 13 | } 14 | 15 | "showTwoColumnString" in { 16 | val arr: Array[(Any, Any)] = Array(("word1", "word2"), ("hi", "there"), ("fun", "train")) 17 | println(ArrayUtil.showTwoColumnString(arr, 10)) 18 | } 19 | 20 | "showTwoColumnDate" in { 21 | val now = LocalDate.now() 22 | val arr: Array[(Any, Any)] = 23 | Array(("word1", "word2"), (Date.valueOf(now), Date.valueOf(now)), (Date.valueOf(now.plusDays(-1)), Date.valueOf(now))) 24 | println(ArrayUtil.showTwoColumnString(arr, 10)) 25 | } 26 | 27 | "dumbshowTwoColumnString" in { 28 | val arr: Array[(Any, Any)] = Array(("word1", "word2"), ("hi", "there"), ("fun", "train")) 29 | val rowEqual = Array(true, false) 30 | println(ArrayUtil.showTwoColumnStringColorCustomizable(arr, Some(rowEqual))) 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/ColumnComparerTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.Row 4 | import org.apache.spark.sql.functions._ 5 | import org.apache.spark.sql.types._ 6 | import java.sql.Date 7 | import java.sql.Timestamp 8 | 9 | import org.scalatest.freespec.AnyFreeSpec 10 | 11 | class ColumnComparerTest extends AnyFreeSpec with ColumnComparer with SparkSessionTestWrapper { 12 | 13 | "assertColumnEquality" - { 14 | 15 | "throws an easily readable error message" in { 16 | val sourceData = Seq( 17 | Row("phil", "phil"), 18 | Row("rashid", "rashid"), 19 | Row("matthew", "mateo"), 20 | Row("sami", "sami"), 21 | Row("this is something that is super crazy long", "sami"), 22 | Row("li", "feng"), 23 | Row(null, null) 24 | ) 25 | val sourceSchema = List( 26 | StructField("name", StringType, true), 27 | StructField("expected_name", StringType, true) 28 | ) 29 | val sourceDF = spark.createDataFrame( 30 | spark.sparkContext.parallelize(sourceData), 31 | StructType(sourceSchema) 32 | ) 33 | val e = intercept[ColumnMismatch] { 34 | assertColumnEquality(sourceDF, "name", "expected_name") 35 | } 36 | val e2 = intercept[ColumnMismatch] { 37 | assertColEquality(sourceDF, "name", "expected_name") 38 | } 39 | } 40 | 41 | "doesn't thrown an error when the columns are equal" in { 42 | val sourceData = Seq( 43 | Row(1, 1), 44 | Row(5, 5), 45 | Row(null, null) 46 | ) 47 | val sourceSchema = List( 48 | StructField("num", IntegerType, true), 49 | StructField("expected_num", IntegerType, true) 50 | ) 51 | val sourceDF = spark.createDataFrame( 52 | spark.sparkContext.parallelize(sourceData), 53 | StructType(sourceSchema) 54 | ) 55 | assertColumnEquality(sourceDF, "num", "expected_num") 56 | assertColEquality(sourceDF, "num", "expected_num") 57 | } 58 | 59 | "throws an error if the columns are not equal" in { 60 | val sourceData = Seq( 61 | Row(1, 3), 62 | Row(5, 5), 63 | Row(null, null) 64 | ) 65 | val sourceSchema = List( 66 | StructField("num", IntegerType, true), 67 | StructField("expected_num", IntegerType, true) 68 | ) 69 | val sourceDF = spark.createDataFrame( 70 | spark.sparkContext.parallelize(sourceData), 71 | StructType(sourceSchema) 72 | ) 73 | val e = intercept[ColumnMismatch] { 74 | assertColumnEquality(sourceDF, "num", "expected_num") 75 | } 76 | val e2 = intercept[ColumnMismatch] { 77 | assertColEquality(sourceDF, "num", "expected_num") 78 | } 79 | } 80 | 81 | "throws an error if the columns are different types" in { 82 | val sourceData = Seq( 83 | Row(1, "hi"), 84 | Row(5, "bye"), 85 | Row(null, null) 86 | ) 87 | val sourceSchema = List( 88 | StructField("num", IntegerType, true), 89 | StructField("word", StringType, true) 90 | ) 91 | val sourceDF = spark.createDataFrame( 92 | spark.sparkContext.parallelize(sourceData), 93 | StructType(sourceSchema) 94 | ) 95 | val e = intercept[ColumnMismatch] { 96 | assertColumnEquality(sourceDF, "num", "word") 97 | } 98 | val e2 = intercept[ColumnMismatch] { 99 | assertColEquality(sourceDF, "num", "word") 100 | } 101 | } 102 | 103 | "works properly, even when null is compared with a value" in { 104 | val sourceData = Seq( 105 | Row(1, 1), 106 | Row(null, 5), 107 | Row(null, null) 108 | ) 109 | val sourceSchema = List( 110 | StructField("num", IntegerType, true), 111 | StructField("expected_num", IntegerType, true) 112 | ) 113 | val sourceDF = spark.createDataFrame( 114 | spark.sparkContext.parallelize(sourceData), 115 | StructType(sourceSchema) 116 | ) 117 | val e = intercept[ColumnMismatch] { 118 | assertColumnEquality(sourceDF, "num", "expected_num") 119 | } 120 | val e2 = intercept[ColumnMismatch] { 121 | assertColEquality(sourceDF, "num", "expected_num") 122 | } 123 | } 124 | 125 | "works for ArrayType columns" in { 126 | val sourceData = Seq( 127 | Row(Array("a"), Array("a")), 128 | Row(Array("a", "b"), Array("a", "b")), 129 | Row(Array(), Array()), 130 | Row(null, null) 131 | ) 132 | val sourceSchema = List( 133 | StructField("l1", ArrayType(StringType, true), true), 134 | StructField("l2", ArrayType(StringType, true), true) 135 | ) 136 | val sourceDF = spark.createDataFrame( 137 | spark.sparkContext.parallelize(sourceData), 138 | StructType(sourceSchema) 139 | ) 140 | assertColumnEquality(sourceDF, "l1", "l2") 141 | assertColEquality(sourceDF, "l1", "l2") 142 | } 143 | 144 | "works for nested arrays" in { 145 | val sourceData = Seq( 146 | Row(Array(Array("a"), Array("a")), Array(Array("a"), Array("a"))), 147 | Row(Array(Array("a", "b"), Array("a", "b")), Array(Array("a", "b"), Array("a", "b"))), 148 | Row(Array(Array(), Array()), Array(Array(), Array())), 149 | Row(null, null) 150 | ) 151 | val sourceSchema = List( 152 | StructField("l1", ArrayType(ArrayType(StringType, true)), true), 153 | StructField("l2", ArrayType(ArrayType(StringType, true)), true) 154 | ) 155 | val sourceDF = spark.createDataFrame( 156 | spark.sparkContext.parallelize(sourceData), 157 | StructType(sourceSchema) 158 | ) 159 | assertColumnEquality(sourceDF, "l1", "l2") 160 | assertColEquality(sourceDF, "l1", "l2") 161 | } 162 | 163 | "works for computed ArrayType columns" in { 164 | val sourceData = Seq( 165 | Row("i like blue and red", Array("blue", "red")), 166 | Row("you pink and blue", Array("blue", "pink")), 167 | Row("i like fun", Array("")) 168 | ) 169 | val sourceSchema = List( 170 | StructField("words", StringType, true), 171 | StructField("expected_colors", ArrayType(StringType, true), true) 172 | ) 173 | val sourceDF = spark.createDataFrame( 174 | spark.sparkContext.parallelize(sourceData), 175 | StructType(sourceSchema) 176 | ) 177 | val actualDF = sourceDF.withColumn( 178 | "colors", 179 | coalesce( 180 | split( 181 | concat_ws( 182 | ",", 183 | when(col("words").contains("blue"), "blue"), 184 | when(col("words").contains("red"), "red"), 185 | when(col("words").contains("pink"), "pink"), 186 | when(col("words").contains("cyan"), "cyan") 187 | ), 188 | "," 189 | ), 190 | typedLit(Array()) 191 | ) 192 | ) 193 | assertColumnEquality(actualDF, "colors", "expected_colors") 194 | assertColEquality(actualDF, "colors", "expected_colors") 195 | } 196 | 197 | "works for MapType columns" in { 198 | val data = Seq( 199 | Row(Map("good_song" -> "santeria", "bad_song" -> "doesn't exist"), Map("good_song" -> "santeria", "bad_song" -> "doesn't exist")) 200 | ) 201 | val schema = List( 202 | StructField("m1", MapType(StringType, StringType, true), true), 203 | StructField("m2", MapType(StringType, StringType, true), true) 204 | ) 205 | val df = spark.createDataFrame( 206 | spark.sparkContext.parallelize(data), 207 | StructType(schema) 208 | ) 209 | assertColumnEquality(df, "m1", "m2") 210 | } 211 | 212 | "throws error when MapType columns aren't equal" in { 213 | val data = Seq( 214 | Row(Map("good_song" -> "santeria", "bad_song" -> "doesn't exist"), Map("good_song" -> "what i got", "bad_song" -> "doesn't exist")) 215 | ) 216 | val schema = List( 217 | StructField("m1", MapType(StringType, StringType, true), true), 218 | StructField("m2", MapType(StringType, StringType, true), true) 219 | ) 220 | val df = spark.createDataFrame( 221 | spark.sparkContext.parallelize(data), 222 | StructType(schema) 223 | ) 224 | val e = intercept[ColumnMismatch] { 225 | assertColumnEquality(df, "m1", "m2") 226 | } 227 | } 228 | 229 | "works for MapType columns with deep comparisons" in { 230 | val data = Seq( 231 | Row(Map("good_song" -> Array(1, 2, 3, 4)), Map("good_song" -> Array(1, 2, 3, 4))) 232 | ) 233 | val schema = List( 234 | StructField("m1", MapType(StringType, ArrayType(IntegerType, true), true), true), 235 | StructField("m2", MapType(StringType, ArrayType(IntegerType, true), true), true) 236 | ) 237 | val df = spark.createDataFrame( 238 | spark.sparkContext.parallelize(data), 239 | StructType(schema) 240 | ) 241 | assertColumnEquality(df, "m1", "m2") 242 | } 243 | 244 | "throws an errors for MapType columns with deep comparisons that aren't equal" in { 245 | val data = Seq( 246 | Row(Map("good_song" -> Array(1, 2, 3, 4)), Map("good_song" -> Array(1, 2, 3, 8))) 247 | ) 248 | val schema = List( 249 | StructField("m1", MapType(StringType, ArrayType(IntegerType, true), true), true), 250 | StructField("m2", MapType(StringType, ArrayType(IntegerType, true), true), true) 251 | ) 252 | val df = spark.createDataFrame( 253 | spark.sparkContext.parallelize(data), 254 | StructType(schema) 255 | ) 256 | val e = intercept[ColumnMismatch] { 257 | assertColumnEquality(df, "m1", "m2") 258 | } 259 | } 260 | 261 | "works when DateType columns are equal" in { 262 | val sourceData = Seq( 263 | Row(Date.valueOf("2016-08-09"), Date.valueOf("2016-08-09")), 264 | Row(Date.valueOf("2019-01-01"), Date.valueOf("2019-01-01")), 265 | Row(null, null) 266 | ) 267 | val sourceSchema = List( 268 | StructField("d1", DateType, true), 269 | StructField("d2", DateType, true) 270 | ) 271 | val sourceDF = spark.createDataFrame( 272 | spark.sparkContext.parallelize(sourceData), 273 | StructType(sourceSchema) 274 | ) 275 | assertColumnEquality(sourceDF, "d1", "d2") 276 | } 277 | 278 | "throws an error when DateType columns are not equal" in { 279 | val sourceData = Seq( 280 | Row(Date.valueOf("2010-07-07"), Date.valueOf("2016-08-09")), 281 | Row(Date.valueOf("2019-01-01"), Date.valueOf("2019-01-01")), 282 | Row(null, null) 283 | ) 284 | val sourceSchema = List( 285 | StructField("d1", DateType, true), 286 | StructField("d2", DateType, true) 287 | ) 288 | val sourceDF = spark.createDataFrame( 289 | spark.sparkContext.parallelize(sourceData), 290 | StructType(sourceSchema) 291 | ) 292 | val e = intercept[ColumnMismatch] { 293 | assertColumnEquality(sourceDF, "d1", "d2") 294 | } 295 | } 296 | 297 | "works when TimestampType columns are equal" in { 298 | val sourceData = Seq( 299 | Row(Timestamp.valueOf("2016-08-09 09:57:00"), Timestamp.valueOf("2016-08-09 09:57:00")), 300 | Row(Timestamp.valueOf("2016-04-10 09:57:00"), Timestamp.valueOf("2016-04-10 09:57:00")), 301 | Row(null, null) 302 | ) 303 | val sourceSchema = List( 304 | StructField("t1", TimestampType, true), 305 | StructField("t2", TimestampType, true) 306 | ) 307 | val sourceDF = spark.createDataFrame( 308 | spark.sparkContext.parallelize(sourceData), 309 | StructType(sourceSchema) 310 | ) 311 | assertColumnEquality(sourceDF, "t1", "t2") 312 | } 313 | 314 | "throws an error when TimestampType columns are not equal" in { 315 | val sourceData = Seq( 316 | Row(Timestamp.valueOf("2010-08-09 09:57:00"), Timestamp.valueOf("2016-08-09 09:57:00")), 317 | Row(Timestamp.valueOf("2016-04-10 10:01:00"), Timestamp.valueOf("2016-04-10 09:57:00")), 318 | Row(null, null) 319 | ) 320 | val sourceSchema = List( 321 | StructField("t1", TimestampType, true), 322 | StructField("t2", TimestampType, true) 323 | ) 324 | val sourceDF = spark.createDataFrame( 325 | spark.sparkContext.parallelize(sourceData), 326 | StructType(sourceSchema) 327 | ) 328 | val e = intercept[ColumnMismatch] { 329 | assertColumnEquality(sourceDF, "t1", "t2") 330 | } 331 | } 332 | 333 | "works when ByteType columns are equal" in { 334 | val sourceData = Seq( 335 | Row(10.toByte, 10.toByte), 336 | Row(33.toByte, 33.toByte), 337 | Row(null, null) 338 | ) 339 | val sourceSchema = List( 340 | StructField("b1", ByteType, true), 341 | StructField("b2", ByteType, true) 342 | ) 343 | val sourceDF = spark.createDataFrame( 344 | spark.sparkContext.parallelize(sourceData), 345 | StructType(sourceSchema) 346 | ) 347 | assertColumnEquality(sourceDF, "b1", "b2") 348 | } 349 | 350 | "throws an error when ByteType columns are not equal" in { 351 | val sourceData = Seq( 352 | Row(8.toByte, 10.toByte), 353 | Row(33.toByte, 33.toByte), 354 | Row(null, null) 355 | ) 356 | val sourceSchema = List( 357 | StructField("b1", ByteType, true), 358 | StructField("b2", ByteType, true) 359 | ) 360 | val sourceDF = spark.createDataFrame( 361 | spark.sparkContext.parallelize(sourceData), 362 | StructType(sourceSchema) 363 | ) 364 | val e = intercept[ColumnMismatch] { 365 | assertColumnEquality(sourceDF, "b1", "b2") 366 | } 367 | } 368 | 369 | } 370 | 371 | "assertBinaryTypeColumnEquality" - { 372 | 373 | "works when BinaryType columns are equal" in { 374 | val sourceData = Seq( 375 | Row(Array(10.toByte, 15.toByte), Array(10.toByte, 15.toByte)), 376 | Row(Array(4.toByte, 33.toByte), Array(4.toByte, 33.toByte)), 377 | Row(null, null) 378 | ) 379 | val sourceSchema = List( 380 | StructField("b1", BinaryType, true), 381 | StructField("b2", BinaryType, true) 382 | ) 383 | val sourceDF = spark.createDataFrame( 384 | spark.sparkContext.parallelize(sourceData), 385 | StructType(sourceSchema) 386 | ) 387 | assertBinaryTypeColumnEquality(sourceDF, "b1", "b2") 388 | } 389 | 390 | "throws an error when BinaryType columns are not equal" in { 391 | val sourceData = Seq( 392 | Row(Array(10.toByte, 15.toByte), Array(10.toByte, 15.toByte)), 393 | Row(Array(4.toByte, 33.toByte), Array(4.toByte, 33.toByte)), 394 | Row(null, null), 395 | Row(Array(7.toByte, 33.toByte), Array(4.toByte, 33.toByte)), 396 | Row(Array(4.toByte, 33.toByte), null), 397 | Row(null, Array(4.toByte, 33.toByte)) 398 | ) 399 | val sourceSchema = List( 400 | StructField("b1", BinaryType, true), 401 | StructField("b2", BinaryType, true) 402 | ) 403 | val sourceDF = spark.createDataFrame( 404 | spark.sparkContext.parallelize(sourceData), 405 | StructType(sourceSchema) 406 | ) 407 | val e = intercept[ColumnMismatch] { 408 | assertBinaryTypeColumnEquality(sourceDF, "b1", "b2") 409 | } 410 | } 411 | 412 | } 413 | 414 | "assertDoubleTypeColumnEquality" - { 415 | 416 | "doesn't throw an error when two DoubleType columns are equal" in { 417 | val sourceData = Seq( 418 | Row(1.3, 1.3), 419 | Row(5.01, 5.0101), 420 | Row(null, null) 421 | ) 422 | val sourceSchema = List( 423 | StructField("d1", DoubleType, true), 424 | StructField("d2", DoubleType, true) 425 | ) 426 | val df = spark.createDataFrame( 427 | spark.sparkContext.parallelize(sourceData), 428 | StructType(sourceSchema) 429 | ) 430 | assertDoubleTypeColumnEquality(df, "d1", "d2", 0.01) 431 | } 432 | 433 | "throws an error when two DoubleType columns are not equal" in { 434 | val sourceData = Seq( 435 | Row(1.3, 1.8), 436 | Row(5.01, 5.0101), 437 | Row(null, 10.0), 438 | Row(3.4, null), 439 | Row(-1.1, -1.1), 440 | Row(null, null) 441 | ) 442 | val sourceSchema = List( 443 | StructField("d1", DoubleType, true), 444 | StructField("d2", DoubleType, true) 445 | ) 446 | val df = spark.createDataFrame( 447 | spark.sparkContext.parallelize(sourceData), 448 | StructType(sourceSchema) 449 | ) 450 | val e = intercept[ColumnMismatch] { 451 | assertDoubleTypeColumnEquality(df, "d1", "d2", 0.01) 452 | } 453 | } 454 | 455 | } 456 | 457 | "assertFloatTypeColumnEquality" - { 458 | 459 | "doesn't throw an error when two FloatType columns are equal" in { 460 | val sourceData = Seq( 461 | Row(1.3f, 1.3f), 462 | Row(5.01f, 5.0101f), 463 | Row(null, null) 464 | ) 465 | val sourceSchema = List( 466 | StructField("num1", FloatType, true), 467 | StructField("num2", FloatType, true) 468 | ) 469 | val df = spark.createDataFrame( 470 | spark.sparkContext.parallelize(sourceData), 471 | StructType(sourceSchema) 472 | ) 473 | assertFloatTypeColumnEquality(df, "num1", "num2", 0.01f) 474 | } 475 | 476 | "throws an error when two FloatType columns are not equal" in { 477 | val sourceData = Seq( 478 | Row(1.3f, 1.8f), 479 | Row(5.01f, 5.0101f), 480 | Row(null, 10.0f), 481 | Row(3.4f, null), 482 | Row(null, null) 483 | ) 484 | val sourceSchema = List( 485 | StructField("d1", FloatType, true), 486 | StructField("d2", FloatType, true) 487 | ) 488 | val df = spark.createDataFrame( 489 | spark.sparkContext.parallelize(sourceData), 490 | StructType(sourceSchema) 491 | ) 492 | val e = intercept[ColumnMismatch] { 493 | assertFloatTypeColumnEquality(df, "d1", "d2", 0.01f) 494 | } 495 | } 496 | 497 | } 498 | 499 | "assertColEquality" - { 500 | 501 | "throws an easily readable error message" in { 502 | val sourceData = Seq( 503 | Row("phil", "phil"), 504 | Row("rashid", "rashid"), 505 | Row("matthew", "mateo"), 506 | Row("sami", "sami"), 507 | Row("this is something that is super crazy long", "sami"), 508 | Row("li", "feng"), 509 | Row(null, null) 510 | ) 511 | val sourceSchema = List( 512 | StructField("name", StringType, true), 513 | StructField("expected_name", StringType, true) 514 | ) 515 | val sourceDF = spark.createDataFrame( 516 | spark.sparkContext.parallelize(sourceData), 517 | StructType(sourceSchema) 518 | ) 519 | val e = intercept[ColumnMismatch] { 520 | assertColEquality(sourceDF, "name", "expected_name") 521 | } 522 | } 523 | 524 | "doesn't thrown an error when the columns are equal" in { 525 | val sourceData = Seq( 526 | Row(1, 1), 527 | Row(5, 5), 528 | Row(null, null) 529 | ) 530 | val sourceSchema = List( 531 | StructField("num", IntegerType, true), 532 | StructField("expected_num", IntegerType, true) 533 | ) 534 | val sourceDF = spark.createDataFrame( 535 | spark.sparkContext.parallelize(sourceData), 536 | StructType(sourceSchema) 537 | ) 538 | assertColEquality(sourceDF, "num", "expected_num") 539 | } 540 | 541 | "throws an error if the columns are not equal" in { 542 | val sourceData = Seq( 543 | Row(1, 3), 544 | Row(5, 5), 545 | Row(null, null) 546 | ) 547 | val sourceSchema = List( 548 | StructField("num", IntegerType, true), 549 | StructField("expected_num", IntegerType, true) 550 | ) 551 | val sourceDF = spark.createDataFrame( 552 | spark.sparkContext.parallelize(sourceData), 553 | StructType(sourceSchema) 554 | ) 555 | val e = intercept[ColumnMismatch] { 556 | assertColEquality(sourceDF, "num", "expected_num") 557 | } 558 | } 559 | 560 | "throws an error if the columns are different types" in { 561 | val sourceData = Seq( 562 | Row(1, "hi"), 563 | Row(5, "bye"), 564 | Row(null, null) 565 | ) 566 | val sourceSchema = List( 567 | StructField("num", IntegerType, true), 568 | StructField("word", StringType, true) 569 | ) 570 | val sourceDF = spark.createDataFrame( 571 | spark.sparkContext.parallelize(sourceData), 572 | StructType(sourceSchema) 573 | ) 574 | val e = intercept[ColumnMismatch] { 575 | assertColEquality(sourceDF, "num", "word") 576 | } 577 | } 578 | 579 | "works properly, even when null is compared with a value" in { 580 | val sourceData = Seq( 581 | Row(1, 1), 582 | Row(null, 5), 583 | Row(null, null) 584 | ) 585 | val sourceSchema = List( 586 | StructField("num", IntegerType, true), 587 | StructField("expected_num", IntegerType, true) 588 | ) 589 | val sourceDF = spark.createDataFrame( 590 | spark.sparkContext.parallelize(sourceData), 591 | StructType(sourceSchema) 592 | ) 593 | val e = intercept[ColumnMismatch] { 594 | assertColumnEquality(sourceDF, "num", "expected_num") 595 | } 596 | } 597 | 598 | "works for ArrayType columns" in { 599 | val sourceData = Seq( 600 | Row(Array("a"), Array("a")), 601 | Row(Array("a", "b"), Array("a", "b")), 602 | Row(Array(), Array()), 603 | Row(null, null) 604 | ) 605 | val sourceSchema = List( 606 | StructField("l1", ArrayType(StringType, true), true), 607 | StructField("l2", ArrayType(StringType, true), true) 608 | ) 609 | val sourceDF = spark.createDataFrame( 610 | spark.sparkContext.parallelize(sourceData), 611 | StructType(sourceSchema) 612 | ) 613 | assertColumnEquality(sourceDF, "l1", "l2") 614 | } 615 | 616 | "throws an error for unequal nested StructType columns with same schema" in { 617 | val sourceData = Seq( 618 | Row(Row("John", 30), Row("John", 31)), 619 | Row(Row("Jane", 25), Row("Jane", 25)), 620 | Row(Row("Jake", 40), Row("Jake", 40)), 621 | Row(null, null) 622 | ) 623 | val nestedSchema = StructType( 624 | List( 625 | StructField("name", StringType, true), 626 | StructField("age", IntegerType, true) 627 | ) 628 | ) 629 | val sourceSchema = List( 630 | StructField("struct1", nestedSchema, true), 631 | StructField("struct2", nestedSchema, true) 632 | ) 633 | val sourceDF = spark.createDataFrame( 634 | spark.sparkContext.parallelize(sourceData), 635 | StructType(sourceSchema) 636 | ) 637 | intercept[ColumnMismatch] { 638 | assertColumnEquality(sourceDF, "struct1", "struct2") 639 | } 640 | } 641 | 642 | "throws an error for unequal nested StructType columns with different schema" in { 643 | val sourceData = Seq( 644 | Row(Row("John", 30), Row("John")), 645 | Row(Row("Jane", 25), Row("Jane")), 646 | Row(Row("Jake", 40), Row("Jake")), 647 | Row(null, null) 648 | ) 649 | val nestedSchema1 = StructType( 650 | List( 651 | StructField("name", StringType, true), 652 | StructField("age", IntegerType, true) 653 | ) 654 | ) 655 | 656 | val nestedSchema2 = StructType( 657 | List( 658 | StructField("name", StringType, true) 659 | ) 660 | ) 661 | val sourceSchema = List( 662 | StructField("struct1", nestedSchema1, true), 663 | StructField("struct2", nestedSchema2, true) 664 | ) 665 | val sourceDF = spark.createDataFrame( 666 | spark.sparkContext.parallelize(sourceData), 667 | StructType(sourceSchema) 668 | ) 669 | intercept[ColumnMismatch] { 670 | assertColumnEquality(sourceDF, "struct1", "struct2") 671 | } 672 | } 673 | 674 | "work with StructType columns" in { 675 | val sourceData = Seq( 676 | Row(Row("John", 30), Row("John", 30)), 677 | Row(Row("Jane", 25), Row("Jane", 25)), 678 | Row(Row("Jake", 40), Row("Jake", 40)), 679 | Row(null, null) 680 | ) 681 | val nestedSchema = StructType( 682 | List( 683 | StructField("name", StringType, true), 684 | StructField("age", IntegerType, true) 685 | ) 686 | ) 687 | val sourceSchema = List( 688 | StructField("struct1", nestedSchema, true), 689 | StructField("struct2", nestedSchema, true) 690 | ) 691 | val sourceDF = spark.createDataFrame( 692 | spark.sparkContext.parallelize(sourceData), 693 | StructType(sourceSchema) 694 | ) 695 | 696 | assertColumnEquality(sourceDF, "struct1", "struct2") 697 | } 698 | } 699 | 700 | } 701 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} 4 | import SparkSessionExt._ 5 | import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch 6 | import org.apache.spark.sql.functions.col 7 | import com.github.mrpowers.spark.fast.tests.TestUtilsExt.ExceptionOps 8 | import org.scalatest.freespec.AnyFreeSpec 9 | 10 | import java.time.Instant 11 | 12 | class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with SparkSessionTestWrapper { 13 | 14 | "prints a descriptive error message if it bugs out" in { 15 | val sourceDF = spark.createDF( 16 | List( 17 | ("bob", 1, "uk"), 18 | ("camila", 5, "peru") 19 | ), 20 | List( 21 | ("name", StringType, true), 22 | ("age", IntegerType, true), 23 | ("country", StringType, true) 24 | ) 25 | ) 26 | 27 | val expectedDF = spark.createDF( 28 | List( 29 | ("bob", 1, "france"), 30 | ("camila", 5, "peru") 31 | ), 32 | List( 33 | ("name", StringType, true), 34 | ("age", IntegerType, true), 35 | ("country", StringType, true) 36 | ) 37 | ) 38 | 39 | val e = intercept[DatasetContentMismatch] { 40 | assertSmallDataFrameEquality(sourceDF, expectedDF) 41 | } 42 | assert(e.getMessage.indexOf("bob") >= 0) 43 | assert(e.getMessage.indexOf("camila") >= 0) 44 | } 45 | 46 | "Correctly mark unequal elements" in { 47 | val sourceDF = spark.createDF( 48 | List( 49 | ("bob", 1, "uk"), 50 | ("camila", 5, "peru"), 51 | ("steve", 10, "aus") 52 | ), 53 | List( 54 | ("name", StringType, true), 55 | ("age", IntegerType, true), 56 | ("country", StringType, true) 57 | ) 58 | ) 59 | 60 | val expectedDF = spark.createDF( 61 | List( 62 | ("bob", 1, "france"), 63 | ("camila", 5, "peru"), 64 | ("mark", 11, "usa") 65 | ), 66 | List( 67 | ("name", StringType, true), 68 | ("age", IntegerType, true), 69 | ("country", StringType, true) 70 | ) 71 | ) 72 | 73 | val e = intercept[DatasetContentMismatch] { 74 | assertSmallDataFrameEquality(expectedDF, sourceDF) 75 | } 76 | 77 | e.assertColorDiff(Seq("france", "[mark,11,usa]"), Seq("uk", "[steve,10,aus]")) 78 | } 79 | 80 | "Can handle unequal Dataframe containing null" in { 81 | val sourceDF = spark.createDF( 82 | List( 83 | ("bob", 1, "uk"), 84 | (null, 5, "peru"), 85 | ("steve", 10, "aus") 86 | ), 87 | List( 88 | ("name", StringType, true), 89 | ("age", IntegerType, true), 90 | ("country", StringType, true) 91 | ) 92 | ) 93 | 94 | val expectedDF = spark.createDF( 95 | List( 96 | ("bob", 1, "uk"), 97 | (null, 5, "peru"), 98 | (null, 10, "aus") 99 | ), 100 | List( 101 | ("name", StringType, true), 102 | ("age", IntegerType, true), 103 | ("country", StringType, true) 104 | ) 105 | ) 106 | 107 | val e = intercept[DatasetContentMismatch] { 108 | assertSmallDataFrameEquality(expectedDF, sourceDF) 109 | } 110 | 111 | e.assertColorDiff(Seq("null"), Seq("steve")) 112 | } 113 | 114 | "works well for wide DataFrames" in { 115 | val sourceDF = spark.createDF( 116 | List( 117 | ("bobisanicepersonandwelikehimOK", 1, "uk"), 118 | ("camila", 5, "peru") 119 | ), 120 | List( 121 | ("name", StringType, true), 122 | ("age", IntegerType, true), 123 | ("country", StringType, true) 124 | ) 125 | ) 126 | 127 | val expectedDF = spark.createDF( 128 | List( 129 | ("bobisanicepersonandwelikehimNOT", 1, "france"), 130 | ("camila", 5, "peru") 131 | ), 132 | List( 133 | ("name", StringType, true), 134 | ("age", IntegerType, true), 135 | ("country", StringType, true) 136 | ) 137 | ) 138 | 139 | intercept[DatasetContentMismatch] { 140 | assertSmallDataFrameEquality(sourceDF, expectedDF) 141 | } 142 | } 143 | 144 | "also print a descriptive error message if the right side is missing" in { 145 | val sourceDF = spark.createDF( 146 | List( 147 | ("bob", 1, "uk"), 148 | ("camila", 5, "peru"), 149 | ("jean-jacques", 4, "france") 150 | ), 151 | List( 152 | ("name", StringType, true), 153 | ("age", IntegerType, true), 154 | ("country", StringType, true) 155 | ) 156 | ) 157 | 158 | val expectedDF = spark.createDF( 159 | List( 160 | ("bob", 1, "france"), 161 | ("camila", 5, "peru") 162 | ), 163 | List( 164 | ("name", StringType, true), 165 | ("age", IntegerType, true), 166 | ("country", StringType, true) 167 | ) 168 | ) 169 | 170 | val e = intercept[DatasetContentMismatch] { 171 | assertSmallDataFrameEquality(sourceDF, expectedDF) 172 | } 173 | 174 | assert(e.getMessage.indexOf("jean-jacques") >= 0) 175 | } 176 | 177 | "also print a descriptive error message if the left side is missing" in { 178 | val sourceDF = spark.createDF( 179 | List( 180 | ("bob", 1, "uk"), 181 | ("camila", 5, "peru") 182 | ), 183 | List( 184 | ("name", StringType, true), 185 | ("age", IntegerType, true), 186 | ("country", StringType, true) 187 | ) 188 | ) 189 | 190 | val expectedDF = spark.createDF( 191 | List( 192 | ("bob", 1, "france"), 193 | ("camila", 5, "peru"), 194 | ("jean-claude", 4, "france") 195 | ), 196 | List( 197 | ("name", StringType, true), 198 | ("age", IntegerType, true), 199 | ("country", StringType, true) 200 | ) 201 | ) 202 | 203 | val e = intercept[DatasetContentMismatch] { 204 | assertSmallDataFrameEquality(sourceDF, expectedDF) 205 | } 206 | 207 | assert(e.getMessage.indexOf("jean-claude") >= 0) 208 | } 209 | 210 | "assertSmallDataFrameEquality" - { 211 | 212 | "does nothing if the DataFrames have the same schemas and content" in { 213 | val sourceDF = spark.createDF( 214 | List( 215 | (1), 216 | (5) 217 | ), 218 | List(("number", IntegerType, true)) 219 | ) 220 | 221 | val expectedDF = spark.createDF( 222 | List( 223 | (1), 224 | (5) 225 | ), 226 | List(("number", IntegerType, true)) 227 | ) 228 | assertLargeDataFrameEquality(sourceDF, expectedDF) 229 | } 230 | 231 | "throws an error if the DataFrames have different schemas" in { 232 | val sourceDF = spark.createDF( 233 | List( 234 | (1), 235 | (5) 236 | ), 237 | List(("number", IntegerType, true)) 238 | ) 239 | 240 | val expectedDF = spark.createDF( 241 | List( 242 | (1, "word"), 243 | (5, "word") 244 | ), 245 | List( 246 | ("number", IntegerType, true), 247 | ("word", StringType, true) 248 | ) 249 | ) 250 | 251 | intercept[DatasetSchemaMismatch] { 252 | assertLargeDataFrameEquality(sourceDF, expectedDF) 253 | } 254 | intercept[DatasetSchemaMismatch] { 255 | assertSmallDataFrameEquality(sourceDF, expectedDF) 256 | } 257 | } 258 | 259 | "throws an error if the DataFrames content is different" in { 260 | val sourceDF = spark.createDF( 261 | List( 262 | (1), 263 | (5) 264 | ), 265 | List(("number", IntegerType, true)) 266 | ) 267 | 268 | val expectedDF = spark.createDF( 269 | List( 270 | (10), 271 | (5) 272 | ), 273 | List(("number", IntegerType, true)) 274 | ) 275 | 276 | intercept[DatasetContentMismatch] { 277 | assertLargeDataFrameEquality(sourceDF, expectedDF) 278 | } 279 | intercept[DatasetContentMismatch] { 280 | assertSmallDataFrameEquality(sourceDF, expectedDF) 281 | } 282 | } 283 | 284 | "can performed unordered DataFrame comparisons" in { 285 | val sourceDF = spark.createDF( 286 | List( 287 | (1), 288 | (5) 289 | ), 290 | List(("number", IntegerType, true)) 291 | ) 292 | val expectedDF = spark.createDF( 293 | List( 294 | (5), 295 | (1) 296 | ), 297 | List(("number", IntegerType, true)) 298 | ) 299 | assertSmallDataFrameEquality(sourceDF, expectedDF, orderedComparison = false) 300 | } 301 | 302 | "throws an error for unordered DataFrame comparisons that don't match" in { 303 | val sourceDF = spark.createDF( 304 | List( 305 | (1), 306 | (5) 307 | ), 308 | List(("number", IntegerType, true)) 309 | ) 310 | val expectedDF = spark.createDF( 311 | List( 312 | (5), 313 | (1), 314 | (10) 315 | ), 316 | List(("number", IntegerType, true)) 317 | ) 318 | intercept[DatasetContentMismatch] { 319 | assertSmallDataFrameEquality(sourceDF, expectedDF, orderedComparison = false) 320 | } 321 | } 322 | 323 | "can performed DataFrame comparisons with unordered column" in { 324 | val sourceDF = spark.createDF( 325 | List( 326 | (1, "word"), 327 | (5, "word") 328 | ), 329 | List( 330 | ("number", IntegerType, true), 331 | ("word", StringType, true) 332 | ) 333 | ) 334 | val expectedDF = spark.createDF( 335 | List( 336 | ("word", 1), 337 | ("word", 5) 338 | ), 339 | List( 340 | ("word", StringType, true), 341 | ("number", IntegerType, true) 342 | ) 343 | ) 344 | assertLargeDataFrameEquality(sourceDF, expectedDF, ignoreColumnOrder = true) 345 | } 346 | 347 | "should not ignore nullable if ignoreNullable is false" in { 348 | val sourceDF = spark.createDF( 349 | List( 350 | 1.2, 351 | 5.1 352 | ), 353 | List(("number", DoubleType, false)) 354 | ) 355 | val expectedDF = spark.createDF( 356 | List( 357 | 1.2, 358 | 5.1 359 | ), 360 | List(("number", DoubleType, true)) 361 | ) 362 | 363 | intercept[DatasetSchemaMismatch] { 364 | assertLargeDataFrameEquality(sourceDF, expectedDF) 365 | } 366 | } 367 | 368 | "correctly mark unequal schema field" in { 369 | val sourceDF = spark.createDF( 370 | List( 371 | (1, 2.0), 372 | (5, 3.0) 373 | ), 374 | List( 375 | ("number", IntegerType, true), 376 | ("float", DoubleType, true) 377 | ) 378 | ) 379 | 380 | val expectedDF = spark.createDF( 381 | List( 382 | (1, "word", 1L), 383 | (5, "word", 2L) 384 | ), 385 | List( 386 | ("number", IntegerType, true), 387 | ("word", StringType, true), 388 | ("long", LongType, true) 389 | ) 390 | ) 391 | 392 | val e = intercept[DatasetSchemaMismatch] { 393 | assertSmallDataFrameEquality(sourceDF, expectedDF) 394 | } 395 | 396 | e.assertColorDiff( 397 | Seq("float", "DoubleType", "MISSING"), 398 | Seq("word", "StringType", "StructField(long,LongType,true,{})") 399 | ) 400 | } 401 | 402 | "can performed Dataset comparisons and ignore metadata" in { 403 | val sourceDF = spark 404 | .createDF( 405 | List( 406 | 1, 407 | 5 408 | ), 409 | List(("number", IntegerType, true)) 410 | ) 411 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build())) 412 | 413 | val expectedDF = spark 414 | .createDF( 415 | List( 416 | 1, 417 | 5 418 | ), 419 | List(("number", IntegerType, true)) 420 | ) 421 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build())) 422 | 423 | assertLargeDataFrameEquality(sourceDF, expectedDF) 424 | } 425 | 426 | "can performed Dataset comparisons and compare metadata" in { 427 | val sourceDF = spark 428 | .createDF( 429 | List( 430 | 1, 431 | 5 432 | ), 433 | List(("number", IntegerType, true)) 434 | ) 435 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build())) 436 | 437 | val expectedDF = spark 438 | .createDF( 439 | List( 440 | 1, 441 | 5 442 | ), 443 | List(("number", IntegerType, true)) 444 | ) 445 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build())) 446 | 447 | intercept[DatasetSchemaMismatch] { 448 | assertLargeDataFrameEquality(sourceDF, expectedDF, ignoreMetadata = false) 449 | } 450 | } 451 | } 452 | 453 | "assertApproximateDataFrameEquality" - { 454 | 455 | "does nothing if the DataFrames have the same schemas and content" in { 456 | val sourceDF = spark.createDF( 457 | List( 458 | 1.2, 459 | 5.1, 460 | null 461 | ), 462 | List(("number", DoubleType, true)) 463 | ) 464 | val expectedDF = spark.createDF( 465 | List( 466 | 1.2, 467 | 5.1, 468 | null 469 | ), 470 | List(("number", DoubleType, true)) 471 | ) 472 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01) 473 | } 474 | 475 | "throws an error if the rows are different" in { 476 | val sourceDF = spark.createDF( 477 | List( 478 | 100.9, 479 | 5.1 480 | ), 481 | List(("number", DoubleType, true)) 482 | ) 483 | val expectedDF = spark.createDF( 484 | List( 485 | 1.2, 486 | 5.1 487 | ), 488 | List(("number", DoubleType, true)) 489 | ) 490 | val e = intercept[DatasetContentMismatch] { 491 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01) 492 | } 493 | } 494 | 495 | "throws an error DataFrames have a different number of rows" in { 496 | val sourceDF = spark.createDF( 497 | List( 498 | 1.2, 499 | 5.1, 500 | 8.8 501 | ), 502 | List(("number", DoubleType, true)) 503 | ) 504 | val expectedDF = spark.createDF( 505 | List( 506 | 1.2, 507 | 5.1 508 | ), 509 | List(("number", DoubleType, true)) 510 | ) 511 | val e = intercept[DatasetCountMismatch] { 512 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01) 513 | } 514 | } 515 | 516 | "should not ignore nullable if ignoreNullable is false" in { 517 | val sourceDF = spark.createDF( 518 | List( 519 | 1.2, 520 | 5.1 521 | ), 522 | List(("number", DoubleType, false)) 523 | ) 524 | val expectedDF = spark.createDF( 525 | List( 526 | 1.2, 527 | 5.1 528 | ), 529 | List(("number", DoubleType, true)) 530 | ) 531 | 532 | intercept[DatasetSchemaMismatch] { 533 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01) 534 | } 535 | } 536 | 537 | "can ignore the nullable property" in { 538 | val sourceDF = spark.createDF( 539 | List( 540 | 1.2, 541 | 5.1 542 | ), 543 | List(("number", DoubleType, false)) 544 | ) 545 | val expectedDF = spark.createDF( 546 | List( 547 | 1.2, 548 | 5.1 549 | ), 550 | List(("number", DoubleType, true)) 551 | ) 552 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreNullable = true) 553 | } 554 | 555 | "can ignore the column names" in { 556 | val sourceDF = spark.createDF( 557 | List( 558 | 1.2, 559 | 5.1, 560 | null 561 | ), 562 | List(("BLAHBLBH", DoubleType, true)) 563 | ) 564 | val expectedDF = spark.createDF( 565 | List( 566 | 1.2, 567 | 5.1, 568 | null 569 | ), 570 | List(("number", DoubleType, true)) 571 | ) 572 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreColumnNames = true) 573 | } 574 | 575 | "can work with precision and unordered comparison" in { 576 | import spark.implicits._ 577 | val ds1 = Seq( 578 | ("1", "10/01/2019", 26.762499999999996), 579 | ("1", "11/01/2019", 26.762499999999996) 580 | ).toDF("col_B", "col_C", "col_A") 581 | 582 | val ds2 = Seq( 583 | ("1", "10/01/2019", 26.762499999999946), 584 | ("1", "11/01/2019", 26.76249999999991) 585 | ).toDF("col_B", "col_C", "col_A") 586 | 587 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 588 | } 589 | 590 | "can work with precision and unordered comparison 2" in { 591 | import spark.implicits._ 592 | val ds1 = Seq( 593 | ("1", "10/01/2019", 26.762499999999996, "A"), 594 | ("1", "10/01/2019", 26.762499999999996, "B") 595 | ).toDF("col_B", "col_C", "col_A", "col_D") 596 | 597 | val ds2 = Seq( 598 | ("1", "10/01/2019", 26.762499999999946, "A"), 599 | ("1", "10/01/2019", 26.76249999999991, "B") 600 | ).toDF("col_B", "col_C", "col_A", "col_D") 601 | 602 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 603 | } 604 | 605 | "throw error when exceed precision" in { 606 | import spark.implicits._ 607 | val ds1 = Seq( 608 | ("1", "10/01/2019", 26.762499999999996), 609 | ("1", "11/01/2019", 26.762499999999996) 610 | ).toDF("col_B", "col_C", "col_A") 611 | 612 | val ds2 = Seq( 613 | ("1", "10/01/2019", 26.762499999999946), 614 | ("1", "11/01/2019", 28.76249999999991) 615 | ).toDF("col_B", "col_C", "col_A") 616 | 617 | intercept[DatasetContentMismatch] { 618 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 619 | } 620 | } 621 | 622 | "throw error when exceed precision for TimestampType" in { 623 | import spark.implicits._ 624 | val ds1 = Seq( 625 | ("1", Instant.parse("2019-10-01T00:00:00Z")), 626 | ("2", Instant.parse("2019-11-01T00:00:00Z")) 627 | ).toDF("col_B", "col_A") 628 | 629 | val ds2 = Seq( 630 | ("1", Instant.parse("2019-10-01T00:00:00Z")), 631 | ("2", Instant.parse("2019-12-01T00:00:00Z")) 632 | ).toDF("col_B", "col_A") 633 | 634 | intercept[DatasetContentMismatch] { 635 | assertApproximateDataFrameEquality(ds1, ds2, precision = 100, orderedComparison = false) 636 | } 637 | } 638 | 639 | "throw error when exceed precision for BigDecimal" in { 640 | import spark.implicits._ 641 | val ds1 = Seq( 642 | ("1", BigDecimal(101)), 643 | ("2", BigDecimal(200)) 644 | ).toDF("col_B", "col_A") 645 | 646 | val ds2 = Seq( 647 | ("1", BigDecimal(101)), 648 | ("2", BigDecimal(203)) 649 | ).toDF("col_B", "col_A") 650 | 651 | intercept[DatasetContentMismatch] { 652 | assertApproximateDataFrameEquality(ds1, ds2, precision = 2, orderedComparison = false) 653 | } 654 | } 655 | 656 | "can work with precision and unordered comparison on nested column" in { 657 | import spark.implicits._ 658 | val ds1 = Seq( 659 | ("1", "10/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)), 660 | ("1", "11/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)) 661 | ).toDF("col_B", "col_C", "col_A", "col_D") 662 | 663 | val ds2 = Seq( 664 | ("1", "11/01/2019", 26.7624999999999961, Seq(26.7624999999999961, 26.7624999999999961)), 665 | ("1", "10/01/2019", 26.762499999999997, Seq(26.762499999999997, 26.762499999999997)) 666 | ).toDF("col_B", "col_C", "col_A", "col_D") 667 | 668 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 669 | } 670 | 671 | "can performed Dataset comparisons and ignore metadata" in { 672 | val sourceDF = spark 673 | .createDF( 674 | List( 675 | 1, 676 | 5 677 | ), 678 | List(("number", IntegerType, true)) 679 | ) 680 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build())) 681 | 682 | val expectedDF = spark 683 | .createDF( 684 | List( 685 | 1, 686 | 5 687 | ), 688 | List(("number", IntegerType, true)) 689 | ) 690 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build())) 691 | 692 | assertApproximateDataFrameEquality(sourceDF, expectedDF, precision = 0.0000001) 693 | } 694 | 695 | "can performed Dataset comparisons and compare metadata" in { 696 | val sourceDF = spark 697 | .createDF( 698 | List( 699 | 1, 700 | 5 701 | ), 702 | List(("number", IntegerType, true)) 703 | ) 704 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build())) 705 | 706 | val expectedDF = spark 707 | .createDF( 708 | List( 709 | 1, 710 | 5 711 | ), 712 | List(("number", IntegerType, true)) 713 | ) 714 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build())) 715 | 716 | intercept[DatasetSchemaMismatch] { 717 | assertApproximateDataFrameEquality(sourceDF, expectedDF, precision = 0.0000001, ignoreMetadata = false) 718 | } 719 | } 720 | } 721 | 722 | "assertApproximateSmallDataFrameEquality" - { 723 | 724 | "does nothing if the DataFrames have the same schemas and content" in { 725 | val sourceDF = spark.createDF( 726 | List( 727 | 1.2, 728 | 5.1, 729 | null 730 | ), 731 | List(("number", DoubleType, true)) 732 | ) 733 | val expectedDF = spark.createDF( 734 | List( 735 | 1.2, 736 | 5.1, 737 | null 738 | ), 739 | List(("number", DoubleType, true)) 740 | ) 741 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01) 742 | } 743 | 744 | "throws an error if the rows are different" in { 745 | val sourceDF = spark.createDF( 746 | List( 747 | 100.9, 748 | 5.1 749 | ), 750 | List(("number", DoubleType, true)) 751 | ) 752 | val expectedDF = spark.createDF( 753 | List( 754 | 1.2, 755 | 5.1 756 | ), 757 | List(("number", DoubleType, true)) 758 | ) 759 | val e = intercept[DatasetContentMismatch] { 760 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01) 761 | } 762 | } 763 | 764 | "throws an error DataFrames have a different number of rows" in { 765 | val sourceDF = spark.createDF( 766 | List( 767 | 1.2, 768 | 5.1, 769 | 8.8 770 | ), 771 | List(("number", DoubleType, true)) 772 | ) 773 | val expectedDF = spark.createDF( 774 | List( 775 | 1.2, 776 | 5.1 777 | ), 778 | List(("number", DoubleType, true)) 779 | ) 780 | val e = intercept[DatasetContentMismatch] { 781 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01) 782 | } 783 | } 784 | 785 | "can ignore the nullable property" in { 786 | val sourceDF = spark.createDF( 787 | List( 788 | 1.2, 789 | 5.1 790 | ), 791 | List(("number", DoubleType, false)) 792 | ) 793 | val expectedDF = spark.createDF( 794 | List( 795 | 1.2, 796 | 5.1 797 | ), 798 | List(("number", DoubleType, true)) 799 | ) 800 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreNullable = true) 801 | } 802 | 803 | "should not ignore nullable if ignoreNullable is false" in { 804 | val sourceDF = spark.createDF( 805 | List( 806 | 1.2, 807 | 5.1 808 | ), 809 | List(("number", DoubleType, false)) 810 | ) 811 | val expectedDF = spark.createDF( 812 | List( 813 | 1.2, 814 | 5.1 815 | ), 816 | List(("number", DoubleType, true)) 817 | ) 818 | 819 | intercept[DatasetSchemaMismatch] { 820 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01) 821 | } 822 | } 823 | 824 | "can ignore the column names" in { 825 | val sourceDF = spark.createDF( 826 | List( 827 | 1.2, 828 | 5.1, 829 | null 830 | ), 831 | List(("BLAHBLBH", DoubleType, true)) 832 | ) 833 | val expectedDF = spark.createDF( 834 | List( 835 | 1.2, 836 | 5.1, 837 | null 838 | ), 839 | List(("number", DoubleType, true)) 840 | ) 841 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreColumnNames = true) 842 | } 843 | 844 | "can work with precision and unordered comparison" in { 845 | import spark.implicits._ 846 | val ds1 = Seq( 847 | ("1", "10/01/2019", 26.762499999999996), 848 | ("1", "11/01/2019", 26.762499999999996) 849 | ).toDF("col_B", "col_C", "col_A") 850 | 851 | val ds2 = Seq( 852 | ("1", "10/01/2019", 26.762499999999946), 853 | ("1", "11/01/2019", 26.76249999999991) 854 | ).toDF("col_B", "col_C", "col_A") 855 | 856 | assertApproximateSmallDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 857 | } 858 | 859 | "can work with precision and unordered comparison 2" in { 860 | import spark.implicits._ 861 | val ds1 = Seq( 862 | ("1", "10/01/2019", "A", 26.762499999999996), 863 | ("1", "10/01/2019", "B", 26.762499999999996) 864 | ).toDF("col_B", "col_C", "col_A", "col_D") 865 | 866 | val ds2 = Seq( 867 | ("1", "10/01/2019", "A", 26.762499999999946), 868 | ("1", "10/01/2019", "B", 26.76249999999991) 869 | ).toDF("col_B", "col_C", "col_A", "col_D") 870 | 871 | assertApproximateSmallDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 872 | } 873 | 874 | "can work with precision and unordered comparison on nested column" in { 875 | import spark.implicits._ 876 | val ds1 = Seq( 877 | ("1", "10/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)), 878 | ("2", "11/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)) 879 | ).toDF("col_B", "col_C", "col_A", "col_D") 880 | 881 | val ds2 = Seq( 882 | ("2", "11/01/2019", 26.7624999999999961, Seq(26.7624999999999961, 26.7624999999999961)), 883 | ("1", "10/01/2019", 26.762499999999997, Seq(26.762499999999997, 26.762499999999997)) 884 | ).toDF("col_B", "col_C", "col_A", "col_D") 885 | 886 | assertApproximateSmallDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 887 | } 888 | 889 | "can performed Dataset comparisons and ignore metadata" in { 890 | val sourceDF = spark 891 | .createDF( 892 | List( 893 | 1, 894 | 5 895 | ), 896 | List(("number", IntegerType, true)) 897 | ) 898 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build())) 899 | 900 | val expectedDF = spark 901 | .createDF( 902 | List( 903 | 1, 904 | 5 905 | ), 906 | List(("number", IntegerType, true)) 907 | ) 908 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build())) 909 | 910 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, precision = 0.0000001) 911 | } 912 | 913 | "can performed Dataset comparisons and compare metadata" in { 914 | val sourceDF = spark 915 | .createDF( 916 | List( 917 | 1, 918 | 5 919 | ), 920 | List(("number", IntegerType, true)) 921 | ) 922 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small int").build())) 923 | 924 | val expectedDF = spark 925 | .createDF( 926 | List( 927 | 1, 928 | 5 929 | ), 930 | List(("number", IntegerType, true)) 931 | ) 932 | .withColumn("number", col("number").as("number", new MetadataBuilder().putString("description", "small number").build())) 933 | 934 | intercept[DatasetSchemaMismatch] { 935 | assertApproximateSmallDataFrameEquality(sourceDF, expectedDF, precision = 0.0000001, ignoreMetadata = false) 936 | } 937 | } 938 | } 939 | } 940 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFramePrettyPrintTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.scalatest.freespec.AnyFreeSpec 4 | 5 | class DataFramePrettyPrintTest extends AnyFreeSpec with SparkSessionTestWrapper { 6 | "prints named_struct with keys" in { 7 | val inputDataframe = spark.sql("""select 1 as id, named_struct("k1", 2, "k2", 3) as to_show_with_k""") 8 | assert( 9 | DataFramePrettyPrint.showString(inputDataframe, 10) == 10 | """+---+------------------+ 11 | || id| to_show_with_k| 12 | |+---+------------------+ 13 | || 1|{k1 -> 2, k2 -> 3}| 14 | |+---+------------------+ 15 | |""".stripMargin 16 | ) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.types._ 4 | import SparkSessionExt._ 5 | import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch 6 | import com.github.mrpowers.spark.fast.tests.TestUtilsExt.ExceptionOps 7 | import org.apache.spark.sql.functions.col 8 | import org.scalatest.freespec.AnyFreeSpec 9 | 10 | object Person { 11 | 12 | def caseInsensitivePersonEquals(some: Person, other: Person): Boolean = { 13 | some.name.equalsIgnoreCase(other.name) && some.age == other.age 14 | } 15 | } 16 | case class Person(name: String, age: Int) 17 | case class PrecisePerson(name: String, age: Double) 18 | 19 | class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSessionTestWrapper { 20 | 21 | "checkDatasetEquality" - { 22 | import spark.implicits._ 23 | 24 | "provides a good README example" in { 25 | val sourceDS = Seq( 26 | Person("juan", 5), 27 | Person("bob", 1), 28 | Person("li", 49), 29 | Person("alice", 5) 30 | ).toDS 31 | 32 | val expectedDS = Seq( 33 | Person("juan", 5), 34 | Person("frank", 10), 35 | Person("li", 49), 36 | Person("lucy", 5) 37 | ).toDS 38 | 39 | val e = intercept[DatasetContentMismatch] { 40 | assertSmallDatasetEquality(sourceDS, expectedDS) 41 | } 42 | } 43 | 44 | "can compare unequal Dataset containing null in column" in { 45 | val sourceDS = Seq(Person(null, 5), Person(null, 1)).toDS 46 | val expectedDS = Seq(Person("juan", 5), Person(null, 1)).toDS 47 | 48 | val e = intercept[DatasetContentMismatch] { 49 | assertSmallDatasetEquality(sourceDS, expectedDS, ignoreNullable = true, orderedComparison = false) 50 | } 51 | 52 | e.assertColorDiff(Seq("null"), Seq("juan")) 53 | } 54 | 55 | "Correctly mark unequal elements" in { 56 | val sourceDS = Seq( 57 | Person("juan", 5), 58 | Person("bob", 1), 59 | Person("li", 49), 60 | Person("alice", 5) 61 | ).toDS 62 | 63 | val expectedDS = Seq( 64 | Person("juan", 5), 65 | Person("frank", 10), 66 | Person("li", 49), 67 | Person("lucy", 5) 68 | ).toDS 69 | 70 | val e = intercept[DatasetContentMismatch] { 71 | assertSmallDatasetEquality(sourceDS, expectedDS) 72 | } 73 | 74 | e.assertColorDiff(Seq("Person(bob,1)", "alice"), Seq("Person(frank,10)", "lucy")) 75 | } 76 | 77 | "correctly mark unequal element for Dataset[String]" in { 78 | import spark.implicits._ 79 | val sourceDS = Seq("word", "StringType", "StructField(long,LongType,true,{})").toDS 80 | 81 | val expectedDS = List("word", "StringType", "StructField(long,LongType2,true,{})").toDS 82 | 83 | val e = intercept[DatasetContentMismatch] { 84 | assertSmallDatasetEquality(sourceDS, expectedDS) 85 | } 86 | 87 | e.assertColorDiff(Seq("String(StructField(long,LongType,true,{}))"), Seq("String(StructField(long,LongType2,true,{}))")) 88 | } 89 | 90 | "correctly mark unequal element for Dataset[Seq[String]]" in { 91 | import spark.implicits._ 92 | 93 | val sourceDS = Seq( 94 | Seq("apple", "banana", "cherry"), 95 | Seq("dog", "cat"), 96 | Seq("red", "green", "blue") 97 | ).toDS 98 | 99 | val expectedDS = Seq( 100 | Seq("apple", "banana2"), 101 | Seq("dog", "cat"), 102 | Seq("red", "green", "blue") 103 | ).toDS 104 | 105 | val e = intercept[DatasetContentMismatch] { 106 | assertSmallDatasetEquality(sourceDS, expectedDS) 107 | } 108 | 109 | e.assertColorDiff(Seq("banana", "cherry"), Seq("banana2", "MISSING")) 110 | } 111 | 112 | "correctly mark unequal element for Dataset[Array[String]]" in { 113 | import spark.implicits._ 114 | 115 | val sourceDS = Seq( 116 | Array("apple", "banana", "cherry"), 117 | Array("dog", "cat"), 118 | Array("red", "green", "blue") 119 | ).toDS 120 | 121 | val expectedDS = Seq( 122 | Array("apple", "banana2"), 123 | Array("dog", "cat"), 124 | Array("red", "green", "blue") 125 | ).toDS 126 | 127 | val e = intercept[DatasetContentMismatch] { 128 | assertSmallDatasetEquality(sourceDS, expectedDS) 129 | } 130 | 131 | e.assertColorDiff(Seq("banana", "cherry"), Seq("banana2", "MISSING")) 132 | } 133 | 134 | "correctly mark unequal element for Dataset[Map[String, String]]" in { 135 | import spark.implicits._ 136 | 137 | val sourceDS = Seq( 138 | Map("apple" -> "banana", "apple1" -> "banana1"), 139 | Map("apple" -> "banana", "apple1" -> "banana1") 140 | ).toDS 141 | 142 | val expectedDS = Seq( 143 | Map("apple" -> "banana1", "apple1" -> "banana1"), 144 | Map("apple" -> "banana", "apple1" -> "banana1") 145 | ).toDS 146 | 147 | val e = intercept[DatasetContentMismatch] { 148 | assertSmallDatasetEquality(sourceDS, expectedDS) 149 | } 150 | 151 | e.assertColorDiff(Seq("(apple,banana)"), Seq("(apple,banana1)")) 152 | } 153 | 154 | "works with really long columns" in { 155 | val sourceDS = Seq( 156 | Person("juanisareallygoodguythatilikealotOK", 5), 157 | Person("bob", 1), 158 | Person("li", 49), 159 | Person("alice", 5) 160 | ).toDS 161 | 162 | val expectedDS = Seq( 163 | Person("juanisareallygoodguythatilikealotNOT", 5), 164 | Person("frank", 10), 165 | Person("li", 49), 166 | Person("lucy", 5) 167 | ).toDS 168 | 169 | val e = intercept[DatasetContentMismatch] { 170 | assertSmallDatasetEquality(sourceDS, expectedDS) 171 | } 172 | } 173 | 174 | "does nothing if the DataFrames have the same schemas and content" in { 175 | val sourceDF = spark.createDF( 176 | List( 177 | (1), 178 | (5) 179 | ), 180 | List(("number", IntegerType, true)) 181 | ) 182 | 183 | val expectedDF = spark.createDF( 184 | List( 185 | (1), 186 | (5) 187 | ), 188 | List(("number", IntegerType, true)) 189 | ) 190 | 191 | assertSmallDatasetEquality(sourceDF, expectedDF) 192 | assertLargeDatasetEquality(sourceDF, expectedDF) 193 | } 194 | 195 | "does nothing if the Datasets have the same schemas and content" in { 196 | val sourceDS = spark.createDataset[Person]( 197 | Seq( 198 | Person("Alice", 12), 199 | Person("Bob", 17) 200 | ) 201 | ) 202 | 203 | val expectedDS = spark.createDataset[Person]( 204 | Seq( 205 | Person("Alice", 12), 206 | Person("Bob", 17) 207 | ) 208 | ) 209 | 210 | assertSmallDatasetEquality(sourceDS, expectedDS) 211 | assertLargeDatasetEquality(sourceDS, expectedDS) 212 | } 213 | 214 | "works with DataFrames that have ArrayType columns" in { 215 | val sourceDF = spark.createDF( 216 | List( 217 | (1, Array("word1", "blah")), 218 | (5, Array("hi", "there")) 219 | ), 220 | List( 221 | ("number", IntegerType, true), 222 | ("words", ArrayType(StringType, true), true) 223 | ) 224 | ) 225 | 226 | val expectedDF = spark.createDF( 227 | List( 228 | (1, Array("word1", "blah")), 229 | (5, Array("hi", "there")) 230 | ), 231 | List( 232 | ("number", IntegerType, true), 233 | ("words", ArrayType(StringType, true), true) 234 | ) 235 | ) 236 | 237 | assertLargeDatasetEquality(sourceDF, expectedDF) 238 | assertSmallDatasetEquality(sourceDF, expectedDF) 239 | } 240 | 241 | "throws an error if the DataFrames have different schemas" in { 242 | val nestedSchema = StructType( 243 | Seq( 244 | StructField( 245 | "attributes", 246 | StructType( 247 | Seq( 248 | StructField("PostCode", IntegerType, nullable = true) 249 | ) 250 | ), 251 | nullable = true 252 | ) 253 | ) 254 | ) 255 | 256 | val nestedSchema2 = StructType( 257 | Seq( 258 | StructField( 259 | "attributes", 260 | StructType( 261 | Seq( 262 | StructField("PostCode", StringType, nullable = true) 263 | ) 264 | ), 265 | nullable = true 266 | ) 267 | ) 268 | ) 269 | 270 | val sourceDF = spark.createDF( 271 | List( 272 | (1, 2.0, null), 273 | (5, 3.0, null) 274 | ), 275 | List( 276 | ("number", IntegerType, true), 277 | ("float", DoubleType, true), 278 | ("nestedField", nestedSchema, true) 279 | ) 280 | ) 281 | 282 | val expectedDF = spark.createDF( 283 | List( 284 | (1, "word", null, 1L), 285 | (5, "word", null, 2L) 286 | ), 287 | List( 288 | ("number", IntegerType, true), 289 | ("word", StringType, true), 290 | ("nestedField", nestedSchema2, true), 291 | ("long", LongType, true) 292 | ) 293 | ) 294 | 295 | intercept[DatasetSchemaMismatch] { 296 | assertLargeDatasetEquality(sourceDF, expectedDF) 297 | } 298 | 299 | intercept[DatasetSchemaMismatch] { 300 | assertSmallDatasetEquality(sourceDF, expectedDF) 301 | } 302 | } 303 | 304 | "throws an error if the DataFrames content is different" in { 305 | val sourceDF = Seq( 306 | (1), (5), (7), (1), (1) 307 | ).toDF("number") 308 | 309 | val expectedDF = Seq( 310 | (10), (5), (3), (7), (1) 311 | ).toDF("number") 312 | 313 | val e = intercept[DatasetContentMismatch] { 314 | assertLargeDatasetEquality(sourceDF, expectedDF) 315 | } 316 | val e2 = intercept[DatasetContentMismatch] { 317 | assertSmallDatasetEquality(sourceDF, expectedDF) 318 | } 319 | } 320 | 321 | "throws an error if the Dataset content is different" in { 322 | val sourceDS = spark.createDataset[Person]( 323 | Seq( 324 | Person("Alice", 12), 325 | Person("Bob", 17) 326 | ) 327 | ) 328 | 329 | val expectedDS = spark.createDataset[Person]( 330 | Seq( 331 | Person("Frank", 10), 332 | Person("Lucy", 5) 333 | ) 334 | ) 335 | 336 | val e = intercept[DatasetContentMismatch] { 337 | assertLargeDatasetEquality(sourceDS, expectedDS) 338 | } 339 | val e2 = intercept[DatasetContentMismatch] { 340 | assertLargeDatasetEquality(sourceDS, expectedDS) 341 | } 342 | } 343 | 344 | "succeeds if custom comparator returns true" in { 345 | val sourceDS = spark.createDataset[Person]( 346 | Seq( 347 | Person("bob", 1), 348 | Person("alice", 5) 349 | ) 350 | ) 351 | val expectedDS = spark.createDataset[Person]( 352 | Seq( 353 | Person("Bob", 1), 354 | Person("Alice", 5) 355 | ) 356 | ) 357 | assertLargeDatasetEquality(sourceDS, expectedDS, Person.caseInsensitivePersonEquals) 358 | } 359 | 360 | "fails if custom comparator for returns false" in { 361 | val sourceDS = spark.createDataset[Person]( 362 | Seq( 363 | Person("bob", 10), 364 | Person("alice", 5) 365 | ) 366 | ) 367 | val expectedDS = spark.createDataset[Person]( 368 | Seq( 369 | Person("Bob", 1), 370 | Person("Alice", 5) 371 | ) 372 | ) 373 | val e = intercept[DatasetContentMismatch] { 374 | assertLargeDatasetEquality(sourceDS, expectedDS, Person.caseInsensitivePersonEquals) 375 | } 376 | } 377 | 378 | } 379 | 380 | "assertLargeDatasetEquality" - { 381 | import spark.implicits._ 382 | 383 | "ignores the nullable flag when making DataFrame comparisons" in { 384 | val sourceDF = spark.createDF( 385 | List( 386 | (1), 387 | (5) 388 | ), 389 | List(("number", IntegerType, false)) 390 | ) 391 | 392 | val expectedDF = spark.createDF( 393 | List( 394 | (1), 395 | (5) 396 | ), 397 | List(("number", IntegerType, true)) 398 | ) 399 | 400 | assertLargeDatasetEquality(sourceDF, expectedDF, ignoreNullable = true) 401 | } 402 | 403 | "should not ignore nullable if ignoreNullable is false" in { 404 | 405 | val sourceDF = spark.createDF( 406 | List( 407 | (1), 408 | (5) 409 | ), 410 | List(("number", IntegerType, false)) 411 | ) 412 | 413 | val expectedDF = spark.createDF( 414 | List( 415 | (1), 416 | (5) 417 | ), 418 | List(("number", IntegerType, true)) 419 | ) 420 | 421 | intercept[DatasetSchemaMismatch] { 422 | assertLargeDatasetEquality(sourceDF, expectedDF) 423 | } 424 | } 425 | 426 | "can performed unordered DataFrame comparisons" in { 427 | val sourceDF = spark.createDF( 428 | List( 429 | (1), 430 | (5) 431 | ), 432 | List(("number", IntegerType, true)) 433 | ) 434 | 435 | val expectedDF = spark.createDF( 436 | List( 437 | (5), 438 | (1) 439 | ), 440 | List(("number", IntegerType, true)) 441 | ) 442 | 443 | assertLargeDatasetEquality(sourceDF, expectedDF, orderedComparison = false) 444 | } 445 | 446 | "throws an error for unordered Dataset comparisons that don't match" in { 447 | val sourceDS = spark.createDataset[Person]( 448 | Seq( 449 | Person("bob", 1), 450 | Person("frank", 5) 451 | ) 452 | ) 453 | 454 | val expectedDS = spark.createDataset[Person]( 455 | Seq( 456 | Person("frank", 5), 457 | Person("bob", 1), 458 | Person("sadie", 2) 459 | ) 460 | ) 461 | 462 | val e = intercept[DatasetCountMismatch] { 463 | assertLargeDatasetEquality(sourceDS, expectedDS, orderedComparison = false) 464 | } 465 | } 466 | 467 | "throws an error for unordered DataFrame comparisons that don't match" in { 468 | val sourceDF = spark.createDF( 469 | List( 470 | (1), 471 | (5) 472 | ), 473 | List(("number", IntegerType, true)) 474 | ) 475 | val expectedDF = spark.createDF( 476 | List( 477 | (5), 478 | (1), 479 | (10) 480 | ), 481 | List(("number", IntegerType, true)) 482 | ) 483 | 484 | val e = intercept[DatasetCountMismatch] { 485 | assertLargeDatasetEquality(sourceDF, expectedDF, orderedComparison = false) 486 | } 487 | } 488 | 489 | "throws an error DataFrames have a different number of rows" in { 490 | val sourceDF = spark.createDF( 491 | List( 492 | (1), 493 | (5) 494 | ), 495 | List(("number", IntegerType, true)) 496 | ) 497 | val expectedDF = spark.createDF( 498 | List( 499 | (1), 500 | (5), 501 | (10) 502 | ), 503 | List(("number", IntegerType, true)) 504 | ) 505 | 506 | val e = intercept[DatasetCountMismatch] { 507 | assertLargeDatasetEquality(sourceDF, expectedDF) 508 | } 509 | } 510 | 511 | "can performed DataFrame comparisons with unordered column" in { 512 | val sourceDF = spark.createDF( 513 | List( 514 | (1, "word"), 515 | (5, "word") 516 | ), 517 | List( 518 | ("number", IntegerType, true), 519 | ("word", StringType, true) 520 | ) 521 | ) 522 | val expectedDF = spark.createDF( 523 | List( 524 | ("word", 1), 525 | ("word", 5) 526 | ), 527 | List( 528 | ("word", StringType, true), 529 | ("number", IntegerType, true) 530 | ) 531 | ) 532 | assertLargeDatasetEquality(sourceDF, expectedDF, ignoreColumnOrder = true) 533 | } 534 | 535 | "can performed Dataset comparisons with unordered column" in { 536 | val ds1 = Seq( 537 | Person("juan", 5), 538 | Person("bob", 1), 539 | Person("li", 49), 540 | Person("alice", 5) 541 | ).toDS 542 | 543 | val ds2 = Seq( 544 | Person("juan", 5), 545 | Person("bob", 1), 546 | Person("li", 49), 547 | Person("alice", 5) 548 | ).toDS.select("age", "name").as(ds1.encoder) 549 | 550 | assertLargeDatasetEquality(ds1, ds2, ignoreColumnOrder = true) 551 | assertLargeDatasetEquality(ds2, ds1, ignoreColumnOrder = true) 552 | } 553 | 554 | "correctly mark unequal schema field" in { 555 | val sourceDF = spark.createDF( 556 | List( 557 | (1, 2.0), 558 | (5, 3.0) 559 | ), 560 | List( 561 | ("number", IntegerType, true), 562 | ("float", DoubleType, true) 563 | ) 564 | ) 565 | 566 | val expectedDF = spark.createDF( 567 | List( 568 | (1, "word", 1L), 569 | (5, "word", 2L) 570 | ), 571 | List( 572 | ("number", IntegerType, true), 573 | ("word", StringType, true), 574 | ("long", LongType, true) 575 | ) 576 | ) 577 | 578 | val e = intercept[DatasetSchemaMismatch] { 579 | assertLargeDatasetEquality(sourceDF, expectedDF) 580 | } 581 | 582 | e.assertColorDiff(Seq("float", "DoubleType", "MISSING"), Seq("word", "StringType", "StructField(long,LongType,true,{})")) 583 | } 584 | 585 | "can performed Dataset comparisons and ignore metadata" in { 586 | val ds1 = Seq( 587 | Person("juan", 5), 588 | Person("bob", 1), 589 | Person("li", 49), 590 | Person("alice", 5) 591 | ).toDS 592 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build())) 593 | .as[Person] 594 | 595 | val ds2 = Seq( 596 | Person("juan", 5), 597 | Person("bob", 1), 598 | Person("li", 49), 599 | Person("alice", 5) 600 | ).toDS 601 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build())) 602 | .as[Person] 603 | 604 | assertLargeDatasetEquality(ds2, ds1) 605 | } 606 | 607 | "can performed Dataset comparisons and compare metadata" in { 608 | val ds1 = Seq( 609 | Person("juan", 5), 610 | Person("bob", 1), 611 | Person("li", 49), 612 | Person("alice", 5) 613 | ).toDS 614 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build())) 615 | .as[Person] 616 | 617 | val ds2 = Seq( 618 | Person("juan", 5), 619 | Person("bob", 1), 620 | Person("li", 49), 621 | Person("alice", 5) 622 | ).toDS 623 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build())) 624 | .as[Person] 625 | 626 | intercept[DatasetSchemaMismatch] { 627 | assertLargeDatasetEquality(ds2, ds1, ignoreMetadata = false) 628 | } 629 | } 630 | } 631 | 632 | "assertSmallDatasetEquality" - { 633 | import spark.implicits._ 634 | 635 | "ignores the nullable flag when making DataFrame comparisons" in { 636 | val sourceDF = spark.createDF( 637 | List( 638 | (1), 639 | (5) 640 | ), 641 | List(("number", IntegerType, false)) 642 | ) 643 | 644 | val expectedDF = spark.createDF( 645 | List( 646 | (1), 647 | (5) 648 | ), 649 | List(("number", IntegerType, true)) 650 | ) 651 | 652 | assertSmallDatasetEquality(sourceDF, expectedDF, ignoreNullable = true) 653 | } 654 | 655 | "should not ignore nullable if ignoreNullable is false" in { 656 | val sourceDF = spark.createDF( 657 | List( 658 | (1), 659 | (5) 660 | ), 661 | List(("number", IntegerType, false)) 662 | ) 663 | 664 | val expectedDF = spark.createDF( 665 | List( 666 | (1), 667 | (5) 668 | ), 669 | List(("number", IntegerType, true)) 670 | ) 671 | 672 | intercept[DatasetSchemaMismatch] { 673 | assertSmallDatasetEquality(sourceDF, expectedDF) 674 | } 675 | } 676 | 677 | "can performed unordered DataFrame comparisons" in { 678 | val sourceDF = spark.createDF( 679 | List( 680 | (1), 681 | (5) 682 | ), 683 | List(("number", IntegerType, true)) 684 | ) 685 | val expectedDF = spark.createDF( 686 | List( 687 | (5), 688 | (1) 689 | ), 690 | List(("number", IntegerType, true)) 691 | ) 692 | assertSmallDatasetEquality(sourceDF, expectedDF, orderedComparison = false) 693 | } 694 | 695 | "can performed unordered Dataset comparisons" in { 696 | val sourceDS = spark.createDataset[Person]( 697 | Seq( 698 | Person("bob", 1), 699 | Person("alice", 5) 700 | ) 701 | ) 702 | val expectedDS = spark.createDataset[Person]( 703 | Seq( 704 | Person("alice", 5), 705 | Person("bob", 1) 706 | ) 707 | ) 708 | assertSmallDatasetEquality(sourceDS, expectedDS, orderedComparison = false) 709 | } 710 | 711 | "throws an error for unordered Dataset comparisons that don't match" in { 712 | val sourceDS = spark.createDataset[Person]( 713 | Seq( 714 | Person("bob", 1), 715 | Person("frank", 5) 716 | ) 717 | ) 718 | val expectedDS = spark.createDataset[Person]( 719 | Seq( 720 | Person("frank", 5), 721 | Person("bob", 1), 722 | Person("sadie", 2) 723 | ) 724 | ) 725 | val e = intercept[DatasetContentMismatch] { 726 | assertSmallDatasetEquality(sourceDS, expectedDS, orderedComparison = false) 727 | } 728 | } 729 | 730 | "throws an error for unordered DataFrame comparisons that don't match" in { 731 | val sourceDF = spark.createDF( 732 | List( 733 | (1), 734 | (5) 735 | ), 736 | List(("number", IntegerType, true)) 737 | ) 738 | val expectedDF = spark.createDF( 739 | List( 740 | (5), 741 | (1), 742 | (10) 743 | ), 744 | List(("number", IntegerType, true)) 745 | ) 746 | val e = intercept[DatasetContentMismatch] { 747 | assertSmallDatasetEquality(sourceDF, expectedDF, orderedComparison = false) 748 | } 749 | } 750 | 751 | "throws an error DataFrames have a different number of rows" in { 752 | val sourceDF = spark.createDF( 753 | List( 754 | (1), 755 | (5) 756 | ), 757 | List(("number", IntegerType, true)) 758 | ) 759 | val expectedDF = spark.createDF( 760 | List( 761 | (1), 762 | (5), 763 | (10) 764 | ), 765 | List(("number", IntegerType, true)) 766 | ) 767 | val e = intercept[DatasetContentMismatch] { 768 | assertSmallDatasetEquality(sourceDF, expectedDF) 769 | } 770 | } 771 | 772 | "can performed DataFrame comparisons with unordered column" in { 773 | val sourceDF = spark.createDF( 774 | List( 775 | (1, "word"), 776 | (5, "word") 777 | ), 778 | List( 779 | ("number", IntegerType, true), 780 | ("word", StringType, true) 781 | ) 782 | ) 783 | val expectedDF = spark.createDF( 784 | List( 785 | ("word", 1), 786 | ("word", 5) 787 | ), 788 | List( 789 | ("word", StringType, true), 790 | ("number", IntegerType, true) 791 | ) 792 | ) 793 | assertSmallDatasetEquality(sourceDF, expectedDF, ignoreColumnOrder = true) 794 | } 795 | 796 | "can performed Dataset comparisons with unordered column" in { 797 | val ds1 = Seq( 798 | Person("juan", 5), 799 | Person("bob", 1), 800 | Person("li", 49), 801 | Person("alice", 5) 802 | ).toDS 803 | 804 | val ds2 = Seq( 805 | Person("juan", 5), 806 | Person("bob", 1), 807 | Person("li", 49), 808 | Person("alice", 5) 809 | ).toDS.select("age", "name").as(ds1.encoder) 810 | 811 | assertSmallDatasetEquality(ds2, ds1, ignoreColumnOrder = true) 812 | } 813 | 814 | "correctly mark unequal schema field" in { 815 | val sourceDF = spark.createDF( 816 | List( 817 | (1, 2.0), 818 | (5, 3.0) 819 | ), 820 | List( 821 | ("number", IntegerType, true), 822 | ("float", DoubleType, true) 823 | ) 824 | ) 825 | 826 | val expectedDF = spark.createDF( 827 | List( 828 | (1, "word", 1L), 829 | (5, "word", 2L) 830 | ), 831 | List( 832 | ("number", IntegerType, true), 833 | ("word", StringType, true), 834 | ("long", LongType, true) 835 | ) 836 | ) 837 | 838 | val e = intercept[DatasetSchemaMismatch] { 839 | assertSmallDatasetEquality(sourceDF, expectedDF) 840 | } 841 | 842 | e.assertColorDiff(Seq("float", "DoubleType", "MISSING"), Seq("word", "StringType", "StructField(long,LongType,true,{})")) 843 | } 844 | 845 | "correctly mark schema with unequal metadata" in { 846 | val sourceDF = spark.createDF( 847 | List( 848 | (1, 2.0), 849 | (5, 3.0) 850 | ), 851 | List( 852 | ("number", IntegerType, true), 853 | ("float", DoubleType, true) 854 | ) 855 | ) 856 | 857 | val expectedDF = spark 858 | .createDF( 859 | List( 860 | (1, 2.0), 861 | (5, 3.0) 862 | ), 863 | List( 864 | ("number", IntegerType, true), 865 | ("float", DoubleType, true) 866 | ) 867 | ) 868 | .withColumn("float", col("float").as("float", new MetadataBuilder().putString("description", "a float").build())) 869 | 870 | val e = intercept[DatasetSchemaMismatch] { 871 | assertSmallDatasetEquality(sourceDF, expectedDF, ignoreMetadata = false) 872 | } 873 | 874 | e.assertColorDiff( 875 | Seq("{}"), 876 | Seq("{\"description\":\"a float\"}") 877 | ) 878 | } 879 | 880 | "can performed Dataset comparisons and ignore metadata" in { 881 | val ds1 = Seq( 882 | Person("juan", 5), 883 | Person("bob", 1), 884 | Person("li", 49), 885 | Person("alice", 5) 886 | ).toDS 887 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build())) 888 | .as[Person] 889 | 890 | val ds2 = Seq( 891 | Person("juan", 5), 892 | Person("bob", 1), 893 | Person("li", 49), 894 | Person("alice", 5) 895 | ).toDS 896 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build())) 897 | .as[Person] 898 | 899 | assertSmallDatasetEquality(ds2, ds1) 900 | } 901 | 902 | "can performed Dataset comparisons and compare metadata" in { 903 | val ds1 = Seq( 904 | Person("juan", 5), 905 | Person("bob", 1), 906 | Person("li", 49), 907 | Person("alice", 5) 908 | ).toDS 909 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build())) 910 | .as[Person] 911 | 912 | val ds2 = Seq( 913 | Person("juan", 5), 914 | Person("bob", 1), 915 | Person("li", 49), 916 | Person("alice", 5) 917 | ).toDS 918 | .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build())) 919 | .as[Person] 920 | 921 | intercept[DatasetSchemaMismatch] { 922 | assertSmallDatasetEquality(ds2, ds1, ignoreMetadata = false) 923 | } 924 | } 925 | } 926 | 927 | "defaultSortDataset" - { 928 | 929 | "sorts a DataFrame by the column names in alphabetical order" in { 930 | val sourceDF = spark.createDF( 931 | List( 932 | (5, "bob"), 933 | (1, "phil"), 934 | (5, "anne") 935 | ), 936 | List( 937 | ("fun_level", IntegerType, true), 938 | ("name", StringType, true) 939 | ) 940 | ) 941 | val actualDF = defaultSortDataset(sourceDF) 942 | val expectedDF = spark.createDF( 943 | List( 944 | (1, "phil"), 945 | (5, "anne"), 946 | (5, "bob") 947 | ), 948 | List( 949 | ("fun_level", IntegerType, true), 950 | ("name", StringType, true) 951 | ) 952 | ) 953 | assertSmallDatasetEquality(actualDF, expectedDF) 954 | } 955 | 956 | } 957 | 958 | "assertApproximateDataFrameEquality" - { 959 | 960 | "does nothing if the DataFrames have the same schemas and content" in { 961 | val sourceDF = spark.createDF( 962 | List( 963 | (1.2), 964 | (5.1), 965 | (null) 966 | ), 967 | List(("number", DoubleType, true)) 968 | ) 969 | val expectedDF = spark.createDF( 970 | List( 971 | (1.2), 972 | (5.1), 973 | (null) 974 | ), 975 | List(("number", DoubleType, true)) 976 | ) 977 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01) 978 | } 979 | 980 | "throws an error if the rows are different" in { 981 | val sourceDF = spark.createDF( 982 | List( 983 | (100.9), 984 | (5.1) 985 | ), 986 | List(("number", DoubleType, true)) 987 | ) 988 | val expectedDF = spark.createDF( 989 | List( 990 | (1.2), 991 | (5.1) 992 | ), 993 | List(("number", DoubleType, true)) 994 | ) 995 | val e = intercept[DatasetContentMismatch] { 996 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01) 997 | } 998 | } 999 | 1000 | "throws an error DataFrames have a different number of rows" in { 1001 | val sourceDF = spark.createDF( 1002 | List( 1003 | (1.2), 1004 | (5.1), 1005 | (8.8) 1006 | ), 1007 | List(("number", DoubleType, true)) 1008 | ) 1009 | val expectedDF = spark.createDF( 1010 | List( 1011 | (1.2), 1012 | (5.1) 1013 | ), 1014 | List(("number", DoubleType, true)) 1015 | ) 1016 | val e = intercept[DatasetCountMismatch] { 1017 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01) 1018 | } 1019 | } 1020 | 1021 | "can ignore the nullable property" in { 1022 | val sourceDF = spark.createDF( 1023 | List( 1024 | (1.2), 1025 | (5.1) 1026 | ), 1027 | List(("number", DoubleType, false)) 1028 | ) 1029 | val expectedDF = spark.createDF( 1030 | List( 1031 | (1.2), 1032 | (5.1) 1033 | ), 1034 | List(("number", DoubleType, true)) 1035 | ) 1036 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreNullable = true) 1037 | } 1038 | 1039 | "can ignore the column names" in { 1040 | val sourceDF = spark.createDF( 1041 | List( 1042 | (1.2), 1043 | (5.1), 1044 | (null) 1045 | ), 1046 | List(("BLAHBLBH", DoubleType, true)) 1047 | ) 1048 | val expectedDF = spark.createDF( 1049 | List( 1050 | (1.2), 1051 | (5.1), 1052 | (null) 1053 | ), 1054 | List(("number", DoubleType, true)) 1055 | ) 1056 | assertApproximateDataFrameEquality(sourceDF, expectedDF, 0.01, ignoreColumnNames = true) 1057 | } 1058 | 1059 | "can work with precision and unordered comparison" in { 1060 | import spark.implicits._ 1061 | val ds1 = Seq( 1062 | ("1", "10/01/2019", 26.762499999999996), 1063 | ("1", "11/01/2019", 26.762499999999996) 1064 | ).toDF("col_B", "col_C", "col_A") 1065 | 1066 | val ds2 = Seq( 1067 | ("1", "10/01/2019", 26.762499999999946), 1068 | ("1", "11/01/2019", 26.76249999999991) 1069 | ).toDF("col_B", "col_C", "col_A") 1070 | 1071 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 1072 | } 1073 | 1074 | "can work with precision and unordered comparison 2" in { 1075 | import spark.implicits._ 1076 | val ds1 = Seq( 1077 | ("1", "10/01/2019", 26.762499999999996, "A"), 1078 | ("1", "10/01/2019", 26.762499999999996, "B") 1079 | ).toDF("col_B", "col_C", "col_A", "col_D") 1080 | 1081 | val ds2 = Seq( 1082 | ("1", "10/01/2019", 26.762499999999946, "A"), 1083 | ("1", "10/01/2019", 26.76249999999991, "B") 1084 | ).toDF("col_B", "col_C", "col_A", "col_D") 1085 | 1086 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 1087 | } 1088 | 1089 | "can work with precision and unordered comparison on nested column" in { 1090 | import spark.implicits._ 1091 | val ds1 = Seq( 1092 | ("1", "10/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)), 1093 | ("1", "11/01/2019", 26.762499999999996, Seq(26.762499999999996, 26.762499999999996)) 1094 | ).toDF("col_B", "col_C", "col_A", "col_D") 1095 | 1096 | val ds2 = Seq( 1097 | ("1", "11/01/2019", 26.7624999999999961, Seq(26.7624999999999961, 26.7624999999999961)), 1098 | ("1", "10/01/2019", 26.762499999999997, Seq(26.762499999999997, 26.762499999999997)) 1099 | ).toDF("col_B", "col_C", "col_A", "col_D") 1100 | 1101 | assertApproximateDataFrameEquality(ds1, ds2, precision = 0.0000001, orderedComparison = false) 1102 | } 1103 | } 1104 | 1105 | // "works with FloatType columns" - { 1106 | // val sourceDF = spark.createDF( 1107 | // List( 1108 | // (1.2), 1109 | // (5.1), 1110 | // (null) 1111 | // ), 1112 | // List( 1113 | // ("number", FloatType, true) 1114 | // ) 1115 | // ) 1116 | // 1117 | // val expectedDF = spark.createDF( 1118 | // List( 1119 | // (1.2), 1120 | // (5.1), 1121 | // (null) 1122 | // ), 1123 | // List( 1124 | // ("number", FloatType, true) 1125 | // ) 1126 | // ) 1127 | // 1128 | // assertApproximateDataFrameEquality( 1129 | // sourceDF, 1130 | // expectedDF, 1131 | // 0.01 1132 | // ) 1133 | // } 1134 | 1135 | } 1136 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/ExamplesTest.scala: -------------------------------------------------------------------------------- 1 | //package com.github.mrpowers.spark.fast.tests 2 | // 3 | //import org.apache.spark.sql.types._ 4 | //import SparkSessionExt._ 5 | // 6 | //import org.scalatest.FreeSpec 7 | // 8 | //class ExamplesTest extends FreeSpec with SparkSessionTestWrapper with DataFrameComparer with ColumnComparer { 9 | // 10 | // "assertSmallDatasetEquality" - { 11 | // 12 | // "error when row counts don't match" in { 13 | // 14 | // val sourceDF = spark.createDF( 15 | // List( 16 | // (1), 17 | // (5) 18 | // ), 19 | // List(("number", IntegerType, true)) 20 | // ) 21 | // 22 | // val expectedDF = spark.createDF( 23 | // List( 24 | // (1), 25 | // (5), 26 | // (10) 27 | // ), 28 | // List(("number", IntegerType, true)) 29 | // ) 30 | // 31 | // assertSmallDatasetEquality( 32 | // sourceDF, 33 | // expectedDF 34 | // ) 35 | // 36 | // } 37 | // 38 | // "error when schemas don't match" in { 39 | // 40 | // val sourceDF = spark.createDF( 41 | // List( 42 | // (1, "a"), 43 | // (5, "b") 44 | // ), 45 | // List( 46 | // ("number", IntegerType, true), 47 | // ("letter", StringType, true) 48 | // ) 49 | // ) 50 | // 51 | // val expectedDF = spark.createDF( 52 | // List( 53 | // (1, "a"), 54 | // (5, "b") 55 | // ), 56 | // List( 57 | // ("num", IntegerType, true), 58 | // ("letter", StringType, true) 59 | // ) 60 | // ) 61 | // 62 | // assertSmallDatasetEquality( 63 | // sourceDF, 64 | // expectedDF 65 | // ) 66 | // 67 | // } 68 | // 69 | // "error when content doesn't match" in { 70 | // 71 | // val sourceDF = spark.createDF( 72 | // List( 73 | // (1, "z"), 74 | // (5, "b") 75 | // ), 76 | // List( 77 | // ("number", IntegerType, true), 78 | // ("letter", StringType, true) 79 | // ) 80 | // ) 81 | // 82 | // val expectedDF = spark.createDF( 83 | // List( 84 | // (1111, "a"), 85 | // (5, "b") 86 | // ), 87 | // List( 88 | // ("number", IntegerType, true), 89 | // ("letter", StringType, true) 90 | // ) 91 | // ) 92 | // 93 | // assertSmallDataFrameEquality( 94 | // sourceDF, 95 | // expectedDF 96 | // ) 97 | // 98 | // } 99 | // 100 | // } 101 | // 102 | // "pretty column mismatch message" in { 103 | // 104 | // val df = spark.createDF( 105 | // List( 106 | // ("a", "z"), 107 | // ("b", "b") 108 | // ), 109 | // List( 110 | // ("letter1", StringType, true), 111 | // ("letter", StringType, true) 112 | // ) 113 | // ) 114 | // 115 | // assertColumnEquality(df, "letter1", "letter") 116 | // 117 | // } 118 | // 119 | //} 120 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/RDDComparerTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.scalatest.freespec.AnyFreeSpec 4 | 5 | class RDDComparerTest extends AnyFreeSpec with RDDComparer with SparkSessionTestWrapper { 6 | 7 | "contentMismatchMessage" - { 8 | 9 | "returns a string string to compare the expected and actual RDDs" in { 10 | val sourceData = List( 11 | ("cat"), 12 | ("dog"), 13 | ("frog") 14 | ) 15 | 16 | val actualRDD = spark.sparkContext.parallelize(sourceData) 17 | 18 | val expectedData = List( 19 | ("man"), 20 | ("can"), 21 | ("pan") 22 | ) 23 | 24 | val expectedRDD = spark.sparkContext.parallelize(expectedData) 25 | 26 | val expected = """ 27 | Actual RDD Content: 28 | cat 29 | dog 30 | frog 31 | Expected RDD Content: 32 | man 33 | can 34 | pan 35 | """ 36 | 37 | assert( 38 | contentMismatchMessage(actualRDD, expectedRDD) == expected 39 | ) 40 | } 41 | 42 | } 43 | 44 | "assertSmallRDDEquality" - { 45 | 46 | "does nothing if the RDDs have the same content" in { 47 | val sourceData = List( 48 | ("cat"), 49 | ("dog"), 50 | ("frog") 51 | ) 52 | 53 | val sourceRDD = spark.sparkContext.parallelize(sourceData) 54 | 55 | val expectedData = List( 56 | ("cat"), 57 | ("dog"), 58 | ("frog") 59 | ) 60 | 61 | val expectedRDD = spark.sparkContext.parallelize(expectedData) 62 | 63 | assertSmallRDDEquality(sourceRDD, expectedRDD) 64 | } 65 | 66 | "throws an error if the RDDs have different content" in { 67 | val sourceData = List( 68 | ("cat"), 69 | ("dog"), 70 | ("frog") 71 | ) 72 | val sourceRDD = spark.sparkContext.parallelize(sourceData) 73 | 74 | val expectedData = List( 75 | ("mouse"), 76 | ("pig"), 77 | ("frog") 78 | ) 79 | 80 | val expectedRDD = spark.sparkContext.parallelize(expectedData) 81 | 82 | val e = intercept[RDDContentMismatch] { 83 | assertSmallRDDEquality(sourceRDD, expectedRDD) 84 | } 85 | } 86 | 87 | } 88 | 89 | } 90 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/RowComparerTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.scalatest.freespec.AnyFreeSpec 4 | 5 | import org.apache.spark.sql.Row 6 | 7 | class RowComparerTest extends AnyFreeSpec { 8 | 9 | "areRowsEqual" - { 10 | 11 | "returns true for rows that contain the same elements" in { 12 | val r1 = Row("a", "b") 13 | val r2 = Row("a", "b") 14 | assert( 15 | RowComparer.areRowsEqual(r1, r2, 0.0) 16 | ) 17 | } 18 | 19 | "returns false for rows that don't contain the same elements" - { 20 | val r1 = Row("a", 3) 21 | val r2 = Row("a", 4) 22 | assert( 23 | !RowComparer.areRowsEqual(r1, r2, 0.0) 24 | ) 25 | } 26 | 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/SchemaComparerTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch 4 | import org.apache.spark.sql.types._ 5 | import org.scalatest.freespec.AnyFreeSpec 6 | 7 | class SchemaComparerTest extends AnyFreeSpec { 8 | 9 | "equals" - { 10 | 11 | "returns true if the schemas are equal" in { 12 | val s1 = StructType( 13 | Seq( 14 | StructField("something", StringType, true), 15 | StructField("mood", StringType, true) 16 | ) 17 | ) 18 | val s2 = StructType( 19 | Seq( 20 | StructField("something", StringType, true), 21 | StructField("mood", StringType, true) 22 | ) 23 | ) 24 | assert(SchemaComparer.equals(s1, s2)) 25 | } 26 | 27 | "works for single column schemas" in { 28 | val s1 = StructType( 29 | Seq( 30 | StructField("something", StringType, true) 31 | ) 32 | ) 33 | val s2 = StructType( 34 | Seq( 35 | StructField("something", StringType, false) 36 | ) 37 | ) 38 | assert(SchemaComparer.equals(s1, s2, true)) 39 | } 40 | 41 | "returns false if the schemas aren't equal" in { 42 | val s1 = StructType( 43 | Seq( 44 | StructField("something", StringType, true) 45 | ) 46 | ) 47 | val s2 = StructType( 48 | Seq( 49 | StructField("something", StringType, true), 50 | StructField("mood", StringType, true) 51 | ) 52 | ) 53 | assert(!SchemaComparer.equals(s1, s2)) 54 | } 55 | 56 | "can ignore the nullable flag when determining equality" in { 57 | val s1 = StructType( 58 | Seq( 59 | StructField("something", StringType, true), 60 | StructField("mood", StringType, true) 61 | ) 62 | ) 63 | val s2 = StructType( 64 | Seq( 65 | StructField("something", StringType, false), 66 | StructField("mood", StringType, true) 67 | ) 68 | ) 69 | assert(SchemaComparer.equals(s1, s2, ignoreNullable = true)) 70 | } 71 | 72 | "do not ignore nullable when determining equality if ignoreNullable is true" in { 73 | val s1 = StructType( 74 | Seq( 75 | StructField("something", StringType, true), 76 | StructField("mood", StringType, true) 77 | ) 78 | ) 79 | val s2 = StructType( 80 | Seq( 81 | StructField("something", StringType, false), 82 | StructField("mood", StringType, true) 83 | ) 84 | ) 85 | assert(!SchemaComparer.equals(s1, s2)) 86 | } 87 | 88 | "can ignore the nullable flag when determining equality on complex data types" in { 89 | val s1 = StructType( 90 | Seq( 91 | StructField("something", StringType, true), 92 | StructField("array", ArrayType(StringType, containsNull = true), true), 93 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 94 | StructField( 95 | "struct", 96 | StructType( 97 | StructType( 98 | Seq( 99 | StructField("something", StringType, false), 100 | StructField("mood", ArrayType(StringType, containsNull = false), true) 101 | ) 102 | ) 103 | ), 104 | true 105 | ) 106 | ) 107 | ) 108 | val s2 = StructType( 109 | Seq( 110 | StructField("something", StringType, false), 111 | StructField("array", ArrayType(StringType, containsNull = false), true), 112 | StructField("map", MapType(StringType, StringType, valueContainsNull = true), true), 113 | StructField( 114 | "struct", 115 | StructType( 116 | StructType( 117 | Seq( 118 | StructField("something", StringType, false), 119 | StructField("mood", ArrayType(StringType, containsNull = true), true) 120 | ) 121 | ) 122 | ), 123 | false 124 | ) 125 | ) 126 | ) 127 | assert(SchemaComparer.equals(s1, s2, ignoreNullable = true)) 128 | } 129 | 130 | "do not ignore nullable when determining equality on complex data types if ignoreNullable is true" in { 131 | val s1 = StructType( 132 | Seq( 133 | StructField("something", StringType, true), 134 | StructField("array", ArrayType(StringType, containsNull = true), true), 135 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 136 | StructField( 137 | "struct", 138 | StructType( 139 | StructType( 140 | Seq( 141 | StructField("something", StringType, false), 142 | StructField("mood", ArrayType(StringType, containsNull = false), true) 143 | ) 144 | ) 145 | ), 146 | true 147 | ) 148 | ) 149 | ) 150 | val s2 = StructType( 151 | Seq( 152 | StructField("something", StringType, false), 153 | StructField("array", ArrayType(StringType, containsNull = false), true), 154 | StructField("map", MapType(StringType, StringType, valueContainsNull = true), true), 155 | StructField( 156 | "struct", 157 | StructType( 158 | StructType( 159 | Seq( 160 | StructField("something", StringType, false), 161 | StructField("mood", ArrayType(StringType, containsNull = true), true) 162 | ) 163 | ) 164 | ), 165 | false 166 | ) 167 | ) 168 | ) 169 | assert(!SchemaComparer.equals(s1, s2)) 170 | } 171 | 172 | "can ignore the column names flag when determining equality" in { 173 | val s1 = StructType( 174 | Seq( 175 | StructField("these", StringType, true), 176 | StructField("are", StringType, true) 177 | ) 178 | ) 179 | val s2 = StructType( 180 | Seq( 181 | StructField("very", StringType, true), 182 | StructField("different", StringType, true) 183 | ) 184 | ) 185 | assert(SchemaComparer.equals(s1, s2, ignoreColumnNames = true)) 186 | } 187 | 188 | "can ignore the column order when determining equality" in { 189 | val s1 = StructType( 190 | Seq( 191 | StructField("these", StringType, true), 192 | StructField("are", StringType, true) 193 | ) 194 | ) 195 | val s2 = StructType( 196 | Seq( 197 | StructField("are", StringType, true), 198 | StructField("these", StringType, true) 199 | ) 200 | ) 201 | assert(SchemaComparer.equals(s1, s2)) 202 | } 203 | 204 | "can ignore the column order when determining equality of complex type" in { 205 | val s1 = StructType( 206 | Seq( 207 | StructField("array", ArrayType(StringType, containsNull = true), true), 208 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 209 | StructField("something", StringType, true), 210 | StructField( 211 | "struct", 212 | StructType( 213 | StructType( 214 | Seq( 215 | StructField("mood", ArrayType(StringType, containsNull = false), true), 216 | StructField("something", StringType, false) 217 | ) 218 | ) 219 | ), 220 | true 221 | ) 222 | ) 223 | ) 224 | val s2 = StructType( 225 | Seq( 226 | StructField("something", StringType, true), 227 | StructField("array", ArrayType(StringType, containsNull = true), true), 228 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 229 | StructField( 230 | "struct", 231 | StructType( 232 | StructType( 233 | Seq( 234 | StructField("something", StringType, false), 235 | StructField("mood", ArrayType(StringType, containsNull = false), true) 236 | ) 237 | ) 238 | ), 239 | true 240 | ) 241 | ) 242 | ) 243 | assert(SchemaComparer.equals(s1, s2)) 244 | } 245 | 246 | "display schema diff as tree with different depth" in { 247 | val s1 = StructType( 248 | Seq( 249 | StructField("array", ArrayType(StringType, containsNull = true), true), 250 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 251 | StructField("something", StringType, true), 252 | StructField( 253 | "struct", 254 | StructType( 255 | StructType( 256 | Seq( 257 | StructField("mood", ArrayType(StringType, containsNull = false), true), 258 | StructField("something", StringType, false), 259 | StructField( 260 | "something2", 261 | StructType( 262 | Seq( 263 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 264 | StructField("something2", StringType, false) 265 | ) 266 | ), 267 | false 268 | ) 269 | ) 270 | ) 271 | ), 272 | true 273 | ) 274 | ) 275 | ) 276 | val s2 = StructType( 277 | Seq( 278 | StructField("array", ArrayType(StringType, containsNull = true), true), 279 | StructField("something", StringType, true), 280 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 281 | StructField( 282 | "struct", 283 | StructType( 284 | StructType( 285 | Seq( 286 | StructField("something", StringType, false), 287 | StructField("mood", ArrayType(StringType, containsNull = false), true), 288 | StructField( 289 | "something3", 290 | StructType( 291 | Seq( 292 | StructField("mood3", ArrayType(StringType, containsNull = false), true) 293 | ) 294 | ), 295 | false 296 | ) 297 | ) 298 | ) 299 | ), 300 | true 301 | ), 302 | StructField("norma2", StringType, false) 303 | ) 304 | ) 305 | 306 | val e = intercept[DatasetSchemaMismatch] { 307 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree) 308 | } 309 | val expectedMessage = """Diffs 310 | | 311 | |Actual Schema Expected Schema 312 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 313 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) 314 | |\u001b[90m|--\u001b[39m \u001b[31mmap\u001b[39m : \u001b[31mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 315 | |\u001b[90m|--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32mmap\u001b[39m : \u001b[32mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 316 | |\u001b[90m|--\u001b[39m \u001b[90mstruct\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90mstruct\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 317 | |\u001b[90m| |--\u001b[39m \u001b[31mmood\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[31mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[32mfalse\u001b[39m) 318 | |\u001b[31m| | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[32m| |--\u001b[39m \u001b[32mmood\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 319 | |\u001b[31m| |--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) \u001b[32m| | |--\u001b[39m \u001b[32melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) 320 | |\u001b[90m| |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething3\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) 321 | |\u001b[90m| | |--\u001b[39m \u001b[31mmood2\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[32mmood3\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 322 | |\u001b[90m| | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) 323 | |\u001b[31m| | |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[32m|--\u001b[39m \u001b[32mnorma2\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) 324 | |""".stripMargin 325 | 326 | assert(e.getMessage == expectedMessage) 327 | } 328 | 329 | "display schema diff for tree with array of struct" in { 330 | val s1 = StructType( 331 | Seq( 332 | StructField("array", ArrayType(StructType(Seq(StructField("arrayChild1", StringType))), containsNull = true), true) 333 | ) 334 | ) 335 | val s2 = StructType( 336 | Seq( 337 | StructField("array", ArrayType(StructType(Seq(StructField("arrayChild2", IntegerType))), containsNull = false), true) 338 | ) 339 | ) 340 | s1.printTreeString() 341 | 342 | val e = intercept[DatasetSchemaMismatch] { 343 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree) 344 | } 345 | 346 | val expectedMessage = """Diffs 347 | | 348 | |Actual Schema Expected Schema 349 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 350 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31mstruct\u001b[39m (containsNull = \u001b[31mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32mstruct\u001b[39m (containsNull = \u001b[32mfalse\u001b[39m) 351 | |\u001b[90m| | |--\u001b[39m \u001b[31marrayChild1\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[32marrayChild2\u001b[39m : \u001b[32minteger\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 352 | |""".stripMargin 353 | 354 | assert(e.getMessage == expectedMessage) 355 | } 356 | 357 | "display schema diff for tree with array of array of struct" in { 358 | val s1 = StructType( 359 | Seq( 360 | StructField("array", ArrayType(ArrayType(StructType(Seq(StructField("arrayChild1", StringType))), containsNull = true))) 361 | ) 362 | ) 363 | val s2 = StructType( 364 | Seq( 365 | StructField("array", ArrayType(ArrayType(StructType(Seq(StructField("arrayChild2", IntegerType))), containsNull = false))) 366 | ) 367 | ) 368 | s1.printTreeString() 369 | 370 | val e = intercept[DatasetSchemaMismatch] { 371 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree) 372 | } 373 | 374 | val expectedMessage = """Diffs 375 | | 376 | |Actual Schema Expected Schema 377 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 378 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31marray\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32marray\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) 379 | |\u001b[90m| | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31mstruct\u001b[39m (containsNull = \u001b[31mtrue\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32mstruct\u001b[39m (containsNull = \u001b[32mfalse\u001b[39m) 380 | |\u001b[90m| | | |--\u001b[39m \u001b[31marrayChild1\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[32marrayChild2\u001b[39m : \u001b[32minteger\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 381 | |""".stripMargin 382 | 383 | assert(e.getMessage == expectedMessage) 384 | } 385 | 386 | "display schema diff for tree with array of simple type" in { 387 | val s1 = StructType( 388 | Seq( 389 | StructField("array", ArrayType(StringType, containsNull = true), true) 390 | ) 391 | ) 392 | val s2 = StructType( 393 | Seq( 394 | StructField("array", ArrayType(IntegerType, containsNull = true), true) 395 | ) 396 | ) 397 | s1.printTreeString() 398 | 399 | val e = intercept[DatasetSchemaMismatch] { 400 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree) 401 | } 402 | 403 | val expectedMessage = """Diffs 404 | | 405 | |Actual Schema Expected Schema 406 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 407 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[31mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[32minteger\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) 408 | |""".stripMargin 409 | 410 | assert(e.getMessage == expectedMessage) 411 | } 412 | 413 | "display schema diff for wide tree" in { 414 | val s1 = StructType( 415 | Seq( 416 | StructField("array", ArrayType(StringType, containsNull = true), true), 417 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 418 | StructField("something", StringType, true), 419 | StructField( 420 | "struct", 421 | StructType( 422 | StructType( 423 | Seq( 424 | StructField("mood", ArrayType(StringType, containsNull = false), true), 425 | StructField("something", StringType, false), 426 | StructField( 427 | "something2", 428 | StructType( 429 | Seq( 430 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 431 | StructField( 432 | "something2", 433 | StructType( 434 | Seq( 435 | StructField("mood", ArrayType(StringType, containsNull = false), true), 436 | StructField("something", StringType, false), 437 | StructField( 438 | "something2", 439 | StructType( 440 | Seq( 441 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 442 | StructField("something2", StringType, false) 443 | ) 444 | ), 445 | false 446 | ) 447 | ) 448 | ), 449 | false 450 | ) 451 | ) 452 | ), 453 | false 454 | ) 455 | ) 456 | ) 457 | ), 458 | true 459 | ) 460 | ) 461 | ) 462 | val s2 = StructType( 463 | Seq( 464 | StructField("array", ArrayType(StringType, containsNull = true), true), 465 | StructField("something", StringType, true), 466 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 467 | StructField( 468 | "struct", 469 | StructType( 470 | StructType( 471 | Seq( 472 | StructField("something", StringType, false), 473 | StructField("mood", ArrayType(StringType, containsNull = false), true), 474 | StructField( 475 | "something3", 476 | StructType( 477 | Seq( 478 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 479 | StructField( 480 | "something2", 481 | StructType( 482 | Seq( 483 | StructField("mood", ArrayType(StringType, containsNull = false), true), 484 | StructField("something", StringType, false), 485 | StructField( 486 | "something2", 487 | StructType( 488 | Seq( 489 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 490 | StructField("something2", StringType, false) 491 | ) 492 | ), 493 | false 494 | ) 495 | ) 496 | ), 497 | false 498 | ) 499 | ) 500 | ), 501 | false 502 | ) 503 | ) 504 | ) 505 | ), 506 | true 507 | ), 508 | StructField("norma2", StringType, false) 509 | ) 510 | ) 511 | 512 | val e = intercept[DatasetSchemaMismatch] { 513 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree) 514 | } 515 | val expectedMessage = """Diffs 516 | | 517 | |Actual Schema Expected Schema 518 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 519 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) 520 | |\u001b[90m|--\u001b[39m \u001b[31mmap\u001b[39m : \u001b[31mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 521 | |\u001b[90m|--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32mmap\u001b[39m : \u001b[32mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 522 | |\u001b[90m|--\u001b[39m \u001b[90mstruct\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90mstruct\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 523 | |\u001b[90m| |--\u001b[39m \u001b[31mmood\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[31mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[32mfalse\u001b[39m) 524 | |\u001b[31m| | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[32m| |--\u001b[39m \u001b[32mmood\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 525 | |\u001b[31m| |--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) \u001b[32m| | |--\u001b[39m \u001b[32melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) 526 | |\u001b[90m| |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething3\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) 527 | |\u001b[90m| | |--\u001b[39m \u001b[90mmood2\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[90mmood2\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 528 | |\u001b[90m| | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) 529 | |\u001b[90m| | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) 530 | |\u001b[90m| | | |--\u001b[39m \u001b[90mmood\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90mmood\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 531 | |\u001b[90m| | | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) 532 | |\u001b[90m| | | |--\u001b[39m \u001b[90msomething\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90msomething\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) 533 | |\u001b[90m| | | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| | | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) 534 | |\u001b[90m| | | | |--\u001b[39m \u001b[90mmood2\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| | | | |--\u001b[39m \u001b[90mmood2\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 535 | |\u001b[90m| | | | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | | | | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mdouble\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) 536 | |\u001b[90m| | | | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| | | | |--\u001b[39m \u001b[90msomething2\u001b[39m : \u001b[90mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) 537 | | \u001b[90m|--\u001b[39m \u001b[32mnorma2\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[32mfalse\u001b[39m) 538 | |""".stripMargin 539 | 540 | assert(e.getMessage == expectedMessage) 541 | } 542 | 543 | "display schema diff as tree with more actual Column 2" in { 544 | val s1 = StructType( 545 | Seq( 546 | StructField("array", ArrayType(StringType, containsNull = true), true), 547 | StructField("map", MapType(StringType, StringType, valueContainsNull = false), true), 548 | StructField("something", StringType, true), 549 | StructField( 550 | "struct", 551 | StructType( 552 | StructType( 553 | Seq( 554 | StructField("mood", ArrayType(StringType, containsNull = false), true), 555 | StructField("something", StringType, false), 556 | StructField( 557 | "something2", 558 | StructType( 559 | Seq( 560 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 561 | StructField( 562 | "something2", 563 | StructType( 564 | Seq( 565 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 566 | StructField( 567 | "something2", 568 | StructType( 569 | Seq( 570 | StructField("mood2", ArrayType(DoubleType, containsNull = false), true), 571 | StructField("something2", StringType, false) 572 | ) 573 | ), 574 | false 575 | ) 576 | ) 577 | ), 578 | false 579 | ) 580 | ) 581 | ), 582 | false 583 | ) 584 | ) 585 | ) 586 | ), 587 | true 588 | ) 589 | ) 590 | ) 591 | val s2 = StructType( 592 | Seq( 593 | StructField("array", ArrayType(StringType, containsNull = true), true), 594 | StructField("something", StringType, true), 595 | StructField( 596 | "struct", 597 | StructType( 598 | StructType( 599 | Seq( 600 | StructField("something", StringType, false), 601 | StructField("mood", ArrayType(StringType, containsNull = false), true), 602 | StructField( 603 | "something3", 604 | StructType( 605 | Seq( 606 | StructField("mood3", ArrayType(StringType, containsNull = false), true) 607 | ) 608 | ), 609 | false 610 | ) 611 | ) 612 | ) 613 | ), 614 | true 615 | ) 616 | ) 617 | ) 618 | 619 | val e = intercept[DatasetSchemaMismatch] { 620 | SchemaComparer.assertSchemaEqual(s1, s2, ignoreColumnOrder = false, outputFormat = SchemaDiffOutputFormat.Tree) 621 | } 622 | 623 | val expectedMessage = """Diffs 624 | | 625 | |Actual Schema Expected Schema 626 | |\u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[90marray\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 627 | |\u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mtrue\u001b[39m) 628 | |\u001b[90m|--\u001b[39m \u001b[31mmap\u001b[39m : \u001b[31mmap\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 629 | |\u001b[90m|--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m|--\u001b[39m \u001b[32mstruct\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 630 | |\u001b[31m|--\u001b[39m \u001b[31mstruct\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[31mtrue\u001b[39m) \u001b[32m| |--\u001b[39m \u001b[32msomething\u001b[39m : \u001b[32mstring\u001b[39m (nullable = \u001b[32mfalse\u001b[39m) 631 | |\u001b[90m| |--\u001b[39m \u001b[90mmood\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[90mmood\u001b[39m : \u001b[90marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) 632 | |\u001b[90m| | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) \u001b[90m| | |--\u001b[39m \u001b[90melement\u001b[39m : \u001b[90mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) 633 | |\u001b[90m| |--\u001b[39m \u001b[31msomething\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) \u001b[90m| |--\u001b[39m \u001b[32msomething3\u001b[39m : \u001b[32mstruct\u001b[39m (nullable = \u001b[90mfalse\u001b[39m) 634 | |\u001b[31m| |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) \u001b[32m| | |--\u001b[39m \u001b[32mmood3\u001b[39m : \u001b[32marray\u001b[39m (nullable = \u001b[32mtrue\u001b[39m) 635 | |\u001b[31m| | |--\u001b[39m \u001b[31mmood2\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[90mtrue\u001b[39m) \u001b[32m| | | |--\u001b[39m \u001b[32melement\u001b[39m : \u001b[32mstring\u001b[39m (containsNull = \u001b[90mfalse\u001b[39m) 636 | |\u001b[31m| | | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mdouble\u001b[39m (containsNull = \u001b[31mfalse\u001b[39m) 637 | |\u001b[31m| | |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) 638 | |\u001b[31m| | | |--\u001b[39m \u001b[31mmood2\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[31mtrue\u001b[39m) 639 | |\u001b[31m| | | | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mdouble\u001b[39m (containsNull = \u001b[31mfalse\u001b[39m) 640 | |\u001b[31m| | | |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstruct\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) 641 | |\u001b[31m| | | | |--\u001b[39m \u001b[31mmood2\u001b[39m : \u001b[31marray\u001b[39m (nullable = \u001b[31mtrue\u001b[39m) 642 | |\u001b[31m| | | | | |--\u001b[39m \u001b[31melement\u001b[39m : \u001b[31mdouble\u001b[39m (containsNull = \u001b[31mfalse\u001b[39m) 643 | |\u001b[31m| | | | |--\u001b[39m \u001b[31msomething2\u001b[39m : \u001b[31mstring\u001b[39m (nullable = \u001b[31mfalse\u001b[39m) 644 | |""".stripMargin 645 | 646 | assert(e.getMessage == expectedMessage) 647 | } 648 | } 649 | } 650 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/SeqLikesExtensionsTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.scalatest.freespec.AnyFreeSpec 4 | import SeqLikesExtensions._ 5 | 6 | class SeqLikesExtensionsTest extends AnyFreeSpec with SparkSessionTestWrapper { 7 | 8 | "check equality" - { 9 | import spark.implicits._ 10 | 11 | "check equal Seq" in { 12 | val source = Seq( 13 | ("juan", 5), 14 | ("bob", 1), 15 | ("li", 49), 16 | ("alice", 5) 17 | ) 18 | 19 | val expected = Seq( 20 | ("juan", 5), 21 | ("bob", 1), 22 | ("li", 49), 23 | ("alice", 5) 24 | ) 25 | 26 | assert(source.approximateSameElements(expected, (s1, s2) => s1 == s2)) 27 | } 28 | 29 | "check equal Seq[Row]" in { 30 | 31 | val source = Seq( 32 | ("juan", 5), 33 | ("bob", 1), 34 | ("li", 49), 35 | ("alice", 5) 36 | ).toDF.collect().toSeq 37 | 38 | val expected = Seq( 39 | ("juan", 5), 40 | ("bob", 1), 41 | ("li", 49), 42 | ("alice", 5) 43 | ).toDF.collect() 44 | 45 | assert(source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _))) 46 | } 47 | 48 | "check unequal Seq[Row]" in { 49 | 50 | val source = Seq( 51 | ("juan", 5), 52 | ("bob", 1), 53 | ("li", 49), 54 | ("alice", 5) 55 | ).toDF.collect().toSeq 56 | 57 | val expected = Seq( 58 | ("juan", 5), 59 | ("bob", 1), 60 | ("li", 40), 61 | ("alice", 5) 62 | ).toDF.collect() 63 | 64 | assert(!source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _))) 65 | } 66 | 67 | "check equal Seq[Row] with tolerance" in { 68 | val source = Seq( 69 | ("juan", 12.00000000001), 70 | ("bob", 1.00000000001), 71 | ("li", 49.00000000001), 72 | ("alice", 5.00000000001) 73 | ).toDF.collect().toSeq 74 | 75 | val expected = Seq( 76 | ("juan", 12.0), 77 | ("bob", 1.0), 78 | ("li", 49.0), 79 | ("alice", 5.0) 80 | ).toDF.collect() 81 | assert(source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _, 0.0000000002))) 82 | } 83 | 84 | "check indexedSeq[Row] with tolerance" in { 85 | val source = Seq( 86 | ("juan", 12.00000000001), 87 | ("bob", 1.00000000001), 88 | ("li", 49.00000000001), 89 | ("alice", 5.00000000001) 90 | ).toDF.collect().toIndexedSeq 91 | 92 | val expected = Seq( 93 | ("juan", 12.0), 94 | ("bob", 1.0), 95 | ("li", 49.0), 96 | ("alice", 5.0) 97 | ).toDF.collect().toIndexedSeq 98 | 99 | assert(source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _, 0.0000000002))) 100 | } 101 | 102 | "check non equal Seq[Row] with tolerance" in { 103 | val source = Seq( 104 | ("juan", 12.00000000002), 105 | ("bob", 1.00000000002), 106 | ("li", 49.00000000002), 107 | ("alice", 5.00000000002) 108 | ).toDF.collect().toSeq 109 | 110 | val expected = Seq( 111 | ("juan", 12), 112 | ("bob", 1), 113 | ("li", 49), 114 | ("alice", 5) 115 | ).toDF.collect() 116 | 117 | assert(!source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _, .00000000001))) 118 | } 119 | 120 | "check non equal indexedSeq[Row] with tolerance" in { 121 | val source = Seq( 122 | ("juan", 12.00000000002), 123 | ("bob", 1.00000000002), 124 | ("li", 49.00000000002), 125 | ("alice", 5.00000000002) 126 | ).toDF.collect().toIndexedSeq 127 | 128 | val expected = Seq( 129 | ("juan", 12), 130 | ("bob", 1), 131 | ("li", 49), 132 | ("alice", 5) 133 | ).toDF.collect().toIndexedSeq 134 | 135 | assert(!source.approximateSameElements(expected, RowComparer.areRowsEqual(_, _, .00000000001))) 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/SparkSessionExt.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions 4 | import org.apache.spark.sql.{DataFrame, Row, SparkSession} 5 | import org.apache.spark.sql.types.{DataType, StructField, StructType} 6 | 7 | object SparkSessionExt { 8 | 9 | implicit class SparkSessionMethods(spark: SparkSession) { 10 | 11 | private def asSchema[U](fields: List[U]): List[StructField] = { 12 | fields.map { 13 | case x: StructField => x.asInstanceOf[StructField] 14 | case (name: String, dataType: DataType, nullable: Boolean) => 15 | StructField(name, dataType, nullable) 16 | } 17 | } 18 | 19 | /** 20 | * Creates a DataFrame, similar to createDataFrame, but with better syntax spark-daria defined a createDF method that allows for the terse syntax 21 | * of `toDF` and the control of `createDataFrame`. 22 | * 23 | * spark.createDF( List( ("bob", 45), ("liz", 25), ("freeman", 32) ), List( ("name", StringType, true), ("age", IntegerType, false) ) ) 24 | * 25 | * The `createDF` method can also be used with lists of `Row` and `StructField` objects. 26 | * 27 | * spark.createDF( List( Row("bob", 45), Row("liz", 25), Row("freeman", 32) ), List( StructField("name", StringType, true), StructField("age", 28 | * IntegerType, false) ) ) 29 | */ 30 | def createDF[U, T](rowData: List[U], fields: List[T]): DataFrame = { 31 | spark.createDataFrame( 32 | spark.sparkContext.parallelize(rowData.asRows), 33 | StructType(asSchema(fields)) 34 | ) 35 | } 36 | 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/SparkSessionTestWrapper.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import org.apache.spark.sql.SparkSession 4 | 5 | trait SparkSessionTestWrapper { 6 | 7 | lazy val spark: SparkSession = { 8 | val session = SparkSession 9 | .builder() 10 | .master("local") 11 | .appName("spark session") 12 | .config("spark.sql.shuffle.partitions", "1") 13 | .getOrCreate() 14 | session.sparkContext.setLogLevel("ERROR") 15 | session 16 | } 17 | 18 | } 19 | -------------------------------------------------------------------------------- /core/src/test/scala/com/github/mrpowers/spark/fast/tests/TestUtilsExt.scala: -------------------------------------------------------------------------------- 1 | package com.github.mrpowers.spark.fast.tests 2 | 3 | import scala.util.matching.Regex 4 | 5 | object TestUtilsExt { 6 | val coloredStringPattern: Regex = "(\u001B\\[\\d{1,2}m)([\\s\\S]*?)(?=\u001B\\[\\d{1,2}m)".r 7 | implicit class StringOps(s: String) { 8 | def extractColorGroup: Map[String, Seq[String]] = coloredStringPattern 9 | .findAllMatchIn(s) 10 | .map(m => (m.group(1), m.group(2))) 11 | .toSeq 12 | .groupBy(_._1) 13 | .mapValues(_.map(_._2)) 14 | .view 15 | .toMap 16 | 17 | def assertColorDiff(actual: Seq[String], expected: Seq[String]): Unit = { 18 | val colourGroup = extractColorGroup 19 | val expectedColourGroup = colourGroup.get(Console.GREEN) 20 | val actualColourGroup = colourGroup.get(Console.RED) 21 | assert(expectedColourGroup.contains(expected)) 22 | assert(actualColourGroup.contains(actual)) 23 | } 24 | } 25 | 26 | implicit class ExceptionOps(e: Exception) { 27 | def assertColorDiff(actual: Seq[String], expected: Seq[String]): Unit = e.getMessage.assertColorDiff(actual, expected) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /docs/about/README.md: -------------------------------------------------------------------------------- 1 | ../../README.md -------------------------------------------------------------------------------- /images/assertColumnEquality_error_message.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertColumnEquality_error_message.png -------------------------------------------------------------------------------- /images/assertSchemaEquality_tree_message.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSchemaEquality_tree_message.png -------------------------------------------------------------------------------- /images/assertSmallDataFrameEquality_DatasetContentMissmatch_message.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSmallDataFrameEquality_DatasetContentMissmatch_message.png -------------------------------------------------------------------------------- /images/assertSmallDataFrameEquality_DatasetSchemaMisMatch_message.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSmallDataFrameEquality_DatasetSchemaMisMatch_message.png -------------------------------------------------------------------------------- /images/assertSmallDataFrameEquality_error_message.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSmallDataFrameEquality_error_message.png -------------------------------------------------------------------------------- /images/assertSmallDatasetEquality_error_message.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrpowers-io/spark-fast-tests/ea31bc9f2563069ae95865494fb26706bc81bc60/images/assertSmallDatasetEquality_error_message.png -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.10.1 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn 2 | 3 | resolvers += Resolver.bintrayIvyRepo("s22s", "sbt-plugins") 4 | 5 | resolvers += Resolver.typesafeRepo("releases") 6 | 7 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") 8 | 9 | addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.9.0") 10 | 11 | addSbtPlugin("org.typelevel" % "laika-sbt" % "1.2.0") 12 | 13 | addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.3") -------------------------------------------------------------------------------- /scripts/multi_spark_releases.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # link to the spark-daria GitHub release script: https://github.com/MrPowers/spark-daria/blob/master/scripts/github_release.sh 4 | # need to clone the spark-daria repo and feed the release script as an argument to this script 5 | 6 | SPARK_DARIA_GITHUB_RELEASE=$1 7 | if [ "$SPARK_DARIA_GITHUB_RELEASE" = "" ] 8 | then 9 | echo "spark-daria github_release script path must be set" 10 | exit 1 11 | fi 12 | 13 | for sparkVersion in 2.2.0 2.2.1 2.3.0; do 14 | echo $sparkVersion 15 | sed -i '' "s/^val sparkVersion.*/val sparkVersion = \"$sparkVersion\"/" build.sbt 16 | $SPARK_DARIA_GITHUB_RELEASE package 17 | done --------------------------------------------------------------------------------