├── pysparkler ├── tests │ ├── __init__.py │ ├── sample │ │ ├── config.yaml │ │ ├── input_pyspark.py │ │ └── InputPySparkNotebook.ipynb │ ├── conftest.py │ ├── test_cli.py │ ├── test_base.py │ ├── test_api.py │ ├── test_sql_21_to_33.py │ ├── test_pyspark_32_to_33.py │ ├── test_pyspark_23_to_24.py │ └── test_pyspark_31_to_32.py ├── .gitignore ├── Makefile ├── bump_version.sh ├── pysparkler │ ├── __init__.py │ ├── pyspark_32_to_33.py │ ├── pyspark_31_to_32.py │ ├── pyspark_23_to_24.py │ └── api.py ├── pyproject.toml └── .pre-commit-config.yaml ├── project └── build.properties ├── sql ├── src │ └── sparksql_upgrade │ │ ├── plugin_default_config.cfg │ │ ├── __init__.py │ │ └── plugin.py ├── MANIFEST.in ├── requirements.txt ├── test │ └── rules │ │ ├── test_cases │ │ ├── SPARK_SQL_CAST_CHANGE.yml │ │ ├── SPARK_SQL_APPROX_PERCENTILE.yml │ │ ├── SPARK_SQL_EXTRACT_SECOND.yml │ │ └── SPARK_SQL_RESERVED_PROPERTIES.yml │ │ └── rule_test_cases_test.py └── setup.py ├── scalafix ├── project │ ├── build.properties │ └── plugins.sbt ├── scalafix ├── output │ └── src │ │ └── main │ │ └── scala │ │ └── fix │ │ ├── SparkAutoUpgrade.scala │ │ ├── FunkyTest.scala │ │ ├── ScalaTestExtendsFix2.scala │ │ ├── AllEquivalentExprsTest.scala │ │ ├── DontMigrateTrigger.scala │ │ ├── MigrateTrigger.scala │ │ ├── HiveContextRenamed.scala │ │ ├── Accumulator.scala │ │ ├── ExpressionEncoder.scala │ │ ├── UnionRewrite.scala │ │ ├── HiveContext.scala │ │ ├── SparkSQLCallExternal.scala │ │ ├── OldReaderAddImports.scala │ │ ├── OldReader.scala │ │ ├── GroupByKeyRewrite.scala │ │ ├── SQLContextConstructor.scala │ │ └── GroupByKeyRenameColumnQQ.scala ├── input │ └── src │ │ └── main │ │ ├── scala │ │ └── fix │ │ │ ├── SparkAutoUpgrade.scala │ │ │ ├── ScalaTestExtendsFix2.scala │ │ │ ├── AllEquivalentExprsTest.scala │ │ │ ├── DontMigrateTrigger.scala │ │ │ ├── MigrateTrigger.scala │ │ │ ├── HiveContextRenamed.scala │ │ │ ├── ExpressionEncoder.scala │ │ │ ├── Accumulator.scala │ │ │ ├── HiveContext.scala │ │ │ ├── MetadataWarnQQ.scala │ │ │ ├── SparkSQLCallExternal.scala │ │ │ ├── MultiLineDatasetReadWarn.scala │ │ │ ├── UnionRewrite.scala │ │ │ ├── GroupByKeyWarn.scala │ │ │ ├── GroupByKeyRewrite.scala │ │ │ ├── SQLContextConstructor.scala │ │ │ └── GroupByKeyRenameColumnQQ.scala │ │ └── scala-2.12 │ │ └── fix │ │ ├── FunkyTest.scala │ │ ├── ExecutorPluginWarn.scala │ │ ├── OldReaderAddImports.scala │ │ └── OldReader.scala ├── tests │ └── src │ │ └── test │ │ └── scala │ │ └── fix │ │ └── RuleSuite.scala ├── .scalafix-warn.conf ├── rules │ └── src │ │ └── main │ │ ├── scala-2.11 │ │ └── fix │ │ │ ├── ExecutorPluginWarn.scala │ │ │ └── UnionRewrite.scala │ │ ├── scala │ │ └── fix │ │ │ ├── SparkAutoUpgrade.scala │ │ │ ├── ScalaTestExtendsFix.scala │ │ │ ├── IsRunningLocally.scala │ │ │ ├── AllEquivalentExprs.scala │ │ │ ├── OnFailureFix.scala │ │ │ ├── MigrateDeprecatedDataFrameReaderFuns.scala │ │ │ ├── MigrateTrigger.scala │ │ │ ├── MultiLineDatasetReadWarn.scala │ │ │ ├── ExpressionEncoder.scala │ │ │ ├── MigrateToSparkSessionBuilder.scala │ │ │ ├── MetadataWarnQQ.scala │ │ │ ├── GroupByKeyWarn.scala │ │ │ ├── SparkSQLCallExternal.scala │ │ │ ├── ScalaTestImportChange.scala │ │ │ ├── AccumulatorUpgrade.scala │ │ │ ├── MigrateHiveContext.scala │ │ │ ├── Utils.scala │ │ │ └── GroupByKeyRenameColumnQQ.scala │ │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── scalafix.v1.Rule │ │ ├── scala-2.12 │ │ └── fix │ │ │ ├── ExecutorPluginWarn.scala │ │ │ └── UnionRewrite.scala │ │ └── scala-2.13 │ │ └── fix │ │ ├── ExecutorPluginWarn.scala │ │ └── UnionRewrite.scala ├── .scalafix.conf ├── readme.md ├── build.sbt └── build │ ├── sbt │ └── sbt-launch-lib.bash ├── iceberg-spark-upgrade-wap-plugin ├── .gitignore ├── project │ ├── build.properties │ └── Dependencies.scala ├── src │ ├── main │ │ └── scala │ │ │ └── com │ │ │ └── holdenkarau │ │ │ └── spark │ │ │ └── upgrade │ │ │ └── wap │ │ │ └── plugin │ │ │ ├── Agent.scala │ │ │ └── IcebergListener.scala │ └── test │ │ └── scala │ │ └── com │ │ └── holdenkarau │ │ └── spark │ │ └── upgrade │ │ └── wap │ │ └── plugin │ │ └── WAPIcebergSpec.scala └── build.sbt ├── pipelinecompare ├── requirements.txt ├── utils.py ├── README.md └── spark_utils.py ├── e2e_demo └── scala │ ├── sparkdemoproject │ ├── project │ │ ├── build.properties │ │ └── plugins.sbt │ ├── settings.gradle │ ├── settings.gradle-33-example │ ├── gradle.properties │ ├── .travis.yml │ ├── src │ │ ├── test │ │ │ └── scala │ │ │ │ └── com │ │ │ │ └── holdenkarau │ │ │ │ └── sparkdemoproject │ │ │ │ └── WordCountTest.scala │ │ └── main │ │ │ └── scala │ │ │ └── com │ │ │ └── holdenkarau │ │ │ └── sparkdemoproject │ │ │ ├── CountingApp.scala │ │ │ └── WordCount.scala │ ├── build.gradle │ ├── build.gradle-33-example │ ├── build.gradle.scalafix │ ├── build.sbt │ └── .gitignore │ ├── python_check.py │ ├── cleanup.sh │ ├── update_gradle_settings.py │ ├── update_gradle_build.py │ └── dl_dependencies.sh ├── conf_migrate ├── README.md └── migrate.py ├── SECURITY.md ├── .gitignore ├── docs └── scala │ ├── sbt.md │ └── gradle.md ├── .github └── workflows │ ├── release.yml │ └── github-actions-basic.yml └── README.md /pysparkler/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.10.2 2 | -------------------------------------------------------------------------------- /sql/src/sparksql_upgrade/plugin_default_config.cfg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scalafix/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.10.2 2 | -------------------------------------------------------------------------------- /pysparkler/.gitignore: -------------------------------------------------------------------------------- 1 | **/.pytest_cache/ 2 | **/__pycache__/ 3 | -------------------------------------------------------------------------------- /iceberg-spark-upgrade-wap-plugin/.gitignore: -------------------------------------------------------------------------------- 1 | /.bsp/ 2 | target/ 3 | -------------------------------------------------------------------------------- /sql/src/sparksql_upgrade/__init__.py: -------------------------------------------------------------------------------- 1 | """Example sqlfluff plugin.""" 2 | -------------------------------------------------------------------------------- /pipelinecompare/requirements.txt: -------------------------------------------------------------------------------- 1 | iceberg 2 | lakefs_client 3 | colorama 4 | -------------------------------------------------------------------------------- /sql/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/sparksql_upgrade/plugin_default_config.cfg 2 | -------------------------------------------------------------------------------- /sql/requirements.txt: -------------------------------------------------------------------------------- 1 | sqlfluff==2.3.2 2 | flake8 3 | pytest 4 | autopep8 5 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.10.2 2 | -------------------------------------------------------------------------------- /iceberg-spark-upgrade-wap-plugin/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.10.2 2 | -------------------------------------------------------------------------------- /scalafix/scalafix: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holdenk/spark-upgrade/HEAD/scalafix/scalafix -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'sparkdemoproject_2.12' 2 | include('src') -------------------------------------------------------------------------------- /e2e_demo/scala/python_check.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info < (3, 9): 4 | sys.exit("Please use Python 3.9+") 5 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/settings.gradle-33-example: -------------------------------------------------------------------------------- 1 | rootProject.name = 'sparkdemoproject-3.3_2.12' 2 | include('src') 3 | -------------------------------------------------------------------------------- /conf_migrate/README.md: -------------------------------------------------------------------------------- 1 | Tool to migrate Spark configuration such that legacy behaviour is (for the most part) maintained when possible. 2 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/SparkAutoUpgrade.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | object SparkAutoUpgrade { 4 | // Add code that needs fixing here. 5 | } 6 | -------------------------------------------------------------------------------- /pysparkler/tests/sample/config.yaml: -------------------------------------------------------------------------------- 1 | pysparkler: 2 | dry_run: False 3 | PY24-30-001: 4 | comment: A new comment 5 | PY24-30-002: 6 | enabled: False -------------------------------------------------------------------------------- /e2e_demo/scala/cleanup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | rm -rf sparkdemoproject-3 6 | rm -rf ../../pipelinecompare/warehouse/* 7 | rm -f /tmp/spark-migration-jars/* 8 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/gradle.properties: -------------------------------------------------------------------------------- 1 | scalaVersion=2.12.13 2 | sparkVersion=2.4.8 3 | org.gradle.jvmargs=-Xms512M -Xmx2048M -XX:MaxPermSize=2048M -XX:+CMSClassUnloadingEnabled -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/SparkAutoUpgrade.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule = SparkAutoUpgrade 3 | */ 4 | package fix 5 | 6 | object SparkAutoUpgrade { 7 | // Add code that needs fixing here. 8 | } 9 | -------------------------------------------------------------------------------- /e2e_demo/scala/update_gradle_settings.py: -------------------------------------------------------------------------------- 1 | import re,sys; 2 | 3 | print( 4 | re.sub(r"rootProject.name\s*=\s*[\'\"](.*?)(-3)?(_2.1[12])?[\"\']", "rootProject.name = \"\\1-3\\3\"", sys.stdin.read())) 5 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/FunkyTest.scala: -------------------------------------------------------------------------------- 1 | import org.scalatest.matchers.should.Matchers._ 2 | import org.scalatest.funsuite.AnyFunSuite 3 | 4 | class OldTest extends AnyFunSuite { val a = 1 } 5 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/ScalaTestExtendsFix2.scala: -------------------------------------------------------------------------------- 1 | trait Farts { 2 | } 3 | 4 | trait AnyFunSuite { 5 | } 6 | 7 | class OldTest2 extends AnyFunSuite with Farts { 8 | val a = 1 9 | } 10 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala-2.12/fix/FunkyTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=ScalaTestImportChange 3 | */ 4 | import org.scalatest.Matchers._ 5 | import org.scalatest.FunSuite 6 | 7 | class OldTest extends FunSuite { val a = 1 } 8 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/ScalaTestExtendsFix2.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=ScalaTestExtendsFix 3 | */ 4 | trait Farts { 5 | } 6 | 7 | trait FunSuite { 8 | } 9 | 10 | class OldTest2 extends FunSuite with Farts { 11 | val a = 1 12 | } 13 | -------------------------------------------------------------------------------- /conf_migrate/migrate.py: -------------------------------------------------------------------------------- 1 | legacy_apped_rules = { 2 | # SQL - https://spark.apache.org/docs/3.0.0/sql-migration-guide.html 3 | "spark.sql.storeAssignmentPolicy": "Legacy", 4 | "spark.sql.legacy.setCommandRejectsSparkCoreConfs": "false", 5 | } 6 | -------------------------------------------------------------------------------- /scalafix/tests/src/test/scala/fix/RuleSuite.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | import scalafix.testkit.AbstractSemanticRuleSuite 3 | import org.scalatest.FunSuiteLike 4 | 5 | class RuleSuite extends AbstractSemanticRuleSuite with FunSuiteLike { 6 | runAllTests() 7 | } 8 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/AllEquivalentExprsTest.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.sql.catalyst.expressions.EquivalentExpressions 2 | 3 | object EETest { 4 | def boop(e: EquivalentExpressions) = { 5 | e.getCommonSubexpressions.map(List(_)) 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/DontMigrateTrigger.scala: -------------------------------------------------------------------------------- 1 | import scala.concurrent.duration._ 2 | import org.apache.spark._ 3 | import org.apache.spark.sql.streaming._ 4 | 5 | object DontMigrateTrigger { 6 | def boop(): Unit = { 7 | val sc = new SparkContext() 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /pysparkler/Makefile: -------------------------------------------------------------------------------- 1 | install: 2 | pip install poetry 3 | poetry install 4 | 5 | lint: 6 | poetry run pre-commit run --all-files 7 | 8 | test: 9 | poetry run pytest tests/ ${PYTEST_ARGS} 10 | 11 | publish: 12 | pip install poetry 13 | ./bump_version.sh 14 | poetry publish --build 15 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/AllEquivalentExprsTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=AllEquivalentExprs 3 | */ 4 | import org.apache.spark.sql.catalyst.expressions.EquivalentExpressions 5 | 6 | object EETest { 7 | def boop(e: EquivalentExpressions) = { 8 | e.getAllEquivalentExprs 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala-2.12/fix/ExecutorPluginWarn.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule = ExecutorPluginWarn 3 | */ 4 | 5 | import org.apache.spark.ExecutorPlugin // assert: ExecutorPluginWarn 6 | 7 | class TestExecutorPlugin() extends ExecutorPlugin { // assert: ExecutorPluginWarn 8 | override def shutdown() = { 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/DontMigrateTrigger.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MigrateTrigger 3 | */ 4 | import scala.concurrent.duration._ 5 | import org.apache.spark._ 6 | import org.apache.spark.sql.streaming._ 7 | 8 | object DontMigrateTrigger { 9 | def boop(): Unit = { 10 | val sc = new SparkContext() 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | Only the latest version has security support. 6 | 7 | ## Reporting a Vulnerability 8 | 9 | To report a security vulnerability, please use the 10 | [Tidelift security contact](https://tidelift.com/security). 11 | Tidelift will coordinate the fix and disclosure. 12 | -------------------------------------------------------------------------------- /sql/test/rules/test_cases/SPARK_SQL_CAST_CHANGE.yml: -------------------------------------------------------------------------------- 1 | rule: SPARKSQLCAST_L001 2 | 3 | cast_as_int: 4 | configs: 5 | core: 6 | dialect: sparksql 7 | fail_str: | 8 | select 9 | cast(a as int), 10 | cast(b as int) 11 | from tbl 12 | fix_str: | 13 | select 14 | int(a), 15 | int(b) 16 | from tbl 17 | -------------------------------------------------------------------------------- /scalafix/.scalafix-warn.conf: -------------------------------------------------------------------------------- 1 | rules = [ 2 | GroupByKeyWarn, 3 | MetadataWarnQQ, 4 | MultiLineDatasetReadWarn 5 | ] 6 | UnionRewrite.deprecatedMethod { 7 | "unionAll" = "union" 8 | } 9 | 10 | OrganizeImports { 11 | blankLines = Auto 12 | groups = [ 13 | "re:javax?\\." 14 | "scala." 15 | "org.apache.spark." 16 | "*" 17 | ] 18 | } -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/MigrateTrigger.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MigrateTrigger 3 | */ 4 | import scala.concurrent.duration._ 5 | import org.apache.spark._ 6 | import org.apache.spark.sql.streaming._ 7 | 8 | object MigrateTrigger { 9 | def boop(): Unit = { 10 | val sc = new SparkContext() 11 | val trigger = ProcessingTime(1.second) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /scalafix/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += Resolver.sonatypeRepo("releases") 2 | addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.10.4") 3 | 4 | addDependencyTreePlugin 5 | 6 | ThisBuild / libraryDependencySchemes ++= Seq( 7 | "org.scala-lang.modules" %% "scala-xml" % VersionScheme.Always 8 | ) 9 | 10 | addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.11") 11 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/MigrateTrigger.scala: -------------------------------------------------------------------------------- 1 | import scala.concurrent.duration._ 2 | import org.apache.spark._ 3 | import org.apache.spark.sql.streaming._ 4 | import org.apache.spark.sql.streaming.Trigger._ 5 | 6 | object MigrateTrigger { 7 | def boop(): Unit = { 8 | val sc = new SparkContext() 9 | val trigger = ProcessingTime(1.second) 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/HiveContextRenamed.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark._ 2 | import org.apache.spark.sql._ 3 | 4 | object BadHiveContextMagic2 { 5 | def hiveContextFunc(sc: SparkContext): SQLContext = { 6 | val hiveContext1 = SparkSession.builder.enableHiveSupport().getOrCreate().sqlContext 7 | import hiveContext1.implicits._ 8 | hiveContext1 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala-2.11/fix/ExecutorPluginWarn.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class ExecutorPluginWarn extends SemanticRule("ExecutorPluginWarn") { 7 | // Executor plugin does not exist in early versions of Spark so skip the rule. 8 | 9 | override def fix(implicit doc: SemanticDocument): Patch = { 10 | None.asPatch 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | 3 | # These directories are cached to S3 at the end of the build 4 | cache: 5 | directories: 6 | - $HOME/.ivy2/cache 7 | - $HOME/.sbt/boot/ 8 | - $HOME/.sbt/launchers 9 | - $HOME/build 10 | 11 | jdk: 12 | - oraclejdk8 13 | scala: 14 | - 2.12.13 15 | after_success: 16 | - bash <(curl -s https://codecov.io/bash) 17 | sudo: false -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") 2 | 3 | resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" 4 | 5 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.1.1") 6 | 7 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.6.0") 8 | 9 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.15.0") 10 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/HiveContextRenamed.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MigrateHiveContext 3 | */ 4 | import org.apache.spark._ 5 | import org.apache.spark.sql._ 6 | import org.apache.spark.sql.hive.{HiveContext => HiveCtx} 7 | 8 | object BadHiveContextMagic2 { 9 | def hiveContextFunc(sc: SparkContext): HiveCtx = { 10 | val hiveContext1 = new HiveCtx(sc) 11 | import hiveContext1.implicits._ 12 | hiveContext1 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /pipelinecompare/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from colorama import init as colorama_init 3 | from colorama import Fore 4 | from colorama import Style 5 | 6 | colorama_init() 7 | 8 | def eprint(*args, **kwargs): 9 | print(Fore.RED, file=sys.stderr) 10 | print(*args, file=sys.stderr, **kwargs) 11 | print(Style.RESET_ALL, file=sys.stderr) 12 | 13 | 14 | def error(*args, **kwargs): 15 | eprint(*args, **kwargs) 16 | raise Exception(*args) 17 | -------------------------------------------------------------------------------- /pysparkler/bump_version.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -efx 2 | 3 | if [ "$1" == '--no-op' ] 4 | then 5 | echo='echo' 6 | else 7 | echo= 8 | fi 9 | 10 | PYSPARKLER_VERSION=$(poetry version --short --dry-run); 11 | 12 | # Check if version is development release 13 | if [[ ${PYSPARKLER_VERSION} =~ dev.* ]]; then 14 | # Append epoch time to dev version 15 | PYSPARKLER_VERSION="${PYSPARKLER_VERSION}$(date +'%s')" 16 | fi 17 | 18 | ${echo} poetry version "${PYSPARKLER_VERSION}" 19 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/Accumulator.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark._ 2 | 3 | object BadAcc { 4 | def boop(): Unit = { 5 | val sc = new SparkContext() 6 | val num = 0 7 | val numAcc = /*sc.accumulator(num)*/ null 8 | val litAcc = /*sc.accumulator(0)*/ null 9 | val litLongAcc = sc.longAccumulator 10 | val namedAcc = /*sc.accumulator(0, "cheese")*/ null 11 | val litDoubleAcc = sc.doubleAccumulator 12 | val rdd = sc.parallelize(List(1,2,3)) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/SparkAutoUpgrade.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | 5 | import scala.meta._ 6 | 7 | class SparkAutoUpgrade extends SemanticRule("SparkAutoUpgrade") { 8 | override def fix(implicit doc: SemanticDocument): Patch = { 9 | // println("Tree.syntax: " + doc.tree.syntax) 10 | // println("Tree.structure: " + doc.tree.structure) 11 | // println("Tree.structureLabeled: " + doc.tree.structureLabeled) 12 | 13 | Patch.empty 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/ExpressionEncoder.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark._ 2 | import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder 3 | 4 | object BadEncoderEx { 5 | def boop() = { 6 | // Round trip with toRow and fromRow. 7 | val stringEncoder = ExpressionEncoder[String] 8 | val intEncoder = ExpressionEncoder[Int] 9 | val row = stringEncoder.createSerializer()("hello world") 10 | val decoded = stringEncoder.createDeserializer()(row) 11 | val intRow = intEncoder.createSerializer()(1) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/ExpressionEncoder.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=ExpressionEncoder 3 | */ 4 | import org.apache.spark._ 5 | import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder 6 | 7 | object BadEncoderEx { 8 | def boop() = { 9 | // Round trip with toRow and fromRow. 10 | val stringEncoder = ExpressionEncoder[String] 11 | val intEncoder = ExpressionEncoder[Int] 12 | val row = stringEncoder.toRow("hello world") 13 | val decoded = stringEncoder.fromRow(row) 14 | val intRow = intEncoder.toRow(1) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /sql/test/rules/rule_test_cases_test.py: -------------------------------------------------------------------------------- 1 | """Runs the rule test cases.""" 2 | import os 3 | 4 | import pytest 5 | 6 | from sqlfluff.utils.testing.rules import load_test_cases, rules__test_helper 7 | 8 | ids, test_cases = load_test_cases( 9 | test_cases_path=os.path.join( 10 | os.path.abspath(os.path.dirname(__file__)), "test_cases", "*.yml" 11 | ) 12 | ) 13 | 14 | 15 | @pytest.mark.parametrize("test_case", test_cases, ids=ids) 16 | def test__rule_test_case(test_case): 17 | """Run the tests.""" 18 | rules__test_helper(test_case) 19 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/UnionRewrite.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | import org.apache.spark.sql.{DataFrame, Dataset} 3 | 4 | object UnionRewrite { 5 | def inSource( 6 | df1: DataFrame, 7 | df2: DataFrame, 8 | df3: DataFrame, 9 | ds1: Dataset[String], 10 | ds2: Dataset[String] 11 | ): Unit = { 12 | val res1 = df1.union(df2) 13 | val res2 = df1.union(df2).union(df3) 14 | val res3 = Seq(df1, df2, df3).reduce(_ union _) 15 | val res4 = ds1.union(ds2) 16 | val res5 = Seq(ds1, ds2).reduce(_ union _) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/HiveContext.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark._ 2 | import org.apache.spark.sql._ 3 | 4 | object BadHiveContextMagic { 5 | def hiveContextFunc(sc: SparkContext): SQLContext = { 6 | val hiveContext1 = SparkSession.builder.enableHiveSupport().getOrCreate().sqlContext 7 | import hiveContext1.implicits._ 8 | hiveContext1 9 | } 10 | 11 | def makeSparkConf() = { 12 | val sparkConf = new SparkConf(true) 13 | sparkConf 14 | } 15 | 16 | def throwSomeCrap() = { 17 | throw new RuntimeException("mr farts!") 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /sql/test/rules/test_cases/SPARK_SQL_APPROX_PERCENTILE.yml: -------------------------------------------------------------------------------- 1 | rule: SPARKSQL_L005 2 | 3 | approx_percent: 4 | configs: 5 | core: 6 | dialect: sparksql 7 | fail_str: | 8 | SELECT approx_percentile(col, array(0.5, 0.4, 0.1), temp); 9 | fix_str: | 10 | SELECT approx_percentile(col, array(0.5, 0.4, 0.1), cast(temp as int)); 11 | 12 | percent_approx: 13 | configs: 14 | core: 15 | dialect: sparksql 16 | fail_str: | 17 | SELECT percentile_approx(col, 0.2, temp); 18 | fix_str: | 19 | SELECT percentile_approx(col, 0.2, cast(temp as int)); 20 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/SparkSQLCallExternal.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark._ 2 | import org.apache.spark.sql._ 3 | 4 | object OldQuery { 5 | def doQuery(s: SparkSession) { 6 | // We should be able to rewrite this one 7 | s.sql("""select 8 | int(a), 9 | int(b)from fart_tbl 10 | """) 11 | // We can't auto rewrite this :( easily. 12 | val q = "SELECT * FROM FARTS LIMIT 1" 13 | s.sql(q) 14 | // we should not change this 15 | fart("magic farts") 16 | } 17 | 18 | def fart(str: String) = { 19 | println(s"Fart ${str}") 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /scalafix/.scalafix.conf: -------------------------------------------------------------------------------- 1 | rules = [ 2 | UnionRewrite, 3 | AccumulatorUpgrade, 4 | ScalaTestImportChange, 5 | GroupByKeyRewrite, 6 | MigrateHiveContext, 7 | MigrateTrigger, 8 | MigrateDeprecatedDataFrameReaderFuns, 9 | ScalaTestExtendsFix, 10 | MigrateToSparkSessionBuilder, 11 | GroupByKeyRenameColumnQQ, 12 | ExpressionEncoder, 13 | ] 14 | UnionRewrite.deprecatedMethod { 15 | "unionAll" = "union" 16 | } 17 | 18 | OrganizeImports { 19 | blankLines = Auto 20 | groups = [ 21 | "re:javax?\\." 22 | "scala." 23 | "org.apache.spark." 24 | "*" 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/resources/META-INF/services/scalafix.v1.Rule: -------------------------------------------------------------------------------- 1 | fix.SparkAutoUpgrade 2 | fix.MigrateHiveContext 3 | fix.MigrateToSparkSessionBuilder 4 | fix.MigrateDeprecatedDataFrameReaderFuns 5 | fix.AccumulatorUpgrade 6 | fix.OnFailureFix 7 | fix.ExecutorPluginWarn 8 | fix.UnionRewrite 9 | fix.GroupByKeyWarn 10 | fix.GroupByKeyRewrite 11 | fix.MetadataWarnQQ 12 | fix.ScalaTestImportChange 13 | fix.ScalaTestExtendsFix 14 | fix.MigrateTrigger 15 | fix.GroupByKeyRenameColumnQQ 16 | fix.MultiLineDatasetReadWarn 17 | fix.ExpressionEncoder 18 | fix.SparkSQLCallExternal 19 | fix.AllEquivalentExprs -------------------------------------------------------------------------------- /iceberg-spark-upgrade-wap-plugin/project/Dependencies.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | 3 | object Dependencies { 4 | lazy val scalaTest = "org.scalatest" %% "scalatest" % "3.2.11" 5 | lazy val iceberg = "org.apache.iceberg" % "iceberg-core" % "0.9.1" 6 | lazy val logback = "ch.qos.logback" % "logback-classic" % "1.2.10" 7 | lazy val scalaLogging = "com.typesafe.scala-logging" %% "scala-logging" % "3.9.4" 8 | lazy val sparkTestingBase = "com.holdenkarau" %% "spark-testing-base" % "3.2.2_1.3.4" 9 | lazy val icebergSparkRuntime = "org.apache.iceberg" %% "iceberg-spark-runtime-3.2" % "1.1.0" 10 | } 11 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/Accumulator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=AccumulatorUpgrade 3 | */ 4 | import org.apache.spark._ 5 | 6 | object BadAcc { 7 | def boop(): Unit = { 8 | val sc = new SparkContext() 9 | val num = 0 10 | val numAcc = sc.accumulator(num)// assert: AccumulatorUpgrade 11 | val litAcc = sc.accumulator(0)// assert: AccumulatorUpgrade 12 | val litLongAcc = sc.accumulator(0L) 13 | val namedAcc = sc.accumulator(0, "cheese")// assert: AccumulatorUpgrade 14 | val litDoubleAcc = sc.accumulator(0.0) 15 | val rdd = sc.parallelize(List(1,2,3)) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/HiveContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MigrateHiveContext 3 | */ 4 | import org.apache.spark._ 5 | import org.apache.spark.sql._ 6 | import org.apache.spark.sql.hive.HiveContext 7 | 8 | object BadHiveContextMagic { 9 | def hiveContextFunc(sc: SparkContext): HiveContext = { 10 | val hiveContext1 = new HiveContext(sc) 11 | import hiveContext1.implicits._ 12 | hiveContext1 13 | } 14 | 15 | def makeSparkConf() = { 16 | val sparkConf = new SparkConf(true) 17 | sparkConf 18 | } 19 | 20 | def throwSomeCrap() = { 21 | throw new RuntimeException("mr farts!") 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /iceberg-spark-upgrade-wap-plugin/src/main/scala/com/holdenkarau/spark/upgrade/wap/plugin/Agent.scala: -------------------------------------------------------------------------------- 1 | package com.holdenkarau.spark.upgrade.wap.plugin 2 | 3 | import java.lang.instrument.Instrumentation; 4 | 5 | import org.apache.iceberg.events.{CreateSnapshotEvent, Listeners} 6 | 7 | object Agent { 8 | def premain(agentOps: String, inst: Instrumentation): Unit = { 9 | registerListener() 10 | } 11 | def agentmain(agentOps: String, inst: Instrumentation): Unit = { 12 | registerListener() 13 | } 14 | def registerListener(): Unit = { 15 | Listeners.register(WAPIcebergListener, classOf[CreateSnapshotEvent]) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/MetadataWarnQQ.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MetadataWarnQQ 3 | */ 4 | package fix 5 | 6 | import org.apache.spark.sql.{DataFrame, SparkSession} 7 | import org.apache.spark.sql.functions.col 8 | import org.apache.spark.sql.types.Metadata // assert: MetadataWarnQQ 9 | 10 | object MetadataWarnQQ{ 11 | def inSource(sparkSession: SparkSession, df: DataFrame): Unit = { 12 | val ndf = df.select( 13 | col("id"), 14 | col("v").as( 15 | "newV", 16 | Metadata.fromJson( // assert: MetadataWarnQQ 17 | """{"desc": "replace old V"}""" 18 | ) 19 | ) 20 | ) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/SparkSQLCallExternal.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=SparkSQLCallExternal 3 | */ 4 | import org.apache.spark._ 5 | import org.apache.spark.sql._ 6 | 7 | object OldQuery { 8 | def doQuery(s: SparkSession) { 9 | // We should be able to rewrite this one 10 | s.sql("""select 11 | cast(a as int), 12 | cast(b as int) 13 | from fart_tbl""") 14 | // We can't auto rewrite this :( easily. 15 | val q = "SELECT * FROM FARTS LIMIT 1" 16 | s.sql(q) 17 | // we should not change this 18 | fart("magic farts") 19 | } 20 | 21 | def fart(str: String) = { 22 | println(s"Fart ${str}") 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | */build/sbt*.jar 4 | */target/ 5 | .metals/ 6 | target/ 7 | *.jar 8 | spark-*-bin-*hadoop* 9 | sparkdemoproject-3 10 | */__pycache__/* 11 | project/metals.sbt 12 | */*.egg-info/* 13 | .bsp/ 14 | hadoop-* 15 | pipelinecompare/warehouse/ 16 | pipelinecompare/metastore_db/ 17 | 18 | # PyCharm/Intellij Ignores 19 | **/.idea/* 20 | **/venv/* 21 | *.iml 22 | **/dist/* 23 | **/__pycache__/* 24 | .DS_Store 25 | **/build/* 26 | **/*.egg-info/* 27 | 28 | # Annoying macOS Ignores 29 | **.DS_Store 30 | 31 | 32 | # Some deps we download 33 | hadoop-2.*.0* 34 | # Generated in testing 35 | spark-warehouse 36 | metastore_db 37 | warehouse 38 | 39 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/MultiLineDatasetReadWarn.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MultiLineDatasetReadWarn 3 | */ 4 | package fix 5 | import org.apache.spark.sql.{SparkSession, Dataset} 6 | 7 | class MultiLineDatasetReadWarn { 8 | def inSource(sparkSession: SparkSession): Unit = { 9 | import sparkSession.implicits._ 10 | val df = (sparkSession // assert: MultiLineDatasetReadWarn 11 | .read 12 | .format("csv") 13 | .option("multiline", true) 14 | ) 15 | 16 | val okDf = (sparkSession 17 | .read 18 | .format("csv") 19 | ) 20 | 21 | val ds7 = Seq("test 1", "test 2", "test 3").toDF().groupBy("value").count() 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/UnionRewrite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule = UnionRewrite 3 | UnionRewrite.deprecatedMethod { 4 | "unionAll" = "union" 5 | } 6 | */ 7 | package fix 8 | import org.apache.spark.sql.{DataFrame, Dataset} 9 | 10 | object UnionRewrite { 11 | def inSource( 12 | df1: DataFrame, 13 | df2: DataFrame, 14 | df3: DataFrame, 15 | ds1: Dataset[String], 16 | ds2: Dataset[String] 17 | ): Unit = { 18 | val res1 = df1.unionAll(df2) 19 | val res2 = df1.unionAll(df2).unionAll(df3) 20 | val res3 = Seq(df1, df2, df3).reduce(_ unionAll _) 21 | val res4 = ds1.unionAll(ds2) 22 | val res5 = Seq(ds1, ds2).reduce(_ unionAll _) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /iceberg-spark-upgrade-wap-plugin/src/main/scala/com/holdenkarau/spark/upgrade/wap/plugin/IcebergListener.scala: -------------------------------------------------------------------------------- 1 | package com.holdenkarau.spark.upgrade.wap.plugin 2 | 3 | import org.apache.iceberg.events.{CreateSnapshotEvent, Listener} 4 | 5 | object WAPIcebergListener extends Listener[CreateSnapshotEvent] { 6 | // For testing 7 | private[holdenkarau] var lastLog = "" 8 | 9 | override def notify(event: CreateSnapshotEvent): Unit = { 10 | val msg = s"IcebergListener: Created snapshot ${event.snapshotId()} on table " + 11 | s"${event.tableName()} summary ${event.summary()} from operation " + 12 | s"${event.operation()}" 13 | lastLog = msg 14 | System.err.println(msg) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /sql/test/rules/test_cases/SPARK_SQL_EXTRACT_SECOND.yml: -------------------------------------------------------------------------------- 1 | rule: SPARKSQL_L004 2 | 3 | extract_second: 4 | configs: 5 | core: 6 | dialect: sparksql 7 | fail_str: | 8 | select extract(second from to_timestamp('2019-09-20 10:10:10.1')) 9 | fix_str: | 10 | select int(extract(second from to_timestamp('2019-09-20 10:10:10.1'))) 11 | 12 | extract_second_existing_cast: 13 | configs: 14 | core: 15 | dialect: sparksql 16 | fail_str: | 17 | INSERT OVERWRITE foo.bar SELECT CAST(extract(second from to_timestamp('2019-09-20 10:10:11.1')) AS STRING) AS a 18 | fix_str: | 19 | INSERT OVERWRITE foo.bar SELECT CAST(int(extract(second from to_timestamp('2019-09-20 10:10:11.1'))) AS STRING) AS a 20 | -------------------------------------------------------------------------------- /e2e_demo/scala/update_gradle_build.py: -------------------------------------------------------------------------------- 1 | import re,sys,os; 2 | 3 | original_build = sys.stdin.read() 4 | 5 | build_with_plugin = original_build 6 | 7 | version = os.getenv("SCALAFIX_RULES_VERSION", "0.1.14") 8 | 9 | if "scalafix" not in build_with_plugin: 10 | build_with_plugin = re.sub( 11 | r"plugins\s*{", 12 | "plugins {\n id 'io.github.cosmicsilence.scalafix' version '0.1.14'\n", 13 | build_with_plugin 14 | ) 15 | 16 | build_with_plugin_and_rules = re.sub( 17 | r"dependencies\s*{", 18 | "dependencies {\n scalafix group: 'com.holdenkarau', name: 'spark-scalafix-rules-2.4.8_2.12', version: '" + version +"'\n", 19 | build_with_plugin) 20 | 21 | print(build_with_plugin_and_rules) 22 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala-2.12/fix/OldReaderAddImports.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MigrateDeprecatedDataFrameReaderFuns 3 | */ 4 | import org.apache.spark._ 5 | import org.apache.spark.rdd._ 6 | import org.apache.spark.sql.{SparkSession, Dataset} 7 | 8 | object BadReadsAddImports { 9 | def doMyWork(session: SparkSession, r: RDD[String], dataset: Dataset[String]) = { 10 | import session.implicits._ 11 | val shouldRewriteBasic = session.read.json(r) 12 | val r2 = session.sparkContext.parallelize(List("{}")) 13 | val shouldRewrite = session.read.json(r2) 14 | val r3: RDD[String] = session.sparkContext.parallelize(List("{}")) 15 | val shouldRewriteExplicit = session.read.json(r3) 16 | val noRewrite2 = session.read.json(dataset) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala-2.12/fix/OldReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MigrateDeprecatedDataFrameReaderFuns 3 | */ 4 | import org.apache.spark._ 5 | import org.apache.spark.rdd._ 6 | import org.apache.spark.sql._ 7 | 8 | object BadReads { 9 | def doMyWork(session: SparkSession, r: RDD[String], dataset: Dataset[String]) = { 10 | import session.implicits._ 11 | val shouldRewriteBasic = session.read.json(r) 12 | val r2 = session.sparkContext.parallelize(List("{}")) 13 | val shouldRewrite = session.read.json(r2) 14 | val r3: RDD[String] = session.sparkContext.parallelize(List("{}")) 15 | val shouldRewriteExplicit = session.read.json(r3) 16 | val noRewrite1 = session.read.json(session.createDataset(r)(Encoders.STRING)) 17 | val noRewrite2 = session.read.json(dataset) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/OldReaderAddImports.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark._ 2 | import org.apache.spark.rdd._ 3 | import org.apache.spark.sql.{SparkSession, Dataset} 4 | import org.apache.spark.sql.Encoders 5 | 6 | object BadReadsAddImports { 7 | def doMyWork(session: SparkSession, r: RDD[String], dataset: Dataset[String]) = { 8 | import session.implicits._ 9 | val shouldRewriteBasic = session.read.json(session.createDataset(r)(Encoders.STRING)) 10 | val r2 = session.sparkContext.parallelize(List("{}")) 11 | val shouldRewrite = session.read.json(session.createDataset(r2)(Encoders.STRING)) 12 | val r3: RDD[String] = session.sparkContext.parallelize(List("{}")) 13 | val shouldRewriteExplicit = session.read.json(session.createDataset(r3)(Encoders.STRING)) 14 | val noRewrite2 = session.read.json(dataset) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/OldReader.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark._ 2 | import org.apache.spark.rdd._ 3 | import org.apache.spark.sql._ 4 | 5 | object BadReads { 6 | def doMyWork(session: SparkSession, r: RDD[String], dataset: Dataset[String]) = { 7 | import session.implicits._ 8 | val shouldRewriteBasic = session.read.json(session.createDataset(r)(Encoders.STRING)) 9 | val r2 = session.sparkContext.parallelize(List("{}")) 10 | val shouldRewrite = session.read.json(session.createDataset(r2)(Encoders.STRING)) 11 | val r3: RDD[String] = session.sparkContext.parallelize(List("{}")) 12 | val shouldRewriteExplicit = session.read.json(session.createDataset(r3)(Encoders.STRING)) 13 | val noRewrite1 = session.read.json(session.createDataset(r)(Encoders.STRING)) 14 | val noRewrite2 = session.read.json(dataset) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/ScalaTestExtendsFix.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | // Fix the extends with since the QQ matcher doesn't like it and I'm lazy. 7 | class ScalaTestExtendsFix 8 | extends SyntacticRule("ScalaTestExtendsFix") { 9 | override val description = 10 | """Handle the change with ScalaTest ( see https://www.scalatest.org/release_notes/3.1.0 ) """ 11 | 12 | override val isRewrite = true 13 | 14 | override def fix(implicit doc: SyntacticDocument): Patch = { 15 | println("Magicz!") 16 | doc.tree.collect { case v: Type.Name => 17 | println(v) 18 | if (v.toString == "FunSuite") { 19 | Patch.replaceTree(v, "AnyFunSuite") 20 | } else { 21 | println(s"No change to $v") 22 | Patch.empty 23 | } 24 | }.asPatch 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /pysparkler/pysparkler/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/IsRunningLocally.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | case class IsRunningLocally(e: scala.meta.Tree) extends Diagnostic { 7 | override def position: Position = e.pos 8 | override def message: String = 9 | "TaskContext.isRunningLocally has been removed, see " + 10 | "https://spark.apache.org/docs/3.0.0/core-migration-guide.html " + 11 | " since local execution was removed you can probably delete this code path." 12 | } 13 | 14 | class IsRunningLocallyWarn extends SemanticRule("IsRunningLocallyWarn") { 15 | 16 | val matcher = SymbolMatcher.normalized("org.apache.spark.TaskContext.isRunningLocally") 17 | 18 | override def fix(implicit doc: SemanticDocument): Patch = { 19 | doc.tree.collect { 20 | case matcher(s) => 21 | Patch.lint(IsRunningLocally(s)) 22 | }.asPatch 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /docs/scala/sbt.md: -------------------------------------------------------------------------------- 1 | Hi Friend! It looks your trying to migrate a ASF Spark project! 2 | Let's make it happen! 3 | These instructions are written for sbt, there will be instructions for other builds as well. 4 | 5 | I've tried to update your build file for you, but there might be some mistakes. What I've tried to do is: 6 | 7 | `` 8 | scalafixDependencies in ThisBuild += 9 | "com.holdenkarau" %% "spark-scalafix-rules-2.4.8" % "0.1.15" 10 | semanticdbEnabled in ThisBuild := true 11 | `` 12 | 13 | Then add: 14 | 15 | `` 16 | resolvers += Resolver.sonatypeRepo("releases") 17 | 18 | addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.10.4") 19 | `` 20 | 21 | to project/plugins.sbt 22 | 23 | And update your build file to add a "-3" to the artifact name so I can tell the difference between your Spark 3 & Spark 2 jars. 24 | 25 | Thanks friend! 26 | 27 | (Note: we could also try and do this with some REs on your build file too, but... it's a demo) 28 | 29 | Add a .scalafix.conf file as patterned after the one in our scalafix directory. 30 | -------------------------------------------------------------------------------- /docs/scala/gradle.md: -------------------------------------------------------------------------------- 1 | Hi Friend! It looks your trying to migrate a ASF Spark project! 2 | Let's make it happen! 3 | These instructions are written for gradle, there will be instructions for other builds as well. 4 | 5 | I've tried to update your build file for you, but there might be some mistakes. What I've tried to do is: 6 | 7 | Add 8 | `` 9 | scalafix group: "com.holdenkarau", name: 'spark-scalafix-rules-2.4.8_2.12', version: '0.1.13' 10 | `` 11 | to your dependencies 12 | 13 | And add: 14 | 15 | `` 16 | id "io.github.cosmicsilence.scalafix" version "0.1.14" 17 | `` 18 | 19 | To your plugins. 20 | 21 | If your including ScalaFix through "classpath" rather than "plugins" you will want add `apply plugin: 'io.github.cosmicsilence.scalafix'`. 22 | 23 | 24 | And update your build file to add a "-3" to the artifact name so we can tell the difference between your Spark 3 & Spark 2 jars. 25 | 26 | Thanks friend! 27 | 28 | (Note: we could also try and do this with some REs on your build file too, but... it's a demo) 29 | 30 | Add a .scalafix.conf file as patterned after the one in our scalafix directory. 31 | -------------------------------------------------------------------------------- /scalafix/readme.md: -------------------------------------------------------------------------------- 1 | # Scalafix rules for Spark Auto Upgrade 2 | 3 | To use the scalafix rules, see the build tool specific docs https://github.com/holdenk/spark-upgrade/tree/main/docs/scala 4 | and the end to end demo https://github.com/holdenk/spark-upgrade/tree/main/e2e_demo/scala 5 | 6 | To migrate in-line SQL you will need to install sqlfluff + our extensions, which you can do with 7 | 8 | ```bash 9 | git clone https://github.com/holdenk/spark-upgrade.git 10 | cd spark-upgrade/sql; pip install . 11 | ``` 12 | 13 | ## Other (non-Spark Specific) rules you may wish to use 14 | 15 | https://github.com/scala/scala-rewrites - 16 | 17 | fix.scala213.Any2StringAdd 18 | fix.scala213.Core 19 | fix.scala213.ExplicitNonNullaryApply 20 | fix.scala213.ExplicitNullaryEtaExpansion 21 | fix.scala213.NullaryHashHash 22 | fix.scala213.ScalaSeq 23 | fix.scala213.Varargs 24 | 25 | 26 | ## Demo 27 | You can also watch the end to end demo at https://www.youtube.com/watch?v=bqpb84n9Dpk :) 28 | 29 | ## Extending 30 | 31 | To develop rule: 32 | ``` 33 | sbt ~tests/test 34 | # edit rules/src/main/scala/fix/SparkAutoUpgrade.scala 35 | ``` 36 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/AllEquivalentExprs.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class AllEquivalentExprs extends SemanticRule("AllEquivalentExprs") { 7 | 8 | override def fix(implicit doc: SemanticDocument): Patch = { 9 | val equivExprs = SymbolMatcher.normalized("org/apache/spark/sql/catalyst/expressions/EquivalentExpressions#getAllEquivalentExprs().") 10 | val utils = new Utils() 11 | 12 | def matchOnTree(e: Tree): Patch = { 13 | e match { 14 | case equivExprs(call) => 15 | // This is sketch because were messing with the string repr but it's easier 16 | // since we only want to replace some of our match. 17 | val newCall = call.toString.replace(".getAllEquivalentExprs", ".getCommonSubexpressions.map(List(_))") 18 | Patch.replaceTree(call, newCall) 19 | case elem @ _ => 20 | elem.children match { 21 | case Nil => Patch.empty 22 | case _ => elem.children.map(matchOnTree).asPatch 23 | } 24 | } 25 | } 26 | matchOnTree(doc.tree) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/OnFailureFix.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class OnFailureFix extends SemanticRule("onFailureFix") { 7 | // See https://stackoverflow.com/questions/62047662/value-onsuccess-is-not-a-member-of-scala-concurrent-futureany 8 | val onFailureFunMatch = SymbolMatcher.normalized("scala.concurrent.onFuture") 9 | val onSuccessFunMatch = SymbolMatcher.normalized("scala.concurrent.onFuture") 10 | 11 | override def fix(implicit doc: SemanticDocument): Patch = { 12 | doc.tree.collect { 13 | case ns @ Term.Apply(j @ onFailureFunMatch(f), args) => 14 | val future = ns.children(0).children(0) 15 | List( 16 | Patch.addRight(j, "(ev) }"), 17 | Patch.replaceTree(j, s"${future}.onComplete { case Error(ev) => ") 18 | ) 19 | case ns @ Term.Apply(j @ onSuccessFunMatch(f), args) => 20 | val future = ns.children(0).children(0) 21 | List( 22 | Patch.addRight(j, "(sv) }"), 23 | Patch.replaceTree(j, s"${future}.onComplete { case Success(sv) => ") 24 | ) 25 | }.flatten.asPatch 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/src/test/scala/com/holdenkarau/sparkdemoproject/WordCountTest.scala: -------------------------------------------------------------------------------- 1 | package com.holdenkarau.sparkDemoProject 2 | 3 | /** 4 | * A simple test for everyone's favourite wordcount example. 5 | */ 6 | 7 | import com.holdenkarau.spark.testing.SharedSparkContext 8 | import org.scalatest.funsuite.AnyFunSuite 9 | 10 | class WordCountTest extends AnyFunSuite with SharedSparkContext { 11 | test("word count with Stop Words Removed"){ 12 | val linesRDD = sc.parallelize(Seq( 13 | "How happy was the panda? You ask.", 14 | "Panda is the most happy panda in all the#!?ing land!")) 15 | 16 | val stopWords: Set[String] = Set("a", "the", "in", "was", "there", "she", "he") 17 | val splitTokens: Array[Char] = "#%?!. ".toCharArray 18 | 19 | val wordCounts = WordCount.withStopWordsFiltered( 20 | linesRDD, splitTokens, stopWords) 21 | val wordCountsAsMap = wordCounts.collectAsMap() 22 | assert(!wordCountsAsMap.contains("the")) 23 | assert(!wordCountsAsMap.contains("?")) 24 | assert(!wordCountsAsMap.contains("#!?ing")) 25 | assert(wordCountsAsMap.contains("ing")) 26 | assert(wordCountsAsMap.get("panda").get === 3L) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala-2.12/fix/ExecutorPluginWarn.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | case class ExecutorPluginWarning(e: scala.meta.Tree) extends Diagnostic { 7 | override def position: Position = e.pos 8 | override def message: String = 9 | "Executor Plugin is dropped in 3.0+, see " + 10 | "https://spark.apache.org/docs/3.0.0/core-migration-guide.html " + 11 | " https://spark.apache.org/docs/3.2.1/api/java/index.html?org/apache/spark/api/plugin/SparkPlugin.html" 12 | } 13 | 14 | class ExecutorPluginWarn extends SemanticRule("ExecutorPluginWarn") { 15 | // See https://spark.apache.org/docs/3.0.0/core-migration-guide.html + 16 | // + new docs at: 17 | // https://spark.apache.org/docs/3.2.1/api/java/index.html?org/apache/spark/api/plugin/SparkPlugin.html 18 | // https://spark.apache.org/docs/3.2.1/api/java/org/apache/spark/api/plugin/ExecutorPlugin.html 19 | 20 | val matcher = SymbolMatcher.normalized("org.apache.spark.ExecutorPlugin") 21 | 22 | override def fix(implicit doc: SemanticDocument): Patch = { 23 | doc.tree.collect { 24 | case matcher(s) => 25 | Patch.lint(ExecutorPluginWarning(s)) 26 | }.asPatch 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala-2.13/fix/ExecutorPluginWarn.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | case class ExecutorPluginWarning(e: scala.meta.Tree) extends Diagnostic { 7 | override def position: Position = e.pos 8 | override def message: String = 9 | "Executor Plugin is dropped in 3.0+, see " + 10 | "https://spark.apache.org/docs/3.0.0/core-migration-guide.html " + 11 | " https://spark.apache.org/docs/3.2.1/api/java/index.html?org/apache/spark/api/plugin/SparkPlugin.html" 12 | } 13 | 14 | class ExecutorPluginWarn extends SemanticRule("ExecutorPluginWarn") { 15 | // See https://spark.apache.org/docs/3.0.0/core-migration-guide.html + 16 | // + new docs at: 17 | // https://spark.apache.org/docs/3.2.1/api/java/index.html?org/apache/spark/api/plugin/SparkPlugin.html 18 | // https://spark.apache.org/docs/3.2.1/api/java/org/apache/spark/api/plugin/ExecutorPlugin.html 19 | 20 | val matcher = SymbolMatcher.normalized("org.apache.spark.ExecutorPlugin") 21 | 22 | override def fix(implicit doc: SemanticDocument): Patch = { 23 | doc.tree.collect { 24 | case matcher(s) => 25 | Patch.lint(ExecutorPluginWarning(s)) 26 | }.asPatch 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /sql/setup.py: -------------------------------------------------------------------------------- 1 | """Setup file for example plugin.""" 2 | from setuptools import find_packages, setup 3 | 4 | # Change these names in your plugin, e.g. company name or plugin purpose. 5 | PLUGIN_LOGICAL_NAME = "sparksql-upgrade" 6 | PLUGIN_ROOT_MODULE = "sparksql_upgrade" 7 | 8 | setup( 9 | name="sqlfluff-plugin-{plugin_logical_name}".format( 10 | plugin_logical_name=PLUGIN_LOGICAL_NAME 11 | ), 12 | version="0.1.4", 13 | author="Holden Karau", 14 | author_email="holden@pigscanfly.ca", 15 | url="https://github.com/holdenk/spark-upgrade", 16 | description="SQLFluff rules to help migrate your Spark SQL from 2.X to 3.X", 17 | long_description="SQLFluff rules to help migrate your Spark SQL from 2.X to 3.X", 18 | test_requires=["nose", "coverage", "unittest2"], 19 | license="../LICENSE", 20 | include_package_data=True, 21 | package_dir={"": "src"}, 22 | packages=find_packages(where="src"), 23 | install_requires="sqlfluff==2.3.2", 24 | entry_points={ 25 | "sqlfluff": [ 26 | "{plugin_logical_name} = {plugin_root_module}.plugin".format( 27 | plugin_logical_name=PLUGIN_LOGICAL_NAME, 28 | plugin_root_module=PLUGIN_ROOT_MODULE, 29 | ) 30 | ] 31 | }, 32 | ) 33 | -------------------------------------------------------------------------------- /pipelinecompare/README.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | 3 | Install requirements from `requirements.txt`, create two different pipelines, build the parent table comparision project. 4 | 5 | # Open questions 6 | 7 | Is shelling through the command line the right approach? Benefit: we don't need to run inside of spark-submit. 8 | Do we want to support "raw" tables? 9 | 10 | 11 | # Samples 12 | 13 | ## Iceberg sample 14 | 15 | ## LakeFS Sample 16 | - sign up for lakefs demo 17 | - create a ~/.lakectl.yaml file with `username` `password` and `host`. 18 | - run following command (compares two no-op pipelines on exiting output, should succeed). 19 | 20 | `python domagic.py --control-pipeline "ls /" --input-tables farts mcgee --lakeFS --repo my-repo --new-pipeline "ls /" --output-tables sample_data` 21 | 22 | OR if your running in local mode: 23 | 24 | `python domagic.py --control-pipeline "ls /" --input-tables farts mcgee --lakeFS --repo my-repo --new-pipeline "ls /" --output-tables "sample_data/release=v1.9/type=relation/20220106_182445_00068_pa8u7_04924a3b-01b0-4174-9772-7285db53a68c" --format parquet` 25 | 26 | ### LakeFS FAQ 27 | 28 | Why don't you just use commit hashes? 29 | 30 | Many things can result in different binary data on disk but still have the same effective data stored. 31 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/MigrateDeprecatedDataFrameReaderFuns.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class MigrateDeprecatedDataFrameReaderFuns extends SemanticRule("MigrateDeprecatedDataFrameReaderFuns") { 7 | 8 | override def fix(implicit doc: SemanticDocument): Patch = { 9 | val readerMatcher = SymbolMatcher.normalized("org.apache.spark.sql.DataFrameReader") 10 | val jsonReaderMatcher = SymbolMatcher.normalized("org.apache.spark.sql.DataFrameReader.json") 11 | val utils = new Utils() 12 | 13 | def matchOnTree(e: Tree): Patch = { 14 | e match { 15 | case ns @ Term.Apply(jsonReaderMatcher(reader), List(param)) => 16 | param match { 17 | case utils.rddMatcher(rdd) => 18 | (Patch.addLeft(rdd, "session.createDataset(") + Patch.addRight(rdd, ")(Encoders.STRING)") + 19 | utils.addImportIfNotPresent(importer"org.apache.spark.sql.Encoders")) 20 | case _ => 21 | Patch.empty 22 | } 23 | case elem @ _ => 24 | elem.children match { 25 | case Nil => Patch.empty 26 | case _ => elem.children.map(matchOnTree).asPatch 27 | } 28 | } 29 | } 30 | matchOnTree(doc.tree) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id "scala" 3 | id 'java-library' 4 | } 5 | 6 | scala { 7 | sourceCompatibility = "2.12" 8 | targetCompatibility = "2.12" 9 | } 10 | 11 | java { 12 | withSourcesJar() 13 | } 14 | 15 | repositories { 16 | mavenCentral() 17 | maven { 18 | name "sonatype-releases" 19 | url "https://oss.sonatype.org/content/repositories/releases/" 20 | } 21 | maven { 22 | name "Typesafe repository" 23 | url "https://repo.typesafe.com/typesafe/releases/" 24 | } 25 | maven { 26 | name "Second Typesafe repo" 27 | url "https://repo.typesafe.com/typesafe/maven-releases/" 28 | } 29 | } 30 | 31 | dependencies { 32 | compileOnly group: "org.apache.spark", name: 'spark-streaming_2.12', version: '2.4.8' 33 | compileOnly group: "org.apache.spark", name: 'spark-sql_2.12', version: '2.4.8' 34 | 35 | testImplementation group: "org.scalatest", name : "scalatest_2.12", version: "3.2.2" 36 | testImplementation group: "org.scalacheck", name: 'scalacheck_2.12', version: '1.15.2' 37 | testImplementation group: "com.holdenkarau", name: 'spark-testing-base_2.12', version: '2.4.8_1.3.0' 38 | 39 | } 40 | 41 | configurations { 42 | testImplementation.extendsFrom compileOnly 43 | } 44 | 45 | group "com.holdenkarau" 46 | version "0.0.1" -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/MigrateTrigger.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class MigrateTrigger extends SemanticRule("MigrateTrigger") { 7 | override val description = 8 | """Migrate Trigger.""" 9 | override val isRewrite = true 10 | 11 | val triggerMatcher = SymbolMatcher.normalized("org.apache.spark.sql.streaming.ProcessingTime") 12 | 13 | override def fix(implicit doc: SemanticDocument): Patch = { 14 | val utils = new Utils() 15 | def matchOnTree(e: Tree): Patch = { 16 | e match { 17 | // Trigger match seems to be matching too widly sometimes? 18 | case triggerMatcher(e) => 19 | if (e.toString.contains("ProcessingTime")) { 20 | utils.addImportIfNotPresent(importer"org.apache.spark.sql.streaming.Trigger._") 21 | } else { 22 | None.asPatch 23 | } 24 | case elem @ _ => 25 | elem.children match { 26 | case Nil => Patch.empty 27 | case _ => elem.children.map(matchOnTree).asPatch 28 | } 29 | } 30 | } 31 | // Deal with the spurious matches by only running on files that importing streaming. 32 | if (doc.input.text.contains("org.apache.spark.sql.streaming")) { 33 | matchOnTree(doc.tree) 34 | } else { 35 | None.asPatch 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /sql/src/sparksql_upgrade/plugin.py: -------------------------------------------------------------------------------- 1 | """Custom Spark SQL upgrade rules.""" 2 | 3 | import os.path 4 | from typing import List 5 | 6 | 7 | from sqlfluff.core.config import ConfigLoader 8 | from sqlfluff.core.plugin import hookimpl 9 | from sqlfluff.core.rules import BaseRule 10 | 11 | 12 | @hookimpl 13 | def get_rules() -> List[BaseRule]: 14 | """Get plugin rules.""" 15 | from .rules import ( 16 | Rule_SPARKSQLCAST_L001, 17 | Rule_RESERVEDROPERTIES_L002, 18 | Rule_NOCHARS_L003, 19 | Rule_FORMATSTRONEINDEX_L004, 20 | Rule_SPARKSQL_L004, 21 | Rule_SPARKSQL_L005, 22 | ) 23 | 24 | return [ 25 | Rule_SPARKSQLCAST_L001, 26 | Rule_RESERVEDROPERTIES_L002, 27 | Rule_NOCHARS_L003, 28 | Rule_FORMATSTRONEINDEX_L004, 29 | Rule_SPARKSQL_L004, 30 | Rule_SPARKSQL_L005, 31 | ] 32 | 33 | 34 | @hookimpl 35 | def load_default_config() -> dict: 36 | """Loads the default configuration for the plugin.""" 37 | return ConfigLoader.get_global().load_config_file( 38 | file_dir=os.path.dirname(__file__), 39 | file_name="plugin_default_config.cfg", 40 | ) 41 | 42 | 43 | @hookimpl 44 | def get_configs_info() -> dict: 45 | """Get rule config validations and descriptions.""" 46 | return { 47 | "forbidden_columns": {"definition": "A list of column to forbid"}, 48 | } 49 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/build.gradle-33-example: -------------------------------------------------------------------------------- 1 | plugins { 2 | id 'io.github.cosmicsilence.scalafix' version '0.1.14' 3 | 4 | id "scala" 5 | id 'java-library' 6 | } 7 | 8 | scala { 9 | sourceCompatibility = "2.12" 10 | targetCompatibility = "2.12" 11 | } 12 | 13 | java { 14 | withSourcesJar() 15 | } 16 | 17 | repositories { 18 | mavenCentral() 19 | maven { 20 | name "sonatype-releases" 21 | url "https://oss.sonatype.org/content/repositories/releases/" 22 | } 23 | maven { 24 | name "Typesafe repository" 25 | url "https://repo.typesafe.com/typesafe/releases/" 26 | } 27 | maven { 28 | name "Second Typesafe repo" 29 | url "https://repo.typesafe.com/typesafe/maven-releases/" 30 | } 31 | } 32 | 33 | dependencies { 34 | compileOnly group: "org.apache.spark", name: 'spark-streaming_2.12', version: '3.3.100' 35 | compileOnly group: "org.apache.spark", name: 'spark-sql_2.12', version: '3.3.100' 36 | 37 | testImplementation group: "org.scalatest", name : "scalatest_2.12", version: "3.2.2" 38 | testImplementation group: "org.scalacheck", name: 'scalacheck_2.12', version: '1.15.2' 39 | testImplementation group: "com.holdenkarau", name: 'spark-testing-base_2.12', version: '3.3.1_1.3.0' 40 | 41 | } 42 | 43 | configurations { 44 | testImplementation.extendsFrom compileOnly 45 | } 46 | 47 | group "com.holdenkarau" 48 | version "0.0.1" 49 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/MultiLineDatasetReadWarn.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | case class MultiLineDatasetReadWarning(tn: scala.meta.Tree) extends Diagnostic { 7 | override def position: Position = tn.pos 8 | 9 | override def message: String = 10 | """In Spark 2.4.X and below, 11 | |when reading multi-line textual input with \r\n (windows line feed) _might_ 12 | |leave \rs. You can get this legacy behaviour by specifying a lineSep of "\n", 13 | |but for most people this was bug. 14 | |This linter rule is fuzzy.""".stripMargin 15 | } 16 | 17 | class MultiLineDatasetReadWarn extends SemanticRule("MultiLineDatasetReadWarn") { 18 | val matcher = SymbolMatcher.normalized("org.apache.spark.sql.DataFrameReader#option") 19 | override val description = "MultiLine text input dataframe warning." 20 | 21 | override def fix(implicit doc: SemanticDocument): Patch = { 22 | // Imperfect, maybe someone will have the string "multiline" while reading from a DataFrame but it's an ok place to start. 23 | if (doc.input.text.contains("'multiline'") || doc.input.text.contains("\"multiline\"")) { 24 | doc.tree.collect { 25 | case matcher(read) => 26 | if (read.toString.contains("multiline")) { 27 | Patch.lint(MultiLineDatasetReadWarning(read)) 28 | } else { 29 | None.asPatch 30 | } 31 | }.asPatch 32 | } else { 33 | Patch.empty 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/ExpressionEncoder.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class ExpressionEncoder extends SemanticRule("ExpressionEncoder") { 7 | 8 | override def fix(implicit doc: SemanticDocument): Patch = { 9 | val toRowMatcher = SymbolMatcher.normalized("org/apache/spark/sql/catalyst/encoders/ExpressionEncoder#toRow().") 10 | val fromRowMatcher = SymbolMatcher.normalized("org/apache/spark/sql/catalyst/encoders/ExpressionEncoder#fromRow().") 11 | val utils = new Utils() 12 | 13 | def matchOnTree(e: Tree): Patch = { 14 | e match { 15 | case toRowMatcher(call) => 16 | // This is sketch because were messing with the string repr but it's easier 17 | // since we only want to replace some of our match. 18 | val newCall = call.toString.replace(".toRow", ".createSerializer()") 19 | Patch.replaceTree(call, newCall) 20 | case fromRowMatcher(call) => 21 | // This is sketch because were messing with the string repr but it's easier 22 | // since we only want to replace some of our match. 23 | val newCall = call.toString.replace(".fromRow", ".createDeserializer()") 24 | Patch.replaceTree(call, newCall) 25 | case elem @ _ => 26 | elem.children match { 27 | case Nil => Patch.empty 28 | case _ => elem.children.map(matchOnTree).asPatch 29 | } 30 | } 31 | } 32 | matchOnTree(doc.tree) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/build.gradle.scalafix: -------------------------------------------------------------------------------- 1 | plugins { 2 | id "scala" 3 | id 'java-library' 4 | id "io.github.cosmicsilence.scalafix" version "0.1.14" 5 | } 6 | 7 | scala { 8 | sourceCompatibility = "2.12" 9 | targetCompatibility = "2.12" 10 | } 11 | 12 | java { 13 | withSourcesJar() 14 | } 15 | 16 | repositories { 17 | mavenCentral() 18 | maven { 19 | name "sonatype-releases" 20 | url "https://oss.sonatype.org/content/repositories/releases/" 21 | } 22 | maven { 23 | name "Typesafe repository" 24 | url "https://repo.typesafe.com/typesafe/releases/" 25 | } 26 | maven { 27 | name "Second Typesafe repo" 28 | url "https://repo.typesafe.com/typesafe/maven-releases/" 29 | } 30 | } 31 | 32 | dependencies { 33 | compileOnly group: "org.apache.spark", name: 'spark-streaming_2.12', version: '2.4.8' 34 | compileOnly group: "org.apache.spark", name: 'spark-sql_2.12', version: '2.4.8' 35 | 36 | testImplementation group: "org.scalatest", name : "scalatest_2.12", version: "3.2.2" 37 | testImplementation group: "org.scalacheck", name: 'scalacheck_2.12', version: '1.15.2' 38 | testImplementation group: "com.holdenkarau", name: 'spark-testing-base_2.12', version: '2.4.8_1.3.0' 39 | 40 | scalafix group: "com.holdenkarau", name: 'spark-scalafix-rules-2.4.8_2.12', version: '0.1.9' 41 | } 42 | 43 | configurations { 44 | testImplementation.extendsFrom compileOnly 45 | } 46 | 47 | group "com.holdenkarau" 48 | version "0.0.1" -------------------------------------------------------------------------------- /pysparkler/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pysparkler" 3 | version = "0.9.dev" 4 | description = "A tool that upgrades your PySpark scripts to the latest Spark version as per Spark migration Guideline" 5 | authors = ["Dhruv Pratap "] 6 | readme = "README.md" 7 | license = "Apache-2.0" 8 | homepage = "https://github.com/holdenk/spark-upgrade" 9 | repository = "https://github.com/holdenk/spark-upgrade" 10 | maintainers = [ 11 | "Holden Karau ", 12 | ] 13 | 14 | 15 | [tool.poetry.dependencies] 16 | python = "^3.10" 17 | libcst = "^1.0.1" 18 | click = "^8.1.3" 19 | rich = "^13.3.3" 20 | nbformat = "^5.8.0" 21 | sqlfluff = "^1.0.0" 22 | sqlfluff-plugin-sparksql-upgrade = "^0.1.0" 23 | 24 | 25 | [tool.poetry.group.test.dependencies] 26 | pytest = "^7.2.2" 27 | 28 | 29 | [tool.poetry.group.lint.dependencies] 30 | pre-commit = "^3.2.1" 31 | 32 | 33 | [tool.poetry.scripts] 34 | pysparkler = "pysparkler.cli:run" 35 | 36 | 37 | [build-system] 38 | requires = ["poetry-core"] 39 | build-backend = "poetry.core.masonry.api" 40 | 41 | 42 | [tool.isort] 43 | src_paths = ["pysparkler/", "tests/"] 44 | profile = 'black' 45 | 46 | 47 | [[tool.mypy.overrides]] 48 | module = "libcst.*" 49 | ignore_missing_imports = true 50 | 51 | [[tool.mypy.overrides]] 52 | module = "click.*" 53 | ignore_missing_imports = true 54 | 55 | [[tool.mypy.overrides]] 56 | module = "rich.*" 57 | ignore_missing_imports = true 58 | 59 | [[tool.mypy.overrides]] 60 | module = "nbformat.*" 61 | ignore_missing_imports = true 62 | 63 | [[tool.mypy.overrides]] 64 | module = "sqlfluff.*" 65 | ignore_missing_imports = true 66 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/MigrateToSparkSessionBuilder.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class MigrateToSparkSessionBuilder extends SemanticRule("MigrateToSparkSessionBuilder") { 7 | 8 | override def fix(implicit doc: SemanticDocument): Patch = { 9 | val sqlSymbolMatcher = SymbolMatcher.normalized("org.apache.spark.sql.SQLContext") 10 | val sqlGetOrCreateMatcher = SymbolMatcher.normalized("org.apache.spark.sql.SQLContext.getOrCreate") 11 | val newCreate = "SparkSession.builder.getOrCreate().sqlContext" 12 | def matchOnTree(e: Tree): Patch = { 13 | e match { 14 | // Rewrite the construction of a SQLContext 15 | case ns @ Term.New(Init(initArgs)) => 16 | initArgs match { 17 | case (sqlSymbolMatcher(s), _, _) => 18 | List( 19 | Patch.replaceTree( 20 | ns, 21 | newCreate), 22 | Patch.addGlobalImport(importer"org.apache.spark.sql.SparkSession") 23 | ).asPatch 24 | case _ => Patch.empty 25 | } 26 | case ns @ Term.Apply(sqlGetOrCreateMatcher(_), _) => 27 | List( 28 | Patch.replaceTree( 29 | ns, 30 | newCreate), 31 | Patch.addGlobalImport(importer"org.apache.spark.sql.SparkSession") 32 | ).asPatch 33 | case elem @ _ => 34 | elem.children match { 35 | case Nil => Patch.empty 36 | case _ => elem.children.map(matchOnTree).asPatch 37 | } 38 | } 39 | } 40 | matchOnTree(doc.tree) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/MetadataWarnQQ.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | import scalafix.v1._ 3 | import scala.meta._ 4 | 5 | case class MetadataWarning(t: scala.meta.Tree) extends Diagnostic { 6 | override def position = t.pos 7 | override def message = """ 8 | |In Spark 3.0, the column metadata 9 | |will always be propagated in the API Column.name and Column.as. 10 | |In Spark version 2.4 and earlier, the metadata of NamedExpression 11 | |is set as the explicitMetadata for the new column 12 | |at the time the API is called, 13 | |it won’t change even if the underlying NamedExpression changes metadata. 14 | |To restore the behavior before Spark 3.0, 15 | |you can use the API as(alias: String, metadata: Metadata) with explicit metadata.""".stripMargin 16 | } 17 | 18 | class MetadataWarnQQ extends SemanticRule("MetadataWarnQQ") { 19 | val matcher = SymbolMatcher.normalized("org.apache.spark.sql.types.Metadata") 20 | override val description = "Metadata warning." 21 | 22 | override def fix(implicit doc: SemanticDocument): Patch = { 23 | def isSelectAndAs(t: Tree): Boolean = { 24 | val isSelect = t.collect { case q"""select""" => true } 25 | val isAs = t.collect { case q"""as""" => true } 26 | (isSelect.isEmpty.equals(false) && isSelect.head.equals( 27 | true 28 | )) && (isAs.isEmpty.equals(false) && isAs.head.equals(true)) 29 | } 30 | 31 | doc.tree.collect { case matcher(s) => 32 | if (isSelectAndAs(doc.tree)) Patch.lint(MetadataWarning(s)) 33 | else Patch.empty 34 | }.asPatch 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /iceberg-spark-upgrade-wap-plugin/build.sbt: -------------------------------------------------------------------------------- 1 | import Dependencies._ 2 | 3 | resolvers += Resolver.mavenLocal 4 | resolvers += Resolver.sonatypeRepo("public" ) 5 | resolvers += Resolver.typesafeRepo("releases") 6 | resolvers += Resolver.sbtPluginRepo("releases") 7 | 8 | 9 | ThisBuild / scalaVersion := "2.12.8" 10 | ThisBuild / version := "0.1.0-SNAPSHOT" 11 | ThisBuild / organization := "com.holdenkarau" 12 | ThisBuild / organizationName := "holdenkarau" 13 | ThisBuild / name := "Iceberg WAP plugin" 14 | ThisBuild / javacOptions ++= Seq("-source", "1.8", "-target", "1.8") 15 | 16 | Test / classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.Flat 17 | Test / parallelExecution := false 18 | Test / fork := true 19 | Test / javaOptions += "-javaagent:./target/scala-2.12/iceberg-spark-upgrade-wap-plugin_2.12-0.1.0-SNAPSHOT.jar" 20 | Test / compile := ((Test / compile) dependsOn( Compile / Keys.`package` )).value 21 | 22 | 23 | lazy val root = (project in file(".")) 24 | .settings( 25 | name := "Iceberg Spark Upgrade WAP Plugin", 26 | libraryDependencies += scalaTest % Test, 27 | libraryDependencies += icebergSparkRuntime % Test, 28 | libraryDependencies += sparkTestingBase % Test, 29 | libraryDependencies += iceberg % Provided, 30 | ) 31 | 32 | // Since sbt generates a MANIFEST.MF file rather than storing one in resources and dealing the conflict 33 | // just add our properties to the one sbt generates for us. 34 | Compile / packageBin / packageOptions ++= List( 35 | Package.ManifestAttributes("Premain-Class" -> "com.holdenkarau.spark.upgrade.wap.plugin.Agent"), 36 | Package.ManifestAttributes("Agent-Class" -> "com.holdenkarau.spark.upgrade.wap.plugin.Agent"), 37 | Package.ManifestAttributes("Can-Redefine-Classes" -> "com.holdenkarau.spark.upgrade.wap.plugin.Agent")) 38 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | branches: [master, main] 5 | tags: ["*"] 6 | jobs: 7 | publish: 8 | runs-on: ubuntu-20.04 9 | steps: 10 | - uses: actions/checkout@v3.0.2 11 | with: 12 | fetch-depth: 0 13 | - uses: olafurpg/setup-scala@v13 14 | - name: Release 2.4.8 -> 3.3 15 | run: | 16 | cd scalafix 17 | sbt ci-release 18 | env: 19 | PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} 20 | PGP_SECRET: ${{ secrets.PGP_SECRET }} 21 | SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} 22 | SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} 23 | - name: Release 2.3.2 -> 3.3 24 | run: | 25 | cd scalafix 26 | sbt ci-release -DsparkVersion=2.3.2 27 | env: 28 | PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} 29 | PGP_SECRET: ${{ secrets.PGP_SECRET }} 30 | SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} 31 | SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} 32 | - name: Release 2.1.1 -> 3.3 33 | run: | 34 | cd scalafix 35 | sbt ci-release -DsparkVersion=2.1.1 36 | env: 37 | PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} 38 | PGP_SECRET: ${{ secrets.PGP_SECRET }} 39 | SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} 40 | SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} 41 | pypi-publish-pysparkler: 42 | runs-on: ubuntu-latest 43 | steps: 44 | - name: Checkout PySparkler 45 | uses: actions/checkout@v3 46 | with: 47 | fetch-depth: 0 48 | - name: Publish PySparkler Package on PyPI 49 | run: | 50 | cd pysparkler 51 | make publish 52 | env: 53 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYSPARKLER_PYPI_API_TOKEN }} 54 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/GroupByKeyWarn.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=GroupByKeyWarn 3 | */ 4 | package fix 5 | import org.apache.spark.sql.{SparkSession, Dataset} 6 | 7 | class GroupByKeyWarn { 8 | def inSource(sparkSession: SparkSession): Unit = { 9 | import sparkSession.implicits._ 10 | val ds1 = List( // assert: GroupByKeyWarn 11 | "Person 1", 12 | "Person 2", 13 | "User 1", 14 | "User 2", 15 | "User 3", 16 | "Test", 17 | "Test Test" 18 | ).toDS() 19 | .groupByKey(l => l.substring(0, 3)) // assert: GroupByKeyWarn 20 | .count() 21 | 22 | val noChange: Dataset[String] = List("1").toDS() 23 | val noChangeMore = List("1").toDS().count() 24 | 25 | // Make sure we trigger not just on toDS 26 | val someChange = noChange.groupByKey(l => // assert: GroupByKeyWarn 27 | l.substring(0, 3).toUpperCase() 28 | ) 29 | .count() 30 | 31 | val ds2: Dataset[(String, Long)] = 32 | List("Test 1", "Test 2", "user 1", "Person 1", "Person 2") // assert: GroupByKeyWarn 33 | .toDS() 34 | .groupByKey(l => // assert: GroupByKeyWarn 35 | l.substring(0, 3).toUpperCase() 36 | ) 37 | .count() 38 | 39 | val ds3 = 40 | List(1, 2, 3, 4, 5, 6) // assert: GroupByKeyWarn 41 | .toDS() 42 | .groupByKey(l => l > 3) // assert: GroupByKeyWarn 43 | .count() 44 | 45 | val ds4 = 46 | List(Array(19, 12), Array(1, 2, 3, 4, 5, 6), Array(678, 99, 88)) // assert: GroupByKeyWarn 47 | .toDS() 48 | .groupByKey(l => l.length >= 3) // assert: GroupByKeyWarn 49 | .count() 50 | 51 | val ds5 = Seq("test 1", "test 2", "test 3").toDS() 52 | 53 | val ds6 = Seq("test 1", "test 2", "test 3").toDS().groupBy("value").count() 54 | 55 | val ds7 = Seq("test 1", "test 2", "test 3").toDF().groupBy("value").count() 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /pysparkler/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | 19 | import os 20 | 21 | import libcst as cst 22 | 23 | 24 | def rewrite(given_code: str, cst_transformer: cst.CSTTransformer): 25 | given_tree = cst.parse_module(given_code) 26 | modified_tree = given_tree.visit(cst_transformer) 27 | modified_code = modified_tree.code 28 | return modified_code 29 | 30 | 31 | def absolute_path(relative_path: str): 32 | cwd = os.getcwd() 33 | # Tokenize on the path separator 34 | relative_path_tokens = relative_path.split(os.path.sep) 35 | cwd_tokens = cwd.split(os.path.sep) 36 | 37 | # Check if last token of cwd is the same as the first token of the relative path 38 | if cwd_tokens[-1] == relative_path_tokens[0]: 39 | # Remove the first token of the relative path 40 | relative_path_tokens.pop(0) 41 | # Join the remaining tokens 42 | relative_path = os.path.sep.join(relative_path_tokens) 43 | # Return the relative path 44 | 45 | # Join with the current directory to make the absolute path 46 | return os.path.join(cwd, relative_path) 47 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/src/main/scala/com/holdenkarau/sparkdemoproject/CountingApp.scala: -------------------------------------------------------------------------------- 1 | package com.holdenkarau.sparkDemoProject 2 | 3 | import org.apache.spark.{SparkConf, SparkContext} 4 | import org.apache.spark.sql._ 5 | 6 | /** 7 | * Use this to test the app locally, from sbt: 8 | * sbt "run inputFile.txt outputFile.txt" 9 | * (+ select CountingLocalApp when prompted) 10 | */ 11 | object CountingLocalApp { 12 | def main(args: Array[String]) = { 13 | val (inputFile, outputFile) = (args(0), args(1)) 14 | val conf = new SparkConf() 15 | .setMaster("local") 16 | .setAppName("my awesome app") 17 | 18 | Runner.run(conf, inputFile, outputFile) 19 | } 20 | } 21 | 22 | /** 23 | * Use this when submitting the app to a cluster with spark-submit 24 | * */ 25 | object CountingApp { 26 | def main(args: Array[String]) = { 27 | val (inputFile, outputFile) = (args(0), args(1)) 28 | 29 | // spark-submit command should supply all necessary config elements 30 | Runner.run(new SparkConf(), inputFile, outputFile) 31 | } 32 | } 33 | 34 | object Runner { 35 | def run(conf: SparkConf, inputPath: String, outputTable: String): Unit = { 36 | val sc = new SparkContext(conf) 37 | val spark = SparkSession.builder().getOrCreate() 38 | val df = spark.read.format("text").load(inputPath) 39 | val counts = WordCount.dataFrameWC(df) 40 | counts.cache() 41 | counts.count() 42 | // Try and append, or create. 43 | try { 44 | counts.write.format("iceberg").mode("overwrite").save(outputTable) 45 | } catch { 46 | case e => 47 | println(s"Error $e writing to $outputTable creating fresh output table at location.") 48 | spark.sql(s"CREATE TABLE ${outputTable} (word string, count long) USING iceberg") 49 | counts.write.format("iceberg").saveAsTable(outputTable) 50 | spark.sql(s"ALTER TABLE ${outputTable} SET write.wap.enabled=true") 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /pysparkler/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | files: ^pysparkler/ 2 | exclude: ^pysparkler/tests/sample 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-docstring-first 11 | - id: debug-statements 12 | - id: check-yaml 13 | - id: check-ast 14 | - repo: https://github.com/ambv/black 15 | rev: 23.3.0 16 | hooks: 17 | - id: black 18 | - repo: https://github.com/pre-commit/mirrors-isort 19 | rev: v5.10.1 20 | hooks: 21 | - id: isort 22 | args: 23 | - --settings-path=pysparkler/pyproject.toml 24 | - repo: https://github.com/pre-commit/mirrors-mypy 25 | rev: v1.1.1 26 | hooks: 27 | - id: mypy 28 | args: 29 | - --install-types 30 | - --non-interactive 31 | - --config=pysparkler/pyproject.toml 32 | - repo: https://github.com/hadialqattan/pycln 33 | rev: v2.1.3 34 | hooks: 35 | - id: pycln 36 | args: 37 | - --config=pysparkler/pyproject.toml 38 | - repo: https://github.com/asottile/pyupgrade 39 | rev: v3.3.1 40 | hooks: 41 | - id: pyupgrade 42 | args: 43 | - --py310-plus 44 | - repo: https://github.com/pycqa/pylint 45 | rev: v2.17.1 46 | hooks: 47 | - id: pylint 48 | args: 49 | - --rcfile=pysparkler/pylintrc 50 | - repo: https://github.com/pycqa/flake8 51 | rev: 6.0.0 52 | hooks: 53 | - id: flake8 54 | args: 55 | - --max-line-length=120 56 | additional_dependencies: 57 | - flake8-bugbear 58 | - flake8-comprehensions 59 | - repo: https://github.com/executablebooks/mdformat 60 | rev: 0.7.17 61 | hooks: 62 | - id: mdformat 63 | additional_dependencies: 64 | - mdformat-black 65 | - mdformat-config 66 | - mdformat-beautysh 67 | - mdformat-admon 68 | -------------------------------------------------------------------------------- /pysparkler/tests/test_cli.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | from tempfile import NamedTemporaryFile 19 | 20 | from click.testing import CliRunner 21 | 22 | from pysparkler.cli import run 23 | from tests.conftest import absolute_path 24 | 25 | 26 | def test_upgrade_cli(): 27 | with NamedTemporaryFile(mode="w", delete=False, encoding="utf-8") as output_file: 28 | runner = CliRunner() 29 | result = runner.invoke( 30 | cli=run, 31 | args=[ 32 | "--config-yaml", 33 | absolute_path("tests/sample/config.yaml"), 34 | "upgrade", 35 | "--input-file", 36 | absolute_path("tests/sample/input_pyspark.py"), 37 | "--output-file", 38 | output_file.name, 39 | ], 40 | ) 41 | 42 | print(result.output) 43 | assert result.exit_code == 0 # Check exit code 44 | 45 | with open(output_file.name, encoding="utf-8") as f: 46 | modified_code = f.read() 47 | assert "A new comment" in modified_code 48 | assert "PY24-30-002" not in modified_code 49 | 50 | # Clean up and delete the temporary file 51 | output_file.close() 52 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/GroupByKeyRewrite.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | import org.apache.spark.sql.SparkSession 3 | import org.apache.spark.sql.functions._ 4 | 5 | object GroupByKeyRewrite { 6 | def isSource1(sparkSession: SparkSession): Unit = { 7 | import sparkSession.implicits._ 8 | val ds1 = 9 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 10 | .toDS() 11 | .groupByKey(l => l.substring(0, 3)) 12 | .count() 13 | .withColumnRenamed("key", "newName") 14 | 15 | val ds11 = 16 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 17 | .toDS() 18 | .withColumnRenamed("value", "newName") 19 | 20 | val df11 = 21 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 22 | .toDF() 23 | .withColumnRenamed("value", "newName") 24 | 25 | val ds2 = 26 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 27 | .toDS() 28 | .groupByKey(l => l.substring(0, 3)) 29 | .count() 30 | .select($"key", $"count(1)") 31 | 32 | val ds3 = 33 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 34 | .toDS() 35 | .groupByKey(l => l.substring(0, 3)) 36 | .count() 37 | .select(col("key"), col("count(1)")) 38 | 39 | val ds4 = 40 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 41 | .toDS() 42 | .groupByKey(l => l.substring(0, 3)) 43 | .count() 44 | .select('key, 'count (1)) 45 | 46 | val ds5 = 47 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 48 | .toDS() 49 | .groupByKey(l => l.substring(0, 3)) 50 | .count() 51 | .withColumn("newNameCol", upper(col("key"))) 52 | 53 | val ds6 = 54 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 55 | .toDS() 56 | .groupByKey(l => l.substring(0, 3)) 57 | .count() 58 | .withColumn("value", upper(col("key"))) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/src/main/scala/com/holdenkarau/sparkdemoproject/WordCount.scala: -------------------------------------------------------------------------------- 1 | package com.holdenkarau.sparkDemoProject 2 | 3 | /** 4 | * Everyone's favourite wordcount example. 5 | */ 6 | 7 | import org.apache.spark._ 8 | import org.apache.spark.rdd._ 9 | import org.apache.spark.sql._ 10 | import org.apache.spark.sql.functions._ 11 | 12 | object WordCount { 13 | /** 14 | * A slightly more complex than normal wordcount example with optional 15 | * separators and stopWords. Splits on the provided separators, removes 16 | * the stopwords, and converts everything to lower case. 17 | */ 18 | def dataFrameWC(df : DataFrame, 19 | separators : Array[Char] = " ".toCharArray, 20 | stopWords : Set[String] = Set("the")): DataFrame = { 21 | // Yes this is deprecated, but it should get rewritten by our magic. 22 | val spark = SQLContext.getOrCreate(SparkContext.getOrCreate()) 23 | import spark.implicits._ 24 | val splitPattern = "[" + separators.mkString("") + "]" 25 | val stopArray = array(stopWords.map(lit).toSeq:_*) 26 | val words = df.select(explode(split(lower(col("value")), splitPattern)).as("words")).filter( 27 | not(array_contains(stopArray, col("words")))) 28 | // This will need to be re-written in 2 -> 3 29 | // Normally we would use groupBy(col("words")) instead by that doesn't require the migration step :p 30 | def keyMe(x: Row): String = { 31 | x.apply(0).asInstanceOf[String] 32 | } 33 | words.groupByKey(keyMe).count().select(col("value").as("word"), col("count(1)").as("count")).orderBy("count") 34 | } 35 | 36 | def withStopWordsFiltered(rdd : RDD[String], 37 | separators : Array[Char] = " ".toCharArray, 38 | stopWords : Set[String] = Set("the")): RDD[(String, Long)] = { 39 | val spark = SQLContext.getOrCreate(SparkContext.getOrCreate()) 40 | import spark.implicits._ 41 | val df = rdd.toDF 42 | val resultDF = dataFrameWC(df, separators, stopWords) 43 | resultDF.rdd.map { row => 44 | (row.apply(0).asInstanceOf[String], row.apply(1).asInstanceOf[Long]) 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/GroupByKeyRewrite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=GroupByKeyRewrite 3 | */ 4 | package fix 5 | import org.apache.spark.sql.SparkSession 6 | import org.apache.spark.sql.functions._ 7 | 8 | object GroupByKeyRewrite { 9 | def isSource1(sparkSession: SparkSession): Unit = { 10 | import sparkSession.implicits._ 11 | val ds1 = 12 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 13 | .toDS() 14 | .groupByKey(l => l.substring(0, 3)) 15 | .count() 16 | .withColumnRenamed("value", "newName") 17 | 18 | val ds11 = 19 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 20 | .toDS() 21 | .withColumnRenamed("value", "newName") 22 | 23 | val df11 = 24 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 25 | .toDF() 26 | .withColumnRenamed("value", "newName") 27 | 28 | val ds2 = 29 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 30 | .toDS() 31 | .groupByKey(l => l.substring(0, 3)) 32 | .count() 33 | .select($"value", $"count(1)") 34 | 35 | val ds3 = 36 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 37 | .toDS() 38 | .groupByKey(l => l.substring(0, 3)) 39 | .count() 40 | .select(col("value"), col("count(1)")) 41 | 42 | val ds4 = 43 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 44 | .toDS() 45 | .groupByKey(l => l.substring(0, 3)) 46 | .count() 47 | .select('value, 'count (1)) 48 | 49 | val ds5 = 50 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 51 | .toDS() 52 | .groupByKey(l => l.substring(0, 3)) 53 | .count() 54 | .withColumn("newNameCol", upper(col("value"))) 55 | 56 | val ds6 = 57 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 58 | .toDS() 59 | .groupByKey(l => l.substring(0, 3)) 60 | .count() 61 | .withColumn("value", upper(col("value"))) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala-2.11/fix/UnionRewrite.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | import metaconfig.generic.Surface 3 | import metaconfig.{ConfDecoder, Configured} 4 | import scalafix.v1._ 5 | 6 | import scala.meta._ 7 | final case class UnionRewriteConfig( 8 | deprecatedMethod: Map[String, String] 9 | ) 10 | 11 | object UnionRewriteConfig { 12 | val default: UnionRewriteConfig = 13 | UnionRewriteConfig( 14 | deprecatedMethod = Map( 15 | "unionAll" -> "union" 16 | ) 17 | ) 18 | 19 | implicit val surface: Surface[UnionRewriteConfig] = 20 | metaconfig.generic.deriveSurface[UnionRewriteConfig] 21 | implicit val decoder: ConfDecoder[UnionRewriteConfig] = 22 | metaconfig.generic.deriveDecoder(default) 23 | } 24 | 25 | class UnionRewrite(config: UnionRewriteConfig) extends SemanticRule("UnionRewrite") { 26 | def this() = this(UnionRewriteConfig.default) 27 | 28 | override def withConfiguration(config: Configuration): Configured[Rule] = 29 | config.conf.getOrElse("UnionRewrite")(this.config).map { newConfig => 30 | new UnionRewrite(newConfig) 31 | } 32 | 33 | override val isRewrite = true 34 | 35 | override def fix(implicit doc: SemanticDocument): Patch = { 36 | def matchOnTree(t: Tree): Patch = { 37 | t.collect { 38 | case Term.Apply( 39 | Term.Select(_, deprecated @ Term.Name(name)), 40 | _ 41 | ) if config.deprecatedMethod.contains(name) => 42 | Patch.replaceTree( 43 | deprecated, 44 | config.deprecatedMethod(name) 45 | ) 46 | case Term.Apply( 47 | Term.Select(_, _ @Term.Name(name)), 48 | List( 49 | Term.AnonymousFunction( 50 | Term.ApplyInfix( 51 | _, 52 | deprecatedAnm @ Term.Name(nameAnm), 53 | _, 54 | _ 55 | ) 56 | ) 57 | ) 58 | ) if "reduce".contains(name) && config.deprecatedMethod.contains(nameAnm) => 59 | Patch.replaceTree( 60 | deprecatedAnm, 61 | config.deprecatedMethod(nameAnm) 62 | ) 63 | }.asPatch 64 | } 65 | 66 | matchOnTree(doc.tree) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/build.sbt: -------------------------------------------------------------------------------- 1 | // give the user a nice default project! 2 | 3 | val sparkVersion = settingKey[String]("Spark version") 4 | 5 | lazy val root = (project in file(".")). 6 | 7 | settings( 8 | inThisBuild(List( 9 | organization := "com.holdenkarau", 10 | scalaVersion := "2.12.13" 11 | )), 12 | name := "sparkDemoProject", 13 | version := "0.0.1", 14 | 15 | sparkVersion := "2.4.8", 16 | 17 | javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), 18 | javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:MaxPermSize=2048M", "-XX:+CMSClassUnloadingEnabled"), 19 | scalacOptions ++= Seq("-deprecation", "-unchecked"), 20 | parallelExecution in Test := false, 21 | fork := true, 22 | 23 | coverageHighlighting := true, 24 | 25 | libraryDependencies ++= Seq( 26 | "org.apache.spark" %% "spark-streaming" % "2.4.8" % "provided", 27 | "org.apache.spark" %% "spark-sql" % "2.4.8" % "provided", 28 | 29 | "org.scalatest" %% "scalatest" % "3.2.2" % "test", 30 | "org.scalacheck" %% "scalacheck" % "1.15.2" % "test", 31 | "com.holdenkarau" %% "spark-testing-base" % "2.4.8_1.3.0" % "test" 32 | ), 33 | 34 | // uses compile classpath for the run task, including "provided" jar (cf http://stackoverflow.com/a/21803413/3827) 35 | run in Compile := Defaults.runTask(fullClasspath in Compile, mainClass in (Compile, run), runner in (Compile, run)).evaluated, 36 | 37 | scalacOptions ++= Seq("-deprecation", "-unchecked"), 38 | pomIncludeRepository := { x => false }, 39 | 40 | resolvers ++= Seq( 41 | "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/", 42 | "Typesafe repository" at "https://repo.typesafe.com/typesafe/releases/", 43 | "Second Typesafe repo" at "https://repo.typesafe.com/typesafe/maven-releases/", 44 | Resolver.sonatypeRepo("public") 45 | ), 46 | 47 | pomIncludeRepository := { _ => false }, 48 | 49 | // publish settings 50 | publishTo := { 51 | val nexus = "https://oss.sonatype.org/" 52 | if (isSnapshot.value) 53 | Some("snapshots" at nexus + "content/repositories/snapshots") 54 | else 55 | Some("releases" at nexus + "service/local/staging/deploy/maven2") 56 | } 57 | ) 58 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/SQLContextConstructor.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=MigrateToSparkSessionBuilder 3 | */ 4 | import org.apache.spark._ 5 | import org.apache.spark.sql._ 6 | 7 | import java.sql.{Date, Timestamp} 8 | 9 | import org.apache.spark.sql.types._ 10 | import org.apache.spark.sql.{DataFrame, Row, SQLContext} 11 | import org.scalacheck.{Arbitrary, Gen} 12 | 13 | object BadSessionBuilder { 14 | def getSQLContext(sc: SparkContext): SQLContext = { 15 | val ctx = new SQLContext(sc) 16 | ctx 17 | } 18 | 19 | def getOrCreateSQL(sc: SparkContext): SQLContext = { 20 | val ctx = SQLContext.getOrCreate(sc) 21 | val boop = SQLContext.clearActive() // We shouldn't rewrite this 22 | ctx 23 | } 24 | 25 | // This function is unrelated but early tests had arbLong rewrite for some reason. 26 | private def getGenerator( 27 | dataType: DataType, generators: Seq[_] = Seq()): Gen[Any] = { 28 | dataType match { 29 | case ByteType => Arbitrary.arbitrary[Byte] 30 | case ShortType => Arbitrary.arbitrary[Short] 31 | case IntegerType => Arbitrary.arbitrary[Int] 32 | case LongType => Arbitrary.arbitrary[Long] 33 | case FloatType => Arbitrary.arbitrary[Float] 34 | case DoubleType => Arbitrary.arbitrary[Double] 35 | case StringType => Arbitrary.arbitrary[String] 36 | case BinaryType => Arbitrary.arbitrary[Array[Byte]] 37 | case BooleanType => Arbitrary.arbitrary[Boolean] 38 | case TimestampType => Arbitrary.arbLong.arbitrary.map(new Timestamp(_)) 39 | case DateType => Arbitrary.arbLong.arbitrary.map(new Date(_)) 40 | case arr: ArrayType => { 41 | val elementGenerator = getGenerator(arr.elementType) 42 | Gen.listOf(elementGenerator) 43 | } 44 | case map: MapType => { 45 | val keyGenerator = getGenerator(map.keyType) 46 | val valueGenerator = getGenerator(map.valueType) 47 | val keyValueGenerator: Gen[(Any, Any)] = for { 48 | key <- keyGenerator 49 | value <- valueGenerator 50 | } yield (key, value) 51 | 52 | Gen.mapOf(keyValueGenerator) 53 | } 54 | case row: StructType => None 55 | case _ => throw new UnsupportedOperationException( 56 | s"Type: $dataType not supported") 57 | } 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/GroupByKeyWarn.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | case class GroupByKeyWarning(tn: scala.meta.Tree) extends Diagnostic { 7 | override def position: Position = tn.pos 8 | 9 | override def message: String = 10 | """In Spark 2.4 and below, 11 | |Dataset.groupByKey results to a grouped dataset with key attribute is wrongly named as “value”, 12 | |if the key is non-struct type, for example, int, string, array, etc. 13 | |This is counterintuitive and makes the schema of aggregation queries unexpected. 14 | |For example, the schema of ds.groupByKey(...).count() is (value, count). 15 | |Since Spark 3.0, we name the grouping attribute to “key”. 16 | |The old behavior is preserved under a newly added configuration 17 | |spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue with a default value of false. 18 | |This linter rule is fuzzy.""".stripMargin 19 | } 20 | 21 | class GroupByKeyWarn extends SemanticRule("GroupByKeyWarn") { 22 | val matcher = SymbolMatcher.normalized("org.apache.spark.sql.Dataset.groupByKey") 23 | override val description = "GroupByKey Warning." 24 | 25 | override def fix(implicit doc: SemanticDocument): Patch = { 26 | // Hacky. 27 | val grpByKey = "groupByKey" 28 | val funcToDS = "toDS" 29 | val agrFunCount = "count" 30 | 31 | if (doc.input.text.contains("groupByKey") && doc.input.text.contains("value") && 32 | doc.input.text.contains("org.apache.spark.sql")) { 33 | doc.tree.collect { 34 | case matcher(gbk) => 35 | Patch.lint(GroupByKeyWarning(gbk)) 36 | case t @ Term.Apply( 37 | Term.Select( 38 | Term.Apply( 39 | Term.Select( 40 | Term.Apply(Term.Select(_, _ @Term.Name(fName)), _), 41 | gbk @ Term.Name(name) 42 | ), 43 | _ 44 | ), 45 | _ @Term.Name(oprName) 46 | ), 47 | _ 48 | ) 49 | if grpByKey.equals(name) && funcToDS.equals(fName) && agrFunCount 50 | .equals(oprName) => 51 | Patch.lint(GroupByKeyWarning(t)) 52 | }.asPatch 53 | } else { 54 | Patch.empty 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala-2.12/fix/UnionRewrite.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | import metaconfig.ConfDecoder.canBuildFromAnyMapWithStringKey 3 | import metaconfig.generic.Surface 4 | import metaconfig.{ConfDecoder, Configured} 5 | import scalafix.v1._ 6 | 7 | import scala.meta._ 8 | final case class UnionRewriteConfig( 9 | deprecatedMethod: Map[String, String] 10 | ) 11 | 12 | object UnionRewriteConfig { 13 | val default: UnionRewriteConfig = 14 | UnionRewriteConfig( 15 | deprecatedMethod = Map( 16 | "unionAll" -> "union" 17 | ) 18 | ) 19 | 20 | implicit val surface: Surface[UnionRewriteConfig] = 21 | metaconfig.generic.deriveSurface[UnionRewriteConfig] 22 | implicit val decoder: ConfDecoder[UnionRewriteConfig] = 23 | metaconfig.generic.deriveDecoder(default) 24 | } 25 | 26 | class UnionRewrite(config: UnionRewriteConfig) extends SemanticRule("UnionRewrite") { 27 | def this() = this(UnionRewriteConfig.default) 28 | 29 | override def withConfiguration(config: Configuration): Configured[Rule] = 30 | config.conf.getOrElse("UnionRewrite")(this.config).map { newConfig => 31 | new UnionRewrite(newConfig) 32 | } 33 | 34 | override val isRewrite = true 35 | 36 | override def fix(implicit doc: SemanticDocument): Patch = { 37 | def matchOnTree(t: Tree): Patch = { 38 | t.collect { 39 | case Term.Apply( 40 | Term.Select(_, deprecated @ Term.Name(name)), 41 | _ 42 | ) if config.deprecatedMethod.contains(name) => 43 | Patch.replaceTree( 44 | deprecated, 45 | config.deprecatedMethod(name) 46 | ) 47 | case Term.Apply( 48 | Term.Select(_, _ @Term.Name(name)), 49 | List( 50 | Term.AnonymousFunction( 51 | Term.ApplyInfix( 52 | _, 53 | deprecatedAnm @ Term.Name(nameAnm), 54 | _, 55 | _ 56 | ) 57 | ) 58 | ) 59 | ) if "reduce".contains(name) && config.deprecatedMethod.contains(nameAnm) => 60 | Patch.replaceTree( 61 | deprecatedAnm, 62 | config.deprecatedMethod(nameAnm) 63 | ) 64 | }.asPatch 65 | } 66 | 67 | matchOnTree(doc.tree) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala-2.13/fix/UnionRewrite.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | import metaconfig.ConfDecoder.canBuildFromAnyMapWithStringKey 3 | import metaconfig.generic.Surface 4 | import metaconfig.{ConfDecoder, Configured} 5 | import scalafix.v1._ 6 | 7 | import scala.meta._ 8 | final case class UnionRewriteConfig( 9 | deprecatedMethod: Map[String, String] 10 | ) 11 | 12 | object UnionRewriteConfig { 13 | val default: UnionRewriteConfig = 14 | UnionRewriteConfig( 15 | deprecatedMethod = Map( 16 | "unionAll" -> "union" 17 | ) 18 | ) 19 | 20 | implicit val surface: Surface[UnionRewriteConfig] = 21 | metaconfig.generic.deriveSurface[UnionRewriteConfig] 22 | implicit val decoder: ConfDecoder[UnionRewriteConfig] = 23 | metaconfig.generic.deriveDecoder(default) 24 | } 25 | 26 | class UnionRewrite(config: UnionRewriteConfig) extends SemanticRule("UnionRewrite") { 27 | def this() = this(UnionRewriteConfig.default) 28 | 29 | override def withConfiguration(config: Configuration): Configured[Rule] = 30 | config.conf.getOrElse("UnionRewrite")(this.config).map { newConfig => 31 | new UnionRewrite(newConfig) 32 | } 33 | 34 | override val isRewrite = true 35 | 36 | override def fix(implicit doc: SemanticDocument): Patch = { 37 | def matchOnTree(t: Tree): Patch = { 38 | t.collect { 39 | case Term.Apply( 40 | Term.Select(_, deprecated @ Term.Name(name)), 41 | _ 42 | ) if config.deprecatedMethod.contains(name) => 43 | Patch.replaceTree( 44 | deprecated, 45 | config.deprecatedMethod(name) 46 | ) 47 | case Term.Apply( 48 | Term.Select(_, _ @Term.Name(name)), 49 | List( 50 | Term.AnonymousFunction( 51 | Term.ApplyInfix( 52 | _, 53 | deprecatedAnm @ Term.Name(nameAnm), 54 | _, 55 | _ 56 | ) 57 | ) 58 | ) 59 | ) if "reduce".contains(name) && config.deprecatedMethod.contains(nameAnm) => 60 | Patch.replaceTree( 61 | deprecatedAnm, 62 | config.deprecatedMethod(nameAnm) 63 | ) 64 | }.asPatch 65 | } 66 | 67 | matchOnTree(doc.tree) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /sql/test/rules/test_cases/SPARK_SQL_RESERVED_PROPERTIES.yml: -------------------------------------------------------------------------------- 1 | rule: RESERVEDROPERTIES_L002 2 | 3 | create_db: 4 | configs: 5 | core: 6 | dialect: sparksql 7 | fail_str: | 8 | CREATE DATABASE boop WITH DBPROPERTIES("location" = "farts", "junk" = "farts") 9 | fix_str: | 10 | CREATE DATABASE boop LOCATION "farts" WITH DBPROPERTIES( "junk" = "farts") 11 | 12 | create_db_only_loc_prop: 13 | configs: 14 | core: 15 | dialect: sparksql 16 | fail_str: | 17 | CREATE DATABASE boop WITH DBPROPERTIES("location" = "farts") 18 | fix_str: | 19 | CREATE DATABASE boop LOCATION "farts" WITH DBPROPERTIES("legacy_location" = "farts") 20 | 21 | create_partitioned_tbl: 22 | configs: 23 | core: 24 | dialect: sparksql 25 | fail_str: | 26 | CREATE TABLE boop (id INT) PARTITIONED BY (id) TBLPROPERTIES("provider" = "parquet", "ok" = "three") 27 | fix_str: | 28 | CREATE TABLE boop (id INT) USING parquet PARTITIONED BY (id) TBLPROPERTIES( "ok" = "three") 29 | 30 | create_empty_parquet_tbl: 31 | configs: 32 | core: 33 | dialect: sparksql 34 | fail_str: | 35 | CREATE TABLE boop TBLPROPERTIES("provider" = "parquet", "ok" = "three") 36 | fix_str: | 37 | CREATE TABLE boop USING parquet TBLPROPERTIES( "ok" = "three") 38 | 39 | create_empty_loc_spec_tbl: 40 | configs: 41 | core: 42 | dialect: sparksql 43 | fail_str: | 44 | CREATE TABLE boop TBLPROPERTIES("location" = "butts") 45 | fix_str: | 46 | CREATE TABLE boop LOCATION "butts" TBLPROPERTIES("legacy_location" = "butts") 47 | 48 | alter_db: 49 | configs: 50 | core: 51 | dialect: sparksql 52 | fail_str: | 53 | ALTER DATABASE meaning_of_life SET DBPROPERTIES('location'='/storage/earth/42') 54 | fix_str: | 55 | ALTER DATABASE meaning_of_life SET DBPROPERTIES("legacy_location"='/storage/earth/42') 56 | 57 | alter_tbl: 58 | configs: 59 | core: 60 | dialect: sparksql 61 | fail_str: | 62 | ALTER TABLE meaning_of_life SET TBLPROPERTIES('location'='/storage/earth/42') 63 | fix_str: | 64 | ALTER TABLE meaning_of_life SET TBLPROPERTIES("legacy_location"='/storage/earth/42') 65 | 66 | create_tbl_ok: 67 | configs: 68 | core: 69 | dialect: sparksql 70 | pass_str: | 71 | CREATE TABLE boop (id INT) USING parquet TBLPROPERTIES( "ok" = "three") 72 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/SparkSQLCallExternal.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import java.io._ 4 | 5 | 6 | import scalafix.v1._ 7 | import scala.meta._ 8 | import sys.process._ 9 | 10 | class SparkSQLCallExternal extends SemanticRule("SparkSQLCallExternal") { 11 | 12 | override def fix(implicit doc: SemanticDocument): Patch = { 13 | val sparkSQLFunMatch = SymbolMatcher.normalized("org.apache.spark.sql.SparkSession.sql") 14 | val utils = new Utils() 15 | 16 | def matchOnTree(e: Tree): Patch = { 17 | e match { 18 | // non-named accumulator 19 | case ns @ Term.Apply(j @ sparkSQLFunMatch(f), params) => 20 | // Find the spark context for rewriting 21 | params match { 22 | case List(param) => 23 | param match { 24 | case s @ Lit.String(_) => 25 | // Write out the SQL to a file 26 | val f = File.createTempFile("magic", ".sql") 27 | f.deleteOnExit() 28 | val bw = new BufferedWriter(new FileWriter(f)) 29 | bw.write(s.value.toString) 30 | bw.close() 31 | // Run SQL fluff 32 | val strToRun = s"sqlfluff fix --dialect sparksql -f ${f.toPath}" 33 | println(s"Running ${strToRun}") 34 | val ret = strToRun.! 35 | println(ret) 36 | val newSQL = scala.io.Source.fromFile(f).mkString 37 | // We don't care about whitespace only changes. 38 | if (newSQL.filterNot(_.isWhitespace) != s) { 39 | Patch.replaceTree(param, "\"\"\"" + newSQL + "\"\"\"") 40 | } else { 41 | Patch.empty 42 | } 43 | case _ => 44 | // TODO: Do we want to warn here about non migrated dynamically generated SQL 45 | // or no? 46 | Patch.empty 47 | } 48 | case _ => 49 | Patch.empty 50 | } 51 | case elem @ _ => 52 | elem.children match { 53 | case Nil => Patch.empty 54 | case _ => elem.children.map(matchOnTree).asPatch 55 | } 56 | } 57 | } 58 | matchOnTree(doc.tree) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/SQLContextConstructor.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark._ 2 | import org.apache.spark.sql._ 3 | 4 | import java.sql.{Date, Timestamp} 5 | 6 | import org.apache.spark.sql.types._ 7 | import org.apache.spark.sql.{DataFrame, Row, SQLContext} 8 | import org.scalacheck.{Arbitrary, Gen} 9 | import org.apache.spark.sql.SparkSession 10 | 11 | object BadSessionBuilder { 12 | def getSQLContext(sc: SparkContext): SQLContext = { 13 | val ctx = SparkSession.builder.getOrCreate().sqlContext 14 | ctx 15 | } 16 | 17 | def getOrCreateSQL(sc: SparkContext): SQLContext = { 18 | val ctx = SparkSession.builder.getOrCreate().sqlContext 19 | val boop = SQLContext.clearActive() // We shouldn't rewrite this 20 | ctx 21 | } 22 | 23 | // This function is unrelated but early tests had arbLong rewrite for some reason. 24 | private def getGenerator( 25 | dataType: DataType, generators: Seq[_] = Seq()): Gen[Any] = { 26 | dataType match { 27 | case ByteType => Arbitrary.arbitrary[Byte] 28 | case ShortType => Arbitrary.arbitrary[Short] 29 | case IntegerType => Arbitrary.arbitrary[Int] 30 | case LongType => Arbitrary.arbitrary[Long] 31 | case FloatType => Arbitrary.arbitrary[Float] 32 | case DoubleType => Arbitrary.arbitrary[Double] 33 | case StringType => Arbitrary.arbitrary[String] 34 | case BinaryType => Arbitrary.arbitrary[Array[Byte]] 35 | case BooleanType => Arbitrary.arbitrary[Boolean] 36 | case TimestampType => Arbitrary.arbLong.arbitrary.map(new Timestamp(_)) 37 | case DateType => Arbitrary.arbLong.arbitrary.map(new Date(_)) 38 | case arr: ArrayType => { 39 | val elementGenerator = getGenerator(arr.elementType) 40 | Gen.listOf(elementGenerator) 41 | } 42 | case map: MapType => { 43 | val keyGenerator = getGenerator(map.keyType) 44 | val valueGenerator = getGenerator(map.valueType) 45 | val keyValueGenerator: Gen[(Any, Any)] = for { 46 | key <- keyGenerator 47 | value <- valueGenerator 48 | } yield (key, value) 49 | 50 | Gen.mapOf(keyValueGenerator) 51 | } 52 | case row: StructType => None 53 | case _ => throw new UnsupportedOperationException( 54 | s"Type: $dataType not supported") 55 | } 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /iceberg-spark-upgrade-wap-plugin/src/test/scala/com/holdenkarau/spark/upgrade/wap/plugin/WAPIcebergSpec.scala: -------------------------------------------------------------------------------- 1 | package com.holdenkarau.spark.upgrade.wap.plugin 2 | 3 | import java.io.File 4 | 5 | import scala.util.matching.Regex 6 | 7 | import com.holdenkarau.spark.testing.ScalaDataFrameSuiteBase 8 | import com.holdenkarau.spark.testing.Utils 9 | 10 | import org.scalatest.funsuite.AnyFunSuite 11 | import org.scalatest.matchers.should.Matchers 12 | 13 | import org.apache.spark.SparkConf 14 | 15 | /** 16 | * This tests both the listener and the agent since if the agent is not loaded the 17 | * WAPIcebergListener will not be registered and none of the tests will pass :D 18 | */ 19 | class WAPIcebergSpec extends AnyFunSuite with ScalaDataFrameSuiteBase with Matchers { 20 | 21 | val delay = 2 22 | 23 | override protected def enableHiveSupport = false 24 | override protected def enableIcebergSupport = true 25 | 26 | override def conf: SparkConf = { 27 | new SparkConf(). 28 | setMaster("local[*]"). 29 | setAppName("test"). 30 | set("spark.ui.enabled", "false"). 31 | set("spark.app.id", appID). 32 | set("spark.driver.host", "localhost"). 33 | // Hack 34 | set("spark.driver.userClassPathFirst", "true"). 35 | set("spark.executor.userClassPathFirst", "true") 36 | } 37 | 38 | test("WAPIcebergSpec should be called on iceberg commit") { 39 | val re = """IcebergListener: Created snapshot (\d+) on table (.+?) summary .*? from operation (.+)""".r 40 | spark.sql("CREATE TABLE local.db.table (id bigint, data string) USING iceberg") 41 | // there _might be_ a timing race condition here since were using a listener 42 | // that is not blocking the write path. 43 | spark.sql("INSERT INTO local.db.table VALUEs (1, 'timbit')") 44 | Thread.sleep(delay) 45 | val firstLog = WAPIcebergListener.lastLog 46 | firstLog should fullyMatch regex re 47 | firstLog match { 48 | case re(snapshot, table, op) => 49 | snapshot.toLong should be > 0L 50 | table should be ("local.db.table") 51 | op should be ("append") 52 | } 53 | spark.sql("INSERT INTO local.db.table VALUEs (2, 'timbot')") 54 | Thread.sleep(delay) 55 | val secondLog = WAPIcebergListener.lastLog 56 | secondLog should not equal firstLog 57 | secondLog should fullyMatch regex re 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /pysparkler/tests/sample/input_pyspark.py: -------------------------------------------------------------------------------- 1 | import pyspark 2 | import numpy as np 3 | import pandas as pd 4 | import pyspark.pandas as ps 5 | 6 | from pandas import DataFrame as df 7 | from pyspark.sql import SparkSession, Row 8 | from pyspark.sql.functions import pandas_udf, PandasUDFType 9 | from pyspark.ml.param.shared import * 10 | 11 | spark = SparkSession.builder.appName('example').getOrCreate() 12 | spark.conf.set("spark.sql.execution.arrow.enabled", "true") 13 | 14 | table_name = "my_table" 15 | result = spark.sql(f"select cast(dateint as int) val from {table_name} limit 10") 16 | 17 | data = [("James", "", "Smith", "36636", "M", 60000), 18 | ("Jen", "Mary", "Brown", "", "F", 0)] 19 | 20 | columns = ["first_name", "middle_name", "last_name", "dob", "gender", "salary"] 21 | pysparkDF = spark.createDataFrame(data=data, schema=columns, verifySchema=True) 22 | 23 | pandasDF = pysparkDF.toPandas() 24 | print(pandasDF) 25 | 26 | pysparkDF.write.partitionBy('gender').saveAsTable("persons") 27 | pysparkDF.write.insertInto("persons", overwrite=True) 28 | 29 | data = [Row(name="James,,Smith", lang=["Java", "Scala", "C++"], state="CA"), 30 | Row(name="Robert,,Williams", lang=["CSharp", "VB"], state="NV")] 31 | 32 | rdd = spark.sparkContext.parallelize(data) 33 | print(rdd.collect()) 34 | 35 | ps_df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 36 | ps_df.drop(['B', 'C']) 37 | 38 | a_column_values = list(ps_df['A'].unique()) 39 | repr_a_column_values = [repr(value) for value in a_column_values] 40 | 41 | spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") 42 | tz_df = spark.createDataFrame([28801], "long").selectExpr("timestamp(value) as ts") 43 | tz_df.show() 44 | 45 | rp_df = spark.createDataFrame([ 46 | (10, 80.5, "Alice", None), 47 | (5, None, "Bob", None), 48 | (None, None, "Tom", None), 49 | (None, None, None, True)], 50 | schema=["age", "height", "name", "bool"]) 51 | 52 | rp_df.na.replace('Alice').show() 53 | rp_df.na.fill(False).show() 54 | rp_df.fillna(True).show() 55 | 56 | 57 | def truncate(truncate=True): 58 | try: 59 | int_truncate = int(truncate) 60 | except ValueError as ex: 61 | raise TypeError( 62 | "Parameter 'truncate={}' should be either bool or int.".format(truncate) 63 | ) 64 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/ScalaTestImportChange.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class ScalaTestImportChange 7 | extends SemanticRule("ScalaTestImportChange") { 8 | override val description = 9 | """Handle the import change with ScalaTest ( see https://www.scalatest.org/release_notes/3.1.0 ) """ 10 | 11 | override val isRewrite = true 12 | 13 | override def fix(implicit doc: SemanticDocument): Patch = { 14 | 15 | def matchOnTree(t: Tree): Patch = { 16 | t match { 17 | case q"""import org.scalatest.FunSuite""" => 18 | Patch.replaceTree(t, q"""import org.scalatest.funsuite.AnyFunSuite""".toString()) 19 | case q"""class $cls extends FunSuite { $expr }""" => 20 | Patch.replaceTree(t, f"class $cls extends AnyFunSuite { $expr }") 21 | case q"""import org.scalatest.FunSuiteLike""" => 22 | Patch.replaceTree(t, q"""import org.scalatest.funsuite.AnyFunSuiteLike""".toString()) 23 | case q"""class $cls extends FunSuiteLike { $expr }""" => 24 | Patch.replaceTree(t, q"class $cls extends AnyFunSuiteLike { $expr }".toString) 25 | case q"""import org.scalatest.AsyncFunSuite""" => 26 | Patch.replaceTree(t, q"""import org.scalatest.funsuite.AsyncFunSuiteLike""".toString()) 27 | case q"""import org.scalatest.fixture.FunSuite""" => 28 | Patch.replaceTree(t, q"""import org.scalatest.funsuite.FixtureAnyFunSuite""".toString()) 29 | case q"""import org.scalatest.Matchers._""" => 30 | Patch.replaceTree(t, q"""import org.scalatest.matchers.should.Matchers._""".toString()) 31 | case q"""import org.scalatest.Matchers""" => 32 | Patch.replaceTree(t, q"""import org.scalatest.matchers.should.Matchers""".toString()) 33 | case q"""import org.scalatest.MustMatchers""" => 34 | Patch.replaceTree(t, q"""import org.scalatest.matchers.must.{Matchers => MustMatchers}""".toString) 35 | case q"""import org.scalatest.MustMatchers._""" => 36 | Patch.replaceTree(t, """import org.scalatest.matchers.must.Matchers._\n""") 37 | case elem @ _ => 38 | elem.children match { 39 | case Nil => Patch.empty 40 | case _ => 41 | elem.children.map(matchOnTree).asPatch 42 | } 43 | } 44 | } 45 | 46 | matchOnTree(doc.tree) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/AccumulatorUpgrade.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | case class AccMigrationGuide(position: Position) extends Diagnostic { 7 | val migrationGuide = 8 | "https://spark.apache.org/docs/latest/core-migration-guide.html#upgrading-from-core-24-to-30" 9 | def message = s"sc.accumulator is removed see ${migrationGuide}" 10 | } 11 | 12 | class AccumulatorUpgrade extends SemanticRule("AccumulatorUpgrade") { 13 | 14 | override def fix(implicit doc: SemanticDocument): Patch = { 15 | val accumulatorFunMatch = SymbolMatcher.normalized("org.apache.spark.SparkContext.accumulator") 16 | val utils = new Utils() 17 | 18 | def matchOnTree(e: Tree): Patch = { 19 | e match { 20 | // non-named accumulator 21 | case ns @ Term.Apply(j @ accumulatorFunMatch(f), params) => 22 | // Find the spark context for rewriting 23 | val sc = ns.children(0).children(0) 24 | params match { 25 | case List(param) => 26 | param match { 27 | // TODO: Handle non zero values 28 | case utils.intMatcher(initialValue) => 29 | Seq( 30 | Patch.lint(AccMigrationGuide(e.pos)), 31 | Patch.addLeft(e, "/*"), 32 | Patch.addRight(e, "*/ null")).asPatch 33 | case q"0L" => 34 | Patch.replaceTree(ns, s"${sc}.longAccumulator") 35 | case utils.longMatcher(initialValue) => 36 | Patch.empty 37 | case q"0.0" => 38 | Patch.replaceTree(ns, s"${sc}.doubleAccumulator") 39 | case _ => 40 | Seq( 41 | Patch.lint(AccMigrationGuide(e.pos)), 42 | Patch.addLeft(e, "/*"), 43 | Patch.addRight(e, "*/ null")).asPatch 44 | } 45 | case List(param, name) => 46 | Seq( 47 | Patch.lint(AccMigrationGuide(e.pos)), 48 | Patch.addLeft(e, "/*"), 49 | Patch.addRight(e, "*/ null")).asPatch 50 | } 51 | case elem @ _ => 52 | elem.children match { 53 | case Nil => Patch.empty 54 | case _ => elem.children.map(matchOnTree).asPatch 55 | } 56 | } 57 | } 58 | matchOnTree(doc.tree) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /pysparkler/tests/test_base.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | 19 | from pysparkler.base import StatementLineCommentWriter 20 | from tests.conftest import rewrite 21 | 22 | 23 | def test_writes_comment_with_noqa_when_match_is_found(): 24 | given_code = """ 25 | import pyspark 26 | """ 27 | comment_writer = StatementLineCommentWriter(transformer_id="foo", comment="bar") 28 | comment_writer.match_found = True 29 | 30 | modified_code = rewrite(given_code, comment_writer) 31 | expected_code = """ 32 | import pyspark # foo: bar # noqa: E501 33 | """ 34 | assert modified_code == expected_code 35 | 36 | 37 | def test_overrides_comment_from_kwargs(): 38 | given_code = """ 39 | import pyspark 40 | """ 41 | overrides = {"comment": "baz"} 42 | comment_writer = StatementLineCommentWriter(transformer_id="foo", comment="bar") 43 | comment_writer.match_found = True 44 | comment_writer.override(**overrides) 45 | 46 | modified_code = rewrite(given_code, comment_writer) 47 | expected_code = """ 48 | import pyspark # foo: baz # noqa: E501 49 | """ 50 | assert modified_code == expected_code 51 | 52 | 53 | def test_overrides_with_unknown_attributes_are_silently_ignored(): 54 | given_code = """ 55 | import pyspark 56 | """ 57 | overrides = {"unknown": "attribute"} 58 | comment_writer = StatementLineCommentWriter(transformer_id="foo", comment="bar") 59 | comment_writer.match_found = True 60 | comment_writer.override(**overrides) 61 | 62 | modified_code = rewrite(given_code, comment_writer) 63 | expected_code = """ 64 | import pyspark # foo: bar # noqa: E501 65 | """ 66 | assert modified_code == expected_code 67 | -------------------------------------------------------------------------------- /e2e_demo/scala/sparkdemoproject/.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | build.sbt_back 4 | 5 | # sbt specific 6 | dist/* 7 | target/ 8 | lib_managed/ 9 | src_managed/ 10 | project/boot/ 11 | project/plugins/project/ 12 | sbt/*.jar 13 | mini-complete-example/sbt/*.jar 14 | 15 | # Scala-IDE specific 16 | .scala_dependencies 17 | 18 | #Emacs 19 | *~ 20 | 21 | #ignore the metastore 22 | metastore_db/* 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .env 33 | .Python 34 | env/ 35 | bin/ 36 | build/*.jar 37 | develop-eggs/ 38 | dist/ 39 | eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | 61 | # Translations 62 | *.mo 63 | 64 | # Mr Developer 65 | .mr.developer.cfg 66 | .project 67 | .pydevproject 68 | 69 | # Rope 70 | .ropeproject 71 | 72 | # Django stuff: 73 | *.log 74 | *.pot 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyCharm files 80 | *.idea 81 | 82 | # emacs stuff 83 | 84 | # Autoenv 85 | .env 86 | *~ 87 | # Byte-compiled / optimized / DLL files 88 | __pycache__/ 89 | *.py[cod] 90 | 91 | # C extensions 92 | *.so 93 | 94 | # Distribution / packaging 95 | .env 96 | .Python 97 | env/ 98 | bin/ 99 | build/ 100 | develop-eggs/ 101 | dist/ 102 | eggs/ 103 | lib/ 104 | lib64/ 105 | parts/ 106 | sdist/ 107 | var/ 108 | *.egg-info/ 109 | .installed.cfg 110 | *.egg 111 | 112 | # Installer logs 113 | pip-log.txt 114 | pip-delete-this-directory.txt 115 | 116 | # Unit test / coverage reports 117 | htmlcov/ 118 | .tox/ 119 | .coverage 120 | .cache 121 | nosetests.xml 122 | coverage.xml 123 | 124 | # Translations 125 | *.mo 126 | 127 | # Mr Developer 128 | .mr.developer.cfg 129 | .project 130 | .pydevproject 131 | 132 | # Rope 133 | .ropeproject 134 | 135 | # Django stuff: 136 | *.log 137 | *.pot 138 | 139 | # Sphinx documentation 140 | docs/_build/ 141 | 142 | # PyCharm files 143 | *.idea 144 | 145 | # emacs stuff 146 | \#*\# 147 | \.\#* 148 | 149 | # Autoenv 150 | .env 151 | *~ 152 | 153 | 154 | # Ignore Gradle project-specific cache directory 155 | .gradle 156 | 157 | # Ignore Gradle build output directory 158 | build 159 | -------------------------------------------------------------------------------- /pysparkler/tests/test_api.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | import nbformat 19 | 20 | from pysparkler.api import PySparkler 21 | from tests.conftest import absolute_path 22 | 23 | 24 | def test_upgrade_pyspark_python_script(): 25 | modified_code = PySparkler(dry_run=True).upgrade_script( 26 | input_file=absolute_path("tests/sample/input_pyspark.py") 27 | ) 28 | 29 | with open( 30 | file=absolute_path("tests/sample/output_pyspark.py"), encoding="utf-8" 31 | ) as f: 32 | expected_code = f.read() 33 | 34 | assert modified_code == expected_code 35 | 36 | 37 | def test_upgrade_pyspark_jupyter_notebook(): 38 | modified_code = PySparkler(dry_run=True).upgrade_notebook( 39 | input_file=absolute_path("tests/sample/InputPySparkNotebook.ipynb"), 40 | output_kernel_name="spark33-python3-venv", 41 | ) 42 | 43 | with open( 44 | file=absolute_path("tests/sample/OutputPySparkNotebook.ipynb"), encoding="utf-8" 45 | ) as f: 46 | expected_code = f.read() 47 | 48 | assert nbformat.reads( 49 | modified_code, as_version=nbformat.NO_CONVERT 50 | ) == nbformat.reads(expected_code, as_version=nbformat.NO_CONVERT) 51 | 52 | 53 | def test_disable_transformers_are_filtered_out(): 54 | transformer_id = "PY24-30-001" 55 | given_overrides = { 56 | transformer_id: {"enabled": False}, 57 | } 58 | transformers = PySparkler(**given_overrides).transformers 59 | 60 | assert transformer_id not in [t.transformer_id for t in transformers] 61 | 62 | 63 | def test_transformer_override_comments_are_taking_effect(): 64 | transformer_id = "PY24-30-001" 65 | overriden_comment = "A new comment" 66 | given_overrides = { 67 | transformer_id: {"comment": overriden_comment}, 68 | } 69 | 70 | modified_code = PySparkler(dry_run=True, **given_overrides).upgrade_script( 71 | input_file=absolute_path("tests/sample/input_pyspark.py") 72 | ) 73 | 74 | assert overriden_comment in modified_code 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Aims to help people upgrade to the latest version of Spark. 2 | 3 | # Tools 4 | 5 | ## Upgrade Validation 6 | 7 | While we all wish we had a great test suite that covered all of the possible issues, that is not the case for all of our pipelines. 8 | 9 | ### Side-By-Side Comparison 10 | 11 | To support migrations for pipelines with incomplete test coverage, we have tooling to compare two runs of the same pipeline on different versions of Spark. 12 | Right now it requires that you specify the table to be compared, but (not yet started) for Iceberg tables we plan to provide a custom Iceberg library which automates this component. 13 | 14 | ### Performance only Comparison (not yet started) 15 | 16 | ## Semi-Automatic Upgrades 17 | 18 | Upgrading your code to a new version of Spark is perhaps not how most folks wish to spend their work day (let alone their after work day). Some parts of the migrations can be automated, and when combined with the upgrade validation described above can (hopefully) lead to reasonably confident automatic upgrades. 19 | 20 | ### SQL (WIP) 21 | 22 | Spark SQL has some important changes between Spark 2.4 and 3.0 as well as some smaller changes in between later versions. (Spark SQL migration guide)[https://spark.apache.org/docs/3.3.0/sql-migration-guide.html] covers most of the expected required changes. 23 | 24 | The SQL migration tool is built using (SQLFluff)[https://sqlfluff.com/], which has a (Spark SQL dialect)[https://docs.sqlfluff.com/en/stable/dialects.html]. 25 | 26 | #### Limitations / Unique Challenges 27 | 28 | Out of the box SQLFluff lacks access to type information that is available when migrating Scala code, and the AST parser is not a 1:1 match with the underlying parser used by Spark SQL. A potential mitigation (if we end up needing type information) is integrating with Spark SQL to run an EXPLAIN on the input query and extract type information. 29 | 30 | 31 | Some migration rules are too much work to fully automate so instead output warnings for users to manually verify. 32 | 33 | We do not have an equivelent to "Scala Steward" for SQL files and SQL can target multiple backends. In most situations, the scheduler job type can be used to determine the engine. 34 | 35 | ### PySpark (Python Spark) Upgrade (WIP) 36 | 37 | The PySpark Upgrade tool - PySparkler - is currently built using [LibCST](https://github.com/Instagram/LibCST). 38 | More on the tool's design and challenges can be found in the subdirectory README of the tool [here](./pysparkler/README.md). 39 | 40 | ### Scala Upgrade (WIP) 41 | 42 | The Scala upgrade tooling is built on top of ScalaFix and has access to (most) of the type information. Spark's Scala APIs are perhaps the fastest changing of three primary languages used with Spark. 43 | 44 | 45 | #### Limitations / Unique Challenges 46 | 47 | While scalafix can be integrated with tools like Scala Steward (yay!), recompiling and publishing new artifacts is required to verify the changes. It is likely that dependencies will need to be manually upgraded. 48 | -------------------------------------------------------------------------------- /pysparkler/tests/sample/InputPySparkNotebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "b5e40a68", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pyspark\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "import pyspark.pandas as ps\n", 14 | "\n", 15 | "from pandas import DataFrame as df\n", 16 | "from pyspark.sql import SparkSession, Row\n", 17 | "from pyspark.sql.functions import pandas_udf, PandasUDFType\n", 18 | "from pyspark.ml.param.shared import *\n", 19 | "\n", 20 | "spark = SparkSession.builder.appName('example').getOrCreate()\n", 21 | "spark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")\n", 22 | "\n", 23 | "table_name = \"my_table\"\n", 24 | "result = spark.sql(f\"select cast(dateint as int) val from {table_name} limit 10\")\n", 25 | "\n", 26 | "data = [(\"James\", \"\", \"Smith\", \"36636\", \"M\", 60000),\n", 27 | " (\"Jen\", \"Mary\", \"Brown\", \"\", \"F\", 0)]\n", 28 | "\n", 29 | "columns = [\"first_name\", \"middle_name\", \"last_name\", \"dob\", \"gender\", \"salary\"]\n", 30 | "pysparkDF = spark.createDataFrame(data=data, schema=columns, verifySchema=True)\n", 31 | "\n", 32 | "pandasDF = pysparkDF.toPandas()\n", 33 | "print(pandasDF)\n", 34 | "\n", 35 | "pysparkDF.write.partitionBy('gender').saveAsTable(\"persons\")\n", 36 | "pysparkDF.write.insertInto(\"persons\", overwrite=True)\n", 37 | "\n", 38 | "data = [Row(name=\"James,,Smith\", lang=[\"Java\", \"Scala\", \"C++\"], state=\"CA\"),\n", 39 | " Row(name=\"Robert,,Williams\", lang=[\"CSharp\", \"VB\"], state=\"NV\")]\n", 40 | "\n", 41 | "rdd = spark.sparkContext.parallelize(data)\n", 42 | "print(rdd.collect())\n", 43 | "\n", 44 | "ps_df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D'])\n", 45 | "ps_df.drop(['B', 'C'])\n", 46 | "\n", 47 | "a_column_values = list(ps_df['A'].unique())\n", 48 | "repr_a_column_values = [repr(value) for value in a_column_values]\n", 49 | "\n", 50 | "spark.conf.set(\"spark.sql.session.timeZone\", \"America/Los_Angeles\")\n", 51 | "tz_df = spark.createDataFrame([28801], \"long\").selectExpr(\"timestamp(value) as ts\")\n", 52 | "tz_df.show()\n", 53 | "\n", 54 | "rp_df = spark.createDataFrame([\n", 55 | " (10, 80.5, \"Alice\", None),\n", 56 | " (5, None, \"Bob\", None),\n", 57 | " (None, None, \"Tom\", None),\n", 58 | " (None, None, None, True)],\n", 59 | " schema=[\"age\", \"height\", \"name\", \"bool\"])\n", 60 | "\n", 61 | "rp_df.na.replace('Alice').show()\n", 62 | "rp_df.na.fill(False).show()\n", 63 | "rp_df.fillna(True).show()\n", 64 | "\n", 65 | "\n", 66 | "def truncate(truncate=True):\n", 67 | " try:\n", 68 | " int_truncate = int(truncate)\n", 69 | " except ValueError as ex:\n", 70 | " raise TypeError(\n", 71 | " \"Parameter 'truncate={}' should be either bool or int.\".format(truncate)\n", 72 | " )\n" 73 | ] 74 | } 75 | ], 76 | "metadata": { 77 | "hide_input": false, 78 | "kernelspec": { 79 | "display_name": "Spark 2.4.4 - Python 3 (venv)", 80 | "language": "python", 81 | "name": "spark24-python3-venv" 82 | } 83 | }, 84 | "nbformat": 4, 85 | "nbformat_minor": 5 86 | } 87 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/MigrateHiveContext.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class MigrateHiveContext extends SemanticRule("MigrateHiveContext") { 7 | 8 | override def fix(implicit doc: SemanticDocument): Patch = { 9 | val hiveSymbolMatcher = SymbolMatcher.normalized("org.apache.spark.sql.hive.HiveContext") 10 | val hiveGetOrCreateMatcher = SymbolMatcher.normalized("org.apache.spark.sql.hive.HiveContext.getOrCreate") 11 | val newCreateHive = "SparkSession.builder.enableHiveSupport().getOrCreate().sqlContext" 12 | val utils = new Utils() 13 | def matchOnTree(e: Tree): Patch = { 14 | e match { 15 | // Rewrite the construction of a HiveContext 16 | case ns @ Term.New(Init(initArgs)) => 17 | initArgs match { 18 | case (hiveSymbolMatcher(_), _, _) => 19 | List( 20 | Patch.replaceTree( 21 | ns, 22 | newCreateHive), 23 | // TODO Add SparkSession import if missing -- addGlobalImport is broken 24 | // Patch.addGlobalImport(importer"org.apache.spark.sql.SparkSession") 25 | utils.addImportIfNotPresent(importer"org.apache.spark.sql.SparkSession") 26 | ).asPatch 27 | case _ => Patch.empty 28 | } 29 | case ns @ Term.Apply(hiveGetOrCreateMatcher(_), _) => 30 | List( 31 | Patch.replaceTree( 32 | ns, 33 | newCreateHive), 34 | // TODO Add SparkSession import if missing -- addGlobalImport is broken 35 | // Patch.addGlobalImport(importer"org.apache.spark.sql.SparkSession") 36 | utils.addImportIfNotPresent(importer"org.apache.spark.sql.SparkSession") 37 | ).asPatch 38 | 39 | // HiveContext type name rewrite to SQLContext 40 | // There should be a way to combine these two rules right? 41 | // Ideally we could rewrite the import to SqlContext symbol. 42 | case Import(List( 43 | Importer(Term.Select(Term.Select(Term.Select( 44 | Term.Select(Term.Name("org"), Term.Name("apache")), 45 | Term.Name("spark")), 46 | Term.Name("sql")), 47 | Term.Name("hive")), 48 | List(hiveImports)))) => 49 | // Remove HiveContext it's deprecated 50 | hiveImports.collect { 51 | case i @ Importee.Name(Name("HiveContext")) => 52 | List( 53 | Patch.removeImportee(i), 54 | utils.addImportIfNotPresent(importer"org.apache.spark.sql.SQLContext") 55 | // TODO add SQLContext import if missing -- addGlobalImport is broken 56 | ).asPatch 57 | case i @ Importee.Rename(Name("HiveContext"), _) => 58 | List( 59 | Patch.removeImportee(i), 60 | utils.addImportIfNotPresent(importer"org.apache.spark.sql.SQLContext") 61 | // TODO add SQLContext import if missing -- addGlobalImport is broken 62 | ).asPatch 63 | case _ => Patch.empty 64 | }.asPatch 65 | case hiveSymbolMatcher(h) => 66 | Patch.replaceTree(h, "SQLContext") 67 | case elem @ _ => 68 | elem.children match { 69 | case Nil => Patch.empty 70 | case _ => elem.children.map(matchOnTree).asPatch 71 | } 72 | } 73 | } 74 | matchOnTree(doc.tree) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /pipelinecompare/spark_utils.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from pyspark.sql import DataFrame, Row, SparkSession 3 | import sys 4 | 5 | if sys.version_info < (3, 9): 6 | sys.exit("Please use Python 3.9+") 7 | 8 | def extract_catalog(table_name: str) -> str: 9 | """Extract the catalog.""" 10 | if "." in table_name: 11 | return table_name.split(".")[0] 12 | else: 13 | return "spark_catalog" 14 | 15 | 16 | def get_ancestors(spark: SparkSession, table_name: str, snapshot: str) -> list[Row]: 17 | """Get the ancestors of a given table at a given snapshot.""" 18 | catalog_name = extract_catalog(table_name) 19 | return spark.sql( 20 | f"""CALL {catalog_name}.system.ancestors_of( 21 | snapshot_id => {snapshot}, table => '{table_name}')""").collect() 22 | 23 | 24 | def create_changelog_view(spark: SparkSession, table_name: str, start_snapshot: str, end_snapshot: str, view_name: str) -> DataFrame: 25 | """Create a changelog view for the provided table.""" 26 | catalog_name = extract_catalog(table_name) 27 | return spark.sql( 28 | f"""CALL {catalog_name}.system.create_changelog_view( 29 | table => '{table_name}', 30 | options => map('start-snapshot-id','{start_snapshot}','end-snapshot-id', '{end_snapshot}'), 31 | changelog_view => '{view_name}' 32 | )""") 33 | 34 | 35 | def drop_iceberg_internal_columns(df: DataFrame) -> DataFrame: 36 | """Drop the iceberg internal columns from a changelog view that would make comparisons tricky.""" 37 | new_df = df 38 | # We don't drop "_change_type" because if one version inserts and the other deletes that's a diff we want to catch. 39 | # However change_orgidinal and _commit_snapshot_id are expected to differ even with identical end table states. 40 | internal = set("_change_ordinal", "_commit_snapshot_id") 41 | for c in df.columns: 42 | name = c.split("#") 43 | if name in iternal: 44 | new_df = new_df.drop(c) 45 | return new_df 46 | 47 | 48 | def get_cdc_views(spark: SparkSession, ctrl_name: str, target_name: str) -> tuple[DataFrame, DataFrame]: 49 | """Get the changelog/CDC views of two tables with a common ancestor.""" 50 | (ctrl_name, c_snapshot) = ctrl_name.split("@") 51 | (target_name, t_snapshot) = target_name.split("@") 52 | if ctrl_name != target_name: 53 | error(f"{ctrl_name} and {target_name} are not the same table.") 54 | # Now we need to get the table history and make sure that the table history intersects. 55 | ancestors_c = get_ancestors(spark, ctrl_name, c_snapshot) 56 | ancestors_t = get_ancestors(spark, target_name, t_snapshot) 57 | control_ancestor_set = set(ancestors_c) 58 | shared_ancestor = None 59 | for t in reversed(ancestors_t): 60 | if t in control_ancestor_set: 61 | shared_ancestor = t 62 | break 63 | if shared_ancestor is None: 64 | error(f"No shared ancestor between tables c:{ancestors_c} t:{ancestors_t}") 65 | try: 66 | c_diff_view_name = create_changelog_view(spark, ctrl_name, t.snapshot_id, c_snapshot, "c") 67 | t_diff_view_name = create_changelog_view(spark, ctrl_name, t.snapshot_id, t_snapshot, "t") 68 | c_diff_view = drop_iceberg_internal_columns(spark.sql("SELECT * FROM c")) 69 | t_diff_view = drop_iceberg_internal_columns(spark.sql("SELECT * FROM t")) 70 | except Exception as e: 71 | error(f"Iceberg may not support change log view, doing legacy compare {e}") 72 | return (c_diff_view, t_diff_view) 73 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/Utils.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | import scala.util.matching.Regex 6 | import scala.collection.mutable.HashSet 7 | import scala.reflect.ClassTag 8 | 9 | case class Utils()(implicit val doc: SemanticDocument) { 10 | /** 11 | * Match an RDD. Just the symbol matcher alone misses some cases so we also 12 | * look at type signatures etc. 13 | */ 14 | class MagicMatcher(matchers: List[SymbolMatcher]) { 15 | def unapply(param: Term) = { 16 | matchers.flatMap(unapplyMatcher(_, param)).headOption 17 | } 18 | 19 | def unapplyMatcher(matcher: SymbolMatcher, param: Term) = { 20 | param match { 21 | case matcher(e) => Some(e) 22 | case _ => 23 | param.symbol.info match { 24 | case None => 25 | None 26 | case Some(symbolInfo) => 27 | symbolInfo.signature match { 28 | case ValueSignature(tpe) => 29 | tpe match { 30 | case TypeRef(_, symbol, _) => 31 | symbol match { 32 | case matcher(e) => Some(param) 33 | case _ => None 34 | } 35 | case _ => 36 | None 37 | } 38 | case _ => None 39 | } 40 | } 41 | } 42 | } 43 | 44 | } 45 | 46 | /** 47 | * Strings, ints, doubles, etc. can all be literals or regular symbols 48 | */ 49 | class MagicMatcherLit[T <: meta.Lit: ClassTag](matchers: List[SymbolMatcher]) 50 | extends MagicMatcher(matchers) { 51 | override def unapply(param: Term) = { 52 | param match { 53 | case e: T => Some(e) 54 | case _ => super.unapply(param) 55 | } 56 | } 57 | } 58 | 59 | 60 | object intMatcher extends MagicMatcherLit[Lit.Int]( 61 | List(SymbolMatcher.normalized("scala.Int"))) 62 | object longMatcher extends MagicMatcherLit[Lit.Long]( 63 | List(SymbolMatcher.normalized("scala.Long"))) 64 | object doubleMatcher extends MagicMatcherLit[Lit.Double]( 65 | List(SymbolMatcher.normalized("scala.Double"))) 66 | 67 | object rddMatcher extends MagicMatcher( 68 | List(SymbolMatcher.normalized("org.apache.spark.rdd.RDD#"))) 69 | 70 | lazy val imports = HashSet(doc.tree.collect { 71 | case Importer(term, importees) => 72 | importees.map { 73 | importee => (term.toString(), importee.toString()) 74 | } 75 | }.flatten:_*) 76 | 77 | private val importSplitRegex = "(.*?)\\.([a-zA-Z0-9_]+)".r 78 | 79 | /** 80 | * Add an import if the import it self is not present & 81 | * there is no corresponding import for this. Note this may make 82 | * mistakes with rename imports & local imports. 83 | */ 84 | def addImportIfNotPresent(importElem: Importer): Patch = { 85 | val importName = importElem.toString() 86 | importName match { 87 | case importSplitRegex(importTermName, importee) => 88 | if (imports contains ((importTermName, importee))) { 89 | Patch.empty 90 | } else if (imports contains ((importTermName, "_"))) { 91 | Patch.empty 92 | } else { 93 | importElem match { 94 | case Importer(term, importees) => 95 | imports ++= importees.map { 96 | importee => (term.toString(), importee.toString()) 97 | } 98 | } 99 | Patch.addGlobalImport(importElem) 100 | } 101 | } 102 | } 103 | 104 | def importPresent(importName: String): Boolean = { 105 | false 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /scalafix/rules/src/main/scala/fix/GroupByKeyRenameColumnQQ.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import scalafix.v1._ 4 | import scala.meta._ 5 | 6 | class GroupByKeyRenameColumnQQ 7 | extends SemanticRule("GroupByKeyRenameColumnQQ") { 8 | override val description = 9 | """Renaming column "value" with "key" when have Dataset.groupByKey(...).count()""" 10 | 11 | override val isRewrite = true 12 | 13 | override def fix(implicit doc: SemanticDocument): Patch = { 14 | 15 | def matchOnTerm(t: Term): Patch = { 16 | val p = t match { 17 | case q""""value"""" => Patch.replaceTree(t, q""""key"""".toString()) 18 | case q"""'value""" => Patch.replaceTree(t, q"""'key""".toString()) 19 | case q"""col("value")""" => 20 | Patch.replaceTree(t, q"""col("key")""".toString()) 21 | case q"""col("value").as""" => 22 | Patch.replaceTree(t, q"""col("key").as""".toString()) 23 | case q"""col("value").alias""" => 24 | Patch.replaceTree(t, q"""col("key").alias""".toString()) 25 | case q"""upper(col("value"))""" => 26 | Patch.replaceTree(t, q"""upper(col("key"))""".toString()) 27 | case q"""upper(col('value))""" => 28 | Patch.replaceTree(t, q"""upper(col('key))""".toString()) 29 | case _ if ! t.children.isEmpty => 30 | t.children.map { 31 | case e: scala.meta.Term => matchOnTerm(e) 32 | case _ => Patch.empty 33 | }.asPatch 34 | case _ => Patch.empty 35 | } 36 | p 37 | } 38 | 39 | val dsGBKmatcher = SymbolMatcher.normalized("org.apache.spark.sql.Dataset.groupByKey") 40 | val dsSelect = SymbolMatcher.normalized("org.apache.spark.sql.Dataset.select") 41 | val dsMatcher = SymbolMatcher.normalized("org.apache.spark.sql.Dataset") 42 | val dfMatcher = SymbolMatcher.normalized("org.apache.spark.sql.DataFrame") 43 | val keyedDs = SymbolMatcher.normalized("org.apache.spark.sql.KeyValueGroupedDataset") 44 | val keyedDsCount = SymbolMatcher.normalized("org.apache.spark.sql.KeyValueGroupedDataset.count") 45 | 46 | def isDSGroupByKey(t: Term): Boolean = { 47 | val isDataset = t.collect { 48 | case q"""DataFame""" => true 49 | case q"""Dataset""" => true 50 | case q"""Dataset[_]""" => true 51 | case dsGBKmatcher(_) => true 52 | case dfMatcher(_) => true 53 | case dsMatcher(_) => true 54 | case dsSelect(_) => true 55 | case keyedDs(_) => true 56 | case keyedDsCount(_) => true 57 | } 58 | val isGroupByKey = t.collect { case q"""groupByKey""" => true } 59 | (isGroupByKey.isEmpty.equals(false) && isGroupByKey.head.equals( 60 | true 61 | )) && (isDataset.isEmpty.equals(false) && isDataset.head.equals( 62 | true 63 | )) 64 | } 65 | 66 | def matchOnTree(t: Tree): Patch = { 67 | t match { 68 | case _ @Term.Apply(tr, params) if (isDSGroupByKey(tr)) => { 69 | val patch = List( 70 | params.map(matchOnTerm).asPatch, 71 | params.map(matchOnTree).asPatch, 72 | tr.children.map(matchOnTree).asPatch 73 | ).asPatch 74 | patch 75 | } 76 | case elem @ _ => { 77 | elem.children match { 78 | case Nil => Patch.empty 79 | case _ => { 80 | elem.children.map(matchOnTree).asPatch 81 | } 82 | } 83 | } 84 | } 85 | } 86 | 87 | // Bit of a hack, but limit our blast radius 88 | if (doc.input.text.contains("groupByKey") && doc.input.text.contains("value") && 89 | doc.input.text.contains("org.apache.spark.sql")) { 90 | matchOnTree(doc.tree) 91 | } else { 92 | Patch.empty 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /e2e_demo/scala/dl_dependencies.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | ######################################################################## 4 | # DL Spark 2 and 3 5 | ######################################################################## 6 | echo "Downloading Spark 2 and 3" 7 | if [ ! -f ${CORE_SPARK2}.tgz ]; then 8 | wget https://archive.apache.org/dist/spark/spark-2.4.8/${CORE_SPARK2}.tgz & 9 | fi 10 | if [ ! -f hadoop-2.7.0.tar.gz ]; then 11 | wget https://archive.apache.org/dist/hadoop/common/hadoop-2.7.0/hadoop-2.7.0.tar.gz & 12 | fi 13 | if [ ! -f ${SPARK2_DETAILS}.tgz ]; then 14 | wget https://archive.apache.org/dist/spark/spark-2.4.8/${SPARK2_DETAILS}.tgz & 15 | fi 16 | if [ ! -f ${SPARK3_DETAILS}.tgz ]; then 17 | wget https://archive.apache.org/dist/spark/spark-3.3.1/${SPARK3_DETAILS}.tgz & 18 | fi 19 | wait 20 | 21 | 22 | ######################################################################## 23 | # Extracting artifacts 24 | ######################################################################## 25 | echo "Unzipping downloaded files" 26 | if [ ! -d ${SPARK3_DETAILS} ]; then 27 | tar -xf ${SPARK3_DETAILS}.tgz 28 | fi 29 | if [ ! -d ${SPARK2_DETAILS} ]; then 30 | tar -xf ${SPARK2_DETAILS}.tgz 31 | fi 32 | if [ ! -d ${CORE_SPARK2} ]; then 33 | tar -xf ${CORE_SPARK2}.tgz 34 | fi 35 | if [ ! -d hadoop-2.7.0 ]; then 36 | tar -xf hadoop-2.7.0.tar.gz 37 | fi 38 | 39 | ######################################################################## 40 | # DLing iceberg dependencies 41 | ######################################################################## 42 | echo "Fetching iceberg dependencies" 43 | # We use Iceberg 1.3.0 for Spark 3.3 since we want to able to use changelog view 44 | if [ ! -f iceberg-spark-runtime-3.3_2.12-1.3.0.jar ]; then 45 | wget https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-3.3_2.12/1.3.0/iceberg-spark-runtime-3.3_2.12-1.3.0.jar -O iceberg-spark-runtime-3.3_2.12-1.3.0.jar & 46 | fi 47 | # For 2.4 we use Iceberg 1.1 since newer versions are not published for Spark 2.4 48 | if [ ! -f iceberg-spark-runtime-2.4-1.1.0.jar ]; then 49 | wget https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-2.4/1.1.0/iceberg-spark-runtime-2.4-1.1.0.jar -O iceberg-spark-runtime-2.4-1.1.0.jar & 50 | fi 51 | wait 52 | cp iceberg-spark-runtime-3.3_2.12-1.3.0.jar ${SPARK3_DETAILS}/jars/ 53 | cp iceberg-spark-runtime-2.4-1.1.0.jar ${SPARK2_DETAILS}/jars/ 54 | 55 | ######################################################################## 56 | # Bring over the hadoop parts we need, this is a bit of a hack but using hadoop-2.7.0 contents 57 | # does not work well either. 58 | ######################################################################## 59 | cp -f ${CORE_SPARK2}/jars/apache*.jar ${SPARK2_DETAILS}/jars/ 60 | cp -f ${CORE_SPARK2}/jars/guice*.jar ${SPARK2_DETAILS}/jars/ 61 | cp -f ${CORE_SPARK2}/jars/http*.jar ${SPARK2_DETAILS}/jars/ 62 | cp -f ${CORE_SPARK2}/jars/proto*.jar ${SPARK2_DETAILS}/jars/ 63 | cp -f ${CORE_SPARK2}/jars/parquet-hadoop*.jar ${SPARK2_DETAILS}/jars/ 64 | cp -f ${CORE_SPARK2}/jars/snappy*.jar ${SPARK2_DETAILS}/jars/ 65 | cp -f ${CORE_SPARK2}/jars/hadoop*.jar ${SPARK2_DETAILS}/jars/ 66 | cp -f ${CORE_SPARK2}/jars/guava*.jar ${SPARK2_DETAILS}/jars/ 67 | cp -f ${CORE_SPARK2}/jars/commons*.jar ${SPARK2_DETAILS}/jars/ 68 | cp -f ${CORE_SPARK2}/jars/libthrift*.jar ${SPARK2_DETAILS}/jars/ 69 | cp -f ${CORE_SPARK2}/jars/slf4j*.jar ${SPARK2_DETAILS}/jars/ 70 | cp -f ${CORE_SPARK2}/jars/log4j* ${SPARK2_DETAILS}/jars/ 71 | cp -f ${CORE_SPARK2}/jars/hive-*.jar ${SPARK2_DETAILS}/jars/ 72 | 73 | ######################################################################## 74 | # Bring over non-scala 2.11 jackson jars. 75 | ######################################################################## 76 | cp -f ${CORE_SPARK2}/jars/*jackson*.jar ${SPARK2_DETAILS}/jars/ 77 | rm ${SPARK2_DETAILS}/jars/*jackson*_2.11*.jar 78 | -------------------------------------------------------------------------------- /.github/workflows/github-actions-basic.yml: -------------------------------------------------------------------------------- 1 | name: Build and test 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | - '**' 8 | - '!branch-*.*' 9 | pull_request: 10 | types: [opened, reopened, edited] 11 | 12 | jobs: 13 | # Build: build and run the tests for specified modules. 14 | build: 15 | runs-on: ubuntu-22.04 16 | strategy: 17 | fail-fast: false 18 | env: 19 | SPARK_VERSION: ${{ matrix.spark }} 20 | steps: 21 | - uses: coursier/cache-action@v6 22 | - name: sbt 23 | run: | 24 | sudo apt-get update 25 | sudo apt-get install -y apt-transport-https curl gnupg -yqq 26 | echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | sudo tee /etc/apt/sources.list.d/sbt.list 27 | echo "deb https://repo.scala-sbt.org/scalasbt/debian /" | sudo tee /etc/apt/sources.list.d/sbt_old.list 28 | curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | sudo -H gpg --no-default-keyring --keyring gnupg-ring:/etc/apt/trusted.gpg.d/scalasbt-release.gpg --import 29 | sudo chmod 644 /etc/apt/trusted.gpg.d/scalasbt-release.gpg 30 | sudo apt-get update 31 | sudo apt-get install -y sbt 32 | - name: Checkout 33 | uses: actions/checkout@v3 34 | with: 35 | fetch-depth: 0 36 | # Install python deps 37 | - name: Install python deps 38 | run: pip install -r sql/requirements.txt 39 | - name: Make pip python programs findable 40 | run: export PATH=`python3 -m site --user-base`/bin:$PATH 41 | - name: SQL Module - Lint Checks 42 | run: cd sql; flake8 --max-line-length 100 --ignore=E129,W504 43 | - name: SQL Module - Install and Test 44 | run: cd sql; pip install -e .; pytest . 45 | # Run the scala tests. 46 | # We are exctracting the dynamic version from something like [info] 0.1.9+18-ddfaf3e6-SNAPSHOT (non-snapshots will not have +...) 47 | - name: Run scalafix sbt Spark 2.3.2 tests 48 | run: cd scalafix; sbt ";clean;compile;test" -DsparkVersion=2.3.2 ;cd .. 49 | - name: Run scalafix sbt Spark 2.1.1 tests 50 | run: cd scalafix; sbt ";clean;compile;test" -DsparkVersion=2.3.2 ;cd .. 51 | - name: Run sbt tests on scalafix & extract the dynver 52 | run: cd scalafix; sbt ";clean;compile;test;publishLocal;+publishLocal"; sbt "show rules/dynver" |grep "\[info\]" |grep "[0-9].[0-9].[0-9]" | cut -f 2 -d " " > ~/rules_version 53 | - name: Run sbt tests on our WAP plugin 54 | run: cd iceberg-spark-upgrade-wap-plugin; sbt ";clean;test" 55 | - name: PySparkler - Make Install 56 | run: | 57 | cd pysparkler 58 | make install 59 | - name: PySparkler - Make Lint 60 | run: | 61 | cd pysparkler 62 | make lint 63 | - name: PySparkler - Make Test 64 | run: | 65 | cd pysparkler 66 | make test 67 | - name: Cache tgzs 68 | id: cache-tgz 69 | uses: actions/cache@v3 70 | with: 71 | path: e2e_demo/scala/*.tgz 72 | key: ${{ runner.os }}-tgz 73 | - name: Cache tar.gz 74 | id: cache-targz 75 | uses: actions/cache@v3 76 | with: 77 | path: e2e_demo/scala/*.tar.gz 78 | key: ${{ runner.os }}-targz 79 | - name: Cache extractions spark 3 80 | id: cache-extract-spark3 81 | uses: actions/cache@v3 82 | with: 83 | path: e2e_demo/scala/spark-3.3.1-bin-hadoop2 84 | key: ${{ runner.os }}-extract-spark3 85 | - name: Cache extractions hadoop 86 | id: cache-extract-hadoop 87 | uses: actions/cache@v3 88 | with: 89 | path: e2e_demo/scala/hadoop-2* 90 | key: ${{ runner.os }}-extract-hadoop2 91 | # Run the sbt e2e demo with the local version 92 | - name: sbt e2e demo 93 | run: cd e2e_demo/scala; SCALAFIX_RULES_VERSION=$(cat ~/rules_version) NO_PROMPT="yes" ./run_demo.sh 94 | # Run the gradle e2e demo with the local version 95 | - name: gradle e2e demo 96 | run: cd e2e_demo/scala; SCALAFIX_RULES_VERSION=$(cat ~/rules_version) NO_PROMPT="yes" ./run_demo-gradle.sh 97 | -------------------------------------------------------------------------------- /pysparkler/pysparkler/pyspark_32_to_33.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | import libcst as cst 19 | import libcst.matchers as m 20 | 21 | from pysparkler.base import ( 22 | RequiredDependencyVersionCommentWriter, 23 | StatementLineCommentWriter, 24 | ) 25 | 26 | 27 | class DataframeDropAxisIndexByDefault(StatementLineCommentWriter): 28 | """In Spark 3.3, the drop method of pandas API on Spark DataFrame supports dropping rows by index, and sets dropping 29 | by index instead of column by default. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | pyspark_version: str = "3.3", 35 | ): 36 | super().__init__( 37 | transformer_id="PY32-33-001", 38 | comment=f"As of PySpark {pyspark_version} the drop method of pandas API on Spark DataFrame sets drop by \ 39 | index as default, instead of drop by column. Please explicitly set axis argument to 1 to drop by column.", 40 | ) 41 | 42 | def visit_Call(self, node: cst.Call) -> None: 43 | """Check if drop method does not specify the axis argument and drops by labels""" 44 | if m.matches( 45 | node, 46 | m.Call( 47 | func=m.Attribute( 48 | attr=m.Name("drop"), 49 | ), 50 | args=[ 51 | m.OneOf( 52 | m.Arg(keyword=m.Name("labels")), 53 | m.Arg(keyword=None), 54 | ) 55 | ], 56 | ), 57 | ): 58 | self.match_found = True 59 | 60 | 61 | class RequiredPandasVersionCommentWriter(RequiredDependencyVersionCommentWriter): 62 | """In Spark 3.3, PySpark upgrades Pandas version, the new minimum required version changes from 0.23.2 to 1.0.5.""" 63 | 64 | def __init__( 65 | self, 66 | pyspark_version: str = "3.3", 67 | required_dependency_name: str = "pandas", 68 | required_dependency_version: str = "1.0.5", 69 | ): 70 | super().__init__( 71 | transformer_id="PY32-33-002", 72 | pyspark_version=pyspark_version, 73 | required_dependency_name=required_dependency_name, 74 | required_dependency_version=required_dependency_version, 75 | ) 76 | 77 | 78 | class SQLDataTypesReprReturnsObjectCommentWriter(StatementLineCommentWriter): 79 | """In Spark 3.3, the repr return values of SQL DataTypes have been changed to yield an object with the same value 80 | when passed to eval. 81 | """ 82 | 83 | def __init__( 84 | self, 85 | pyspark_version: str = "3.3", 86 | ): 87 | super().__init__( 88 | transformer_id="PY32-33-003", 89 | comment=f"As of PySpark {pyspark_version}, the repr return values of SQL DataTypes have been changed to \ 90 | yield an object with the same value when passed to eval.", 91 | ) 92 | 93 | def visit_Call(self, node: cst.Call) -> None: 94 | """Check if the repr method of SQL DataTypes is called""" 95 | if m.matches( 96 | node, 97 | m.Call(func=m.Name("repr")), 98 | ): 99 | self.match_found = True 100 | 101 | 102 | def pyspark_32_to_33_transformers() -> list[cst.CSTTransformer]: 103 | """Return a list of transformers for PySpark 3.2 to 3.3 migration guide""" 104 | return [ 105 | DataframeDropAxisIndexByDefault(), 106 | RequiredPandasVersionCommentWriter(), 107 | SQLDataTypesReprReturnsObjectCommentWriter(), 108 | ] 109 | -------------------------------------------------------------------------------- /pysparkler/pysparkler/pyspark_31_to_32.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | import libcst as cst 19 | import libcst.matchers as m 20 | 21 | from pysparkler.base import StatementLineCommentWriter 22 | 23 | 24 | class SqlMlMethodsRaiseTypeErrorCommentWriter(StatementLineCommentWriter): 25 | """In Spark 3.2, the PySpark methods from sql, ml, spark_on_pandas modules raise the TypeError instead of ValueError 26 | when are applied to a param of inappropriate type. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | pyspark_version: str = "3.2", 32 | ): 33 | super().__init__( 34 | transformer_id="PY31-32-001", 35 | comment=f"As of PySpark {pyspark_version}, the methods from sql, ml, spark_on_pandas modules raise the \ 36 | TypeError instead of ValueError when are applied to a param of inappropriate type.", 37 | ) 38 | self._has_sql_or_ml_import = False 39 | 40 | def visit_ImportFrom(self, node: cst.ImportFrom) -> None: 41 | """Check if pyspark.sql.* or pyspark.ml.* is being used in a from import statement""" 42 | if m.matches( 43 | node, 44 | m.ImportFrom( 45 | module=m.Attribute( 46 | value=m.OneOf( 47 | m.Attribute( 48 | value=m.Name("pyspark"), 49 | attr=m.Name("sql"), 50 | ), 51 | m.Attribute( 52 | value=m.Name("pyspark"), 53 | attr=m.Name("ml"), 54 | ), 55 | ), 56 | ), 57 | ), 58 | ): 59 | self._has_sql_or_ml_import = True 60 | 61 | def visit_Import(self, node: cst.Import) -> None: 62 | """Check if pyspark.sql.* or pyspark.ml.* is being used in an import statement""" 63 | if m.matches( 64 | node, 65 | m.Import( 66 | names=[ 67 | m.OneOf( 68 | m.ImportAlias( 69 | name=m.Attribute( 70 | value=m.Attribute( 71 | value=m.Name("pyspark"), 72 | attr=m.Name("sql"), 73 | ) 74 | ), 75 | ), 76 | m.ImportAlias( 77 | name=m.Attribute( 78 | value=m.Attribute( 79 | value=m.Name("pyspark"), 80 | attr=m.Name("ml"), 81 | ) 82 | ), 83 | ), 84 | ), 85 | m.ZeroOrMore(), 86 | ] 87 | ), 88 | ): 89 | self._has_sql_or_ml_import = True 90 | 91 | def visit_ExceptHandler(self, node: cst.ExceptHandler) -> None: 92 | """Check if the except handler is catching the ValueError""" 93 | if m.matches( 94 | node, 95 | m.ExceptHandler( 96 | type=m.Name("ValueError"), 97 | ), 98 | ): 99 | if self._has_sql_or_ml_import: 100 | self.match_found = True 101 | 102 | 103 | def pyspark_31_to_32_transformers() -> list[cst.CSTTransformer]: 104 | """Return a list of transformers for PySpark 3.1 to 3.2 migration guide""" 105 | return [SqlMlMethodsRaiseTypeErrorCommentWriter()] 106 | -------------------------------------------------------------------------------- /pysparkler/tests/test_sql_21_to_33.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | 19 | from pysparkler.sql_21_to_33 import SqlStatementUpgradeAndCommentWriter 20 | from tests.conftest import rewrite 21 | 22 | 23 | def test_upgrades_non_templated_sql(): 24 | given_code = """\ 25 | from pyspark.sql import SparkSession 26 | 27 | spark = SparkSession.builder.appName("SQL Example").getOrCreate() 28 | result = spark.sql("select cast(dateint as int) val from my_table limit 10") 29 | spark.stop() 30 | """ 31 | modified_code = rewrite(given_code, SqlStatementUpgradeAndCommentWriter()) 32 | expected_code = """\ 33 | from pyspark.sql import SparkSession 34 | 35 | spark = SparkSession.builder.appName("SQL Example").getOrCreate() 36 | result = spark.sql("select int(dateint) val from my_table limit 10") # PY21-33-001: Spark SQL statement has been upgraded to Spark 3.3 compatible syntax. # noqa: E501 37 | spark.stop() 38 | """ 39 | assert modified_code == expected_code 40 | 41 | 42 | def test_upgrades_templated_sql(): 43 | given_code = """\ 44 | from pyspark.sql import SparkSession 45 | 46 | spark = SparkSession.builder.appName("SQL Example").getOrCreate() 47 | table_name = "my_table" 48 | result = spark.sql(f"select cast(dateint as int) val from {table_name} limit 10") 49 | spark.stop() 50 | """ 51 | modified_code = rewrite(given_code, SqlStatementUpgradeAndCommentWriter()) 52 | expected_code = """\ 53 | from pyspark.sql import SparkSession 54 | 55 | spark = SparkSession.builder.appName("SQL Example").getOrCreate() 56 | table_name = "my_table" 57 | result = spark.sql(f"select int(dateint) val from {table_name} limit 10") # PY21-33-001: Spark SQL statement has been upgraded to Spark 3.3 compatible syntax. # noqa: E501 58 | spark.stop() 59 | """ 60 | assert modified_code == expected_code 61 | 62 | 63 | def test_unable_to_upgrade_templated_sql_with_complex_expressions(): 64 | given_code = """\ 65 | from pyspark.sql import SparkSession 66 | 67 | spark = SparkSession.builder.appName("SQL Example").getOrCreate() 68 | table_name = "my_table" 69 | num = 10 70 | result = spark.sql(f"select cast(dateint as int) val from {table_name} where x < {num * 100} limit 10") 71 | spark.stop() 72 | """ 73 | modified_code = rewrite(given_code, SqlStatementUpgradeAndCommentWriter()) 74 | expected_code = """\ 75 | from pyspark.sql import SparkSession 76 | 77 | spark = SparkSession.builder.appName("SQL Example").getOrCreate() 78 | table_name = "my_table" 79 | num = 10 80 | result = spark.sql(f"select cast(dateint as int) val from {table_name} where x < {num * 100} limit 10") # PY21-33-001: Unable to inspect the Spark SQL statement since the formatted string SQL has complex expressions within. Please de-template the SQL and use the 'pysparkler upgrade-sql' CLI command to upcast the SQL yourself. # noqa: E501 81 | spark.stop() 82 | """ 83 | assert modified_code == expected_code 84 | 85 | 86 | def test_no_upgrades_required_after_inspecting_sql(): 87 | given_code = """\ 88 | from pyspark.sql import SparkSession 89 | 90 | spark = SparkSession.builder.appName("SQL Example").getOrCreate() 91 | result = spark.sql("select * from my_table limit 10") 92 | spark.stop() 93 | """ 94 | modified_code = rewrite(given_code, SqlStatementUpgradeAndCommentWriter()) 95 | expected_code = """\ 96 | from pyspark.sql import SparkSession 97 | 98 | spark = SparkSession.builder.appName("SQL Example").getOrCreate() 99 | result = spark.sql("select * from my_table limit 10") # PY21-33-001: Spark SQL statement has Spark 3.3 compatible syntax. # noqa: E501 100 | spark.stop() 101 | """ 102 | assert modified_code == expected_code 103 | -------------------------------------------------------------------------------- /scalafix/build.sbt: -------------------------------------------------------------------------------- 1 | val sparkVersion = settingKey[String]("Spark version") 2 | val srcSparkVersion = settingKey[String]("Source Spark version") 3 | val targetSparkVersion = settingKey[String]("Target Spark version") 4 | val sparkUpgradeVersion = settingKey[String]("Spark upgrade version") 5 | 6 | 7 | lazy val V = _root_.scalafix.sbt.BuildInfo 8 | inThisBuild( 9 | List( 10 | organization := "com.holdenkarau", 11 | homepage := Some(url("https://github.com/holdenk/spark-auto-upgrade")), 12 | licenses := List("Apache-2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0")), 13 | srcSparkVersion := System.getProperty("sparkVersion", "2.4.8"), 14 | targetSparkVersion := System.getProperty("targetSparkVersion", "3.3.0"), 15 | sparkVersion := srcSparkVersion.value, 16 | // actual version is pulled from tags. 17 | sparkUpgradeVersion := "0.0.1-SNAPSHOT", // dev version for testing. 18 | versionScheme := Some("early-semver"), 19 | publishMavenStyle := true, 20 | publishTo := { 21 | val nexus = "https://oss.sonatype.org/" 22 | Some("releases" at nexus + "service/local/staging/deploy/maven2") 23 | }, 24 | useGpg := true, 25 | developers := List( 26 | Developer( 27 | "holdenk", 28 | "Holden Karau", 29 | "holden@pigscanfly.ca", 30 | url("https://github.com/holdenk/spark-auto-upgrade") 31 | ) 32 | ), 33 | scalaVersion := { 34 | if (sparkVersion.value > "2.4") { 35 | V.scala212 36 | } else { 37 | V.scala211 38 | } 39 | }, 40 | crossScalaVersions := { 41 | if (sparkVersion.value > "3.1.0") { 42 | List(V.scala212, V.scala213) 43 | } else if (sparkVersion.value > "2.4") { 44 | List(V.scala211, V.scala212) 45 | } else { 46 | List(V.scala211) 47 | } 48 | }, 49 | addCompilerPlugin(scalafixSemanticdb), 50 | scalacOptions ++= List( 51 | "-Yrangepos", 52 | "-P:semanticdb:synthetics:on" 53 | ), 54 | scmInfo := Some(ScmInfo( 55 | url("https://github.com/holdenk/spark-testing-base.git"), 56 | "scm:git@github.com:holdenk/spark-testing-base.git" 57 | )), 58 | skip in publish := false 59 | ) 60 | ) 61 | 62 | skip in publish := true 63 | 64 | 65 | lazy val rules = project.settings( 66 | moduleName := s"spark-scalafix-rules-${sparkVersion.value}", 67 | libraryDependencies += "ch.epfl.scala" %% "scalafix-core" % V.scalafixVersion, 68 | ) 69 | 70 | lazy val input = project.settings( 71 | skip in publish := true, 72 | sparkVersion := srcSparkVersion.value, 73 | libraryDependencies ++= Seq( 74 | "org.scalacheck" %% "scalacheck" % "1.14.0", 75 | "org.apache.spark" %% "spark-core" % sparkVersion.value, 76 | "org.apache.spark" %% "spark-sql" % sparkVersion.value, 77 | "org.apache.spark" %% "spark-hive" % sparkVersion.value, 78 | "org.scalatest" %% "scalatest" % "3.0.0") 79 | ) 80 | 81 | lazy val output = project.settings( 82 | skip in publish := true, 83 | sparkVersion := targetSparkVersion.value, 84 | scalaVersion := V.scala212, 85 | crossScalaVersions := { 86 | if (sparkVersion.value > "3.1.0") { 87 | List(V.scala212, V.scala213) 88 | } else { 89 | List(V.scala211, V.scala212) 90 | } 91 | }, 92 | libraryDependencies ++= Seq( 93 | "org.scalacheck" %% "scalacheck" % "1.14.0", 94 | "org.apache.spark" %% "spark-core" % sparkVersion.value, 95 | "org.apache.spark" %% "spark-sql" % sparkVersion.value, 96 | "org.apache.spark" %% "spark-hive" % sparkVersion.value, 97 | "org.scalatest" %% "scalatest" % "3.2.14") 98 | ) 99 | 100 | lazy val tests = project 101 | .settings( 102 | skip in publish := true, 103 | libraryDependencies += "ch.epfl.scala" % "scalafix-testkit" % V.scalafixVersion % Test cross CrossVersion.full, 104 | compile.in(Compile) := 105 | compile.in(Compile).dependsOn(compile.in(input, Compile), compile.in(output, Compile)).value, 106 | scalafixTestkitOutputSourceDirectories := 107 | sourceDirectories.in(output, Compile).value, 108 | scalafixTestkitInputSourceDirectories := 109 | sourceDirectories.in(input, Compile).value, 110 | scalafixTestkitInputClasspath := 111 | fullClasspath.in(input, Compile).value, 112 | ) 113 | .dependsOn(rules) 114 | .enablePlugins(ScalafixTestkitPlugin) 115 | 116 | ThisBuild / libraryDependencySchemes ++= Seq( 117 | "org.scala-lang.modules" %% "scala-xml" % VersionScheme.Always 118 | ) 119 | -------------------------------------------------------------------------------- /scalafix/output/src/main/scala/fix/GroupByKeyRenameColumnQQ.scala: -------------------------------------------------------------------------------- 1 | package fix 2 | 3 | import org.apache.spark._ 4 | import org.apache.spark.sql._ 5 | import org.apache.spark.sql.functions._ 6 | 7 | object GroupByKeyRenameColumnQQ { 8 | case class City(countryName: String, cityName: String) 9 | def inSource(spark: SparkSession): Unit = { 10 | import spark.implicits._ 11 | 12 | val ds = List("Person 1", "Person 2", "User 1", "User2", "Test").toDS() 13 | // Don't change the RDD one. 14 | val sc = SparkContext.getOrCreate() 15 | val rdd = sc.parallelize(List((1,2))).groupByKey().map(x => "value") 16 | 17 | // Do change the inidrect ds ones 18 | val ds11 = 19 | ds.groupByKey(c => c.substring(0, 3)).count().select(col("key")) 20 | val df: DataFrame = null 21 | var words1: Dataset[Row] = null 22 | def keyMe(a: Row): String = { 23 | "1" 24 | } 25 | val stopArray = array(lit("hi")) 26 | val splitPattern = "" 27 | words1.groupByKey(keyMe).count().select(col("key").as("word"), col("count(1)")).orderBy("count(1)") 28 | val words = df.select(explode(split(lower(col("value")), splitPattern)).as("words")).filter( 29 | not(array_contains(stopArray, col("words")))) 30 | words.groupByKey(keyMe).count().select(col("key").as("word"), col("count(1)")).orderBy("count(1)") 31 | 32 | val ds10 = List("Person 1", "Person 2", "User 1", "User 3", "test") 33 | .toDS() 34 | .groupByKey(i => i.substring(0, 3)) 35 | .count() 36 | .select(col("key")) 37 | 38 | val ds1 = List("Person 1", "Person 2", "User 1", "User 3", "test") 39 | .toDS() 40 | .groupByKey(i => i.substring(0, 3)) 41 | .count() 42 | .select('key) 43 | 44 | val ds2 = List("Person 1", "Person 2", "User 1", "User 3", "test") 45 | .toDS() 46 | .groupByKey(i => i.substring(0, 3)) 47 | .count() 48 | .withColumn("key", upper(col("key"))) 49 | 50 | val ds3 = 51 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 52 | .toDS() 53 | .groupByKey(l => l.substring(0, 3)) 54 | .count() 55 | .withColumnRenamed("key", "newName") 56 | 57 | val ds5 = 58 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 59 | .toDS() 60 | .groupByKey(l => l.substring(0, 3)) 61 | .count() 62 | .withColumn("newNameCol", upper(col("key"))) 63 | 64 | val ds00 = List("Person 1", "Person 2", "User 1", "User 3", "test") 65 | .toDS() 66 | .select(col("value")) 67 | 68 | val c = List("Person 1", "Person 2", "User 1", "User 3", "test") 69 | .toDS() 70 | .withColumnRenamed("value", "newColName") 71 | .count() 72 | 73 | val s = "value" 74 | val value = 1 75 | } 76 | def inSource2(spark: SparkSession): Unit = { 77 | import spark.implicits._ 78 | 79 | val source = Seq( 80 | City("USA", "Seatle"), 81 | City("Canada", "Toronto"), 82 | City("Ukraine", "Kyev"), 83 | City("Ukraine", "Ternopil"), 84 | City("Canada", "Vancouver"), 85 | City("Germany", "Köln") 86 | ) 87 | 88 | val df = source.toDF().groupBy(col("countryName")).count 89 | val ds = source.toDS().groupBy(col("countryName")).count 90 | val res = source.toDF().as[City].groupBy(col("countryName")).count 91 | val res1 = res 92 | .select(col("countryName").alias("value")) 93 | .as[String] 94 | .groupByKey(l => l.substring(0, 3)) 95 | .count() 96 | val res2 = res 97 | .select(col("countryName").alias("newValue")) 98 | .as[String] 99 | .groupByKey(l => l.substring(0, 3)) 100 | .count() 101 | .select('key) 102 | } 103 | def inSource3(spark: SparkSession): Unit = { 104 | import spark.implicits._ 105 | 106 | val source = Seq( 107 | City("USA", "Seatle"), 108 | City("Canada", "Toronto"), 109 | City("Ukraine", "Kyev"), 110 | City("Ukraine", "Ternopil"), 111 | City("Canada", "Vancouver"), 112 | City("Germany", "Köln") 113 | ) 114 | val res = source.toDF().as[City].groupBy(col("countryName")).count 115 | val res1 = res 116 | .select(col("countryName").alias("value")) 117 | .as[String] 118 | .groupByKey(l => l.substring(0, 3)) 119 | .count() 120 | 121 | val ds = List("Person 1", "Person 2", "User 1", "User 2").toDS() 122 | 123 | val res2 = res1.union(ds.groupByKey(l => l.substring(0, 3)).count) 124 | 125 | val r = res2.select('value, col("count(1)")) 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /scalafix/input/src/main/scala/fix/GroupByKeyRenameColumnQQ.scala: -------------------------------------------------------------------------------- 1 | /* 2 | rule=GroupByKeyRenameColumnQQ 3 | */ 4 | package fix 5 | 6 | import org.apache.spark._ 7 | import org.apache.spark.sql._ 8 | import org.apache.spark.sql.functions._ 9 | 10 | object GroupByKeyRenameColumnQQ { 11 | case class City(countryName: String, cityName: String) 12 | def inSource(spark: SparkSession): Unit = { 13 | import spark.implicits._ 14 | 15 | val ds = List("Person 1", "Person 2", "User 1", "User2", "Test").toDS() 16 | // Don't change the RDD one. 17 | val sc = SparkContext.getOrCreate() 18 | val rdd = sc.parallelize(List((1,2))).groupByKey().map(x => "value") 19 | 20 | // Do change the inidrect ds ones 21 | val ds11 = 22 | ds.groupByKey(c => c.substring(0, 3)).count().select(col("value")) 23 | val df: DataFrame = null 24 | var words1: Dataset[Row] = null 25 | def keyMe(a: Row): String = { 26 | "1" 27 | } 28 | val stopArray = array(lit("hi")) 29 | val splitPattern = "" 30 | words1.groupByKey(keyMe).count().select(col("value").as("word"), col("count(1)")).orderBy("count(1)") 31 | val words = df.select(explode(split(lower(col("value")), splitPattern)).as("words")).filter( 32 | not(array_contains(stopArray, col("words")))) 33 | words.groupByKey(keyMe).count().select(col("value").as("word"), col("count(1)")).orderBy("count(1)") 34 | 35 | val ds10 = List("Person 1", "Person 2", "User 1", "User 3", "test") 36 | .toDS() 37 | .groupByKey(i => i.substring(0, 3)) 38 | .count() 39 | .select(col("value")) 40 | 41 | val ds1 = List("Person 1", "Person 2", "User 1", "User 3", "test") 42 | .toDS() 43 | .groupByKey(i => i.substring(0, 3)) 44 | .count() 45 | .select('value) 46 | 47 | val ds2 = List("Person 1", "Person 2", "User 1", "User 3", "test") 48 | .toDS() 49 | .groupByKey(i => i.substring(0, 3)) 50 | .count() 51 | .withColumn("value", upper(col("value"))) 52 | 53 | val ds3 = 54 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 55 | .toDS() 56 | .groupByKey(l => l.substring(0, 3)) 57 | .count() 58 | .withColumnRenamed("value", "newName") 59 | 60 | val ds5 = 61 | List("Paerson 1", "Person 2", "User 1", "User 2", "test", "gggg") 62 | .toDS() 63 | .groupByKey(l => l.substring(0, 3)) 64 | .count() 65 | .withColumn("newNameCol", upper(col("value"))) 66 | 67 | val ds00 = List("Person 1", "Person 2", "User 1", "User 3", "test") 68 | .toDS() 69 | .select(col("value")) 70 | 71 | val c = List("Person 1", "Person 2", "User 1", "User 3", "test") 72 | .toDS() 73 | .withColumnRenamed("value", "newColName") 74 | .count() 75 | 76 | val s = "value" 77 | val value = 1 78 | } 79 | def inSource2(spark: SparkSession): Unit = { 80 | import spark.implicits._ 81 | 82 | val source = Seq( 83 | City("USA", "Seatle"), 84 | City("Canada", "Toronto"), 85 | City("Ukraine", "Kyev"), 86 | City("Ukraine", "Ternopil"), 87 | City("Canada", "Vancouver"), 88 | City("Germany", "Köln") 89 | ) 90 | 91 | val df = source.toDF().groupBy(col("countryName")).count 92 | val ds = source.toDS().groupBy(col("countryName")).count 93 | val res = source.toDF().as[City].groupBy(col("countryName")).count 94 | val res1 = res 95 | .select(col("countryName").alias("value")) 96 | .as[String] 97 | .groupByKey(l => l.substring(0, 3)) 98 | .count() 99 | val res2 = res 100 | .select(col("countryName").alias("newValue")) 101 | .as[String] 102 | .groupByKey(l => l.substring(0, 3)) 103 | .count() 104 | .select('value) 105 | } 106 | def inSource3(spark: SparkSession): Unit = { 107 | import spark.implicits._ 108 | 109 | val source = Seq( 110 | City("USA", "Seatle"), 111 | City("Canada", "Toronto"), 112 | City("Ukraine", "Kyev"), 113 | City("Ukraine", "Ternopil"), 114 | City("Canada", "Vancouver"), 115 | City("Germany", "Köln") 116 | ) 117 | val res = source.toDF().as[City].groupBy(col("countryName")).count 118 | val res1 = res 119 | .select(col("countryName").alias("value")) 120 | .as[String] 121 | .groupByKey(l => l.substring(0, 3)) 122 | .count() 123 | 124 | val ds = List("Person 1", "Person 2", "User 1", "User 2").toDS() 125 | 126 | val res2 = res1.union(ds.groupByKey(l => l.substring(0, 3)).count) 127 | 128 | val r = res2.select('value, col("count(1)")) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /pysparkler/pysparkler/pyspark_23_to_24.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | import libcst as cst 19 | import libcst.matchers as m 20 | 21 | from pysparkler.base import StatementLineCommentWriter 22 | 23 | 24 | class ToPandasAllowsFallbackOnArrowOptimization(StatementLineCommentWriter): 25 | """In PySpark 2.4, when Arrow optimization is enabled, previously toPandas just failed when Arrow optimization is 26 | unable to be used whereas createDataFrame from Pandas DataFrame allowed the fallback to non-optimization. Now, both 27 | toPandas and createDataFrame from Pandas DataFrame allow the fallback by default, which can be switched off by 28 | spark.sql.execution.arrow.fallback.enabled. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | pyspark_version: str = "2.4", 34 | ): 35 | super().__init__( 36 | transformer_id="PY23-24-001", 37 | comment=f"As of PySpark {pyspark_version} toPandas() allows fallback to non-optimization by default when \ 38 | Arrow optimization is unable to be used. This can be switched off by spark.sql.execution.arrow.fallback.enabled", 39 | ) 40 | 41 | def visit_Call(self, node: cst.Call) -> None: 42 | """Check if toPandas is being called""" 43 | if m.matches(node, m.Call(func=m.Attribute(attr=m.Name("toPandas")))): 44 | self.match_found = True 45 | 46 | 47 | class RecommendDataFrameWriterV2ApiForV1ApiSaveAsTable(StatementLineCommentWriter): 48 | """Spark 2.4 introduced the new DataFrameWriterV2 API for writing to tables using data frames. The v2 API is 49 | recommended for several reasons: 50 | CTAS, RTAS, and overwrite by filter are supported 51 | Hidden partition expressions are supported in partitionedBy 52 | """ 53 | 54 | def __init__( 55 | self, 56 | pyspark_version: str = "2.4", 57 | ): 58 | super().__init__( 59 | transformer_id="PY23-24-002", 60 | comment=f"""As of PySpark {pyspark_version} the new DataFrameWriterV2 API is recommended for creating or \ 61 | replacing tables using data frames. To run a CTAS or RTAS, use create(), replace(), or createOrReplace() operations. \ 62 | For example: df.writeTo("prod.db.table").partitionedBy("dateint").createOrReplace(). Please note that the v1 DataFrame \ 63 | write API is still supported, but is not recommended.""", 64 | ) 65 | 66 | def visit_Call(self, node: cst.Call) -> None: 67 | """Check if toPandas is being called""" 68 | if m.matches(node, m.Call(func=m.Attribute(attr=m.Name("saveAsTable")))): 69 | self.match_found = True 70 | 71 | 72 | class RecommendDataFrameWriterV2ApiForV1ApiInsertInto(StatementLineCommentWriter): 73 | """Spark 2.4 introduced the new DataFrameWriterV2 API for writing to tables using data frames. The v2 API is 74 | recommended for several reasons: 75 | All operations consistently write columns to a table by name 76 | Overwrite behavior is explicit, either dynamic or by a user-supplied filter 77 | """ 78 | 79 | def __init__( 80 | self, 81 | pyspark_version: str = "2.4", 82 | ): 83 | super().__init__( 84 | transformer_id="PY23-24-003", 85 | comment=f"""As of PySpark {pyspark_version} the new DataFrameWriterV2 API is recommended for writing into \ 86 | tables in append or overwrite mode. For example, to append use df.writeTo(t).append() and to overwrite partitions \ 87 | dynamically use df.writeTo(t).overwritePartitions() Please note that the v1 DataFrame write API is still supported, \ 88 | but is not recommended.""", 89 | ) 90 | 91 | def visit_Call(self, node: cst.Call) -> None: 92 | """Check if toPandas is being called""" 93 | if m.matches(node, m.Call(func=m.Attribute(attr=m.Name("insertInto")))): 94 | self.match_found = True 95 | 96 | 97 | def pyspark_23_to_24_transformers() -> list[cst.CSTTransformer]: 98 | """Return a list of transformers for PySpark 2.3 to 2.4 migration guide""" 99 | return [ 100 | ToPandasAllowsFallbackOnArrowOptimization(), 101 | RecommendDataFrameWriterV2ApiForV1ApiSaveAsTable(), 102 | RecommendDataFrameWriterV2ApiForV1ApiInsertInto(), 103 | ] 104 | -------------------------------------------------------------------------------- /pysparkler/tests/test_pyspark_32_to_33.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | 19 | from pysparkler.pyspark_32_to_33 import ( 20 | DataframeDropAxisIndexByDefault, 21 | RequiredPandasVersionCommentWriter, 22 | SQLDataTypesReprReturnsObjectCommentWriter, 23 | ) 24 | from tests.conftest import rewrite 25 | 26 | 27 | def test_adds_code_hint_to_drop_by_column_behavior_when_axis_not_specified_without_labels_keyword(): 28 | given_code = """ 29 | import pyspark.pandas as ps 30 | 31 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 32 | df.drop(['B', 'C']) 33 | display(df) 34 | """ 35 | modified_code = rewrite(given_code, DataframeDropAxisIndexByDefault()) 36 | expected_code = """ 37 | import pyspark.pandas as ps 38 | 39 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 40 | df.drop(['B', 'C']) # PY32-33-001: As of PySpark 3.3 the drop method of pandas API on Spark DataFrame sets drop by index as default, instead of drop by column. Please explicitly set axis argument to 1 to drop by column. # noqa: E501 41 | display(df) 42 | """ 43 | assert modified_code == expected_code 44 | 45 | 46 | def test_adds_code_hint_to_drop_by_column_behavior_when_axis_not_specified_with_labels_keyword(): 47 | given_code = """ 48 | import pyspark.pandas as ps 49 | 50 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 51 | df.drop(labels=['B', 'C']).withColumnRenamed('A', 'B') 52 | """ 53 | modified_code = rewrite(given_code, DataframeDropAxisIndexByDefault()) 54 | expected_code = """ 55 | import pyspark.pandas as ps 56 | 57 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 58 | df.drop(labels=['B', 'C']).withColumnRenamed('A', 'B') # PY32-33-001: As of PySpark 3.3 the drop method of pandas API on Spark DataFrame sets drop by index as default, instead of drop by column. Please explicitly set axis argument to 1 to drop by column. # noqa: E501 59 | """ 60 | assert modified_code == expected_code 61 | 62 | 63 | def test_does_nothing_when_drop_by_column_with_axis_one_specified(): 64 | given_code = """ 65 | import pyspark.pandas as ps 66 | 67 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 68 | df.drop(['B', 'C'], axis=1) 69 | """ 70 | modified_code = rewrite(given_code, DataframeDropAxisIndexByDefault()) 71 | assert modified_code == given_code 72 | 73 | 74 | def test_does_nothing_when_drop_by_column_with_axis_zero_specified(): 75 | given_code = """ 76 | import pyspark.pandas as ps 77 | 78 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 79 | df.drop(['B', 'C'], axis=0) 80 | """ 81 | modified_code = rewrite(given_code, DataframeDropAxisIndexByDefault()) 82 | assert modified_code == given_code 83 | 84 | 85 | def test_does_nothing_when_drop_by_columns_keyword(): 86 | given_code = """ 87 | import pyspark.pandas as ps 88 | 89 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 90 | df.drop(columns=['B', 'C']) 91 | """ 92 | modified_code = rewrite(given_code, DataframeDropAxisIndexByDefault()) 93 | assert modified_code == given_code 94 | 95 | 96 | def test_does_nothing_when_drop_by_index_keyword(): 97 | given_code = """ 98 | import pyspark.pandas as ps 99 | 100 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 101 | df.drop(index=[0, 1], columns='A') 102 | """ 103 | modified_code = rewrite(given_code, DataframeDropAxisIndexByDefault()) 104 | assert modified_code == given_code 105 | 106 | 107 | def test_adds_required_pandas_version_comment_to_import_statements(): 108 | given_code = """ 109 | import pandas 110 | import pyspark 111 | """ 112 | modified_code = rewrite(given_code, RequiredPandasVersionCommentWriter()) 113 | expected_code = """ 114 | import pandas # PY32-33-002: PySpark 3.3 requires pandas version 1.0.5 or higher # noqa: E501 115 | import pyspark 116 | """ 117 | assert modified_code == expected_code 118 | 119 | 120 | def test_adds_comment_when_repr_is_called_on_sql_data_types(): 121 | given_code = """ 122 | import pyspark.pandas as ps 123 | 124 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 125 | a_column_values = list(df['A'].unique()) 126 | repr_a_column_values = [repr(value) for value in a_column_values] 127 | """ 128 | modified_code = rewrite(given_code, SQLDataTypesReprReturnsObjectCommentWriter()) 129 | expected_code = """ 130 | import pyspark.pandas as ps 131 | 132 | df = ps.DataFrame(np.arange(12).reshape(3, 4), columns=['A', 'B', 'C', 'D']) 133 | a_column_values = list(df['A'].unique()) 134 | repr_a_column_values = [repr(value) for value in a_column_values] # PY32-33-003: As of PySpark 3.3, the repr return values of SQL DataTypes have been changed to yield an object with the same value when passed to eval. # noqa: E501 135 | """ 136 | assert modified_code == expected_code 137 | -------------------------------------------------------------------------------- /pysparkler/tests/test_pyspark_23_to_24.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | from pysparkler.pyspark_23_to_24 import ( 19 | RecommendDataFrameWriterV2ApiForV1ApiInsertInto, 20 | RecommendDataFrameWriterV2ApiForV1ApiSaveAsTable, 21 | ToPandasAllowsFallbackOnArrowOptimization, 22 | ) 23 | from tests.conftest import rewrite 24 | 25 | 26 | def test_writes_comment_when_topandas_func_is_used_without_import(): 27 | given_code = """\ 28 | import pyspark 29 | from pyspark.sql import SparkSession 30 | 31 | spark = SparkSession.builder.appName('example').getOrCreate() 32 | 33 | data = [("James","","Smith","36636","M",60000), 34 | ("Jen","Mary","Brown","","F",0)] 35 | 36 | columns = ["first_name","middle_name","last_name","dob","gender","salary"] 37 | pysparkDF = spark.createDataFrame(data = data, schema = columns) 38 | 39 | pandasDF = pysparkDF.toPandas() 40 | print(pandasDF) 41 | """ 42 | modified_code = rewrite(given_code, ToPandasAllowsFallbackOnArrowOptimization()) 43 | expected_code = """\ 44 | import pyspark 45 | from pyspark.sql import SparkSession 46 | 47 | spark = SparkSession.builder.appName('example').getOrCreate() 48 | 49 | data = [("James","","Smith","36636","M",60000), 50 | ("Jen","Mary","Brown","","F",0)] 51 | 52 | columns = ["first_name","middle_name","last_name","dob","gender","salary"] 53 | pysparkDF = spark.createDataFrame(data = data, schema = columns) 54 | 55 | pandasDF = pysparkDF.toPandas() # PY23-24-001: As of PySpark 2.4 toPandas() allows fallback to non-optimization by default when Arrow optimization is unable to be used. This can be switched off by spark.sql.execution.arrow.fallback.enabled # noqa: E501 56 | print(pandasDF) 57 | """ 58 | assert modified_code == expected_code 59 | 60 | 61 | def test_writes_comment_when_data_frame_writer_v1_api_save_as_table_is_detected(): 62 | given_code = """\ 63 | import pyspark 64 | from pyspark.sql import SparkSession 65 | 66 | spark = SparkSession.builder.appName('example').getOrCreate() 67 | 68 | data = [("James","","Smith","36636","M",60000), 69 | ("Jen","Mary","Brown","","F",0)] 70 | 71 | columns = ["first_name","middle_name","last_name","dob","gender","salary"] 72 | pysparkDF = spark.createDataFrame(data = data, schema = columns) 73 | 74 | pysparkDF.write.partitionBy('gender').saveAsTable("persons") 75 | """ 76 | modified_code = rewrite( 77 | given_code, RecommendDataFrameWriterV2ApiForV1ApiSaveAsTable() 78 | ) 79 | expected_code = """\ 80 | import pyspark 81 | from pyspark.sql import SparkSession 82 | 83 | spark = SparkSession.builder.appName('example').getOrCreate() 84 | 85 | data = [("James","","Smith","36636","M",60000), 86 | ("Jen","Mary","Brown","","F",0)] 87 | 88 | columns = ["first_name","middle_name","last_name","dob","gender","salary"] 89 | pysparkDF = spark.createDataFrame(data = data, schema = columns) 90 | 91 | pysparkDF.write.partitionBy('gender').saveAsTable("persons") # PY23-24-002: As of PySpark 2.4 the new DataFrameWriterV2 API is recommended for creating or replacing tables using data frames. To run a CTAS or RTAS, use create(), replace(), or createOrReplace() operations. For example: df.writeTo("prod.db.table").partitionedBy("dateint").createOrReplace(). Please note that the v1 DataFrame write API is still supported, but is not recommended. # noqa: E501 92 | """ 93 | assert modified_code == expected_code 94 | 95 | 96 | def test_writes_comment_when_data_frame_writer_v1_api_insert_into_is_detected(): 97 | given_code = """\ 98 | import pyspark 99 | from pyspark.sql import SparkSession 100 | 101 | spark = SparkSession.builder.appName('example').getOrCreate() 102 | 103 | data = [("James","","Smith","36636","M",60000), 104 | ("Jen","Mary","Brown","","F",0)] 105 | 106 | columns = ["first_name","middle_name","last_name","dob","gender","salary"] 107 | pysparkDF = spark.createDataFrame(data = data, schema = columns) 108 | 109 | pysparkDF.write.insertInto("persons", overwrite=True) 110 | """ 111 | modified_code = rewrite( 112 | given_code, RecommendDataFrameWriterV2ApiForV1ApiInsertInto() 113 | ) 114 | expected_code = """\ 115 | import pyspark 116 | from pyspark.sql import SparkSession 117 | 118 | spark = SparkSession.builder.appName('example').getOrCreate() 119 | 120 | data = [("James","","Smith","36636","M",60000), 121 | ("Jen","Mary","Brown","","F",0)] 122 | 123 | columns = ["first_name","middle_name","last_name","dob","gender","salary"] 124 | pysparkDF = spark.createDataFrame(data = data, schema = columns) 125 | 126 | pysparkDF.write.insertInto("persons", overwrite=True) # PY23-24-003: As of PySpark 2.4 the new DataFrameWriterV2 API is recommended for writing into tables in append or overwrite mode. For example, to append use df.writeTo(t).append() and to overwrite partitions dynamically use df.writeTo(t).overwritePartitions() Please note that the v1 DataFrame write API is still supported, but is not recommended. # noqa: E501 127 | """ 128 | assert modified_code == expected_code 129 | -------------------------------------------------------------------------------- /pysparkler/pysparkler/api.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | import json 19 | from typing import Any 20 | 21 | import libcst as cst 22 | import nbformat 23 | 24 | from pysparkler.base import BaseTransformer 25 | from pysparkler.pyspark_22_to_23 import pyspark_22_to_23_transformers 26 | from pysparkler.pyspark_23_to_24 import pyspark_23_to_24_transformers 27 | from pysparkler.pyspark_24_to_30 import pyspark_24_to_30_transformers 28 | from pysparkler.pyspark_31_to_32 import pyspark_31_to_32_transformers 29 | from pysparkler.pyspark_32_to_33 import pyspark_32_to_33_transformers 30 | from pysparkler.sql_21_to_33 import sql_21_to_33_transformers 31 | 32 | 33 | class PySparkler: 34 | """Main class for PySparkler""" 35 | 36 | def __init__( 37 | self, 38 | from_pyspark: str = "2.2", 39 | to_pyspark: str = "3.3", 40 | dry_run: bool = False, 41 | **overrides: dict[str, Any] 42 | ): 43 | self.from_pyspark = from_pyspark 44 | self.to_pyspark = to_pyspark 45 | self.dry_run = dry_run 46 | self.overrides = overrides 47 | 48 | @property 49 | def transformers(self) -> list[BaseTransformer]: 50 | """Returns a list of transformers to be applied to the AST""" 51 | all_transformers = [ 52 | *pyspark_22_to_23_transformers(), 53 | *pyspark_23_to_24_transformers(), 54 | *pyspark_24_to_30_transformers(), 55 | *pyspark_31_to_32_transformers(), 56 | *pyspark_32_to_33_transformers(), 57 | *sql_21_to_33_transformers(), 58 | ] 59 | # Override the default values of the transformers with the user provided values 60 | for transformer in all_transformers: 61 | if transformer.transformer_id in self.overrides: 62 | transformer.override(**self.overrides[transformer.transformer_id]) 63 | 64 | # Filter out disabled transformers 65 | enabled_transformers = [ 66 | transformer for transformer in all_transformers if transformer.enabled 67 | ] 68 | 69 | # Return the list of enabled transformers 70 | return enabled_transformers 71 | 72 | def upgrade_script(self, input_file: str, output_file: str | None = None) -> str: 73 | """Upgrade a PySpark Python script file to the latest version and provides comments as hints for manual 74 | changes""" 75 | # Parse the PySpark script written in version 2.4 76 | with open(input_file, encoding="utf-8") as f: 77 | original_code = f.read() 78 | original_tree = cst.parse_module(original_code) 79 | 80 | # Apply the re-writer to the AST 81 | modified_tree = self.visit(original_tree) 82 | 83 | if not self.dry_run: 84 | if output_file: 85 | # Write the modified AST to the output file location 86 | with open(output_file, "w", encoding="utf-8") as f: 87 | f.write(modified_tree.code) 88 | else: 89 | # Re-write the modified AST back to the input file 90 | with open(input_file, "w", encoding="utf-8") as f: 91 | f.write(modified_tree.code) 92 | 93 | # Return the modified Python Script 94 | return modified_tree.code 95 | 96 | def upgrade_notebook( 97 | self, 98 | input_file: str, 99 | output_file: str | None = None, 100 | output_kernel_name: str | None = None, 101 | ) -> str: 102 | """Upgrade a Jupyter Notebook that contains PySpark code cells to the latest version and provides comments 103 | as hints for manual changes""" 104 | 105 | # Parse the Jupyter Notebook 106 | nb = nbformat.read(input_file, as_version=nbformat.NO_CONVERT) 107 | 108 | # Apply the re-writer to the AST to each code cell 109 | for cell in nb.cells: 110 | if cell.cell_type == "code": 111 | original_code = "".join(cell.source) 112 | original_tree = cst.parse_module(original_code) 113 | modified_tree = self.visit(original_tree) 114 | cell.source = modified_tree.code.splitlines(keepends=True) 115 | 116 | # Update the kernel name if requested 117 | if output_kernel_name: 118 | nb.metadata.kernelspec.name = output_kernel_name 119 | nb.metadata.kernelspec.display_name = output_kernel_name 120 | 121 | if not self.dry_run: 122 | if output_file: 123 | # Write the modified Notebook to the output file location 124 | nbformat.write(nb, output_file) 125 | else: 126 | # Re-write the modified AST back to the input file 127 | nbformat.write(nb, input_file) 128 | 129 | # Return the modified Jupyter Notebook as String 130 | return json.dumps(nb.dict(), indent=1) 131 | 132 | def visit(self, module: cst.Module) -> cst.Module: 133 | """Visit a CSTModule and apply the transformers""" 134 | for transformer in self.transformers: 135 | module = module.visit(transformer) 136 | return module 137 | -------------------------------------------------------------------------------- /scalafix/build/sbt: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | # When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so 21 | # that we can run Hive to generate the golden answer. This is not required for normal development 22 | # or testing. 23 | if [ -n "$HIVE_HOME" ]; then 24 | for i in "$HIVE_HOME"/lib/* 25 | do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" 26 | done 27 | export HADOOP_CLASSPATH 28 | fi 29 | 30 | realpath () { 31 | ( 32 | TARGET_FILE="$1" 33 | 34 | cd "$(dirname "$TARGET_FILE")" 35 | TARGET_FILE="$(basename "$TARGET_FILE")" 36 | 37 | COUNT=0 38 | while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] 39 | do 40 | TARGET_FILE="$(readlink "$TARGET_FILE")" 41 | cd $(dirname "$TARGET_FILE") 42 | TARGET_FILE="$(basename $TARGET_FILE)" 43 | COUNT=$(($COUNT + 1)) 44 | done 45 | 46 | echo "$(pwd -P)/"$TARGET_FILE"" 47 | ) 48 | } 49 | 50 | . "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash 51 | 52 | 53 | declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" 54 | declare -r sbt_opts_file=".sbtopts" 55 | declare -r etc_sbt_opts_file="/etc/sbt/sbtopts" 56 | declare -r default_sbt_opts="-Xss4m" 57 | 58 | usage() { 59 | cat < path to global settings/plugins directory (default: ~/.sbt) 68 | -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series) 69 | -ivy path to local Ivy repository (default: ~/.ivy2) 70 | -mem set memory options (default: $sbt_default_mem, which is $(get_mem_opts $sbt_default_mem)) 71 | -no-share use all local caches; no sharing 72 | -no-global uses global caches, but does not use global ~/.sbt directory. 73 | -jvm-debug Turn on JVM debugging, open at the given port. 74 | -batch Disable interactive mode 75 | 76 | # sbt version (default: from project/build.properties if present, else latest release) 77 | -sbt-version use the specified version of sbt 78 | -sbt-jar use the specified jar as the sbt launcher 79 | -sbt-rc use an RC version of sbt 80 | -sbt-snapshot use a snapshot version of sbt 81 | 82 | # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) 83 | -java-home alternate JAVA_HOME 84 | 85 | # jvm options and output control 86 | JAVA_OPTS environment variable, if unset uses "$java_opts" 87 | SBT_OPTS environment variable, if unset uses "$default_sbt_opts" 88 | .sbtopts if this file exists in the current directory, it is 89 | prepended to the runner args 90 | /etc/sbt/sbtopts if this file exists, it is prepended to the runner args 91 | -Dkey=val pass -Dkey=val directly to the java runtime 92 | -J-X pass option -X directly to the java runtime 93 | (-J is stripped) 94 | -S-X add -X to sbt's scalacOptions (-S is stripped) 95 | -PmavenProfiles Enable a maven profile for the build. 96 | 97 | In the case of duplicated or conflicting options, the order above 98 | shows precedence: JAVA_OPTS lowest, command line options highest. 99 | EOM 100 | } 101 | 102 | process_my_args () { 103 | while [[ $# -gt 0 ]]; do 104 | case "$1" in 105 | -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; 106 | -no-share) addJava "$noshare_opts" && shift ;; 107 | -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;; 108 | -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; 109 | -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;; 110 | -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; 111 | -batch) exec /dev/null) 148 | if [[ ! $? ]]; then 149 | saved_stty="" 150 | fi 151 | } 152 | 153 | saveSttySettings 154 | trap onExit INT 155 | 156 | run "$@" 157 | 158 | exit_status=$? 159 | onExit 160 | -------------------------------------------------------------------------------- /scalafix/build/sbt-launch-lib.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | 4 | # A library to simplify using the SBT launcher from other packages. 5 | # Note: This should be used by tools like giter8/conscript etc. 6 | 7 | # TODO - Should we merge the main SBT script with this library? 8 | 9 | if test -z "$HOME"; then 10 | declare -r script_dir="$(dirname "$script_path")" 11 | else 12 | declare -r script_dir="$HOME/.sbt" 13 | fi 14 | 15 | declare -a residual_args 16 | declare -a java_args 17 | declare -a scalac_args 18 | declare -a sbt_commands 19 | declare -a maven_profiles 20 | declare sbt_default_mem=4096 21 | 22 | if test -x "$JAVA_HOME/bin/java"; then 23 | echo -e "Using $JAVA_HOME as default JAVA_HOME." 24 | echo "Note, this will be overridden by -java-home if it is set." 25 | declare java_cmd="$JAVA_HOME/bin/java" 26 | else 27 | declare java_cmd=java 28 | fi 29 | 30 | echoerr () { 31 | echo 1>&2 "$@" 32 | } 33 | vlog () { 34 | [[ $verbose || $debug ]] && echoerr "$@" 35 | } 36 | dlog () { 37 | [[ $debug ]] && echoerr "$@" 38 | } 39 | 40 | acquire_sbt_jar () { 41 | SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties` 42 | # DEFAULT_ARTIFACT_REPOSITORY env variable can be used to only fetch 43 | # artifacts from internal repos only. 44 | # Ex: 45 | # DEFAULT_ARTIFACT_REPOSITORY=https://artifacts.internal.com/libs-release/ 46 | URL1=${DEFAULT_ARTIFACT_REPOSITORY:-https://repo1.maven.org/maven2/}org/scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch-${SBT_VERSION}.jar 47 | JAR=build/sbt-launch-${SBT_VERSION}.jar 48 | 49 | sbt_jar=$JAR 50 | 51 | if [[ ! -f "$sbt_jar" ]]; then 52 | # Download sbt launch jar if it hasn't been downloaded yet 53 | if [ ! -f "${JAR}" ]; then 54 | # Download 55 | printf "Attempting to fetch sbt\n" 56 | JAR_DL="${JAR}.part" 57 | if [ $(command -v curl) ]; then 58 | curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ 59 | mv "${JAR_DL}" "${JAR}" 60 | elif [ $(command -v wget) ]; then 61 | wget --quiet ${URL1} -O "${JAR_DL}" &&\ 62 | mv "${JAR_DL}" "${JAR}" 63 | else 64 | printf "You do not have curl or wget installed, please install sbt manually from https://www.scala-sbt.org/\n" 65 | exit -1 66 | fi 67 | fi 68 | if [ ! -f "${JAR}" ]; then 69 | # We failed to download 70 | printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from https://www.scala-sbt.org/\n" 71 | exit -1 72 | fi 73 | printf "Launching sbt from ${JAR}\n" 74 | fi 75 | } 76 | 77 | execRunner () { 78 | # print the arguments one to a line, quoting any containing spaces 79 | [[ $verbose || $debug ]] && echo "# Executing command line:" && { 80 | for arg; do 81 | if printf "%s\n" "$arg" | grep -q ' '; then 82 | printf "\"%s\"\n" "$arg" 83 | else 84 | printf "%s\n" "$arg" 85 | fi 86 | done 87 | echo "" 88 | } 89 | 90 | "$@" 91 | } 92 | 93 | addJava () { 94 | dlog "[addJava] arg = '$1'" 95 | java_args=( "${java_args[@]}" "$1" ) 96 | } 97 | 98 | enableProfile () { 99 | dlog "[enableProfile] arg = '$1'" 100 | maven_profiles=( "${maven_profiles[@]}" "$1" ) 101 | export SBT_MAVEN_PROFILES="${maven_profiles[@]}" 102 | } 103 | 104 | addSbt () { 105 | dlog "[addSbt] arg = '$1'" 106 | sbt_commands=( "${sbt_commands[@]}" "$1" ) 107 | } 108 | addResidual () { 109 | dlog "[residual] arg = '$1'" 110 | residual_args=( "${residual_args[@]}" "$1" ) 111 | } 112 | addDebugger () { 113 | addJava "-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=$1" 114 | } 115 | 116 | # a ham-fisted attempt to move some memory settings in concert 117 | # so they need not be dicked around with individually. 118 | get_mem_opts () { 119 | local mem=${1:-$sbt_default_mem} 120 | local codecache=$(( $mem / 8 )) 121 | (( $codecache > 128 )) || codecache=128 122 | (( $codecache < 2048 )) || codecache=2048 123 | 124 | echo "-Xms${mem}m -Xmx${mem}m -XX:ReservedCodeCacheSize=${codecache}m" 125 | } 126 | 127 | require_arg () { 128 | local type="$1" 129 | local opt="$2" 130 | local arg="$3" 131 | if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then 132 | echo "$opt requires <$type> argument" 1>&2 133 | exit 1 134 | fi 135 | } 136 | 137 | is_function_defined() { 138 | declare -f "$1" > /dev/null 139 | } 140 | 141 | process_args () { 142 | while [[ $# -gt 0 ]]; do 143 | case "$1" in 144 | -h|-help) usage; exit 1 ;; 145 | -v|-verbose) verbose=1 && shift ;; 146 | -d|-debug) debug=1 && shift ;; 147 | 148 | -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; 149 | -mem) require_arg integer "$1" "$2" && sbt_mem="$2" && shift 2 ;; 150 | -jvm-debug) require_arg port "$1" "$2" && addDebugger $2 && shift 2 ;; 151 | -batch) exec None: 31 | 32 | if isinstance(truncate, bool) and truncate: 33 | print(self._jdf.showString(n, 20, vertical)) 34 | else: 35 | try: 36 | int_truncate = int(truncate) 37 | except ValueError: 38 | raise TypeError( 39 | "Parameter 'truncate={}' should be either bool or int.".format(truncate) 40 | ) 41 | 42 | print(self._jdf.showString(n, int_truncate, vertical)) 43 | """ 44 | modified_code = rewrite(given_code, SqlMlMethodsRaiseTypeErrorCommentWriter()) 45 | expected_code = """ 46 | from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column 47 | from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 48 | from pyspark.sql.streaming import DataStreamWriter 49 | 50 | 51 | def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: 52 | 53 | if isinstance(truncate, bool) and truncate: 54 | print(self._jdf.showString(n, 20, vertical)) 55 | else: 56 | try: 57 | int_truncate = int(truncate) 58 | except ValueError: 59 | raise TypeError( 60 | "Parameter 'truncate={}' should be either bool or int.".format(truncate) 61 | ) # PY31-32-001: As of PySpark 3.2, the methods from sql, ml, spark_on_pandas modules raise the TypeError instead of ValueError when are applied to a param of inappropriate type. # noqa: E501 62 | 63 | print(self._jdf.showString(n, int_truncate, vertical)) 64 | """ 65 | assert modified_code == expected_code 66 | 67 | 68 | def test_adds_may_raise_type_error_with_alias_when_catching_value_errors_on_sql_or_ml_from_imports(): 69 | given_code = """ 70 | from pyspark.ml.streaming import DataStreamWriter 71 | 72 | def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: 73 | 74 | if isinstance(truncate, bool) and truncate: 75 | print(self._jdf.showString(n, 20, vertical)) 76 | else: 77 | try: 78 | int_truncate = int(truncate) 79 | except ValueError as ex: 80 | raise TypeError( 81 | "Parameter 'truncate={}' should be either bool or int.".format(truncate) 82 | ) 83 | 84 | print(self._jdf.showString(n, int_truncate, vertical)) 85 | """ 86 | modified_code = rewrite(given_code, SqlMlMethodsRaiseTypeErrorCommentWriter()) 87 | expected_code = """ 88 | from pyspark.ml.streaming import DataStreamWriter 89 | 90 | def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: 91 | 92 | if isinstance(truncate, bool) and truncate: 93 | print(self._jdf.showString(n, 20, vertical)) 94 | else: 95 | try: 96 | int_truncate = int(truncate) 97 | except ValueError as ex: 98 | raise TypeError( 99 | "Parameter 'truncate={}' should be either bool or int.".format(truncate) 100 | ) # PY31-32-001: As of PySpark 3.2, the methods from sql, ml, spark_on_pandas modules raise the TypeError instead of ValueError when are applied to a param of inappropriate type. # noqa: E501 101 | 102 | print(self._jdf.showString(n, int_truncate, vertical)) 103 | """ 104 | assert modified_code == expected_code 105 | 106 | 107 | def test_adds_may_raise_type_error_with_alias_when_catching_value_errors_on_sql_or_ml_import(): 108 | given_code = """ 109 | import pyspark.sql.functions as f 110 | 111 | def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: 112 | 113 | if isinstance(truncate, bool) and truncate: 114 | print(self._jdf.showString(n, 20, vertical)) 115 | else: 116 | try: 117 | int_truncate = int(truncate) 118 | except ValueError as ex: 119 | raise TypeError( 120 | "Parameter 'truncate={}' should be either bool or int.".format(truncate) 121 | ) 122 | 123 | print(self._jdf.showString(n, int_truncate, vertical)) 124 | """ 125 | modified_code = rewrite(given_code, SqlMlMethodsRaiseTypeErrorCommentWriter()) 126 | expected_code = """ 127 | import pyspark.sql.functions as f 128 | 129 | def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: 130 | 131 | if isinstance(truncate, bool) and truncate: 132 | print(self._jdf.showString(n, 20, vertical)) 133 | else: 134 | try: 135 | int_truncate = int(truncate) 136 | except ValueError as ex: 137 | raise TypeError( 138 | "Parameter 'truncate={}' should be either bool or int.".format(truncate) 139 | ) # PY31-32-001: As of PySpark 3.2, the methods from sql, ml, spark_on_pandas modules raise the TypeError instead of ValueError when are applied to a param of inappropriate type. # noqa: E501 140 | 141 | print(self._jdf.showString(n, int_truncate, vertical)) 142 | """ 143 | assert modified_code == expected_code 144 | --------------------------------------------------------------------------------