├── Dockerfile ├── README.md ├── docker-compose.yaml ├── main.py └── test ├── __pycache__ ├── conftest.cpython-36-pytest-6.2.2.pyc └── test_pipeline.cpython-36-pytest-6.2.2.pyc ├── conftest.py └── test_pipeline.py /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y default-jdk scala wget vim software-properties-common python3.8 python3-pip curl unzip libpq-dev build-essential libssl-dev libffi-dev python3-dev&& \ 5 | apt-get clean 6 | 7 | RUN wget https://archive.apache.org/dist/spark/spark-3.0.1/spark-3.0.1-bin-hadoop3.2.tgz && \ 8 | tar xvf spark-3.0.1-bin-hadoop3.2.tgz && \ 9 | mv spark-3.0.1-bin-hadoop3.2/ /usr/local/spark && \ 10 | ln -s /usr/local/spark spark 11 | 12 | WORKDIR app 13 | COPY . /app 14 | RUN pip3 install cython==0.29.21 numpy==1.18.5 && pip3 install pytest pyspark pandas==1.0.5 15 | ENV PYSPARK_PYTHON=python3 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## How to unit test PySpark code ## 2 | 3 | Read full blog post here https://www.confessionsofadataguy.com/introduction-to-unit-testing-with-pyspark/ 4 | 5 | Just some sample code and instructions on how to write simple unit tests and setup Docker to run unit-tests for PySpark code. 6 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | services: 3 | test: 4 | environment: 5 | - PYTHONPATH=./src 6 | image: "spark-test" 7 | volumes: 8 | - .:/app 9 | command: python3 -m pytest -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pyspark.sql.functions as F 2 | from pyspark.sql import DataFrame 3 | 4 | 5 | def sample_transform(input_df: DataFrame) -> DataFrame: 6 | inter_df = input_df.where(input_df['that_column'] == \ 7 | F.lit('hobbit')).groupBy('another_column').agg(F.sum('yet_another').alias('new_column')) 8 | output_df = inter_df.select('another_column', 'new_column', \ 9 | F.when(F.col('new_column') > 10, 'yes').otherwise('no').alias('indicator')).where( 10 | F.col('indicator') == F.lit('yes')) 11 | return output_df -------------------------------------------------------------------------------- /test/__pycache__/conftest.cpython-36-pytest-6.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielbeach/unitTestPySpark/dfacfbc695c93080f99d74b2d81f1d7929b40b2a/test/__pycache__/conftest.cpython-36-pytest-6.2.2.pyc -------------------------------------------------------------------------------- /test/__pycache__/test_pipeline.cpython-36-pytest-6.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielbeach/unitTestPySpark/dfacfbc695c93080f99d74b2d81f1d7929b40b2a/test/__pycache__/test_pipeline.cpython-36-pytest-6.2.2.pyc -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pyspark.sql import SparkSession 3 | 4 | 5 | @pytest.fixture(scope="session") 6 | def spark_session(): 7 | spark = SparkSession.builder.master("local[*]").appName("test").getOrCreate() 8 | return spark -------------------------------------------------------------------------------- /test/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from main import sample_transform 3 | 4 | 5 | @pytest.mark.usefixtures("spark_session") 6 | def test_sample_transform(spark_session): 7 | test_df = spark_session.createDataFrame( 8 | [ 9 | ('hobbit', 'Samwise', 5), 10 | ('hobbit', 'Billbo', 50), 11 | ('hobbit', 'Billbo', 20), 12 | ('wizard', 'Gandalf', 1000) 13 | ], 14 | ['that_column', 'another_column', 'yet_another'] 15 | ) 16 | new_df = sample_transform(test_df) 17 | assert new_df.count() == 1 18 | assert new_df.toPandas().to_dict('list')['new_column'][0] == 70 19 | --------------------------------------------------------------------------------